sm9: clean code and unit test

This commit is contained in:
Sun Yimin 2022-07-19 08:58:12 +08:00 committed by GitHub
parent 24765d0e35
commit 711508985e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 111 additions and 113 deletions

View File

@ -1,6 +1,7 @@
package bn256 package bn256
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"io" "io"
"math/big" "math/big"
@ -141,6 +142,34 @@ func TestG1BaseMult(t *testing.T) {
} }
} }
func TestG1ScaleMult(t *testing.T) {
k, e, err := RandomG1(rand.Reader)
if err != nil {
t.Fatal(err)
}
e.p.MakeAffine()
e2, e3 := &G1{}, &G1{}
if e2.p == nil {
e2.p = &curvePoint{}
}
e2.p.Mul(curveGen, k)
e2.p.MakeAffine()
if !e.Equal(e2) {
t.Errorf("not same")
}
e3.ScalarMult(Gen1, k)
e3.p.MakeAffine()
if !e.Equal(e3) {
t.Errorf("not same")
}
}
func TestFuzz(t *testing.T) { func TestFuzz(t *testing.T) {
g1 := g1Curve g1 := g1Curve
g1Generic := g1.Params() g1Generic := g1.Params()
@ -277,21 +306,74 @@ func TestInfinity(t *testing.T) {
*/ */
} }
func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
tests := []struct {
name string
curve Curve
}{
{"g1", g1Curve},
{"g1/Params", g1Curve.params},
}
if testing.Short() {
tests = tests[:1]
}
for _, test := range tests {
curve := test.curve
t.Run(test.name, func(t *testing.T) {
t.Parallel()
f(t, curve)
})
}
}
func TestMarshal(t *testing.T) { func TestMarshal(t *testing.T) {
_, x, y, err := GenerateKey(g1Curve, rand.Reader) testAllCurves(t, func(t *testing.T, curve Curve) {
_, x, y, err := GenerateKey(curve, rand.Reader)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
serialized := Marshal(g1Curve, x, y) serialized := Marshal(curve, x, y)
xx, yy := Unmarshal(g1Curve, serialized) xx, yy := Unmarshal(curve, serialized)
if xx == nil { if xx == nil {
t.Fatal("failed to unmarshal") t.Fatal("failed to unmarshal")
} }
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
t.Fatal("unmarshal returned different values") t.Fatal("unmarshal returned different values")
} }
})
} }
func TestMarshalCompressed(t *testing.T) {
testAllCurves(t, func(t *testing.T, curve Curve) {
_, x, y, err := GenerateKey(curve, rand.Reader)
if err != nil {
t.Fatal(err)
}
testMarshalCompressed(t, curve, x, y, nil)
})
}
func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
if !curve.IsOnCurve(x, y) {
t.Fatal("invalid test point")
}
got := MarshalCompressed(curve, x, y)
if want != nil && !bytes.Equal(got, want) {
t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
}
X, Y := UnmarshalCompressed(curve, got)
if X == nil || Y == nil {
t.Fatalf("UnmarshalCompressed failed unexpectedly")
}
if !curve.IsOnCurve(X, Y) {
t.Error("UnmarshalCompressed returned a point not on the curve")
}
if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
}
}
func TestInvalidCoordinates(t *testing.T) { func TestInvalidCoordinates(t *testing.T) {
checkIsOnCurveFalse := func(name string, x, y *big.Int) { checkIsOnCurveFalse := func(name string, x, y *big.Int) {
if g1Curve.IsOnCurve(x, y) { if g1Curve.IsOnCurve(x, y) {

View File

@ -74,6 +74,25 @@ func Test_G2MarshalCompressed(t *testing.T) {
} }
} }
func TestScaleMult(t *testing.T) {
k, e, err := RandomG2(rand.Reader)
if err != nil {
t.Fatal(err)
}
e.p.MakeAffine()
e2, e3 := &G2{}, &G2{}
e3.p = &twistPoint{}
e3.p.Mul(twistGen, k)
e3.p.MakeAffine()
e2.ScalarMult(Gen2, k)
e2.p.MakeAffine()
if !e.Equal(e2) {
t.Errorf("not same")
}
}
func BenchmarkG2(b *testing.B) { func BenchmarkG2(b *testing.B) {
x, _ := rand.Int(rand.Reader, Order) x, _ := rand.Int(rand.Reader, Order)
b.ReportAllocs() b.ReportAllocs()

View File

@ -1,13 +1,10 @@
package bn256 package bn256
import ( import (
"crypto/sha256"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"math/big" "math/big"
"golang.org/x/crypto/hkdf"
) )
type gfP [4]uint64 type gfP [4]uint64
@ -58,32 +55,6 @@ func fromBigInt(x *big.Int) (out *gfP) {
return out return out
} }
// hashToBase implements hashing a message to an element of the field.
//
// L = ceil((256+128)/8)=48, ctr = 0, i = 1
func hashToBase(msg, dst []byte) *gfP {
var t [48]byte
info := []byte{'H', '2', 'C', byte(0), byte(1)}
r := hkdf.New(sha256.New, msg, dst, info)
if _, err := r.Read(t[:]); err != nil {
panic(err)
}
var x big.Int
v := x.SetBytes(t[:]).Mod(&x, p).Bytes()
v32 := [32]byte{}
for i := len(v) - 1; i >= 0; i-- {
v32[len(v)-1-i] = v[i]
}
u := &gfP{
binary.LittleEndian.Uint64(v32[0*8 : 1*8]),
binary.LittleEndian.Uint64(v32[1*8 : 2*8]),
binary.LittleEndian.Uint64(v32[2*8 : 3*8]),
binary.LittleEndian.Uint64(v32[3*8 : 4*8]),
}
montEncode(u, u)
return u
}
func (e *gfP) String() string { func (e *gfP) String() string {
return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0]) return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0])
} }
@ -165,47 +136,6 @@ func (e *gfP) Unmarshal(in []byte) error {
func montEncode(c, a *gfP) { gfpMul(c, a, r2) } func montEncode(c, a *gfP) { gfpMul(c, a, r2) }
func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) } func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) }
func sign0(e *gfP) int {
x := &gfP{}
montDecode(x, e)
for w := 3; w >= 0; w-- {
if x[w] > pMinus1Over2[w] {
return 1
} else if x[w] < pMinus1Over2[w] {
return -1
}
}
return 1
}
func legendre(e *gfP) int {
f := &gfP{}
// Since p = 8k+5, then e^(4k+2) is the Legendre symbol of e.
f.exp(e, pMinus1Over2)
montDecode(f, f)
if *f != [4]uint64{} {
return 2*int(f[0]&1) - 1
}
return 0
}
func (e *gfP) Div2(f *gfP) *gfP {
ret := &gfP{}
gfpMul(ret, f, twoInvert)
e.Set(ret)
return e
}
var twoInvert = &gfP{}
func init() {
t1 := newGFp(2)
twoInvert.Invert(t1)
}
// cmovznzU64 is a single-word conditional move. // cmovznzU64 is a single-word conditional move.
// //
// Postconditions: // Postconditions:

