gmsm/smx509/csr_rsp.go
2024-06-13 17:44:32 +08:00

147 lines
4.6 KiB
Go

// Marshal & Parse CSRResponse which is defined in GM/T 0092-2020
// Specification of certificate request syntax based on SM2 cryptographic algorithm.
package smx509
import (
"bytes"
"crypto/ecdsa"
"crypto/rand"
"encoding/asn1"
"errors"
"github.com/emmansun/gmsm/sm2"
)
// CSRResponse represents the response of a certificate signing request.
type CSRResponse struct {
SignCerts []*Certificate
EncryptPrivateKey *sm2.PrivateKey
EncryptCerts []*Certificate
}
type tbsCSRResponse struct {
SignCerts rawCertificates
EncryptedPrivateKey asn1.RawValue `asn1:"optional,tag:0"`
EncryptCerts rawCertificates `asn1:"optional,tag:1"`
}
type rawCertificates struct {
Raw asn1.RawContent
}
// ParseCSRResponse parses a CSRResponse from DER format.
// We do NOT verify the cert chain here, it's the caller's responsibility.
func ParseCSRResponse(signPrivateKey *sm2.PrivateKey, der []byte) (CSRResponse, error) {
result := CSRResponse{}
resp := &tbsCSRResponse{}
rest, err := asn1.Unmarshal(der, resp)
if err != nil || len(rest) > 0 {
return result, errors.New("smx509: invalid CSRResponse asn1 data")
}
signCerts, err := resp.SignCerts.Parse()
if err != nil || len(signCerts) == 0 {
return result, errors.New("smx509: invalid sign certificates")
}
// check sign public key against the private key
if !signPrivateKey.PublicKey.Equal(signCerts[0].PublicKey) {
return result, errors.New("smx509: sign cert public key mismatch")
}
var encPrivateKey *sm2.PrivateKey
if len(resp.EncryptedPrivateKey.Bytes) > 0 {
encPrivateKey, err = sm2.ParseEnvelopedPrivateKey(signPrivateKey, resp.EncryptedPrivateKey.Bytes)
if err != nil {
return result, err
}
}
var encryptCerts []*Certificate
if len(resp.EncryptCerts.Raw) > 0 {
encryptCerts, err = resp.EncryptCerts.Parse()
if err != nil {
return result, err
}
}
// check the public key of the encrypt certificate
if encPrivateKey != nil && len(encryptCerts) == 0 {
return result, errors.New("smx509: missing encrypt certificate")
}
if encPrivateKey != nil && !encPrivateKey.PublicKey.Equal(encryptCerts[0].PublicKey) {
return result, errors.New("smx509: encrypt key pair mismatch")
}
result.SignCerts = signCerts
result.EncryptPrivateKey = encPrivateKey
result.EncryptCerts = encryptCerts
return result, nil
}
// MarshalCSRResponse marshals a CSRResponse to DER format.
func MarshalCSRResponse(signCerts []*Certificate, encryptPrivateKey *sm2.PrivateKey, encryptCerts []*Certificate) ([]byte, error) {
if len(signCerts) == 0 {
return nil, errors.New("smx509: no sign certificate")
}
signPubKey, ok := signCerts[0].PublicKey.(*ecdsa.PublicKey)
if !ok || !sm2.IsSM2PublicKey(signPubKey) {
return nil, errors.New("smx509: invalid sign public key")
}
// check the public key of the encrypt certificate
if encryptPrivateKey != nil && len(encryptCerts) == 0 {
return nil, errors.New("smx509: missing encrypt certificate")
}
if encryptPrivateKey != nil && !encryptPrivateKey.PublicKey.Equal(encryptCerts[0].PublicKey) {
return nil, errors.New("smx509: encrypt key pair mismatch")
}
resp := tbsCSRResponse{}
resp.SignCerts = marshalCertificates(signCerts)
if encryptPrivateKey != nil && len(encryptCerts) > 0 {
privateKeyBytes, err := sm2.MarshalEnvelopedPrivateKey(rand.Reader, signPubKey, encryptPrivateKey)
if err != nil {
return nil, err
}
resp.EncryptedPrivateKey = asn1.RawValue{Class: 2, Tag: 0, IsCompound: true, Bytes: privateKeyBytes}
resp.EncryptCerts = marshalCertificates(encryptCerts)
}
return asn1.Marshal(resp)
}
// concats and wraps the certificates in the RawValue structure
func marshalCertificates(certs []*Certificate) rawCertificates {
var buf bytes.Buffer
for _, cert := range certs {
buf.Write(cert.Raw)
}
rawCerts, _ := marshalCertificateBytes(buf.Bytes())
return rawCerts
}
// Even though, the tag & length are stripped out during marshalling the
// RawContent, we have to encode it into the RawContent. If its missing,
// then `asn1.Marshal()` will strip out the certificate wrapper instead.
func marshalCertificateBytes(certs []byte) (rawCertificates, error) {
var val = asn1.RawValue{Bytes: certs, Class: 2, Tag: 0, IsCompound: true}
b, err := asn1.Marshal(val)
if err != nil {
return rawCertificates{}, err
}
return rawCertificates{Raw: b}, nil
}
func (raw rawCertificates) Parse() ([]*Certificate, error) {
if len(raw.Raw) == 0 {
return nil, nil
}
var val asn1.RawValue
if _, err := asn1.Unmarshal(raw.Raw, &val); err != nil {
return nil, err
}
return ParseCertificates(val.Bytes)
}