mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
internal/bigmod: drop math/big dependency #273
This commit is contained in:
parent
cd60dad621
commit
9624b43515
@ -7,7 +7,6 @@ package bigmod
|
|||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"math/big"
|
|
||||||
"math/bits"
|
"math/bits"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -104,6 +103,27 @@ func (x *Nat) reset(n int) *Nat {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing
|
||||||
|
// n to the appropriate size.
|
||||||
|
//
|
||||||
|
// The announced length of x is set based on the actual bit size of the input,
|
||||||
|
// ignoring leading zeroes.
|
||||||
|
func (x *Nat) resetToBytes(b []byte) *Nat {
|
||||||
|
x.reset((len(b) + _S - 1) / _S)
|
||||||
|
if err := x.setBytes(b); err != nil {
|
||||||
|
panic("bigmod: internal error: bad arithmetic")
|
||||||
|
}
|
||||||
|
// Trim most significant (trailing in little-endian) zero limbs.
|
||||||
|
// We assume comparison with zero (but not the branch) is constant time.
|
||||||
|
for i := len(x.limbs) - 1; i >= 0; i-- {
|
||||||
|
if x.limbs[i] != 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
x.limbs = x.limbs[:i]
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
// set assigns x = y, optionally resizing x to the appropriate size.
|
// set assigns x = y, optionally resizing x to the appropriate size.
|
||||||
func (x *Nat) Set(y *Nat) *Nat {
|
func (x *Nat) Set(y *Nat) *Nat {
|
||||||
x.reset(len(y.limbs))
|
x.reset(len(y.limbs))
|
||||||
@ -111,19 +131,6 @@ func (x *Nat) Set(y *Nat) *Nat {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetBig assigns x = n, optionally resizing n to the appropriate size.
|
|
||||||
//
|
|
||||||
// The announced length of x is set based on the actual bit size of the input,
|
|
||||||
// ignoring leading zeroes.
|
|
||||||
func (x *Nat) SetBig(n *big.Int) *Nat {
|
|
||||||
limbs := n.Bits()
|
|
||||||
x.reset(len(limbs))
|
|
||||||
for i := range limbs {
|
|
||||||
x.limbs[i] = uint(limbs[i])
|
|
||||||
}
|
|
||||||
return x
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bytes returns x as a zero-extended big-endian byte slice. The size of the
|
// Bytes returns x as a zero-extended big-endian byte slice. The size of the
|
||||||
// slice will match the size of m.
|
// slice will match the size of m.
|
||||||
//
|
//
|
||||||
@ -152,7 +159,8 @@ func (x *Nat) Bytes(m *Modulus) []byte {
|
|||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||||
if err := x.setBytes(b, m); err != nil {
|
x.resetFor(m)
|
||||||
|
if err := x.setBytes(b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if x.CmpGeq(m.nat) == yes {
|
if x.CmpGeq(m.nat) == yes {
|
||||||
@ -167,7 +175,8 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
|||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||||
if err := x.setBytes(b, m); err != nil {
|
x.resetFor(m)
|
||||||
|
if err := x.setBytes(b); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
|
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
|
||||||
@ -178,6 +187,19 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
|||||||
return x, nil
|
return x, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetOverflowedBytes assigns x = (b mode (m-1)) + 1, where b is a slice of big-endian bytes.
|
||||||
|
//
|
||||||
|
// The output will be resized to the size of m and overwritten.
|
||||||
|
func (x *Nat) SetOverflowedBytes(b []byte, m *Modulus) *Nat {
|
||||||
|
mMinusOne := NewNat().Set(m.nat)
|
||||||
|
mMinusOne.limbs[0]-- // due to m is odd, so we can safely subtract 1
|
||||||
|
one := NewNat().resetFor(m)
|
||||||
|
one.limbs[0] = 1
|
||||||
|
x.resetToBytes(b)
|
||||||
|
x = NewNat().modNat(x, mMinusOne)
|
||||||
|
return x.Add(one, m)
|
||||||
|
}
|
||||||
|
|
||||||
// bigEndianUint returns the contents of buf interpreted as a
|
// bigEndianUint returns the contents of buf interpreted as a
|
||||||
// big-endian encoded uint value.
|
// big-endian encoded uint value.
|
||||||
func bigEndianUint(buf []byte) uint {
|
func bigEndianUint(buf []byte) uint {
|
||||||
@ -187,8 +209,7 @@ func bigEndianUint(buf []byte) uint {
|
|||||||
return uint(binary.BigEndian.Uint32(buf))
|
return uint(binary.BigEndian.Uint32(buf))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *Nat) setBytes(b []byte, m *Modulus) error {
|
func (x *Nat) setBytes(b []byte) error {
|
||||||
x.resetFor(m)
|
|
||||||
i, k := len(b), 0
|
i, k := len(b), 0
|
||||||
for k < len(x.limbs) && i >= _S {
|
for k < len(x.limbs) && i >= _S {
|
||||||
x.limbs[k] = bigEndianUint(b[i-_S : i])
|
x.limbs[k] = bigEndianUint(b[i-_S : i])
|
||||||
@ -381,18 +402,16 @@ func minusInverseModW(x uint) uint {
|
|||||||
return -y
|
return -y
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewModulusFromBig creates a new Modulus from a [big.Int].
|
// NewModulus creates a new Modulus from a slice of big-endian bytes.
|
||||||
//
|
//
|
||||||
// The Int must be odd. The number of significant bits (and nothing else) is
|
// The value must be odd. The number of significant bits (and nothing else) is
|
||||||
// leaked through timing side-channels.
|
// leaked through timing side-channels.
|
||||||
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
|
func NewModulus(b []byte) (*Modulus, error) {
|
||||||
if b := n.Bits(); len(b) == 0 {
|
if len(b) == 0 || b[len(b)-1]&1 != 1 {
|
||||||
return nil, errors.New("modulus must be >= 0")
|
return nil, errors.New("modulus must be > 0 and odd")
|
||||||
} else if b[0]&1 != 1 {
|
|
||||||
return nil, errors.New("modulus must be odd")
|
|
||||||
}
|
}
|
||||||
m := &Modulus{}
|
m := &Modulus{}
|
||||||
m.nat = NewNat().SetBig(n)
|
m.nat = NewNat().resetToBytes(b)
|
||||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||||
m.rr = rr(m)
|
m.rr = rr(m)
|
||||||
@ -478,7 +497,7 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
|||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
||||||
return out.ModNat(x, m.nat)
|
return out.modNat(x, m.nat)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mod calculates out = x mod m.
|
// Mod calculates out = x mod m.
|
||||||
@ -486,7 +505,7 @@ func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
|
|||||||
// This works regardless how large the value of x is.
|
// This works regardless how large the value of x is.
|
||||||
//
|
//
|
||||||
// The output will be resized to the size of m and overwritten.
|
// The output will be resized to the size of m and overwritten.
|
||||||
func (out *Nat) ModNat(x *Nat, m *Nat) *Nat {
|
func (out *Nat) modNat(x *Nat, m *Nat) *Nat {
|
||||||
out.reset(len(m.limbs))
|
out.reset(len(m.limbs))
|
||||||
// Working our way from the most significant to the least significant limb,
|
// Working our way from the most significant to the least significant limb,
|
||||||
// we can insert each limb at the least significant position, shifting all
|
// we can insert each limb at the least significant position, shifting all
|
||||||
|
@ -5,6 +5,8 @@
|
|||||||
package bigmod
|
package bigmod
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
"math/bits"
|
"math/bits"
|
||||||
@ -70,7 +72,7 @@ func TestMontgomeryRoundtrip(t *testing.T) {
|
|||||||
one.limbs[0] = 1
|
one.limbs[0] = 1
|
||||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||||
m, _ := NewModulusFromBig(aPlusOne)
|
m, _ := NewModulus(aPlusOne.Bytes())
|
||||||
monty := new(Nat).Set(a)
|
monty := new(Nat).Set(a)
|
||||||
monty.montgomeryRepresentation(m)
|
monty.montgomeryRepresentation(m)
|
||||||
aAgain := new(Nat).Set(monty)
|
aAgain := new(Nat).Set(monty)
|
||||||
@ -310,6 +312,19 @@ func TestExpShort(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setBig assigns x = n, optionally resizing n to the appropriate size.
|
||||||
|
//
|
||||||
|
// The announced length of x is set based on the actual bit size of the input,
|
||||||
|
// ignoring leading zeroes.
|
||||||
|
func (x *Nat) setBig(n *big.Int) *Nat {
|
||||||
|
limbs := n.Bits()
|
||||||
|
x.reset(len(limbs))
|
||||||
|
for i := range limbs {
|
||||||
|
x.limbs[i] = uint(limbs[i])
|
||||||
|
}
|
||||||
|
return x
|
||||||
|
}
|
||||||
|
|
||||||
// TestMulReductions tests that Mul reduces results equal or slightly greater
|
// TestMulReductions tests that Mul reduces results equal or slightly greater
|
||||||
// than the modulus. Some Montgomery algorithms don't and need extra care to
|
// than the modulus. Some Montgomery algorithms don't and need extra care to
|
||||||
// return correct results. See https://go.dev/issue/13907.
|
// return correct results. See https://go.dev/issue/13907.
|
||||||
@ -319,19 +334,19 @@ func TestMulReductions(t *testing.T) {
|
|||||||
b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
|
b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
|
||||||
n := new(big.Int).Mul(a, b)
|
n := new(big.Int).Mul(a, b)
|
||||||
|
|
||||||
N, _ := NewModulusFromBig(n)
|
N, _ := NewModulus(n.Bytes())
|
||||||
A := NewNat().SetBig(a).ExpandFor(N)
|
A := NewNat().setBig(a).ExpandFor(N)
|
||||||
B := NewNat().SetBig(b).ExpandFor(N)
|
B := NewNat().setBig(b).ExpandFor(N)
|
||||||
|
|
||||||
if A.Mul(B, N).IsZero() != 1 {
|
if A.Mul(B, N).IsZero() != 1 {
|
||||||
t.Error("a * b mod (a * b) != 0")
|
t.Error("a * b mod (a * b) != 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
i := new(big.Int).ModInverse(a, b)
|
i := new(big.Int).ModInverse(a, b)
|
||||||
N, _ = NewModulusFromBig(b)
|
N, _ = NewModulus(b.Bytes())
|
||||||
A = NewNat().SetBig(a).ExpandFor(N)
|
A = NewNat().setBig(a).ExpandFor(N)
|
||||||
I := NewNat().SetBig(i).ExpandFor(N)
|
I := NewNat().setBig(i).ExpandFor(N)
|
||||||
one := NewNat().SetBig(big.NewInt(1)).ExpandFor(N)
|
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
|
||||||
|
|
||||||
if A.Mul(I, N).Equal(one) != 1 {
|
if A.Mul(I, N).Equal(one) != 1 {
|
||||||
t.Error("a * inv(a) mod b != 1")
|
t.Error("a * inv(a) mod b != 1")
|
||||||
@ -345,12 +360,12 @@ func natBytes(n *Nat) []byte {
|
|||||||
func natFromBytes(b []byte) *Nat {
|
func natFromBytes(b []byte) *Nat {
|
||||||
// Must not use Nat.SetBytes as it's used in TestSetBytes.
|
// Must not use Nat.SetBytes as it's used in TestSetBytes.
|
||||||
bb := new(big.Int).SetBytes(b)
|
bb := new(big.Int).SetBytes(b)
|
||||||
return NewNat().SetBig(bb)
|
return NewNat().setBig(bb)
|
||||||
}
|
}
|
||||||
|
|
||||||
func modulusFromBytes(b []byte) *Modulus {
|
func modulusFromBytes(b []byte) *Modulus {
|
||||||
bb := new(big.Int).SetBytes(b)
|
bb := new(big.Int).SetBytes(b)
|
||||||
m, _ := NewModulusFromBig(bb)
|
m, _ := NewModulus(bb.Bytes())
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -359,7 +374,7 @@ func maxModulus(n uint) *Modulus {
|
|||||||
b := big.NewInt(1)
|
b := big.NewInt(1)
|
||||||
b.Lsh(b, n*_W)
|
b.Lsh(b, n*_W)
|
||||||
b.Sub(b, big.NewInt(1))
|
b.Sub(b, big.NewInt(1))
|
||||||
m, _ := NewModulusFromBig(b)
|
m, _ := NewModulus(b.Bytes())
|
||||||
return m
|
return m
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -483,16 +498,56 @@ func BenchmarkExp(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNewModFromBigZero(t *testing.T) {
|
func TestNewModulus(t *testing.T) {
|
||||||
expected := "modulus must be >= 0"
|
expected := "modulus must be > 0 and odd"
|
||||||
_, err := NewModulusFromBig(big.NewInt(0))
|
_, err := NewModulus([]byte{})
|
||||||
if err == nil || err.Error() != expected {
|
if err == nil || err.Error() != expected {
|
||||||
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
|
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
|
||||||
}
|
}
|
||||||
|
_, err = NewModulus([]byte{0})
|
||||||
expected = "modulus must be odd"
|
|
||||||
_, err = NewModulusFromBig(big.NewInt(2))
|
|
||||||
if err == nil || err.Error() != expected {
|
if err == nil || err.Error() != expected {
|
||||||
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
|
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
|
||||||
|
}
|
||||||
|
_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
|
||||||
|
if err == nil || err.Error() != expected {
|
||||||
|
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
|
||||||
|
}
|
||||||
|
_, err = NewModulus([]byte{1, 1, 1, 1, 2})
|
||||||
|
if err == nil || err.Error() != expected {
|
||||||
|
t.Errorf("NewModulus(2) got %q, want %q", err, expected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOverflowedBytes(t *testing.T) {
|
||||||
|
cases := []string{
|
||||||
|
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25",
|
||||||
|
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23",
|
||||||
|
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24",
|
||||||
|
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24b640000002a3a6f1",
|
||||||
|
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
|
||||||
|
"00",
|
||||||
|
}
|
||||||
|
mBytes, _ := hex.DecodeString(cases[0])
|
||||||
|
m, err := NewModulus(mBytes)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
bigOne := big.NewInt(1)
|
||||||
|
mBigInt := new(big.Int).SetBytes(mBytes)
|
||||||
|
mMinusOne := new(big.Int).Sub(mBigInt, bigOne)
|
||||||
|
|
||||||
|
for _, c := range cases {
|
||||||
|
d, _ := hex.DecodeString(c)
|
||||||
|
k := new(big.Int).SetBytes(d)
|
||||||
|
k = new(big.Int).Mod(k, mMinusOne)
|
||||||
|
k = new(big.Int).Add(k, bigOne)
|
||||||
|
k = new(big.Int).Mod(k, mBigInt)
|
||||||
|
|
||||||
|
kNat := NewNat().SetOverflowedBytes(d, m)
|
||||||
|
k2 := new(big.Int).SetBytes(kNat.Bytes(m))
|
||||||
|
|
||||||
|
if !bytes.Equal(k2.Bytes(), k.Bytes()) {
|
||||||
|
t.Errorf("%s, expected %x, got %x", c, k.Bytes(), k2.Bytes())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1062,8 +1062,8 @@ func p256() *sm2Curve {
|
|||||||
func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
|
func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
|
||||||
params := curve.Params()
|
params := curve.Params()
|
||||||
c.curve = curve
|
c.curve = curve
|
||||||
c.N, _ = bigmod.NewModulusFromBig(params.N)
|
c.N, _ = bigmod.NewModulus(params.N.Bytes())
|
||||||
c.P, _ = bigmod.NewModulusFromBig(params.P)
|
c.P, _ = bigmod.NewModulus(params.P.Bytes())
|
||||||
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
|
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
|
||||||
c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N)
|
c.nMinus1, _ = bigmod.NewNat().SetBytes(new(big.Int).Sub(params.N, big.NewInt(1)).Bytes(), c.N)
|
||||||
}
|
}
|
||||||
|
19
sm9/sm9.go
19
sm9/sm9.go
@ -19,15 +19,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
|
// SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
|
||||||
|
var (
|
||||||
var orderNat, _ = bigmod.NewModulusFromBig(bn256.Order)
|
orderMinus2 []byte
|
||||||
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
orderNat *bigmod.Modulus
|
||||||
var bigOne = big.NewInt(1)
|
)
|
||||||
var bigOneNat *bigmod.Nat
|
|
||||||
var orderMinus1 = bigmod.NewNat().SetBig(new(big.Int).Sub(bn256.Order, bigOne))
|
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
bigOneNat, _ = bigmod.NewNat().SetBytes(bigOne.Bytes(), orderNat)
|
orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
||||||
|
orderNat, _ = bigmod.NewModulus(bn256.Order.Bytes())
|
||||||
}
|
}
|
||||||
|
|
||||||
type hashMode byte
|
type hashMode byte
|
||||||
@ -70,11 +69,7 @@ func hash(z []byte, h hashMode) *bigmod.Nat {
|
|||||||
md.Write(countBytes[:])
|
md.Write(countBytes[:])
|
||||||
copy(ha[sm3.Size:], md.Sum(nil))
|
copy(ha[sm3.Size:], md.Sum(nil))
|
||||||
|
|
||||||
k := new(big.Int).SetBytes(ha[:40])
|
return bigmod.NewNat().SetOverflowedBytes(ha[:40], orderNat)
|
||||||
kNat := bigmod.NewNat().SetBig(k)
|
|
||||||
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
|
|
||||||
kNat.Add(bigOneNat, orderNat)
|
|
||||||
return kNat
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func hashH1(z []byte) *bigmod.Nat {
|
func hashH1(z []byte) *bigmod.Nat {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user