precompute part 1

This commit is contained in:
Sun Yimin 2022-06-13 13:50:27 +08:00 committed by GitHub
parent cadc48f630
commit f78fd3c105
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 351 additions and 21 deletions

View File

@ -1,6 +1,9 @@
package sm9 package sm9
import "math/big" import (
"crypto/subtle"
"math/big"
)
// curvePoint implements the elliptic curve y²=x³+5. Points are kept in Jacobian // curvePoint implements the elliptic curve y²=x³+5. Points are kept in Jacobian
// form and t=z² when valid. G₁ is the set of points of this curve on GF(p). // form and t=z² when valid. G₁ is the set of points of this curve on GF(p).
@ -49,6 +52,18 @@ func (c *curvePoint) IsOnCurve() bool {
return *y2 == *x3 return *y2 == *x3
} }
func NewCurvePoint() *curvePoint {
c := &curvePoint{}
c.SetInfinity()
return c
}
func NewCurveGenerator() *curvePoint {
c := &curvePoint{}
c.Set(curveGen)
return c
}
func (c *curvePoint) SetInfinity() { func (c *curvePoint) SetInfinity() {
c.x = *zero c.x = *zero
c.y = *one c.y = *one
@ -226,3 +241,30 @@ func (c *curvePoint) Neg(a *curvePoint) {
c.z.Set(&a.z) c.z.Set(&a.z)
c.t = *zero c.t = *zero
} }
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *curvePoint) Select(p1, p2 *curvePoint, cond int) *curvePoint {
q.x.Select(&p1.x, &p2.x, cond)
q.y.Select(&p1.y, &p2.y, cond)
q.z.Select(&p1.z, &p2.z, cond)
q.t.Select(&p1.t, &p2.t, cond)
return q
}
// A curvePointTable holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.
type curvePointTable [15]*curvePoint
// Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *curvePointTable) Select(p *curvePoint, n uint8) {
if n >= 16 {
panic("sm9: internal error: curvePointTable called with out-of-bounds value")
}
p.SetInfinity()
for i := uint8(1); i < 16; i++ {
cond := subtle.ConstantTimeByteEq(i, n)
p.Select(table[i-1], p, cond)
}
}

100
sm9/g1.go
View File

