pkcs: supplement test cases

This commit is contained in:
Sun Yimin 2024-07-10 14:47:27 +08:00 committed by GitHub
parent d5b39e6176
commit ba1836fa45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 373 additions and 39 deletions

View File

@ -108,15 +108,15 @@ func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) {
return pkix.AlgorithmIdentifier{}, errors.New("pbes/pbkdf2: unsupported hash function")
}
// PBKDF2-params ::= SEQUENCE {
// salt CHOICE {
// specified OCTET STRING,
// otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}}
// },
// iterationCount INTEGER (1..MAX),
// keyLength INTEGER (1..MAX) OPTIONAL,
// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1
//}
// PBKDF2-params ::= SEQUENCE {
// salt CHOICE {
// specified OCTET STRING,
// otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}}
// },
// iterationCount INTEGER (1..MAX),
// keyLength INTEGER (1..MAX) OPTIONAL,
// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1
// }
type pbkdf2Params struct {
Salt []byte
IterationCount int
@ -132,6 +132,11 @@ func (p pbkdf2Params) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s
return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil
}
// KeyLength returns the length of the derived key.
func (p pbkdf2Params) KeyLength() int {
return p.KeyLen
}
// PBKDF2Opts contains options for the PBKDF2 key derivation function.
type PBKDF2Opts struct {
SaltSize int

79
pkcs/kdf_pbkdf2_test.go Normal file
View File

@ -0,0 +1,79 @@
package pkcs
import (
"bytes"
"crypto/sha1"
"crypto/x509/pkix"
"testing"
"github.com/emmansun/gmsm/sm3"
)
func TestNewHashFromPRF(t *testing.T) {
h, err := newHashFromPRF(oidPKCS5PBKDF2, pkix.AlgorithmIdentifier{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
hash := h()
if hash.Size() != sha1.Size {
t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sha1.Size)
}
h, err = newHashFromPRF(oidSMPBKDF, pkix.AlgorithmIdentifier{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
hash = h()
if hash.Size() != sm3.Size {
t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sm3.Size)
}
}
func TestPBKDF2DeriveKey(t *testing.T) {
testCases := []struct {
name string
opts PBKDF2Opts
}{
{
name: "PBKDF2-SHA1",
opts: NewPBKDF2Opts(SHA1, 16, 1000),
},
{
name: "PBKDF2-SHA256",
opts: NewPBKDF2Opts(SHA256, 16, 1000),
},
{
name: "SMPBKDF2",
opts: NewSMPBKDF2Opts(16, 1000),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
key, params, err := tc.opts.DeriveKey([]byte("password"), []byte("saltsaltsaltsalt"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(key) != 32 {
t.Errorf("unexpected key length: got %d, want 32", len(key))
}
if params.KeyLength() != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.KeyLength())
}
if len(params.(pbkdf2Params).Salt) != tc.opts.SaltSize {
t.Errorf("unexpected salt length: got %d, want %d", len(params.(pbkdf2Params).Salt), tc.opts.SaltSize)
}
if params.(pbkdf2Params).IterationCount != tc.opts.IterationCount {
t.Errorf("unexpected iteration count: got %d, want %d", params.(pbkdf2Params).IterationCount, tc.opts.IterationCount)
}
if params.(pbkdf2Params).KeyLen != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.(pbkdf2Params).KeyLen)
}
key2, err := params.DeriveKey(nil, []byte("password"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !bytes.Equal(key, key2) {
t.Errorf("unexpected key: got %x, want %x", key2, key)
}
})
}
}

View File

@ -25,6 +25,7 @@ type scryptParams struct {
CostParameter int
BlockSize int
ParallelizationParameter int
KeyLen int `asn1:"optional"`
}
func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) {
@ -32,6 +33,10 @@ func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s
p.ParallelizationParameter, size)
}
func (p scryptParams) KeyLength() int {
return p.KeyLen
}
// ScryptOpts contains options for the scrypt key derivation function.
type ScryptOpts struct {
SaltSize int
@ -63,6 +68,7 @@ func (p ScryptOpts) DeriveKey(password, salt []byte, size int) (
CostParameter: p.CostParameter,
ParallelizationParameter: p.ParallelizationParameter,
Salt: salt,
KeyLen: size,
}
return key, params, nil
}

39
pkcs/kdf_scrypt_test.go Normal file
View File

