sm9: use bigmod instead of math/big, part 2

This commit is contained in:
Sun Yimin 2022-11-25 17:45:11 +08:00 committed by GitHub
parent a592631459
commit c477816aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 49 deletions

View File

@ -123,11 +123,11 @@ func (x *Nat) Set(y *Nat) *Nat {
return x
}
// setBig assigns x = n, optionally resizing n to the appropriate size.
// SetBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat {
func (x *Nat) SetBig(n *big.Int) *Nat {
requiredLimbs := (n.BitLen() + _W - 1) / _W
x.reset(requiredLimbs)
@ -386,7 +386,7 @@ func minusInverseModW(x uint) uint {
// The Int must be odd. The number of significant bits must be leakable.
func NewModulusFromBig(n *big.Int) *Modulus {
m := &Modulus{}
m.nat = NewNat().setBig(n)
m.nat = NewNat().SetBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
@ -426,13 +426,20 @@ func (m *Modulus) Nat() *Nat {
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
d := NewNat().resetFor(m)
return x.shiftInNat(y, m.nat)
}
// shiftIn calculates x = x << _W + y mod m.
//
// This assumes that x is already reduced mod m, and that y < 2^_W.
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
d := NewNat().reset(len(m.limbs))
// Eliminate bounds checks in the loop.
size := len(m.nat.limbs)
size := len(m.limbs)
xLimbs := x.limbs[:size]
dLimbs := d.limbs[:size]
mLimbs := m.nat.limbs[:size]
mLimbs := m.limbs[:size]
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
// from y. Effectively, it left-shifts x and adds y one bit at a time,
@ -469,7 +476,16 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
out.resetFor(m)
return out.ModNat(x, m.nat)
}
// Mod calculates out = x mod m.
//
// This works regardless how large the value of x is.
//
// The output will be resized to the size of m and overwritten.
func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
out.reset(len(m.limbs))
// Working our way from the most significant to the least significant limb,
// we can insert each limb at the least significant position, shifting all
// previous limbs left by _W. This way each limb will get shifted by the
@ -478,7 +494,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
i := len(x.limbs) - 1
// For the first N - 1 limbs we can skip the actual shifting and position
// them at the shifted position, which starts at min(N - 2, i).
start := len(m.nat.limbs) - 2
start := len(m.limbs) - 2
if i < start {
start = i
}
@ -488,7 +504,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
}
// We shift in the remaining limbs, reducing modulo m each time.
for i >= 0 {
out.shiftIn(x.limbs[i], m)
out.shiftInNat(x.limbs[i], m)
i--
}
return out

View File

@ -244,6 +244,52 @@ func TestMod(t *testing.T) {
}
}
func TestModNat(t *testing.T) {
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
orderNat := NewModulusFromBig(order)
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
examples := []struct {
in *big.Int
expected *big.Int
}{
{
big.NewInt(1),
big.NewInt(2),
},
{
orderMinus1,
big.NewInt(1),
},
{
order,
big.NewInt(2),
},
{
hashValue1,
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
},
{
hashValue2,
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
},
}
for i, tt := range examples {
kNat := NewNat().SetBig(tt.in)
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
kNat.Add(oneNat, orderNat)
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
if out.Cmp(tt.expected) != 0 {
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
}
}
}
func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
@ -293,7 +339,7 @@ func natBytes(n *Nat) []byte {
func natFromBytes(b []byte) *Nat {
bb := new(big.Int).SetBytes(b)
return NewNat().setBig(bb)
return NewNat().SetBig(bb)
}
func modulusFromBytes(b []byte) *Modulus {

View File

@ -24,7 +24,12 @@ import (
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
var bigOne = big.NewInt(1)
var orderMinus1 = new(big.Int).Sub(bn256.Order, bigOne)
var bigOneNat *bigmod.Nat
var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
func init() {
bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
}
type hashMode byte
@ -46,7 +51,7 @@ const (
)
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
func hash(z []byte, h hashMode) *big.Int {
func hash(z []byte, h hashMode) *bigmod.Nat {
md := sm3.New()
var ha [64]byte
var countBytes [4]byte
@ -61,18 +66,18 @@ func hash(z []byte, h hashMode) *big.Int {
ct++
md.Reset()
}
//TODO: how to rewrite this part with nat?
k := new(big.Int).SetBytes(ha[:40])
k.Mod(k, orderMinus1)
k.Add(k, bigOne)
return k
kNat := bigmod.NewNat().SetBig(k)
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
kNat.Add(bigOneNat, orderNat)
return kNat
}
func hashH1(z []byte) *big.Int {
func hashH1(z []byte) *bigmod.Nat {
return hash(z, H1)
}
func hashH2(z []byte) *big.Int {
func hashH2(z []byte) *bigmod.Nat {
return hash(z, H2)
}
@ -131,14 +136,11 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn2
buffer = append(buffer, hash...)
buffer = append(buffer, w.Marshal()...)
h = hashH2(buffer)
hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), orderNat)
if err != nil {
return
}
hNat = hashH2(buffer)
r.Sub(hNat, orderNat)
if r.IsZero() == 0 {
h = new(big.Int).SetBytes(hNat.Bytes(orderNat))
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
break
}
@ -203,7 +205,7 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
buffer = append(buffer, w.Marshal()...)
h2 := hashH2(buffer)
return h.Cmp(h2) == 0
return h2.Equal(hNat) == 1
}
// VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the

View File

@ -116,12 +116,7 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign
id = append(id, uid...)
id = append(id, hid)
t1 := hashH1(id)
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
if err != nil {
return nil, err
}
t1Nat := hashH1(id)
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
if err != nil {
@ -180,7 +175,7 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2
buffer = append(buffer, uid...)
buffer = append(buffer, hid)
h1 := hashH1(buffer)
p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
p, err := new(bn256.G2).ScalarBaseMult(h1.Bytes(orderNat))
if err != nil {
panic(err)
}
@ -376,12 +371,7 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E
id = append(id, uid...)
id = append(id, hid)
t1 := hashH1(id)
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
if err != nil {
return nil, err
}
t1Nat := hashH1(id)
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
if err != nil {
@ -476,7 +466,7 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
buffer = append(buffer, uid...)
buffer = append(buffer, hid)
h1 := hashH1(buffer)
p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
p, err := new(bn256.G1).ScalarBaseMult(h1.Bytes(orderNat))
if err != nil {
panic(err)
}

View File

@ -24,8 +24,8 @@ func bigFromHex(s string) *big.Int {
func TestHashH1(t *testing.T) {
expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab"
h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01})
if hex.EncodeToString(h.Bytes()) != expected {
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
}
}
@ -37,8 +37,8 @@ func TestHashH2(t *testing.T) {
t.Fatal(err)
}
h := hashH2(z)
if hex.EncodeToString(h.Bytes()) != expected {
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
}
}
@ -91,11 +91,19 @@ func TestSignASN1(t *testing.T) {
// SM9 Appendix A
func TestSignSM9Sample(t *testing.T) {
expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb")
expectedHNat, err := bigmod.NewNat().SetBytes(expectedH.Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05"
hash := []byte("Chinese IBS standard")
hid := byte(0x01)
uid := []byte("Alice")
r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe")
rNat, err := bigmod.NewNat().SetBytes(r.Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
masterKey := new(SignMasterPrivateKey)
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
@ -118,17 +126,13 @@ func TestSignSM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...)
h := hashH2(buffer)
if h.Cmp(expectedH) != 0 {
if h.Equal(expectedHNat) == 0 {
t.Fatal("not same h")
}
l := new(big.Int).Sub(r, h)
rNat.Sub(h, orderNat)
if l.Sign() < 0 {
l.Add(l, bn256.Order)
}
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes()))
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, rNat.Bytes(orderNat))
if err != nil {
t.Fatal(err)
}