@ -6,6 +6,7 @@ import (
"io" "io"
"math/big" "math/big"
"math/bits" "math/bits"
"sync"
) )
func randomK(r io.Reader) (k *big.Int, err error) { func randomK(r io.Reader) (k *big.Int, err error) {
@ -26,6 +27,31 @@ type G1 struct {
//Gen1 is the generator of G1. //Gen1 is the generator of G1.
var Gen1 = &G1{curveGen} var Gen1 = &G1{curveGen}
var g1GeneratorTable *[32 * 2]curvePointTable
var g1GeneratorTableOnce sync.Once
func (g *G1) generatorTable() *[32 * 2]curvePointTable {
g1GeneratorTableOnce.Do(func() {
g1GeneratorTable = new([32 * 2]curvePointTable)
base := NewCurveGenerator()
for i := 0; i < 32*2; i++ {
g1GeneratorTable[i][0] = &curvePoint{}
g1GeneratorTable[i][0].Set(base)
for j := 1; j < 15; j += 2 {
g1GeneratorTable[i][j] = &curvePoint{}
g1GeneratorTable[i][j].Double(g1GeneratorTable[i][j/2])
g1GeneratorTable[i][j+1] = &curvePoint{}
g1GeneratorTable[i][j+1].Add(g1GeneratorTable[i][j], base)
}
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return g1GeneratorTable
}
// RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r. // RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r.
func RandomG1(r io.Reader) (*big.Int, *G1, error) { func RandomG1(r io.Reader) (*big.Int, *G1, error) {
k, err := randomK(r) k, err := randomK(r)
@ -40,13 +66,48 @@ func (g *G1) String() string {
return "sm9.G1" + g.p.String() return "sm9.G1" + g.p.String()
} }
func normalizeScalar(scalar []byte) []byte {
if len(scalar) == 32 {
return scalar
}
s := new(big.Int).SetBytes(scalar)
if len(scalar) > 32 {
s.Mod(s, Order)
}
out := make([]byte, 32)
return s.FillBytes(out)
}
// ScalarBaseMult sets e to g*k where g is the generator of the group and then // ScalarBaseMult sets e to g*k where g is the generator of the group and then
// returns e. // returns e.
func (e *G1) ScalarBaseMult(k *big.Int) *G1 { func (e *G1) ScalarBaseMult(k *big.Int) *G1 {
if e.p == nil { if e.p == nil {
e.p = &curvePoint{} e.p = &curvePoint{}
} }
e.p.Mul(curveGen, k)
//e.p.Mul(curveGen, k)
scalar := normalizeScalar(k.Bytes())
tables := e.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewCurvePoint()
e.p.SetInfinity()
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
}
return e return e
} }
@ -55,7 +116,42 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 {
if e.p == nil { if e.p == nil {
e.p = &curvePoint{} e.p = &curvePoint{}
} }
e.p.Mul(a.p, k) //e.p.Mul(a.p, k)
// Compute a curvePointTable for the base point a.
var table = curvePointTable{NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint()}
table[0].Set(a.p)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], a.p)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := &G1{NewCurvePoint()}
e.p.SetInfinity()
scalarBytes := normalizeScalar(k.Bytes())
for i, byte := range scalarBytes {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
e.Double(e)
e.Double(e)
e.Double(e)
e.Double(e)
}
windowValue := byte >> 4
table.Select(t.p, windowValue)
e.Add(e, t)
e.Double(e)
e.Double(e)
e.Double(e)
e.Double(e)
windowValue = byte & 0b1111
table.Select(t.p, windowValue)
e.Add(e, t)
}
return e return e
} }

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"io" "io"
"math/big" "math/big"
"sync"
) )
// G2 is an abstract cyclic group. The zero value is suitable for use as the // G2 is an abstract cyclic group. The zero value is suitable for use as the
@ -15,6 +16,31 @@ type G2 struct {
//Gen2 is the generator of G2. //Gen2 is the generator of G2.
var Gen2 = &G2{twistGen} var Gen2 = &G2{twistGen}
var g2GeneratorTable *[32 * 2]twistPointTable
var g2GeneratorTableOnce sync.Once
func (g *G2) generatorTable() *[32 * 2]twistPointTable {
g2GeneratorTableOnce.Do(func() {
g2GeneratorTable = new([32 * 2]twistPointTable)
base := NewTwistGenerator()
for i := 0; i < 32*2; i++ {
g2GeneratorTable[i][0] = &twistPoint{}
g2GeneratorTable[i][0].Set(base)
for j := 1; j < 15; j += 2 {
g2GeneratorTable[i][j] = &twistPoint{}
g2GeneratorTable[i][j].Double(g2GeneratorTable[i][j/2])
g2GeneratorTable[i][j+1] = &twistPoint{}
g2GeneratorTable[i][j+1].Add(g2GeneratorTable[i][j], base)
}
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return g2GeneratorTable
}
// RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r. // RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r.
func RandomG2(r io.Reader) (*big.Int, *G2, error) { func RandomG2(r io.Reader) (*big.Int, *G2, error) {
k, err := randomK(r) k, err := randomK(r)
@ -35,7 +61,30 @@ func (e *G2) ScalarBaseMult(k *big.Int) *G2 {
if e.p == nil { if e.p == nil {
e.p = &twistPoint{} e.p = &twistPoint{}
} }
e.p.Mul(twistGen, k) //e.p.Mul(twistGen, k)
scalar := normalizeScalar(k.Bytes())
tables := e.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewTwistPoint()
e.p.SetInfinity()
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
}
return e return e
} }
@ -44,7 +93,42 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 {
if e.p == nil { if e.p == nil {
e.p = &twistPoint{} e.p = &twistPoint{}
} }
e.p.Mul(a.p, k) //e.p.Mul(a.p, k)
// Compute a twistPointTable for the base point a.
var table = twistPointTable{NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint(),
NewTwistPoint(), NewTwistPoint(), NewTwistPoint(), NewTwistPoint()}
table[0].Set(a.p)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], a.p)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := &G2{NewTwistPoint()}
e.p.SetInfinity()
scalarBytes := normalizeScalar(k.Bytes())
for i, byte := range scalarBytes {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
e.p.Double(e.p)
e.p.Double(e.p)
e.p.Double(e.p)
e.p.Double(e.p)
}
windowValue := byte >> 4
table.Select(t.p, windowValue)
e.Add(e, t)
e.p.Double(e.p)
e.p.Double(e.p)
e.p.Double(e.p)
e.p.Double(e.p)
windowValue = byte & 0b1111
table.Select(t.p, windowValue)
e.Add(e, t)
}
return e return e
} }

View File