@ -0,0 +1,39 @@
package pkcs
import (
"bytes"
"testing"
)
func TestScryptDeriveKey(t *testing.T) {
opts := NewScryptOpts(8, 16384, 8, 1)
key, params, err := opts.DeriveKey([]byte("password"), []byte("saltsalt"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(key) != 32 {
t.Errorf("unexpected key length: got %d, want 32", len(key))
}
if params.KeyLength() != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.KeyLength())
}
if len(params.(scryptParams).Salt) != opts.SaltSize {
t.Errorf("unexpected salt length: got %d, want %d", len(params.(scryptParams).Salt), opts.SaltSize)
}
if params.(scryptParams).CostParameter != opts.CostParameter {
t.Errorf("unexpected cost parameter: got %d, want %d", params.(scryptParams).CostParameter, opts.CostParameter)
}
if params.(scryptParams).BlockSize != opts.BlockSize {
t.Errorf("unexpected block size: got %d, want %d", params.(scryptParams).BlockSize, opts.BlockSize)
}
if params.(scryptParams).ParallelizationParameter != opts.ParallelizationParameter {
t.Errorf("unexpected parallelization parameter: got %d, want %d", params.(scryptParams).ParallelizationParameter, opts.ParallelizationParameter)
}
key2, err := params.DeriveKey(nil, []byte("password"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !bytes.Equal(key, key2) {
t.Errorf("unexpected key: got %x, want %x", key2, key)
}
}

View File

@ -9,6 +9,7 @@ import (
"encoding/asn1"
"errors"
"hash"
"io"
"github.com/emmansun/gmsm/pkcs/internal/md2"
"github.com/emmansun/gmsm/pkcs/internal/rc2"
@ -33,8 +34,51 @@ type PBES1 struct {
Algorithm pkix.AlgorithmIdentifier
}
// newPBES1 creates a new PBES1 instance.
func newPBES1(rand io.Reader, oid asn1.ObjectIdentifier, saltLen, iterations int) (*PBES1, error) {
salt := make([]byte, saltLen)
if _, err := rand.Read(salt); err != nil {
return nil, err
}
param := pbeParameter{Salt: salt, Iteration: iterations}
marshalledParams, err := asn1.Marshal(param)
if err != nil {
return nil, err
}
return &PBES1{
Algorithm: pkix.AlgorithmIdentifier{
Algorithm: oid,
Parameters: asn1.RawValue{FullBytes: marshalledParams},
},
}, nil
}
func NewPbeWithMD2AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD2AndDESCBC, saltLen, iterations)
}
func NewPbeWithMD2AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD2AndRC2CBC, saltLen, iterations)
}
func NewPbeWithMD5AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD5AndDESCBC, saltLen, iterations)
}
func NewPbeWithMD5AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD5AndRC2CBC, saltLen, iterations)
}
func NewPbeWithSHA1AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithSHA1AndDESCBC, saltLen, iterations)
}
func NewPbeWithSHA1AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithSHA1AndRC2CBC, saltLen, iterations)
}
// Key returns the key derived from the password according PBKDF1.
func (pbes1 *PBES1) Key(password []byte) ([]byte, error) {
func (pbes1 *PBES1) key(password []byte) ([]byte, error) {
param := new(pbeParameter)
if _, err := asn1.Unmarshal(pbes1.Algorithm.Parameters.FullBytes, param); err != nil {
return nil, err
@ -61,24 +105,45 @@ func (pbes1 *PBES1) Key(password []byte) ([]byte, error) {
return key, nil
}
func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) {
key, err := pbes1.Key(password)
if err != nil {
return nil, nil, err
}
func (pbes1 *PBES1) newBlock(key []byte) (cipher.Block, error) {
var block cipher.Block
switch {
case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndDESCBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndDESCBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndDESCBC):
block, err = des.NewCipher(key[:8])
block, _ = des.NewCipher(key[:8])
case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndRC2CBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndRC2CBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC):
block, err = rc2.NewCipher(key[:8])
block, _ = rc2.NewCipher(key[:8])
default:
return nil, nil, errors.New("pbes: unsupported pbes1 cipher")
return nil, errors.New("pbes: unsupported pbes1 cipher")
}
return block, nil
}
func (pbes1 *PBES1) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
key, err := pbes1.key(password)
if err != nil {
return nil, nil, err
}
block, err := pbes1.newBlock(key)
if err != nil {
return nil, nil, err
}
ciphertext, err := cbcEncrypt(block, key[8:16], plaintext)
if err != nil {
return nil, nil, err
}
return &pbes1.Algorithm, ciphertext, nil
}
func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) {
key, err := pbes1.key(password)
if err != nil {
return nil, nil, err
}
block, err := pbes1.newBlock(key)
if err != nil {
return nil, nil, err
}

63
pkcs/pkcs5_pbes1_test.go Normal file
View File

@ -0,0 +1,63 @@
package pkcs
import (
"crypto/rand"
"testing"
)
func TestPBES1(t *testing.T) {
var testCases []*PBES1
pbes1, err := NewPbeWithMD2AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD2AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD5AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD5AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithSHA1AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithSHA1AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
for _, pbes1 := range testCases {
t.Run("", func(t *testing.T) {
_, ciphertext, err := pbes1.Encrypt(rand.Reader, []byte("password"), []byte("pbes1"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
plaintext, _, err := pbes1.Decrypt([]byte("password"), ciphertext)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(plaintext) != "pbes1" {
t.Errorf("unexpected plaintext: got %s, want password", plaintext)
}
})
}
}

View File

@ -65,10 +65,11 @@ var (
)
// PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function.
// PBES2-params ::= SEQUENCE {
// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}}
// }
//
// PBES2-params ::= SEQUENCE {
// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}}
// }
type PBES2Params struct {
KeyDerivationFunc pkix.AlgorithmIdentifier
EncryptionScheme pkix.AlgorithmIdentifier
@ -140,6 +141,8 @@ type KDFParameters interface {
// DeriveKey derives a key of size bytes from the given password.
// It uses the salt from the decoded parameters.
DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error)
// KeyLength returns the length of the derived key from the params.
KeyLength() int
}
var kdfs = make(map[string]func() KDFParameters)
@ -189,39 +192,43 @@ func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KD
return plaintext, kdfParams, nil
}
// Encrypt encrypts the given plaintext using the given password and the options specified.
func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
func deriveKey(kdfOpts KDFOpts, rand io.Reader, password []byte, size int) ([]byte, *pkix.AlgorithmIdentifier, error) {
// Generate a random salt
salt := make([]byte, opts.KDFOpts.GetSaltSize())
salt := make([]byte, kdfOpts.GetSaltSize())
if _, err := rand.Read(salt); err != nil {
return nil, nil, err
}
// Derive the key
encAlg := opts.Cipher
key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize())
key, kdfParams, err := kdfOpts.DeriveKey(password, salt, size)
if err != nil {
return nil, nil, err
}
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: kdfOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
return key, &keyDerivationFunc, nil
}
// Encrypt encrypts the given plaintext using the given password and the options specified.
func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
encAlg := opts.Cipher
key, keyDerivationFunc, err := deriveKey(opts.KDFOpts, rand, password, encAlg.KeySize())
if err != nil {
return nil, nil, err
}
// Encrypt the plaintext
encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext)
if err != nil {
return nil, nil, err
}
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: opts.KDFOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
encryptionAlgorithmParams := PBES2Params{
EncryptionScheme: *encryptionScheme,
KeyDerivationFunc: keyDerivationFunc,
KeyDerivationFunc: *keyDerivationFunc,
}
marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams)
if err != nil {

70
pkcs/pkcs5_pbes2_test.go Normal file
View File

@ -0,0 +1,70 @@
package pkcs
import (
"crypto/rand"
"encoding/asn1"
"testing"
)
func TestPBES2(t *testing.T) {
testCases := []struct {
name string
opts PBESEncrypter
}{
{
name: "PBKDF2-AES128-CBC",
opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2-AES192-CBC",
opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2-AES256-CBC",
opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES128-CBC",
opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES192-CBC",
opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES256-CBC",
opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SM3)-SM4-CBC",
opts: NewPBESEncrypter(SM4CBC, NewPBKDF2Opts(SM3, 16, 1000)),
},
{
name: "SMPBES",
opts: NewSMPBESEncrypter(16, 1000),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
alg, ciphertext, err := tc.opts.Encrypt(rand.Reader, []byte("password"), []byte("pbes2"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
pbes2Opts := tc.opts.(*PBES2Opts)
if !alg.Algorithm.Equal(pbes2Opts.pbesOID) {
t.Errorf("unexpected algorithm: got %v, want %v", alg.Algorithm, tc.opts.(*PBES2Opts).pbesOID)
}
var param PBES2Params
if _, err := asn1.Unmarshal(alg.Parameters.FullBytes, &param); err != nil {
t.Errorf("unexpected error: %v", err)
}
plaintext, _, err := param.Decrypt([]byte("password"), ciphertext)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(plaintext) != "pbes2" {
t.Errorf("unexpected plaintext: got %s, want pbes2", plaintext)
}
})
}
}