From 191cd2622d09ea31a80e6121d88817f84d1141ae Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Thu, 28 Nov 2024 08:44:30 +0800 Subject: [PATCH] stealth address fix private key issue --- .github/workflows/ci.yml | 2 +- ecdh/sm2ec.go | 25 +++++++++++++++++++------ ecdh/stealth_test.go | 38 ++++++++++++++++++++++++++++---------- 3 files changed, 48 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c66b5d8..c877e88 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,7 +2,7 @@ name: ci on: push: - branches: [ main ] + branches: [ main, blockchain ] pull_request: branches: [ main ] diff --git a/ecdh/sm2ec.go b/ecdh/sm2ec.go index 7ed16b9..f3e1129 100644 --- a/ecdh/sm2ec.go +++ b/ecdh/sm2ec.go @@ -16,6 +16,7 @@ import ( type sm2Curve struct { name string newPoint func() *sm2ec.SM2P256Point + scalarOrder []byte scalarOrderMinus1 []byte constantA []byte constantB []byte @@ -49,10 +50,14 @@ func (c *sm2Curve) GenerateKey(rand io.Reader) (*PrivateKey, error) { } func (c *sm2Curve) NewPrivateKey(key []byte) (*PrivateKey, error) { - if len(key) != len(c.scalarOrderMinus1) { + return c.newPrivateKey(key, true) +} + +func (c *sm2Curve) newPrivateKey(key []byte, checkOrderMinus1 bool) (*PrivateKey, error) { + if len(key) != len(c.scalarOrder) { return nil, errors.New("ecdh: invalid private key size") } - if subtle.ConstantTimeAllZero(key) == 1 || !isLess(key, c.scalarOrderMinus1) { + if subtle.ConstantTimeAllZero(key) == 1 || (checkOrderMinus1 && !isLess(key, c.scalarOrderMinus1)) { return nil, errInvalidPrivateKey } return &PrivateKey{ @@ -84,11 +89,13 @@ func (c *sm2Curve) privateKeyToPublicKey(key *PrivateKey) *PublicKey { } } +// GenerateKeyFromScalar generates a private key from a scalar. The scalar will +// be reduced to the range [0, Order). func (c *sm2Curve) GenerateKeyFromScalar(scalar []byte) (*PrivateKey, error) { - if size := len(c.scalarOrderMinus1); len(scalar) > size { + if size := len(c.scalarOrder); len(scalar) > size { scalar = scalar[:size] } - m, err := bigmod.NewModulus(c.scalarOrderMinus1) + m, err := bigmod.NewModulus(c.scalarOrder) if err != nil { return nil, err } @@ -96,7 +103,7 @@ func (c *sm2Curve) GenerateKeyFromScalar(scalar []byte) (*PrivateKey, error) { if err != nil { return nil, err } - return c.NewPrivateKey(p.Bytes(m)) + return c.newPrivateKey(p.Bytes(m), false) } func (c *sm2Curve) NewPublicKey(key []byte) (*PublicKey, error) { @@ -141,7 +148,7 @@ func (c *sm2Curve) addPublicKeys(a, b *PublicKey) (*PublicKey, error) { } func (c *sm2Curve) addPrivateKeys(a, b *PrivateKey) (*PrivateKey, error) { - m, err := bigmod.NewModulus(c.scalarOrderMinus1) + m, err := bigmod.NewModulus(c.scalarOrder) if err != nil { return nil, err } @@ -240,12 +247,18 @@ func P256() Curve { return sm2P256 } var sm2P256 = &sm2Curve{ name: "sm2p256v1", newPoint: sm2ec.NewSM2P256Point, + scalarOrder: sm2P256Order, scalarOrderMinus1: sm2P256OrderMinus1, generator: sm2Generator, constantA: sm2ConstantA, constantB: sm2ConstantB, } +var sm2P256Order = []byte{ + 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x72, 0x03, 0xdf, 0x6b, 0x21, 0xc6, 0x05, 0x2b, + 0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x23} var sm2P256OrderMinus1 = []byte{ 0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, diff --git a/ecdh/stealth_test.go b/ecdh/stealth_test.go index 2b28aae..f14b84d 100644 --- a/ecdh/stealth_test.go +++ b/ecdh/stealth_test.go @@ -3,6 +3,7 @@ package ecdh import ( "crypto/rand" "testing" + "time" "github.com/emmansun/gmsm/sm3" ) @@ -10,18 +11,18 @@ import ( // https://eips.ethereum.org/EIPS/eip-5564, but uses SM3 instead of Keccak256 // Generation - Generate stealth address from stealth meta-address -func generateStealthAddress(spendPub, viewPub *PublicKey) (ephemeralPub *PublicKey, stealth *PublicKey, err error) { +func generateStealthAddress(spendPub, viewPub *PublicKey) (ephemeralPub *PublicKey, stealth *PublicKey, viewTag byte, err error) { // generate ephemeral key pair ephemeralPriv, err := P256().GenerateKey(rand.Reader) if err != nil { - return nil, nil, err + return nil, nil, 0, err } ephemeralPub = ephemeralPriv.PublicKey() // compute shared secret key R, err := ephemeralPriv.SecretKey(viewPub) if err != nil { - return nil, nil, err + return nil, nil, 0, err } // the secret key is hashed @@ -30,20 +31,20 @@ func generateStealthAddress(spendPub, viewPub *PublicKey) (ephemeralPub *PublicK // multiply the hashed shared secret with the generator point shPriv, err := P256().GenerateKeyFromScalar(sh[:]) if err != nil { - return nil, nil, err + return nil, nil, 0, err } shPublic := shPriv.PublicKey() // compute the recipient's stealth public key stealth, err = shPublic.Add(spendPub) if err != nil { - return nil, nil, err + return nil, nil, 0, err } - return ephemeralPub, stealth, nil + return ephemeralPub, stealth, sh[0], nil } // Parsing - Locate one’s own stealth address -func checkStealthAddress(viewPriv *PrivateKey, spendPub, ephemeralPub, stealth *PublicKey) (bool, error) { +func checkStealthAddress(viewPriv *PrivateKey, spendPub, ephemeralPub, stealth *PublicKey, viewTag byte) (bool, error) { // compute shared secret key R, err := viewPriv.SecretKey(ephemeralPub) if err != nil { @@ -51,6 +52,9 @@ func checkStealthAddress(viewPriv *PrivateKey, spendPub, ephemeralPub, stealth * } // the secret key is hashed sh := sm3.Sum(R[1:]) + if sh[0] != viewTag { + return false, nil + } // multiply the hashed shared secret with the generator point shPriv, err := P256().GenerateKeyFromScalar(sh[:]) if err != nil { @@ -86,13 +90,13 @@ func computeStealthKey(spendPriv, viewPriv *PrivateKey, ephemeralPub *PublicKey) func testEIP5564StealthAddress(t *testing.T, spendPriv, viewPriv *PrivateKey) { t.Helper() - ephemeralPub, expectedStealth, err := generateStealthAddress(spendPriv.PublicKey(), viewPriv.PublicKey()) + ephemeralPub, expectedStealth, viewTag, err := generateStealthAddress(spendPriv.PublicKey(), viewPriv.PublicKey()) if err != nil { t.Fatalf("the recipient's stealth public key: failed to add public keys: %v", err) } - passed, err := checkStealthAddress(viewPriv, spendPriv.PublicKey(), ephemeralPub, expectedStealth) + passed, err := checkStealthAddress(viewPriv, spendPriv.PublicKey(), ephemeralPub, expectedStealth, viewTag) if err != nil { t.Fatal(err) } @@ -118,5 +122,19 @@ func TestEIP5564StealthAddress(t *testing.T) { if err != nil { t.Fatalf("failed to generate private key: %v", err) } - testEIP5564StealthAddress(t, privSpend, privView) + var timeout *time.Timer + + if testing.Short() { + timeout = time.NewTimer(50 * time.Millisecond) + } else { + timeout = time.NewTimer(5 * time.Second) + } + for { + select { + case <-timeout.C: + return + default: + } + testEIP5564StealthAddress(t, privSpend, privView) + } }