[sync sdk] crypto/internal/bigmod: switch to saturated limbs

This commit is contained in:
Sun Yimin 2023-06-01 10:39:12 +08:00 committed by GitHub
parent f7a04e74a1
commit f32b7e1afc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1952 additions and 361 deletions

View File

@ -5,16 +5,17 @@
package bigmod
import (
"encoding/binary"
"errors"
"math/big"
"math/bits"
)
const (
// _W is the number of bits we use for our limbs.
_W = bits.UintSize - 1
// _MASK selects _W bits from a full machine word.
_MASK = (1 << _W) - 1
// _W is the size in bits of our limbs.
_W = bits.UintSize
// _S is the size in bytes of our limbs.
_S = _W / 8
)
// choice represents a constant-time boolean. The value of choice is always
@ -27,15 +28,8 @@ func not(c choice) choice { return 1 ^ c }
const yes = choice(1)
const no = choice(0)
// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
// function does not depend on its inputs. If on is any value besides 1 or 0,
// the result is undefined.
func ctSelect(on choice, x, y uint) uint {
// When on == 1, mask is 0b111..., otherwise mask is 0b000...
mask := -uint(on)
// When mask is all zeros, we just have y, otherwise, y cancels with itself.
return y ^ (mask & (y ^ x))
}
// ctMask is all 1s if on is yes, and all 0s otherwise.
func ctMask(on choice) uint { return -uint(on) }
// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
// function does not depend on its inputs.
@ -60,13 +54,7 @@ func ctGeq(x, y uint) choice {
// Operations on this number are allowed to leak this length, but will not leak
// any information about the values contained in those limbs.
type Nat struct {
// limbs is a little-endian representation in base 2^W with
// W = bits.UintSize - 1. The top bit is always unset between operations.
//
// The top bit is left unset to optimize Montgomery multiplication, in the
// inner loop of exponentiation. Using fully saturated limbs would leave us
// working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
// and thus time.
// limbs is little-endian in base 2^W with W = bits.UintSize.
limbs []uint
}
@ -128,25 +116,10 @@ func (x *Nat) Set(y *Nat) *Nat {
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) SetBig(n *big.Int) *Nat {
requiredLimbs := (n.BitLen() + _W - 1) / _W
x.reset(requiredLimbs)
outI := 0
shift := 0
limbs := n.Bits()
x.reset(len(limbs))
for i := range limbs {
xi := uint(limbs[i])
x.limbs[outI] |= (xi << shift) & _MASK
outI++
if outI == requiredLimbs {
return x
}
x.limbs[outI] = xi >> (_W - shift)
shift++ // this assumes bits.UintSize - _W = 1
if shift == _W {
shift = 0
outI++
}
x.limbs[i] = uint(limbs[i])
}
return x
}
@ -156,24 +129,20 @@ func (x *Nat) SetBig(n *big.Int) *Nat {
//
// x must have the same size as m and it must be reduced modulo m.
func (x *Nat) Bytes(m *Modulus) []byte {
bytes := make([]byte, m.Size())
shift := 0
outI := len(bytes) - 1
i := m.Size()
bytes := make([]byte, i)
for _, limb := range x.limbs {
remainingBits := _W
for remainingBits >= 8 {
bytes[outI] |= byte(limb) << shift
consumed := 8 - shift
limb >>= consumed
remainingBits -= consumed
shift = 0
outI--
if outI < 0 {
return bytes
for j := 0; j < _S; j++ {
i--
if i < 0 {
if limb == 0 {
break
}
panic("bigmod: modulus is smaller than nat")
}
bytes[i] = byte(limb)
limb >>= 8
}
bytes[outI] = byte(limb)
shift = remainingBits
}
return bytes
}
@ -192,9 +161,9 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
return x, nil
}
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
// returns an error if b has a longer bit length than m, but reduces overflowing
// values up to 2^⌈log2(m)⌉ - 1.
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
// SetOverflowingBytes returns an error if b has a longer bit length than m, but
// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
//
// The output will be resized to the size of m and overwritten.
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
@ -203,33 +172,35 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
}
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
if leading < m.leading {
return nil, errors.New("input overflows the modulus")
return nil, errors.New("input overflows the modulus size")
}
x.sub(x.cmpGeq(m.nat), m.nat)
x.maybeSubtractModulus(no, m)
return x, nil
}
// bigEndianUint returns the contents of buf interpreted as a
// big-endian encoded uint value.
func bigEndianUint(buf []byte) uint {
if _W == 64 {
return uint(binary.BigEndian.Uint64(buf))
}
return uint(binary.BigEndian.Uint32(buf))
}
func (x *Nat) setBytes(b []byte, m *Modulus) error {
outI := 0
shift := 0
x.resetFor(m)
for i := len(b) - 1; i >= 0; i-- {
bi := b[i]
x.limbs[outI] |= uint(bi) << shift
shift += 8
if shift >= _W {
shift -= _W
x.limbs[outI] &= _MASK
overflow := bi >> (8 - shift)
outI++
if outI >= len(x.limbs) {
if overflow > 0 || i > 0 {
return errors.New("input overflows the modulus")
}
break
}
x.limbs[outI] = uint(overflow)
}
i, k := len(b), 0
for k < len(x.limbs) && i >= _S {
x.limbs[k] = bigEndianUint(b[i-_S : i])
i -= _S
k++
}
for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
x.limbs[k] |= uint(b[i-1]) << s
i--
}
if i > 0 {
return errors.New("input overflows the modulus size")
}
return nil
}
@ -274,7 +245,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
var c uint
for i := 0; i < size; i++ {
c = (xLimbs[i] - yLimbs[i] - c) >> _W
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
@ -290,44 +261,39 @@ func (x *Nat) assign(on choice, y *Nat) *Nat {
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
mask := ctMask(on)
for i := 0; i < size; i++ {
xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
}
return x
}
// add computes x += y if on == 1, and does nothing otherwise. It returns the
// carry of the addition regardless of on.
// add computes x += y and returns the carry.
//
// Both operands must have the same announced length.
func (x *Nat) add(on choice, y *Nat) (c uint) {
func (x *Nat) add(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
res := xLimbs[i] + yLimbs[i] + c
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
c = res >> _W
xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
}
return
}
// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
// borrow of the subtraction regardless of on.
// sub computes x -= y. It returns the borrow of the subtraction.
//
// Both operands must have the same announced length.
func (x *Nat) sub(on choice, y *Nat) (c uint) {
func (x *Nat) sub(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
yLimbs := y.limbs[:size]
for i := 0; i < size; i++ {
res := xLimbs[i] - yLimbs[i] - c
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
c = res >> _W
xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
return
}
@ -371,26 +337,32 @@ func minusInverseModW(x uint) uint {
// Every iteration of this loop doubles the least-significant bits of
// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
// for 61 bits (and wastes only one iteration for 31 bits).
// for 64 bits (and wastes only one iteration for 32 bits).
//
// See https://crypto.stackexchange.com/a/47496.
y := x
for i := 0; i < 5; i++ {
y = y * (2 - x*y)
}
return (1 << _W) - (y & _MASK)
return -y
}
// NewModulusFromBig creates a new Modulus from a [big.Int].
//
// The Int must be odd. The number of significant bits must be leakable.
func NewModulusFromBig(n *big.Int) *Modulus {
// The Int must be odd. The number of significant bits (and nothing else) is
// leaked through timing side-channels.
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
if b := n.Bits(); len(b) == 0 {
return nil, errors.New("modulus must be >= 0")
} else if b[0]&1 != 1 {
return nil, errors.New("modulus must be odd")
}
m := &Modulus{}
m.nat = NewNat().SetBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
return m
return m, nil
}
// bitLen is a version of bits.Len that only leaks the bit length of n, but not
@ -447,25 +419,21 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
//
// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
// The next iteration (and finally the return line) will use either result
// based on whether the subtraction underflowed.
// based on whether 2x + b overflows m.
needSubtraction := no
for i := _W - 1; i >= 0; i-- {
carry := (y >> i) & 1
var borrow uint
mask := ctMask(needSubtraction)
for i := 0; i < size; i++ {
l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
res := l<<1 + carry
xLimbs[i] = res & _MASK
carry = res >> _W
res = xLimbs[i] - mLimbs[i] - borrow
dLimbs[i] = res & _MASK
borrow = res >> _W
l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
xLimbs[i], carry = bits.Add(l, l, carry)
dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
}
// See Add for how carry (aka overflow), borrow (aka underflow), and
// needSubtraction relate.
needSubtraction = ctEq(carry, borrow)
// Like in maybeSubtractModulus, we need the subtraction if either it
// didn't underflow (meaning 2x + b > m) or if computing 2x + b
// overflowed (meaning 2x + b > 2^_W*n > m).
needSubtraction = not(choice(borrow)) | choice(carry)
}
return x.assign(needSubtraction, d)
}
@ -524,14 +492,34 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
return out.reset(len(m.nat.limbs))
}
// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
//
// It can be used to reduce modulo m a value up to 2m - 1, which is a common
// range for results computed by higher level operations.
//
// always is usually a carry that indicates that the operation that produced x
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
//
// x and m operands must have the same announced length.
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
t := NewNat().Set(x)
underflow := t.sub(m.nat)
// We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set.
keep := not(choice(underflow)) | choice(always)
x.assign(keep, t)
}
// Sub computes x = x - y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(yes, y)
underflow := x.sub(y)
// If the subtraction underflowed, add m.
x.add(choice(underflow), m.nat)
t := NewNat().Set(x)
t.add(m.nat)
x.assign(choice(underflow), t)
return x
}
@ -540,34 +528,8 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
overflow := x.add(yes, y)
underflow := not(x.cmpGeq(m.nat)) // x < m
// Three cases are possible:
//
// - overflow = 0, underflow = 0
//
// In this case, addition fits in our limbs, but we can still subtract away
// m without an underflow, so we need to perform the subtraction to reduce
// our result.
//
// - overflow = 0, underflow = 1
//
// The addition fits in our limbs, but we can't subtract m without
// underflowing. The result is already reduced.
//
// - overflow = 1, underflow = 1
//
// The addition does not fit in our limbs, and the subtraction's borrow
// would cancel out with the addition's carry. We need to subtract m to
// reduce our result.
//
// The overflow = 1, underflow = 0 case is not possible, because y is at
// most m - 1, and if adding m - 1 overflows, then subtracting m must
// necessarily underflow.
needSubtraction := ctEq(overflow, uint(underflow))
x.sub(needSubtraction, m.nat)
overflow := x.add(y)
x.maybeSubtractModulus(choice(overflow), m)
return x
}
@ -581,7 +543,7 @@ func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
// A Montgomery multiplication (which computes a * b / R) by R * R works out
// to a multiplication by R, which takes the value out of the Montgomery domain.
return x.montgomeryMul(NewNat().Set(x), m.rr, m)
return x.montgomeryMul(x, m.rr, m)
}
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
@ -592,77 +554,157 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// By Montgomery multiplying with 1 not in Montgomery representation, we
// convert out back from Montgomery representation, because it works out to
// dividing by R.
t0 := NewNat().Set(x)
t1 := NewNat().ExpandFor(m)
t1.limbs[0] = 1
return x.montgomeryMul(t0, t1, m)
one := NewNat().ExpandFor(m)
one.limbs[0] = 1
return x.montgomeryMul(x, one, m)
}
// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
// n = len(m.nat.limbs), also known as a Montgomery multiplication.
//
// All inputs should be the same length, not aliasing d, and already
// reduced modulo m. d will be resized to the size of m and overwritten.
func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
d.resetFor(m)
if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
panic("bigmod: invalid montgomeryMul input")
}
// All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten.
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
n := len(m.nat.limbs)
mLimbs := m.nat.limbs[:n]
aLimbs := a.limbs[:n]
bLimbs := b.limbs[:n]
// See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
// for a description of the algorithm implemented mostly in montgomeryLoop.
// See Add for how overflow, underflow, and needSubtraction relate.
overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
underflow := not(d.cmpGeq(m.nat)) // d < m
needSubtraction := ctEq(overflow, uint(underflow))
d.sub(needSubtraction, m.nat)
switch n {
default:
// Attempt to use a stack-allocated backing array.
T := make([]uint, 0, preallocLimbs*2)
if cap(T) < n*2 {
T = make([]uint, 0, n*2)
}
T = T[:n*2]
return d
}
// This loop implements Word-by-Word Montgomery Multiplication, as
// described in Algorithm 4 (Fig. 3) of "Efficient Software
// Implementations of Modular Exponentiation" by Shay Gueron
// [https://eprint.iacr.org/2011/239.pdf].
var c uint
for i := 0; i < n; i++ {
_ = T[n+i] // bounds check elimination hint
func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
// Eliminate bounds checks in the loop.
size := len(d)
a = a[:size]
b = b[:size]
m = m[:size]
// Step 1 (T = a × b) is computed as a large pen-and-paper column
// multiplication of two numbers with n base-2^_W digits. If we just
// wanted to produce 2n-wide T, we would do
//
// for i := 0; i < n; i++ {
// d := bLimbs[i]
// T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
// }
//
// where d is a digit of the multiplier, T[i:n+i] is the shifted
// position of the product of that digit, and T[n+i] is the final carry.
// Note that T[i] isn't modified after processing the i-th digit.
//
// Instead of running two loops, one for Step 1 and one for Steps 26,
// the result of Step 1 is computed during the next loop. This is
// possible because each iteration only uses T[i] in Step 2 and then
// discards it in Step 6.
d := bLimbs[i]
c1 := addMulVVW(T[i:n+i], aLimbs, d)
for _, ai := range a {
// This is an unrolled iteration of the loop below with j = 0.
hi, lo := bits.Mul(ai, b[0])
z_lo, c := bits.Add(d[0], lo, 0)
f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
z_hi, _ := bits.Add(0, hi, c)
hi, lo = bits.Mul(f, m[0])
z_lo, c = bits.Add(z_lo, lo, 0)
z_hi, _ = bits.Add(z_hi, hi, c)
carry := z_hi<<1 | z_lo>>_W
// Step 6 is replaced by shifting the virtual window we operate
// over: T of the algorithm is T[i:] for us. That means that T1 in
// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
Y := T[i] * m.m0inv
for j := 1; j < size; j++ {
// z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
hi, lo := bits.Mul(ai, b[j])
z_lo, c := bits.Add(d[j], lo, 0)
z_hi, _ := bits.Add(0, hi, c)
hi, lo = bits.Mul(f, m[j])
z_lo, c = bits.Add(z_lo, lo, 0)
z_hi, _ = bits.Add(z_hi, hi, c)
z_lo, c = bits.Add(z_lo, carry, 0)
z_hi, _ = bits.Add(z_hi, 0, c)
d[j-1] = z_lo & _MASK
carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
// Step 4 and 5 add Y × m to T, which as mentioned above is stored
// at T[i:]. The two carries (from a × d and Y × m) are added up in
// the next word T[n+i], and the carry bit from that addition is
// brought forward to the next iteration.
c2 := addMulVVW(T[i:n+i], mLimbs, Y)
T[n+i], c = bits.Add(c1, c2, c)
}
z := overflow + carry // z <= 2^(W+1) - 1
d[size-1] = z & _MASK
overflow = z >> _W // overflow <= 1
// Finally for Step 7 we copy the final T window into x, and subtract m
// if necessary (which as explained in maybeSubtractModulus can be the
// case both if x >= m, or if x overflowed).
//
// The paper suggests in Section 4 that we can do an "Almost Montgomery
// Multiplication" by subtracting only in the overflow case, but the
// cost is very similar since the constant time subtraction tells us if
// x >= m as a side effect, and taking care of the broken invariant is
// highly undesirable (see https://go.dev/issue/13907).
copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m)
// The following specialized cases follow the exact same algorithm, but
// optimized for the sizes most used in RSA. addMulVVW is implemented in
// assembly with loop unrolling depending on the architecture and bounds
// checks are removed by the compiler thanks to the constant size.
case 1024 / _W:
const n = 1024 / _W // compiler hint
T := make([]uint, n*2)
var c uint
for i := 0; i < n; i++ {
d := bLimbs[i]
c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
Y := T[i] * m.m0inv
c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
T[n+i], c = bits.Add(c1, c2, c)
}
copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m)
case 1536 / _W:
const n = 1536 / _W // compiler hint
T := make([]uint, n*2)
var c uint
for i := 0; i < n; i++ {
d := bLimbs[i]
c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
Y := T[i] * m.m0inv
c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
T[n+i], c = bits.Add(c1, c2, c)
}
copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m)
case 2048 / _W:
const n = 2048 / _W // compiler hint
T := make([]uint, n*2)
var c uint
for i := 0; i < n; i++ {
d := bLimbs[i]
c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
Y := T[i] * m.m0inv
c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
T[n+i], c = bits.Add(c1, c2, c)
}
copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m)
}
return
return x
}
// Mul calculates x *= y mod m.
// addMulVVW multiplies the multi-word value x by the single-word value y,
// adding the result to the multi-word value z and returning the final carry.
// It can be thought of as one row of a pen-and-paper column multiplication.
func addMulVVW(z, x []uint, y uint) (carry uint) {
_ = x[len(z)-1] // bounds check elimination hint
for i := range z {
hi, lo := bits.Mul(x[i], y)
lo, c := bits.Add(lo, z[i], 0)
// We use bits.Add with zero to get an add-with-carry instruction that
// absorbs the carry from the previous bits.Add.
hi, _ = bits.Add(hi, 0, c)
lo, c = bits.Add(lo, carry, 0)
hi, _ = bits.Add(hi, 0, c)
carry = hi
z[i] = lo
}
return carry
}
// Mul calculates x = x * y mod m.
//
// x and y must already be reduced modulo m, they must share its announced
// length, and they may not alias.
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
@ -677,7 +719,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
// Using bit sizes that don't divide 8 are more complex to implement.
// Using bit sizes that don't divide 8 are more complex to implement, but
// are likely to be more efficient if necessary.
table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
// newNat calls are unrolled so they are allocated on the stack.
@ -693,27 +736,51 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
out.resetFor(m)
out.limbs[0] = 1
out.montgomeryRepresentation(m)
t0 := NewNat().ExpandFor(m)
t1 := NewNat().ExpandFor(m)
tmp := NewNat().ExpandFor(m)
for _, b := range e {
for _, j := range []int{4, 0} {
// Square four times.
t1.montgomeryMul(out, out, m)
out.montgomeryMul(t1, t1, m)
t1.montgomeryMul(out, out, m)
out.montgomeryMul(t1, t1, m)
// Square four times. Optimization note: this can be implemented
// more efficiently than with generic Montgomery multiplication.
out.montgomeryMul(out, out, m)
out.montgomeryMul(out, out, m)
out.montgomeryMul(out, out, m)
out.montgomeryMul(out, out, m)
// Select x^k in constant time from the table.
k := uint((b >> j) & 0b1111)
for i := range table {
t0.assign(ctEq(k, uint(i+1)), table[i])
tmp.assign(ctEq(k, uint(i+1)), table[i])
}
// Multiply by x^k, discarding the result if k = 0.
t1.montgomeryMul(out, t0, m)
out.assign(not(ctEq(k, 0)), t1)
tmp.montgomeryMul(out, tmp, m)
out.assign(not(ctEq(k, 0)), tmp)
}
}
return out.montgomeryReduction(m)
}
// ExpShort calculates out = x^e mod m.
//
// The output will be resized to the size of m and overwritten. x must already
// be reduced modulo m. This leaks the exact bit size of the exponent.
func (out *Nat) ExpShort(x *Nat, e uint, m *Modulus) *Nat {
xR := NewNat().Set(x).montgomeryRepresentation(m)
out.resetFor(m)
out.limbs[0] = 1
out.montgomeryRepresentation(m)
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple constant-time conditional
// square-and-multiply chain, skipping the initial run of zeroes.
tmp := NewNat().ExpandFor(m)
for i := bits.UintSize - bitLen(e); i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m)
k := (e >> (bits.UintSize - i - 1)) & 1
tmp.montgomeryMul(out, xR, m)
out.assign(ctEq(k, 1), tmp)
}
return out.montgomeryReduction(m)
}

