diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index 642214f..335e80d 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -123,11 +123,11 @@ func (x *Nat) Set(y *Nat) *Nat { return x } -// setBig assigns x = n, optionally resizing n to the appropriate size. +// SetBig assigns x = n, optionally resizing n to the appropriate size. // // The announced length of x is set based on the actual bit size of the input, // ignoring leading zeroes. -func (x *Nat) setBig(n *big.Int) *Nat { +func (x *Nat) SetBig(n *big.Int) *Nat { requiredLimbs := (n.BitLen() + _W - 1) / _W x.reset(requiredLimbs) @@ -386,7 +386,7 @@ func minusInverseModW(x uint) uint { // The Int must be odd. The number of significant bits must be leakable. func NewModulusFromBig(n *big.Int) *Modulus { m := &Modulus{} - m.nat = NewNat().setBig(n) + m.nat = NewNat().SetBig(n) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) m.m0inv = minusInverseModW(m.nat.limbs[0]) m.rr = rr(m) @@ -426,13 +426,20 @@ func (m *Modulus) Nat() *Nat { // // This assumes that x is already reduced mod m, and that y < 2^_W. func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { - d := NewNat().resetFor(m) + return x.shiftInNat(y, m.nat) +} + +// shiftIn calculates x = x << _W + y mod m. +// +// This assumes that x is already reduced mod m, and that y < 2^_W. +func (x *Nat) shiftInNat(y uint, m *Nat) *Nat { + d := NewNat().reset(len(m.limbs)) // Eliminate bounds checks in the loop. - size := len(m.nat.limbs) + size := len(m.limbs) xLimbs := x.limbs[:size] dLimbs := d.limbs[:size] - mLimbs := m.nat.limbs[:size] + mLimbs := m.limbs[:size] // Each iteration of this loop computes x = 2x + b mod m, where b is a bit // from y. Effectively, it left-shifts x and adds y one bit at a time, @@ -469,7 +476,16 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { // // The output will be resized to the size of m and overwritten. func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { - out.resetFor(m) + return out.ModNat(x, m.nat) +} + +// Mod calculates out = x mod m. +// +// This works regardless how large the value of x is. +// +// The output will be resized to the size of m and overwritten. +func (out *Nat) ModNat(x *Nat, m *Nat) *Nat { + out.reset(len(m.limbs)) // Working our way from the most significant to the least significant limb, // we can insert each limb at the least significant position, shifting all // previous limbs left by _W. This way each limb will get shifted by the @@ -478,7 +494,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { i := len(x.limbs) - 1 // For the first N - 1 limbs we can skip the actual shifting and position // them at the shifted position, which starts at min(N - 2, i). - start := len(m.nat.limbs) - 2 + start := len(m.limbs) - 2 if i < start { start = i } @@ -488,7 +504,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { } // We shift in the remaining limbs, reducing modulo m each time. for i >= 0 { - out.shiftIn(x.limbs[i], m) + out.shiftInNat(x.limbs[i], m) i-- } return out diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 89ce818..8d82165 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -244,6 +244,52 @@ func TestMod(t *testing.T) { } } +func TestModNat(t *testing.T) { + order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16) + orderNat := NewModulusFromBig(order) + oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat) + if err != nil { + t.Fatal(err) + } + orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16) + hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16) + hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16) + examples := []struct { + in *big.Int + expected *big.Int + }{ + { + big.NewInt(1), + big.NewInt(2), + }, + { + orderMinus1, + big.NewInt(1), + }, + { + order, + big.NewInt(2), + }, + { + hashValue1, + new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)), + }, + { + hashValue2, + new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)), + }, + } + for i, tt := range examples { + kNat := NewNat().SetBig(tt.in) + kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1)) + kNat.Add(oneNat, orderNat) + out := new(big.Int).SetBytes(kNat.Bytes(orderNat)) + if out.Cmp(tt.expected) != 0 { + t.Errorf("%d: got %x, expected %x", i, out, tt.expected) + } + } +} + func TestModSub(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{6}} @@ -293,7 +339,7 @@ func natBytes(n *Nat) []byte { func natFromBytes(b []byte) *Nat { bb := new(big.Int).SetBytes(b) - return NewNat().setBig(bb) + return NewNat().SetBig(bb) } func modulusFromBytes(b []byte) *Modulus { diff --git a/sm9/sm9.go b/sm9/sm9.go index cd1211d..dc00564 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -24,7 +24,12 @@ import ( var orderNat = bigmod.NewModulusFromBig(bn256.Order) var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes() var bigOne = big.NewInt(1) -var orderMinus1 = new(big.Int).Sub(bn256.Order, bigOne) +var bigOneNat *bigmod.Nat +var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne)) + +func init() { + bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat) +} type hashMode byte @@ -46,7 +51,7 @@ const ( ) //hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. -func hash(z []byte, h hashMode) *big.Int { +func hash(z []byte, h hashMode) *bigmod.Nat { md := sm3.New() var ha [64]byte var countBytes [4]byte @@ -61,18 +66,18 @@ func hash(z []byte, h hashMode) *big.Int { ct++ md.Reset() } - //TODO: how to rewrite this part with nat? k := new(big.Int).SetBytes(ha[:40]) - k.Mod(k, orderMinus1) - k.Add(k, bigOne) - return k + kNat := bigmod.NewNat().SetBig(k) + kNat = bigmod.NewNat().ModNat(kNat, orderMinus1) + kNat.Add(bigOneNat, orderNat) + return kNat } -func hashH1(z []byte) *big.Int { +func hashH1(z []byte) *bigmod.Nat { return hash(z, H1) } -func hashH2(z []byte) *big.Int { +func hashH2(z []byte) *bigmod.Nat { return hash(z, H2) } @@ -131,14 +136,11 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn2 buffer = append(buffer, hash...) buffer = append(buffer, w.Marshal()...) - h = hashH2(buffer) - hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), orderNat) - if err != nil { - return - } + hNat = hashH2(buffer) r.Sub(hNat, orderNat) if r.IsZero() == 0 { + h = new(big.Int).SetBytes(hNat.Bytes(orderNat)) s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)) break } @@ -203,7 +205,7 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big. buffer = append(buffer, w.Marshal()...) h2 := hashH2(buffer) - return h.Cmp(h2) == 0 + return h2.Equal(hNat) == 1 } // VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index 8291e85..f2cc7ec 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -116,12 +116,7 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign id = append(id, uid...) id = append(id, hid) - t1 := hashH1(id) - - t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat) - if err != nil { - return nil, err - } + t1Nat := hashH1(id) d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat) if err != nil { @@ -180,7 +175,7 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2 buffer = append(buffer, uid...) buffer = append(buffer, hid) h1 := hashH1(buffer) - p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) + p, err := new(bn256.G2).ScalarBaseMult(h1.Bytes(orderNat)) if err != nil { panic(err) } @@ -376,12 +371,7 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E id = append(id, uid...) id = append(id, hid) - t1 := hashH1(id) - - t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat) - if err != nil { - return nil, err - } + t1Nat := hashH1(id) d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat) if err != nil { @@ -476,7 +466,7 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) * buffer = append(buffer, uid...) buffer = append(buffer, hid) h1 := hashH1(buffer) - p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) + p, err := new(bn256.G1).ScalarBaseMult(h1.Bytes(orderNat)) if err != nil { panic(err) } diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 13a5650..81eed76 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -24,8 +24,8 @@ func bigFromHex(s string) *big.Int { func TestHashH1(t *testing.T) { expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab" h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01}) - if hex.EncodeToString(h.Bytes()) != expected { - t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) + if hex.EncodeToString(h.Bytes(orderNat)) != expected { + t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected) } } @@ -37,8 +37,8 @@ func TestHashH2(t *testing.T) { t.Fatal(err) } h := hashH2(z) - if hex.EncodeToString(h.Bytes()) != expected { - t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) + if hex.EncodeToString(h.Bytes(orderNat)) != expected { + t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected) } } @@ -91,11 +91,19 @@ func TestSignASN1(t *testing.T) { // SM9 Appendix A func TestSignSM9Sample(t *testing.T) { expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb") + expectedHNat, err := bigmod.NewNat().SetBytes(expectedH.Bytes(), orderNat) + if err != nil { + t.Fatal(err) + } expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05" hash := []byte("Chinese IBS standard") hid := byte(0x01) uid := []byte("Alice") r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe") + rNat, err := bigmod.NewNat().SetBytes(r.Bytes(), orderNat) + if err != nil { + t.Fatal(err) + } masterKey := new(SignMasterPrivateKey) masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") @@ -118,17 +126,13 @@ func TestSignSM9Sample(t *testing.T) { buffer = append(buffer, w.Marshal()...) h := hashH2(buffer) - if h.Cmp(expectedH) != 0 { + if h.Equal(expectedHNat) == 0 { t.Fatal("not same h") } - l := new(big.Int).Sub(r, h) + rNat.Sub(h, orderNat) - if l.Sign() < 0 { - l.Add(l, bn256.Order) - } - - s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes())) + s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, rNat.Bytes(orderNat)) if err != nil { t.Fatal(err) }