diff --git a/sm2/sm2.go b/sm2/sm2.go index 2f3eea4..bbd13ff 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -250,6 +250,17 @@ func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { return md.Sum(nil) } +func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(x1) + b.AddASN1BigInt(y1) + b.AddASN1OctetString(c3) + b.AddASN1OctetString(c2) + }) + return b.Bytes() +} + // sm2 encrypt and output ASN.1 result func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) { return Encrypt(random, pub, msg, ASN1EncrypterOpts) @@ -310,16 +321,9 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter } // c1 || c2 || c3 return append(append(c1, c2...), c3...), nil - } else { // ASN.1 format will force C3 C2 order - var b cryptobyte.Builder - b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { - b.AddASN1BigInt(x1) - b.AddASN1BigInt(y1) - b.AddASN1OctetString(c3) - b.AddASN1OctetString(c2) - }) - return b.Bytes() } + // ASN.1 format will force C3 C2 order + return mashalASN1Ciphertext(x1, y1, c2, c3) } } @@ -344,20 +348,9 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { } func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) { - var ( - x1, y1 = &big.Int{}, &big.Int{} - c2, c3 []byte - inner cryptobyte.String - ) - input := cryptobyte.String(ciphertext) - if !input.ReadASN1(&inner, asn1.SEQUENCE) || - !input.Empty() || - !inner.ReadASN1Integer(x1) || - !inner.ReadASN1Integer(y1) || - !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || - !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || - !inner.Empty() { - return nil, errors.New("SM2: invalid asn1 format ciphertext") + x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext) + if err != nil { + return nil, err } return rawDecrypt(priv, x1, y1, c2, c3) } @@ -420,11 +413,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, return rawDecrypt(priv, x1, y1, c2, c3) } -// utility method to convert ASN.1 encoding ciphertext to plain encoding format -func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { - if opts == nil { - opts = defaultEncrypterOpts - } +func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) { var ( x1, y1 = &big.Int{}, &big.Int{} c2, c3 []byte @@ -438,7 +427,19 @@ func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || !inner.Empty() { - return nil, errors.New("SM2: invalid asn1 format ciphertext") + return nil, nil, nil, nil, errors.New("SM2: invalid asn1 format ciphertext") + } + return x1, y1, c2, c3, nil +} + +// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format +func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { + if opts == nil { + opts = defaultEncrypterOpts + } + x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext)) + if err != nil { + return nil, err } curve := P256() c1 := opts.PointMarshalMode.mashal(curve, x1, y1) @@ -450,7 +451,7 @@ func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error return append(append(c1, c2...), c3...), nil } -// utility method to convert plain encoding ciphertext to ASN.1 encoding format +// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { if ciphertext[0] == 0x30 { return nil, errors.New("SM2: invalid plain encoding ciphertext") @@ -475,17 +476,10 @@ func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]by c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] c3 = ciphertext[ciphertextLen-sm3.Size:] } - var b cryptobyte.Builder - b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { - b.AddASN1BigInt(x1) - b.AddASN1BigInt(y1) - b.AddASN1OctetString(c3) - b.AddASN1OctetString(c2) - }) - return b.Bytes() + return mashalASN1Ciphertext(x1, y1, c2, c3) } -// utility method +// AdjustCiphertextSplicingOrder utility method to change c2 c3 order func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) { curve := P256() if from == to { diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 545c342..00bb930 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -83,7 +83,7 @@ func Test_encryptDecrypt_ASN1(t *testing.T) { plainText string }{ // TODO: Add test cases. - {"less than 32", "emmansun"}, + {"less than 32", "encryption standard"}, {"equals 32", "encryption standard encryption "}, {"long than 32", "encryption standard encryption standard"}, } @@ -112,7 +112,7 @@ func Test_Ciphertext2ASN1(t *testing.T) { plainText string }{ // TODO: Add test cases. - {"less than 32", "emmansun"}, + {"less than 32", "encryption standard"}, {"equals 32", "encryption standard encryption "}, {"long than 32", "encryption standard encryption standard"}, } @@ -144,7 +144,7 @@ func Test_ASN1Ciphertext2Plain(t *testing.T) { plainText string }{ // TODO: Add test cases. - {"less than 32", "emmansun"}, + {"less than 32", "encryption standard"}, {"equals 32", "encryption standard encryption "}, {"long than 32", "encryption standard encryption standard"}, } diff --git a/smx509/x509_test.go b/smx509/x509_test.go index 74920fc..754429f 100644 --- a/smx509/x509_test.go +++ b/smx509/x509_test.go @@ -264,30 +264,26 @@ func TestSignByHuaweiVerifyAtLocal(t *testing.T) { } } -func TestParsePKIXPublicKey(t *testing.T) { - pub, err := getPublicKey([]byte(publicKeyPemFromAliKms)) - if err != nil { - t.Fatal(err) +func TestParsePKIXPublicKeyFromExternal(t *testing.T) { + tests := []struct { + name string + pem string + }{ + {"ALI", publicKeyPemFromAliKms}, + {"HUAWEI", publicKeyPemFromHuaweiKms}, } - pub1 := pub.(*ecdsa.PublicKey) - encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("testfile"), nil) - if err != nil { - t.Fatal(err) + for _, test := range tests { + pub, err := getPublicKey([]byte(test.pem)) + if err != nil { + t.Fatalf("%s failed to get public key %v", test.name, err) + } + pub1 := pub.(*ecdsa.PublicKey) + encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("encryption standard"), sm2.ASN1EncrypterOpts) + if err != nil { + t.Fatalf("%s failed to encrypt %v", test.name, err) + } + fmt.Printf("encrypted=%s\n", base64.RawURLEncoding.EncodeToString(encrypted)) } - fmt.Printf("encrypted=%s\n", base64.StdEncoding.EncodeToString(encrypted)) -} - -func TestParsePKIXPublicKeyFromHuawei(t *testing.T) { - pub, err := getPublicKey([]byte(publicKeyPemFromHuaweiKms)) - if err != nil { - t.Fatal(err) - } - pub1 := pub.(*ecdsa.PublicKey) - encrypted, err := sm2.Encrypt(rand.Reader, pub1, []byte("encryption standard"), sm2.ASN1EncrypterOpts) - if err != nil { - t.Fatal(err) - } - fmt.Printf("encrypted=%s\n", base64.RawURLEncoding.EncodeToString(encrypted)) } func TestMarshalPKIXPublicKey(t *testing.T) {