48
internal/bigmod/nat_386.s Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
// +build !purego
#include "textflag.h"
// func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-16
MOVL $32, BX
JMP addMulVVWx(SB)
// func addMulVVW1536(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1536(SB), $0-16
MOVL $48, BX
JMP addMulVVWx(SB)
// func addMulVVW2048(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW2048(SB), $0-16
MOVL $64, BX
JMP addMulVVWx(SB)
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
MOVL z+0(FP), DI
MOVL x+4(FP), SI
MOVL y+8(FP), BP
LEAL (DI)(BX*4), DI
LEAL (SI)(BX*4), SI
NEGL BX // i = -n
MOVL $0, CX // c = 0
JMP E6
L6: MOVL (SI)(BX*4), AX
MULL BP
ADDL CX, AX
ADCL $0, DX
ADDL AX, (DI)(BX*4)
ADCL $0, DX
MOVL DX, CX
ADDL $1, BX // i++
E6: CMPL BX, $0 // i < 0
JL L6
MOVL CX, c+12(FP)
RET

View File

@ -1,7 +0,0 @@
//go:build amd64 && gc && !purego
// +build amd64,gc,!purego
package bigmod
//go:noescape
func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint

File diff suppressed because it is too large Load Diff

48
internal/bigmod/nat_arm.s Normal file
View File

