pkcs: supplement test cases

This commit is contained in:
Sun Yimin 2024-07-10 14:47:27 +08:00 committed by GitHub
parent d5b39e6176
commit ba1836fa45
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 373 additions and 39 deletions

View File

@ -116,7 +116,7 @@ func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) {
// iterationCount INTEGER (1..MAX), // iterationCount INTEGER (1..MAX),
// keyLength INTEGER (1..MAX) OPTIONAL, // keyLength INTEGER (1..MAX) OPTIONAL,
// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 // prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1
//} // }
type pbkdf2Params struct { type pbkdf2Params struct {
Salt []byte Salt []byte
IterationCount int IterationCount int
@ -132,6 +132,11 @@ func (p pbkdf2Params) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s
return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil
} }
// KeyLength returns the length of the derived key.
func (p pbkdf2Params) KeyLength() int {
return p.KeyLen
}
// PBKDF2Opts contains options for the PBKDF2 key derivation function. // PBKDF2Opts contains options for the PBKDF2 key derivation function.
type PBKDF2Opts struct { type PBKDF2Opts struct {
SaltSize int SaltSize int

79
pkcs/kdf_pbkdf2_test.go Normal file
View File

@ -0,0 +1,79 @@
package pkcs
import (
"bytes"
"crypto/sha1"
"crypto/x509/pkix"
"testing"
"github.com/emmansun/gmsm/sm3"
)
func TestNewHashFromPRF(t *testing.T) {
h, err := newHashFromPRF(oidPKCS5PBKDF2, pkix.AlgorithmIdentifier{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
hash := h()
if hash.Size() != sha1.Size {
t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sha1.Size)
}
h, err = newHashFromPRF(oidSMPBKDF, pkix.AlgorithmIdentifier{})
if err != nil {
t.Errorf("unexpected error: %v", err)
}
hash = h()
if hash.Size() != sm3.Size {
t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sm3.Size)
}
}
func TestPBKDF2DeriveKey(t *testing.T) {
testCases := []struct {
name string
opts PBKDF2Opts
}{
{
name: "PBKDF2-SHA1",
opts: NewPBKDF2Opts(SHA1, 16, 1000),
},
{
name: "PBKDF2-SHA256",
opts: NewPBKDF2Opts(SHA256, 16, 1000),
},
{
name: "SMPBKDF2",
opts: NewSMPBKDF2Opts(16, 1000),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
key, params, err := tc.opts.DeriveKey([]byte("password"), []byte("saltsaltsaltsalt"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(key) != 32 {
t.Errorf("unexpected key length: got %d, want 32", len(key))
}
if params.KeyLength() != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.KeyLength())
}
if len(params.(pbkdf2Params).Salt) != tc.opts.SaltSize {
t.Errorf("unexpected salt length: got %d, want %d", len(params.(pbkdf2Params).Salt), tc.opts.SaltSize)
}
if params.(pbkdf2Params).IterationCount != tc.opts.IterationCount {
t.Errorf("unexpected iteration count: got %d, want %d", params.(pbkdf2Params).IterationCount, tc.opts.IterationCount)
}
if params.(pbkdf2Params).KeyLen != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.(pbkdf2Params).KeyLen)
}
key2, err := params.DeriveKey(nil, []byte("password"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !bytes.Equal(key, key2) {
t.Errorf("unexpected key: got %x, want %x", key2, key)
}
})
}
}

View File

@ -25,6 +25,7 @@ type scryptParams struct {
CostParameter int CostParameter int
BlockSize int BlockSize int
ParallelizationParameter int ParallelizationParameter int
KeyLen int `asn1:"optional"`
} }
func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) { func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) {
@ -32,6 +33,10 @@ func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s
p.ParallelizationParameter, size) p.ParallelizationParameter, size)
} }
func (p scryptParams) KeyLength() int {
return p.KeyLen
}
// ScryptOpts contains options for the scrypt key derivation function. // ScryptOpts contains options for the scrypt key derivation function.
type ScryptOpts struct { type ScryptOpts struct {
SaltSize int SaltSize int
@ -63,6 +68,7 @@ func (p ScryptOpts) DeriveKey(password, salt []byte, size int) (
CostParameter: p.CostParameter, CostParameter: p.CostParameter,
ParallelizationParameter: p.ParallelizationParameter, ParallelizationParameter: p.ParallelizationParameter,
Salt: salt, Salt: salt,
KeyLen: size,
} }
return key, params, nil return key, params, nil
} }

39
pkcs/kdf_scrypt_test.go Normal file
View File