@ -195,3 +195,37 @@ func init() {
t1 := newGFp(2) t1 := newGFp(2)
twoInvert.Invert(t1) twoInvert.Invert(t1)
} }
// cmovznzU64 is a single-word conditional move.
//
// Postconditions:
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
// Output Bounds:
// out1: [0x0 ~> 0xffffffffffffffff]
func cmovznzU64(out1 *uint64, arg1 uint64, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))
*out1 = x2
}
// Select sets e to p1 if cond == 1, and to p2 if cond == 0.
func (e *gfP) Select(p1, p2 *gfP, cond int) *gfP {
var x1 uint64
cmovznzU64(&x1, uint64(cond), p2[0], p1[0])
var x2 uint64
cmovznzU64(&x2, uint64(cond), p2[1], p1[1])
var x3 uint64
cmovznzU64(&x3, uint64(cond), p2[2], p1[2])
var x4 uint64
cmovznzU64(&x4, uint64(cond), p2[3], p1[3])
e[0] = x1
e[1] = x2
e[2] = x3
e[3] = x4
return e
}

View File

@ -258,3 +258,10 @@ func (e *gfP2) Div2(f *gfP2) *gfP2 {
e.Set(t) e.Set(t)
return e return e
} }
// Select sets e to p1 if cond == 1, and to p2 if cond == 0.
func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 {
e.x.Select(&p1.x, &p2.x, cond)
e.y.Select(&p1.y, &p2.y, cond)
return e
}

View File

