diff --git a/pkcs/kdf_pbkdf2.go b/pkcs/kdf_pbkdf2.go index ed1b89f..19efff1 100644 --- a/pkcs/kdf_pbkdf2.go +++ b/pkcs/kdf_pbkdf2.go @@ -108,15 +108,15 @@ func newPRFParamFromHash(h Hash) (pkix.AlgorithmIdentifier, error) { return pkix.AlgorithmIdentifier{}, errors.New("pbes/pbkdf2: unsupported hash function") } -// PBKDF2-params ::= SEQUENCE { -// salt CHOICE { -// specified OCTET STRING, -// otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}} -// }, -// iterationCount INTEGER (1..MAX), -// keyLength INTEGER (1..MAX) OPTIONAL, -// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 -//} +// PBKDF2-params ::= SEQUENCE { +// salt CHOICE { +// specified OCTET STRING, +// otherSource AlgorithmIdentifier {{PBKDF2-SaltSources}} +// }, +// iterationCount INTEGER (1..MAX), +// keyLength INTEGER (1..MAX) OPTIONAL, +// prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 +// } type pbkdf2Params struct { Salt []byte IterationCount int @@ -132,6 +132,11 @@ func (p pbkdf2Params) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s return pbkdf2.Key(password, p.Salt, p.IterationCount, size, h), nil } +// KeyLength returns the length of the derived key. +func (p pbkdf2Params) KeyLength() int { + return p.KeyLen +} + // PBKDF2Opts contains options for the PBKDF2 key derivation function. type PBKDF2Opts struct { SaltSize int diff --git a/pkcs/kdf_pbkdf2_test.go b/pkcs/kdf_pbkdf2_test.go new file mode 100644 index 0000000..860f805 --- /dev/null +++ b/pkcs/kdf_pbkdf2_test.go @@ -0,0 +1,79 @@ +package pkcs + +import ( + "bytes" + "crypto/sha1" + "crypto/x509/pkix" + "testing" + + "github.com/emmansun/gmsm/sm3" +) + +func TestNewHashFromPRF(t *testing.T) { + h, err := newHashFromPRF(oidPKCS5PBKDF2, pkix.AlgorithmIdentifier{}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + hash := h() + if hash.Size() != sha1.Size { + t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sha1.Size) + } + h, err = newHashFromPRF(oidSMPBKDF, pkix.AlgorithmIdentifier{}) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + hash = h() + if hash.Size() != sm3.Size { + t.Errorf("unexpected hash size: got %d, want %d", hash.Size(), sm3.Size) + } +} + +func TestPBKDF2DeriveKey(t *testing.T) { + testCases := []struct { + name string + opts PBKDF2Opts + }{ + { + name: "PBKDF2-SHA1", + opts: NewPBKDF2Opts(SHA1, 16, 1000), + }, + { + name: "PBKDF2-SHA256", + opts: NewPBKDF2Opts(SHA256, 16, 1000), + }, + { + name: "SMPBKDF2", + opts: NewSMPBKDF2Opts(16, 1000), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + key, params, err := tc.opts.DeriveKey([]byte("password"), []byte("saltsaltsaltsalt"), 32) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(key) != 32 { + t.Errorf("unexpected key length: got %d, want 32", len(key)) + } + if params.KeyLength() != 32 { + t.Errorf("unexpected key length: got %d, want 32", params.KeyLength()) + } + if len(params.(pbkdf2Params).Salt) != tc.opts.SaltSize { + t.Errorf("unexpected salt length: got %d, want %d", len(params.(pbkdf2Params).Salt), tc.opts.SaltSize) + } + if params.(pbkdf2Params).IterationCount != tc.opts.IterationCount { + t.Errorf("unexpected iteration count: got %d, want %d", params.(pbkdf2Params).IterationCount, tc.opts.IterationCount) + } + if params.(pbkdf2Params).KeyLen != 32 { + t.Errorf("unexpected key length: got %d, want 32", params.(pbkdf2Params).KeyLen) + } + key2, err := params.DeriveKey(nil, []byte("password"), 32) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !bytes.Equal(key, key2) { + t.Errorf("unexpected key: got %x, want %x", key2, key) + } + }) + } +} diff --git a/pkcs/kdf_scrypt.go b/pkcs/kdf_scrypt.go index e6af136..7054c8f 100644 --- a/pkcs/kdf_scrypt.go +++ b/pkcs/kdf_scrypt.go @@ -25,6 +25,7 @@ type scryptParams struct { CostParameter int BlockSize int ParallelizationParameter int + KeyLen int `asn1:"optional"` } func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) { @@ -32,6 +33,10 @@ func (p scryptParams) DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, s p.ParallelizationParameter, size) } +func (p scryptParams) KeyLength() int { + return p.KeyLen +} + // ScryptOpts contains options for the scrypt key derivation function. type ScryptOpts struct { SaltSize int @@ -63,6 +68,7 @@ func (p ScryptOpts) DeriveKey(password, salt []byte, size int) ( CostParameter: p.CostParameter, ParallelizationParameter: p.ParallelizationParameter, Salt: salt, + KeyLen: size, } return key, params, nil } diff --git a/pkcs/kdf_scrypt_test.go b/pkcs/kdf_scrypt_test.go new file mode 100644 index 0000000..2e418da --- /dev/null +++ b/pkcs/kdf_scrypt_test.go @@ -0,0 +1,39 @@ +package pkcs + +import ( + "bytes" + "testing" +) + +func TestScryptDeriveKey(t *testing.T) { + opts := NewScryptOpts(8, 16384, 8, 1) + key, params, err := opts.DeriveKey([]byte("password"), []byte("saltsalt"), 32) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if len(key) != 32 { + t.Errorf("unexpected key length: got %d, want 32", len(key)) + } + if params.KeyLength() != 32 { + t.Errorf("unexpected key length: got %d, want 32", params.KeyLength()) + } + if len(params.(scryptParams).Salt) != opts.SaltSize { + t.Errorf("unexpected salt length: got %d, want %d", len(params.(scryptParams).Salt), opts.SaltSize) + } + if params.(scryptParams).CostParameter != opts.CostParameter { + t.Errorf("unexpected cost parameter: got %d, want %d", params.(scryptParams).CostParameter, opts.CostParameter) + } + if params.(scryptParams).BlockSize != opts.BlockSize { + t.Errorf("unexpected block size: got %d, want %d", params.(scryptParams).BlockSize, opts.BlockSize) + } + if params.(scryptParams).ParallelizationParameter != opts.ParallelizationParameter { + t.Errorf("unexpected parallelization parameter: got %d, want %d", params.(scryptParams).ParallelizationParameter, opts.ParallelizationParameter) + } + key2, err := params.DeriveKey(nil, []byte("password"), 32) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !bytes.Equal(key, key2) { + t.Errorf("unexpected key: got %x, want %x", key2, key) + } +} diff --git a/pkcs/pkcs5_pbes1.go b/pkcs/pkcs5_pbes1.go index 82da347..63bf458 100644 --- a/pkcs/pkcs5_pbes1.go +++ b/pkcs/pkcs5_pbes1.go @@ -9,6 +9,7 @@ import ( "encoding/asn1" "errors" "hash" + "io" "github.com/emmansun/gmsm/pkcs/internal/md2" "github.com/emmansun/gmsm/pkcs/internal/rc2" @@ -33,8 +34,51 @@ type PBES1 struct { Algorithm pkix.AlgorithmIdentifier } +// newPBES1 creates a new PBES1 instance. +func newPBES1(rand io.Reader, oid asn1.ObjectIdentifier, saltLen, iterations int) (*PBES1, error) { + salt := make([]byte, saltLen) + if _, err := rand.Read(salt); err != nil { + return nil, err + } + param := pbeParameter{Salt: salt, Iteration: iterations} + marshalledParams, err := asn1.Marshal(param) + if err != nil { + return nil, err + } + return &PBES1{ + Algorithm: pkix.AlgorithmIdentifier{ + Algorithm: oid, + Parameters: asn1.RawValue{FullBytes: marshalledParams}, + }, + }, nil +} + +func NewPbeWithMD2AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithMD2AndDESCBC, saltLen, iterations) +} + +func NewPbeWithMD2AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithMD2AndRC2CBC, saltLen, iterations) +} + +func NewPbeWithMD5AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithMD5AndDESCBC, saltLen, iterations) +} + +func NewPbeWithMD5AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithMD5AndRC2CBC, saltLen, iterations) +} + +func NewPbeWithSHA1AndDESCBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithSHA1AndDESCBC, saltLen, iterations) +} + +func NewPbeWithSHA1AndRC2CBC(rand io.Reader, saltLen, iterations int) (*PBES1, error) { + return newPBES1(rand, pbeWithSHA1AndRC2CBC, saltLen, iterations) +} + // Key returns the key derived from the password according PBKDF1. -func (pbes1 *PBES1) Key(password []byte) ([]byte, error) { +func (pbes1 *PBES1) key(password []byte) ([]byte, error) { param := new(pbeParameter) if _, err := asn1.Unmarshal(pbes1.Algorithm.Parameters.FullBytes, param); err != nil { return nil, err @@ -61,24 +105,45 @@ func (pbes1 *PBES1) Key(password []byte) ([]byte, error) { return key, nil } -func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) { - key, err := pbes1.Key(password) - if err != nil { - return nil, nil, err - } +func (pbes1 *PBES1) newBlock(key []byte) (cipher.Block, error) { var block cipher.Block switch { case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndDESCBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndDESCBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndDESCBC): - block, err = des.NewCipher(key[:8]) + block, _ = des.NewCipher(key[:8]) case pbes1.Algorithm.Algorithm.Equal(pbeWithMD2AndRC2CBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithMD5AndRC2CBC) || pbes1.Algorithm.Algorithm.Equal(pbeWithSHA1AndRC2CBC): - block, err = rc2.NewCipher(key[:8]) + block, _ = rc2.NewCipher(key[:8]) default: - return nil, nil, errors.New("pbes: unsupported pbes1 cipher") + return nil, errors.New("pbes: unsupported pbes1 cipher") } + return block, nil +} + +func (pbes1 *PBES1) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { + key, err := pbes1.key(password) + if err != nil { + return nil, nil, err + } + block, err := pbes1.newBlock(key) + if err != nil { + return nil, nil, err + } + ciphertext, err := cbcEncrypt(block, key[8:16], plaintext) + if err != nil { + return nil, nil, err + } + return &pbes1.Algorithm, ciphertext, nil +} + +func (pbes1 *PBES1) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) { + key, err := pbes1.key(password) + if err != nil { + return nil, nil, err + } + block, err := pbes1.newBlock(key) if err != nil { return nil, nil, err } diff --git a/pkcs/pkcs5_pbes1_test.go b/pkcs/pkcs5_pbes1_test.go new file mode 100644 index 0000000..b25c4ef --- /dev/null +++ b/pkcs/pkcs5_pbes1_test.go @@ -0,0 +1,63 @@ +package pkcs + +import ( + "crypto/rand" + "testing" +) + +func TestPBES1(t *testing.T) { + var testCases []*PBES1 + + pbes1, err := NewPbeWithMD2AndDESCBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + pbes1, err = NewPbeWithMD2AndRC2CBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + pbes1, err = NewPbeWithMD5AndDESCBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + pbes1, err = NewPbeWithMD5AndRC2CBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + pbes1, err = NewPbeWithSHA1AndDESCBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + pbes1, err = NewPbeWithSHA1AndRC2CBC(rand.Reader, 8, 1000) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + testCases = append(testCases, pbes1) + + for _, pbes1 := range testCases { + t.Run("", func(t *testing.T) { + _, ciphertext, err := pbes1.Encrypt(rand.Reader, []byte("password"), []byte("pbes1")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + plaintext, _, err := pbes1.Decrypt([]byte("password"), ciphertext) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if string(plaintext) != "pbes1" { + t.Errorf("unexpected plaintext: got %s, want password", plaintext) + } + }) + } +} diff --git a/pkcs/pkcs5_pbes2.go b/pkcs/pkcs5_pbes2.go index 0c23797..f944c65 100644 --- a/pkcs/pkcs5_pbes2.go +++ b/pkcs/pkcs5_pbes2.go @@ -65,10 +65,11 @@ var ( ) // PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function. -// PBES2-params ::= SEQUENCE { -// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}}, -// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} -// } +// +// PBES2-params ::= SEQUENCE { +// keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}}, +// encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} +// } type PBES2Params struct { KeyDerivationFunc pkix.AlgorithmIdentifier EncryptionScheme pkix.AlgorithmIdentifier @@ -140,6 +141,8 @@ type KDFParameters interface { // DeriveKey derives a key of size bytes from the given password. // It uses the salt from the decoded parameters. DeriveKey(oidKDF asn1.ObjectIdentifier, password []byte, size int) (key []byte, err error) + // KeyLength returns the length of the derived key from the params. + KeyLength() int } var kdfs = make(map[string]func() KDFParameters) @@ -189,39 +192,43 @@ func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KD return plaintext, kdfParams, nil } -// Encrypt encrypts the given plaintext using the given password and the options specified. -func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { +func deriveKey(kdfOpts KDFOpts, rand io.Reader, password []byte, size int) ([]byte, *pkix.AlgorithmIdentifier, error) { // Generate a random salt - salt := make([]byte, opts.KDFOpts.GetSaltSize()) + salt := make([]byte, kdfOpts.GetSaltSize()) if _, err := rand.Read(salt); err != nil { return nil, nil, err } - - // Derive the key - encAlg := opts.Cipher - key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize()) + key, kdfParams, err := kdfOpts.DeriveKey(password, salt, size) if err != nil { return nil, nil, err } + marshalledParams, err := asn1.Marshal(kdfParams) + if err != nil { + return nil, nil, err + } + keyDerivationFunc := pkix.AlgorithmIdentifier{ + Algorithm: kdfOpts.OID(), + Parameters: asn1.RawValue{FullBytes: marshalledParams}, + } + return key, &keyDerivationFunc, nil +} +// Encrypt encrypts the given plaintext using the given password and the options specified. +func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) { + encAlg := opts.Cipher + key, keyDerivationFunc, err := deriveKey(opts.KDFOpts, rand, password, encAlg.KeySize()) + if err != nil { + return nil, nil, err + } // Encrypt the plaintext encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext) if err != nil { return nil, nil, err } - marshalledParams, err := asn1.Marshal(kdfParams) - if err != nil { - return nil, nil, err - } - keyDerivationFunc := pkix.AlgorithmIdentifier{ - Algorithm: opts.KDFOpts.OID(), - Parameters: asn1.RawValue{FullBytes: marshalledParams}, - } - encryptionAlgorithmParams := PBES2Params{ EncryptionScheme: *encryptionScheme, - KeyDerivationFunc: keyDerivationFunc, + KeyDerivationFunc: *keyDerivationFunc, } marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams) if err != nil { diff --git a/pkcs/pkcs5_pbes2_test.go b/pkcs/pkcs5_pbes2_test.go new file mode 100644 index 0000000..35e97b3 --- /dev/null +++ b/pkcs/pkcs5_pbes2_test.go @@ -0,0 +1,70 @@ +package pkcs + +import ( + "crypto/rand" + "encoding/asn1" + "testing" +) + +func TestPBES2(t *testing.T) { + testCases := []struct { + name string + opts PBESEncrypter + }{ + { + name: "PBKDF2-AES128-CBC", + opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA1, 16, 1000)), + }, + { + name: "PBKDF2-AES192-CBC", + opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA1, 16, 1000)), + }, + { + name: "PBKDF2-AES256-CBC", + opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA1, 16, 1000)), + }, + { + name: "PBKDF2(SHA256)-AES128-CBC", + opts: NewPBESEncrypter(AES128CBC, NewPBKDF2Opts(SHA256, 16, 1000)), + }, + { + name: "PBKDF2(SHA256)-AES192-CBC", + opts: NewPBESEncrypter(AES192CBC, NewPBKDF2Opts(SHA256, 16, 1000)), + }, + { + name: "PBKDF2(SHA256)-AES256-CBC", + opts: NewPBESEncrypter(AES256CBC, NewPBKDF2Opts(SHA256, 16, 1000)), + }, + { + name: "PBKDF2(SM3)-SM4-CBC", + opts: NewPBESEncrypter(SM4CBC, NewPBKDF2Opts(SM3, 16, 1000)), + }, + { + name: "SMPBES", + opts: NewSMPBESEncrypter(16, 1000), + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + alg, ciphertext, err := tc.opts.Encrypt(rand.Reader, []byte("password"), []byte("pbes2")) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + pbes2Opts := tc.opts.(*PBES2Opts) + if !alg.Algorithm.Equal(pbes2Opts.pbesOID) { + t.Errorf("unexpected algorithm: got %v, want %v", alg.Algorithm, tc.opts.(*PBES2Opts).pbesOID) + } + var param PBES2Params + if _, err := asn1.Unmarshal(alg.Parameters.FullBytes, ¶m); err != nil { + t.Errorf("unexpected error: %v", err) + } + plaintext, _, err := param.Decrypt([]byte("password"), ciphertext) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if string(plaintext) != "pbes2" { + t.Errorf("unexpected plaintext: got %s, want pbes2", plaintext) + } + }) + } +}