package asymm import ( "crypto" "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "errors" "math/big" "golang.org/x/crypto/ssh" ) func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { private, err := rsa.GenerateKey(rand.Reader, bits) if err != nil { return nil, nil, err } return private, &private.PublicKey, nil } func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) { der := x509.MarshalPKCS1PrivateKey(private) if secret == "" { return pem.EncodeToMemory(&pem.Block{ Type: "RSA PRIVATE KEY", Bytes: der, }), nil } blk, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", der, []byte(secret), x509.PEMCipherAES256) if err != nil { return nil, err } return pem.EncodeToMemory(blk), nil } func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) { publicBytes, err := x509.MarshalPKIXPublicKey(public) if err != nil { return nil, err } return pem.EncodeToMemory(&pem.Block{ Type: "PUBLIC KEY", Bytes: publicBytes, }), nil } func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, error) { blk, _ := pem.Decode(private) if blk == nil { return nil, errors.New("private key error") } bytes, err := decodePEMBlockBytes(blk, password) if err != nil { return nil, err } if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil { return prikey, nil } pkcs8, err := x509.ParsePKCS8PrivateKey(bytes) if err != nil { return nil, err } prikey, ok := pkcs8.(*rsa.PrivateKey) if !ok { return nil, errors.New("private key is not RSA") } return prikey, nil } func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) { blk, _ := pem.Decode(pubStr) if blk == nil { return nil, errors.New("public key error") } pub, err := x509.ParsePKIXPublicKey(blk.Bytes) if err != nil { return nil, err } rsapub, ok := pub.(*rsa.PublicKey) if !ok { return nil, errors.New("public key is not RSA") } return rsapub, nil } func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) { publicKey, err := ssh.NewPublicKey(public) if err != nil { return nil, err } return ssh.MarshalAuthorizedKey(publicKey), nil } func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { pkey, pubkey, err := GenerateRsaKey(bits) if err != nil { return "", "", err } pub, err := EncodeRsaSSHPublicKey(pubkey) if err != nil { return "", "", err } priv, err := EncodeRsaPrivateKey(pkey, secret) if err != nil { return "", "", err } return string(priv), string(pub), nil } func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) { return rsa.EncryptPKCS1v15(rand.Reader, pub, data) } func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) { return rsa.DecryptPKCS1v15(rand.Reader, prikey, data) } func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) { prikey, err := DecodeRsaPrivateKey(priKey, password) if err != nil { return nil, err } hashed, err := hashMessage(msg, hashType) if err != nil { return nil, err } return rsa.SignPKCS1v15(rand.Reader, prikey, hashType, hashed) } func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error { pubkey, err := DecodeRsaPublicKey(pubKey) if err != nil { return err } hashed, err := hashMessage(msg, hashType) if err != nil { return err } return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig) } func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) { return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data) } func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) { c := new(big.Int).SetBytes(data) m := new(big.Int).Exp(c, big.NewInt(int64(pub.E)), pub.N) em := leftPad(m.Bytes(), (pub.N.BitLen()+7)/8) return unLeftPad(em) } func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) { if hashType == 0 { return msg, nil } if !hashType.Available() { return nil, errors.New("hash function is not available") } h := hashType.New() _, err := h.Write(msg) if err != nil { return nil, err } return h.Sum(nil), nil } func leftPad(input []byte, size int) []byte { n := len(input) if n > size { n = size } out := make([]byte, size) copy(out[len(out)-n:], input) return out } func unLeftPad(input []byte) ([]byte, error) { // PKCS#1 v1.5 block format: 0x00 || 0x01 || PS(0xff...) || 0x00 || M if len(input) < 3 { return nil, errors.New("invalid RSA block") } if input[0] != 0x00 || input[1] != 0x01 { return nil, errors.New("invalid RSA block header") } i := 2 for i < len(input) && input[i] == 0xff { i++ } if i >= len(input) || input[i] != 0x00 { return nil, errors.New("invalid RSA block padding") } i++ if i > len(input) { return nil, errors.New("invalid RSA block payload") } out := make([]byte, len(input)-i) copy(out, input[i:]) return out, nil }