sm2: recover public keys from signature

This commit is contained in:
Sun Yimin 2024-05-31 18:30:58 +08:00 committed by GitHub
parent 81b0c7f5ae
commit 63affe5127
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 120 additions and 0 deletions

View File

@ -705,6 +705,83 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
})
}
var ErrInvalidSignature = errors.New("sm2: invalid signature")
// RecoverPublicKeysFromSM2Signature recovers two SM2 public keys from a given signature and hash.
// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey.
// If the signature or hash is invalid, it returns an error.
// The function follows the SM2 algorithm to recover the public keys.
func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) {
c := p256()
rBytes, sBytes, err := parseSignature(sig)
if err != nil {
return nil, err
}
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
if err != nil || r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
if err != nil || s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
e := bigmod.NewNat()
hashToNat(c, e, hash)
// p₁ = [-s]G
negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N)
p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N))
if err != nil {
return nil, err
}
// s = [r + s]
s.Add(r, c.N)
if s.IsZero() == 1 {
return nil, ErrInvalidSignature
}
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
if err != nil {
return nil, err
}
// Rx = r - e
if r.CmpGeq(e) == 0 {
// If r < e, then Rx = N - e + r
n0 := bigmod.NewNat().Set(c.N.Nat())
n0.Sub(e, c.P)
r.Add(n0, c.P)
} else {
r.Sub(e, c.P)
}
if r.IsZero() == 1 {
return nil, ErrInvalidSignature
}
rBytes = r.Bytes(c.P)
tmp := make([]byte, len(rBytes)+1)
copy(tmp[1:], rBytes)
compressFlags := []byte{compressed02, compressed03}
pks := make([]*ecdsa.PublicKey, 0, 2)
for _, flag := range compressFlags {
tmp[0] = flag
p0, err := c.newPoint().SetBytes(tmp)
if err != nil {
return nil, err
}
p0.Add(p0, p1)
p0.ScalarMult(p0, sBytes)
pk := new(ecdsa.PublicKey)
pk.Curve = c.curve
pk.X, pk.Y, err = c.pointToAffine(p0)
if err != nil {
return nil, err
}
pks = append(pks, pk)
}
return pks, nil
}
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
// public key, pub. Its return value records whether the signature is valid.
//
@ -922,6 +999,7 @@ type sm2Curve struct {
newPoint func() *_sm2ec.SM2P256Point
curve elliptic.Curve
N *bigmod.Modulus
P *bigmod.Modulus
nMinus1 *bigmod.Nat
nMinus2 []byte
}
@ -975,6 +1053,7 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
params := curve.Params()
c.curve = curve
c.N, _ = bigmod.NewModulusFromBig(params.N)
c.P, _ = bigmod.NewModulusFromBig(params.P)
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N)
}

View File

@ -470,6 +470,47 @@ func TestSignVerify(t *testing.T) {
}
}
func TestRecoverSM2PublicKeyFromSig(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
}{
{"less than 32", "encryption standard"},
{"equals 32", "encryption standard encryption "},
{"long than 32", "encryption standard encryption standard"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
hashValue, err := CalculateSM2Hash(&priv.PublicKey, []byte(tt.plainText), nil)
if err != nil {
t.Fatalf("hash failed %v", err)
}
sig, err := priv.Sign(rand.Reader, hashValue, nil)
if err != nil {
t.Fatalf("sign failed %v", err)
}
pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig)
if err != nil {
t.Fatalf("recover failed %v", err)
}
found := false
for _, pub := range pubs {
if !VerifyASN1(pub, hashValue, sig) {
t.Errorf("failed to verify hash")
}
if pub.Equal(&priv.PublicKey) {
found = true
}
}
if !found {
t.Errorf("recover failed, not found public key")
}
})
}
}
func TestSignVerifyLegacy(t *testing.T) {
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
tests := []struct {