@ -0,0 +1,48 @@
// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
// +build !purego
#include "textflag.h"
// func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-16
MOVW $32, R5
JMP addMulVVWx(SB)
// func addMulVVW1536(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1536(SB), $0-16
MOVW $48, R5
JMP addMulVVWx(SB)
// func addMulVVW2048(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW2048(SB), $0-16
MOVW $64, R5
JMP addMulVVWx(SB)
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
MOVW $0, R0
MOVW z+0(FP), R1
MOVW x+4(FP), R2
MOVW y+8(FP), R3
ADD R5<<2, R1, R5
MOVW $0, R4
B E9
L9: MOVW.P 4(R2), R6
MULLU R6, R3, (R7, R6)
ADD.S R4, R6
ADC R0, R7
MOVW 0(R1), R4
ADD.S R4, R6
ADC R0, R7
MOVW.P R6, 4(R1)
MOVW R7, R4
E9: TEQ R1, R5
BNE L9
MOVW R4, c+12(FP)
RET

View File

@ -0,0 +1,70 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
// +build !purego
#include "textflag.h"
// func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R0
JMP addMulVVWx(SB)
// func addMulVVW1536(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1536(SB), $0-32
MOVD $24, R0
JMP addMulVVWx(SB)
// func addMulVVW2048(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW2048(SB), $0-32
MOVD $32, R0
JMP addMulVVWx(SB)
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
MOVD z+0(FP), R1
MOVD x+8(FP), R2
MOVD y+16(FP), R3
MOVD $0, R4
// The main loop of this code operates on a block of 4 words every iteration
// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9]
// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next
// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z.
loop:
CBZ R0, done
LDP.P 16(R2), (R5, R6)
LDP.P 16(R2), (R7, R8)
LDP (R1), (R9, R10)
ADDS R4, R9
MUL R6, R3, R14
ADCS R14, R10
MUL R7, R3, R15
LDP 16(R1), (R11, R12)
ADCS R15, R11
MUL R8, R3, R16
ADCS R16, R12
UMULH R8, R3, R20
ADC $0, R20
MUL R5, R3, R13
ADDS R13, R9
UMULH R5, R3, R17
ADCS R17, R10
UMULH R6, R3, R21
STP.P (R9, R10), 16(R1)
ADCS R21, R11
UMULH R7, R3, R19
ADCS R19, R12
STP.P (R11, R12), 16(R1)
ADC $0, R20, R4
SUB $4, R0
B loop
done:
MOVD R4, c+24(FP)
RET