@ -0,0 +1,39 @@
package pkcs
import (
"bytes"
"testing"
)
func TestScryptDeriveKey(t *testing.T) {
opts := NewScryptOpts(8, 16384, 8, 1)
key, params, err := opts.DeriveKey([]byte("password"), []byte("saltsalt"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if len(key) != 32 {
t.Errorf("unexpected key length: got %d, want 32", len(key))
}
if params.KeyLength() != 32 {
t.Errorf("unexpected key length: got %d, want 32", params.KeyLength())
}
if len(params.(scryptParams).Salt) != opts.SaltSize {
t.Errorf("unexpected salt length: got %d, want %d", len(params.(scryptParams).Salt), opts.SaltSize)
}
if params.(scryptParams).CostParameter != opts.CostParameter {
t.Errorf("unexpected cost parameter: got %d, want %d", params.(scryptParams).CostParameter, opts.CostParameter)
}
if params.(scryptParams).BlockSize != opts.BlockSize {
t.Errorf("unexpected block size: got %d, want %d", params.(scryptParams).BlockSize, opts.BlockSize)
}
if params.(scryptParams).ParallelizationParameter != opts.ParallelizationParameter {
t.Errorf("unexpected parallelization parameter: got %d, want %d", params.(scryptParams).ParallelizationParameter, opts.ParallelizationParameter)
}
key2, err := params.DeriveKey(nil, []byte("password"), 32)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if !bytes.Equal(key, key2) {
t.Errorf("unexpected key: got %x, want %x", key2, key)
}
}

View File

@ -9,6 +9,7 @@ import (
"encoding/asn1" "encoding/asn1"
"errors" "errors"
"hash" "hash"
"io"
"github.com/emmansun/gmsm/pkcs/internal/md2" "github.com/emmansun/gmsm/pkcs/internal/md2"
"github.com/emmansun/gmsm/pkcs/internal/rc2" "github.com/emmansun/gmsm/pkcs/internal/rc2"
@ -33,8 +34,51 @@ type PBES1 struct {
Algorithm pkix.AlgorithmIdentifier Algorithm pkix.AlgorithmIdentifier
} }
// newPBES1 creates a new PBES1 instance.
func newPBES1(rand io.Reader, oid asn1.ObjectIdentifier, saltLen, iterations int) (*PBES1, error) {
salt := make([]byte, saltLen)
if _, err := rand.Read(salt); err != nil {
return nil, err
}
param := pbeParameter{Salt: salt, Iteration: iterations}
marshalledParams, err := asn1.Marshal(param)
if err != nil {
return nil, err
}
return &PBES1{
Algorithm: pkix.AlgorithmIdentifier{
Algorithm: oid,
Parameters: asn1.RawValue{FullBytes: marshalledParams},
},
}, nil
}
func NewPbeWithMD2AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD2AndDESCBC, saltLen, iterations)
}
func NewPbeWithMD2AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD2AndRC2CBC, saltLen, iterations)
}
func NewPbeWithMD5AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD5AndDESCBC, saltLen, iterations)
}
func NewPbeWithMD5AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithMD5AndRC2CBC, saltLen, iterations)
}
func NewPbeWithSHA1AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithSHA1AndDESCBC, saltLen, iterations)
}
func NewPbeWithSHA1AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) {
return newPBES1(rand, pbeWithSHA1AndRC2CBC, saltLen, iterations)
}
// Key returns the key derived from the password according PBKDF1. // Key returns the key derived from the password according PBKDF1.
func (pbes1 *PBES1) Key(password []byte) ([]byte, error) { func (pbes1 *PBES1) key(password []byte) ([]byte, error) {
param := new(pbeParameter) param := new(pbeParameter)
if _, err := asn1.Unmarshal(pbes1.Algorithm.Parameters.FullBytes, param); err != nil { if _, err := asn1.Unmarshal(pbes1.Algorithm.Parameters.FullBytes, param); err != nil {
return nil, err return nil, err
@ -61,24 +105,45 @@ func (pbes1 *PBES1) Key(password []byte) ([]byte, error) {
return key, nil return key, nil
} }
func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) { func (pbes1 *PBES1) newBlock(key []byte) (cipher.Block, error) {
key, err := pbes1.Key(password)
if err != nil {
return nil, nil, err
}
var block cipher.Block var block cipher.Block
switch { switch {
case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndDESCBC) || case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndDESCBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndDESCBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndDESCBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndDESCBC): pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndDESCBC):
block, err = des.NewCipher(key[:8]) block, _ = des.NewCipher(key[:8])
case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndRC2CBC) || case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndRC2CBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndRC2CBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndRC2CBC) ||
pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC): pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC):
block, err = rc2.NewCipher(key[:8]) block, _ = rc2.NewCipher(key[:8])
default: default:
return nil, nil, errors.New("pbes: unsupported pbes1 cipher") return nil, errors.New("pbes: unsupported pbes1 cipher")
} }
return block, nil
}
func (pbes1 *PBES1) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
key, err := pbes1.key(password)
if err != nil {
return nil, nil, err
}
block, err := pbes1.newBlock(key)
if err != nil {
return nil, nil, err
}
ciphertext, err := cbcEncrypt(block, key[8:16], plaintext)
if err != nil {
return nil, nil, err
}
return &pbes1.Algorithm, ciphertext, nil
}
func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) {
key, err := pbes1.key(password)
if err != nil {
return nil, nil, err
}
block, err := pbes1.newBlock(key)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

63
pkcs/pkcs5_pbes1_test.go Normal file
View File

@ -0,0 +1,63 @@
package pkcs
import (
"crypto/rand"
"testing"
)
func TestPBES1(t *testing.T) {
var testCases []*PBES1
pbes1, err := NewPbeWithMD2AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD2AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD5AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithMD5AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithSHA1AndDESCBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
pbes1, err = NewPbeWithSHA1AndRC2CBC(rand.Reader, 8, 1000)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
testCases = append(testCases, pbes1)
for _, pbes1 := range testCases {
t.Run("", func(t *testing.T) {
_, ciphertext, err := pbes1.Encrypt(rand.Reader, []byte("password"), []byte("pbes1"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
plaintext, _, err := pbes1.Decrypt([]byte("password"), ciphertext)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(plaintext) != "pbes1" {
t.Errorf("unexpected plaintext: got %s, want password", plaintext)
}
})
}
}

View File

@ -65,6 +65,7 @@ var (
) )
// PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function. // PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function.
//
// PBES2-params ::= SEQUENCE { // PBES2-params ::= SEQUENCE {
// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}}, // keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} // encryptionScheme AlgorithmIdentifier {{PBES2-Encs}}
@ -140,6 +141,8 @@ type KDFParameters interface {
// DeriveKey derives a key of size bytes from the given password. // DeriveKey derives a key of size bytes from the given password.
// It uses the salt from the decoded parameters. // It uses the salt from the decoded parameters.
DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error)
// KeyLength returns the length of the derived key from the params.
KeyLength() int
} }
var kdfs = make(map[string]func() KDFParameters) var kdfs = make(map[string]func() KDFParameters)
@ -189,39 +192,43 @@ func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KD
return plaintext, kdfParams, nil return plaintext, kdfParams, nil
} }
// Encrypt encrypts the given plaintext using the given password and the options specified. func deriveKey(kdfOpts KDFOpts, rand io.Reader, password []byte, size int) ([]byte, *pkix.AlgorithmIdentifier, error) {
func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
// Generate a random salt // Generate a random salt
salt := make([]byte, opts.KDFOpts.GetSaltSize()) salt := make([]byte, kdfOpts.GetSaltSize())
if _, err := rand.Read(salt); err != nil { if _, err := rand.Read(salt); err != nil {
return nil, nil, err return nil, nil, err
} }
key, kdfParams, err := kdfOpts.DeriveKey(password, salt, size)
// Derive the key
encAlg := opts.Cipher
key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize())
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: kdfOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
return key, &keyDerivationFunc, nil
}
// Encrypt encrypts the given plaintext using the given password and the options specified.
func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
encAlg := opts.Cipher
key, keyDerivationFunc, err := deriveKey(opts.KDFOpts, rand, password, encAlg.KeySize())
if err != nil {
return nil, nil, err
}
// Encrypt the plaintext // Encrypt the plaintext
encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext) encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: opts.KDFOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
encryptionAlgorithmParams := PBES2Params{ encryptionAlgorithmParams := PBES2Params{
EncryptionScheme: *encryptionScheme, EncryptionScheme: *encryptionScheme,
KeyDerivationFunc: keyDerivationFunc, KeyDerivationFunc: *keyDerivationFunc,
} }
marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams) marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams)
if err != nil { if err != nil {

70
pkcs/pkcs5_pbes2_test.go Normal file
View File

@ -0,0 +1,70 @@
package pkcs
import (
"crypto/rand"
"encoding/asn1"
"testing"
)
func TestPBES2(t *testing.T) {
testCases := []struct {
name string
opts PBESEncrypter
}{
{
name: "PBKDF2-AES128-CBC",
opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2-AES192-CBC",
opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2-AES256-CBC",
opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA1, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES128-CBC",
opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES192-CBC",
opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SHA256)-AES256-CBC",
opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA256, 16, 1000)),
},
{
name: "PBKDF2(SM3)-SM4-CBC",
opts: NewPBESEncrypter(SM4CBC, NewPBKDF2Opts(SM3, 16, 1000)),
},
{
name: "SMPBES",
opts: NewSMPBESEncrypter(16, 1000),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
alg, ciphertext, err := tc.opts.Encrypt(rand.Reader, []byte("password"), []byte("pbes2"))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
pbes2Opts := tc.opts.(*PBES2Opts)
if !alg.Algorithm.Equal(pbes2Opts.pbesOID) {
t.Errorf("unexpected algorithm: got %v, want %v", alg.Algorithm, tc.opts.(*PBES2Opts).pbesOID)
}
var param PBES2Params
if _, err := asn1.Unmarshal(alg.Parameters.FullBytes, &param); err != nil {
t.Errorf("unexpected error: %v", err)
}
plaintext, _, err := param.Decrypt([]byte("password"), ciphertext)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if string(plaintext) != "pbes2" {
t.Errorf("unexpected plaintext: got %s, want pbes2", plaintext)
}
})
}
}