gmsm/sm9/bn256/gfp_test.go
2023-07-27 15:59:52 +08:00

380 lines
9.4 KiB
Go

package bn256
import (
"encoding/hex"
"math/big"
"testing"
)
func TestGfpBasicOperations(t *testing.T) {
x := fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141"))
y := fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B"))
expectedAdd := fromBigInt(bigFromHex("0691692307d370af56226e57920199fbbe10f216c67fbc9468c7f225a4b1f21f"))
expectedDouble := fromBigInt(bigFromHex("551de7a0ee24723edcf314ff72f478fac1c7c4e7044238acc3913cfbcdaf7d05"))
expectedSub := fromBigInt(bigFromHex("67b381821c52a5624f3304a8149be8461e3bc07adcb872c38aa65051ba53ba97"))
expectedNeg := fromBigInt(bigFromHex("7f1d8aad70909be90358f1d02240062433cc3a0248ded72febb879ec33ce6f22"))
expectedMul := fromBigInt(bigFromHex("3d08bbad376584e4f74bd31f78f716372b96ba8c3f939c12b8d54e79b6489e76"))
expectedMul2 := fromBigInt(bigFromHex("1df94a9e05a559ff38e0ab50cece734dc058d33738ceacaa15986a67cbff1ef6"))
t.Parallel()
t.Run("add", func(t *testing.T) {
ret := &gfP{}
gfpAdd(ret, x, y)
if *expectedAdd != *ret {
t.Errorf("add not same")
}
x1 := &gfP{}
x1.Set(x)
gfpAdd(x1, x1, y)
if *expectedAdd != *x1 {
t.Errorf("add not same when add self")
}
})
t.Run("double", func(t *testing.T) {
ret := &gfP{}
gfpDouble(ret, x)
if ret.Equal(expectedDouble) != 1 {
t.Errorf("double not same, got %v, expected %v", ret, expectedDouble)
}
ret.Set(x)
gfpDouble(ret, ret)
if ret.Equal(expectedDouble) != 1 {
t.Errorf("double not same, got %v, expected %v", ret, expectedDouble)
}
})
t.Run("triple", func(t *testing.T) {
expected := &gfP{}
gfpAdd(expected, x, expectedDouble)
ret := &gfP{}
ret.Set(x)
gfpTriple(ret, ret)
if ret.Equal(expected) != 1 {
t.Errorf("expected %v, got %v", expected, ret)
}
})
t.Run("sub", func(t *testing.T) {
ret := &gfP{}
gfpSub(ret, y, x)
if *expectedSub != *ret {
t.Errorf("sub not same")
}
x1 := &gfP{}
x1.Set(x)
gfpSub(x1, y, x1)
if *expectedSub != *x1 {
t.Errorf("sub not same when sub self")
}
})
t.Run("neg", func(t *testing.T) {
ret := &gfP{}
gfpNeg(ret, y)
if *expectedNeg != *ret {
t.Errorf("neg not same")
}
ret.Set(y)
gfpNeg(ret, ret)
if *expectedNeg != *ret {
t.Errorf("neg not same when neg self")
}
})
t.Run("mul", func(t *testing.T) {
ret := &gfP{}
gfpMul(ret, x, y)
if *expectedMul != *ret {
t.Errorf("mul not same")
}
ret.Set(x)
gfpMul(ret, ret, y)
if *expectedMul != *ret {
t.Errorf("mul not same when mul self")
}
})
t.Run("square", func(t *testing.T) {
ret, ret1, ret2 := &gfP{}, &gfP{}, &gfP{}
gfpMul(ret, x, y)
gfpMul(ret1, ret, ret)
if *ret1 != *expectedMul2 {
t.Errorf("mul not same")
}
gfpMul(ret1, ret1, ret1)
gfpSqr(ret2, ret, 2)
if *ret1 != *ret2 {
t.Errorf("mul/sqr not same")
}
ret2.Set(ret)
gfpSqr(ret2, ret2, 2)
if *ret1 != *ret2 {
t.Errorf("mul/sqr not same when square self")
}
})
}
func TestGfpSqr(t *testing.T) {
t.Run("p-1", func(t *testing.T) {
pMinusOne := new(big.Int).Sub(p, big.NewInt(1))
x := fromBigInt(pMinusOne)
ret := &gfP{}
gfpSqr(ret, x, 1)
pMinusOne.Mul(pMinusOne, pMinusOne)
pMinusOne.Mod(pMinusOne, p)
expected := fromBigInt(pMinusOne)
if *ret != *expected {
t.Errorf("bad sqr")
}
})
t.Run("p+1", func(t *testing.T) {
pPlusOne := new(big.Int).Add(p, big.NewInt(1))
x := fromBigInt(pPlusOne)
ret := &gfP{}
gfpSqr(ret, x, 1)
pPlusOne.Mul(pPlusOne, pPlusOne)
pPlusOne.Mod(pPlusOne, p)
if *ret != *fromBigInt(pPlusOne) {
t.Errorf("bad sqr")
}
})
}
func TestFromMont(t *testing.T) {
x := fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141"))
ret1, ret2 := &gfP{}, &gfP{}
gfpFromMont(ret1, x)
gfpMul(ret2, x, &gfP{1})
if *ret1 != *ret2 {
t.Errorf("mul/fromMont not same")
}
}
func TestGfpExp(t *testing.T) {
xI := bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596")
x := fromBigInt(xI)
ret, ret3 := &gfP{}, &gfP{}
ret.exp(x, pMinus2)
gfpMul(ret3, x, ret)
if *ret3 != *one {
t.Errorf("got %v, expected %v\n", ret3, one)
}
montDecode(ret, ret)
ret2 := new(big.Int).Exp(xI, bigFromHex("b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457b"), p)
if hex.EncodeToString(ret2.Bytes()) != ret.String() {
t.Errorf("exp not same, got %v, expected %v\n", ret, hex.EncodeToString(ret2.Bytes()))
}
xInv := new(big.Int).ModInverse(xI, p)
if hex.EncodeToString(ret2.Bytes()) != hex.EncodeToString(xInv.Bytes()) {
t.Errorf("exp not same, got %v, expected %v\n", hex.EncodeToString(ret2.Bytes()), hex.EncodeToString(xInv.Bytes()))
}
x2 := new(big.Int).Mul(xI, xInv)
x2.Mod(x2, p)
if big.NewInt(1).Cmp(x2) != 0 {
t.Errorf("not same")
}
xInvGfp := fromBigInt(xInv)
gfpMul(ret, x, xInvGfp)
if *ret != *one {
t.Errorf("got %v, expected %v", ret, one)
}
}
func TestSqrt(t *testing.T) {
tests := []string{
"9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596",
"92fe90b700fbd4d8cc177d300ed16e4e15471a681b2c9e3728c1b82c885e49c2",
}
for i, test := range tests {
y2 := bigFromHex(test)
y21 := new(big.Int).ModSqrt(y2, p)
y3 := new(big.Int).Mul(y21, y21)
y3.Mod(y3, p)
if y2.Cmp(y3) != 0 {
t.Error("Invalid sqrt")
}
tmp := fromBigInt(y2)
tmp.Sqrt(tmp)
montDecode(tmp, tmp)
var res [32]byte
tmp.Marshal(res[:])
if hex.EncodeToString(res[:]) != hex.EncodeToString(y21.Bytes()) {
t.Errorf("case %v, got %v, expected %v\n", i, hex.EncodeToString(res[:]), hex.EncodeToString(y21.Bytes()))
}
}
}
func TestGeneratedSqrt(t *testing.T) {
tests := []string{
"9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596",
"92fe90b700fbd4d8cc177d300ed16e4e15471a681b2c9e3728c1b82c885e49c2",
}
for i, test := range tests {
y2 := bigFromHex(test)
y21 := new(big.Int).ModSqrt(y2, p)
y3 := new(big.Int).Mul(y21, y21)
y3.Mod(y3, p)
if y2.Cmp(y3) != 0 {
t.Error("Invalid sqrt")
}
tmp := fromBigInt(y2)
e := &gfP{}
Sqrt(e, tmp)
montDecode(e, e)
var res [32]byte
e.Marshal(res[:])
if hex.EncodeToString(res[:]) != hex.EncodeToString(y21.Bytes()) {
t.Errorf("case %v, got %v, expected %v\n", i, hex.EncodeToString(res[:]), hex.EncodeToString(y21.Bytes()))
}
}
}
func TestInvert(t *testing.T) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
xInv := &gfP{}
xInv.Invert(x)
y := &gfP{}
gfpMul(y, x, xInv)
if *y != *one {
t.Errorf("got %v, expected %v", y, one)
}
}
func TestGfpNeg(t *testing.T) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
got := &gfP{}
gfpSub(got, zero, x)
expected := &gfP{}
gfpNeg(expected, x)
if *expected != *got {
t.Errorf("got %v, expected %v", got, expected)
}
gfpSub(got, zero, zero)
gfpNeg(expected, zero)
if *expected != *got {
t.Errorf("got %v, expected %v", got, expected)
}
}
func BenchmarkGfPUnmarshal(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
var out [32]byte
x.Marshal(out[:])
for i := 0; i < b.N; i++ {
x.Unmarshal(out[:])
}
}
func BenchmarkGfPMul(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpMul(ret, x, x)
}
}
func BenchmarkGfPSqr(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpSqr(ret, x, 1)
}
}
func BenchmarkGfPTriple(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpTriple(ret, x)
}
}
func BenchmarkGfPTriple2(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpAdd(ret, x, x)
gfpAdd(ret, ret, x)
}
}
func BenchmarkGfPDouble(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpDouble(ret, x)
}
}
func BenchmarkGfPDouble2(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpAdd(ret, x, x)
}
}
func BenchmarkGfPNeg(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpNeg(ret, x)
}
}
func BenchmarkGfPNeg2(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
gfpSub(ret, zero, x)
}
}
func BenchmarkGfPInvert(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
ret.Invert(x)
}
}
func BenchmarkGfPInvert2(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
ret := &gfP{}
for i := 0; i < b.N; i++ {
ret.exp(x, pMinus2)
}
}