View File

@ -0,0 +1,30 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
// +build !purego
// +build 386 amd64 arm arm64 ppc64 ppc64le s390x
package bigmod
import "golang.org/x/sys/cpu"
// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry
// chains in the flags in parallel across the whole operation, and aggressively
// unrolls loops. arm64 processes four words at a time.
//
// It's unclear why the assembly for all other architectures, as well as for
// amd64 without ADX, perform better than the compiler output.
// TODO(filippo): file cmd/compile performance issue.
var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2
//go:noescape
func addMulVVW1024(z, x *uint, y uint) (c uint)
//go:noescape
func addMulVVW1536(z, x *uint, y uint) (c uint)
//go:noescape
func addMulVVW2048(z, x *uint, y uint) (c uint)

View File

@ -1,12 +1,22 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !amd64 || !gc || purego
// +build !amd64 !gc purego
//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
// +build !386,!amd64,!arm,!arm64,!ppc64,!ppc64le,!s390x purego
package bigmod
func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint {
return montgomeryLoopGeneric(d, a, b, m, m0inv)
import "unsafe"
func addMulVVW1024(z, x *uint, y uint) (c uint) {
return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y)
}
func addMulVVW1536(z, x *uint, y uint) (c uint) {
return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y)
}
func addMulVVW2048(z, x *uint, y uint) (c uint) {
return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y)
}

