diff --git a/internal/sm2ec/sm2ec_test.go b/internal/sm2ec/sm2ec_test.go index 220b1e5..9f6a3dc 100644 --- a/internal/sm2ec/sm2ec_test.go +++ b/internal/sm2ec/sm2ec_test.go @@ -1,7 +1,9 @@ package sm2ec import ( + "bytes" "encoding/hex" + "fmt" "math/big" "testing" ) @@ -159,3 +161,68 @@ func TestForSqrt(t *testing.T) { exp := new(big.Int).Add(sm2Prime, big.NewInt(1)) exp.Div(exp, big.NewInt(4)) } + +func TestScalarMult(t *testing.T) { + G := NewSM2P256Point().SetGenerator() + checkScalar := func(t *testing.T, scalar []byte) { + p1, err := NewSM2P256Point().ScalarBaseMult(scalar) + fatalIfErr(t, err) + p2, err := NewSM2P256Point().ScalarMult(G, scalar) + fatalIfErr(t, err) + if !bytes.Equal(p1.Bytes(), p2.Bytes()) { + t.Error("[k]G != ScalarBaseMult(k)") + } + + d := new(big.Int).SetBytes(scalar) + d.Sub(sm2n, d) + d.Mod(d, sm2n) + g1, err := NewSM2P256Point().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar)))) + fatalIfErr(t, err) + g1.Add(g1, p1) + if !bytes.Equal(g1.Bytes(), NewSM2P256Point().Bytes()) { + t.Error("[N - k]G + [k]G != ∞") + } + } + + byteLen := len(sm2n.Bytes()) + bitLen := sm2n.BitLen() + t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) }) + t.Run("1", func(t *testing.T) { + checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen))) + }) + t.Run("N-1", func(t *testing.T) { + checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes()) + }) + t.Run("N", func(t *testing.T) { checkScalar(t, sm2n.Bytes()) }) + t.Run("N+1", func(t *testing.T) { + checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes()) + }) + t.Run("all1s", func(t *testing.T) { + s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen)) + s.Sub(s, big.NewInt(1)) + checkScalar(t, s.Bytes()) + }) + if testing.Short() { + return + } + for i := 0; i < bitLen; i++ { + t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) { + s := new(big.Int).Lsh(big.NewInt(1), uint(i)) + checkScalar(t, s.FillBytes(make([]byte, byteLen))) + }) + } + // Test N+1...N+32 since they risk overlapping with precomputed table values + // in the final additions. + for i := int64(2); i <= 32; i++ { + t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) { + checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes()) + }) + } +} + +func fatalIfErr(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} diff --git a/pkcs8/cipher.go b/pkcs8/cipher.go index 87e4f89..135f0f4 100644 --- a/pkcs8/cipher.go +++ b/pkcs8/cipher.go @@ -135,8 +135,7 @@ func (c cipherWithGCM) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier encryptionAlgorithm := pkix.AlgorithmIdentifier{ Algorithm: c.oid, Parameters: asn1.RawValue{ - Tag: asn1.TagSequence, - Bytes: paramBytes, + FullBytes: paramBytes, }, } return &encryptionAlgorithm, ciphertext, nil @@ -148,7 +147,7 @@ func (c cipherWithGCM) Decrypt(key []byte, parameters *asn1.RawValue, encryptedK return nil, err } params := gcmParameters{} - _, err = asn1.Unmarshal(parameters.Bytes, ¶ms) + _, err = asn1.Unmarshal(parameters.FullBytes, ¶ms) if err != nil { return nil, err }