sm9: use bigmod instead of math/big

This commit is contained in:
Sun Yimin 2022-11-25 10:11:46 +08:00 committed by GitHub
parent aede405cdd
commit d7e853277a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 357 additions and 123 deletions

View File

@ -53,7 +53,10 @@ func init() {
func Test_Pairing_A2(t *testing.T) { func Test_Pairing_A2(t *testing.T) {
pk := bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") pk := bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
g2 := &G2{} g2 := &G2{}
g2.ScalarBaseMult(pk) _, err := g2.ScalarBaseMult(NormalizeScalar(pk.Bytes()))
if err != nil {
t.Fatal(err)
}
ret := pairing(g2.p, curveGen) ret := pairing(g2.p, curveGen)
if ret.x != expected1.x || ret.y != expected1.y || ret.z != expected1.z { if ret.x != expected1.x || ret.y != expected1.y || ret.z != expected1.z {
t.Errorf("not expected") t.Errorf("not expected")

View File

@ -59,14 +59,15 @@ func RandomG1(r io.Reader) (*big.Int, *G1, error) {
return nil, nil, err return nil, nil, err
} }
return k, new(G1).ScalarBaseMult(k), nil g1, err := new(G1).ScalarBaseMult(NormalizeScalar(k.Bytes()))
return k, g1, err
} }
func (g *G1) String() string { func (g *G1) String() string {
return "sm9.G1" + g.p.String() return "sm9.G1" + g.p.String()
} }
func normalizeScalar(scalar []byte) []byte { func NormalizeScalar(scalar []byte) []byte {
if len(scalar) == 32 { if len(scalar) == 32 {
return scalar return scalar
} }
@ -78,16 +79,18 @@ func normalizeScalar(scalar []byte) []byte {
return s.FillBytes(out) return s.FillBytes(out)
} }
// ScalarBaseMult sets e to g*k where g is the generator of the group and then // ScalarBaseMult sets e to scaler*g where g is the generator of the group and then
// returns e. // returns e.
func (e *G1) ScalarBaseMult(k *big.Int) *G1 { func (e *G1) ScalarBaseMult(scalar []byte) (*G1, error) {
if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
if e.p == nil { if e.p == nil {
e.p = &curvePoint{} e.p = &curvePoint{}
} }
//e.p.Mul(curveGen, k) //e.p.Mul(curveGen, k)
scalar := normalizeScalar(k.Bytes())
tables := e.generatorTable() tables := e.generatorTable()
// This is also a scalar multiplication with a four-bit window like in // This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value // ScalarMult, but in this case the doublings are precomputed. The value
@ -108,11 +111,11 @@ func (e *G1) ScalarBaseMult(k *big.Int) *G1 {
e.p.Add(e.p, t) e.p.Add(e.p, t)
tableIndex-- tableIndex--
} }
return e return e, nil
} }
// ScalarMult sets e to a*k and then returns e. // ScalarMult sets e to a*k and then returns e.
func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { func (e *G1) ScalarMult(a *G1, scalar []byte) (*G1, error) {
if e.p == nil { if e.p == nil {
e.p = &curvePoint{} e.p = &curvePoint{}
} }
@ -131,8 +134,7 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 {
// four-bit window: we double four times, and then add [0-15]P. // four-bit window: we double four times, and then add [0-15]P.
t := &G1{NewCurvePoint()} t := &G1{NewCurvePoint()}
e.p.SetInfinity() e.p.SetInfinity()
scalarBytes := normalizeScalar(k.Bytes()) for i, byte := range scalar {
for i, byte := range scalarBytes {
// No need to double on the first iteration, as p is the identity at // No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞. // this point, and [N]∞ = ∞.
if i != 0 { if i != 0 {
@ -152,7 +154,7 @@ func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 {
table.Select(t.p, windowValue) table.Select(t.p, windowValue)
e.Add(e, t) e.Add(e, t)
} }
return e return e, nil
} }
// Add sets e to a+b and then returns e. // Add sets e to a+b and then returns e.
@ -398,27 +400,37 @@ func (g1 *G1Curve) Params() *CurveParams {
// normalizeScalar brings the scalar within the byte size of the order of the // normalizeScalar brings the scalar within the byte size of the order of the
// curve, as expected by the nistec scalar multiplication functions. // curve, as expected by the nistec scalar multiplication functions.
func (curve *G1Curve) normalizeScalar(scalar []byte) *big.Int { func (curve *G1Curve) normalizeScalar(scalar []byte) []byte {
byteSize := (curve.params.N.BitLen() + 7) / 8 byteSize := (curve.params.N.BitLen() + 7) / 8
s := new(big.Int).SetBytes(scalar) s := new(big.Int).SetBytes(scalar)
if len(scalar) > byteSize { if len(scalar) > byteSize {
s.Mod(s, curve.params.N) s.Mod(s, curve.params.N)
} }
return s out := make([]byte, byteSize)
return s.FillBytes(out)
} }
func (g1 *G1Curve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { func (g1 *G1Curve) ScalarBaseMult(scalar []byte) (*big.Int, *big.Int) {
scalar := g1.normalizeScalar(k) scalar = g1.normalizeScalar(scalar)
res := g1.g.ScalarBaseMult(scalar).Marshal() p, err := g1.g.ScalarBaseMult(scalar)
if err != nil {
panic("sm9: g1 rejected normalized scalar")
}
res := p.Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
} }
func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, scalar []byte) (*big.Int, *big.Int) {
a, err := g1.pointFromAffine(Bx, By) a, err := g1.pointFromAffine(Bx, By)
if err != nil { if err != nil {
panic("sm9: ScalarMult was called on an invalid point") panic("sm9: ScalarMult was called on an invalid point")
} }
res := g1.g.ScalarMult(a, new(big.Int).SetBytes(k)).Marshal() scalar = g1.normalizeScalar(scalar)
p, err := g1.g.ScalarMult(a, scalar)
if err != nil {
panic("sm9: g1 rejected normalized scalar")
}
res := p.Marshal()
return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:])
} }

View File

@ -177,7 +177,10 @@ func TestG1ScaleMult(t *testing.T) {
t.Errorf("not same") t.Errorf("not same")
} }
e3.ScalarMult(Gen1, k) _, err = e3.ScalarMult(Gen1, NormalizeScalar(k.Bytes()))
if err != nil {
t.Fatal(err)
}
e3.p.MakeAffine() e3.p.MakeAffine()
if !e.Equal(e3) { if !e.Equal(e3) {

View File

@ -47,8 +47,8 @@ func RandomG2(r io.Reader) (*big.Int, *G2, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
g2, err := new(G2).ScalarBaseMult(NormalizeScalar(k.Bytes()))
return k, new(G2).ScalarBaseMult(k), nil return k, g2, err
} }
func (e *G2) String() string { func (e *G2) String() string {
@ -57,13 +57,15 @@ func (e *G2) String() string {
// ScalarBaseMult sets e to g*k where g is the generator of the group and then // ScalarBaseMult sets e to g*k where g is the generator of the group and then
// returns out. // returns out.
func (e *G2) ScalarBaseMult(k *big.Int) *G2 { func (e *G2) ScalarBaseMult(scalar []byte) (*G2, error) {
if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
if e.p == nil { if e.p == nil {
e.p = &twistPoint{} e.p = &twistPoint{}
} }
//e.p.Mul(twistGen, k) //e.p.Mul(twistGen, k)
scalar := normalizeScalar(k.Bytes())
tables := e.generatorTable() tables := e.generatorTable()
// This is also a scalar multiplication with a four-bit window like in // This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value // ScalarMult, but in this case the doublings are precomputed. The value
@ -85,11 +87,11 @@ func (e *G2) ScalarBaseMult(k *big.Int) *G2 {
tableIndex-- tableIndex--
} }
return e return e, nil
} }
// ScalarMult sets e to a*k and then returns e. // ScalarMult sets e to a*k and then returns e.
func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { func (e *G2) ScalarMult(a *G2, scalar []byte) (*G2, error) {
if e.p == nil { if e.p == nil {
e.p = &twistPoint{} e.p = &twistPoint{}
} }
@ -108,8 +110,7 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 {
// four-bit window: we double four times, and then add [0-15]P. // four-bit window: we double four times, and then add [0-15]P.
t := &G2{NewTwistPoint()} t := &G2{NewTwistPoint()}
e.p.SetInfinity() e.p.SetInfinity()
scalarBytes := normalizeScalar(k.Bytes()) for i, byte := range scalar {
for i, byte := range scalarBytes {
// No need to double on the first iteration, as p is the identity at // No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞. // this point, and [N]∞ = ∞.
if i != 0 { if i != 0 {
@ -129,7 +130,7 @@ func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 {
table.Select(t.p, windowValue) table.Select(t.p, windowValue)
e.Add(e, t) e.Add(e, t)
} }
return e return e, nil
} }
// Add sets e to a+b and then returns e. // Add sets e to a+b and then returns e.

View File

@ -14,7 +14,10 @@ func TestG2(t *testing.T) {
} }
ma := Ga.Marshal() ma := Ga.Marshal()
Gb := new(G2).ScalarBaseMult(k) Gb, err := new(G2).ScalarBaseMult(NormalizeScalar(k.Bytes()))
if err != nil {
t.Fatal(err)
}
mb := Gb.Marshal() mb := Gb.Marshal()
if !bytes.Equal(ma, mb) { if !bytes.Equal(ma, mb) {
@ -86,7 +89,10 @@ func TestScaleMult(t *testing.T) {
e3.p.Mul(twistGen, k) e3.p.Mul(twistGen, k)
e3.p.MakeAffine() e3.p.MakeAffine()
e2.ScalarMult(Gen2, k) _, err = e2.ScalarMult(Gen2, NormalizeScalar(k.Bytes()))
if err != nil {
t.Fatal(err)
}
e2.p.MakeAffine() e2.p.MakeAffine()
if !e.Equal(e2) { if !e.Equal(e2) {
t.Errorf("not same") t.Errorf("not same")
@ -110,10 +116,11 @@ func TestG2AddNeg(t *testing.T) {
func BenchmarkG2(b *testing.B) { func BenchmarkG2(b *testing.B) {
x, _ := rand.Int(rand.Reader, Order) x, _ := rand.Int(rand.Reader, Order)
xb := NormalizeScalar(x.Bytes())
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
new(G2).ScalarBaseMult(x) new(G2).ScalarBaseMult(xb)
} }
} }

View File

@ -241,8 +241,10 @@ func GenerateGTFieldTable(basePoint *GT) *[32 * 2]GTFieldTable {
} }
// ScalarBaseMultGT compute basepoint^r with precomputed table // ScalarBaseMultGT compute basepoint^r with precomputed table
func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, r *big.Int) *GT { func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, scalar []byte) (*GT, error) {
scalar := normalizeScalar(r.Bytes()) if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
// This is also a scalar multiplication with a four-bit window like in // This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value // ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled // [windowValue]G added at iteration k would normally get doubled
@ -263,5 +265,48 @@ func ScalarBaseMultGT(tables *[32 * 2]GTFieldTable, r *big.Int) *GT {
e.Add(e, t) e.Add(e, t)
tableIndex-- tableIndex--
} }
return e return e, nil
}
// ScalarMultGT compute a^scalar
func ScalarMultGT(a *GT, scalar []byte) (*GT, error) {
var table GTFieldTable
table[0] = &GT{}
table[0].Set(a)
for i := 1; i < 15; i += 2 {
table[i] = &GT{}
table[i].p = &gfP12{}
table[i].p.Square(table[i/2].p)
table[i+1] = &GT{}
table[i+1].p = &gfP12{}
table[i+1].Add(table[i], a)
}
e, t := &GT{}, &GT{}
e.SetOne()
t.SetOne()
for i, byte := range scalar {
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
e.p.Square(e.p)
e.p.Square(e.p)
e.p.Square(e.p)
e.p.Square(e.p)
}
windowValue := byte >> 4
table.Select(t, windowValue)
e.Add(e, t)
e.p.Square(e.p)
e.p.Square(e.p)
e.p.Square(e.p)
e.p.Square(e.p)
windowValue = byte & 0b1111
table.Select(t, windowValue)
e.Add(e, t)
}
return e, nil
} }

View File

@ -24,6 +24,19 @@ func TestGT(t *testing.T) {
if !bytes.Equal(ma, mb) { if !bytes.Equal(ma, mb) {
t.Fatal("bytes are different") t.Fatal("bytes are different")
} }
_, err = Gb.Unmarshal((&GT{gfP12Gen}).Marshal())
if err != nil {
t.Fatal("unmarshal not ok")
}
Gc, err := ScalarMultGT(Gb, k.Bytes())
if err != nil {
t.Fatal(err)
}
mc := Gc.Marshal()
if !bytes.Equal(ma, mc) {
t.Fatal("bytes are different")
}
} }
func BenchmarkGT(b *testing.B) { func BenchmarkGT(b *testing.B) {

View File

@ -10,6 +10,7 @@ import (
"io" "io"
"math/big" "math/big"
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
@ -20,6 +21,10 @@ import (
// SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification // SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
// OrderNat is the Nat presentation of Order
var OrderNat = bigmod.NewModulusFromBig(bn256.Order)
var OrderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
var bigOne = big.NewInt(1) var bigOne = big.NewInt(1)
type hashMode byte type hashMode byte
@ -57,6 +62,7 @@ func hash(z []byte, h hashMode) *big.Int {
ct++ ct++
md.Reset() md.Reset()
} }
//TODO: how to rewrite this part with nat?
k := new(big.Int).SetBytes(ha[:40]) k := new(big.Int).SetBytes(ha[:40])
n := new(big.Int).Sub(bn256.Order, bigOne) n := new(big.Int).Sub(bn256.Order, bigOne)
k.Mod(k, n) k.Mod(k, n)
@ -72,48 +78,70 @@ func hashH2(z []byte) *big.Int {
return hash(z, H2) return hash(z, H2)
} }
// randFieldElement returns a random element of the order of the given func randomScalar(rand io.Reader) (k *bigmod.Nat, err error) {
// curve using the procedure given in FIPS 186-4, Appendix B.5.1. k = bigmod.NewNat()
func randFieldElement(rand io.Reader) (k *big.Int, err error) { for {
b := make([]byte, 40) // (256 + 64 / 8 b := make([]byte, OrderNat.Size())
_, err = io.ReadFull(rand, b) if _, err = io.ReadFull(rand, b); err != nil {
if err != nil {
return return
} }
k = new(big.Int).SetBytes(b) // Mask off any excess bits to increase the chance of hitting a value in
n := new(big.Int).Sub(bn256.Order, bigOne) // (0, N). These are the most dangerous lines in the package and maybe in
k.Mod(k, n) // the library: a single bit of bias in the selection of nonces would likely
k.Add(k, bigOne) // lead to key recovery, but no tests would fail. Look but DO NOT TOUCH.
if excess := len(b)*8 - OrderNat.BitLen(); excess > 0 {
// Just to be safe, assert that this only happens for the one curve that
// doesn't have a round number of bits.
if excess != 0 {
panic("sm9: internal error: unexpectedly masking off bits")
}
b[0] >>= excess
}
// FIPS 186-4 makes us check k <= N - 2 and then add one.
// Checking 0 < k <= N - 1 is strictly equivalent.
// None of this matters anyway because the chance of selecting
// zero is cryptographically negligible.
if _, err = k.SetBytes(b, OrderNat); err == nil && k.IsZero() == 0 {
break
}
}
return return
} }
// Sign signs a hash (which should be the result of hashing a larger message) // Sign signs a hash (which should be the result of hashing a larger message)
// using the user dsa key. It returns the signature as a pair of h and s. // using the user dsa key. It returns the signature as a pair of h and s.
func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn256.G1, err error) { func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn256.G1, err error) {
var r *big.Int var (
r *bigmod.Nat
w *bn256.GT
hNat *bigmod.Nat
)
for { for {
r, err = randFieldElement(rand) r, err = randomScalar(rand)
if err != nil { if err != nil {
return return
} }
w := priv.SignMasterPublicKey.ScalarBaseMult(r) w, err = priv.SignMasterPublicKey.ScalarBaseMult(r.Bytes(OrderNat))
if err != nil {
return
}
var buffer []byte var buffer []byte
buffer = append(buffer, hash...) buffer = append(buffer, hash...)
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
h = hashH2(buffer) h = hashH2(buffer)
hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), OrderNat)
l := new(big.Int).Sub(r, h) if err != nil {
return
if l.Sign() < 0 {
l.Add(l, bn256.Order)
} }
r.Sub(hNat, OrderNat)
if l.Sign() != 0 { if r.IsZero() == 0 {
s = new(bn256.G1).ScalarMult(priv.PrivateKey, l) s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(OrderNat))
break break
} }
} }
@ -129,7 +157,7 @@ func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.Signer
return nil, err return nil, err
} }
hBytes := make([]byte, 32) hBytes := make([]byte, OrderNat.Size())
h.FillBytes(hBytes) h.FillBytes(hBytes)
var b cryptobyte.Builder var b cryptobyte.Builder
@ -156,7 +184,15 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
return false return false
} }
t := pub.ScalarBaseMult(h) hNat, err := bigmod.NewNat().SetBytes(h.Bytes(), OrderNat)
if err != nil {
return false
}
t, err := pub.ScalarBaseMult(hNat.Bytes(OrderNat))
if err != nil {
return false
}
// user sign public key p generation // user sign public key p generation
p := pub.GenerateUserPublicKey(uid, hid) p := pub.GenerateUserPublicKey(uid, hid)
@ -210,17 +246,26 @@ func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) b
// WrapKey generate and wrap key with reciever's uid and system hid // WrapKey generate and wrap key with reciever's uid and system hid
func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) { func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) {
q := pub.GenerateUserPublicKey(uid, hid) q := pub.GenerateUserPublicKey(uid, hid)
var r *big.Int var (
r *bigmod.Nat
w *bn256.GT
)
for { for {
r, err = randFieldElement(rand) r, err = randomScalar(rand)
if err != nil { if err != nil {
return return
} }
cipher = new(bn256.G1).ScalarMult(q, r) rBytes := r.Bytes(OrderNat)
cipher, err = new(bn256.G1).ScalarMult(q, rBytes)
w := pub.ScalarBaseMult(r) if err != nil {
return
}
w, err = pub.ScalarBaseMult(rBytes)
if err != nil {
return
}
var buffer []byte var buffer []byte
buffer = append(buffer, cipher.Marshal()...) buffer = append(buffer, cipher.Marshal()...)
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
@ -463,7 +508,7 @@ type KeyExchange struct {
privateKey *EncryptPrivateKey // owner's encryption private key privateKey *EncryptPrivateKey // owner's encryption private key
uid []byte // owner uid uid []byte // owner uid
peerUID []byte // peer uid peerUID []byte // peer uid
r *big.Int // random which will be used to compute secret r *bigmod.Nat // random which will be used to compute secret
secret *bn256.G1 // generated secret which will be passed to peer secret *bn256.G1 // generated secret which will be passed to peer
peerSecret *bn256.G1 // received peer's secret peerSecret *bn256.G1 // received peer's secret
g1 *bn256.GT // internal state which will be used when compute the key and signature g1 *bn256.GT // internal state which will be used when compute the key and signature
@ -485,7 +530,7 @@ func NewKeyExchange(priv *EncryptPrivateKey, uid, peerUID []byte, keyLen int, ge
// Destroy clear all internal state and Ephemeral private/public keys // Destroy clear all internal state and Ephemeral private/public keys
func (ke *KeyExchange) Destroy() { func (ke *KeyExchange) Destroy() {
if ke.r != nil { if ke.r != nil {
ke.r.SetInt64(0) ke.r.SetBytes([]byte{0}, OrderNat)
} }
if ke.g1 != nil { if ke.g1 != nil {
ke.g1.SetOne() ke.g1.SetOne()
@ -498,16 +543,19 @@ func (ke *KeyExchange) Destroy() {
} }
} }
func initKeyExchange(ke *KeyExchange, hid byte, r *big.Int) { func initKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat) {
pubB := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid) pubB := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid)
ke.r = r ke.r = r
rA := new(bn256.G1).ScalarMult(pubB, ke.r) rA, err := new(bn256.G1).ScalarMult(pubB, ke.r.Bytes(OrderNat))
if err != nil {
panic(err)
}
ke.secret = rA ke.secret = rA
} }
// InitKeyExchange generate random with responder uid, for initiator's step A1-A4 // InitKeyExchange generate random with responder uid, for initiator's step A1-A4
func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, error) { func (ke *KeyExchange) InitKeyExchange(rand io.Reader, hid byte) (*bn256.G1, error) {
r, err := randFieldElement(rand) r, err := randomScalar(rand)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -559,20 +607,33 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
} }
func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {
if !rA.IsOnCurve() { if !rA.IsOnCurve() {
return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key") return nil, nil, errors.New("sm9: invalid initiator's ephemeral public key")
} }
ke.peerSecret = rA ke.peerSecret = rA
pubA := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid) pubA := ke.privateKey.GenerateUserPublicKey(ke.peerUID, hid)
ke.r = r ke.r = r
rB := new(bn256.G1).ScalarMult(pubA, r) rBytes := r.Bytes(OrderNat)
rB, err := new(bn256.G1).ScalarMult(pubA, rBytes)
if err != nil {
return nil, nil, err
}
ke.secret = rB ke.secret = rB
ke.g1 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey) ke.g1 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey)
ke.g3 = &bn256.GT{} ke.g3 = &bn256.GT{}
ke.g3.ScalarMult(ke.g1, r) g3, err := bn256.ScalarMultGT(ke.g1, rBytes)
ke.g2 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(r) if err != nil {
return nil, nil, err
}
ke.g3 = g3
g2, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(rBytes)
if err != nil {
return nil, nil, err
}
ke.g2 = g2
if !ke.genSignature { if !ke.genSignature {
return ke.secret, nil, nil return ke.secret, nil, nil
@ -583,7 +644,7 @@ func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*b
// RepondKeyExchange when responder receive rA, for responder's step B1-B7 // RepondKeyExchange when responder receive rA, for responder's step B1-B7
func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1) (*bn256.G1, []byte, error) { func (ke *KeyExchange) RepondKeyExchange(rand io.Reader, hid byte, rA *bn256.G1) (*bn256.G1, []byte, error) {
r, err := randFieldElement(rand) r, err := randomScalar(rand)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -597,10 +658,18 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, []byte
} }
// step 5 // step 5
ke.peerSecret = rB ke.peerSecret = rB
ke.g1 = ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r) g1, err := ke.privateKey.EncryptMasterPublicKey.ScalarBaseMult(ke.r.Bytes(OrderNat))
if err != nil {
return nil, nil, err
}
ke.g1 = g1
ke.g2 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey) ke.g2 = bn256.Pair(ke.peerSecret, ke.privateKey.PrivateKey)
ke.g3 = &bn256.GT{} ke.g3 = &bn256.GT{}
ke.g3.ScalarMult(ke.g2, ke.r) g3, err := bn256.ScalarMultGT(ke.g2, ke.r.Bytes(OrderNat))
if err != nil {
return nil, nil, err
}
ke.g3 = g3
// step 6, verify signature // step 6, verify signature
if len(sB) > 0 { if len(sB) > 0 {
signature := ke.sign(false, 0x82) signature := ke.sign(false, 0x82)

View File

@ -8,6 +8,7 @@ import (
"math/big" "math/big"
"sync" "sync"
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/sm9/bn256" "github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1"
@ -57,14 +58,19 @@ type EncryptPrivateKey struct {
// GenerateSignMasterKey generates a master public and private key pair for DSA usage. // GenerateSignMasterKey generates a master public and private key pair for DSA usage.
func GenerateSignMasterKey(rand io.Reader) (*SignMasterPrivateKey, error) { func GenerateSignMasterKey(rand io.Reader) (*SignMasterPrivateKey, error) {
k, err := randFieldElement(rand) k, err := randomScalar(rand)
if err != nil {
return nil, err
}
kBytes := k.Bytes(OrderNat)
p, err := new(bn256.G2).ScalarBaseMult(kBytes)
if err != nil { if err != nil {
return nil, err return nil, err
} }
priv := new(SignMasterPrivateKey) priv := new(SignMasterPrivateKey)
priv.D = k priv.D = new(big.Int).SetBytes(kBytes)
priv.MasterPublicKey = new(bn256.G2).ScalarBaseMult(k) priv.MasterPublicKey = p
return priv, nil return priv, nil
} }
@ -96,7 +102,11 @@ func (master *SignMasterPrivateKey) UnmarshalASN1(der []byte) error {
return errors.New("sm9: invalid sign master private key asn1 data") return errors.New("sm9: invalid sign master private key asn1 data")
} }
master.D = d master.D = d
master.MasterPublicKey = new(bn256.G2).ScalarBaseMult(d) p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(d.Bytes()))
if err != nil {
return err
}
master.MasterPublicKey = p
return nil return nil
} }
@ -107,17 +117,32 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign
id = append(id, hid) id = append(id, hid)
t1 := hashH1(id) t1 := hashH1(id)
t1.Add(t1, master.D)
if t1.Sign() == 0 { t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), OrderNat)
if err != nil {
return nil, err
}
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), OrderNat)
if err != nil {
return nil, err
}
t1Nat.Add(d, OrderNat)
if t1Nat.IsZero() == 1 {
return nil, errors.New("sm9: need to re-generate sign master private key") return nil, errors.New("sm9: need to re-generate sign master private key")
} }
t1 = fermatInverse(t1, bn256.Order)
t2 := new(big.Int).Mul(t1, master.D) t1Nat = bigmod.NewNat().Exp(t1Nat, OrderMinus2, OrderNat)
t2.Mod(t2, bn256.Order) t1Nat.Mul(d, OrderNat)
priv := new(SignPrivateKey) priv := new(SignPrivateKey)
priv.SignMasterPublicKey = master.SignMasterPublicKey priv.SignMasterPublicKey = master.SignMasterPublicKey
priv.PrivateKey = new(bn256.G1).ScalarBaseMult(t2) g1, err := new(bn256.G1).ScalarBaseMult(t1Nat.Bytes(OrderNat))
if err != nil {
return nil, err
}
priv.PrivateKey = g1
return priv, nil return priv, nil
} }
@ -144,9 +169,9 @@ func (pub *SignMasterPublicKey) generatorTable() *[32 * 2]bn256.GTFieldTable {
// ScalarBaseMult compute basepoint^r with precomputed table // ScalarBaseMult compute basepoint^r with precomputed table
// The base point = pair(Gen1, <master public key>) // The base point = pair(Gen1, <master public key>)
func (pub *SignMasterPublicKey) ScalarBaseMult(r *big.Int) *bn256.GT { func (pub *SignMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, error) {
tables := pub.generatorTable() tables := pub.generatorTable()
return bn256.ScalarBaseMultGT(tables, r) return bn256.ScalarBaseMultGT(tables, scalar)
} }
// GenerateUserPublicKey generate user sign public key // GenerateUserPublicKey generate user sign public key
@ -155,7 +180,10 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
buffer = append(buffer, hid) buffer = append(buffer, hid)
h1 := hashH1(buffer) h1 := hashH1(buffer)
p := new(bn256.G2).ScalarBaseMult(h1) p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
if err != nil {
panic(err)
}
p.Add(p, pub.MasterPublicKey) p.Add(p, pub.MasterPublicKey)
return p return p
} }
@ -326,14 +354,19 @@ func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error {
// GenerateEncryptMasterKey generates a master public and private key pair for encryption usage. // GenerateEncryptMasterKey generates a master public and private key pair for encryption usage.
func GenerateEncryptMasterKey(rand io.Reader) (*EncryptMasterPrivateKey, error) { func GenerateEncryptMasterKey(rand io.Reader) (*EncryptMasterPrivateKey, error) {
k, err := randFieldElement(rand) k, err := randomScalar(rand)
if err != nil { if err != nil {
return nil, err return nil, err
} }
kBytes := k.Bytes(OrderNat)
priv := new(EncryptMasterPrivateKey) priv := new(EncryptMasterPrivateKey)
priv.D = k priv.D = new(big.Int).SetBytes(kBytes)
priv.MasterPublicKey = new(bn256.G1).ScalarBaseMult(k) p, err := new(bn256.G1).ScalarBaseMult(kBytes)
if err != nil {
panic(err)
}
priv.MasterPublicKey = p
return priv, nil return priv, nil
} }
@ -344,17 +377,32 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E
id = append(id, hid) id = append(id, hid)
t1 := hashH1(id) t1 := hashH1(id)
t1.Add(t1, master.D)
if t1.Sign() == 0 { t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), OrderNat)
if err != nil {
return nil, err
}
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), OrderNat)
if err != nil {
return nil, err
}
t1Nat.Add(d, OrderNat)
if t1Nat.IsZero() == 1 {
return nil, errors.New("sm9: need to re-generate encrypt master private key") return nil, errors.New("sm9: need to re-generate encrypt master private key")
} }
t1 = fermatInverse(t1, bn256.Order)
t2 := new(big.Int).Mul(t1, master.D) t1Nat = bigmod.NewNat().Exp(t1Nat, OrderMinus2, OrderNat)
t2.Mod(t2, bn256.Order) t1Nat.Mul(d, OrderNat)
priv := new(EncryptPrivateKey) priv := new(EncryptPrivateKey)
priv.EncryptMasterPublicKey = master.EncryptMasterPublicKey priv.EncryptMasterPublicKey = master.EncryptMasterPublicKey
priv.PrivateKey = new(bn256.G2).ScalarBaseMult(t2) p, err := new(bn256.G2).ScalarBaseMult(t1Nat.Bytes(OrderNat))
if err != nil {
panic(err)
}
priv.PrivateKey = p
return priv, nil return priv, nil
} }
@ -392,7 +440,11 @@ func (master *EncryptMasterPrivateKey) UnmarshalASN1(der []byte) error {
return errors.New("sm9: invalid encrypt master private key asn1 data") return errors.New("sm9: invalid encrypt master private key asn1 data")
} }
master.D = d master.D = d
master.MasterPublicKey = new(bn256.G1).ScalarBaseMult(d) p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(d.Bytes()))
if err != nil {
return err
}
master.MasterPublicKey = p
return nil return nil
} }
@ -413,9 +465,9 @@ func (pub *EncryptMasterPublicKey) generatorTable() *[32 * 2]bn256.GTFieldTable
// ScalarBaseMult compute basepoint^r with precomputed table. // ScalarBaseMult compute basepoint^r with precomputed table.
// The base point = pair(<master public key>, Gen2) // The base point = pair(<master public key>, Gen2)
func (pub *EncryptMasterPublicKey) ScalarBaseMult(r *big.Int) *bn256.GT { func (pub *EncryptMasterPublicKey) ScalarBaseMult(scalar []byte) (*bn256.GT, error) {
tables := pub.generatorTable() tables := pub.generatorTable()
return bn256.ScalarBaseMultGT(tables, r) return bn256.ScalarBaseMultGT(tables, scalar)
} }
// GenerateUserPublicKey generate user encrypt public key // GenerateUserPublicKey generate user encrypt public key
@ -424,7 +476,10 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
buffer = append(buffer, hid) buffer = append(buffer, hid)
h1 := hashH1(buffer) h1 := hashH1(buffer)
p := new(bn256.G1).ScalarBaseMult(h1) p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
if err != nil {
panic(err)
}
p.Add(p, pub.MasterPublicKey) p.Add(p, pub.MasterPublicKey)
return p return p
} }
@ -554,14 +609,3 @@ func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error {
} }
return nil return nil
} }
// fermatInverse calculates the inverse of k in GF(P) using Fermat's method
// (exponentiation modulo P - 2, per Euler's theorem). This has better
// constant-time properties than Euclid's method (implemented in
// math/big.Int.ModInverse and FIPS 186-4, Appendix C.1) although math/big
// itself isn't strictly constant-time so it's not perfect.
func fermatInverse(k, N *big.Int) *big.Int {
two := big.NewInt(2)
nMinus2 := new(big.Int).Sub(N, two)
return new(big.Int).Exp(k, nMinus2, N)
}