View File

@ -0,0 +1,53 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego && (ppc64 || ppc64le)
// +build !purego
// +build ppc64 ppc64le
#include "textflag.h"
// func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R22 // R22 = z_len
JMP addMulVVWx(SB)
// func addMulVVW1536(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1536(SB), $0-32
MOVD $24, R22 // R22 = z_len
JMP addMulVVWx(SB)
// func addMulVVW2048(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW2048(SB), $0-32
MOVD $32, R22 // R22 = z_len
JMP addMulVVWx(SB)
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
MOVD z+0(FP), R10 // R10 = z[]
MOVD x+8(FP), R8 // R8 = x[]
MOVD y+16(FP), R9 // R9 = y
MOVD R0, R3 // R3 will be the index register
CMP R0, R22
MOVD R0, R4 // R4 = c = 0
MOVD R22, CTR // Initialize loop counter
BEQ done
PCALIGN $16
loop:
MOVD (R8)(R3), R20 // Load x[i]
MOVD (R10)(R3), R21 // Load z[i]
MULLD R9, R20, R6 // R6 = Low-order(x[i]*y)
MULHDU R9, R20, R7 // R7 = High-order(x[i]*y)
ADDC R21, R6 // R6 = z0
ADDZE R7 // R7 = z1
ADDC R4, R6 // R6 = z0 + c + 0
ADDZE R7, R4 // c += z1
MOVD R6, (R10)(R3) // Store z[i]
ADD $8, R3
BC 16, 0, loop // bdnz
done:
MOVD R4, c+24(FP)
RET

