diff --git a/pkcs7/decrypt.go b/pkcs7/decrypt.go index 43df784..bf183ab 100644 --- a/pkcs7/decrypt.go +++ b/pkcs7/decrypt.go @@ -3,13 +3,11 @@ package pkcs7 import ( "bytes" "crypto" - "crypto/rand" "encoding/asn1" "errors" "math/big" "github.com/emmansun/gmsm/pkcs" - "github.com/emmansun/gmsm/sm2" "github.com/emmansun/gmsm/smx509" ) @@ -93,26 +91,11 @@ func (p7 *PKCS7) decrypt(cert *smx509.Certificate, pkey crypto.PrivateKey, isCFC if recipient == nil { return nil, errors.New("pkcs7: no enveloped recipient for provided certificate") } - - switch pkey := pkey.(type) { - case crypto.Decrypter: - // Generic case to handle anything that provides the crypto.Decrypter interface. - encryptedKey := recipient.EncryptedKey - var decrypterOpts crypto.DecrypterOpts - if _, ok := pkey.(*sm2.PrivateKey); ok && isCFCA { - encryptedKey = make([]byte, len(recipient.EncryptedKey)+1) - encryptedKey[0] = 0x04 - copy(encryptedKey[1:], recipient.EncryptedKey) - decrypterOpts = sm2.NewPlainDecrypterOpts(sm2.C1C2C3) - } - - contentKey, err := pkey.Decrypt(rand.Reader, encryptedKey, decrypterOpts) - if err != nil { - return nil, err - } - return decryptableData.GetEncryptedContentInfo().decrypt(contentKey) + contentKey, err := p7.session.DecryptDataKey(recipient.EncryptedKey, pkey, cert, isCFCA) + if err != nil { + return nil, err } - return nil, ErrUnsupportedAlgorithm + return decryptableData.GetEncryptedContentInfo().decrypt(contentKey) } // DecryptUsingPSK decrypts encrypted data using caller provided diff --git a/pkcs7/envelope.go b/pkcs7/envelope.go index 45ab908..4c08d1b 100644 --- a/pkcs7/envelope.go +++ b/pkcs7/envelope.go @@ -2,16 +2,13 @@ package pkcs7 import ( "bytes" - "crypto/ecdsa" "crypto/rand" - "crypto/rsa" "crypto/sha1" "crypto/x509/pkix" "encoding/asn1" "errors" "github.com/emmansun/gmsm/pkcs" - "github.com/emmansun/gmsm/sm2" "github.com/emmansun/gmsm/smx509" "golang.org/x/crypto/cryptobyte" cryptobyte_asn1 "golang.org/x/crypto/cryptobyte/asn1" @@ -22,6 +19,7 @@ type EnvelopedData struct { key []byte contentType asn1.ObjectIdentifier encryptedContentType asn1.ObjectIdentifier + session Session } type envelopedData struct { @@ -106,7 +104,7 @@ func Encrypt(cipher pkcs.Cipher, content []byte, recipients []*smx509.Certificat } for _, recipient := range recipients { if err := ed.AddRecipient(recipient, 0, func(cert *smx509.Certificate, key []byte) ([]byte, error) { - return encryptKey(key, cert, false) + return ed.session.EncryptdDataKey(key, cert, nil) }); err != nil { return nil, err } @@ -136,7 +134,7 @@ func EncryptCFCA(cipher pkcs.Cipher, content []byte, recipients []*smx509.Certif // recipient keys for each recipient public key. // The OIDs use GM/T 0010 - 2012 set and the encrypted key uses ASN.1 format. // This function uses recipient's SubjectKeyIdentifier to identify the recipient. -// This function is used for CFCA compatibility. +// This function is used for CFCA compatibility. func EnvelopeMessageCFCA(cipher pkcs.Cipher, content []byte, recipients []*smx509.Certificate) ([]byte, error) { return encryptSM(cipher, content, recipients, 2, false) } @@ -148,7 +146,7 @@ func encryptSM(cipher pkcs.Cipher, content []byte, recipients []*smx509.Certific } for _, recipient := range recipients { if err := ed.AddRecipient(recipient, version, func(cert *smx509.Certificate, key []byte) ([]byte, error) { - return encryptKey(key, cert, isLegacyCFCA) + return ed.session.EncryptdDataKey(key, cert, isLegacyCFCA) }); err != nil { return nil, err } @@ -158,22 +156,35 @@ func encryptSM(cipher pkcs.Cipher, content []byte, recipients []*smx509.Certific // NewEnvelopedData creates a new EnvelopedData structure with the provided cipher and content. func NewEnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, error) { - return newEnvelopedData(cipher, content, OIDEnvelopedData) + return newEnvelopedData(cipher, content, OIDEnvelopedData, nil) } // NewSM2EnvelopedData creates a new EnvelopedData structure with the provided cipher and content. // The OIDs use GM/T 0010 - 2012 set. func NewSM2EnvelopedData(cipher pkcs.Cipher, content []byte) (*EnvelopedData, error) { - return newEnvelopedData(cipher, content, SM2OIDEnvelopedData) + return newEnvelopedData(cipher, content, SM2OIDEnvelopedData, nil) } -func newEnvelopedData(cipher pkcs.Cipher, content []byte, contentType asn1.ObjectIdentifier) (*EnvelopedData, error) { - var key []byte - var err error +// NewEnvelopedDataWithSession creates a new EnvelopedData structure with the provided cipher, content and sessionKey. +func NewEnvelopedDataWithSession(cipher pkcs.Cipher, content []byte, session Session) (*EnvelopedData, error) { + return newEnvelopedData(cipher, content, OIDEnvelopedData, session) +} - // Create key - key = make([]byte, cipher.KeySize()) - if _, err = rand.Read(key); err != nil { +// NewSM2EnvelopedDataWithSession creates a new EnvelopedData structure with the provided cipher, content and sessionKey. +// The OIDs use GM/T 0010 - 2012 set. +func NewSM2EnvelopedDataWithSession(cipher pkcs.Cipher, content []byte, session Session) (*EnvelopedData, error) { + return newEnvelopedData(cipher, content, SM2OIDEnvelopedData, session) +} + +func newEnvelopedData(cipher pkcs.Cipher, content []byte, contentType asn1.ObjectIdentifier, session Session) (*EnvelopedData, error) { + ed := &EnvelopedData{} + ed.session = session + if ed.session == nil { + ed.session = DefaultSession{} + } + + key, err := ed.session.GenerateDataKey(cipher.KeySize()) + if err != nil { return nil, err } @@ -181,7 +192,7 @@ func newEnvelopedData(cipher pkcs.Cipher, content []byte, contentType asn1.Objec if err != nil { return nil, err } - ed := &EnvelopedData{} + ed.contentType = contentType ed.encryptedContentType = OIDData version := 0 @@ -274,21 +285,3 @@ func newEncryptedContent(contentType asn1.ObjectIdentifier, alg *pkix.AlgorithmI func marshalEncryptedContent(content []byte) asn1.RawValue { return asn1.RawValue{Tag: 0, Class: asn1.ClassContextSpecific, Bytes: content} } - -func encryptKey(key []byte, recipient *smx509.Certificate, isCFCA bool) ([]byte, error) { - if pub, ok := recipient.PublicKey.(*rsa.PublicKey); ok { - return rsa.EncryptPKCS1v15(rand.Reader, pub, key) - } - if pub, ok := recipient.PublicKey.(*ecdsa.PublicKey); ok && pub.Curve == sm2.P256() { - if isCFCA { - encryptedKey, err := sm2.Encrypt(rand.Reader, pub, key, sm2.NewPlainEncrypterOpts(sm2.MarshalUncompressed, sm2.C1C2C3)) - if err != nil { - return nil, err - } - return encryptedKey[1:], nil - } else { - return sm2.EncryptASN1(rand.Reader, pub, key) - } - } - return nil, errors.New("pkcs7: only supports RSA/SM2 key") -} diff --git a/pkcs7/pkcs7.go b/pkcs7/pkcs7.go index 4594ca1..0fa2a6c 100644 --- a/pkcs7/pkcs7.go +++ b/pkcs7/pkcs7.go @@ -27,6 +27,7 @@ type PKCS7 struct { CRLs []pkix.CertificateList Signers []signerInfo raw any + session Session } type contentInfo struct { @@ -195,8 +196,13 @@ func getOIDForEncryptionAlgorithm(pkey any, OIDDigestAlg asn1.ObjectIdentifier) } -// Parse decodes a DER encoded PKCS7 package +// Parse decodes a DER encoded PKCS7 package and assign the default session to the PKCS7 object func Parse(data []byte) (p7 *PKCS7, err error) { + return ParseWithSession(DefaultSession{}, data) +} + +// ParseWithSession decodes a DER encoded PKCS7 package and assign the session to the PKCS7 object +func ParseWithSession(session Session, data []byte) (p7 *PKCS7, err error) { if len(data) == 0 { return nil, errors.New("pkcs7: input data is empty") } @@ -218,33 +224,35 @@ func Parse(data []byte) (p7 *PKCS7, err error) { case info.ContentType.Equal(OIDSignedData) || info.ContentType.Equal(SM2OIDSignedData): return parseSignedData(info.Content.Bytes) case info.ContentType.Equal(OIDEnvelopedData) || info.ContentType.Equal(SM2OIDEnvelopedData): - return parseEnvelopedData(info.Content.Bytes) + return parseEnvelopedData(session, info.Content.Bytes) case info.ContentType.Equal(OIDEncryptedData) || info.ContentType.Equal(SM2OIDEncryptedData): - return parseEncryptedData(info.Content.Bytes) + return parseEncryptedData(session, info.Content.Bytes) case info.ContentType.Equal(OIDSignedEnvelopedData) || info.ContentType.Equal(SM2OIDSignedEnvelopedData): - return parseSignedEnvelopedData(info.Content.Bytes) + return parseSignedEnvelopedData(session, info.Content.Bytes) default: return nil, ErrUnsupportedContentType } } -func parseEnvelopedData(data []byte) (*PKCS7, error) { +func parseEnvelopedData(session Session, data []byte) (*PKCS7, error) { var ed envelopedData if _, err := asn1.Unmarshal(data, &ed); err != nil { return nil, err } return &PKCS7{ - raw: ed, + raw: ed, + session: session, }, nil } -func parseEncryptedData(data []byte) (*PKCS7, error) { +func parseEncryptedData(session Session, data []byte) (*PKCS7, error) { var ed encryptedData if _, err := asn1.Unmarshal(data, &ed); err != nil { return nil, err } return &PKCS7{ - raw: ed, + raw: ed, + session: session, }, nil } diff --git a/pkcs7/session.go b/pkcs7/session.go new file mode 100644 index 0000000..ab74948 --- /dev/null +++ b/pkcs7/session.go @@ -0,0 +1,78 @@ +package pkcs7 + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/rsa" + "errors" + + "github.com/emmansun/gmsm/sm2" + "github.com/emmansun/gmsm/smx509" +) + +type Session interface { + // GenerateDataKey returns the data key to be used for encryption + GenerateDataKey(size int) ([]byte, error) + + // EncryptdDataKey encrypts the key with the provided certificate public key + EncryptdDataKey(key []byte, cert *smx509.Certificate, opts any) ([]byte, error) + + // DecryptDataKey decrypts the key with the provided certificate private key + DecryptDataKey(key []byte, priv crypto.PrivateKey, cert *smx509.Certificate, opts any) ([]byte, error) +} + +// DefaultSession is the default implementation of Session without any special handling +// Custom implementations can be provided to handle key reuse, cache, etc. +type DefaultSession struct{} + +func (d DefaultSession) GenerateDataKey(size int) ([]byte, error) { + key := make([]byte, size) + if _, err := rand.Read(key); err != nil { + return nil, err + } + return key, nil +} + +func (d DefaultSession) EncryptdDataKey(key []byte, cert *smx509.Certificate, opts any) ([]byte, error) { + switch pub := cert.PublicKey.(type) { + case *rsa.PublicKey: + return rsa.EncryptPKCS1v15(rand.Reader, pub, key) + case *ecdsa.PublicKey: + if pub.Curve == sm2.P256() { + if isLegacyCFCA, ok := opts.(bool); ok && isLegacyCFCA { + encryptedKey, err := sm2.Encrypt(rand.Reader, pub, key, sm2.NewPlainEncrypterOpts(sm2.MarshalUncompressed, sm2.C1C2C3)) + if err != nil { + return nil, err + } + return encryptedKey[1:], nil + } else { + return sm2.EncryptASN1(rand.Reader, pub, key) + } + } + } + return nil, errors.New("pkcs7: only supports RSA/SM2 key") +} + +func (d DefaultSession) DecryptDataKey(key []byte, priv crypto.PrivateKey, cert *smx509.Certificate, opts any) ([]byte, error) { + switch pkey := priv.(type) { + case crypto.Decrypter: + // Generic case to handle anything that provides the crypto.Decrypter interface. + encryptedKey := key + var decrypterOpts crypto.DecrypterOpts + if _, ok := pkey.(*sm2.PrivateKey); ok { + if isLegacyCFCA, ok := opts.(bool); ok && isLegacyCFCA { + encryptedKey = make([]byte, len(key)+1) + encryptedKey[0] = 0x04 + copy(encryptedKey[1:], key) + decrypterOpts = sm2.NewPlainDecrypterOpts(sm2.C1C2C3) + } + } + contentKey, err := pkey.Decrypt(rand.Reader, encryptedKey, decrypterOpts) + if err != nil { + return nil, err + } + return contentKey, nil + } + return nil, ErrUnsupportedAlgorithm +} diff --git a/pkcs7/sign_enveloped.go b/pkcs7/sign_enveloped.go index d2bb4f3..dba76c1 100644 --- a/pkcs7/sign_enveloped.go +++ b/pkcs7/sign_enveloped.go @@ -47,7 +47,7 @@ func (data signedEnvelopedData) GetEncryptedContentInfo() *encryptedContentInfo return &data.EncryptedContentInfo } -func parseSignedEnvelopedData(data []byte) (*PKCS7, error) { +func parseSignedEnvelopedData(session Session, data []byte) (*PKCS7, error) { var sed signedEnvelopedData if _, err := asn1.Unmarshal(data, &sed); err != nil { return nil, err @@ -61,7 +61,8 @@ func parseSignedEnvelopedData(data []byte) (*PKCS7, error) { Certificates: certs, CRLs: sed.CRLs, Signers: sed.SignerInfos, - raw: sed}, nil + raw: sed, + session: session}, nil } type VerifyFunc func() error @@ -140,6 +141,7 @@ type SignedAndEnvelopedData struct { data, cek []byte contentTypeOid asn1.ObjectIdentifier digestOid asn1.ObjectIdentifier + session Session } // NewSignedAndEnvelopedData takes data and cipher and initializes a new PKCS7 SignedAndEnvelopedData structure @@ -149,9 +151,11 @@ func NewSignedAndEnvelopedData(data []byte, cipher pkcs.Cipher) (*SignedAndEnvel var key []byte var err error - // Create key - key = make([]byte, cipher.KeySize()) - _, err = rand.Read(key) + result := &SignedAndEnvelopedData{ + session: DefaultSession{}, + } + + key, err = result.session.GenerateDataKey(cipher.KeySize()) if err != nil { return nil, err } @@ -160,8 +164,11 @@ func NewSignedAndEnvelopedData(data []byte, cipher pkcs.Cipher) (*SignedAndEnvel if err != nil { return nil, err } - - sed := signedEnvelopedData{ + result.cek = key + result.contentTypeOid = OIDSignedEnvelopedData + result.data = data + result.digestOid = OIDDigestAlgorithmSHA1 + result.sed = signedEnvelopedData{ Version: 1, // 0 or 1? EncryptedContentInfo: encryptedContentInfo{ ContentType: OIDData, @@ -169,7 +176,7 @@ func NewSignedAndEnvelopedData(data []byte, cipher pkcs.Cipher) (*SignedAndEnvel EncryptedContent: marshalEncryptedContent(ciphertext), }, } - return &SignedAndEnvelopedData{sed: sed, data: data, cek: key, digestOid: OIDDigestAlgorithmSHA1, contentTypeOid: OIDSignedEnvelopedData}, nil + return result, nil } // NewSMSignedAndEnvelopedData takes data and cipher and initializes a new PKCS7(SM) SignedAndEnvelopedData structure @@ -271,7 +278,7 @@ func (saed *SignedAndEnvelopedData) AddCertificate(cert *smx509.Certificate) { // AddRecipient adds a recipient to the payload func (saed *SignedAndEnvelopedData) AddRecipient(recipient *smx509.Certificate) error { - encryptedKey, err := encryptKey(saed.cek, recipient, false) //TODO: check if CFCA has such function + encryptedKey, err := saed.session.EncryptdDataKey(saed.cek, recipient, nil) if err != nil { return err }