View File

@ -6,6 +6,7 @@ import (
"math/big" "math/big"
"testing" "testing"
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
@ -98,12 +99,19 @@ func TestSignSM9Sample(t *testing.T) {
masterKey := new(SignMasterPrivateKey) masterKey := new(SignMasterPrivateKey)
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
masterKey.MasterPublicKey = new(bn256.G2).ScalarBaseMult(masterKey.D) p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes()))
if err != nil {
t.Fatal(err)
}
masterKey.MasterPublicKey = p
userKey, err := masterKey.GenerateUserKey(uid, hid) userKey, err := masterKey.GenerateUserKey(uid, hid)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
w := userKey.SignMasterPublicKey.ScalarBaseMult(r) w, err := userKey.SignMasterPublicKey.ScalarBaseMult(bn256.NormalizeScalar(r.Bytes()))
if err != nil {
t.Fatal(err)
}
var buffer []byte var buffer []byte
buffer = append(buffer, hash...) buffer = append(buffer, hash...)
@ -120,7 +128,10 @@ func TestSignSM9Sample(t *testing.T) {
l.Add(l, bn256.Order) l.Add(l, bn256.Order)
} }
s := new(bn256.G1).ScalarMult(userKey.PrivateKey, l) s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes()))
if err != nil {
t.Fatal(err)
}
if hex.EncodeToString(s.MarshalUncompressed()) != expectedS { if hex.EncodeToString(s.MarshalUncompressed()) != expectedS {
t.Fatal("not same S") t.Fatal("not same S")
@ -137,7 +148,11 @@ func TestKeyExchangeSample(t *testing.T) {
masterKey := new(EncryptMasterPrivateKey) masterKey := new(EncryptMasterPrivateKey)
masterKey.D = bigFromHex("02E65B0762D042F51F0D23542B13ED8CFA2E9A0E7206361E013A283905E31F") masterKey.D = bigFromHex("02E65B0762D042F51F0D23542B13ED8CFA2E9A0E7206361E013A283905E31F")
masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes()))
if err != nil {
t.Fatal(err)
}
masterKey.MasterPublicKey = p
if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedPube { if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedPube {
t.Errorf("not expected master public key") t.Errorf("not expected master public key")
@ -162,14 +177,22 @@ func TestKeyExchangeSample(t *testing.T) {
responder.Destroy() responder.Destroy()
}() }()
// A1-A4 // A1-A4
initKeyExchange(initiator, hid, bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8")) k, err := bigmod.NewNat().SetBytes(bigFromHex("5879DD1D51E175946F23B1B41E93BA31C584AE59A426EC1046A4D03B06C8").Bytes(), OrderNat)
if err != nil {
t.Fatal(err)
}
initKeyExchange(initiator, hid, k)
if hex.EncodeToString(initiator.secret.Marshal()) != "7cba5b19069ee66aa79d490413d11846b9ba76dd22567f809cf23b6d964bb265a9760c99cb6f706343fed05637085864958d6c90902aba7d405fbedf7b781599" { if hex.EncodeToString(initiator.secret.Marshal()) != "7cba5b19069ee66aa79d490413d11846b9ba76dd22567f809cf23b6d964bb265a9760c99cb6f706343fed05637085864958d6c90902aba7d405fbedf7b781599" {
t.Fatal("not same") t.Fatal("not same")
} }
// B1 - B7 // B1 - B7
rB, sigB, err := respondKeyExchange(responder, hid, bigFromHex("018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE"), initiator.secret) k, err = bigmod.NewNat().SetBytes(bigFromHex("018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE").Bytes(), OrderNat)
if err != nil {
t.Fatal(err)
}
rB, sigB, err := respondKeyExchange(responder, hid, k, initiator.secret)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -403,7 +426,11 @@ func TestWrapKeySM9Sample(t *testing.T) {
masterKey := new(EncryptMasterPrivateKey) masterKey := new(EncryptMasterPrivateKey)
masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22")
masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes()))
if err != nil {
t.Fatal(err)
}
masterKey.MasterPublicKey = p
if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey { if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey {
t.Errorf("not expected master public key") t.Errorf("not expected master public key")
} }
@ -425,7 +452,10 @@ func TestWrapKeySM9Sample(t *testing.T) {
} }
var r *big.Int = bigFromHex("74015F8489C01EF4270456F9E6475BFB602BDE7F33FD482AB4E3684A6722") var r *big.Int = bigFromHex("74015F8489C01EF4270456F9E6475BFB602BDE7F33FD482AB4E3684A6722")
cipher := new(bn256.G1).ScalarMult(q, r) cipher, err := new(bn256.G1).ScalarMult(q, bn256.NormalizeScalar(r.Bytes()))
if err != nil {
t.Fatal(err)
}
if hex.EncodeToString(cipher.Marshal()) != expectedCipher { if hex.EncodeToString(cipher.Marshal()) != expectedCipher {
t.Errorf("not expected cipher") t.Errorf("not expected cipher")
} }
@ -465,7 +495,11 @@ func TestEncryptSM9Sample(t *testing.T) {
masterKey := new(EncryptMasterPrivateKey) masterKey := new(EncryptMasterPrivateKey)
masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22")
masterKey.MasterPublicKey = new(bn256.G1).ScalarBaseMult(masterKey.D) p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(masterKey.D.Bytes()))
if err != nil {
t.Fatal(err)
}
masterKey.MasterPublicKey = p
if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey { if hex.EncodeToString(masterKey.MasterPublicKey.Marshal()) != expectedMasterPublicKey {
t.Errorf("not expected master public key") t.Errorf("not expected master public key")
} }
@ -487,7 +521,10 @@ func TestEncryptSM9Sample(t *testing.T) {
} }
var r *big.Int = bigFromHex("AAC0541779C8FC45E3E2CB25C12B5D2576B2129AE8BB5EE2CBE5EC9E785C") var r *big.Int = bigFromHex("AAC0541779C8FC45E3E2CB25C12B5D2576B2129AE8BB5EE2CBE5EC9E785C")
cipher := new(bn256.G1).ScalarMult(q, r) cipher, err := new(bn256.G1).ScalarMult(q, bn256.NormalizeScalar(r.Bytes()))
if err != nil {
t.Fatal(err)
}
if hex.EncodeToString(cipher.Marshal()) != expectedCipher { if hex.EncodeToString(cipher.Marshal()) != expectedCipher {
t.Errorf("not expected cipher") t.Errorf("not expected cipher")
} }