View File

@ -0,0 +1,86 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !purego
// +build !purego
#include "textflag.h"
// func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R5
JMP addMulVVWx(SB)
// func addMulVVW1536(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1536(SB), $0-32
MOVD $24, R5
JMP addMulVVWx(SB)
// func addMulVVW2048(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW2048(SB), $0-32
MOVD $32, R5
JMP addMulVVWx(SB)
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
MOVD z+0(FP), R2
MOVD x+8(FP), R8
MOVD y+16(FP), R9
MOVD $0, R1 // i*8 = 0
MOVD $0, R7 // i = 0
MOVD $0, R0 // make sure it's zero
MOVD $0, R4 // c = 0
MOVD R5, R12
AND $-2, R12
CMPBGE R5, $2, A6
BR E6
A6:
MOVD (R8)(R1*1), R6
MULHDU R9, R6
MOVD (R2)(R1*1), R10
ADDC R10, R11 // add to low order bits
ADDE R0, R6
ADDC R4, R11
ADDE R0, R6
MOVD R6, R4
MOVD R11, (R2)(R1*1)
MOVD (8)(R8)(R1*1), R6
MULHDU R9, R6
MOVD (8)(R2)(R1*1), R10
ADDC R10, R11 // add to low order bits
ADDE R0, R6
ADDC R4, R11
ADDE R0, R6
MOVD R6, R4
MOVD R11, (8)(R2)(R1*1)
ADD $16, R1 // i*8 + 8
ADD $2, R7 // i++
CMPBLT R7, R12, A6
BR E6
L6:
// TODO: drop unused single-step loop.
MOVD (R8)(R1*1), R6
MULHDU R9, R6
MOVD (R2)(R1*1), R10
ADDC R10, R11 // add to low order bits
ADDE R0, R6
ADDC R4, R11
ADDE R0, R6
MOVD R6, R4
MOVD R11, (R2)(R1*1)
ADD $8, R1 // i*8 + 8
ADD $1, R7 // i++
E6:
CMPBLT R7, R5, L6 // i < n
MOVD R4, c+24(FP)
RET