View File

@ -281,16 +281,6 @@ func (ret *gfP2) Sqrt(a *gfP2) *gfP2 {
return ret return ret
} }
// Div2 e = f / 2, not used currently
func (e *gfP2) Div2(f *gfP2) *gfP2 {
t := &gfP2{}
t.x.Div2(&f.x)
t.y.Div2(&f.y)
e.Set(t)
return e
}
// Select sets e to p1 if cond == 1, and to p2 if cond == 0. // Select sets e to p1 if cond == 1, and to p2 if cond == 0.
func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 { func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 {
e.x.Select(&p1.x, &p2.x, cond) e.x.Select(&p1.x, &p2.x, cond)

View File

@ -106,19 +106,6 @@ func Test_gfP2Frobenius(t *testing.T) {
} }
} }
func Test_gfP2Div2(t *testing.T) {
x := &gfP2{
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
*fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")),
}
ret := &gfP2{}
ret.Div2(x)
ret.Add(ret, ret)
if *ret != *x {
t.Errorf("got %v, expected %v", ret, x)
}
}
func Test_gfP2Sqrt(t *testing.T) { func Test_gfP2Sqrt(t *testing.T) {
x := &gfP2{ x := &gfP2{
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),

View File

@ -113,13 +113,3 @@ func TestInvert(t *testing.T) {
t.Errorf("got %v, expected %v", y, one) t.Errorf("got %v, expected %v", y, one)
} }
} }
func TestDiv(t *testing.T) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
ret := &gfP{}
ret.Div2(x)
gfpAdd(ret, ret, ret)
if *ret != *x {
t.Errorf("got %v, expected %v", ret, x)
}
}