@ -84,10 +84,17 @@ func randFieldElement(rand io.Reader) (k *big.Int, err error) {
return return
} }
func (pub *SignMasterPublicKey) Pair() *GT {
pub.pairOnce.Do(func() {
pub.basePoint = Pair(Gen1, pub.MasterPublicKey)
})
return pub.basePoint
}
// Sign signs a hash (which should be the result of hashing a larger message) // Sign signs a hash (which should be the result of hashing a larger message)
// using the user dsa key. It returns the signature as a pair of h and s. // using the user dsa key. It returns the signature as a pair of h and s.
func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1, err error) { func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1, err error) {
g := Pair(Gen1, priv.SignMasterPublicKey.MasterPublicKey) g := priv.Pair()
var r *big.Int var r *big.Int
for { for {
r, err = randFieldElement(rand) r, err = randFieldElement(rand)
@ -103,7 +110,10 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1,
h = hashH2(buffer) h = hashH2(buffer)
l := new(big.Int).Sub(r, h) l := new(big.Int).Sub(r, h)
l.Mod(l, Order)
if l.Sign() < 0 {
l.Add(l, Order)
}
if l.Sign() != 0 { if l.Sign() != 0 {
s = new(G1).ScalarMult(priv.PrivateKey, l) s = new(G1).ScalarMult(priv.PrivateKey, l)
@ -138,17 +148,6 @@ func SignASN1(rand io.Reader, priv *SignPrivateKey, hash []byte) ([]byte, error)
return priv.Sign(rand, hash, nil) return priv.Sign(rand, hash, nil)
} }
// GenerateUserPublicKey generate user sign public key
func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G2 {
var buffer []byte
buffer = append(buffer, uid...)
buffer = append(buffer, hid)
h1 := hashH1(buffer)
p := new(G2).ScalarBaseMult(h1)
p.Add(p, pub.MasterPublicKey)
return p
}
// Verify verifies the signature in h, s of hash using the master dsa public key and user id, uid and hid. // Verify verifies the signature in h, s of hash using the master dsa public key and user id, uid and hid.
// Its return value records whether the signature is valid. // Its return value records whether the signature is valid.
func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.Int, s *G1) bool { func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.Int, s *G1) bool {
@ -158,7 +157,8 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
if !s.p.IsOnCurve() { if !s.p.IsOnCurve() {
return false return false
} }
g := Pair(Gen1, pub.MasterPublicKey) g := pub.Pair()
t := new(GT).ScalarMult(g, h) t := new(GT).ScalarMult(g, h)
// user sign public key p generation // user sign public key p generation
@ -220,6 +220,13 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
return p return p
} }
func (pub *EncryptMasterPublicKey) Pair() *GT {
pub.pairOnce.Do(func() {
pub.basePoint = Pair(pub.MasterPublicKey, Gen2)
})
return pub.basePoint
}
// WrappKey generate and wrapp key wtih reciever's uid and system hid // WrappKey generate and wrapp key wtih reciever's uid and system hid
func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *G1, err error) { func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *G1, err error) {
q := pub.GenerateUserPublicKey(uid, hid) q := pub.GenerateUserPublicKey(uid, hid)
@ -233,7 +240,7 @@ func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
cipher = new(G1).ScalarMult(q, r) cipher = new(G1).ScalarMult(q, r)
g := Pair(pub.MasterPublicKey, Gen2) g := pub.Pair()
w := new(GT).ScalarMult(g, r) w := new(GT).ScalarMult(g, r)
var buffer []byte var buffer []byte

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"io" "io"
"math/big" "math/big"
"sync"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
) )
@ -15,6 +16,8 @@ type SignMasterPrivateKey struct {
type SignMasterPublicKey struct { type SignMasterPublicKey struct {
MasterPublicKey *G2 MasterPublicKey *G2
pairOnce sync.Once
basePoint *GT
} }
type SignPrivateKey struct { type SignPrivateKey struct {
@ -29,6 +32,8 @@ type EncryptMasterPrivateKey struct {
type EncryptMasterPublicKey struct { type EncryptMasterPublicKey struct {
MasterPublicKey *G1 MasterPublicKey *G1
pairOnce sync.Once
basePoint *GT
} }
type EncryptPrivateKey struct { type EncryptPrivateKey struct {
@ -123,6 +128,17 @@ func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error {
return nil return nil
} }
// GenerateUserPublicKey generate user sign public key
func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G2 {
var buffer []byte
buffer = append(buffer, uid...)
buffer = append(buffer, hid)
h1 := hashH1(buffer)
p := new(G2).ScalarBaseMult(h1)
p.Add(p, pub.MasterPublicKey)
return p
}
// MasterPublic returns the master public key corresponding to priv. // MasterPublic returns the master public key corresponding to priv.
func (priv *SignPrivateKey) MasterPublic() *SignMasterPublicKey { func (priv *SignPrivateKey) MasterPublic() *SignMasterPublicKey {
return &priv.SignMasterPublicKey return &priv.SignMasterPublicKey

View File

@ -344,6 +344,7 @@ func BenchmarkSign(b *testing.B) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
SignASN1(rand.Reader, userKey, hashed) // fire precompute
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()

View File

@ -1,6 +1,9 @@
package sm9 package sm9
import "math/big" import (
"crypto/subtle"
"math/big"
)
// twistPoint implements the elliptic curve y²=x³+5/ξ (y²=x³+5i) over GF(p²). Points are // twistPoint implements the elliptic curve y²=x³+5/ξ (y²=x³+5i) over GF(p²). Points are
// kept in Jacobian form and t=z² when valid. The group G₂ is the set of // kept in Jacobian form and t=z² when valid. The group G₂ is the set of
@ -41,6 +44,18 @@ func (c *twistPoint) Set(a *twistPoint) {
c.t.Set(&a.t) c.t.Set(&a.t)
} }
func NewTwistPoint() *twistPoint {
c := &twistPoint{}
c.SetInfinity()
return c
}
func NewTwistGenerator() *twistPoint {
c := &twistPoint{}
c.Set(twistGen)
return c
}
// IsOnCurve returns true iff c is on the curve. // IsOnCurve returns true iff c is on the curve.
func (c *twistPoint) IsOnCurve() bool { func (c *twistPoint) IsOnCurve() bool {
c.MakeAffine() c.MakeAffine()
@ -154,6 +169,7 @@ func (c *twistPoint) Double(a *twistPoint) {
c.y.Sub(t2, t) c.y.Sub(t2, t)
} }
// TODO: improve it
func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) { func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) {
sum, t := &twistPoint{}, &twistPoint{} sum, t := &twistPoint{}, &twistPoint{}
@ -220,6 +236,33 @@ func (c *twistPoint) NegFrobeniusP2(a *twistPoint) {
c.t.Square(&a.z) c.t.Square(&a.z)
} }
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *twistPoint) Select(p1, p2 *twistPoint, cond int) *twistPoint {
q.x.Select(&p1.x, &p2.x, cond)
q.y.Select(&p1.y, &p2.y, cond)
q.z.Select(&p1.z, &p2.z, cond)
q.t.Select(&p1.t, &p2.t, cond)
return q
}
// A twistPointTable holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.
type twistPointTable [15]*twistPoint
// Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *twistPointTable) Select(p *twistPoint, n uint8) {
if n >= 16 {
panic("sm9: internal error: twistPointTable called with out-of-bounds value")
}
p.SetInfinity()
for i := uint8(1); i < 16; i++ {
cond := subtle.ConstantTimeByteEq(i, n)
p.Select(table[i-1], p, cond)
}
}
/* /*
//code logic is from https://github.com/miracl/MIRACL/blob/master/source/curve/pairing/bn_pair.cpp //code logic is from https://github.com/miracl/MIRACL/blob/master/source/curve/pairing/bn_pair.cpp
func (c *twistPoint) Frobenius(a *twistPoint) { func (c *twistPoint) Frobenius(a *twistPoint) {