mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
sm9: use bigmod instead of math/big, part 2
This commit is contained in:
parent
a592631459
commit
c477816aa7
@ -123,11 +123,11 @@ func (x *Nat) Set(y *Nat) *Nat {
|
||||
return x
|
||||
}
|
||||
|
||||
// setBig assigns x = n, optionally resizing n to the appropriate size.
|
||||
// SetBig assigns x = n, optionally resizing n to the appropriate size.
|
||||
//
|
||||
// The announced length of x is set based on the actual bit size of the input,
|
||||
// ignoring leading zeroes.
|
||||
func (x *Nat) setBig(n *big.Int) *Nat {
|
||||
func (x *Nat) SetBig(n *big.Int) *Nat {
|
||||
requiredLimbs := (n.BitLen() + _W - 1) / _W
|
||||
x.reset(requiredLimbs)
|
||||
|
||||
@ -386,7 +386,7 @@ func minusInverseModW(x uint) uint {
|
||||
// The Int must be odd. The number of significant bits must be leakable.
|
||||
func NewModulusFromBig(n *big.Int) *Modulus {
|
||||
m := &Modulus{}
|
||||
m.nat = NewNat().setBig(n)
|
||||
m.nat = NewNat().SetBig(n)
|
||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||
m.rr = rr(m)
|
||||
@ -426,13 +426,20 @@ func (m *Modulus) Nat() *Nat {
|
||||
//
|
||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||
d := NewNat().resetFor(m)
|
||||
return x.shiftInNat(y, m.nat)
|
||||
}
|
||||
|
||||
// shiftIn calculates x = x << _W + y mod m.
|
||||
//
|
||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
||||
d := NewNat().reset(len(m.limbs))
|
||||
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(m.nat.limbs)
|
||||
size := len(m.limbs)
|
||||
xLimbs := x.limbs[:size]
|
||||
dLimbs := d.limbs[:size]
|
||||
mLimbs := m.nat.limbs[:size]
|
||||
mLimbs := m.limbs[:size]
|
||||
|
||||
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
||||
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
||||
@ -469,7 +476,16 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||
out.resetFor(m)
|
||||
return out.ModNat(x, m.nat)
|
||||
}
|
||||
|
||||
// Mod calculates out = x mod m.
|
||||
//
|
||||
// This works regardless how large the value of x is.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
|
||||
out.reset(len(m.limbs))
|
||||
// Working our way from the most significant to the least significant limb,
|
||||
// we can insert each limb at the least significant position, shifting all
|
||||
// previous limbs left by _W. This way each limb will get shifted by the
|
||||
@ -478,7 +494,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||
i := len(x.limbs) - 1
|
||||
// For the first N - 1 limbs we can skip the actual shifting and position
|
||||
// them at the shifted position, which starts at min(N - 2, i).
|
||||
start := len(m.nat.limbs) - 2
|
||||
start := len(m.limbs) - 2
|
||||
if i < start {
|
||||
start = i
|
||||
}
|
||||
@ -488,7 +504,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||
}
|
||||
// We shift in the remaining limbs, reducing modulo m each time.
|
||||
for i >= 0 {
|
||||
out.shiftIn(x.limbs[i], m)
|
||||
out.shiftInNat(x.limbs[i], m)
|
||||
i--
|
||||
}
|
||||
return out
|
||||
|
@ -244,6 +244,52 @@ func TestMod(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModNat(t *testing.T) {
|
||||
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||
orderNat := NewModulusFromBig(order)
|
||||
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
|
||||
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
|
||||
examples := []struct {
|
||||
in *big.Int
|
||||
expected *big.Int
|
||||
}{
|
||||
{
|
||||
big.NewInt(1),
|
||||
big.NewInt(2),
|
||||
},
|
||||
{
|
||||
orderMinus1,
|
||||
big.NewInt(1),
|
||||
},
|
||||
{
|
||||
order,
|
||||
big.NewInt(2),
|
||||
},
|
||||
{
|
||||
hashValue1,
|
||||
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
|
||||
},
|
||||
{
|
||||
hashValue2,
|
||||
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
|
||||
},
|
||||
}
|
||||
for i, tt := range examples {
|
||||
kNat := NewNat().SetBig(tt.in)
|
||||
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
|
||||
kNat.Add(oneNat, orderNat)
|
||||
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
|
||||
if out.Cmp(tt.expected) != 0 {
|
||||
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModSub(t *testing.T) {
|
||||
m := modulusFromBytes([]byte{13})
|
||||
x := &Nat{[]uint{6}}
|
||||
@ -293,7 +339,7 @@ func natBytes(n *Nat) []byte {
|
||||
|
||||
func natFromBytes(b []byte) *Nat {
|
||||
bb := new(big.Int).SetBytes(b)
|
||||
return NewNat().setBig(bb)
|
||||
return NewNat().SetBig(bb)
|
||||
}
|
||||
|
||||
func modulusFromBytes(b []byte) *Modulus {
|
||||
|
30
sm9/sm9.go
30
sm9/sm9.go
@ -24,7 +24,12 @@ import (
|
||||
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
|
||||
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
||||
var bigOne = big.NewInt(1)
|
||||
var orderMinus1 = new(big.Int).Sub(bn256.Order, bigOne)
|
||||
var bigOneNat *bigmod.Nat
|
||||
var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
|
||||
|
||||
func init() {
|
||||
bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
|
||||
}
|
||||
|
||||
type hashMode byte
|
||||
|
||||
@ -46,7 +51,7 @@ const (
|
||||
)
|
||||
|
||||
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
|
||||
func hash(z []byte, h hashMode) *big.Int {
|
||||
func hash(z []byte, h hashMode) *bigmod.Nat {
|
||||
md := sm3.New()
|
||||
var ha [64]byte
|
||||
var countBytes [4]byte
|
||||
@ -61,18 +66,18 @@ func hash(z []byte, h hashMode) *big.Int {
|
||||
ct++
|
||||
md.Reset()
|
||||
}
|
||||
//TODO: how to rewrite this part with nat?
|
||||
k := new(big.Int).SetBytes(ha[:40])
|
||||
k.Mod(k, orderMinus1)
|
||||
k.Add(k, bigOne)
|
||||
return k
|
||||
kNat := bigmod.NewNat().SetBig(k)
|
||||
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
|
||||
kNat.Add(bigOneNat, orderNat)
|
||||
return kNat
|
||||
}
|
||||
|
||||
func hashH1(z []byte) *big.Int {
|
||||
func hashH1(z []byte) *bigmod.Nat {
|
||||
return hash(z, H1)
|
||||
}
|
||||
|
||||
func hashH2(z []byte) *big.Int {
|
||||
func hashH2(z []byte) *bigmod.Nat {
|
||||
return hash(z, H2)
|
||||
}
|
||||
|
||||
@ -131,14 +136,11 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn2
|
||||
buffer = append(buffer, hash...)
|
||||
buffer = append(buffer, w.Marshal()...)
|
||||
|
||||
h = hashH2(buffer)
|
||||
hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hNat = hashH2(buffer)
|
||||
r.Sub(hNat, orderNat)
|
||||
|
||||
if r.IsZero() == 0 {
|
||||
h = new(big.Int).SetBytes(hNat.Bytes(orderNat))
|
||||
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
|
||||
break
|
||||
}
|
||||
@ -203,7 +205,7 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
|
||||
buffer = append(buffer, w.Marshal()...)
|
||||
h2 := hashH2(buffer)
|
||||
|
||||
return h.Cmp(h2) == 0
|
||||
return h2.Equal(hNat) == 1
|
||||
}
|
||||
|
||||
// VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the
|
||||
|
@ -116,12 +116,7 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign
|
||||
id = append(id, uid...)
|
||||
id = append(id, hid)
|
||||
|
||||
t1 := hashH1(id)
|
||||
|
||||
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t1Nat := hashH1(id)
|
||||
|
||||
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
@ -180,7 +175,7 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2
|
||||
buffer = append(buffer, uid...)
|
||||
buffer = append(buffer, hid)
|
||||
h1 := hashH1(buffer)
|
||||
p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
|
||||
p, err := new(bn256.G2).ScalarBaseMult(h1.Bytes(orderNat))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@ -376,12 +371,7 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E
|
||||
id = append(id, uid...)
|
||||
id = append(id, hid)
|
||||
|
||||
t1 := hashH1(id)
|
||||
|
||||
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t1Nat := hashH1(id)
|
||||
|
||||
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
@ -476,7 +466,7 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
|
||||
buffer = append(buffer, uid...)
|
||||
buffer = append(buffer, hid)
|
||||
h1 := hashH1(buffer)
|
||||
p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
|
||||
p, err := new(bn256.G1).ScalarBaseMult(h1.Bytes(orderNat))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -24,8 +24,8 @@ func bigFromHex(s string) *big.Int {
|
||||
func TestHashH1(t *testing.T) {
|
||||
expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab"
|
||||
h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01})
|
||||
if hex.EncodeToString(h.Bytes()) != expected {
|
||||
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
|
||||
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
|
||||
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
|
||||
}
|
||||
}
|
||||
|
||||
@ -37,8 +37,8 @@ func TestHashH2(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h := hashH2(z)
|
||||
if hex.EncodeToString(h.Bytes()) != expected {
|
||||
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
|
||||
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
|
||||
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
|
||||
}
|
||||
}
|
||||
|
||||
@ -91,11 +91,19 @@ func TestSignASN1(t *testing.T) {
|
||||
// SM9 Appendix A
|
||||
func TestSignSM9Sample(t *testing.T) {
|
||||
expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb")
|
||||
expectedHNat, err := bigmod.NewNat().SetBytes(expectedH.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05"
|
||||
hash := []byte("Chinese IBS standard")
|
||||
hid := byte(0x01)
|
||||
uid := []byte("Alice")
|
||||
r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe")
|
||||
rNat, err := bigmod.NewNat().SetBytes(r.Bytes(), orderNat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
masterKey := new(SignMasterPrivateKey)
|
||||
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
|
||||
@ -118,17 +126,13 @@ func TestSignSM9Sample(t *testing.T) {
|
||||
buffer = append(buffer, w.Marshal()...)
|
||||
|
||||
h := hashH2(buffer)
|
||||
if h.Cmp(expectedH) != 0 {
|
||||
if h.Equal(expectedHNat) == 0 {
|
||||
t.Fatal("not same h")
|
||||
}
|
||||
|
||||
l := new(big.Int).Sub(r, h)
|
||||
rNat.Sub(h, orderNat)
|
||||
|
||||
if l.Sign() < 0 {
|
||||
l.Add(l, bn256.Order)
|
||||
}
|
||||
|
||||
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes()))
|
||||
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, rNat.Bytes(orderNat))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user