diff --git a/sm2/p256.go b/sm2/p256.go index 13af9c3..fb891c1 100644 --- a/sm2/p256.go +++ b/sm2/p256.go @@ -24,7 +24,7 @@ var ( ) func initP256() { - p256.CurveParams = &elliptic.CurveParams{Name: "P-256/SM2"} + p256.CurveParams = &elliptic.CurveParams{Name: "sm2p256v1"} // 2**256 - 2**224 - 2**96 + 2**64 - 1 p256.P, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) p256.N, _ = new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123", 16) diff --git a/sm2/sm2.go b/sm2/sm2.go index ad2be72..0581578 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -83,7 +83,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) //A2, calculate C1 = k * G x1, y1 := curve.ScalarBaseMult(k.Bytes()) - c1 := point2CompressedBytes(curve, x1, y1) + c1 := point2UncompressedBytes(curve, x1, y1) //A3, skipped //A4, calculate k * P (point of Public Key) @@ -105,7 +105,8 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) //A7, C3 = hash(x2||M||y2) c3 := calculateC3(curve, x2, y2, msg) - return append(append(c1, c2...), c3...), nil + // c1 || c3 || c2 + return append(append(c1, c3...), c2...), nil } } @@ -117,7 +118,7 @@ func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { } curve := priv.Curve // B1, get C1, and check C1 - x1, y1, c2Start, err := bytes2Point(curve, ciphertext) + x1, y1, c3Start, err := bytes2Point(curve, ciphertext) if err != nil { return nil, err } @@ -127,7 +128,7 @@ func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) //B4, calculate t=KDF(x2||y2, klen) - c2 := ciphertext[c2Start : ciphertextLen-sm3.Size] + c2 := ciphertext[c3Start+sm3.Size:] msgLen := len(c2) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { @@ -141,7 +142,7 @@ func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) { } //B6, calculate hash and compare it - c3 := ciphertext[ciphertextLen-sm3.Size:] + c3 := ciphertext[c3Start : c3Start+sm3.Size] u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { if c3[i] != u[i] { diff --git a/sm2/util.go b/sm2/util.go index 4a8de1a..5e54bc1 100644 --- a/sm2/util.go +++ b/sm2/util.go @@ -96,7 +96,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e if len(bytes) < 1+byteLen { return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) } - if strings.HasPrefix(curve.Params().Name, "P-") { + if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { // y² = x³ - 3x + b, prime curves x := toPointXY(bytes[1 : 1+byteLen]) y, err := calculatePrimeCurveY(curve, x) diff --git a/sm2/x509.go b/sm2/x509.go new file mode 100644 index 0000000..85402c5 --- /dev/null +++ b/sm2/x509.go @@ -0,0 +1,75 @@ +package sm2 + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "math/big" +) + +type publicKeyInfo struct { + Raw asn1.RawContent + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString +} + +// pkcs1PublicKey reflects the ASN.1 structure of a PKCS#1 public key. +type pkcs1PublicKey struct { + N *big.Int + E int +} + +// http://gmssl.org/docs/oid.html +var ( + oidPublicKeyECDSA = asn1.ObjectIdentifier{1, 2, 840, 10045, 2, 1} + oidNamedCurveP256SM2 = asn1.ObjectIdentifier{1, 2, 156, 10197, 1, 301} +) + +// ParsePKIXPublicKey parses a public key in PKIX, ASN.1 DER form. +// +// It returns a *rsa.PublicKey, *dsa.PublicKey, *ecdsa.PublicKey, or +// ed25519.PublicKey. More types might be supported in the future. +// +// This kind of key is commonly encoded in PEM blocks of type "PUBLIC KEY". +func ParsePKIXPublicKey(derBytes []byte) (interface{}, error) { + var pki publicKeyInfo + if rest, err := asn1.Unmarshal(derBytes, &pki); err != nil { + if _, err := asn1.Unmarshal(derBytes, &pkcs1PublicKey{}); err == nil { + return nil, errors.New("x509: failed to parse public key (use ParsePKCS1PublicKey instead for this key format)") + } + return nil, err + } else if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ASN.1 of public-key") + } + + if !pki.Algorithm.Algorithm.Equal(oidPublicKeyECDSA) { + return nil, errors.New("x509: invalid public key algorithm") + } + keyData := &pki + asn1Data := keyData.PublicKey.RightAlign() + paramsData := keyData.Algorithm.Parameters.FullBytes + namedCurveOID := new(asn1.ObjectIdentifier) + namedCurve := P256() + rest, err := asn1.Unmarshal(paramsData, namedCurveOID) + if err != nil { + return nil, errors.New("x509: failed to parse ECDSA parameters as named curve") + } + if len(rest) != 0 { + return nil, errors.New("x509: trailing data after ECDSA parameters") + } + if !namedCurveOID.Equal(oidNamedCurveP256SM2) { + return nil, errors.New("x509: it's not SM2 elliptic curve") + } + x, y := elliptic.Unmarshal(namedCurve, asn1Data) + if x == nil { + return nil, errors.New("x509: failed to unmarshal elliptic curve point") + } + pub := &ecdsa.PublicKey{ + Curve: namedCurve, + X: x, + Y: y, + } + return pub, nil +} diff --git a/sm2/x509_test.go b/sm2/x509_test.go new file mode 100644 index 0000000..3bcb961 --- /dev/null +++ b/sm2/x509_test.go @@ -0,0 +1,36 @@ +package sm2 + +import ( + "crypto/ecdsa" + "crypto/rand" + "encoding/pem" + "errors" + "testing" +) + +const publicKeyPemFromAliKms = ` +-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAELfjZP28bYfGSvbODYlXiB5bcoXE+ +2LRjjpIH3DcCCct9FuVhi9cm60nDFrbW49k2D3GJco2iWPlr0+5LV+t4AQ== +-----END PUBLIC KEY----- +` + +func getPublicKey(pemContent []byte) (interface{}, error) { + block, _ := pem.Decode(pemContent) + if block == nil { + return nil, errors.New("Failed to parse PEM block") + } + return ParsePKIXPublicKey(block.Bytes) +} + +func TestParsePKIXPublicKey(t *testing.T) { + pub, err := getPublicKey([]byte(publicKeyPemFromAliKms)) + if err != nil { + t.Fatal(err) + } + pub1 := pub.(*ecdsa.PublicKey) + _, err = Encrypt(rand.Reader, pub1, []byte("testfile")) + if err != nil { + t.Fatal(err) + } +}