sm2: public recover from signature 2

This commit is contained in:
Sun Yimin 2024-06-04 08:26:51 +08:00 committed by GitHub
parent 126ee25d2a
commit cad5d3504e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 43 additions and 26 deletions

View File

@ -707,7 +707,7 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
var ErrInvalidSignature = errors.New("sm2: invalid signature") var ErrInvalidSignature = errors.New("sm2: invalid signature")
// RecoverPublicKeysFromSM2Signature recovers two SM2 public keys from a given signature and hash. // RecoverPublicKeysFromSM2Signature recovers two or four SM2 public keys from a given signature and hash.
// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey. // It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey.
// If the signature or hash is invalid, it returns an error. // If the signature or hash is invalid, it returns an error.
// The function follows the SM2 algorithm to recover the public keys. // The function follows the SM2 algorithm to recover the public keys.
@ -741,38 +741,55 @@ func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, er
if s.IsZero() == 1 { if s.IsZero() == 1 {
return nil, ErrInvalidSignature return nil, ErrInvalidSignature
} }
// sBytes = (r+s)⁻¹
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N)) sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// r = (Rx + e) mod N
// Rx = r - e // Rx = r - e
r.Sub(e, c.N) r.Sub(e, c.N)
if r.IsZero() == 1 { if r.IsZero() == 1 {
return nil, ErrInvalidSignature return nil, ErrInvalidSignature
} }
rBytes = r.Bytes(c.N) pointRx := make([]*bigmod.Nat, 0, 2)
tmp := make([]byte, len(rBytes)+1) pointRx = append(pointRx, r)
copy(tmp[1:], rBytes) // check if Rx in (N, P), small probability event
s.Set(r)
s = s.Add(c.N.Nat(), c.P)
if s.CmpGeq(c.N.Nat()) == 1 {
pointRx = append(pointRx, s)
}
pubs := make([]*ecdsa.PublicKey, 0, 4)
bytes := make([]byte, len(rBytes)+1)
compressFlags := []byte{compressed02, compressed03} compressFlags := []byte{compressed02, compressed03}
pks := make([]*ecdsa.PublicKey, 0, 2) // Rx has one or two possible values, so point R has two or four possible values
for _, x := range pointRx {
rBytes = x.Bytes(c.N)
copy(bytes[1:], rBytes)
for _, flag := range compressFlags { for _, flag := range compressFlags {
tmp[0] = flag bytes[0] = flag
p0, err := c.newPoint().SetBytes(tmp) // p0 = R
p0, err := c.newPoint().SetBytes(bytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// p0 = R - [s]G
p0.Add(p0, p1) p0.Add(p0, p1)
// Pub = [(r + s)⁻¹](R - [s]G)
p0.ScalarMult(p0, sBytes) p0.ScalarMult(p0, sBytes)
pk := new(ecdsa.PublicKey) pub := new(ecdsa.PublicKey)
pk.Curve = c.curve pub.Curve = c.curve
pk.X, pk.Y, err = c.pointToAffine(p0) pub.X, pub.Y, err = c.pointToAffine(p0)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pks = append(pks, pk) pubs = append(pubs, pub)
} }
return pks, nil }
return pubs, nil
} }
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the // VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the

View File

@ -470,7 +470,7 @@ func TestSignVerify(t *testing.T) {
} }
} }
func TestRecoverSM2PublicKeyFromSig(t *testing.T) { func TestRecoverPublicKeysFromSM2Signature(t *testing.T) {
priv, _ := GenerateKey(rand.Reader) priv, _ := GenerateKey(rand.Reader)
tests := []struct { tests := []struct {
name string name string
@ -493,19 +493,19 @@ func TestRecoverSM2PublicKeyFromSig(t *testing.T) {
pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig) pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig)
if err != nil { if err != nil {
t.Fatalf("recover failed %v", err) t.Fatalf("recover sig=%x, priv=%x, failed %v", sig, priv.D.Bytes(), err)
} }
found := false found := false
for _, pub := range pubs { for _, pub := range pubs {
if !VerifyASN1(pub, hashValue, sig) { if !VerifyASN1(pub, hashValue, sig) {
t.Errorf("failed to verify hash") t.Errorf("failed to verify hash for sig=%x, priv=%x", sig, priv.D.Bytes())
} }
if pub.Equal(&priv.PublicKey) { if pub.Equal(&priv.PublicKey) {
found = true found = true
} }
} }
if !found { if !found {
t.Errorf("recover failed, not found public key") t.Errorf("recover failed, not found public key for sig=%x, priv=%x", sig, priv.D.Bytes())
} }
}) })
} }