View File

@ -5,14 +5,24 @@
package bigmod
import (
"fmt"
"math/big"
"math/bits"
"math/rand"
"reflect"
"strings"
"testing"
"testing/quick"
)
func (n *Nat) String() string {
var limbs []string
for i := range n.limbs {
limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
}
return "{" + strings.Join(limbs, " ") + "}"
}
// Generate generates an even nat. It's used by testing/quick to produce random
// *nat values for quick.Check invocations.
func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
@ -54,21 +64,23 @@ func TestModSubThenAddIdentity(t *testing.T) {
}
}
func testMontgomeryRoundtrip(a *Nat) bool {
one := &Nat{make([]uint, len(a.limbs))}
one.limbs[0] = 1
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m := NewModulusFromBig(aPlusOne)
monty := new(Nat).Set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).Set(monty)
aAgain.montgomeryMul(monty, one, m)
return a.Equal(aAgain) == 1
}
func TestMontgomeryRoundtrip(t *testing.T) {
err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
err := quick.Check(func(a *Nat) bool {
one := &Nat{make([]uint, len(a.limbs))}
one.limbs[0] = 1
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulusFromBig(aPlusOne)
monty := new(Nat).Set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).Set(monty)
aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain)
return false
}
return true
}, &quick.Config{})
if err != nil {
t.Error(err)
}
@ -84,30 +96,30 @@ func TestShiftIn(t *testing.T) {
}{{
m: []byte{13},
x: []byte{0},
y: 0x7FFF_FFFF_FFFF_FFFF,
expected: []byte{7},
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{2},
}, {
m: []byte{13},
x: []byte{7},
y: 0x7FFF_FFFF_FFFF_FFFF,
expected: []byte{11},
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{10},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
x: make([]byte, 9),
y: 0x7FFF_FFFF_FFFF_FFFF,
expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
y: 0,
expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
}}
for i, tt := range examples {
m := modulusFromBytes(tt.m)
got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
t.Errorf("%d: got %v, expected %v", i, got, exp)
}
}
}
@ -186,7 +198,7 @@ func TestSetBytes(t *testing.T) {
continue
}
if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
t.Errorf("%d: got %x, expected %x", i, got, expected)
t.Errorf("%d: got %v, expected %v", i, got, expected)
}
}
@ -228,7 +240,7 @@ func TestExpand(t *testing.T) {
for i, tt := range examples {
got := (&Nat{tt.in}).expand(tt.n)
if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
t.Errorf("%d: got %x, expected %x", i, got, tt.out)
t.Errorf("%d: got %v, expected %v", i, got, tt.out)
}
}
}
@ -244,52 +256,6 @@ func TestMod(t *testing.T) {
}
}
func TestModNat(t *testing.T) {
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
orderNat := NewModulusFromBig(order)
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
if err != nil {
t.Fatal(err)
}
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
examples := []struct {
in *big.Int
expected *big.Int
}{
{
big.NewInt(1),
big.NewInt(2),
},
{
orderMinus1,
big.NewInt(1),
},
{
order,
big.NewInt(2),
},
{
hashValue1,
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
},
{
hashValue2,
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
},
}
for i, tt := range examples {
kNat := NewNat().SetBig(tt.in)
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
kNat.Add(oneNat, orderNat)
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
if out.Cmp(tt.expected) != 0 {
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
}
}
}
func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
@ -333,26 +299,68 @@ func TestExp(t *testing.T) {
}
}
func TestExpShort(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{3}}
out := &Nat{[]uint{0}}
out.ExpShort(x, 12, m)
expected := &Nat{[]uint{1}}
if out.Equal(expected) != 1 {
t.Errorf("%+v != %+v", out, expected)
}
}
// TestMulReductions tests that Mul reduces results equal or slightly greater
// than the modulus. Some Montgomery algorithms don't and need extra care to
// return correct results. See https://go.dev/issue/13907.
func TestMulReductions(t *testing.T) {
// Two short but multi-limb primes.
a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
n := new(big.Int).Mul(a, b)
N, _ := NewModulusFromBig(n)
A := NewNat().SetBig(a).ExpandFor(N)
B := NewNat().SetBig(b).ExpandFor(N)
if A.Mul(B, N).IsZero() != 1 {
t.Error("a * b mod (a * b) != 0")
}
i := new(big.Int).ModInverse(a, b)
N, _ = NewModulusFromBig(b)
A = NewNat().SetBig(a).ExpandFor(N)
I := NewNat().SetBig(i).ExpandFor(N)
one := NewNat().SetBig(big.NewInt(1)).ExpandFor(N)
if A.Mul(I, N).Equal(one) != 1 {
t.Error("a * inv(a) mod b != 1")
}
}
func natBytes(n *Nat) []byte {
return n.Bytes(maxModulus(uint(len(n.limbs))))
}
func natFromBytes(b []byte) *Nat {
// Must not use Nat.SetBytes as it's used in TestSetBytes.
bb := new(big.Int).SetBytes(b)
return NewNat().SetBig(bb)
}
func modulusFromBytes(b []byte) *Modulus {
bb := new(big.Int).SetBytes(b)
return NewModulusFromBig(bb)
m, _ := NewModulusFromBig(bb)
return m
}
// maxModulus returns the biggest modulus that can fit in n limbs.
func maxModulus(n uint) *Modulus {
m := big.NewInt(1)
m.Lsh(m, n*_W)
m.Sub(m, big.NewInt(1))
return NewModulusFromBig(m)
b := big.NewInt(1)
b.Lsh(b, n*_W)
b.Sub(b, big.NewInt(1))
m, _ := NewModulusFromBig(b)
return m
}
func makeBenchmarkModulus() *Modulus {
@ -362,7 +370,7 @@ func makeBenchmarkModulus() *Modulus {
func makeBenchmarkValue() *Nat {
x := make([]uint, 32)
for i := 0; i < 32; i++ {
x[i] = _MASK - 1
x[i]--
}
return &Nat{limbs: x}
}
@ -456,3 +464,17 @@ func BenchmarkExp(b *testing.B) {
out.Exp(x, e, m)
}
}
func TestNewModFromBigZero(t *testing.T) {
expected := "modulus must be >= 0"
_, err := NewModulusFromBig(big.NewInt(0))
if err == nil || err.Error() != expected {
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
}
expected = "modulus must be odd"
_, err = NewModulusFromBig(big.NewInt(2))
if err == nil || err.Error() != expected {
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
}
}

View File

@ -931,6 +931,6 @@ func p256() *sm2Curve {
func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
params := curve.Params()
c.curve = curve
c.N = bigmod.NewModulusFromBig(params.N)
c.N, _ = bigmod.NewModulusFromBig(params.N)
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
}

View File

@ -20,7 +20,7 @@ import (
// SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
var orderNat, _ = bigmod.NewModulusFromBig(bn256.Order)
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
var bigOne = big.NewInt(1)
var bigOneNat *bigmod.Nat