mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-27 20:56:18 +08:00
sm9: use bigmod instead of math/big, part 2
This commit is contained in:
parent
a592631459
commit
c477816aa7
@ -123,11 +123,11 @@ func (x *Nat) Set(y *Nat) *Nat {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// setBig assigns x = n, optionally resizing n to the appropriate size.
|
// SetBig assigns x = n, optionally resizing n to the appropriate size.
|
||||||
//
|
//
|
||||||
// The announced length of x is set based on the actual bit size of the input,
|
// The announced length of x is set based on the actual bit size of the input,
|
||||||
// ignoring leading zeroes.
|
// ignoring leading zeroes.
|
||||||
func (x *Nat) setBig(n *big.Int) *Nat {
|
func (x *Nat) SetBig(n *big.Int) *Nat {
|
||||||
requiredLimbs := (n.BitLen() + _W - 1) / _W
|
requiredLimbs := (n.BitLen() + _W - 1) / _W
|
||||||
x.reset(requiredLimbs)
|
x.reset(requiredLimbs)
|
||||||
|
|
||||||
@ -386,7 +386,7 @@ func minusInverseModW(x uint) uint {
|
|||||||
// The Int must be odd. The number of significant bits must be leakable.
|
// The Int must be odd. The number of significant bits must be leakable.
|
||||||
func NewModulusFromBig(n *big.Int) *Modulus {
|
func NewModulusFromBig(n *big.Int) *Modulus {
|
||||||
m := &Modulus{}
|
m := &Modulus{}
|
||||||
m.nat = NewNat().setBig(n)
|
m.nat = NewNat().SetBig(n)
|
||||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||||
m.rr = rr(m)
|
m.rr = rr(m)
|
||||||
@ -426,13 +426,20 @@ func (m *Modulus) Nat() *Nat {
|
|||||||
//
|
//
|
||||||
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||||
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
||||||
d := NewNat().resetFor(m)
|
return x.shiftInNat(y, m.nat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// shiftIn calculates x = x << _W + y mod m.
|
||||||
|
//
|
||||||
|
// This assumes that x is already reduced mod m, and that y < 2^_W.
|
||||||
|
func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
||||||
|
d := NewNat().reset(len(m.limbs))
|
||||||
|
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(m.nat.limbs)
|
size := len(m.limbs)
|
||||||
xLimbs := x.limbs[:size]
|
xLimbs := x.limbs[:size]
|
||||||
dLimbs := d.limbs[:size]
|
dLimbs := d.limbs[:size]
|
||||||
mLimbs := m.nat.limbs[:size]
|
mLimbs := m.limbs[:size]
|
||||||
|
|
||||||
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
// Each iteration of this loop computes x = 2x + b mod m, where b is a bit
|
||||||
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
// from y. Effectively, it left-shifts x and adds y one bit at a time,
|
||||||
@ -469,7 +476,16 @@ func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||||
out.resetFor(m)
|
return out.ModNat(x, m.nat)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mod calculates out = x mod m.
|
||||||
|
//
|
||||||
|
// This works regardless how large the value of x is.
|
||||||
|
//
|
||||||
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
|
||||||
|
out.reset(len(m.limbs))
|
||||||
// Working our way from the most significant to the least significant limb,
|
// Working our way from the most significant to the least significant limb,
|
||||||
// we can insert each limb at the least significant position, shifting all
|
// we can insert each limb at the least significant position, shifting all
|
||||||
// previous limbs left by _W. This way each limb will get shifted by the
|
// previous limbs left by _W. This way each limb will get shifted by the
|
||||||
@ -478,7 +494,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
|||||||
i := len(x.limbs) - 1
|
i := len(x.limbs) - 1
|
||||||
// For the first N - 1 limbs we can skip the actual shifting and position
|
// For the first N - 1 limbs we can skip the actual shifting and position
|
||||||
// them at the shifted position, which starts at min(N - 2, i).
|
// them at the shifted position, which starts at min(N - 2, i).
|
||||||
start := len(m.nat.limbs) - 2
|
start := len(m.limbs) - 2
|
||||||
if i < start {
|
if i < start {
|
||||||
start = i
|
start = i
|
||||||
}
|
}
|
||||||
@ -488,7 +504,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
|||||||
}
|
}
|
||||||
// We shift in the remaining limbs, reducing modulo m each time.
|
// We shift in the remaining limbs, reducing modulo m each time.
|
||||||
for i >= 0 {
|
for i >= 0 {
|
||||||
out.shiftIn(x.limbs[i], m)
|
out.shiftInNat(x.limbs[i], m)
|
||||||
i--
|
i--
|
||||||
}
|
}
|
||||||
return out
|
return out
|
||||||
|
@ -244,6 +244,52 @@ func TestMod(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestModNat(t *testing.T) {
|
||||||
|
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||||
|
orderNat := NewModulusFromBig(order)
|
||||||
|
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
|
||||||
|
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||||
|
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
|
||||||
|
examples := []struct {
|
||||||
|
in *big.Int
|
||||||
|
expected *big.Int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
big.NewInt(1),
|
||||||
|
big.NewInt(2),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
orderMinus1,
|
||||||
|
big.NewInt(1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
order,
|
||||||
|
big.NewInt(2),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hashValue1,
|
||||||
|
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
hashValue2,
|
||||||
|
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for i, tt := range examples {
|
||||||
|
kNat := NewNat().SetBig(tt.in)
|
||||||
|
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
|
||||||
|
kNat.Add(oneNat, orderNat)
|
||||||
|
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
|
||||||
|
if out.Cmp(tt.expected) != 0 {
|
||||||
|
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestModSub(t *testing.T) {
|
func TestModSub(t *testing.T) {
|
||||||
m := modulusFromBytes([]byte{13})
|
m := modulusFromBytes([]byte{13})
|
||||||
x := &Nat{[]uint{6}}
|
x := &Nat{[]uint{6}}
|
||||||
@ -293,7 +339,7 @@ func natBytes(n *Nat) []byte {
|
|||||||
|
|
||||||
func natFromBytes(b []byte) *Nat {
|
func natFromBytes(b []byte) *Nat {
|
||||||
bb := new(big.Int).SetBytes(b)
|
bb := new(big.Int).SetBytes(b)
|
||||||
return NewNat().setBig(bb)
|
return NewNat().SetBig(bb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func modulusFromBytes(b []byte) *Modulus {
|
func modulusFromBytes(b []byte) *Modulus {
|
||||||
|
30
sm9/sm9.go
30
sm9/sm9.go
@ -24,7 +24,12 @@ import (
|
|||||||
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
|
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
|
||||||
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
||||||
var bigOne = big.NewInt(1)
|
var bigOne = big.NewInt(1)
|
||||||
var orderMinus1 = new(big.Int).Sub(bn256.Order, bigOne)
|
var bigOneNat *bigmod.Nat
|
||||||
|
var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
|
||||||
|
}
|
||||||
|
|
||||||
type hashMode byte
|
type hashMode byte
|
||||||
|
|
||||||
@ -46,7 +51,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
|
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
|
||||||
func hash(z []byte, h hashMode) *big.Int {
|
func hash(z []byte, h hashMode) *bigmod.Nat {
|
||||||
md := sm3.New()
|
md := sm3.New()
|
||||||
var ha [64]byte
|
var ha [64]byte
|
||||||
var countBytes [4]byte
|
var countBytes [4]byte
|
||||||
@ -61,18 +66,18 @@ func hash(z []byte, h hashMode) *big.Int {
|
|||||||
ct++
|
ct++
|
||||||
md.Reset()
|
md.Reset()
|
||||||
}
|
}
|
||||||
//TODO: how to rewrite this part with nat?
|
|
||||||
k := new(big.Int).SetBytes(ha[:40])
|
k := new(big.Int).SetBytes(ha[:40])
|
||||||
k.Mod(k, orderMinus1)
|
kNat := bigmod.NewNat().SetBig(k)
|
||||||
k.Add(k, bigOne)
|
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
|
||||||
return k
|
kNat.Add(bigOneNat, orderNat)
|
||||||
|
return kNat
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashH1(z []byte) *big.Int {
|
func hashH1(z []byte) *bigmod.Nat {
|
||||||
return hash(z, H1)
|
return hash(z, H1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashH2(z []byte) *big.Int {
|
func hashH2(z []byte) *bigmod.Nat {
|
||||||
return hash(z, H2)
|
return hash(z, H2)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,14 +136,11 @@ func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *bn2
|
|||||||
buffer = append(buffer, hash...)
|
buffer = append(buffer, hash...)
|
||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
|
|
||||||
h = hashH2(buffer)
|
hNat = hashH2(buffer)
|
||||||
hNat, err = bigmod.NewNat().SetBytes(h.Bytes(), orderNat)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r.Sub(hNat, orderNat)
|
r.Sub(hNat, orderNat)
|
||||||
|
|
||||||
if r.IsZero() == 0 {
|
if r.IsZero() == 0 {
|
||||||
|
h = new(big.Int).SetBytes(hNat.Bytes(orderNat))
|
||||||
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
|
s, err = new(bn256.G1).ScalarMult(priv.PrivateKey, r.Bytes(orderNat))
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -203,7 +205,7 @@ func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
h2 := hashH2(buffer)
|
h2 := hashH2(buffer)
|
||||||
|
|
||||||
return h.Cmp(h2) == 0
|
return h2.Equal(hNat) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
// VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the
|
// VerifyASN1 verifies the ASN.1 encoded signature of type SM9Signature, sig, of hash using the
|
||||||
|
@ -116,12 +116,7 @@ func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*Sign
|
|||||||
id = append(id, uid...)
|
id = append(id, uid...)
|
||||||
id = append(id, hid)
|
id = append(id, hid)
|
||||||
|
|
||||||
t1 := hashH1(id)
|
t1Nat := hashH1(id)
|
||||||
|
|
||||||
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -180,7 +175,7 @@ func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *bn2
|
|||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
buffer = append(buffer, hid)
|
buffer = append(buffer, hid)
|
||||||
h1 := hashH1(buffer)
|
h1 := hashH1(buffer)
|
||||||
p, err := new(bn256.G2).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
|
p, err := new(bn256.G2).ScalarBaseMult(h1.Bytes(orderNat))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@ -376,12 +371,7 @@ func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*E
|
|||||||
id = append(id, uid...)
|
id = append(id, uid...)
|
||||||
id = append(id, hid)
|
id = append(id, hid)
|
||||||
|
|
||||||
t1 := hashH1(id)
|
t1Nat := hashH1(id)
|
||||||
|
|
||||||
t1Nat, err := bigmod.NewNat().SetBytes(t1.Bytes(), orderNat)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
d, err := bigmod.NewNat().SetBytes(master.D.Bytes(), orderNat)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -476,7 +466,7 @@ func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *
|
|||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
buffer = append(buffer, hid)
|
buffer = append(buffer, hid)
|
||||||
h1 := hashH1(buffer)
|
h1 := hashH1(buffer)
|
||||||
p, err := new(bn256.G1).ScalarBaseMult(bn256.NormalizeScalar(h1.Bytes()))
|
p, err := new(bn256.G1).ScalarBaseMult(h1.Bytes(orderNat))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
@ -24,8 +24,8 @@ func bigFromHex(s string) *big.Int {
|
|||||||
func TestHashH1(t *testing.T) {
|
func TestHashH1(t *testing.T) {
|
||||||
expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab"
|
expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab"
|
||||||
h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01})
|
h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01})
|
||||||
if hex.EncodeToString(h.Bytes()) != expected {
|
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
|
||||||
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
|
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -37,8 +37,8 @@ func TestHashH2(t *testing.T) {
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
h := hashH2(z)
|
h := hashH2(z)
|
||||||
if hex.EncodeToString(h.Bytes()) != expected {
|
if hex.EncodeToString(h.Bytes(orderNat)) != expected {
|
||||||
t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected)
|
t.Errorf("got %v, expected %v", h.Bytes(orderNat), expected)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -91,11 +91,19 @@ func TestSignASN1(t *testing.T) {
|
|||||||
// SM9 Appendix A
|
// SM9 Appendix A
|
||||||
func TestSignSM9Sample(t *testing.T) {
|
func TestSignSM9Sample(t *testing.T) {
|
||||||
expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb")
|
expectedH := bigFromHex("823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb")
|
||||||
|
expectedHNat, err := bigmod.NewNat().SetBytes(expectedH.Bytes(), orderNat)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05"
|
expectedS := "0473bf96923ce58b6ad0e13e9643a406d8eb98417c50ef1b29cef9adb48b6d598c856712f1c2e0968ab7769f42a99586aed139d5b8b3e15891827cc2aced9baa05"
|
||||||
hash := []byte("Chinese IBS standard")
|
hash := []byte("Chinese IBS standard")
|
||||||
hid := byte(0x01)
|
hid := byte(0x01)
|
||||||
uid := []byte("Alice")
|
uid := []byte("Alice")
|
||||||
r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe")
|
r := bigFromHex("033c8616b06704813203dfd00965022ed15975c662337aed648835dc4b1cbe")
|
||||||
|
rNat, err := bigmod.NewNat().SetBytes(r.Bytes(), orderNat)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
masterKey := new(SignMasterPrivateKey)
|
masterKey := new(SignMasterPrivateKey)
|
||||||
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
|
masterKey.D = bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4")
|
||||||
@ -118,17 +126,13 @@ func TestSignSM9Sample(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
|
|
||||||
h := hashH2(buffer)
|
h := hashH2(buffer)
|
||||||
if h.Cmp(expectedH) != 0 {
|
if h.Equal(expectedHNat) == 0 {
|
||||||
t.Fatal("not same h")
|
t.Fatal("not same h")
|
||||||
}
|
}
|
||||||
|
|
||||||
l := new(big.Int).Sub(r, h)
|
rNat.Sub(h, orderNat)
|
||||||
|
|
||||||
if l.Sign() < 0 {
|
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, rNat.Bytes(orderNat))
|
||||||
l.Add(l, bn256.Order)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := new(bn256.G1).ScalarMult(userKey.PrivateKey, bn256.NormalizeScalar(l.Bytes()))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user