package sm2ec import ( "bytes" "crypto/rand" "encoding/hex" "fmt" "math/big" "testing" ) // r = 2^256 var r = bigFromHex("010000000000000000000000000000000000000000000000000000000000000000") var r0 = bigFromHex("010000000000000000") var sm2Prime = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF") var sm2n = bigFromHex("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFF7203DF6B21C6052B53BBF40939D54123") var nistP256Prime = bigFromDecimal("115792089210356248762697446949407573530086143415290314195533631308867097853951") var nistP256N = bigFromDecimal("115792089210356248762697446949407573529996955224135760342422259061068512044369") func generateMontgomeryDomain(in *big.Int, p *big.Int) *big.Int { tmp := new(big.Int) tmp = tmp.Mul(in, r) return tmp.Mod(tmp, p) } func bigFromHex(s string) *big.Int { b, ok := new(big.Int).SetString(s, 16) if !ok { panic("sm2ec: internal error: invalid encoding") } return b } func bigFromDecimal(s string) *big.Int { b, ok := new(big.Int).SetString(s, 10) if !ok { panic("sm2ec: internal error: invalid encoding") } return b } func TestSM2P256MontgomeryDomain(t *testing.T) { tests := []struct { in string out string }{ { // One "01", "0000000100000000000000000000000000000000ffffffff0000000000000001", }, { // Gx "32C4AE2C1F1981195F9904466A39C9948FE30BBFF2660BE1715A4589334C74C7", "91167a5ee1c13b05d6a1ed99ac24c3c33e7981eddca6c05061328990f418029e", }, { // Gy "BC3736A2F4F6779C59BDCEE36B692153D0A9877CC62A474002DF32E52139F0A0", "63cd65d481d735bd8d4cfb066e2a48f8c1f5e5788d3295fac1354e593c2d0ddd", }, { // B "28E9FA9E9D9F5E344D5A9E4BCF6509A7F39789F515AB8F92DDBCBD414D940E93", "240fe188ba20e2c8527981505ea51c3c71cf379ae9b537ab90d230632bc0dd42", }, { // R "010000000000000000000000000000000000000000000000000000000000000000", "0400000002000000010000000100000002ffffffff0000000200000003", }, } for _, test := range tests { out := generateMontgomeryDomain(bigFromHex(test.in), sm2Prime) if out.Cmp(bigFromHex(test.out)) != 0 { t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) } } } func TestSM2P256MontgomeryDomainN(t *testing.T) { tests := []struct { in string out string }{ { // One "01", "010000000000000000000000008dfc2094de39fad4ac440bf6c62abedd", }, { // R "010000000000000000000000000000000000000000000000000000000000000000", "1eb5e412a22b3d3b620fc84c3affe0d43464504ade6fa2fa901192af7c114f20", }, } for _, test := range tests { out := generateMontgomeryDomain(bigFromHex(test.in), sm2n) if out.Cmp(bigFromHex(test.out)) != 0 { t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) } } } func TestSM2P256MontgomeryK0(t *testing.T) { tests := []struct { in *big.Int out string }{ { sm2n, "327f9e8872350975", }, { sm2Prime, "0000000000000001", }, } for _, test := range tests { // k0 = -in^(-1) mod 2^64 k0 := new(big.Int).ModInverse(test.in, r0) k0.Neg(k0) k0.Mod(k0, r0) if k0.Cmp(bigFromHex(test.out)) != 0 { t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(k0.Bytes())) } } } func TestNISTP256MontgomeryDomain(t *testing.T) { tests := []struct { in string out string }{ { // One "01", "fffffffeffffffffffffffffffffffff000000000000000000000001", }, { // Gx "6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", "18905f76a53755c679fb732b7762251075ba95fc5fedb60179e730d418a9143c", }, { // Gy "4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", "8571ff1825885d85d2e88688dd21f3258b4ab8e4ba19e45cddf25357ce95560a", }, { // B "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", "dc30061d04874834e5a220abf7212ed6acf005cd78843090d89cdf6229c4bddf", }, { // R "010000000000000000000000000000000000000000000000000000000000000000", "04fffffffdfffffffffffffffefffffffbffffffff0000000000000003", }, } for _, test := range tests { out := generateMontgomeryDomain(bigFromHex(test.in), nistP256Prime) if out.Cmp(bigFromHex(test.out)) != 0 { t.Errorf("expected %v, got %v", test.out, hex.EncodeToString(out.Bytes())) } } } func TestForSqrt(t *testing.T) { mod4 := new(big.Int).Mod(sm2Prime, big.NewInt(4)) if mod4.Cmp(big.NewInt(3)) != 0 { t.Fatal("sm2 prime is not fulfill 3 mod 4") } exp := new(big.Int).Add(sm2Prime, big.NewInt(1)) exp.Div(exp, big.NewInt(4)) } func TestEquivalents(t *testing.T) { p := NewSM2P256Point().SetGenerator() elementSize := 32 two := make([]byte, elementSize) two[len(two)-1] = 2 nPlusTwo := make([]byte, elementSize) new(big.Int).Add(sm2n, big.NewInt(2)).FillBytes(nPlusTwo) p1 := NewSM2P256Point().Double(p) p2 := NewSM2P256Point().Add(p, p) p3, err := NewSM2P256Point().ScalarMult(p, two) fatalIfErr(t, err) p4, err := NewSM2P256Point().ScalarBaseMult(two) fatalIfErr(t, err) p5, err := NewSM2P256Point().ScalarMult(p, nPlusTwo) fatalIfErr(t, err) p6, err := NewSM2P256Point().ScalarBaseMult(nPlusTwo) fatalIfErr(t, err) if !bytes.Equal(p1.Bytes(), p2.Bytes()) { t.Error("P+P != 2*P") } if !bytes.Equal(p1.Bytes(), p3.Bytes()) { t.Error("P+P != [2]P") } if !bytes.Equal(p1.Bytes(), p4.Bytes()) { t.Error("G+G != [2]G") } if !bytes.Equal(p1.Bytes(), p5.Bytes()) { t.Error("P+P != [N+2]P") } if !bytes.Equal(p1.Bytes(), p6.Bytes()) { t.Error("G+G != [N+2]G") } } func TestScalarMult(t *testing.T) { G := NewSM2P256Point().SetGenerator() checkScalar := func(t *testing.T, scalar []byte) { p1, err := NewSM2P256Point().ScalarBaseMult(scalar) fatalIfErr(t, err) p2, err := NewSM2P256Point().ScalarMult(G, scalar) fatalIfErr(t, err) if !bytes.Equal(p1.Bytes(), p2.Bytes()) { t.Error("[k]G != ScalarBaseMult(k)") } d := new(big.Int).SetBytes(scalar) d.Sub(sm2n, d) d.Mod(d, sm2n) g1, err := NewSM2P256Point().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar)))) fatalIfErr(t, err) g1.Add(g1, p1) if !bytes.Equal(g1.Bytes(), NewSM2P256Point().Bytes()) { t.Error("[N - k]G + [k]G != ∞") } } byteLen := len(sm2n.Bytes()) bitLen := sm2n.BitLen() t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) }) t.Run("1", func(t *testing.T) { checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) }) t.Run("N-6", func(t *testing.T) { checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(6)).Bytes()) }) t.Run("N-1", func(t *testing.T) { checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes()) }) t.Run("N", func(t *testing.T) { checkScalar(t, sm2n.Bytes()) }) t.Run("N+1", func(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes()) }) t.Run("N+58", func(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(58)).Bytes()) }) t.Run("all1s", func(t *testing.T) { s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen)) s.Sub(s, big.NewInt(1)) checkScalar(t, s.Bytes()) }) if testing.Short() { return } for i := 0; i < bitLen; i++ { t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) { s := new(big.Int).Lsh(big.NewInt(1), uint(i)) checkScalar(t, s.FillBytes(make([]byte, byteLen))) }) } for i := 0; i <= 64; i++ { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { checkScalar(t, big.NewInt(int64(i)).FillBytes(make([]byte, byteLen))) }) } // Test N-64...N+64 since they risk overlapping with precomputed table values // in the final additions. for i := int64(-64); i <= 64; i++ { t.Run(fmt.Sprintf("N%+d", i), func(t *testing.T) { checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) }) } } func fatalIfErr(t *testing.T, err error) { t.Helper() if err != nil { t.Fatal(err) } } func BenchmarkScalarBaseMult(b *testing.B) { p := NewSM2P256Point().SetGenerator() scalar := make([]byte, 32) rand.Read(scalar) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { p.ScalarBaseMult(scalar) } } func BenchmarkScalarMult(b *testing.B) { p := NewSM2P256Point().SetGenerator() scalar := make([]byte, 32) rand.Read(scalar) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { p.ScalarMult(p, scalar) } }