From d7e853277a8226e844c457d03ca7ba1e75cf88dd Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 25 Nov 2022 10:11:46 +0800 Subject: [PATCH] sm9: use bigmod instead of math/big --- sm9/bn256/bn_pair_test.go | 5 +- sm9/bn256/g1.go | 46 +++++++----- sm9/bn256/g1_test.go | 5 +- sm9/bn256/g2.go | 19 ++--- sm9/bn256/g2_test.go | 13 +++- sm9/bn256/gt.go | 51 ++++++++++++- sm9/bn256/gt_test.go | 13 ++++ sm9/sm9.go | 151 +++++++++++++++++++++++++++----------- sm9/sm9_key.go | 118 +++++++++++++++++++---------- sm9/sm9_test.go | 59 ++++++++++++--- 10 files changed, 357 insertions(+), 123 deletions(-) diff --git a/sm9/bn256/bn_pair_test.go b/sm9/bn256/bn_pair_test.go index 019b553..0fce722 100644 --- a/sm9/bn256/bn_pair_test.go +++ b/sm9/bn256/bn_pair_test.go @@ -53,7 +53,10 @@ func init() { func Test_Pairing_A2(t *testing.T) { pk := bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") g2 := &G2{} - g2.ScalarBaseMult(pk) + _, err := g2.ScalarBaseMult(NormalizeScalar(pk.Bytes())) + if err != nil { + t.Fatal(err) + } ret := pairing(g2.p, curveGen) if ret.x != expected1.x || ret.y != expected1.y || ret.z != expected1.z { t.Errorf("not expected") diff --git a/sm9/bn256/g1.go b/sm9/bn256/g1.go index 74b1a0a..7e0feed 100644 --- a/sm9/bn256/g1.go +++ b/sm9/bn256/g1.go @@ -59,14 +59,15 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) { return nil, nil, err } - return k, new(G1).ScalarBaseMult(k), nil + g1, err := new(G1).ScalarBaseMult(NormalizeScalar(k.Bytes())) + return k, g1, err } func (g *G1) String() string { return "sm9.G1" + g.p.String() } -func normalizeScalar(scalar []byte) []byte { +func NormalizeScalar(scalar []byte) []byte { if len(scalar) == 32 { return scalar } @@ -78,16 +79,18 @@ func normalizeScalar(scalar []byte) []byte { return s.FillBytes(out) } -// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// ScalarBaseMult sets e to scaler*g where g is the generator of the group and then // returns e. -func (e *G1) ScalarBaseMult(k *big.Int) *G1 { +func (e *G1) ScalarBaseMult(scalar []byte) (*G1, error) { + if len(scalar) != 32 { + return nil, errors.New("invalid scalar length") + } if e.p == nil { e.p = &curvePoint{} } //e.p.Mul(curveGen, k) - scalar := normalizeScalar(k.Bytes()) tables := e.generatorTable() // This is also a scalar multiplication with a four-bit window like in // ScalarMult, but in this case the doublings are precomputed. The value @@ -108,11 +111,11 @@ func (e *G1) ScalarBaseMult(k *big.Int) *G1 { e.p.Add(e.p, t) tableIndex-- } - return e + return e, nil } // ScalarMult sets e to a*k and then returns e. -func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { +func (e *G1) ScalarMult(a *G1, scalar []byte) (*G1, error) { if e.p == nil { e.p = &curvePoint{} } @@ -131,8 +134,7 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { // four-bit window: we double four times, and then add [0-15]P. t := &G1{NewCurvePoint()} e.p.SetInfinity() - scalarBytes := normalizeScalar(k.Bytes()) - for i, byte := range scalarBytes { + for i, byte := range scalar { // No need to double on the first iteration, as p is the identity at // this point, and [N]∞ = ∞. if i != 0 { @@ -152,7 +154,7 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { table.Select(t.p, windowValue) e.Add(e, t) } - return e + return e, nil } // Add sets e to a+b and then returns e. @@ -398,27 +400,37 @@ func (g1 *G1Curve) Params() *CurveParams { // normalizeScalar brings the scalar within the byte size of the order of the // curve, as expected by the nistec scalar multiplication functions. -func (curve *G1Curve) normalizeScalar(scalar []byte) *big.Int { +func (curve *G1Curve) normalizeScalar(scalar []byte) []byte { byteSize := (curve.params.N.BitLen() + 7) / 8 s := new(big.Int).SetBytes(scalar) if len(scalar) > byteSize { s.Mod(s, curve.params.N) } - return s + out := make([]byte, byteSize) + return s.FillBytes(out) } -func (g1 *G1Curve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { - scalar := g1.normalizeScalar(k) - res := g1.g.ScalarBaseMult(scalar).Marshal() +func (g1 *G1Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) { + scalar = g1.normalizeScalar(scalar) + p, err := g1.g.ScalarBaseMult(scalar) + if err != nil { + panic("sm9: g1 rejected normalized scalar") + } + res := p.Marshal() return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) } -func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { +func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) { a, err := g1.pointFromAffine(Bx, By) if err != nil { panic("sm9: ScalarMult was called on an invalid point") } - res := g1.g.ScalarMult(a, new(big.Int).SetBytes(k)).Marshal() + scalar = g1.normalizeScalar(scalar) + p, err := g1.g.ScalarMult(a, scalar) + if err != nil { + panic("sm9: g1 rejected normalized scalar") + } + res := p.Marshal() return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) } diff --git a/sm9/bn256/g1_test.go b/sm9/bn256/g1_test.go index 273558b..acc8bb8 100644 --- a/sm9/bn256/g1_test.go +++ b/sm9/bn256/g1_test.go @@ -177,7 +177,10 @@ func TestG1ScaleMult(t *testing.T) { t.Errorf("not same") } - e3.ScalarMult(Gen1, k) + _, err = e3.ScalarMult(Gen1, NormalizeScalar(k.Bytes())) + if err != nil { + t.Fatal(err) + } e3.p.MakeAffine() if !e.Equal(e3) { diff --git a/sm9/bn256/g2.go b/sm9/bn256/g2.go index 9dc3d2c..ba13436 100644 --- a/sm9/bn256/g2.go +++ b/sm9/bn256/g2.go @@ -47,8 +47,8 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) { if err != nil { return nil, nil, err } - - return k, new(G2).ScalarBaseMult(k), nil + g2, err := new(G2).ScalarBaseMult(NormalizeScalar(k.Bytes())) + return k, g2, err } func (e *G2) String() string { @@ -57,13 +57,15 @@ func (e *G2) String() string { // ScalarBaseMult sets e to g*k where g is the generator of the group and then // returns out. -func (e *G2) ScalarBaseMult(k *big.Int) *G2 { +func (e *G2) ScalarBaseMult(scalar []byte) (*G2, error) { + if len(scalar) != 32 { + return nil, errors.New("invalid scalar length") + } if e.p == nil { e.p = &twistPoint{} } //e.p.Mul(twistGen, k) - scalar := normalizeScalar(k.Bytes()) tables := e.generatorTable() // This is also a scalar multiplication with a four-bit window like in // ScalarMult, but in this case the doublings are precomputed. The value @@ -85,11 +87,11 @@ func (e *G2) ScalarBaseMult(k *big.Int) *G2 { tableIndex-- } - return e + return e, nil } // ScalarMult sets e to a*k and then returns e. -func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { +func (e *G2) ScalarMult(a *G2, scalar []byte) (*G2, error) { if e.p == nil { e.p = &twistPoint{} } @@ -108,8 +110,7 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { // four-bit window: we double four times, and then add [0-15]P. t := &G2{NewTwistPoint()} e.p.SetInfinity() - scalarBytes := normalizeScalar(k.Bytes()) - for i, byte := range scalarBytes { + for i, byte := range scalar { // No need to double on the first iteration, as p is the identity at // this point, and [N]∞ = ∞. if i != 0 { @@ -129,7 +130,7 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { table.Select(t.p, windowValue) e.Add(e, t) } - return e + return e, nil } // Add sets e to a+b and then returns e. diff --git a/sm9/bn256/g2_test.go b/sm9/bn256/g2_test.go index d5fb608..376e662 100644 --- a/sm9/bn256/g2_test.go +++ b/sm9/bn256/g2_test.go @@ -14,7 +14,10 @@ func TestG2(t *testing.T) { } ma := Ga.Marshal() - Gb := new(G2).ScalarBaseMult(k) + Gb, err := new(G2).ScalarBaseMult(NormalizeScalar(k.Bytes())) + if err != nil { + t.Fatal(err) + } mb := Gb.Marshal() if !bytes.Equal(ma, mb) { @@ -86,7 +89,10 @@ func TestScaleMult(t *testing.T) { e3.p.Mul(twistGen, k) e3.p.MakeAffine() - e2.ScalarMult(Gen2, k) + _, err = e2.ScalarMult(Gen2, NormalizeScalar(k.Bytes())) + if err != nil { + t.Fatal(err) + } e2.p.MakeAffine() if !e.Equal(e2) { t.Errorf("not same") @@ -110,10 +116,11 @@ func TestG2AddNeg(t *testing.T) { func BenchmarkG2(b *testing.B) { x, _ := rand.Int(rand.Reader, Order) + xb := NormalizeScalar(x.Bytes()) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - new(G2).ScalarBaseMult(x) + new(G2).ScalarBaseMult(xb) } } diff --git a/sm9/bn256/gt.go b/sm9/bn256/gt.go index cab9390..004acbe 100644 --- a/sm9/bn256/gt.go +++ b/sm9/bn256/gt.go @@ -241,8 +241,10 @@ func GenerateGTFieldTable(basePoint *GT) *[32 * 2]GTFieldTable { } // ScalarBaseMultGT compute basepoint^r with precomputed table -func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, r *big.Int) *GT { - scalar := normalizeScalar(r.Bytes()) +func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, scalar []byte) (*GT, error) { + if len(scalar) != 32 { + return nil, errors.New("invalid scalar length") + } // This is also a scalar multiplication with a four-bit window like in // ScalarMult, but in this case the doublings are precomputed. The value // [windowValue]G added at iteration k would normally get doubled @@ -263,5 +265,48 @@ func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, r *big.Int) *GT { e.Add(e, t) tableIndex-- } - return e + return e, nil +} + +// ScalarMultGT compute a^scalar +func ScalarMultGT(a *GT, scalar []byte) (*GT, error) { + var table GTFieldTable + + table[0] = >{} + table[0].Set(a) + for i := 1; i < 15; i += 2 { + table[i] = >{} + table[i].p = &gfP12{} + table[i].p.Square(table[i/2].p) + + table[i+1] = >{} + table[i+1].p = &gfP12{} + table[i+1].Add(table[i], a) + } + + e, t := >{}, >{} + e.SetOne() + t.SetOne() + + for i, byte := range scalar { + // No need to double on the first iteration, as p is the identity at + // this point, and [N]∞ = ∞. + if i != 0 { + e.p.Square(e.p) + e.p.Square(e.p) + e.p.Square(e.p) + e.p.Square(e.p) + } + windowValue := byte >> 4 + table.Select(t, windowValue) + e.Add(e, t) + e.p.Square(e.p) + e.p.Square(e.p) + e.p.Square(e.p) + e.p.Square(e.p) + windowValue = byte & 0b1111 + table.Select(t, windowValue) + e.Add(e, t) + } + return e, nil } diff --git a/sm9/bn256/gt_test.go b/sm9/bn256/gt_test.go index 943ecf2..b64d6f4 100644 --- a/sm9/bn256/gt_test.go +++ b/sm9/bn256/gt_test.go @@ -24,6 +24,19 @@ func TestGT(t *testing.T) { if !bytes.Equal(ma, mb) { t.Fatal("bytes are different") } + + _, err = Gb.Unmarshal((>{gfP12Gen}).Marshal()) + if err != nil { + t.Fatal("unmarshal not ok") + } + Gc, err := ScalarMultGT(Gb, k.Bytes()) + if err != nil { + t.Fatal(err) + } + mc := Gc.Marshal() + if !bytes.Equal(ma, mc) { + t.Fatal("bytes are different") + } } func BenchmarkGT(b *testing.B) { diff --git a/sm9/sm9.go b/sm9/sm9.go index ccba988..24e2c1e 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -10,6 +10,7 @@ import ( "io" "math/big" + "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" @@ -20,6 +21,10 @@ import ( // SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification +// OrderNat is the Nat presentation of Order +var OrderNat = bigmod.NewModulusFromBig(bn256.Order) +var OrderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes() + var bigOne = big.NewInt(1) type hashMode byte @@ -57,6 +62,7 @@ func hash(z []byte, h hashMode) *big.Int { ct++ md.Reset() } + //TODO: how to rewrite this part with nat? k := new(big.Int).SetBytes(ha[:40]) n := new(big.Int).Sub(bn256.Order, bigOne) k.Mod(k, n) @@ -72,48 +78,70 @@ func hashH2(z []byte) *big.Int { return hash(z, H2) } -// randFieldElement returns a random element of the order of the given -// curve using the procedure given in FIPS 186-4, Appendix B.5.1. -func randFieldElement(rand io.Reader) (k *big.Int, err error) { - b := make([]byte, 40) // (256 + 64) / 8 - _, err = io.ReadFull(rand, b) - if err != nil { - return - } +func randomScalar(rand io.Reader) (k *bigmod.Nat, err error) { + k = bigmod.NewNat() + for { + b := make([]byte, OrderNat.Size()) + if _, err = io.ReadFull(rand, b); err != nil { + return + } - k = new(big.Int).SetBytes(b) - n := new(big.Int).Sub(bn256.Order, bigOne) - k.Mod(k, n) - k.Add(k, bigOne) + // Mask off any excess bits to increase the chance of hitting a value in + // (0, N). These are the most dangerous lines in the package and maybe in + // the library: a single bit of bias in the selection of nonces would likely + // lead to key recovery, but no tests would fail. Look but DO NOT TOUCH. + if excess := len(b)*8 - OrderNat.BitLen(); excess > 0 { + // Just to be safe, assert that this only happens for the one curve that + // doesn't have a round number of bits. + if excess != 0 { + panic("sm9: internal error: unexpectedly masking off bits") + } + b[0] >>= excess + } + + // FIPS 186-4 makes us check k <= N - 2 and then add one. + // Checking 0 < k <= N - 1 is strictly equivalent. + // None of this matters anyway because the chance of selecting + // zero is cryptographically negligible. + if _, err = k.SetBytes(b, OrderNat); err == nil && k.IsZero() == 0 { + break + } + } return } // Sign signs a hash (which should be the result of hashing a larger message) // using the user dsa key. It returns the signature as a pair of h and s. func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn256.G1, err error) { - var r *big.Int + var ( + r *bigmod.Nat + w *bn256.GT + hNat *bigmod.Nat + ) for { - r, err = randFieldElement(rand) + r, err = randomScalar(rand) if err != nil { return } - w := priv.SignMasterPublicKey.ScalarBaseMult(r) + w, err = priv.SignMasterPublicKey.ScalarBaseMult(r.Bytes(OrderNat)) + if err != nil { + return + } var buffer []byte buffer = append(buffer, hash...) buffer = append(buffer, w.Marshal()...) h = hashH2(buffer) - - l := new(big.Int).Sub(r, h) - - if l.Sign() < 0 { - l.Add(l, bn256.Order) + hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), OrderNat) + if err != nil { + return } + r.Sub(hNat, OrderNat) - if l.Sign() != 0 { - s = new(bn256.G1).ScalarMult(priv.PrivateKey, l) + if r.IsZero() == 0 { + s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(OrderNat)) break } } @@ -129,7 +157,7 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer return nil, err } - hBytes := make([]byte, 32) + hBytes := make([]byte, OrderNat.Size()) h.FillBytes(hBytes) var b cryptobyte.Builder @@ -156,7 +184,15 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big. return false } - t := pub.ScalarBaseMult(h) + hNat, err := bigmod.NewNat().SetBytes(h.Bytes(), OrderNat) + if err != nil { + return false + } + + t, err := pub.ScalarBaseMult(hNat.Bytes(OrderNat)) + if err != nil { + return false + } // user sign public key p generation p := pub.GenerateUserPublicKey(uid, hid) @@ -210,17 +246,26 @@ func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) b // WrapKey generate and wrap key with reciever's uid and system hid func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) { q := pub.GenerateUserPublicKey(uid, hid) - var r *big.Int + var ( + r *bigmod.Nat + w *bn256.GT + ) for { - r, err = randFieldElement(rand) + r, err = randomScalar(rand) if err != nil { return } - cipher = new(bn256.G1).ScalarMult(q, r) - - w := pub.ScalarBaseMult(r) + rBytes := r.Bytes(OrderNat) + cipher, err = new(bn256.G1).ScalarMult(q, rBytes) + if err != nil { + return + } + w, err = pub.ScalarBaseMult(rBytes) + if err != nil { + return + } var buffer []byte buffer = append(buffer, cipher.Marshal()...) buffer = append(buffer, w.Marshal()...) @@ -463,7 +508,7 @@ type KeyExchange struct { privateKey *EncryptPrivateKey // owner's encryption private key uid []byte // owner uid peerUID []byte // peer uid - r *big.Int // random which will be used to compute secret + r *bigmod.Nat // random which will be used to compute secret secret *bn256.G1 // generated secret which will be passed to peer peerSecret *bn256.G1 // received peer's secret g1 *bn256.GT // internal state which will be used when compute the key and signature @@ -485,7 +530,7 @@ func NewKeyExchange(priv *EncryptPrivateKey, uid, peerUID []byte, keyLen int, ge // Destroy clear all internal state and Ephemeral private/public keys func (ke *KeyExchange) Destroy() { if ke.r != nil { - ke.r.SetInt64(0) + ke.r.SetBytes([]byte{0}, OrderNat) } if ke.g1 != nil { ke.g1.SetOne() @@ -498,16 +543,19 @@ func (ke *KeyExchange) Destroy() { } } -func initKeyExchange(ke *KeyExchange, hid byte, r *big.Int) { +func initKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat) { pubB := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid) ke.r = r - rA := new(bn256.G1).ScalarMult(pubB, ke.r) + rA, err := new(bn256.G1).ScalarMult(pubB, ke.r.Bytes(OrderNat)) + if err != nil { + panic(err) + } ke.secret = rA } // InitKeyExchange generate random with responder uid, for initiator's step A1-A4 func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, error) { - r, err := randFieldElement(rand) + r, err := randomScalar(rand) if err != nil { return nil, err } @@ -559,20 +607,33 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil } -func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { +func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) { if !rA.IsOnCurve() { return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key") } ke.peerSecret = rA pubA := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid) ke.r = r - rB := new(bn256.G1).ScalarMult(pubA, r) + rBytes := r.Bytes(OrderNat) + rB, err := new(bn256.G1).ScalarMult(pubA, rBytes) + if err != nil { + return nil, nil, err + } ke.secret = rB ke.g1 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey) ke.g3 = &bn256.GT{} - ke.g3.ScalarMult(ke.g1, r) - ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) + g3, err := bn256.ScalarMultGT(ke.g1, rBytes) + if err != nil { + return nil, nil, err + } + ke.g3 = g3 + + g2, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(rBytes) + if err != nil { + return nil, nil, err + } + ke.g2 = g2 if !ke.genSignature { return ke.secret, nil, nil @@ -583,7 +644,7 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b // RepondKeyExchange when responder receive rA, for responder's step B1-B7 func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1) (*bn256.G1, []byte, error) { - r, err := randFieldElement(rand) + r, err := randomScalar(rand) if err != nil { return nil, nil, err } @@ -597,10 +658,18 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, []byte } // step 5 ke.peerSecret = rB - ke.g1 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r) + g1, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r.Bytes(OrderNat)) + if err != nil { + return nil, nil, err + } + ke.g1 = g1 ke.g2 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey) ke.g3 = &bn256.GT{} - ke.g3.ScalarMult(ke.g2, ke.r) + g3, err := bn256.ScalarMultGT(ke.g2, ke.r.Bytes(OrderNat)) + if err != nil { + return nil, nil, err + } + ke.g3 = g3 // step 6, verify signature if len(sB) > 0 { signature := ke.sign(false, 0x82) diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index a97a481..e5b466a 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -8,6 +8,7 @@ import ( "math/big" "sync" + "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" @@ -57,14 +58,19 @@ type EncryptPrivateKey struct { // GenerateSignMasterKey generates a master public and private key pair for DSA usage. func GenerateSignMasterKey(rand io.Reader) (*SignMasterPrivateKey, error) { - k, err := randFieldElement(rand) + k, err := randomScalar(rand) + if err != nil { + return nil, err + } + kBytes := k.Bytes(OrderNat) + p, err := new(bn256.G2).ScalarBaseMult(kBytes) if err != nil { return nil, err } priv := new(SignMasterPrivateKey) - priv.D = k - priv.MasterPublicKey = new(bn256.G2).ScalarBaseMult(k) + priv.D = new(big.Int).SetBytes(kBytes) + priv.MasterPublicKey = p return priv, nil } @@ -96,7 +102,11 @@ func (master *SignMasterPrivateKey) UnmarshalASN1(der []byte) error { return errors.New("sm9: invalid sign master private key asn1 data") } master.D = d - master.MasterPublicKey = new(bn256.G2).ScalarBaseMult(d) + p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(d.Bytes())) + if err != nil { + return err + } + master.MasterPublicKey = p return nil } @@ -107,17 +117,32 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign id = append(id, hid) t1 := hashH1(id) - t1.Add(t1, master.D) - if t1.Sign() == 0 { + + t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), OrderNat) + if err != nil { + return nil, err + } + + d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), OrderNat) + if err != nil { + return nil, err + } + + t1Nat.Add(d, OrderNat) + if t1Nat.IsZero() == 1 { return nil, errors.New("sm9: need to re-generate sign master private key") } - t1 = fermatInverse(t1, bn256.Order) - t2 := new(big.Int).Mul(t1, master.D) - t2.Mod(t2, bn256.Order) + + t1Nat = bigmod.NewNat().Exp(t1Nat, OrderMinus2, OrderNat) + t1Nat.Mul(d, OrderNat) priv := new(SignPrivateKey) priv.SignMasterPublicKey = master.SignMasterPublicKey - priv.PrivateKey = new(bn256.G1).ScalarBaseMult(t2) + g1, err := new(bn256.G1).ScalarBaseMult(t1Nat.Bytes(OrderNat)) + if err != nil { + return nil, err + } + priv.PrivateKey = g1 return priv, nil } @@ -144,9 +169,9 @@ func (pub *SignMasterPublicKey) generatorTable() *[32 * 2]bn256.GTFieldTable { // ScalarBaseMult compute basepoint^r with precomputed table // The base point = pair(Gen1, ) -func (pub *SignMasterPublicKey) ScalarBaseMult(r *big.Int) *bn256.GT { +func (pub *SignMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, error) { tables := pub.generatorTable() - return bn256.ScalarBaseMultGT(tables, r) + return bn256.ScalarBaseMultGT(tables, scalar) } // GenerateUserPublicKey generate user sign public key @@ -155,7 +180,10 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2 buffer = append(buffer, uid...) buffer = append(buffer, hid) h1 := hashH1(buffer) - p := new(bn256.G2).ScalarBaseMult(h1) + p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) + if err != nil { + panic(err) + } p.Add(p, pub.MasterPublicKey) return p } @@ -326,14 +354,19 @@ func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { // GenerateEncryptMasterKey generates a master public and private key pair for encryption usage. func GenerateEncryptMasterKey(rand io.Reader) (*EncryptMasterPrivateKey, error) { - k, err := randFieldElement(rand) + k, err := randomScalar(rand) if err != nil { return nil, err } + kBytes := k.Bytes(OrderNat) priv := new(EncryptMasterPrivateKey) - priv.D = k - priv.MasterPublicKey = new(bn256.G1).ScalarBaseMult(k) + priv.D = new(big.Int).SetBytes(kBytes) + p, err := new(bn256.G1).ScalarBaseMult(kBytes) + if err != nil { + panic(err) + } + priv.MasterPublicKey = p return priv, nil } @@ -344,17 +377,32 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E id = append(id, hid) t1 := hashH1(id) - t1.Add(t1, master.D) - if t1.Sign() == 0 { + + t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), OrderNat) + if err != nil { + return nil, err + } + + d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), OrderNat) + if err != nil { + return nil, err + } + + t1Nat.Add(d, OrderNat) + if t1Nat.IsZero() == 1 { return nil, errors.New("sm9: need to re-generate encrypt master private key") } - t1 = fermatInverse(t1, bn256.Order) - t2 := new(big.Int).Mul(t1, master.D) - t2.Mod(t2, bn256.Order) + + t1Nat = bigmod.NewNat().Exp(t1Nat, OrderMinus2, OrderNat) + t1Nat.Mul(d, OrderNat) priv := new(EncryptPrivateKey) priv.EncryptMasterPublicKey = master.EncryptMasterPublicKey - priv.PrivateKey = new(bn256.G2).ScalarBaseMult(t2) + p, err := new(bn256.G2).ScalarBaseMult(t1Nat.Bytes(OrderNat)) + if err != nil { + panic(err) + } + priv.PrivateKey = p return priv, nil } @@ -392,7 +440,11 @@ func (master *EncryptMasterPrivateKey) UnmarshalASN1(der []byte) error { return errors.New("sm9: invalid encrypt master private key asn1 data") } master.D = d - master.MasterPublicKey = new(bn256.G1).ScalarBaseMult(d) + p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(d.Bytes())) + if err != nil { + return err + } + master.MasterPublicKey = p return nil } @@ -413,9 +465,9 @@ func (pub *EncryptMasterPublicKey) generatorTable() *[32 * 2]bn256.GTFieldTable // ScalarBaseMult compute basepoint^r with precomputed table. // The base point = pair(, Gen2) -func (pub *EncryptMasterPublicKey) ScalarBaseMult(r *big.Int) *bn256.GT { +func (pub *EncryptMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, error) { tables := pub.generatorTable() - return bn256.ScalarBaseMultGT(tables, r) + return bn256.ScalarBaseMultGT(tables, scalar) } // GenerateUserPublicKey generate user encrypt public key @@ -424,7 +476,10 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) * buffer = append(buffer, uid...) buffer = append(buffer, hid) h1 := hashH1(buffer) - p := new(bn256.G1).ScalarBaseMult(h1) + p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) + if err != nil { + panic(err) + } p.Add(p, pub.MasterPublicKey) return p } @@ -554,14 +609,3 @@ func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error { } return nil } - -// fermatInverse calculates the inverse of k in GF(P) using Fermat's method -// (exponentiation modulo P - 2, per Euler's theorem). This has better -// constant-time properties than Euclid's method (implemented in -// math/big.Int.ModInverse and FIPS 186-4, Appendix C.1) although math/big -// itself isn't strictly constant-time so it's not perfect. -func fermatInverse(k, N *big.Int) *big.Int { - two := big.NewInt(2) - nMinus2 := new(big.Int).Sub(N, two) - return new(big.Int).Exp(k, nMinus2, N) -} diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index b8157ce..a8b64c2 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -6,6 +6,7 @@ import ( "math/big" "testing" + "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" @@ -98,12 +99,19 @@ func TestSignSM9Sample(t *testing.T) { masterKey := new(SignMasterPrivateKey) masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") - masterKey.MasterPublicKey = new(bn256.G2).ScalarBaseMult(masterKey.D) + p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes())) + if err != nil { + t.Fatal(err) + } + masterKey.MasterPublicKey = p userKey, err := masterKey.GenerateUserKey(uid, hid) if err != nil { t.Fatal(err) } - w := userKey.SignMasterPublicKey.ScalarBaseMult(r) + w, err := userKey.SignMasterPublicKey.ScalarBaseMult(bn256.NormalizeScalar(r.Bytes())) + if err != nil { + t.Fatal(err) + } var buffer []byte buffer = append(buffer, hash...) @@ -120,7 +128,10 @@ func TestSignSM9Sample(t *testing.T) { l.Add(l, bn256.Order) } - s := new(bn256.G1).ScalarMult(userKey.PrivateKey, l) + s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes())) + if err != nil { + t.Fatal(err) + } if hex.EncodeToString(s.MarshalUncompressed()) != expectedS { t.Fatal("not same S") @@ -137,7 +148,11 @@ func TestKeyExchangeSample(t *testing.T) { masterKey := new(EncryptMasterPrivateKey) masterKey.D = bigFromHex("02E65B0762D042F51F0D23542B13ED8CFA2E9A0E7206361E013A283905E31F") - masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) + p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes())) + if err != nil { + t.Fatal(err) + } + masterKey.MasterPublicKey = p if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedPube { t.Errorf("not expected master public key") @@ -162,14 +177,22 @@ func TestKeyExchangeSample(t *testing.T) { responder.Destroy() }() // A1-A4 - initKeyExchange(initiator, hid, bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8")) + k, err := bigmod.NewNat().SetBytes(bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8").Bytes(), OrderNat) + if err != nil { + t.Fatal(err) + } + initKeyExchange(initiator, hid, k) if hex.EncodeToString(initiator.secret.Marshal()) != "7cba5b19069ee66aa79d490413d11846b9ba76dd22567f809cf23b6d964bb265a9760c99cb6f706343fed05637085864958d6c90902aba7d405fbedf7b781599" { t.Fatal("not same") } // B1 - B7 - rB, sigB, err := respondKeyExchange(responder, hid, bigFromHex("018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE"), initiator.secret) + k, err = bigmod.NewNat().SetBytes(bigFromHex("018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE").Bytes(), OrderNat) + if err != nil { + t.Fatal(err) + } + rB, sigB, err := respondKeyExchange(responder, hid, k, initiator.secret) if err != nil { t.Fatal(err) } @@ -403,7 +426,11 @@ func TestWrapKeySM9Sample(t *testing.T) { masterKey := new(EncryptMasterPrivateKey) masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") - masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) + p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes())) + if err != nil { + t.Fatal(err) + } + masterKey.MasterPublicKey = p if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey { t.Errorf("not expected master public key") } @@ -425,7 +452,10 @@ func TestWrapKeySM9Sample(t *testing.T) { } var r *big.Int = bigFromHex("74015F8489C01EF4270456F9E6475BFB602BDE7F33FD482AB4E3684A6722") - cipher := new(bn256.G1).ScalarMult(q, r) + cipher, err := new(bn256.G1).ScalarMult(q, bn256.NormalizeScalar(r.Bytes())) + if err != nil { + t.Fatal(err) + } if hex.EncodeToString(cipher.Marshal()) != expectedCipher { t.Errorf("not expected cipher") } @@ -465,7 +495,11 @@ func TestEncryptSM9Sample(t *testing.T) { masterKey := new(EncryptMasterPrivateKey) masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") - masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) + p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes())) + if err != nil { + t.Fatal(err) + } + masterKey.MasterPublicKey = p if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey { t.Errorf("not expected master public key") } @@ -487,7 +521,10 @@ func TestEncryptSM9Sample(t *testing.T) { } var r *big.Int = bigFromHex("AAC0541779C8FC45E3E2CB25C12B5D2576B2129AE8BB5EE2CBE5EC9E785C") - cipher := new(bn256.G1).ScalarMult(q, r) + cipher, err := new(bn256.G1).ScalarMult(q, bn256.NormalizeScalar(r.Bytes())) + if err != nil { + t.Fatal(err) + } if hex.EncodeToString(cipher.Marshal()) != expectedCipher { t.Errorf("not expected cipher") } @@ -501,7 +538,7 @@ func TestEncryptSM9Sample(t *testing.T) { buffer = append(buffer, uid...) key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32) - + if hex.EncodeToString(key) != expectedKey { t.Errorf("not expected key") }