mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 12:16:20 +08:00
[sync sdk] crypto/internal/bigmod: switch to saturated limbs
This commit is contained in:
parent
f7a04e74a1
commit
f32b7e1afc
@ -5,16 +5,17 @@
|
||||
package bigmod
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
const (
|
||||
// _W is the number of bits we use for our limbs.
|
||||
_W = bits.UintSize - 1
|
||||
// _MASK selects _W bits from a full machine word.
|
||||
_MASK = (1 << _W) - 1
|
||||
// _W is the size in bits of our limbs.
|
||||
_W = bits.UintSize
|
||||
// _S is the size in bytes of our limbs.
|
||||
_S = _W / 8
|
||||
)
|
||||
|
||||
// choice represents a constant-time boolean. The value of choice is always
|
||||
@ -27,15 +28,8 @@ func not(c choice) choice { return 1 ^ c }
|
||||
const yes = choice(1)
|
||||
const no = choice(0)
|
||||
|
||||
// ctSelect returns x if on == 1, and y if on == 0. The execution time of this
|
||||
// function does not depend on its inputs. If on is any value besides 1 or 0,
|
||||
// the result is undefined.
|
||||
func ctSelect(on choice, x, y uint) uint {
|
||||
// When on == 1, mask is 0b111..., otherwise mask is 0b000...
|
||||
mask := -uint(on)
|
||||
// When mask is all zeros, we just have y, otherwise, y cancels with itself.
|
||||
return y ^ (mask & (y ^ x))
|
||||
}
|
||||
// ctMask is all 1s if on is yes, and all 0s otherwise.
|
||||
func ctMask(on choice) uint { return -uint(on) }
|
||||
|
||||
// ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
|
||||
// function does not depend on its inputs.
|
||||
@ -60,13 +54,7 @@ func ctGeq(x, y uint) choice {
|
||||
// Operations on this number are allowed to leak this length, but will not leak
|
||||
// any information about the values contained in those limbs.
|
||||
type Nat struct {
|
||||
// limbs is a little-endian representation in base 2^W with
|
||||
// W = bits.UintSize - 1. The top bit is always unset between operations.
|
||||
//
|
||||
// The top bit is left unset to optimize Montgomery multiplication, in the
|
||||
// inner loop of exponentiation. Using fully saturated limbs would leave us
|
||||
// working with 129-bit numbers on 64-bit platforms, wasting a lot of space,
|
||||
// and thus time.
|
||||
// limbs is little-endian in base 2^W with W = bits.UintSize.
|
||||
limbs []uint
|
||||
}
|
||||
|
||||
@ -128,25 +116,10 @@ func (x *Nat) Set(y *Nat) *Nat {
|
||||
// The announced length of x is set based on the actual bit size of the input,
|
||||
// ignoring leading zeroes.
|
||||
func (x *Nat) SetBig(n *big.Int) *Nat {
|
||||
requiredLimbs := (n.BitLen() + _W - 1) / _W
|
||||
x.reset(requiredLimbs)
|
||||
|
||||
outI := 0
|
||||
shift := 0
|
||||
limbs := n.Bits()
|
||||
x.reset(len(limbs))
|
||||
for i := range limbs {
|
||||
xi := uint(limbs[i])
|
||||
x.limbs[outI] |= (xi << shift) & _MASK
|
||||
outI++
|
||||
if outI == requiredLimbs {
|
||||
return x
|
||||
}
|
||||
x.limbs[outI] = xi >> (_W - shift)
|
||||
shift++ // this assumes bits.UintSize - _W = 1
|
||||
if shift == _W {
|
||||
shift = 0
|
||||
outI++
|
||||
}
|
||||
x.limbs[i] = uint(limbs[i])
|
||||
}
|
||||
return x
|
||||
}
|
||||
@ -156,24 +129,20 @@ func (x *Nat) SetBig(n *big.Int) *Nat {
|
||||
//
|
||||
// x must have the same size as m and it must be reduced modulo m.
|
||||
func (x *Nat) Bytes(m *Modulus) []byte {
|
||||
bytes := make([]byte, m.Size())
|
||||
shift := 0
|
||||
outI := len(bytes) - 1
|
||||
i := m.Size()
|
||||
bytes := make([]byte, i)
|
||||
for _, limb := range x.limbs {
|
||||
remainingBits := _W
|
||||
for remainingBits >= 8 {
|
||||
bytes[outI] |= byte(limb) << shift
|
||||
consumed := 8 - shift
|
||||
limb >>= consumed
|
||||
remainingBits -= consumed
|
||||
shift = 0
|
||||
outI--
|
||||
if outI < 0 {
|
||||
return bytes
|
||||
for j := 0; j < _S; j++ {
|
||||
i--
|
||||
if i < 0 {
|
||||
if limb == 0 {
|
||||
break
|
||||
}
|
||||
panic("bigmod: modulus is smaller than nat")
|
||||
}
|
||||
bytes[i] = byte(limb)
|
||||
limb >>= 8
|
||||
}
|
||||
bytes[outI] = byte(limb)
|
||||
shift = remainingBits
|
||||
}
|
||||
return bytes
|
||||
}
|
||||
@ -192,9 +161,9 @@ func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes. SetOverflowingBytes
|
||||
// returns an error if b has a longer bit length than m, but reduces overflowing
|
||||
// values up to 2^⌈log2(m)⌉ - 1.
|
||||
// SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
|
||||
// SetOverflowingBytes returns an error if b has a longer bit length than m, but
|
||||
// reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten.
|
||||
func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||
@ -203,33 +172,35 @@ func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
|
||||
}
|
||||
leading := _W - bitLen(x.limbs[len(x.limbs)-1])
|
||||
if leading < m.leading {
|
||||
return nil, errors.New("input overflows the modulus")
|
||||
return nil, errors.New("input overflows the modulus size")
|
||||
}
|
||||
x.sub(x.cmpGeq(m.nat), m.nat)
|
||||
x.maybeSubtractModulus(no, m)
|
||||
return x, nil
|
||||
}
|
||||
|
||||
// bigEndianUint returns the contents of buf interpreted as a
|
||||
// big-endian encoded uint value.
|
||||
func bigEndianUint(buf []byte) uint {
|
||||
if _W == 64 {
|
||||
return uint(binary.BigEndian.Uint64(buf))
|
||||
}
|
||||
return uint(binary.BigEndian.Uint32(buf))
|
||||
}
|
||||
|
||||
func (x *Nat) setBytes(b []byte, m *Modulus) error {
|
||||
outI := 0
|
||||
shift := 0
|
||||
x.resetFor(m)
|
||||
for i := len(b) - 1; i >= 0; i-- {
|
||||
bi := b[i]
|
||||
x.limbs[outI] |= uint(bi) << shift
|
||||
shift += 8
|
||||
if shift >= _W {
|
||||
shift -= _W
|
||||
x.limbs[outI] &= _MASK
|
||||
overflow := bi >> (8 - shift)
|
||||
outI++
|
||||
if outI >= len(x.limbs) {
|
||||
if overflow > 0 || i > 0 {
|
||||
return errors.New("input overflows the modulus")
|
||||
}
|
||||
break
|
||||
}
|
||||
x.limbs[outI] = uint(overflow)
|
||||
}
|
||||
i, k := len(b), 0
|
||||
for k < len(x.limbs) && i >= _S {
|
||||
x.limbs[k] = bigEndianUint(b[i-_S : i])
|
||||
i -= _S
|
||||
k++
|
||||
}
|
||||
for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
|
||||
x.limbs[k] |= uint(b[i-1]) << s
|
||||
i--
|
||||
}
|
||||
if i > 0 {
|
||||
return errors.New("input overflows the modulus size")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -274,7 +245,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
|
||||
|
||||
var c uint
|
||||
for i := 0; i < size; i++ {
|
||||
c = (xLimbs[i] - yLimbs[i] - c) >> _W
|
||||
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
|
||||
}
|
||||
// If there was a carry, then subtracting y underflowed, so
|
||||
// x is not greater than or equal to y.
|
||||
@ -290,44 +261,39 @@ func (x *Nat) assign(on choice, y *Nat) *Nat {
|
||||
xLimbs := x.limbs[:size]
|
||||
yLimbs := y.limbs[:size]
|
||||
|
||||
mask := ctMask(on)
|
||||
for i := 0; i < size; i++ {
|
||||
xLimbs[i] = ctSelect(on, yLimbs[i], xLimbs[i])
|
||||
xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// add computes x += y if on == 1, and does nothing otherwise. It returns the
|
||||
// carry of the addition regardless of on.
|
||||
// add computes x += y and returns the carry.
|
||||
//
|
||||
// Both operands must have the same announced length.
|
||||
func (x *Nat) add(on choice, y *Nat) (c uint) {
|
||||
func (x *Nat) add(y *Nat) (c uint) {
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(x.limbs)
|
||||
xLimbs := x.limbs[:size]
|
||||
yLimbs := y.limbs[:size]
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
res := xLimbs[i] + yLimbs[i] + c
|
||||
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
|
||||
c = res >> _W
|
||||
xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// sub computes x -= y if on == 1, and does nothing otherwise. It returns the
|
||||
// borrow of the subtraction regardless of on.
|
||||
// sub computes x -= y. It returns the borrow of the subtraction.
|
||||
//
|
||||
// Both operands must have the same announced length.
|
||||
func (x *Nat) sub(on choice, y *Nat) (c uint) {
|
||||
func (x *Nat) sub(y *Nat) (c uint) {
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(x.limbs)
|
||||
xLimbs := x.limbs[:size]
|
||||
yLimbs := y.limbs[:size]
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
res := xLimbs[i] - yLimbs[i] - c
|
||||
xLimbs[i] = ctSelect(on, res&_MASK, xLimbs[i])
|
||||
c = res >> _W
|
||||
xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -371,26 +337,32 @@ func minusInverseModW(x uint) uint {
|
||||
// Every iteration of this loop doubles the least-significant bits of
|
||||
// correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
|
||||
// 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
|
||||
// for 61 bits (and wastes only one iteration for 31 bits).
|
||||
// for 64 bits (and wastes only one iteration for 32 bits).
|
||||
//
|
||||
// See https://crypto.stackexchange.com/a/47496.
|
||||
y := x
|
||||
for i := 0; i < 5; i++ {
|
||||
y = y * (2 - x*y)
|
||||
}
|
||||
return (1 << _W) - (y & _MASK)
|
||||
return -y
|
||||
}
|
||||
|
||||
// NewModulusFromBig creates a new Modulus from a [big.Int].
|
||||
//
|
||||
// The Int must be odd. The number of significant bits must be leakable.
|
||||
func NewModulusFromBig(n *big.Int) *Modulus {
|
||||
// The Int must be odd. The number of significant bits (and nothing else) is
|
||||
// leaked through timing side-channels.
|
||||
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
|
||||
if b := n.Bits(); len(b) == 0 {
|
||||
return nil, errors.New("modulus must be >= 0")
|
||||
} else if b[0]&1 != 1 {
|
||||
return nil, errors.New("modulus must be odd")
|
||||
}
|
||||
m := &Modulus{}
|
||||
m.nat = NewNat().SetBig(n)
|
||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||
m.rr = rr(m)
|
||||
return m
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// bitLen is a version of bits.Len that only leaks the bit length of n, but not
|
||||
@ -447,25 +419,21 @@ func (x *Nat) shiftInNat(y uint, m *Nat) *Nat {
|
||||
//
|
||||
// To do the reduction, each iteration computes both 2x + b and 2x + b - m.
|
||||
// The next iteration (and finally the return line) will use either result
|
||||
// based on whether the subtraction underflowed.
|
||||
// based on whether 2x + b overflows m.
|
||||
needSubtraction := no
|
||||
for i := _W - 1; i >= 0; i-- {
|
||||
carry := (y >> i) & 1
|
||||
var borrow uint
|
||||
mask := ctMask(needSubtraction)
|
||||
for i := 0; i < size; i++ {
|
||||
l := ctSelect(needSubtraction, dLimbs[i], xLimbs[i])
|
||||
|
||||
res := l<<1 + carry
|
||||
xLimbs[i] = res & _MASK
|
||||
carry = res >> _W
|
||||
|
||||
res = xLimbs[i] - mLimbs[i] - borrow
|
||||
dLimbs[i] = res & _MASK
|
||||
borrow = res >> _W
|
||||
l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
|
||||
xLimbs[i], carry = bits.Add(l, l, carry)
|
||||
dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
|
||||
}
|
||||
// See Add for how carry (aka overflow), borrow (aka underflow), and
|
||||
// needSubtraction relate.
|
||||
needSubtraction = ctEq(carry, borrow)
|
||||
// Like in maybeSubtractModulus, we need the subtraction if either it
|
||||
// didn't underflow (meaning 2x + b > m) or if computing 2x + b
|
||||
// overflowed (meaning 2x + b > 2^_W*n > m).
|
||||
needSubtraction = not(choice(borrow)) | choice(carry)
|
||||
}
|
||||
return x.assign(needSubtraction, d)
|
||||
}
|
||||
@ -524,14 +492,34 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
|
||||
return out.reset(len(m.nat.limbs))
|
||||
}
|
||||
|
||||
// maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
|
||||
//
|
||||
// It can be used to reduce modulo m a value up to 2m - 1, which is a common
|
||||
// range for results computed by higher level operations.
|
||||
//
|
||||
// always is usually a carry that indicates that the operation that produced x
|
||||
// overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
|
||||
//
|
||||
// x and m operands must have the same announced length.
|
||||
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
||||
t := NewNat().Set(x)
|
||||
underflow := t.sub(m.nat)
|
||||
// We keep the result if x - m didn't underflow (meaning x >= m)
|
||||
// or if always was set.
|
||||
keep := not(choice(underflow)) | choice(always)
|
||||
x.assign(keep, t)
|
||||
}
|
||||
|
||||
// Sub computes x = x - y mod m.
|
||||
//
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
||||
underflow := x.sub(yes, y)
|
||||
underflow := x.sub(y)
|
||||
// If the subtraction underflowed, add m.
|
||||
x.add(choice(underflow), m.nat)
|
||||
t := NewNat().Set(x)
|
||||
t.add(m.nat)
|
||||
x.assign(choice(underflow), t)
|
||||
return x
|
||||
}
|
||||
|
||||
@ -540,34 +528,8 @@ func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
|
||||
overflow := x.add(yes, y)
|
||||
underflow := not(x.cmpGeq(m.nat)) // x < m
|
||||
|
||||
// Three cases are possible:
|
||||
//
|
||||
// - overflow = 0, underflow = 0
|
||||
//
|
||||
// In this case, addition fits in our limbs, but we can still subtract away
|
||||
// m without an underflow, so we need to perform the subtraction to reduce
|
||||
// our result.
|
||||
//
|
||||
// - overflow = 0, underflow = 1
|
||||
//
|
||||
// The addition fits in our limbs, but we can't subtract m without
|
||||
// underflowing. The result is already reduced.
|
||||
//
|
||||
// - overflow = 1, underflow = 1
|
||||
//
|
||||
// The addition does not fit in our limbs, and the subtraction's borrow
|
||||
// would cancel out with the addition's carry. We need to subtract m to
|
||||
// reduce our result.
|
||||
//
|
||||
// The overflow = 1, underflow = 0 case is not possible, because y is at
|
||||
// most m - 1, and if adding m - 1 overflows, then subtracting m must
|
||||
// necessarily underflow.
|
||||
needSubtraction := ctEq(overflow, uint(underflow))
|
||||
|
||||
x.sub(needSubtraction, m.nat)
|
||||
overflow := x.add(y)
|
||||
x.maybeSubtractModulus(choice(overflow), m)
|
||||
return x
|
||||
}
|
||||
|
||||
@ -581,7 +543,7 @@ func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
|
||||
func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
|
||||
// A Montgomery multiplication (which computes a * b / R) by R * R works out
|
||||
// to a multiplication by R, which takes the value out of the Montgomery domain.
|
||||
return x.montgomeryMul(NewNat().Set(x), m.rr, m)
|
||||
return x.montgomeryMul(x, m.rr, m)
|
||||
}
|
||||
|
||||
// montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
|
||||
@ -592,77 +554,157 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
|
||||
// By Montgomery multiplying with 1 not in Montgomery representation, we
|
||||
// convert out back from Montgomery representation, because it works out to
|
||||
// dividing by R.
|
||||
t0 := NewNat().Set(x)
|
||||
t1 := NewNat().ExpandFor(m)
|
||||
t1.limbs[0] = 1
|
||||
return x.montgomeryMul(t0, t1, m)
|
||||
one := NewNat().ExpandFor(m)
|
||||
one.limbs[0] = 1
|
||||
return x.montgomeryMul(x, one, m)
|
||||
}
|
||||
|
||||
// montgomeryMul calculates d = a * b / R mod m, with R = 2^(_W * n) and
|
||||
// n = len(m.nat.limbs), using the Montgomery Multiplication technique.
|
||||
// montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
|
||||
// n = len(m.nat.limbs), also known as a Montgomery multiplication.
|
||||
//
|
||||
// All inputs should be the same length, not aliasing d, and already
|
||||
// reduced modulo m. d will be resized to the size of m and overwritten.
|
||||
func (d *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
||||
d.resetFor(m)
|
||||
if len(a.limbs) != len(m.nat.limbs) || len(b.limbs) != len(m.nat.limbs) {
|
||||
panic("bigmod: invalid montgomeryMul input")
|
||||
}
|
||||
// All inputs should be the same length and already reduced modulo m.
|
||||
// x will be resized to the size of m and overwritten.
|
||||
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
||||
n := len(m.nat.limbs)
|
||||
mLimbs := m.nat.limbs[:n]
|
||||
aLimbs := a.limbs[:n]
|
||||
bLimbs := b.limbs[:n]
|
||||
|
||||
// See https://bearssl.org/bigint.html#montgomery-reduction-and-multiplication
|
||||
// for a description of the algorithm implemented mostly in montgomeryLoop.
|
||||
// See Add for how overflow, underflow, and needSubtraction relate.
|
||||
overflow := montgomeryLoop(d.limbs, a.limbs, b.limbs, m.nat.limbs, m.m0inv)
|
||||
underflow := not(d.cmpGeq(m.nat)) // d < m
|
||||
needSubtraction := ctEq(overflow, uint(underflow))
|
||||
d.sub(needSubtraction, m.nat)
|
||||
switch n {
|
||||
default:
|
||||
// Attempt to use a stack-allocated backing array.
|
||||
T := make([]uint, 0, preallocLimbs*2)
|
||||
if cap(T) < n*2 {
|
||||
T = make([]uint, 0, n*2)
|
||||
}
|
||||
T = T[:n*2]
|
||||
|
||||
return d
|
||||
}
|
||||
// This loop implements Word-by-Word Montgomery Multiplication, as
|
||||
// described in Algorithm 4 (Fig. 3) of "Efficient Software
|
||||
// Implementations of Modular Exponentiation" by Shay Gueron
|
||||
// [https://eprint.iacr.org/2011/239.pdf].
|
||||
var c uint
|
||||
for i := 0; i < n; i++ {
|
||||
_ = T[n+i] // bounds check elimination hint
|
||||
|
||||
func montgomeryLoopGeneric(d, a, b, m []uint, m0inv uint) (overflow uint) {
|
||||
// Eliminate bounds checks in the loop.
|
||||
size := len(d)
|
||||
a = a[:size]
|
||||
b = b[:size]
|
||||
m = m[:size]
|
||||
// Step 1 (T = a × b) is computed as a large pen-and-paper column
|
||||
// multiplication of two numbers with n base-2^_W digits. If we just
|
||||
// wanted to produce 2n-wide T, we would do
|
||||
//
|
||||
// for i := 0; i < n; i++ {
|
||||
// d := bLimbs[i]
|
||||
// T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
|
||||
// }
|
||||
//
|
||||
// where d is a digit of the multiplier, T[i:n+i] is the shifted
|
||||
// position of the product of that digit, and T[n+i] is the final carry.
|
||||
// Note that T[i] isn't modified after processing the i-th digit.
|
||||
//
|
||||
// Instead of running two loops, one for Step 1 and one for Steps 2–6,
|
||||
// the result of Step 1 is computed during the next loop. This is
|
||||
// possible because each iteration only uses T[i] in Step 2 and then
|
||||
// discards it in Step 6.
|
||||
d := bLimbs[i]
|
||||
c1 := addMulVVW(T[i:n+i], aLimbs, d)
|
||||
|
||||
for _, ai := range a {
|
||||
// This is an unrolled iteration of the loop below with j = 0.
|
||||
hi, lo := bits.Mul(ai, b[0])
|
||||
z_lo, c := bits.Add(d[0], lo, 0)
|
||||
f := (z_lo * m0inv) & _MASK // (d[0] + a[i] * b[0]) * m0inv
|
||||
z_hi, _ := bits.Add(0, hi, c)
|
||||
hi, lo = bits.Mul(f, m[0])
|
||||
z_lo, c = bits.Add(z_lo, lo, 0)
|
||||
z_hi, _ = bits.Add(z_hi, hi, c)
|
||||
carry := z_hi<<1 | z_lo>>_W
|
||||
// Step 6 is replaced by shifting the virtual window we operate
|
||||
// over: T of the algorithm is T[i:] for us. That means that T1 in
|
||||
// Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
|
||||
Y := T[i] * m.m0inv
|
||||
|
||||
for j := 1; j < size; j++ {
|
||||
// z = d[j] + a[i] * b[j] + f * m[j] + carry <= 2^(2W+1) - 2^(W+1) + 2^W
|
||||
hi, lo := bits.Mul(ai, b[j])
|
||||
z_lo, c := bits.Add(d[j], lo, 0)
|
||||
z_hi, _ := bits.Add(0, hi, c)
|
||||
hi, lo = bits.Mul(f, m[j])
|
||||
z_lo, c = bits.Add(z_lo, lo, 0)
|
||||
z_hi, _ = bits.Add(z_hi, hi, c)
|
||||
z_lo, c = bits.Add(z_lo, carry, 0)
|
||||
z_hi, _ = bits.Add(z_hi, 0, c)
|
||||
d[j-1] = z_lo & _MASK
|
||||
carry = z_hi<<1 | z_lo>>_W // carry <= 2^(W+1) - 2
|
||||
// Step 4 and 5 add Y × m to T, which as mentioned above is stored
|
||||
// at T[i:]. The two carries (from a × d and Y × m) are added up in
|
||||
// the next word T[n+i], and the carry bit from that addition is
|
||||
// brought forward to the next iteration.
|
||||
c2 := addMulVVW(T[i:n+i], mLimbs, Y)
|
||||
T[n+i], c = bits.Add(c1, c2, c)
|
||||
}
|
||||
|
||||
z := overflow + carry // z <= 2^(W+1) - 1
|
||||
d[size-1] = z & _MASK
|
||||
overflow = z >> _W // overflow <= 1
|
||||
// Finally for Step 7 we copy the final T window into x, and subtract m
|
||||
// if necessary (which as explained in maybeSubtractModulus can be the
|
||||
// case both if x >= m, or if x overflowed).
|
||||
//
|
||||
// The paper suggests in Section 4 that we can do an "Almost Montgomery
|
||||
// Multiplication" by subtracting only in the overflow case, but the
|
||||
// cost is very similar since the constant time subtraction tells us if
|
||||
// x >= m as a side effect, and taking care of the broken invariant is
|
||||
// highly undesirable (see https://go.dev/issue/13907).
|
||||
copy(x.reset(n).limbs, T[n:])
|
||||
x.maybeSubtractModulus(choice(c), m)
|
||||
|
||||
// The following specialized cases follow the exact same algorithm, but
|
||||
// optimized for the sizes most used in RSA. addMulVVW is implemented in
|
||||
// assembly with loop unrolling depending on the architecture and bounds
|
||||
// checks are removed by the compiler thanks to the constant size.
|
||||
case 1024 / _W:
|
||||
const n = 1024 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
var c uint
|
||||
for i := 0; i < n; i++ {
|
||||
d := bLimbs[i]
|
||||
c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
|
||||
Y := T[i] * m.m0inv
|
||||
c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
|
||||
T[n+i], c = bits.Add(c1, c2, c)
|
||||
}
|
||||
copy(x.reset(n).limbs, T[n:])
|
||||
x.maybeSubtractModulus(choice(c), m)
|
||||
|
||||
case 1536 / _W:
|
||||
const n = 1536 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
var c uint
|
||||
for i := 0; i < n; i++ {
|
||||
d := bLimbs[i]
|
||||
c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
|
||||
Y := T[i] * m.m0inv
|
||||
c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
|
||||
T[n+i], c = bits.Add(c1, c2, c)
|
||||
}
|
||||
copy(x.reset(n).limbs, T[n:])
|
||||
x.maybeSubtractModulus(choice(c), m)
|
||||
|
||||
case 2048 / _W:
|
||||
const n = 2048 / _W // compiler hint
|
||||
T := make([]uint, n*2)
|
||||
var c uint
|
||||
for i := 0; i < n; i++ {
|
||||
d := bLimbs[i]
|
||||
c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
|
||||
Y := T[i] * m.m0inv
|
||||
c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
|
||||
T[n+i], c = bits.Add(c1, c2, c)
|
||||
}
|
||||
copy(x.reset(n).limbs, T[n:])
|
||||
x.maybeSubtractModulus(choice(c), m)
|
||||
}
|
||||
return
|
||||
|
||||
return x
|
||||
}
|
||||
|
||||
// Mul calculates x *= y mod m.
|
||||
// addMulVVW multiplies the multi-word value x by the single-word value y,
|
||||
// adding the result to the multi-word value z and returning the final carry.
|
||||
// It can be thought of as one row of a pen-and-paper column multiplication.
|
||||
func addMulVVW(z, x []uint, y uint) (carry uint) {
|
||||
_ = x[len(z)-1] // bounds check elimination hint
|
||||
for i := range z {
|
||||
hi, lo := bits.Mul(x[i], y)
|
||||
lo, c := bits.Add(lo, z[i], 0)
|
||||
// We use bits.Add with zero to get an add-with-carry instruction that
|
||||
// absorbs the carry from the previous bits.Add.
|
||||
hi, _ = bits.Add(hi, 0, c)
|
||||
lo, c = bits.Add(lo, carry, 0)
|
||||
hi, _ = bits.Add(hi, 0, c)
|
||||
carry = hi
|
||||
z[i] = lo
|
||||
}
|
||||
return carry
|
||||
}
|
||||
|
||||
// Mul calculates x = x * y mod m.
|
||||
//
|
||||
// x and y must already be reduced modulo m, they must share its announced
|
||||
// length, and they may not alias.
|
||||
// The length of both operands must be the same as the modulus. Both operands
|
||||
// must already be reduced modulo m.
|
||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||
// takes the result out of Montgomery representation.
|
||||
@ -677,7 +719,8 @@ func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
||||
func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||
// We use a 4 bit window. For our RSA workload, 4 bit windows are faster
|
||||
// than 2 bit windows, but use an extra 12 nats worth of scratch space.
|
||||
// Using bit sizes that don't divide 8 are more complex to implement.
|
||||
// Using bit sizes that don't divide 8 are more complex to implement, but
|
||||
// are likely to be more efficient if necessary.
|
||||
|
||||
table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
|
||||
// newNat calls are unrolled so they are allocated on the stack.
|
||||
@ -693,27 +736,51 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
||||
out.resetFor(m)
|
||||
out.limbs[0] = 1
|
||||
out.montgomeryRepresentation(m)
|
||||
t0 := NewNat().ExpandFor(m)
|
||||
t1 := NewNat().ExpandFor(m)
|
||||
tmp := NewNat().ExpandFor(m)
|
||||
for _, b := range e {
|
||||
for _, j := range []int{4, 0} {
|
||||
// Square four times.
|
||||
t1.montgomeryMul(out, out, m)
|
||||
out.montgomeryMul(t1, t1, m)
|
||||
t1.montgomeryMul(out, out, m)
|
||||
out.montgomeryMul(t1, t1, m)
|
||||
// Square four times. Optimization note: this can be implemented
|
||||
// more efficiently than with generic Montgomery multiplication.
|
||||
out.montgomeryMul(out, out, m)
|
||||
out.montgomeryMul(out, out, m)
|
||||
out.montgomeryMul(out, out, m)
|
||||
out.montgomeryMul(out, out, m)
|
||||
|
||||
// Select x^k in constant time from the table.
|
||||
k := uint((b >> j) & 0b1111)
|
||||
for i := range table {
|
||||
t0.assign(ctEq(k, uint(i+1)), table[i])
|
||||
tmp.assign(ctEq(k, uint(i+1)), table[i])
|
||||
}
|
||||
|
||||
// Multiply by x^k, discarding the result if k = 0.
|
||||
t1.montgomeryMul(out, t0, m)
|
||||
out.assign(not(ctEq(k, 0)), t1)
|
||||
tmp.montgomeryMul(out, tmp, m)
|
||||
out.assign(not(ctEq(k, 0)), tmp)
|
||||
}
|
||||
}
|
||||
|
||||
return out.montgomeryReduction(m)
|
||||
}
|
||||
|
||||
// ExpShort calculates out = x^e mod m.
|
||||
//
|
||||
// The output will be resized to the size of m and overwritten. x must already
|
||||
// be reduced modulo m. This leaks the exact bit size of the exponent.
|
||||
func (out *Nat) ExpShort(x *Nat, e uint, m *Modulus) *Nat {
|
||||
xR := NewNat().Set(x).montgomeryRepresentation(m)
|
||||
|
||||
out.resetFor(m)
|
||||
out.limbs[0] = 1
|
||||
out.montgomeryRepresentation(m)
|
||||
|
||||
// For short exponents, precomputing a table and using a window like in Exp
|
||||
// doesn't pay off. Instead, we do a simple constant-time conditional
|
||||
// square-and-multiply chain, skipping the initial run of zeroes.
|
||||
tmp := NewNat().ExpandFor(m)
|
||||
for i := bits.UintSize - bitLen(e); i < bits.UintSize; i++ {
|
||||
out.montgomeryMul(out, out, m)
|
||||
k := (e >> (bits.UintSize - i - 1)) & 1
|
||||
tmp.montgomeryMul(out, xR, m)
|
||||
out.assign(ctEq(k, 1), tmp)
|
||||
}
|
||||
return out.montgomeryReduction(m)
|
||||
}
|
||||
|
48
internal/bigmod/nat_386.s
Normal file
48
internal/bigmod/nat_386.s
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego
|
||||
// +build !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1024(SB), $0-16
|
||||
MOVL $32, BX
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1536(SB), $0-16
|
||||
MOVL $48, BX
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW2048(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW2048(SB), $0-16
|
||||
MOVL $64, BX
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
|
||||
MOVL z+0(FP), DI
|
||||
MOVL x+4(FP), SI
|
||||
MOVL y+8(FP), BP
|
||||
LEAL (DI)(BX*4), DI
|
||||
LEAL (SI)(BX*4), SI
|
||||
NEGL BX // i = -n
|
||||
MOVL $0, CX // c = 0
|
||||
JMP E6
|
||||
|
||||
L6: MOVL (SI)(BX*4), AX
|
||||
MULL BP
|
||||
ADDL CX, AX
|
||||
ADCL $0, DX
|
||||
ADDL AX, (DI)(BX*4)
|
||||
ADCL $0, DX
|
||||
MOVL DX, CX
|
||||
ADDL $1, BX // i++
|
||||
|
||||
E6: CMPL BX, $0 // i < 0
|
||||
JL L6
|
||||
|
||||
MOVL CX, c+12(FP)
|
||||
RET
|
@ -1,7 +0,0 @@
|
||||
//go:build amd64 && gc && !purego
|
||||
// +build amd64,gc,!purego
|
||||
|
||||
package bigmod
|
||||
|
||||
//go:noescape
|
||||
func montgomeryLoop(d []uint, a []uint, b []uint, m []uint, m0inv uint) uint
|
File diff suppressed because it is too large
Load Diff
48
internal/bigmod/nat_arm.s
Normal file
48
internal/bigmod/nat_arm.s
Normal file
@ -0,0 +1,48 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego
|
||||
// +build !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1024(SB), $0-16
|
||||
MOVW $32, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1536(SB), $0-16
|
||||
MOVW $48, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW2048(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW2048(SB), $0-16
|
||||
MOVW $64, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
|
||||
MOVW $0, R0
|
||||
MOVW z+0(FP), R1
|
||||
MOVW x+4(FP), R2
|
||||
MOVW y+8(FP), R3
|
||||
ADD R5<<2, R1, R5
|
||||
MOVW $0, R4
|
||||
B E9
|
||||
|
||||
L9: MOVW.P 4(R2), R6
|
||||
MULLU R6, R3, (R7, R6)
|
||||
ADD.S R4, R6
|
||||
ADC R0, R7
|
||||
MOVW 0(R1), R4
|
||||
ADD.S R4, R6
|
||||
ADC R0, R7
|
||||
MOVW.P R6, 4(R1)
|
||||
MOVW R7, R4
|
||||
|
||||
E9: TEQ R1, R5
|
||||
BNE L9
|
||||
|
||||
MOVW R4, c+12(FP)
|
||||
RET
|
70
internal/bigmod/nat_arm64.s
Normal file
70
internal/bigmod/nat_arm64.s
Normal file
@ -0,0 +1,70 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego
|
||||
// +build !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1024(SB), $0-32
|
||||
MOVD $16, R0
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1536(SB), $0-32
|
||||
MOVD $24, R0
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW2048(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW2048(SB), $0-32
|
||||
MOVD $32, R0
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
|
||||
MOVD z+0(FP), R1
|
||||
MOVD x+8(FP), R2
|
||||
MOVD y+16(FP), R3
|
||||
MOVD $0, R4
|
||||
|
||||
// The main loop of this code operates on a block of 4 words every iteration
|
||||
// performing [R4:R12:R11:R10:R9] = R4 + R3 * [R8:R7:R6:R5] + [R12:R11:R10:R9]
|
||||
// where R4 is carried from the previous iteration, R8:R7:R6:R5 hold the next
|
||||
// 4 words of x, R3 is y and R12:R11:R10:R9 are part of the result z.
|
||||
loop:
|
||||
CBZ R0, done
|
||||
|
||||
LDP.P 16(R2), (R5, R6)
|
||||
LDP.P 16(R2), (R7, R8)
|
||||
|
||||
LDP (R1), (R9, R10)
|
||||
ADDS R4, R9
|
||||
MUL R6, R3, R14
|
||||
ADCS R14, R10
|
||||
MUL R7, R3, R15
|
||||
LDP 16(R1), (R11, R12)
|
||||
ADCS R15, R11
|
||||
MUL R8, R3, R16
|
||||
ADCS R16, R12
|
||||
UMULH R8, R3, R20
|
||||
ADC $0, R20
|
||||
|
||||
MUL R5, R3, R13
|
||||
ADDS R13, R9
|
||||
UMULH R5, R3, R17
|
||||
ADCS R17, R10
|
||||
UMULH R6, R3, R21
|
||||
STP.P (R9, R10), 16(R1)
|
||||
ADCS R21, R11
|
||||
UMULH R7, R3, R19
|
||||
ADCS R19, R12
|
||||
STP.P (R11, R12), 16(R1)
|
||||
ADC $0, R20, R4
|
||||
|
||||
SUB $4, R0
|
||||
B loop
|
||||
|
||||
done:
|
||||
MOVD R4, c+24(FP)
|
||||
RET
|
30
internal/bigmod/nat_asm.go
Normal file
30
internal/bigmod/nat_asm.go
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego && (386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
|
||||
// +build !purego
|
||||
// +build 386 amd64 arm arm64 ppc64 ppc64le s390x
|
||||
|
||||
package bigmod
|
||||
|
||||
import "golang.org/x/sys/cpu"
|
||||
|
||||
// amd64 assembly uses ADCX/ADOX/MULX if ADX is available to run two carry
|
||||
// chains in the flags in parallel across the whole operation, and aggressively
|
||||
// unrolls loops. arm64 processes four words at a time.
|
||||
//
|
||||
// It's unclear why the assembly for all other architectures, as well as for
|
||||
// amd64 without ADX, perform better than the compiler output.
|
||||
// TODO(filippo): file cmd/compile performance issue.
|
||||
|
||||
var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2
|
||||
|
||||
//go:noescape
|
||||
func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
|
||||
//go:noescape
|
||||
func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
|
||||
//go:noescape
|
||||
func addMulVVW2048(z, x *uint, y uint) (c uint)
|
@ -1,12 +1,22 @@
|
||||
// Copyright 2022 The Go Authors. All rights reserved.
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !amd64 || !gc || purego
|
||||
// +build !amd64 !gc purego
|
||||
//go:build purego || !(386 || amd64 || arm || arm64 || ppc64 || ppc64le || s390x)
|
||||
// +build !386,!amd64,!arm,!arm64,!ppc64,!ppc64le,!s390x purego
|
||||
|
||||
package bigmod
|
||||
|
||||
func montgomeryLoop(d, a, b, m []uint, m0inv uint) uint {
|
||||
return montgomeryLoopGeneric(d, a, b, m, m0inv)
|
||||
import "unsafe"
|
||||
|
||||
func addMulVVW1024(z, x *uint, y uint) (c uint) {
|
||||
return addMulVVW(unsafe.Slice(z, 1024/_W), unsafe.Slice(x, 1024/_W), y)
|
||||
}
|
||||
|
||||
func addMulVVW1536(z, x *uint, y uint) (c uint) {
|
||||
return addMulVVW(unsafe.Slice(z, 1536/_W), unsafe.Slice(x, 1536/_W), y)
|
||||
}
|
||||
|
||||
func addMulVVW2048(z, x *uint, y uint) (c uint) {
|
||||
return addMulVVW(unsafe.Slice(z, 2048/_W), unsafe.Slice(x, 2048/_W), y)
|
||||
}
|
||||
|
53
internal/bigmod/nat_ppc64x.s
Normal file
53
internal/bigmod/nat_ppc64x.s
Normal file
@ -0,0 +1,53 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego && (ppc64 || ppc64le)
|
||||
// +build !purego
|
||||
// +build ppc64 ppc64le
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1024(SB), $0-32
|
||||
MOVD $16, R22 // R22 = z_len
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1536(SB), $0-32
|
||||
MOVD $24, R22 // R22 = z_len
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW2048(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW2048(SB), $0-32
|
||||
MOVD $32, R22 // R22 = z_len
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
|
||||
MOVD z+0(FP), R10 // R10 = z[]
|
||||
MOVD x+8(FP), R8 // R8 = x[]
|
||||
MOVD y+16(FP), R9 // R9 = y
|
||||
|
||||
MOVD R0, R3 // R3 will be the index register
|
||||
CMP R0, R22
|
||||
MOVD R0, R4 // R4 = c = 0
|
||||
MOVD R22, CTR // Initialize loop counter
|
||||
BEQ done
|
||||
PCALIGN $16
|
||||
|
||||
loop:
|
||||
MOVD (R8)(R3), R20 // Load x[i]
|
||||
MOVD (R10)(R3), R21 // Load z[i]
|
||||
MULLD R9, R20, R6 // R6 = Low-order(x[i]*y)
|
||||
MULHDU R9, R20, R7 // R7 = High-order(x[i]*y)
|
||||
ADDC R21, R6 // R6 = z0
|
||||
ADDZE R7 // R7 = z1
|
||||
ADDC R4, R6 // R6 = z0 + c + 0
|
||||
ADDZE R7, R4 // c += z1
|
||||
MOVD R6, (R10)(R3) // Store z[i]
|
||||
ADD $8, R3
|
||||
BC 16, 0, loop // bdnz
|
||||
|
||||
done:
|
||||
MOVD R4, c+24(FP)
|
||||
RET
|
86
internal/bigmod/nat_s390x.s
Normal file
86
internal/bigmod/nat_s390x.s
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build !purego
|
||||
// +build !purego
|
||||
|
||||
#include "textflag.h"
|
||||
|
||||
// func addMulVVW1024(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1024(SB), $0-32
|
||||
MOVD $16, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW1536(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW1536(SB), $0-32
|
||||
MOVD $24, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
// func addMulVVW2048(z, x *uint, y uint) (c uint)
|
||||
TEXT ·addMulVVW2048(SB), $0-32
|
||||
MOVD $32, R5
|
||||
JMP addMulVVWx(SB)
|
||||
|
||||
TEXT addMulVVWx(SB), NOFRAME|NOSPLIT, $0
|
||||
MOVD z+0(FP), R2
|
||||
MOVD x+8(FP), R8
|
||||
MOVD y+16(FP), R9
|
||||
|
||||
MOVD $0, R1 // i*8 = 0
|
||||
MOVD $0, R7 // i = 0
|
||||
MOVD $0, R0 // make sure it's zero
|
||||
MOVD $0, R4 // c = 0
|
||||
|
||||
MOVD R5, R12
|
||||
AND $-2, R12
|
||||
CMPBGE R5, $2, A6
|
||||
BR E6
|
||||
|
||||
A6:
|
||||
MOVD (R8)(R1*1), R6
|
||||
MULHDU R9, R6
|
||||
MOVD (R2)(R1*1), R10
|
||||
ADDC R10, R11 // add to low order bits
|
||||
ADDE R0, R6
|
||||
ADDC R4, R11
|
||||
ADDE R0, R6
|
||||
MOVD R6, R4
|
||||
MOVD R11, (R2)(R1*1)
|
||||
|
||||
MOVD (8)(R8)(R1*1), R6
|
||||
MULHDU R9, R6
|
||||
MOVD (8)(R2)(R1*1), R10
|
||||
ADDC R10, R11 // add to low order bits
|
||||
ADDE R0, R6
|
||||
ADDC R4, R11
|
||||
ADDE R0, R6
|
||||
MOVD R6, R4
|
||||
MOVD R11, (8)(R2)(R1*1)
|
||||
|
||||
ADD $16, R1 // i*8 + 8
|
||||
ADD $2, R7 // i++
|
||||
|
||||
CMPBLT R7, R12, A6
|
||||
BR E6
|
||||
|
||||
L6:
|
||||
// TODO: drop unused single-step loop.
|
||||
MOVD (R8)(R1*1), R6
|
||||
MULHDU R9, R6
|
||||
MOVD (R2)(R1*1), R10
|
||||
ADDC R10, R11 // add to low order bits
|
||||
ADDE R0, R6
|
||||
ADDC R4, R11
|
||||
ADDE R0, R6
|
||||
MOVD R6, R4
|
||||
MOVD R11, (R2)(R1*1)
|
||||
|
||||
ADD $8, R1 // i*8 + 8
|
||||
ADD $1, R7 // i++
|
||||
|
||||
E6:
|
||||
CMPBLT R7, R5, L6 // i < n
|
||||
|
||||
MOVD R4, c+24(FP)
|
||||
RET
|
@ -5,14 +5,24 @@
|
||||
package bigmod
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/big"
|
||||
"math/bits"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
)
|
||||
|
||||
func (n *Nat) String() string {
|
||||
var limbs []string
|
||||
for i := range n.limbs {
|
||||
limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
|
||||
}
|
||||
return "{" + strings.Join(limbs, " ") + "}"
|
||||
}
|
||||
|
||||
// Generate generates an even nat. It's used by testing/quick to produce random
|
||||
// *nat values for quick.Check invocations.
|
||||
func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
|
||||
@ -54,21 +64,23 @@ func TestModSubThenAddIdentity(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func testMontgomeryRoundtrip(a *Nat) bool {
|
||||
one := &Nat{make([]uint, len(a.limbs))}
|
||||
one.limbs[0] = 1
|
||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||
m := NewModulusFromBig(aPlusOne)
|
||||
monty := new(Nat).Set(a)
|
||||
monty.montgomeryRepresentation(m)
|
||||
aAgain := new(Nat).Set(monty)
|
||||
aAgain.montgomeryMul(monty, one, m)
|
||||
return a.Equal(aAgain) == 1
|
||||
}
|
||||
|
||||
func TestMontgomeryRoundtrip(t *testing.T) {
|
||||
err := quick.Check(testMontgomeryRoundtrip, &quick.Config{})
|
||||
err := quick.Check(func(a *Nat) bool {
|
||||
one := &Nat{make([]uint, len(a.limbs))}
|
||||
one.limbs[0] = 1
|
||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||
m, _ := NewModulusFromBig(aPlusOne)
|
||||
monty := new(Nat).Set(a)
|
||||
monty.montgomeryRepresentation(m)
|
||||
aAgain := new(Nat).Set(monty)
|
||||
aAgain.montgomeryMul(monty, one, m)
|
||||
if a.Equal(aAgain) != 1 {
|
||||
t.Errorf("%v != %v", a, aAgain)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}, &quick.Config{})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -84,30 +96,30 @@ func TestShiftIn(t *testing.T) {
|
||||
}{{
|
||||
m: []byte{13},
|
||||
x: []byte{0},
|
||||
y: 0x7FFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{7},
|
||||
y: 0xFFFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{2},
|
||||
}, {
|
||||
m: []byte{13},
|
||||
x: []byte{7},
|
||||
y: 0x7FFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{11},
|
||||
y: 0xFFFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{10},
|
||||
}, {
|
||||
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
|
||||
x: make([]byte, 9),
|
||||
y: 0x7FFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
y: 0xFFFF_FFFF_FFFF_FFFF,
|
||||
expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
}, {
|
||||
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
|
||||
x: []byte{0x00, 0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
|
||||
y: 0,
|
||||
expected: []byte{0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08},
|
||||
expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
|
||||
}}
|
||||
|
||||
for i, tt := range examples {
|
||||
m := modulusFromBytes(tt.m)
|
||||
got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
|
||||
if got.Equal(natFromBytes(tt.expected).ExpandFor(m)) != 1 {
|
||||
t.Errorf("%d: got %x, expected %x", i, got, tt.expected)
|
||||
if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
|
||||
t.Errorf("%d: got %v, expected %v", i, got, exp)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -186,7 +198,7 @@ func TestSetBytes(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
|
||||
t.Errorf("%d: got %x, expected %x", i, got, expected)
|
||||
t.Errorf("%d: got %v, expected %v", i, got, expected)
|
||||
}
|
||||
}
|
||||
|
||||
@ -228,7 +240,7 @@ func TestExpand(t *testing.T) {
|
||||
for i, tt := range examples {
|
||||
got := (&Nat{tt.in}).expand(tt.n)
|
||||
if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
|
||||
t.Errorf("%d: got %x, expected %x", i, got, tt.out)
|
||||
t.Errorf("%d: got %v, expected %v", i, got, tt.out)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -244,52 +256,6 @@ func TestMod(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModNat(t *testing.T) {
|
||||
order, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||
orderNat := NewModulusFromBig(order)
|
||||
oneNat, err := NewNat().SetBytes(big.NewInt(1).Bytes(), orderNat)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
orderMinus1, _ := new(big.Int).SetString("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24", 16)
|
||||
hashValue1, _ := new(big.Int).SetString("1000000000000000a640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25", 16)
|
||||
hashValue2, _ := new(big.Int).SetString("1000000000000000b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23", 16)
|
||||
examples := []struct {
|
||||
in *big.Int
|
||||
expected *big.Int
|
||||
}{
|
||||
{
|
||||
big.NewInt(1),
|
||||
big.NewInt(2),
|
||||
},
|
||||
{
|
||||
orderMinus1,
|
||||
big.NewInt(1),
|
||||
},
|
||||
{
|
||||
order,
|
||||
big.NewInt(2),
|
||||
},
|
||||
{
|
||||
hashValue1,
|
||||
new(big.Int).Add(new(big.Int).Mod(hashValue1, orderMinus1), big.NewInt(1)),
|
||||
},
|
||||
{
|
||||
hashValue2,
|
||||
new(big.Int).Add(new(big.Int).Mod(hashValue2, orderMinus1), big.NewInt(1)),
|
||||
},
|
||||
}
|
||||
for i, tt := range examples {
|
||||
kNat := NewNat().SetBig(tt.in)
|
||||
kNat = NewNat().ModNat(kNat, NewNat().SetBig(orderMinus1))
|
||||
kNat.Add(oneNat, orderNat)
|
||||
out := new(big.Int).SetBytes(kNat.Bytes(orderNat))
|
||||
if out.Cmp(tt.expected) != 0 {
|
||||
t.Errorf("%d: got %x, expected %x", i, out, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestModSub(t *testing.T) {
|
||||
m := modulusFromBytes([]byte{13})
|
||||
x := &Nat{[]uint{6}}
|
||||
@ -333,26 +299,68 @@ func TestExp(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpShort(t *testing.T) {
|
||||
m := modulusFromBytes([]byte{13})
|
||||
x := &Nat{[]uint{3}}
|
||||
out := &Nat{[]uint{0}}
|
||||
out.ExpShort(x, 12, m)
|
||||
expected := &Nat{[]uint{1}}
|
||||
if out.Equal(expected) != 1 {
|
||||
t.Errorf("%+v != %+v", out, expected)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMulReductions tests that Mul reduces results equal or slightly greater
|
||||
// than the modulus. Some Montgomery algorithms don't and need extra care to
|
||||
// return correct results. See https://go.dev/issue/13907.
|
||||
func TestMulReductions(t *testing.T) {
|
||||
// Two short but multi-limb primes.
|
||||
a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
|
||||
b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
|
||||
n := new(big.Int).Mul(a, b)
|
||||
|
||||
N, _ := NewModulusFromBig(n)
|
||||
A := NewNat().SetBig(a).ExpandFor(N)
|
||||
B := NewNat().SetBig(b).ExpandFor(N)
|
||||
|
||||
if A.Mul(B, N).IsZero() != 1 {
|
||||
t.Error("a * b mod (a * b) != 0")
|
||||
}
|
||||
|
||||
i := new(big.Int).ModInverse(a, b)
|
||||
N, _ = NewModulusFromBig(b)
|
||||
A = NewNat().SetBig(a).ExpandFor(N)
|
||||
I := NewNat().SetBig(i).ExpandFor(N)
|
||||
one := NewNat().SetBig(big.NewInt(1)).ExpandFor(N)
|
||||
|
||||
if A.Mul(I, N).Equal(one) != 1 {
|
||||
t.Error("a * inv(a) mod b != 1")
|
||||
}
|
||||
}
|
||||
|
||||
func natBytes(n *Nat) []byte {
|
||||
return n.Bytes(maxModulus(uint(len(n.limbs))))
|
||||
}
|
||||
|
||||
func natFromBytes(b []byte) *Nat {
|
||||
// Must not use Nat.SetBytes as it's used in TestSetBytes.
|
||||
bb := new(big.Int).SetBytes(b)
|
||||
return NewNat().SetBig(bb)
|
||||
}
|
||||
|
||||
func modulusFromBytes(b []byte) *Modulus {
|
||||
bb := new(big.Int).SetBytes(b)
|
||||
return NewModulusFromBig(bb)
|
||||
m, _ := NewModulusFromBig(bb)
|
||||
return m
|
||||
}
|
||||
|
||||
// maxModulus returns the biggest modulus that can fit in n limbs.
|
||||
func maxModulus(n uint) *Modulus {
|
||||
m := big.NewInt(1)
|
||||
m.Lsh(m, n*_W)
|
||||
m.Sub(m, big.NewInt(1))
|
||||
return NewModulusFromBig(m)
|
||||
b := big.NewInt(1)
|
||||
b.Lsh(b, n*_W)
|
||||
b.Sub(b, big.NewInt(1))
|
||||
m, _ := NewModulusFromBig(b)
|
||||
return m
|
||||
}
|
||||
|
||||
func makeBenchmarkModulus() *Modulus {
|
||||
@ -362,7 +370,7 @@ func makeBenchmarkModulus() *Modulus {
|
||||
func makeBenchmarkValue() *Nat {
|
||||
x := make([]uint, 32)
|
||||
for i := 0; i < 32; i++ {
|
||||
x[i] = _MASK - 1
|
||||
x[i]--
|
||||
}
|
||||
return &Nat{limbs: x}
|
||||
}
|
||||
@ -456,3 +464,17 @@ func BenchmarkExp(b *testing.B) {
|
||||
out.Exp(x, e, m)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModFromBigZero(t *testing.T) {
|
||||
expected := "modulus must be >= 0"
|
||||
_, err := NewModulusFromBig(big.NewInt(0))
|
||||
if err == nil || err.Error() != expected {
|
||||
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
|
||||
}
|
||||
|
||||
expected = "modulus must be odd"
|
||||
_, err = NewModulusFromBig(big.NewInt(2))
|
||||
if err == nil || err.Error() != expected {
|
||||
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
|
||||
}
|
||||
}
|
||||
|
@ -931,6 +931,6 @@ func p256() *sm2Curve {
|
||||
func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
|
||||
params := curve.Params()
|
||||
c.curve = curve
|
||||
c.N = bigmod.NewModulusFromBig(params.N)
|
||||
c.N, _ = bigmod.NewModulusFromBig(params.N)
|
||||
c.nMinus2 = new(big.Int).Sub(params.N, big.NewInt(2)).Bytes()
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ import (
|
||||
|
||||
// SM9 ASN.1 format reference: Information security technology - SM9 cryptographic algorithm application specification
|
||||
|
||||
var orderNat = bigmod.NewModulusFromBig(bn256.Order)
|
||||
var orderNat, _ = bigmod.NewModulusFromBig(bn256.Order)
|
||||
var orderMinus2 = new(big.Int).Sub(bn256.Order, big.NewInt(2)).Bytes()
|
||||
var bigOne = big.NewInt(1)
|
||||
var bigOneNat *bigmod.Nat
|
||||
|
Loading…
x
Reference in New Issue
Block a user