sm9: use bigmod instead of math/big, part 2

This commit is contained in:
Sun Yimin 2022-11-25 17:45:11 +08:00 committed by GitHub
parent a592631459
commit c477816aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 49 deletions

View File

@ -123,11 +123,11 @@ func (x *Nat) Set(y *Nat) *Nat {
return x return x
} }
// setBig assigns x = n, optionally resizing n to the appropriate size. // SetBig assigns x = n, optionally resizing n to the appropriate size.
// //
// The announced length of x is set based on the actual bit size of the input, // The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes. // ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat { func (x *Nat) SetBig(n *big.Int) *Nat {
requiredLimbs := (n.BitLen() + _W - 1) / _W requiredLimbs := (n.BitLen() + _W - 1) / _W
x.reset(requiredLimbs) x.reset(requiredLimbs)
@ -386,7 +386,7 @@ func minusInverseModW(x uint) uint {
// The Int must be odd. The number of significant bits must be leakable. // The Int must be odd. The number of significant bits must be leakable.
func NewModulusFromBig(n *big.Int) *Modulus { func NewModulusFromBig(n *big.Int) *Modulus {
m := &Modulus{} m := &Modulus{}
m.nat = NewNat().setBig(n) m.nat = NewNat().SetBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0]) m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m) m.rr = rr(m)
@ -426,13 +426,20 @@ func (m *Modulus) Nat() *Nat {
// //
// This assumes that x is already reduced mod m, and that y < 2^_W. // This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat { func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().resetFor(m) return x.shiftInNat(y, m.nat)
}
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
d := NewNat().reset(len(m.limbs))
// Eliminate bounds checks in the loop. // Eliminate bounds checks in the loop.
size := len(m.nat.limbs) size := len(m.limbs)
xLimbs := x.limbs[:size] xLimbs := x.limbs[:size]
dLimbs := d.limbs[:size] dLimbs := d.limbs[:size]
mLimbs := m.nat.limbs[:size] mLimbs := m.limbs[:size]
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit // Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time, // from y. Effectively, it left-shifts x and adds y one bit at a time,
@ -469,7 +476,16 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
// //
// The output will be resized to the size of m and overwritten. // The output will be resized to the size of m and overwritten.
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat { func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
out.resetFor(m) return out.ModNat(x, m.nat)
}
// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
out.reset(len(m.limbs))
// Working our way from the most significant to the least significant limb, // Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all // we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the // previous limbs left by _W. This way each limb will get shifted by the
@ -478,7 +494,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
i := len(x.limbs) - 1 i := len(x.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position // For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i). // them at the shifted position, which starts at min(N - 2, i).
start := len(m.nat.limbs) - 2 start := len(m.limbs) - 2
if i < start { if i < start {
start = i start = i
} }
@ -488,7 +504,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
} }
// We shift in the remaining limbs, reducing modulo m each time. // We shift in the remaining limbs, reducing modulo m each time.
for i >= 0 { for i >= 0 {
out.shiftIn(x.limbs[i], m) out.shiftInNat(x.limbs[i], m)
i-- i--
} }
return out return out

View File

@ -244,6 +244,52 @@ func TestMod(t *testing.T) {
} }
} }
func TestModNat(t *testing.T) {
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
orderNat := NewModulusFromBig(order)
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
examples := []struct {
in *big.Int
expected *big.Int
}{
{
big.NewInt(1),
big.NewInt(2),
},
{
orderMinus1,
big.NewInt(1),
},
{
order,
big.NewInt(2),
},
{
hashValue1,
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
},
{
hashValue2,
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
},
}
for i, tt := range examples {
kNat := NewNat().SetBig(tt.in)
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
kNat.Add(oneNat, orderNat)
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
if out.Cmp(tt.expected) != 0 {
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
}
}
}
func TestModSub(t *testing.T) { func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13}) m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}} x := &Nat{[]uint{6}}
@ -293,7 +339,7 @@ func natBytes(n *Nat) []byte {
func natFromBytes(b []byte) *Nat { func natFromBytes(b []byte) *Nat {
bb := new(big.Int).SetBytes(b) bb := new(big.Int).SetBytes(b)
return NewNat().setBig(bb) return NewNat().SetBig(bb)
} }
func modulusFromBytes(b []byte) *Modulus { func modulusFromBytes(b []byte) *Modulus {

View File

@ -24,7 +24,12 @@ import (
var orderNat = bigmod.NewModulusFromBig(bn256.Order) var orderNat = bigmod.NewModulusFromBig(bn256.Order)
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes() var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
var bigOne = big.NewInt(1) var bigOne = big.NewInt(1)
var orderMinus1 = new(big.Int).Sub(bn256.Order, bigOne) var bigOneNat *bigmod.Nat
var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
func init() {
bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
}
type hashMode byte type hashMode byte
@ -46,7 +51,7 @@ const (
) )
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. //hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
func hash(z []byte, h hashMode) *big.Int { func hash(z []byte, h hashMode) *bigmod.Nat {
md := sm3.New() md := sm3.New()
var ha [64]byte var ha [64]byte
var countBytes [4]byte var countBytes [4]byte
@ -61,18 +66,18 @@ func hash(z []byte, h hashMode) *big.Int {
ct++ ct++
md.Reset() md.Reset()
} }
//TODO: how to rewrite this part with nat?
k := new(big.Int).SetBytes(ha[:40]) k := new(big.Int).SetBytes(ha[:40])
k.Mod(k, orderMinus1) kNat := bigmod.NewNat().SetBig(k)
k.Add(k, bigOne) kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
return k kNat.Add(bigOneNat, orderNat)
return kNat
} }
func hashH1(z []byte) *big.Int { func hashH1(z []byte) *bigmod.Nat {
return hash(z, H1) return hash(z, H1)
} }
func hashH2(z []byte) *big.Int { func hashH2(z []byte) *bigmod.Nat {
return hash(z, H2) return hash(z, H2)
} }
@ -131,14 +136,11 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn2
buffer = append(buffer, hash...) buffer = append(buffer, hash...)
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
h = hashH2(buffer) hNat = hashH2(buffer)
hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), orderNat)
if err != nil {
return
}
r.Sub(hNat, orderNat) r.Sub(hNat, orderNat)
if r.IsZero() == 0 { if r.IsZero() == 0 {
h = new(big.Int).SetBytes(hNat.Bytes(orderNat))
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat)) s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
break break
} }
@ -203,7 +205,7 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
h2 := hashH2(buffer) h2 := hashH2(buffer)
return h.Cmp(h2) == 0 return h2.Equal(hNat) == 1
} }
// VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the // VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the

