From 63affe51274c1533056339954300ebec5d30cae9 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 31 May 2024 18:30:58 +0800 Subject: [PATCH] sm2: recover public keys from signature --- sm2/sm2.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++++ sm2/sm2_test.go | 41 +++++++++++++++++++++++++ 2 files changed, 120 insertions(+) diff --git a/sm2/sm2.go b/sm2/sm2.go index 526373e..d7a5c6d 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -705,6 +705,83 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) { }) } +var ErrInvalidSignature = errors.New("sm2: invalid signature") + +// RecoverPublicKeysFromSM2Signature recovers two SM2 public keys from a given signature and hash. +// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey. +// If the signature or hash is invalid, it returns an error. +// The function follows the SM2 algorithm to recover the public keys. +func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) { + c := p256() + rBytes, sBytes, err := parseSignature(sig) + if err != nil { + return nil, err + } + r, err := bigmod.NewNat().SetBytes(rBytes, c.N) + if err != nil || r.IsZero() == 1 { + return nil, ErrInvalidSignature + } + s, err := bigmod.NewNat().SetBytes(sBytes, c.N) + if err != nil || s.IsZero() == 1 { + return nil, ErrInvalidSignature + } + + e := bigmod.NewNat() + hashToNat(c, e, hash) + + // p₁ = [-s]G + negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N) + p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N)) + if err != nil { + return nil, err + } + + // s = [r + s] + s.Add(r, c.N) + if s.IsZero() == 1 { + return nil, ErrInvalidSignature + } + sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N)) + if err != nil { + return nil, err + } + + // Rx = r - e + if r.CmpGeq(e) == 0 { + // If r < e, then Rx = N - e + r + n0 := bigmod.NewNat().Set(c.N.Nat()) + n0.Sub(e, c.P) + r.Add(n0, c.P) + } else { + r.Sub(e, c.P) + } + if r.IsZero() == 1 { + return nil, ErrInvalidSignature + } + rBytes = r.Bytes(c.P) + tmp := make([]byte, len(rBytes)+1) + copy(tmp[1:], rBytes) + compressFlags := []byte{compressed02, compressed03} + pks := make([]*ecdsa.PublicKey, 0, 2) + for _, flag := range compressFlags { + tmp[0] = flag + p0, err := c.newPoint().SetBytes(tmp) + if err != nil { + return nil, err + } + p0.Add(p0, p1) + p0.ScalarMult(p0, sBytes) + pk := new(ecdsa.PublicKey) + pk.Curve = c.curve + pk.X, pk.Y, err = c.pointToAffine(p0) + if err != nil { + return nil, err + } + pks = append(pks, pk) + } + return pks, nil +} + // VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the // public key, pub. Its return value records whether the signature is valid. // @@ -922,6 +999,7 @@ type sm2Curve struct { newPoint func() *_sm2ec.SM2P256Point curve elliptic.Curve N *bigmod.Modulus + P *bigmod.Modulus nMinus1 *bigmod.Nat nMinus2 []byte } @@ -975,6 +1053,7 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) { params := curve.Params() c.curve = curve c.N, _ = bigmod.NewModulusFromBig(params.N) + c.P, _ = bigmod.NewModulusFromBig(params.P) c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes() c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N) } diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index c5cd193..c998fb2 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -470,6 +470,47 @@ func TestSignVerify(t *testing.T) { } } +func TestRecoverSM2PublicKeyFromSig(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + }{ + {"less than 32", "encryption standard"}, + {"equals 32", "encryption standard encryption "}, + {"long than 32", "encryption standard encryption standard"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hashValue, err := CalculateSM2Hash(&priv.PublicKey, []byte(tt.plainText), nil) + if err != nil { + t.Fatalf("hash failed %v", err) + } + sig, err := priv.Sign(rand.Reader, hashValue, nil) + if err != nil { + t.Fatalf("sign failed %v", err) + } + + pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig) + if err != nil { + t.Fatalf("recover failed %v", err) + } + found := false + for _, pub := range pubs { + if !VerifyASN1(pub, hashValue, sig) { + t.Errorf("failed to verify hash") + } + if pub.Equal(&priv.PublicKey) { + found = true + } + } + if !found { + t.Errorf("recover failed, not found public key") + } + }) + } +} + func TestSignVerifyLegacy(t *testing.T) { priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) tests := []struct {