starcrypto/asymm/rsa.go

205 lines
4.8 KiB
Go

package asymm
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"math/big"
"golang.org/x/crypto/ssh"
)
func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
private, err := rsa.GenerateKey(rand.Reader, bits)
if err != nil {
return nil, nil, err
}
return private, &private.PublicKey, nil
}
func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) {
der := x509.MarshalPKCS1PrivateKey(private)
if secret == "" {
return pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: der,
}), nil
}
blk, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", der, []byte(secret), x509.PEMCipherAES256)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(blk), nil
}
func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) {
publicBytes, err := x509.MarshalPKIXPublicKey(public)
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{
Type: "PUBLIC KEY",
Bytes: publicBytes,
}), nil
}
func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
bytes, err := decodePEMBlockBytes(blk, password)
if err != nil {
return nil, err
}
if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil {
return prikey, nil
}
pkcs8, err := x509.ParsePKCS8PrivateKey(bytes)
if err != nil {
return nil, err
}
prikey, ok := pkcs8.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not RSA")
}
return prikey, nil
}
func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) {
blk, _ := pem.Decode(pubStr)
if blk == nil {
return nil, errors.New("public key error")
}
pub, err := x509.ParsePKIXPublicKey(blk.Bytes)
if err != nil {
return nil, err
}
rsapub, ok := pub.(*rsa.PublicKey)
if !ok {
return nil, errors.New("public key is not RSA")
}
return rsapub, nil
}
func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) {
publicKey, err := ssh.NewPublicKey(public)
if err != nil {
return nil, err
}
return ssh.MarshalAuthorizedKey(publicKey), nil
}
func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) {
pkey, pubkey, err := GenerateRsaKey(bits)
if err != nil {
return "", "", err
}
pub, err := EncodeRsaSSHPublicKey(pubkey)
if err != nil {
return "", "", err
}
priv, err := EncodeRsaPrivateKey(pkey, secret)
if err != nil {
return "", "", err
}
return string(priv), string(pub), nil
}
func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) {
return rsa.EncryptPKCS1v15(rand.Reader, pub, data)
}
func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, prikey, data)
}
func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) {
prikey, err := DecodeRsaPrivateKey(priKey, password)
if err != nil {
return nil, err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return nil, err
}
return rsa.SignPKCS1v15(rand.Reader, prikey, hashType, hashed)
}
func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error {
pubkey, err := DecodeRsaPublicKey(pubKey)
if err != nil {
return err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return err
}
return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig)
}
func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data)
}
func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) {
c := new(big.Int).SetBytes(data)
m := new(big.Int).Exp(c, big.NewInt(int64(pub.E)), pub.N)
em := leftPad(m.Bytes(), (pub.N.BitLen()+7)/8)
return unLeftPad(em)
}
func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) {
if hashType == 0 {
return msg, nil
}
if !hashType.Available() {
return nil, errors.New("hash function is not available")
}
h := hashType.New()
_, err := h.Write(msg)
if err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func leftPad(input []byte, size int) []byte {
n := len(input)
if n > size {
n = size
}
out := make([]byte, size)
copy(out[len(out)-n:], input)
return out
}
func unLeftPad(input []byte) ([]byte, error) {
// PKCS#1 v1.5 block format: 0x00 || 0x01 || PS(0xff...) || 0x00 || M
if len(input) < 3 {
return nil, errors.New("invalid RSA block")
}
if input[0] != 0x00 || input[1] != 0x01 {
return nil, errors.New("invalid RSA block header")
}
i := 2
for i < len(input) && input[i] == 0xff {
i++
}
if i >= len(input) || input[i] != 0x00 {
return nil, errors.New("invalid RSA block padding")
}
i++
if i > len(input) {
return nil, errors.New("invalid RSA block payload")
}
out := make([]byte, len(input)-i)
copy(out, input[i:])
return out, nil
}