pkcs8: fix gcm parameters asn1 issue

This commit is contained in:
Sun Yimin 2023-02-28 13:43:00 +08:00 committed by GitHub
parent a47ae96293
commit 617d2591d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 3 deletions

View File

@ -1,7 +1,9 @@
package sm2ec
import (
"bytes"
"encoding/hex"
"fmt"
"math/big"
"testing"
)
@ -159,3 +161,68 @@ func TestForSqrt(t *testing.T) {
exp := new(big.Int).Add(sm2Prime, big.NewInt(1))
exp.Div(exp, big.NewInt(4))
}
func TestScalarMult(t *testing.T) {
G := NewSM2P256Point().SetGenerator()
checkScalar := func(t *testing.T, scalar []byte) {
p1, err := NewSM2P256Point().ScalarBaseMult(scalar)
fatalIfErr(t, err)
p2, err := NewSM2P256Point().ScalarMult(G, scalar)
fatalIfErr(t, err)
if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
t.Error("[k]G != ScalarBaseMult(k)")
}
d := new(big.Int).SetBytes(scalar)
d.Sub(sm2n, d)
d.Mod(d, sm2n)
g1, err := NewSM2P256Point().ScalarBaseMult(d.FillBytes(make([]byte, len(scalar))))
fatalIfErr(t, err)
g1.Add(g1, p1)
if !bytes.Equal(g1.Bytes(), NewSM2P256Point().Bytes()) {
t.Error("[N - k]G + [k]G != ∞")
}
}
byteLen := len(sm2n.Bytes())
bitLen := sm2n.BitLen()
t.Run("0", func(t *testing.T) { checkScalar(t, make([]byte, byteLen)) })
t.Run("1", func(t *testing.T) {
checkScalar(t, big.NewInt(1).FillBytes(make([]byte, byteLen)))
})
t.Run("N-1", func(t *testing.T) {
checkScalar(t, new(big.Int).Sub(sm2n, big.NewInt(1)).Bytes())
})
t.Run("N", func(t *testing.T) { checkScalar(t, sm2n.Bytes()) })
t.Run("N+1", func(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(1)).Bytes())
})
t.Run("all1s", func(t *testing.T) {
s := new(big.Int).Lsh(big.NewInt(1), uint(bitLen))
s.Sub(s, big.NewInt(1))
checkScalar(t, s.Bytes())
})
if testing.Short() {
return
}
for i := 0; i < bitLen; i++ {
t.Run(fmt.Sprintf("1<<%d", i), func(t *testing.T) {
s := new(big.Int).Lsh(big.NewInt(1), uint(i))
checkScalar(t, s.FillBytes(make([]byte, byteLen)))
})
}
// Test N+1...N+32 since they risk overlapping with precomputed table values
// in the final additions.
for i := int64(2); i <= 32; i++ {
t.Run(fmt.Sprintf("N+%d", i), func(t *testing.T) {
checkScalar(t, new(big.Int).Add(sm2n, big.NewInt(i)).Bytes())
})
}
}
func fatalIfErr(t *testing.T, err error) {
t.Helper()
if err != nil {
t.Fatal(err)
}
}

View File

@ -135,8 +135,7 @@ func (c cipherWithGCM) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier
encryptionAlgorithm := pkix.AlgorithmIdentifier{
Algorithm: c.oid,
Parameters: asn1.RawValue{
Tag: asn1.TagSequence,
Bytes: paramBytes,
FullBytes: paramBytes,
},
}
return &encryptionAlgorithm, ciphertext, nil
@ -148,7 +147,7 @@ func (c cipherWithGCM) Decrypt(key []byte, parameters *asn1.RawValue, encryptedK
return nil, err
}
params := gcmParameters{}
_, err = asn1.Unmarshal(parameters.Bytes, &params)
_, err = asn1.Unmarshal(parameters.FullBytes, &params)
if err != nil {
return nil, err
}