starcrypto/asymm/rsa.go

339 lines
8.6 KiB
Go

package asymm
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"math/big"
"github.com/emmansun/gmsm/pkcs8"
"golang.org/x/crypto/ssh"
)
func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
private, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, err
}
return private, &private.PublicKey, nil
}
func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) {
return EncodeRsaPrivateKeyWithLegacy(private, secret, true)
}
func EncodeRsaPrivateKeyWithLegacy(private *rsa.PrivateKey, secret string, legacy bool) ([]byte, error) {
if legacy {
return encodeRsaPrivateKeyLegacy(private, secret)
}
return EncodeRsaPrivateKeyPKCS8(private, secret)
}
func EncodeRsaPrivateKeyPKCS8(private *rsa.PrivateKey, secret string) ([]byte, error) {
password := []byte(secret)
var (
der []byte
blockType = "PRIVATE KEY"
err error
)
if secret == "" {
der, err = pkcs8.MarshalPrivateKey(private, nil, nil)
} else {
der, err = pkcs8.MarshalPrivateKey(private, password, pkcs8.DefaultOpts)
blockType = "ENCRYPTED PRIVATE KEY"
}
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der}), nil
}
func encodeRsaPrivateKeyLegacy(private *rsa.PrivateKey, secret string) ([]byte, error) {
der := x509.MarshalPKCS1PrivateKey(private)
if secret == "" {
return pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
}), nil
}
blk, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", der, []byte(secret), x509.PEMCipherAES256)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(blk), nil
}
func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) {
publicBytes, err := x509.MarshalPKIXPublicKey(public)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicBytes,
}), nil
}
func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
switch blk.Type {
case "PRIVATE KEY", "ENCRYPTED PRIVATE KEY":
return DecodeRsaPrivateKeyPKCS8(private, password)
default:
return decodeRsaPrivateKeyLegacy(private, password)
}
}
func DecodeRsaPrivateKeyWithLegacy(private []byte, password string, legacy bool) (*rsa.PrivateKey, error) {
if legacy {
return decodeRsaPrivateKeyLegacy(private, password)
}
return DecodeRsaPrivateKeyPKCS8(private, password)
}
func DecodeRsaPrivateKeyPKCS8(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
switch blk.Type {
case "PRIVATE KEY":
return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes)
case "ENCRYPTED PRIVATE KEY":
if password == "" {
return nil, errors.New("private key is encrypted but password is empty")
}
return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes, []byte(password))
default:
return nil, errors.New("private key is not PKCS#8")
}
}
func decodeRsaPrivateKeyLegacy(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
bytes, err := decodePEMBlockBytes(blk, password)
if err != nil {
return nil, err
}
if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil {
return prikey, nil
}
pkcs8key, err := x509.ParsePKCS8PrivateKey(bytes)
if err != nil {
return nil, err
}
prikey, ok := pkcs8key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not RSA")
}
return prikey, nil
}
func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) {
blk, _ := pem.Decode(pubStr)
if blk == nil {
return nil, errors.New("public key error")
}
pub, err := x509.ParsePKIXPublicKey(blk.Bytes)
if err != nil {
return nil, err
}
rsapub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not RSA")
}
return rsapub, nil
}
func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) {
publicKey, err := ssh.NewPublicKey(public)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(publicKey), nil
}
func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) {
return GenerateRsaSSHKeyPairWithLegacy(bits, secret, true)
}
func GenerateRsaSSHKeyPairWithLegacy(bits int, secret string, legacy bool) (string, string, error) {
pkey, pubkey, err := GenerateRsaKey(bits)
if err != nil {
return "", "", err
}
pub, err := EncodeRsaSSHPublicKey(pubkey)
if err != nil {
return "", "", err
}
priv, err := EncodeRsaPrivateKeyWithLegacy(pkey, secret, legacy)
if err != nil {
return "", "", err
}
return string(priv), string(pub), nil
}
func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, pub, data)
}
func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, prikey, data)
}
func RSAEncryptOAEP(pub *rsa.PublicKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
hashType, err := normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
return rsa.EncryptOAEP(hashType.New(), rand.Reader, pub, data, label)
}
func RSADecryptOAEP(prikey *rsa.PrivateKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
hashType, err := normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
return rsa.DecryptOAEP(hashType.New(), rand.Reader, prikey, data, label)
}
func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) {
prikey, err := DecodeRsaPrivateKey(priKey, password)
if err != nil {
return nil, err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return nil, err
}
return rsa.SignPKCS1v15(rand.Reader, prikey, hashType, hashed)
}
func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error {
pubkey, err := DecodeRsaPublicKey(pubKey)
if err != nil {
return err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return err
}
return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig)
}
func RSASignPSS(msg, priKey []byte, password string, hashType crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) {
prikey, err := DecodeRsaPrivateKey(priKey, password)
if err != nil {
return nil, err
}
hashType, err = normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return nil, err
}
return rsa.SignPSS(rand.Reader, prikey, hashType, hashed, opts)
}
func RSAVerifyPSS(sig, msg, pubKey []byte, hashType crypto.Hash, opts *rsa.PSSOptions) error {
pubkey, err := DecodeRsaPublicKey(pubKey)
if err != nil {
return err
}
hashType, err = normalizeModernRSAHash(hashType)
if err != nil {
return err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return err
}
return rsa.VerifyPSS(pubkey, hashType, hashed, sig, opts)
}
func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data)
}
func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) {
c := new(big.Int).SetBytes(data)
m := new(big.Int).Exp(c, big.NewInt(int64(pub.E)), pub.N)
em := leftPad(m.Bytes(), (pub.N.BitLen()+7)/8)
return unLeftPad(em)
}
func normalizeModernRSAHash(hashType crypto.Hash) (crypto.Hash, error) {
if hashType == 0 {
hashType = crypto.SHA256
}
if !hashType.Available() {
return 0, errors.New("hash function is not available")
}
return hashType, nil
}
func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) {
if hashType == 0 {
return msg, nil
}
if !hashType.Available() {
return nil, errors.New("hash function is not available")
}
h := hashType.New()
_, err := h.Write(msg)
if err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func leftPad(input []byte, size int) []byte {
n := len(input)
if n > size {
n = size
}
out := make([]byte, size)
copy(out[len(out)-n:], input)
return out
}
func unLeftPad(input []byte) ([]byte, error) {
// PKCS#1 v1.5 block format: 0x00 || 0x01 || PS(0xff...) || 0x00 || M
if len(input) < 3 {
return nil, errors.New("invalid RSA block")
}
if input[0] != 0x00 || input[1] != 0x01 {
return nil, errors.New("invalid RSA block header")
}
i := 2
for i < len(input) && input[i] == 0xff {
i++
}
if i >= len(input) || input[i] != 0x00 {
return nil, errors.New("invalid RSA block padding")
}
i++
if i > len(input) {
return nil, errors.New("invalid RSA block payload")
}
out := make([]byte, len(input)-i)
copy(out, input[i:])
return out, nil
}