mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
sm9: clean code and unit test
This commit is contained in:
parent
24765d0e35
commit
711508985e
@ -1,6 +1,7 @@
|
|||||||
package bn256
|
package bn256
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
@ -141,6 +142,34 @@ func TestG1BaseMult(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestG1ScaleMult(t *testing.T) {
|
||||||
|
k, e, err := RandomG1(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
e.p.MakeAffine()
|
||||||
|
|
||||||
|
e2, e3 := &G1{}, &G1{}
|
||||||
|
|
||||||
|
if e2.p == nil {
|
||||||
|
e2.p = &curvePoint{}
|
||||||
|
}
|
||||||
|
|
||||||
|
e2.p.Mul(curveGen, k)
|
||||||
|
e2.p.MakeAffine()
|
||||||
|
|
||||||
|
if !e.Equal(e2) {
|
||||||
|
t.Errorf("not same")
|
||||||
|
}
|
||||||
|
|
||||||
|
e3.ScalarMult(Gen1, k)
|
||||||
|
e3.p.MakeAffine()
|
||||||
|
|
||||||
|
if !e.Equal(e3) {
|
||||||
|
t.Errorf("not same")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestFuzz(t *testing.T) {
|
func TestFuzz(t *testing.T) {
|
||||||
g1 := g1Curve
|
g1 := g1Curve
|
||||||
g1Generic := g1.Params()
|
g1Generic := g1.Params()
|
||||||
@ -277,21 +306,74 @@ func TestInfinity(t *testing.T) {
|
|||||||
*/
|
*/
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func testAllCurves(t *testing.T, f func(*testing.T, Curve)) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
curve Curve
|
||||||
|
}{
|
||||||
|
{"g1", g1Curve},
|
||||||
|
{"g1/Params", g1Curve.params},
|
||||||
|
}
|
||||||
|
if testing.Short() {
|
||||||
|
tests = tests[:1]
|
||||||
|
}
|
||||||
|
for _, test := range tests {
|
||||||
|
curve := test.curve
|
||||||
|
t.Run(test.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
f(t, curve)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestMarshal(t *testing.T) {
|
func TestMarshal(t *testing.T) {
|
||||||
_, x, y, err := GenerateKey(g1Curve, rand.Reader)
|
testAllCurves(t, func(t *testing.T, curve Curve) {
|
||||||
|
_, x, y, err := GenerateKey(curve, rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
serialized := Marshal(g1Curve, x, y)
|
serialized := Marshal(curve, x, y)
|
||||||
xx, yy := Unmarshal(g1Curve, serialized)
|
xx, yy := Unmarshal(curve, serialized)
|
||||||
if xx == nil {
|
if xx == nil {
|
||||||
t.Fatal("failed to unmarshal")
|
t.Fatal("failed to unmarshal")
|
||||||
}
|
}
|
||||||
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
|
if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 {
|
||||||
t.Fatal("unmarshal returned different values")
|
t.Fatal("unmarshal returned different values")
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMarshalCompressed(t *testing.T) {
|
||||||
|
testAllCurves(t, func(t *testing.T, curve Curve) {
|
||||||
|
_, x, y, err := GenerateKey(curve, rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
testMarshalCompressed(t, curve, x, y, nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testMarshalCompressed(t *testing.T, curve Curve, x, y *big.Int, want []byte) {
|
||||||
|
if !curve.IsOnCurve(x, y) {
|
||||||
|
t.Fatal("invalid test point")
|
||||||
|
}
|
||||||
|
got := MarshalCompressed(curve, x, y)
|
||||||
|
if want != nil && !bytes.Equal(got, want) {
|
||||||
|
t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
X, Y := UnmarshalCompressed(curve, got)
|
||||||
|
if X == nil || Y == nil {
|
||||||
|
t.Fatalf("UnmarshalCompressed failed unexpectedly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !curve.IsOnCurve(X, Y) {
|
||||||
|
t.Error("UnmarshalCompressed returned a point not on the curve")
|
||||||
|
}
|
||||||
|
if X.Cmp(x) != 0 || Y.Cmp(y) != 0 {
|
||||||
|
t.Errorf("point did not round-trip correctly: got (%v, %v), want (%v, %v)", X, Y, x, y)
|
||||||
|
}
|
||||||
|
}
|
||||||
func TestInvalidCoordinates(t *testing.T) {
|
func TestInvalidCoordinates(t *testing.T) {
|
||||||
checkIsOnCurveFalse := func(name string, x, y *big.Int) {
|
checkIsOnCurveFalse := func(name string, x, y *big.Int) {
|
||||||
if g1Curve.IsOnCurve(x, y) {
|
if g1Curve.IsOnCurve(x, y) {
|
||||||
|
@ -74,6 +74,25 @@ func Test_G2MarshalCompressed(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestScaleMult(t *testing.T) {
|
||||||
|
k, e, err := RandomG2(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
e.p.MakeAffine()
|
||||||
|
|
||||||
|
e2, e3 := &G2{}, &G2{}
|
||||||
|
e3.p = &twistPoint{}
|
||||||
|
e3.p.Mul(twistGen, k)
|
||||||
|
e3.p.MakeAffine()
|
||||||
|
|
||||||
|
e2.ScalarMult(Gen2, k)
|
||||||
|
e2.p.MakeAffine()
|
||||||
|
if !e.Equal(e2) {
|
||||||
|
t.Errorf("not same")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkG2(b *testing.B) {
|
func BenchmarkG2(b *testing.B) {
|
||||||
x, _ := rand.Int(rand.Reader, Order)
|
x, _ := rand.Int(rand.Reader, Order)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
@ -1,13 +1,10 @@
|
|||||||
package bn256
|
package bn256
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"golang.org/x/crypto/hkdf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type gfP [4]uint64
|
type gfP [4]uint64
|
||||||
@ -58,32 +55,6 @@ func fromBigInt(x *big.Int) (out *gfP) {
|
|||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
// hashToBase implements hashing a message to an element of the field.
|
|
||||||
//
|
|
||||||
// L = ceil((256+128)/8)=48, ctr = 0, i = 1
|
|
||||||
func hashToBase(msg, dst []byte) *gfP {
|
|
||||||
var t [48]byte
|
|
||||||
info := []byte{'H', '2', 'C', byte(0), byte(1)}
|
|
||||||
r := hkdf.New(sha256.New, msg, dst, info)
|
|
||||||
if _, err := r.Read(t[:]); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
var x big.Int
|
|
||||||
v := x.SetBytes(t[:]).Mod(&x, p).Bytes()
|
|
||||||
v32 := [32]byte{}
|
|
||||||
for i := len(v) - 1; i >= 0; i-- {
|
|
||||||
v32[len(v)-1-i] = v[i]
|
|
||||||
}
|
|
||||||
u := &gfP{
|
|
||||||
binary.LittleEndian.Uint64(v32[0*8 : 1*8]),
|
|
||||||
binary.LittleEndian.Uint64(v32[1*8 : 2*8]),
|
|
||||||
binary.LittleEndian.Uint64(v32[2*8 : 3*8]),
|
|
||||||
binary.LittleEndian.Uint64(v32[3*8 : 4*8]),
|
|
||||||
}
|
|
||||||
montEncode(u, u)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *gfP) String() string {
|
func (e *gfP) String() string {
|
||||||
return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0])
|
return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0])
|
||||||
}
|
}
|
||||||
@ -165,47 +136,6 @@ func (e *gfP) Unmarshal(in []byte) error {
|
|||||||
func montEncode(c, a *gfP) { gfpMul(c, a, r2) }
|
func montEncode(c, a *gfP) { gfpMul(c, a, r2) }
|
||||||
func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) }
|
func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) }
|
||||||
|
|
||||||
func sign0(e *gfP) int {
|
|
||||||
x := &gfP{}
|
|
||||||
montDecode(x, e)
|
|
||||||
for w := 3; w >= 0; w-- {
|
|
||||||
if x[w] > pMinus1Over2[w] {
|
|
||||||
return 1
|
|
||||||
} else if x[w] < pMinus1Over2[w] {
|
|
||||||
return -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func legendre(e *gfP) int {
|
|
||||||
f := &gfP{}
|
|
||||||
// Since p = 8k+5, then e^(4k+2) is the Legendre symbol of e.
|
|
||||||
f.exp(e, pMinus1Over2)
|
|
||||||
|
|
||||||
montDecode(f, f)
|
|
||||||
|
|
||||||
if *f != [4]uint64{} {
|
|
||||||
return 2*int(f[0]&1) - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *gfP) Div2(f *gfP) *gfP {
|
|
||||||
ret := &gfP{}
|
|
||||||
gfpMul(ret, f, twoInvert)
|
|
||||||
e.Set(ret)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var twoInvert = &gfP{}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
t1 := newGFp(2)
|
|
||||||
twoInvert.Invert(t1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cmovznzU64 is a single-word conditional move.
|
// cmovznzU64 is a single-word conditional move.
|
||||||
//
|
//
|
||||||
// Postconditions:
|
// Postconditions:
|
||||||
|
@ -281,16 +281,6 @@ func (ret *gfP2) Sqrt(a *gfP2) *gfP2 {
|
|||||||
return ret
|
return ret
|
||||||
}
|
}
|
||||||
|
|
||||||
// Div2 e = f / 2, not used currently
|
|
||||||
func (e *gfP2) Div2(f *gfP2) *gfP2 {
|
|
||||||
t := &gfP2{}
|
|
||||||
t.x.Div2(&f.x)
|
|
||||||
t.y.Div2(&f.y)
|
|
||||||
|
|
||||||
e.Set(t)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select sets e to p1 if cond == 1, and to p2 if cond == 0.
|
// Select sets e to p1 if cond == 1, and to p2 if cond == 0.
|
||||||
func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 {
|
func (e *gfP2) Select(p1, p2 *gfP2, cond int) *gfP2 {
|
||||||
e.x.Select(&p1.x, &p2.x, cond)
|
e.x.Select(&p1.x, &p2.x, cond)
|
||||||
|
@ -106,19 +106,6 @@ func Test_gfP2Frobenius(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_gfP2Div2(t *testing.T) {
|
|
||||||
x := &gfP2{
|
|
||||||
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
|
|
||||||
*fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")),
|
|
||||||
}
|
|
||||||
ret := &gfP2{}
|
|
||||||
ret.Div2(x)
|
|
||||||
ret.Add(ret, ret)
|
|
||||||
if *ret != *x {
|
|
||||||
t.Errorf("got %v, expected %v", ret, x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_gfP2Sqrt(t *testing.T) {
|
func Test_gfP2Sqrt(t *testing.T) {
|
||||||
x := &gfP2{
|
x := &gfP2{
|
||||||
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
|
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
|
||||||
|
@ -113,13 +113,3 @@ func TestInvert(t *testing.T) {
|
|||||||
t.Errorf("got %v, expected %v", y, one)
|
t.Errorf("got %v, expected %v", y, one)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDiv(t *testing.T) {
|
|
||||||
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
|
|
||||||
ret := &gfP{}
|
|
||||||
ret.Div2(x)
|
|
||||||
gfpAdd(ret, ret, ret)
|
|
||||||
if *ret != *x {
|
|
||||||
t.Errorf("got %v, expected %v", ret, x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user