diff --git a/sm9/curve.go b/sm9/curve.go index 583e226..bfe8e41 100644 --- a/sm9/curve.go +++ b/sm9/curve.go @@ -1,6 +1,9 @@ package sm9 -import "math/big" +import ( + "crypto/subtle" + "math/big" +) // curvePoint implements the elliptic curve y²=x³+5. Points are kept in Jacobian // form and t=z² when valid. G₁ is the set of points of this curve on GF(p). @@ -49,6 +52,18 @@ func (c *curvePoint) IsOnCurve() bool { return *y2 == *x3 } +func NewCurvePoint() *curvePoint { + c := &curvePoint{} + c.SetInfinity() + return c +} + +func NewCurveGenerator() *curvePoint { + c := &curvePoint{} + c.Set(curveGen) + return c +} + func (c *curvePoint) SetInfinity() { c.x = *zero c.y = *one @@ -226,3 +241,30 @@ func (c *curvePoint) Neg(a *curvePoint) { c.z.Set(&a.z) c.t = *zero } + +// Select sets q to p1 if cond == 1, and to p2 if cond == 0. +func (q *curvePoint) Select(p1, p2 *curvePoint, cond int) *curvePoint { + q.x.Select(&p1.x, &p2.x, cond) + q.y.Select(&p1.y, &p2.y, cond) + q.z.Select(&p1.z, &p2.z, cond) + q.t.Select(&p1.t, &p2.t, cond) + return q +} + +// A curvePointTable holds the first 15 multiples of a point at offset -1, so [1]P +// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity +// point. +type curvePointTable [15]*curvePoint + +// Select selects the n-th multiple of the table base point into p. It works in +// constant time by iterating over every entry of the table. n must be in [0, 15]. +func (table *curvePointTable) Select(p *curvePoint, n uint8) { + if n >= 16 { + panic("sm9: internal error: curvePointTable called with out-of-bounds value") + } + p.SetInfinity() + for i := uint8(1); i < 16; i++ { + cond := subtle.ConstantTimeByteEq(i, n) + p.Select(table[i-1], p, cond) + } +} diff --git a/sm9/g1.go b/sm9/g1.go index 61d8a96..a535b68 100644 --- a/sm9/g1.go +++ b/sm9/g1.go @@ -6,6 +6,7 @@ import ( "io" "math/big" "math/bits" + "sync" ) func randomK(r io.Reader) (k *big.Int, err error) { @@ -26,6 +27,31 @@ type G1 struct { //Gen1 is the generator of G1. var Gen1 = &G1{curveGen} +var g1GeneratorTable *[32 * 2]curvePointTable +var g1GeneratorTableOnce sync.Once + +func (g *G1) generatorTable() *[32 * 2]curvePointTable { + g1GeneratorTableOnce.Do(func() { + g1GeneratorTable = new([32 * 2]curvePointTable) + base := NewCurveGenerator() + for i := 0; i < 32*2; i++ { + g1GeneratorTable[i][0] = &curvePoint{} + g1GeneratorTable[i][0].Set(base) + for j := 1; j < 15; j += 2 { + g1GeneratorTable[i][j] = &curvePoint{} + g1GeneratorTable[i][j].Double(g1GeneratorTable[i][j/2]) + g1GeneratorTable[i][j+1] = &curvePoint{} + g1GeneratorTable[i][j+1].Add(g1GeneratorTable[i][j], base) + } + base.Double(base) + base.Double(base) + base.Double(base) + base.Double(base) + } + }) + return g1GeneratorTable +} + // RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r. func RandomG1(r io.Reader) (*big.Int, *G1, error) { k, err := randomK(r) @@ -40,13 +66,48 @@ func (g *G1) String() string { return "sm9.G1" + g.p.String() } +func normalizeScalar(scalar []byte) []byte { + if len(scalar) == 32 { + return scalar + } + s := new(big.Int).SetBytes(scalar) + if len(scalar) > 32 { + s.Mod(s, Order) + } + out := make([]byte, 32) + return s.FillBytes(out) +} + // ScalarBaseMult sets e to g*k where g is the generator of the group and then // returns e. func (e *G1) ScalarBaseMult(k *big.Int) *G1 { if e.p == nil { e.p = &curvePoint{} } - e.p.Mul(curveGen, k) + + //e.p.Mul(curveGen, k) + + scalar := normalizeScalar(k.Bytes()) + tables := e.generatorTable() + // This is also a scalar multiplication with a four-bit window like in + // ScalarMult, but in this case the doublings are precomputed. The value + // [windowValue]G added at iteration k would normally get doubled + // (totIterations-k)×4 times, but with a larger precomputation we can + // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the + // doublings between iterations. + t := NewCurvePoint() + e.p.SetInfinity() + tableIndex := len(tables) - 1 + for _, byte := range scalar { + windowValue := byte >> 4 + tables[tableIndex].Select(t, windowValue) + e.p.Add(e.p, t) + tableIndex-- + windowValue = byte & 0b1111 + tables[tableIndex].Select(t, windowValue) + e.p.Add(e.p, t) + tableIndex-- + } return e } @@ -55,7 +116,42 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { if e.p == nil { e.p = &curvePoint{} } - e.p.Mul(a.p, k) + //e.p.Mul(a.p, k) + // Compute a curvePointTable for the base point a. + var table = curvePointTable{NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), + NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), + NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), + NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint()} + table[0].Set(a.p) + for i := 1; i < 15; i += 2 { + table[i].Double(table[i/2]) + table[i+1].Add(table[i], a.p) + } + // Instead of doing the classic double-and-add chain, we do it with a + // four-bit window: we double four times, and then add [0-15]P. + t := &G1{NewCurvePoint()} + e.p.SetInfinity() + scalarBytes := normalizeScalar(k.Bytes()) + for i, byte := range scalarBytes { + // No need to double on the first iteration, as p is the identity at + // this point, and [N]∞ = ∞. + if i != 0 { + e.Double(e) + e.Double(e) + e.Double(e) + e.Double(e) + } + windowValue := byte >> 4 + table.Select(t.p, windowValue) + e.Add(e, t) + e.Double(e) + e.Double(e) + e.Double(e) + e.Double(e) + windowValue = byte & 0b1111 + table.Select(t.p, windowValue) + e.Add(e, t) + } return e } diff --git a/sm9/g2.go b/sm9/g2.go index 3c24312..c5af3aa 100644 --- a/sm9/g2.go +++ b/sm9/g2.go @@ -4,6 +4,7 @@ import ( "errors" "io" "math/big" + "sync" ) // G2 is an abstract cyclic group. The zero value is suitable for use as the @@ -15,6 +16,31 @@ type G2 struct { //Gen2 is the generator of G2. var Gen2 = &G2{twistGen} +var g2GeneratorTable *[32 * 2]twistPointTable +var g2GeneratorTableOnce sync.Once + +func (g *G2) generatorTable() *[32 * 2]twistPointTable { + g2GeneratorTableOnce.Do(func() { + g2GeneratorTable = new([32 * 2]twistPointTable) + base := NewTwistGenerator() + for i := 0; i < 32*2; i++ { + g2GeneratorTable[i][0] = &twistPoint{} + g2GeneratorTable[i][0].Set(base) + for j := 1; j < 15; j += 2 { + g2GeneratorTable[i][j] = &twistPoint{} + g2GeneratorTable[i][j].Double(g2GeneratorTable[i][j/2]) + g2GeneratorTable[i][j+1] = &twistPoint{} + g2GeneratorTable[i][j+1].Add(g2GeneratorTable[i][j], base) + } + base.Double(base) + base.Double(base) + base.Double(base) + base.Double(base) + } + }) + return g2GeneratorTable +} + // RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r. func RandomG2(r io.Reader) (*big.Int, *G2, error) { k, err := randomK(r) @@ -35,7 +61,30 @@ func (e *G2) ScalarBaseMult(k *big.Int) *G2 { if e.p == nil { e.p = &twistPoint{} } - e.p.Mul(twistGen, k) + //e.p.Mul(twistGen, k) + + scalar := normalizeScalar(k.Bytes()) + tables := e.generatorTable() + // This is also a scalar multiplication with a four-bit window like in + // ScalarMult, but in this case the doublings are precomputed. The value + // [windowValue]G added at iteration k would normally get doubled + // (totIterations-k)×4 times, but with a larger precomputation we can + // instead add [2^((totIterations-k)×4)][windowValue]G and avoid the + // doublings between iterations. + t := NewTwistPoint() + e.p.SetInfinity() + tableIndex := len(tables) - 1 + for _, byte := range scalar { + windowValue := byte >> 4 + tables[tableIndex].Select(t, windowValue) + e.p.Add(e.p, t) + tableIndex-- + windowValue = byte & 0b1111 + tables[tableIndex].Select(t, windowValue) + e.p.Add(e.p, t) + tableIndex-- + } + return e } @@ -44,7 +93,42 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { if e.p == nil { e.p = &twistPoint{} } - e.p.Mul(a.p, k) + //e.p.Mul(a.p, k) + // Compute a twistPointTable for the base point a. + var table = twistPointTable{NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), + NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), + NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), + NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint()} + table[0].Set(a.p) + for i := 1; i < 15; i += 2 { + table[i].Double(table[i/2]) + table[i+1].Add(table[i], a.p) + } + // Instead of doing the classic double-and-add chain, we do it with a + // four-bit window: we double four times, and then add [0-15]P. + t := &G2{NewTwistPoint()} + e.p.SetInfinity() + scalarBytes := normalizeScalar(k.Bytes()) + for i, byte := range scalarBytes { + // No need to double on the first iteration, as p is the identity at + // this point, and [N]∞ = ∞. + if i != 0 { + e.p.Double(e.p) + e.p.Double(e.p) + e.p.Double(e.p) + e.p.Double(e.p) + } + windowValue := byte >> 4 + table.Select(t.p, windowValue) + e.Add(e, t) + e.p.Double(e.p) + e.p.Double(e.p) + e.p.Double(e.p) + e.p.Double(e.p) + windowValue = byte & 0b1111 + table.Select(t.p, windowValue) + e.Add(e, t) + } return e } diff --git a/sm9/gfp.go b/sm9/gfp.go index 0aee698..a9370ca 100644 --- a/sm9/gfp.go +++ b/sm9/gfp.go @@ -195,3 +195,37 @@ func init() { t1 := newGFp(2) twoInvert.Invert(t1) } + +// cmovznzU64 is a single-word conditional move. +// +// Postconditions: +// out1 = (if arg1 = 0 then arg2 else arg3) +// +// Input Bounds: +// arg1: [0x0 ~> 0x1] +// arg2: [0x0 ~> 0xffffffffffffffff] +// arg3: [0x0 ~> 0xffffffffffffffff] +// Output Bounds: +// out1: [0x0 ~> 0xffffffffffffffff] +func cmovznzU64(out1 *uint64, arg1 uint64, arg2 uint64, arg3 uint64) { + x1 := (uint64(arg1) * 0xffffffffffffffff) + x2 := ((x1 & arg3) | ((^x1) & arg2)) + *out1 = x2 +} + +// Select sets e to p1 if cond == 1, and to p2 if cond == 0. +func (e *gfP) Select(p1, p2 *gfP, cond int) *gfP { + var x1 uint64 + cmovznzU64(&x1, uint64(cond), p2[0], p1[0]) + var x2 uint64 + cmovznzU64(&x2, uint64(cond), p2[1], p1[1]) + var x3 uint64 + cmovznzU64(&x3, uint64(cond), p2[2], p1[2]) + var x4 uint64 + cmovznzU64(&x4, uint64(cond), p2[3], p1[3]) + e[0] = x1 + e[1] = x2 + e[2] = x3 + e[3] = x4 + return e +} diff --git a/sm9/gfp2.go b/sm9/gfp2.go index b42860e..0fc405a 100644 --- a/sm9/gfp2.go +++ b/sm9/gfp2.go @@ -258,3 +258,10 @@ func (e *gfP2) Div2(f *gfP2) *gfP2 { e.Set(t) return e } + +// Select sets e to p1 if cond == 1, and to p2 if cond == 0. +func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 { + e.x.Select(&p1.x, &p2.x, cond) + e.y.Select(&p1.y, &p2.y, cond) + return e +} diff --git a/sm9/sm9.go b/sm9/sm9.go index a7e2a2e..2e91fd1 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -84,10 +84,17 @@ func randFieldElement(rand io.Reader) (k *big.Int, err error) { return } +func (pub *SignMasterPublicKey) Pair() *GT { + pub.pairOnce.Do(func() { + pub.basePoint = Pair(Gen1, pub.MasterPublicKey) + }) + return pub.basePoint +} + // Sign signs a hash (which should be the result of hashing a larger message) // using the user dsa key. It returns the signature as a pair of h and s. func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1, err error) { - g := Pair(Gen1, priv.SignMasterPublicKey.MasterPublicKey) + g := priv.Pair() var r *big.Int for { r, err = randFieldElement(rand) @@ -103,7 +110,10 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1, h = hashH2(buffer) l := new(big.Int).Sub(r, h) - l.Mod(l, Order) + + if l.Sign() < 0 { + l.Add(l, Order) + } if l.Sign() != 0 { s = new(G1).ScalarMult(priv.PrivateKey, l) @@ -138,17 +148,6 @@ func SignASN1(rand io.Reader, priv *SignPrivateKey, hash []byte) ([]byte, error) return priv.Sign(rand, hash, nil) } -// GenerateUserPublicKey generate user sign public key -func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G2 { - var buffer []byte - buffer = append(buffer, uid...) - buffer = append(buffer, hid) - h1 := hashH1(buffer) - p := new(G2).ScalarBaseMult(h1) - p.Add(p, pub.MasterPublicKey) - return p -} - // Verify verifies the signature in h, s of hash using the master dsa public key and user id, uid and hid. // Its return value records whether the signature is valid. func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.Int, s *G1) bool { @@ -158,7 +157,8 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big. if !s.p.IsOnCurve() { return false } - g := Pair(Gen1, pub.MasterPublicKey) + g := pub.Pair() + t := new(GT).ScalarMult(g, h) // user sign public key p generation @@ -220,6 +220,13 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) * return p } +func (pub *EncryptMasterPublicKey) Pair() *GT { + pub.pairOnce.Do(func() { + pub.basePoint = Pair(pub.MasterPublicKey, Gen2) + }) + return pub.basePoint +} + // WrappKey generate and wrapp key wtih reciever's uid and system hid func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *G1, err error) { q := pub.GenerateUserPublicKey(uid, hid) @@ -233,7 +240,7 @@ func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, cipher = new(G1).ScalarMult(q, r) - g := Pair(pub.MasterPublicKey, Gen2) + g := pub.Pair() w := new(GT).ScalarMult(g, r) var buffer []byte diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index 2101ef1..b3761d5 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -4,6 +4,7 @@ import ( "errors" "io" "math/big" + "sync" "golang.org/x/crypto/cryptobyte" ) @@ -15,6 +16,8 @@ type SignMasterPrivateKey struct { type SignMasterPublicKey struct { MasterPublicKey *G2 + pairOnce sync.Once + basePoint *GT } type SignPrivateKey struct { @@ -29,6 +32,8 @@ type EncryptMasterPrivateKey struct { type EncryptMasterPublicKey struct { MasterPublicKey *G1 + pairOnce sync.Once + basePoint *GT } type EncryptPrivateKey struct { @@ -123,6 +128,17 @@ func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { return nil } +// GenerateUserPublicKey generate user sign public key +func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G2 { + var buffer []byte + buffer = append(buffer, uid...) + buffer = append(buffer, hid) + h1 := hashH1(buffer) + p := new(G2).ScalarBaseMult(h1) + p.Add(p, pub.MasterPublicKey) + return p +} + // MasterPublic returns the master public key corresponding to priv. func (priv *SignPrivateKey) MasterPublic() *SignMasterPublicKey { return &priv.SignMasterPublicKey diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 660a02f..3d3a5d4 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -344,6 +344,7 @@ func BenchmarkSign(b *testing.B) { if err != nil { b.Fatal(err) } + SignASN1(rand.Reader, userKey, hashed) // fire precompute b.ReportAllocs() b.ResetTimer() diff --git a/sm9/twist.go b/sm9/twist.go index 611ceec..37736b4 100644 --- a/sm9/twist.go +++ b/sm9/twist.go @@ -1,6 +1,9 @@ package sm9 -import "math/big" +import ( + "crypto/subtle" + "math/big" +) // twistPoint implements the elliptic curve y²=x³+5/ξ (y²=x³+5i) over GF(p²). Points are // kept in Jacobian form and t=z² when valid. The group G₂ is the set of @@ -41,6 +44,18 @@ func (c *twistPoint) Set(a *twistPoint) { c.t.Set(&a.t) } +func NewTwistPoint() *twistPoint { + c := &twistPoint{} + c.SetInfinity() + return c +} + +func NewTwistGenerator() *twistPoint { + c := &twistPoint{} + c.Set(twistGen) + return c +} + // IsOnCurve returns true iff c is on the curve. func (c *twistPoint) IsOnCurve() bool { c.MakeAffine() @@ -154,6 +169,7 @@ func (c *twistPoint) Double(a *twistPoint) { c.y.Sub(t2, t) } +// TODO: improve it func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) { sum, t := &twistPoint{}, &twistPoint{} @@ -220,6 +236,33 @@ func (c *twistPoint) NegFrobeniusP2(a *twistPoint) { c.t.Square(&a.z) } +// Select sets q to p1 if cond == 1, and to p2 if cond == 0. +func (q *twistPoint) Select(p1, p2 *twistPoint, cond int) *twistPoint { + q.x.Select(&p1.x, &p2.x, cond) + q.y.Select(&p1.y, &p2.y, cond) + q.z.Select(&p1.z, &p2.z, cond) + q.t.Select(&p1.t, &p2.t, cond) + return q +} + +// A twistPointTable holds the first 15 multiples of a point at offset -1, so [1]P +// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity +// point. +type twistPointTable [15]*twistPoint + +// Select selects the n-th multiple of the table base point into p. It works in +// constant time by iterating over every entry of the table. n must be in [0, 15]. +func (table *twistPointTable) Select(p *twistPoint, n uint8) { + if n >= 16 { + panic("sm9: internal error: twistPointTable called with out-of-bounds value") + } + p.SetInfinity() + for i := uint8(1); i < 16; i++ { + cond := subtle.ConstantTimeByteEq(i, n) + p.Select(table[i-1], p, cond) + } +} + /* //code logic is from https://github.com/miracl/MIRACL/blob/master/source/curve/pairing/bn_pair.cpp func (c *twistPoint) Frobenius(a *twistPoint) {