pkcs: refactoring, extract pbes2 from pkcs8

This commit is contained in:
Sun Yimin 2024-07-04 17:29:44 +08:00 committed by GitHub
parent ffddbdcfec
commit 2c87cdf8d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 461 additions and 413 deletions

View File

@ -3,24 +3,27 @@ package pkcs
import (
"crypto/cipher"
"crypto/rand"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"io"
smcipher "github.com/emmansun/gmsm/cipher"
"github.com/emmansun/gmsm/padding"
)
// Cipher represents a cipher for encrypting the key material.
// Cipher represents a cipher for encrypting the key material
// which is used in PBES2.
type Cipher interface {
// KeySize returns the key size of the cipher, in bytes.
KeySize() int
// Encrypt encrypts the key material.
Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error)
// Decrypt decrypts the key material.
Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error)
// Encrypt encrypts the key material. The returned AlgorithmIdentifier is
// the algorithm identifier used for encryption including parameters.
Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error)
// Decrypt decrypts the key material. The parameters are the parameters from the
// DER-encoded AlgorithmIdentifier's.
Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error)
// OID returns the OID of the cipher specified.
OID() asn1.ObjectIdentifier
}
@ -33,6 +36,7 @@ func RegisterCipher(oid asn1.ObjectIdentifier, cipher func() Cipher) {
ciphers[oid.String()] = cipher
}
// GetCipher returns an instance of the cipher specified by the given algorithm identifier.
func GetCipher(alg pkix.AlgorithmIdentifier) (Cipher, error) {
oid := alg.Algorithm.String()
if oid == oidSM4.String() {
@ -67,7 +71,7 @@ type ecbBlockCipher struct {
baseBlockCipher
}
func (ecb *ecbBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
func (ecb *ecbBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
block, err := ecb.newBlock(key)
if err != nil {
return nil, nil, err
@ -106,15 +110,17 @@ type cbcBlockCipher struct {
ivSize int
}
func (c *cbcBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
func (c *cbcBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
block, err := c.newBlock(key)
if err != nil {
return nil, nil, err
}
iv, err := genRandom(c.ivSize)
if err != nil {
iv := make([]byte, c.ivSize)
if _, err := rand.Read(iv); err != nil {
return nil, nil, err
}
ciphertext, err := cbcEncrypt(block, iv, plaintext)
if err != nil {
return nil, nil, err
@ -133,7 +139,7 @@ func (c *cbcBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifi
return &encryptionScheme, ciphertext, nil
}
func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error) {
func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) {
block, err := c.newBlock(key)
if err != nil {
return nil, err
@ -144,7 +150,7 @@ func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encrypte
return nil, errors.New("pkcs: invalid cipher parameters")
}
return cbcDecrypt(block, iv, encryptedKey)
return cbcDecrypt(block, iv, ciphertext)
}
func cbcEncrypt(block cipher.Block, iv, plaintext []byte) ([]byte, error) {
@ -170,6 +176,7 @@ type gcmBlockCipher struct {
}
// https://datatracker.ietf.org/doc/rfc5084/
//
// GCMParameters ::= SEQUENCE {
// aes-nonce OCTET STRING, -- recommended size is 12 octets
// aes-ICVlen AES-GCM-ICVlen DEFAULT 12 }
@ -178,13 +185,14 @@ type gcmParameters struct {
ICVLen int `asn1:"default:12,optional"`
}
func (c *gcmBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
func (c *gcmBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
block, err := c.newBlock(key)
if err != nil {
return nil, nil, err
}
nonce, err := genRandom(c.nonceSize)
if err != nil {
nonce := make([]byte, c.nonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, nil, err
}
@ -210,7 +218,7 @@ func (c *gcmBlockCipher) Encrypt(key, plaintext []byte) (*pkix.AlgorithmIdentifi
return &encryptionAlgorithm, ciphertext, nil
}
func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encryptedKey []byte) ([]byte, error) {
func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) {
block, err := c.newBlock(key)
if err != nil {
return nil, err
@ -228,11 +236,5 @@ func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, encrypte
return nil, errors.New("pkcs: we do not support non-standard tag size")
}
return aead.Open(nil, params.Nonce, encryptedKey, nil)
}
func genRandom(len int) ([]byte, error) {
value := make([]byte, len)
_, err := rand.Read(value)
return value, err
return aead.Open(nil, params.Nonce, ciphertext, nil)
}

View File

@ -2,6 +2,7 @@ package pkcs
import (
"bytes"
"crypto/rand"
"crypto/x509/pkix"
"encoding/asn1"
"testing"
@ -36,7 +37,7 @@ func TestGetCipher(t *testing.T) {
func TestInvalidKeyLen(t *testing.T) {
plaintext := []byte("Hello World")
invalidKey := []byte("123456")
_, _, err := SM4ECB.Encrypt(invalidKey, plaintext)
_, _, err := SM4ECB.Encrypt(rand.Reader, invalidKey, plaintext)
if err == nil {
t.Errorf("should be error")
}
@ -44,7 +45,7 @@ func TestInvalidKeyLen(t *testing.T) {
if err == nil {
t.Errorf("should be error")
}
_, _, err = SM4CBC.Encrypt(invalidKey, plaintext)
_, _, err = SM4CBC.Encrypt(rand.Reader, invalidKey, plaintext)
if err == nil {
t.Errorf("should be error")
}
@ -52,7 +53,7 @@ func TestInvalidKeyLen(t *testing.T) {
if err == nil {
t.Errorf("should be error")
}
_, _, err = SM4GCM.Encrypt(invalidKey, plaintext)
_, _, err = SM4GCM.Encrypt(rand.Reader, invalidKey, plaintext)
if err == nil {
t.Errorf("should be error")
}

View File

@ -1,4 +1,4 @@
package pkcs8
package pkcs
//
// Reference https://datatracker.ietf.org/doc/html/rfc8018#section-5.2
@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/pbkdf2"
)
// http://gmssl.org/docs/oid.html
var (
oidPKCS5PBKDF2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 12}
oidHMACWithSHA1 = asn1.ObjectIdentifier{1, 2, 840, 113549, 2, 7}

View File

@ -1,4 +1,4 @@
package pkcs8
package pkcs
//
// Reference https://datatracker.ietf.org/doc/html/rfc7914

200
pkcs/pbes2.go Normal file
View File

@ -0,0 +1,200 @@
package pkcs
import (
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509/pkix"
"encoding/asn1"
"errors"
"fmt"
"hash"
"io"
"strconv"
"github.com/emmansun/gmsm/sm3"
)
var (
oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
)
// Hash identifies a cryptographic hash function that is implemented in another
// package.
type Hash uint
const (
SHA1 Hash = 1 + iota
SHA224
SHA256
SHA384
SHA512
SHA512_224
SHA512_256
SM3
)
// New returns a new hash.Hash calculating the given hash function. New panics
// if the hash function is not linked into the binary.
func (h Hash) New() hash.Hash {
switch h {
case SM3:
return sm3.New()
case SHA1:
return sha1.New()
case SHA224:
return sha256.New224()
case SHA256:
return sha256.New()
case SHA384:
return sha512.New384()
case SHA512:
return sha512.New()
case SHA512_224:
return sha512.New512_224()
case SHA512_256:
return sha512.New512_256()
}
panic("pkcs5: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable")
}
// PBKDF2Opts contains algorithm identifiers and related parameters for PBKDF2 key derivation function.
type PBES2Params struct {
KeyDerivationFunc pkix.AlgorithmIdentifier
EncryptionScheme pkix.AlgorithmIdentifier
}
// PBES2Opts contains options for encrypting a key using PBES2.
type PBES2Opts struct {
Cipher
KDFOpts
}
// DefaultOpts are the default options for encrypting a key if none are given.
// The defaults can be changed by the library user.
var DefaultOpts = &PBES2Opts{
Cipher: AES256CBC,
KDFOpts: PBKDF2Opts{
SaltSize: 16,
IterationCount: 2048,
HMACHash: SHA256,
},
}
// KDFOpts contains options for a key derivation function.
// An implementation of this interface must be specified when encrypting a PKCS#8 key.
type KDFOpts interface {
// DeriveKey derives a key of size bytes from the given password and salt.
// It returns the key and the ASN.1-encodable parameters used.
DeriveKey(password, salt []byte, size int) (key []byte, params KDFParameters, err error)
// GetSaltSize returns the salt size specified.
GetSaltSize() int
// OID returns the OID of the KDF specified.
OID() asn1.ObjectIdentifier
}
// KDFParameters contains parameters (salt, etc.) for a key deriviation function.
// It must be a ASN.1-decodable structure.
// An implementation of this interface is created when decoding an encrypted PKCS#8 key.
type KDFParameters interface {
// DeriveKey derives a key of size bytes from the given password.
// It uses the salt from the decoded parameters.
DeriveKey(password []byte, size int) (key []byte, err error)
}
var kdfs = make(map[string]func() KDFParameters)
// RegisterKDF registers a function that returns a new instance of the given KDF
// parameters. This allows the library to support client-provided KDFs.
func RegisterKDF(oid asn1.ObjectIdentifier, params func() KDFParameters) {
kdfs[oid.String()] = params
}
func (pbes2Params *PBES2Params) parseKeyDerivationFunc() (KDFParameters, error) {
oid := pbes2Params.KeyDerivationFunc.Algorithm.String()
newParams, ok := kdfs[oid]
if !ok {
return nil, fmt.Errorf("pkcs5: unsupported KDF (OID: %s)", oid)
}
params := newParams()
_, err := asn1.Unmarshal(pbes2Params.KeyDerivationFunc.Parameters.FullBytes, params)
if err != nil {
return nil, errors.New("pkcs5: invalid KDF parameters")
}
return params, nil
}
// Decrypt decrypts the given ciphertext using the given password and the options specified.
func (pbes2Params *PBES2Params) Decrypt(password, ciphertext []byte) ([]byte, KDFParameters, error) {
cipher, err := GetCipher(pbes2Params.EncryptionScheme)
if err != nil {
return nil, nil, err
}
kdfParams, err := pbes2Params.parseKeyDerivationFunc()
if err != nil {
return nil, nil, err
}
keySize := cipher.KeySize()
symkey, err := kdfParams.DeriveKey(password, keySize)
if err != nil {
return nil, nil, err
}
plaintext, err := cipher.Decrypt(symkey, &pbes2Params.EncryptionScheme.Parameters, ciphertext)
if err != nil {
return nil, nil, err
}
return plaintext, kdfParams, nil
}
// Encrypt encrypts the given plaintext using the given password and the options specified.
func (opts *PBES2Opts) Encrypt(rand io.Reader, password, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
// Generate a random salt
salt := make([]byte, opts.KDFOpts.GetSaltSize())
if _, err := rand.Read(salt); err != nil {
return nil, nil, err
}
// Derive the key
encAlg := opts.Cipher
key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize())
if err != nil {
return nil, nil, err
}
// Encrypt the plaintext
encryptionScheme, ciphertext, err := encAlg.Encrypt(rand, key, plaintext)
if err != nil {
return nil, nil, err
}
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: opts.KDFOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
encryptionAlgorithmParams := PBES2Params{
EncryptionScheme: *encryptionScheme,
KeyDerivationFunc: keyDerivationFunc,
}
marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams)
if err != nil {
return nil, nil, err
}
encryptionAlgorithm := pkix.AlgorithmIdentifier{
Algorithm: oidPBES2,
Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams},
}
return &encryptionAlgorithm, ciphertext, nil
}
func IsPBES2(algorithm pkix.AlgorithmIdentifier) bool {
return oidPBES2.Equal(algorithm.Algorithm)
}

View File

@ -1,6 +1,7 @@
package pkcs7
import (
"crypto/rand"
"encoding/asn1"
"errors"
@ -36,7 +37,7 @@ func encryptUsingPSK(cipher pkcs.Cipher, content []byte, key []byte, contentType
return nil, ErrPSKNotProvided
}
id, ciphertext, err := cipher.Encrypt(key, content)
id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content)
if err != nil {
return nil, err
}

View File

@ -121,7 +121,7 @@ func NewEnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, error
return nil, err
}
id, ciphertext, err := cipher.Encrypt(key, content)
id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content)
if err != nil {
return nil, err
}
@ -148,7 +148,7 @@ func NewSM2EnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, er
return nil, err
}
id, ciphertext, err := cipher.Encrypt(key, content)
id, ciphertext, err := cipher.Encrypt(rand.Reader, key, content)
if err != nil {
return nil, err
}

View File

@ -147,7 +147,7 @@ func NewSignedAndEnvelopedData(data []byte, cipher pkcs.Cipher) (*SignedAndEnvel
return nil, err
}
id, ciphertext, err := cipher.Encrypt(key, data)
id, ciphertext, err := cipher.Encrypt(rand.Reader, key, data)
if err != nil {
return nil, err
}

View File

@ -5,103 +5,29 @@ import (
"crypto/ecdsa"
"crypto/rand"
"crypto/rsa"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"errors"
"fmt"
"hash"
"strconv"
"github.com/emmansun/gmsm/pkcs"
"github.com/emmansun/gmsm/sm2"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9"
"github.com/emmansun/gmsm/smx509"
)
// Hash identifies a cryptographic hash function that is implemented in another
// package.
type Hash uint
type Opts = pkcs.PBES2Opts
type PBKDF2Opts = pkcs.PBKDF2Opts
type ScryptOpts = pkcs.ScryptOpts
const (
SHA1 Hash = 1 + iota
SHA224
SHA256
SHA384
SHA512
SHA512_224
SHA512_256
SM3
)
// New returns a new hash.Hash calculating the given hash function. New panics
// if the hash function is not linked into the binary.
func (h Hash) New() hash.Hash {
switch h {
case SM3:
return sm3.New()
case SHA1:
return sha1.New()
case SHA224:
return sha256.New224()
case SHA256:
return sha256.New()
case SHA384:
return sha512.New384()
case SHA512:
return sha512.New()
case SHA512_224:
return sha512.New512_224()
case SHA512_256:
return sha512.New512_256()
}
panic("pkcs8: requested hash function #" + strconv.Itoa(int(h)) + " is unavailable")
}
// DefaultOpts are the default options for encrypting a key if none are given.
// The defaults can be changed by the library user.
var DefaultOpts = &Opts{
Cipher: pkcs.AES256CBC,
KDFOpts: PBKDF2Opts{
SaltSize: 8,
IterationCount: 10000,
HMACHash: SHA256,
},
}
// KDFOpts contains options for a key derivation function.
// An implementation of this interface must be specified when encrypting a PKCS#8 key.
type KDFOpts interface {
// DeriveKey derives a key of size bytes from the given password and salt.
// It returns the key and the ASN.1-encodable parameters used.
DeriveKey(password, salt []byte, size int) (key []byte, params KDFParameters, err error)
// GetSaltSize returns the salt size specified.
GetSaltSize() int
// OID returns the OID of the KDF specified.
OID() asn1.ObjectIdentifier
}
// KDFParameters contains parameters (salt, etc.) for a key deriviation function.
// It must be a ASN.1-decodable structure.
// An implementation of this interface is created when decoding an encrypted PKCS#8 key.
type KDFParameters interface {
// DeriveKey derives a key of size bytes from the given password.
// It uses the salt from the decoded parameters.
DeriveKey(password []byte, size int) (key []byte, err error)
}
var kdfs = make(map[string]func() KDFParameters)
// RegisterKDF registers a function that returns a new instance of the given KDF
// parameters. This allows the library to support client-provided KDFs.
func RegisterKDF(oid asn1.ObjectIdentifier, params func() KDFParameters) {
kdfs[oid.String()] = params
}
var SM3 = pkcs.SM3
var SHA1 = pkcs.SHA1
var SHA224 = pkcs.SHA224
var SHA256 = pkcs.SHA256
var SHA384 = pkcs.SHA384
var SHA512 = pkcs.SHA512
var SHA512_224 = pkcs.SHA512_224
var SHA512_256 = pkcs.SHA512_256
// for encrypted private-key information
type encryptedPrivateKeyInfo struct {
@ -109,40 +35,10 @@ type encryptedPrivateKeyInfo struct {
EncryptedData []byte
}
// Opts contains options for encrypting a PKCS#8 key.
type Opts struct {
Cipher pkcs.Cipher
KDFOpts KDFOpts
}
// Unecrypted PKCS8
var (
oidPBES2 = asn1.ObjectIdentifier{1, 2, 840, 113549, 1, 5, 13}
)
type pbes2Params struct {
KeyDerivationFunc pkix.AlgorithmIdentifier
EncryptionScheme pkix.AlgorithmIdentifier
}
func parseKeyDerivationFunc(keyDerivationFunc pkix.AlgorithmIdentifier) (KDFParameters, error) {
oid := keyDerivationFunc.Algorithm.String()
newParams, ok := kdfs[oid]
if !ok {
return nil, fmt.Errorf("pkcs8: unsupported KDF (OID: %s)", oid)
}
params := newParams()
_, err := asn1.Unmarshal(keyDerivationFunc.Parameters.FullBytes, params)
if err != nil {
return nil, errors.New("pkcs8: invalid KDF parameters")
}
return params, nil
}
// ParsePrivateKey parses a DER-encoded PKCS#8 private key.
// Password can be nil.
// This is equivalent to ParsePKCS8PrivateKey.
func ParsePrivateKey(der []byte, password []byte) (any, KDFParameters, error) {
func ParsePrivateKey(der []byte, password []byte) (any, pkcs.KDFParameters, error) {
// No password provided, assume the private key is unencrypted
if len(password) == 0 {
privateKey, err := smx509.ParsePKCS8PrivateKey(der)
@ -158,33 +54,16 @@ func ParsePrivateKey(der []byte, password []byte) (any, KDFParameters, error) {
return nil, nil, errors.New("pkcs8: only PKCS #5 v2.0 supported")
}
if !privKey.EncryptionAlgorithm.Algorithm.Equal(oidPBES2) {
if !pkcs.IsPBES2(privKey.EncryptionAlgorithm) {
return nil, nil, errors.New("pkcs8: only PBES2 supported")
}
var params pbes2Params
var params pkcs.PBES2Params
if _, err := asn1.Unmarshal(privKey.EncryptionAlgorithm.Parameters.FullBytes, &params); err != nil {
return nil, nil, errors.New("pkcs8: invalid PBES2 parameters")
}
cipher, err := pkcs.GetCipher(params.EncryptionScheme)
if err != nil {
return nil, nil, err
}
kdfParams, err := parseKeyDerivationFunc(params.KeyDerivationFunc)
if err != nil {
return nil, nil, err
}
keySize := cipher.KeySize()
symkey, err := kdfParams.DeriveKey(password, keySize)
if err != nil {
return nil, nil, err
}
encryptedKey := privKey.EncryptedData
decryptedKey, err := cipher.Decrypt(symkey, &params.EncryptionScheme.Parameters, encryptedKey)
decryptedKey, kdfParams, err := params.Decrypt(password, privKey.EncryptedData)
if err != nil {
return nil, nil, err
}
@ -204,7 +83,7 @@ func MarshalPrivateKey(priv any, password []byte, opts *Opts) ([]byte, error) {
}
if opts == nil {
opts = DefaultOpts
opts = pkcs.DefaultOpts
}
// Convert private key into PKCS8 format
@ -213,47 +92,13 @@ func MarshalPrivateKey(priv any, password []byte, opts *Opts) ([]byte, error) {
return nil, err
}
encAlg := opts.Cipher
salt := make([]byte, opts.KDFOpts.GetSaltSize())
_, err = rand.Read(salt)
encryptionAlgorithm, encryptedKey, err := opts.Encrypt(rand.Reader, password, pkey)
if err != nil {
return nil, err
}
key, kdfParams, err := opts.KDFOpts.DeriveKey(password, salt, encAlg.KeySize())
if err != nil {
return nil, err
}
encryptionScheme, encryptedKey, err := encAlg.Encrypt(key, pkey)
if err != nil {
return nil, err
}
marshalledParams, err := asn1.Marshal(kdfParams)
if err != nil {
return nil, err
}
keyDerivationFunc := pkix.AlgorithmIdentifier{
Algorithm: opts.KDFOpts.OID(),
Parameters: asn1.RawValue{FullBytes: marshalledParams},
}
encryptionAlgorithmParams := pbes2Params{
EncryptionScheme: *encryptionScheme,
KeyDerivationFunc: keyDerivationFunc,
}
marshalledEncryptionAlgorithmParams, err := asn1.Marshal(encryptionAlgorithmParams)
if err != nil {
return nil, err
}
encryptionAlgorithm := pkix.AlgorithmIdentifier{
Algorithm: oidPBES2,
Parameters: asn1.RawValue{FullBytes: marshalledEncryptionAlgorithmParams},
}
encryptedPkey := encryptedPrivateKeyInfo{
EncryptionAlgorithm: encryptionAlgorithm,
EncryptionAlgorithm: *encryptionAlgorithm,
EncryptedData: encryptedKey,
}