gmsm/internal/bigmod/nat_test.go

825 lines
20 KiB
Go

// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package bigmod
import (
"bufio"
"bytes"
cryptorand "crypto/rand"
"encoding/hex"
"fmt"
"math/big"
"math/bits"
"math/rand"
"os"
"reflect"
"strings"
"testing"
"testing/quick"
)
// setBig assigns x = n, optionally resizing n to the appropriate size.
//
// The announced length of x is set based on the actual bit size of the input,
// ignoring leading zeroes.
func (x *Nat) setBig(n *big.Int) *Nat {
limbs := n.Bits()
x.reset(len(limbs))
for i := range limbs {
x.limbs[i] = uint(limbs[i])
}
return x
}
func (n *Nat) asBig() *big.Int {
bits := make([]big.Word, len(n.limbs))
for i := range n.limbs {
bits[i] = big.Word(n.limbs[i])
}
return new(big.Int).SetBits(bits)
}
func (n *Nat) String() string {
var limbs []string
for i := range n.limbs {
limbs = append(limbs, fmt.Sprintf("%016X", n.limbs[len(n.limbs)-1-i]))
}
return "{" + strings.Join(limbs, " ") + "}"
}
// Generate generates an even nat. It's used by testing/quick to produce random
// *nat values for quick.Check invocations.
func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
limbs := make([]uint, size)
for i := 0; i < size; i++ {
limbs[i] = uint(r.Uint64()) & ((1 << _W) - 2)
}
return reflect.ValueOf(&Nat{limbs})
}
func testModAddCommutative(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
aPlusB := new(Nat).set(a)
aPlusB.Add(b, m)
bPlusA := new(Nat).set(b)
bPlusA.Add(a, m)
return aPlusB.Equal(bPlusA) == 1
}
func TestModAddCommutative(t *testing.T) {
err := quick.Check(testModAddCommutative, &quick.Config{})
if err != nil {
t.Error(err)
}
}
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
original := new(Nat).set(a)
a.Sub(b, m)
a.Add(b, m)
return a.Equal(original) == 1
}
func TestModSubThenAddIdentity(t *testing.T) {
err := quick.Check(testModSubThenAddIdentity, &quick.Config{})
if err != nil {
t.Error(err)
}
}
func TestMontgomeryRoundtrip(t *testing.T) {
err := quick.Check(func(a *Nat) bool {
one := &Nat{make([]uint, len(a.limbs))}
one.limbs[0] = 1
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulus(aPlusOne.Bytes())
monty := new(Nat).set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).set(monty)
aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain)
return false
}
return true
}, &quick.Config{})
if err != nil {
t.Error(err)
}
}
func TestShiftIn(t *testing.T) {
if bits.UintSize != 64 {
t.Skip("examples are only valid in 64 bit")
}
examples := []struct {
m, x, expected []byte
y uint64
}{{
m: []byte{13},
x: []byte{0},
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{2},
}, {
m: []byte{13},
x: []byte{7},
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{10},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
x: make([]byte, 9),
y: 0xFFFF_FFFF_FFFF_FFFF,
expected: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
}, {
m: []byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d},
x: []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
y: 0,
expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06},
}}
for i, tt := range examples {
m := modulusFromBytes(tt.m)
got := natFromBytes(tt.x).ExpandFor(m).shiftIn(uint(tt.y), m)
if exp := natFromBytes(tt.expected).ExpandFor(m); got.Equal(exp) != 1 {
t.Errorf("%d: got %v, expected %v", i, got, exp)
}
}
}
func TestModulusAndNatSizes(t *testing.T) {
// These are 126 bit (2 * _W on 64-bit architectures) values, serialized as
// 128 bits worth of bytes. If leading zeroes are stripped, they fit in two
// limbs, if they are not, they fit in three. This can be a problem because
// modulus strips leading zeroes and nat does not.
m := modulusFromBytes([]byte{
0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
xb := []byte{0x3f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe}
natFromBytes(xb).ExpandFor(m) // must not panic for shrinking
NewNat().SetBytes(xb, m)
}
func TestSetBytes(t *testing.T) {
tests := []struct {
m, b []byte
fail bool
}{{
m: []byte{0xff, 0xff},
b: []byte{0x00, 0x01},
}, {
m: []byte{0xff, 0xff},
b: []byte{0xff, 0xff},
fail: true,
}, {
m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0x00, 0x01},
}, {
m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
}, {
m: []byte{0xff, 0xff},
b: []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
fail: true,
}, {
m: []byte{0xff, 0xff},
b: []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
fail: true,
}, {
m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
}, {
m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
fail: true,
}, {
m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
fail: true,
}, {
m: []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfe},
fail: true,
}, {
m: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xfd},
b: []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff},
fail: true,
}}
for i, tt := range tests {
m := modulusFromBytes(tt.m)
got, err := NewNat().SetBytes(tt.b, m)
if err != nil {
if !tt.fail {
t.Errorf("%d: unexpected error: %v", i, err)
}
continue
}
if tt.fail {
t.Errorf("%d: unexpected success", i)
continue
}
if expected := natFromBytes(tt.b).ExpandFor(m); got.Equal(expected) != yes {
t.Errorf("%d: got %v, expected %v", i, got, expected)
}
}
f := func(xBytes []byte) bool {
m := maxModulus(uint(len(xBytes)*8/_W + 1))
got, err := NewNat().SetBytes(xBytes, m)
if err != nil {
return false
}
return got.Equal(natFromBytes(xBytes).ExpandFor(m)) == yes
}
err := quick.Check(f, &quick.Config{})
if err != nil {
t.Error(err)
}
}
func TestExpand(t *testing.T) {
sliced := []uint{1, 2, 3, 4}
examples := []struct {
in []uint
n int
out []uint
}{{
[]uint{1, 2},
4,
[]uint{1, 2, 0, 0},
}, {
sliced[:2],
4,
[]uint{1, 2, 0, 0},
}, {
[]uint{1, 2},
2,
[]uint{1, 2},
}}
for i, tt := range examples {
got := (&Nat{tt.in}).expand(tt.n)
if len(got.limbs) != len(tt.out) || got.Equal(&Nat{tt.out}) != 1 {
t.Errorf("%d: got %v, expected %v", i, got, tt.out)
}
}
}
func TestMod(t *testing.T) {
m := modulusFromBytes([]byte{0x06, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0d})
x := natFromBytes([]byte{0x40, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01})
out := new(Nat)
out.Mod(x, m)
expected := natFromBytes([]byte{0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09})
if out.Equal(expected) != 1 {
t.Errorf("%+v != %+v", out, expected)
}
}
func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
y := &Nat{[]uint{7}}
x.Sub(y, m)
expected := &Nat{[]uint{12}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
x.Sub(y, m)
expected = &Nat{[]uint{5}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
}
func TestModAdd(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
y := &Nat{[]uint{7}}
x.Add(y, m)
expected := &Nat{[]uint{0}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
x.Add(y, m)
expected = &Nat{[]uint{7}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
}
func TestExp(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{3}}
out := &Nat{[]uint{0}}
out.Exp(x, []byte{12}, m)
expected := &Nat{[]uint{1}}
if out.Equal(expected) != 1 {
t.Errorf("%+v != %+v", out, expected)
}
}
func TestExpShort(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{3}}
out := &Nat{[]uint{0}}
out.ExpShortVarTime(x, 12, m)
expected := &Nat{[]uint{1}}
if out.Equal(expected) != 1 {
t.Errorf("%+v != %+v", out, expected)
}
}
// TestMulReductions tests that Mul reduces results equal or slightly greater
// than the modulus. Some Montgomery algorithms don't and need extra care to
// return correct results. See https://go.dev/issue/13907.
func TestMulReductions(t *testing.T) {
// Two short but multi-limb primes.
a, _ := new(big.Int).SetString("773608962677651230850240281261679752031633236267106044359907", 10)
b, _ := new(big.Int).SetString("180692823610368451951102211649591374573781973061758082626801", 10)
n := new(big.Int).Mul(a, b)
N, _ := NewModulus(n.Bytes())
A := NewNat().setBig(a).ExpandFor(N)
B := NewNat().setBig(b).ExpandFor(N)
if A.Mul(B, N).IsZero() != 1 {
t.Error("a * b mod (a * b) != 0")
}
i := new(big.Int).ModInverse(a, b)
N, _ = NewModulus(b.Bytes())
A = NewNat().setBig(a).ExpandFor(N)
I := NewNat().setBig(i).ExpandFor(N)
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
if A.Mul(I, N).Equal(one) != 1 {
t.Error("a * inv(a) mod b != 1")
}
}
func TestMul(t *testing.T) {
t.Run("760", func(t *testing.T) { testMul(t, 760/8) })
t.Run("256", func(t *testing.T) { testMul(t, 256/8) })
t.Run("1024", func(t *testing.T) { testMul(t, 1024/8) })
t.Run("1536", func(t *testing.T) { testMul(t, 1536/8) })
t.Run("2048", func(t *testing.T) { testMul(t, 2048/8) })
}
func testMul(t *testing.T, n int) {
a, b, m := make([]byte, n), make([]byte, n), make([]byte, n)
cryptorand.Read(a)
cryptorand.Read(b)
cryptorand.Read(m)
// Pick the highest as the modulus.
if bytes.Compare(a, m) > 0 {
a, m = m, a
}
if bytes.Compare(b, m) > 0 {
b, m = m, b
}
M, err := NewModulus(m)
if err != nil {
t.Fatal(err)
}
A, err := NewNat().SetBytes(a, M)
if err != nil {
t.Fatal(err)
}
B, err := NewNat().SetBytes(b, M)
if err != nil {
t.Fatal(err)
}
A.Mul(B, M)
ABytes := A.Bytes(M)
mBig := new(big.Int).SetBytes(m)
aBig := new(big.Int).SetBytes(a)
bBig := new(big.Int).SetBytes(b)
nBig := new(big.Int).Mul(aBig, bBig)
nBig.Mod(nBig, mBig)
nBigBytes := make([]byte, len(ABytes))
nBig.FillBytes(nBigBytes)
if !bytes.Equal(ABytes, nBigBytes) {
t.Errorf("got %x, want %x", ABytes, nBigBytes)
}
}
func TestIs(t *testing.T) {
checkYes := func(c choice, err string) {
t.Helper()
if c != yes {
t.Error(err)
}
}
checkNot := func(c choice, err string) {
t.Helper()
if c != no {
t.Error(err)
}
}
mFour := modulusFromBytes([]byte{4})
n, err := NewNat().SetBytes([]byte{3}, mFour)
if err != nil {
t.Fatal(err)
}
checkYes(n.IsMinusOne(mFour), "3 is not -1 mod 4")
checkNot(n.IsZero(), "3 is zero")
checkNot(n.IsOne(), "3 is one")
checkYes(n.IsOdd(), "3 is not odd")
n.SubOne(mFour)
checkNot(n.IsMinusOne(mFour), "2 is -1 mod 4")
checkNot(n.IsZero(), "2 is zero")
checkNot(n.IsOne(), "2 is one")
checkNot(n.IsOdd(), "2 is odd")
n.SubOne(mFour)
checkNot(n.IsMinusOne(mFour), "1 is -1 mod 4")
checkNot(n.IsZero(), "1 is zero")
checkYes(n.IsOne(), "1 is not one")
checkYes(n.IsOdd(), "1 is not odd")
n.SubOne(mFour)
checkNot(n.IsMinusOne(mFour), "0 is -1 mod 4")
checkYes(n.IsZero(), "0 is not zero")
checkNot(n.IsOne(), "0 is one")
checkNot(n.IsOdd(), "0 is odd")
n.SubOne(mFour)
checkYes(n.IsMinusOne(mFour), "-1 is not -1 mod 4")
checkNot(n.IsZero(), "-1 is zero")
checkNot(n.IsOne(), "-1 is one")
checkYes(n.IsOdd(), "-1 mod 4 is not odd")
mTwoLimbs := maxModulus(2)
n, err = NewNat().SetBytes([]byte{0x01}, mTwoLimbs)
if err != nil {
t.Fatal(err)
}
if n.IsOne() != 1 {
t.Errorf("1 is not one")
}
}
func TestTrailingZeroBits(t *testing.T) {
nb := new(big.Int).SetBytes([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7e})
nb.Lsh(nb, 128)
expected := 129
for expected >= 0 {
n := NewNat().setBig(nb)
if n.TrailingZeroBitsVarTime() != uint(expected) {
t.Errorf("%d != %d", n.TrailingZeroBitsVarTime(), expected)
}
nb.Rsh(nb, 1)
expected--
}
}
func TestRightShift(t *testing.T) {
nb, err := cryptorand.Int(cryptorand.Reader, new(big.Int).Lsh(big.NewInt(1), 1024))
if err != nil {
t.Fatal(err)
}
for _, shift := range []uint{1, 32, 64, 128, 1024 - 128, 1024 - 64, 1024 - 32, 1024 - 1} {
testShift := func(t *testing.T, shift uint) {
n := NewNat().setBig(nb)
oldLen := len(n.limbs)
n.ShiftRightVarTime(shift)
if len(n.limbs) != oldLen {
t.Errorf("len(n.limbs) = %d, want %d", len(n.limbs), oldLen)
}
exp := new(big.Int).Rsh(nb, shift)
if n.asBig().Cmp(exp) != 0 {
t.Errorf("%v != %v", n.asBig(), exp)
}
}
t.Run(fmt.Sprint(shift-1), func(t *testing.T) { testShift(t, shift-1) })
t.Run(fmt.Sprint(shift), func(t *testing.T) { testShift(t, shift) })
t.Run(fmt.Sprint(shift+1), func(t *testing.T) { testShift(t, shift+1) })
}
}
func natBytes(n *Nat) []byte {
return n.Bytes(maxModulus(uint(len(n.limbs))))
}
func natFromBytes(b []byte) *Nat {
// Must not use Nat.SetBytes as it's used in TestSetBytes.
bb := new(big.Int).SetBytes(b)
return NewNat().setBig(bb)
}
func modulusFromBytes(b []byte) *Modulus {
bb := new(big.Int).SetBytes(b)
m, _ := NewModulus(bb.Bytes())
return m
}
// maxModulus returns the biggest modulus that can fit in n limbs.
func maxModulus(n uint) *Modulus {
b := big.NewInt(1)
b.Lsh(b, n*_W)
b.Sub(b, big.NewInt(1))
m, _ := NewModulus(b.Bytes())
return m
}
func makeBenchmarkModulus(n uint) *Modulus {
return maxModulus(n)
}
func makeBenchmarkValue(n int) *Nat {
x := make([]uint, n)
for i := 0; i < n; i++ {
x[i]--
}
return &Nat{limbs: x}
}
func makeBenchmarkExponent() []byte {
e := make([]byte, 256)
for i := 0; i < 32; i++ {
e[i] = 0xFF
}
return e
}
func BenchmarkRR256(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
makeBenchmarkModulus(4)
}
}
func BenchmarkModAdd(b *testing.B) {
x := makeBenchmarkValue(32)
y := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Add(y, m)
}
}
func BenchmarkModSub(b *testing.B) {
x := makeBenchmarkValue(32)
y := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Sub(y, m)
}
}
func BenchmarkMontgomeryRepr(b *testing.B) {
x := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.montgomeryRepresentation(m)
}
}
func BenchmarkMontgomeryMul(b *testing.B) {
x := makeBenchmarkValue(32)
y := makeBenchmarkValue(32)
out := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out.montgomeryMul(x, y, m)
}
}
func BenchmarkModMul(b *testing.B) {
x := makeBenchmarkValue(32)
y := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(y, m)
}
}
func BenchmarkModMul256(b *testing.B) {
x := makeBenchmarkValue(4)
y := makeBenchmarkValue(4)
m := makeBenchmarkModulus(4)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(y, m)
}
}
func BenchmarkExpBig(b *testing.B) {
out := new(big.Int)
exponentBytes := makeBenchmarkExponent()
x := new(big.Int).SetBytes(exponentBytes)
e := new(big.Int).SetBytes(exponentBytes)
n := new(big.Int).SetBytes(exponentBytes)
one := new(big.Int).SetUint64(1)
n.Add(n, one)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out.Exp(x, e, n)
}
}
func BenchmarkExp(b *testing.B) {
x := makeBenchmarkValue(32)
e := makeBenchmarkExponent()
out := makeBenchmarkValue(32)
m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
out.Exp(x, e, m)
}
}
func TestNewModulus(t *testing.T) {
expected := "modulus must be > 1"
_, err := NewModulus([]byte{})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
}
_, err = NewModulus([]byte{0})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
}
_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(0) got %q, want %q", err, expected)
}
_, err = NewModulus([]byte{1})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(1) got %q, want %q", err, expected)
}
_, err = NewModulus([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1})
if err == nil || err.Error() != expected {
t.Errorf("NewModulus(1) got %q, want %q", err, expected)
}
}
func TestOverflowedBytes(t *testing.T) {
cases := []string{
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25",
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf23",
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24",
"b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf24b640000002a3a6f1",
"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff",
"00",
}
mBytes, _ := hex.DecodeString(cases[0])
m, err := NewModulus(mBytes)
if err != nil {
t.Fatal(err)
}
bigOne := big.NewInt(1)
mBigInt := new(big.Int).SetBytes(mBytes)
mMinusOne := new(big.Int).Sub(mBigInt, bigOne)
for _, c := range cases {
d, _ := hex.DecodeString(c)
k := new(big.Int).SetBytes(d)
k = new(big.Int).Mod(k, mMinusOne)
k = new(big.Int).Add(k, bigOne)
k = new(big.Int).Mod(k, mBigInt)
kNat := NewNat().SetOverflowedBytes(d, m)
k2 := new(big.Int).SetBytes(kNat.Bytes(m))
if !bytes.Equal(k2.Bytes(), k.Bytes()) {
t.Errorf("%s, expected %x, got %x", c, k.Bytes(), k2.Bytes())
}
}
}
func makeTestValue(nbits int) []uint {
n := nbits / _W
x := make([]uint, n)
for i := 0; i < n; i++ {
x[i]--
}
return x
}
func slicesEqual(a, b []uint) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}
func TestAddMulVVWSized(t *testing.T) {
// Sized addMulVVW have architecture-specific implementations on
// a number of architectures. Test that they match the generic
// implementation.
tests := []struct {
n int
f func(z, x *uint, y uint) uint
}{
{256, addMulVVW256},
{1024, addMulVVW1024},
{1536, addMulVVW1536},
{2048, addMulVVW2048},
}
for _, test := range tests {
t.Run(fmt.Sprint(test.n), func(t *testing.T) {
x := makeTestValue(test.n)
z := makeTestValue(test.n)
z2 := makeTestValue(test.n)
var y uint
y--
c := addMulVVW(z, x, y)
c2 := test.f(&z2[0], &x[0], y)
if !slicesEqual(z, z2) || c != c2 {
t.Errorf("%016X, %016X != %016X, %016X", z, c, z2, c2)
}
})
}
}
func TestInverse(t *testing.T) {
f, err := os.Open("testdata/mod_inv_tests.txt")
if err != nil {
t.Fatal(err)
}
var ModInv, A, M string
var lineNum int
scanner := bufio.NewScanner(f)
for scanner.Scan() {
lineNum++
line := scanner.Text()
if len(line) == 0 || line[0] == '#' {
continue
}
k, v, _ := strings.Cut(line, " = ")
switch k {
case "ModInv":
ModInv = v
case "A":
A = v
case "M":
M = v
t.Run(fmt.Sprintf("line %d", lineNum), func(t *testing.T) {
m, err := NewModulus(decodeHex(t, M))
if err != nil {
t.Skip("modulus <= 1")
}
a, err := NewNat().SetBytes(decodeHex(t, A), m)
if err != nil {
t.Fatal(err)
}
got, ok := NewNat().InverseVarTime(a, m)
if !ok {
t.Fatal("not invertible")
}
exp, err := NewNat().SetBytes(decodeHex(t, ModInv), m)
if err != nil {
t.Fatal(err)
}
if got.Equal(exp) != 1 {
t.Errorf("%v != %v", got, exp)
}
})
default:
t.Fatalf("unknown key %q on line %d", k, lineNum)
}
}
if err := scanner.Err(); err != nil {
t.Fatal(err)
}
}
func decodeHex(t *testing.T, s string) []byte {
t.Helper()
if len(s)%2 != 0 {
s = "0" + s
}
b, err := hex.DecodeString(s)
if err != nil {
t.Fatalf("failed to decode hex %q: %v", s, err)
}
return b
}