View File

@ -116,12 +116,7 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign
id = append(id, uid...) id = append(id, uid...)
id = append(id, hid) id = append(id, hid)
t1 := hashH1(id) t1Nat := hashH1(id)
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
if err != nil {
return nil, err
}
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat) d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
if err != nil { if err != nil {
@ -180,7 +175,7 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
buffer = append(buffer, hid) buffer = append(buffer, hid)
h1 := hashH1(buffer) h1 := hashH1(buffer)
p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) p, err := new(bn256.G2).ScalarBaseMult(h1.Bytes(orderNat))
if err != nil { if err != nil {
panic(err) panic(err)
} }
@ -376,12 +371,7 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E
id = append(id, uid...) id = append(id, uid...)
id = append(id, hid) id = append(id, hid)
t1 := hashH1(id) t1Nat := hashH1(id)
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
if err != nil {
return nil, err
}
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat) d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
if err != nil { if err != nil {
@ -476,7 +466,7 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
buffer = append(buffer, hid) buffer = append(buffer, hid)
h1 := hashH1(buffer) h1 := hashH1(buffer)
p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes())) p, err := new(bn256.G1).ScalarBaseMult(h1.Bytes(orderNat))
if err != nil { if err != nil {
panic(err) panic(err)
} }

View File

@ -24,8 +24,8 @@ func bigFromHex(s string) *big.Int {
func TestHashH1(t *testing.T) { func TestHashH1(t *testing.T) {
expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab" expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab"
h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01}) h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01})
if hex.EncodeToString(h.Bytes()) != expected { if hex.EncodeToString(h.Bytes(orderNat)) != expected {
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
} }
} }
@ -37,8 +37,8 @@ func TestHashH2(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
h := hashH2(z) h := hashH2(z)
if hex.EncodeToString(h.Bytes()) != expected { if hex.EncodeToString(h.Bytes(orderNat)) != expected {
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
} }
} }
@ -91,11 +91,19 @@ func TestSignASN1(t *testing.T) {
// SM9 Appendix A // SM9 Appendix A
func TestSignSM9Sample(t *testing.T) { func TestSignSM9Sample(t *testing.T) {
expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb") expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb")
expectedHNat, err := bigmod.NewNat().SetBytes(expectedH.Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05" expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05"
hash := []byte("Chinese IBS standard") hash := []byte("Chinese IBS standard")
hid := byte(0x01) hid := byte(0x01)
uid := []byte("Alice") uid := []byte("Alice")
r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe") r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe")
rNat, err := bigmod.NewNat().SetBytes(r.Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
masterKey := new(SignMasterPrivateKey) masterKey := new(SignMasterPrivateKey)
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
@ -118,17 +126,13 @@ func TestSignSM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
h := hashH2(buffer) h := hashH2(buffer)
if h.Cmp(expectedH) != 0 { if h.Equal(expectedHNat) == 0 {
t.Fatal("not same h") t.Fatal("not same h")
} }
l := new(big.Int).Sub(r, h) rNat.Sub(h, orderNat)
if l.Sign() < 0 { s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, rNat.Bytes(orderNat))
l.Add(l, bn256.Order)
}
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes()))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }