mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 04:06:18 +08:00
sm2: code review and refactor
This commit is contained in:
parent
a71e806a2d
commit
89317b8f0b
@ -1,21 +1,12 @@
|
|||||||
// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms.
|
|
||||||
package sm2
|
package sm2
|
||||||
|
|
||||||
// Further references:
|
|
||||||
// [NSA]: Suite B implementer's guide to FIPS 186-3
|
|
||||||
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf
|
|
||||||
// [SECG]: SECG, SEC1
|
|
||||||
// http://www.secg.org/sec1-v2.pdf
|
|
||||||
// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016
|
|
||||||
//
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto"
|
"crypto"
|
||||||
"crypto/ecdsa"
|
"crypto/ecdsa"
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
_subtle "crypto/subtle"
|
_subtle "crypto/subtle"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"sync"
|
"sync"
|
||||||
@ -24,91 +15,12 @@ import (
|
|||||||
"github.com/emmansun/gmsm/internal/bigmod"
|
"github.com/emmansun/gmsm/internal/bigmod"
|
||||||
"github.com/emmansun/gmsm/internal/randutil"
|
"github.com/emmansun/gmsm/internal/randutil"
|
||||||
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
|
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
|
||||||
"github.com/emmansun/gmsm/sm2/sm2ec"
|
"github.com/emmansun/gmsm/sm2/sm2ec"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
"golang.org/x/crypto/cryptobyte/asn1"
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
|
||||||
uncompressed byte = 0x04
|
|
||||||
compressed02 byte = 0x02
|
|
||||||
compressed03 byte = compressed02 | 0x01
|
|
||||||
hybrid06 byte = 0x06
|
|
||||||
hybrid07 byte = hybrid06 | 0x01
|
|
||||||
)
|
|
||||||
|
|
||||||
// PrivateKey represents an ECDSA SM2 private key.
|
|
||||||
// It implemented both crypto.Decrypter and crypto.Signer interfaces.
|
|
||||||
type PrivateKey struct {
|
|
||||||
ecdsa.PrivateKey
|
|
||||||
// inverseOfKeyPlus1 is set under inverseOfKeyPlus1Once
|
|
||||||
inverseOfKeyPlus1 *bigmod.Nat
|
|
||||||
inverseOfKeyPlus1Once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
type pointMarshalMode byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
//MarshalUncompressed uncompressed mashal mode
|
|
||||||
MarshalUncompressed pointMarshalMode = iota
|
|
||||||
//MarshalCompressed compressed mashal mode
|
|
||||||
MarshalCompressed
|
|
||||||
//MarshalHybrid hybrid mashal mode
|
|
||||||
MarshalHybrid
|
|
||||||
)
|
|
||||||
|
|
||||||
type ciphertextSplicingOrder byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
C1C3C2 ciphertextSplicingOrder = iota
|
|
||||||
C1C2C3
|
|
||||||
)
|
|
||||||
|
|
||||||
type ciphertextEncoding byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
ENCODING_PLAIN ciphertextEncoding = iota
|
|
||||||
ENCODING_ASN1
|
|
||||||
)
|
|
||||||
|
|
||||||
// EncrypterOpts encryption options
|
|
||||||
type EncrypterOpts struct {
|
|
||||||
ciphertextEncoding ciphertextEncoding
|
|
||||||
pointMarshalMode pointMarshalMode
|
|
||||||
ciphertextSplicingOrder ciphertextSplicingOrder
|
|
||||||
}
|
|
||||||
|
|
||||||
// DecrypterOpts decryption options
|
|
||||||
type DecrypterOpts struct {
|
|
||||||
ciphertextEncoding ciphertextEncoding
|
|
||||||
cipherTextSplicingOrder ciphertextSplicingOrder
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options.
|
|
||||||
func NewPlainEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts {
|
|
||||||
return &EncrypterOpts{ENCODING_PLAIN, marhsalMode, splicingOrder}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options.
|
|
||||||
func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts {
|
|
||||||
return &DecrypterOpts{ENCODING_PLAIN, splicingOrder}
|
|
||||||
}
|
|
||||||
|
|
||||||
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
|
||||||
byteLen := (curve.Params().BitSize + 7) >> 3
|
|
||||||
result := make([]byte, byteLen)
|
|
||||||
value.FillBytes(result)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2}
|
|
||||||
|
|
||||||
var ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2}
|
|
||||||
|
|
||||||
var ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2}
|
|
||||||
|
|
||||||
// directSigning is a standard Hash value that signals that no pre-hashing
|
// directSigning is a standard Hash value that signals that no pre-hashing
|
||||||
// should be performed.
|
// should be performed.
|
||||||
var directSigning crypto.Hash = 0
|
var directSigning crypto.Hash = 0
|
||||||
@ -118,7 +30,7 @@ type Signer interface {
|
|||||||
SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error)
|
SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SM2SignerOption implements crypto.SignerOpts interface.
|
// SM2SignerOption implements crypto.SignerOpts interface and is used for SM2-specific signing options.
|
||||||
// It is specific for SM2, used in private key's Sign method.
|
// It is specific for SM2, used in private key's Sign method.
|
||||||
type SM2SignerOption struct {
|
type SM2SignerOption struct {
|
||||||
uid []byte
|
uid []byte
|
||||||
@ -146,11 +58,28 @@ func (*SM2SignerOption) HashFunc() crypto.Hash {
|
|||||||
return directSigning
|
return directSigning
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errInvalidPrivateKey = errors.New("sm2: invalid private key")
|
||||||
|
errInvalidPublicKey = errors.New("sm2: invalid public key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrivateKey represents an ECDSA SM2 private key.
|
||||||
|
// It embeds ecdsa.PrivateKey and includes additional fields for SM2-specific operations.
|
||||||
|
// It implements both crypto.Decrypter and crypto.Signer interfaces.
|
||||||
|
type PrivateKey struct {
|
||||||
|
ecdsa.PrivateKey
|
||||||
|
// inverseOfKeyPlus1 stores the modular inverse of (private key + 1) modulo the curve order.
|
||||||
|
// It is computed lazily and cached using sync.Once to ensure it is only calculated once.
|
||||||
|
inverseOfKeyPlus1 *bigmod.Nat
|
||||||
|
inverseOfKeyPlus1Once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
// FromECPrivateKey convert an ecdsa private key to SM2 private key.
|
// FromECPrivateKey convert an ecdsa private key to SM2 private key.
|
||||||
func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) {
|
func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) {
|
||||||
if key.Curve != sm2ec.P256() {
|
if key.Curve != sm2ec.P256() {
|
||||||
return nil, errors.New("sm2: it's NOT a sm2 curve private key")
|
return nil, errors.New("sm2: not an SM2 curve private key")
|
||||||
}
|
}
|
||||||
|
// Copy the ECDSA private key fields to the SM2 private key
|
||||||
priv.PrivateKey = *key
|
priv.PrivateKey = *key
|
||||||
return priv, nil
|
return priv, nil
|
||||||
}
|
}
|
||||||
@ -160,13 +89,7 @@ func (priv *PrivateKey) Equal(x crypto.PrivateKey) bool {
|
|||||||
if !ok {
|
if !ok {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return priv.PublicKey.Equal(&xx.PublicKey) && bigIntEqual(priv.D, xx.D)
|
return priv.PublicKey.Equal(&xx.PublicKey) && _subtle.ConstantTimeCompare(priv.D.Bytes(), xx.D.Bytes()) == 1
|
||||||
}
|
|
||||||
|
|
||||||
// bigIntEqual reports whether a and b are equal leaking only their bit length
|
|
||||||
// through timing side-channels.
|
|
||||||
func bigIntEqual(a, b *big.Int) bool {
|
|
||||||
return _subtle.ConstantTimeCompare(a.Bytes(), b.Bytes()) == 1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sign signs digest with priv, reading randomness from rand. Compliance with GB/T 32918.2-2016.
|
// Sign signs digest with priv, reading randomness from rand. Compliance with GB/T 32918.2-2016.
|
||||||
@ -186,124 +109,6 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
|
|||||||
return priv.Sign(rand, msg, NewSM2SignerOption(true, uid))
|
return priv.Sign(rand, msg, NewSM2SignerOption(true, uid))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt decrypts ciphertext msg to plaintext.
|
|
||||||
// The opts argument should be appropriate for the primitive used.
|
|
||||||
// Compliance with GB/T 32918.4-2016 chapter 7.
|
|
||||||
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
|
||||||
var sm2Opts *DecrypterOpts
|
|
||||||
sm2Opts, _ = opts.(*DecrypterOpts)
|
|
||||||
return decrypt(priv, msg, sm2Opts)
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxRetryLimit = 100
|
|
||||||
|
|
||||||
var (
|
|
||||||
errCiphertextTooShort = errors.New("sm2: ciphertext too short")
|
|
||||||
)
|
|
||||||
|
|
||||||
// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016.
|
|
||||||
//
|
|
||||||
// The random parameter is used as a source of entropy to ensure that
|
|
||||||
// encrypting the same message twice doesn't result in the same ciphertext.
|
|
||||||
// Most applications should use [crypto/rand.Reader] as random.
|
|
||||||
func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
|
||||||
return Encrypt(random, pub, msg, ASN1EncrypterOpts)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016.
|
|
||||||
//
|
|
||||||
// The random parameter is used as a source of entropy to ensure that
|
|
||||||
// encrypting the same message twice doesn't result in the same ciphertext.
|
|
||||||
// Most applications should use [crypto/rand.Reader] as random.
|
|
||||||
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
|
||||||
//A3, requirement is to check if h*P is infinite point, h is 1
|
|
||||||
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
|
|
||||||
return nil, errors.New("sm2: public key point is the infinity")
|
|
||||||
}
|
|
||||||
if len(msg) == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if opts == nil {
|
|
||||||
opts = defaultEncrypterOpts
|
|
||||||
}
|
|
||||||
switch pub.Curve.Params() {
|
|
||||||
case P256().Params():
|
|
||||||
return encryptSM2EC(p256(), pub, random, msg, opts)
|
|
||||||
default:
|
|
||||||
return encryptLegacy(random, pub, msg, opts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
|
||||||
Q, err := c.pointFromAffine(pub.X, pub.Y)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var retryCount int = 0
|
|
||||||
for {
|
|
||||||
k, C1, err := randomPoint(c, random, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
C2, err := Q.ScalarMult(Q, k.Bytes(c.N))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
C2Bytes := C2.Bytes()[1:]
|
|
||||||
c2 := sm3.Kdf(C2Bytes, len(msg))
|
|
||||||
if subtle.ConstantTimeAllZero(c2) == 1 {
|
|
||||||
retryCount++
|
|
||||||
if retryCount > maxRetryLimit {
|
|
||||||
return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
//A6, C2 = M + t;
|
|
||||||
subtle.XORBytes(c2, msg, c2)
|
|
||||||
|
|
||||||
//A7, C3 = hash(x2||M||y2)
|
|
||||||
md := sm3.New()
|
|
||||||
md.Write(C2Bytes[:len(C2Bytes)/2])
|
|
||||||
md.Write(msg)
|
|
||||||
md.Write(C2Bytes[len(C2Bytes)/2:])
|
|
||||||
c3 := md.Sum(nil)
|
|
||||||
|
|
||||||
if opts.ciphertextEncoding == ENCODING_PLAIN {
|
|
||||||
return encodingCiphertext(opts, C1, c2, c3)
|
|
||||||
}
|
|
||||||
return encodingCiphertextASN1(C1, c2, c3)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodingCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
|
|
||||||
var c1 []byte
|
|
||||||
switch opts.pointMarshalMode {
|
|
||||||
case MarshalCompressed:
|
|
||||||
c1 = C1.BytesCompressed()
|
|
||||||
default:
|
|
||||||
c1 = C1.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.ciphertextSplicingOrder == C1C3C2 {
|
|
||||||
// c1 || c3 || c2
|
|
||||||
return append(append(c1, c3...), c2...), nil
|
|
||||||
}
|
|
||||||
// c1 || c2 || c3
|
|
||||||
return append(append(c1, c2...), c3...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
|
|
||||||
c1 := C1.Bytes()
|
|
||||||
var b cryptobyte.Builder
|
|
||||||
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
|
||||||
addASN1IntBytes(b, c1[1:len(c1)/2+1])
|
|
||||||
addASN1IntBytes(b, c1[len(c1)/2+1:])
|
|
||||||
b.AddASN1OctetString(c3)
|
|
||||||
b.AddASN1OctetString(c2)
|
|
||||||
})
|
|
||||||
return b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateKey generates a new SM2 private key.
|
// GenerateKey generates a new SM2 private key.
|
||||||
//
|
//
|
||||||
// Most applications should use [crypto/rand.Reader] as rand. Note that the
|
// Most applications should use [crypto/rand.Reader] as rand. Note that the
|
||||||
@ -358,23 +163,27 @@ func NewPrivateKey(key []byte) (*PrivateKey, error) {
|
|||||||
return priv, nil
|
return priv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPrivateKeyFromInt checks that key is valid and returns a SM2 PrivateKey.
|
// NewPrivateKeyFromInt creates a new SM2 private key from a given big integer.
|
||||||
|
// It returns an error if the provided key is nil.
|
||||||
func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) {
|
func NewPrivateKeyFromInt(key *big.Int) (*PrivateKey, error) {
|
||||||
if key == nil {
|
if key == nil {
|
||||||
return nil, errors.New("sm2: invalid private key size")
|
return nil, errors.New("sm2: private key is nil")
|
||||||
}
|
}
|
||||||
keyBytes := make([]byte, p256().N.Size())
|
keyBytes := make([]byte, p256().N.Size())
|
||||||
return NewPrivateKey(key.FillBytes(keyBytes))
|
return NewPrivateKey(key.FillBytes(keyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPublicKey checks that key is valid and returns a PublicKey.
|
// NewPublicKey checks that the provided key is valid and returns an SM2 PublicKey.
|
||||||
//
|
//
|
||||||
// According GB/T 32918.1-2016, the private key must be in [1, n-2].
|
// The key parameter is a byte slice representing the public key in uncompressed format.
|
||||||
|
// According to GB/T 32918.1-2016, the public key must be in the correct format and on the curve.
|
||||||
func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
|
func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
|
||||||
c := p256()
|
c := p256()
|
||||||
// Reject the point at infinity and compressed encodings.
|
// Reject the point at infinity and compressed encodings.
|
||||||
|
// Points at infinity are invalid because they do not represent a valid point on the curve.
|
||||||
|
// Compressed encodings are not supported by this implementation, so they are also rejected.
|
||||||
if len(key) == 0 || key[0] != 4 {
|
if len(key) == 0 || key[0] != 4 {
|
||||||
return nil, errors.New("sm2: invalid public key")
|
return nil, errInvalidPublicKey
|
||||||
}
|
}
|
||||||
// SetBytes also checks that the point is on the curve.
|
// SetBytes also checks that the point is on the curve.
|
||||||
p, err := c.newPoint().SetBytes(key)
|
p, err := c.newPoint().SetBytes(key)
|
||||||
@ -390,138 +199,6 @@ func NewPublicKey(key []byte) (*ecdsa.PublicKey, error) {
|
|||||||
return k, nil
|
return k, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}.
|
|
||||||
// Compliance with GB/T 32918.4-2016.
|
|
||||||
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|
||||||
return decrypt(priv, ciphertext, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrDecryption represents a failure to decrypt a message.
|
|
||||||
// It is deliberately vague to avoid adaptive attacks.
|
|
||||||
var ErrDecryption = errors.New("sm2: decryption error")
|
|
||||||
|
|
||||||
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
|
||||||
ciphertextLen := len(ciphertext)
|
|
||||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
|
||||||
return nil, errCiphertextTooShort
|
|
||||||
}
|
|
||||||
switch priv.Curve.Params() {
|
|
||||||
case P256().Params():
|
|
||||||
return decryptSM2EC(p256(), priv, ciphertext, opts)
|
|
||||||
default:
|
|
||||||
return decryptLegacy(priv, ciphertext, opts)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
|
||||||
C1, c2, c3, err := parseCiphertext(c, ciphertext, opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ErrDecryption
|
|
||||||
}
|
|
||||||
d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
|
||||||
if err != nil {
|
|
||||||
return nil, ErrDecryption
|
|
||||||
}
|
|
||||||
|
|
||||||
C2, err := C1.ScalarMult(C1, d.Bytes(c.N))
|
|
||||||
if err != nil {
|
|
||||||
return nil, ErrDecryption
|
|
||||||
}
|
|
||||||
C2Bytes := C2.Bytes()[1:]
|
|
||||||
msgLen := len(c2)
|
|
||||||
msg := sm3.Kdf(C2Bytes, msgLen)
|
|
||||||
if subtle.ConstantTimeAllZero(c2) == 1 {
|
|
||||||
return nil, ErrDecryption
|
|
||||||
}
|
|
||||||
|
|
||||||
//B5, calculate msg = c2 ^ t
|
|
||||||
subtle.XORBytes(msg, c2, msg)
|
|
||||||
|
|
||||||
md := sm3.New()
|
|
||||||
md.Write(C2Bytes[:len(C2Bytes)/2])
|
|
||||||
md.Write(msg)
|
|
||||||
md.Write(C2Bytes[len(C2Bytes)/2:])
|
|
||||||
u := md.Sum(nil)
|
|
||||||
|
|
||||||
if _subtle.ConstantTimeCompare(u, c3) == 1 {
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
return nil, ErrDecryption
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
|
|
||||||
bitSize := c.curve.Params().BitSize
|
|
||||||
// Encode the coordinates and let SetBytes reject invalid points.
|
|
||||||
byteLen := (bitSize + 7) / 8
|
|
||||||
splicingOrder := C1C3C2
|
|
||||||
if opts != nil {
|
|
||||||
splicingOrder = opts.cipherTextSplicingOrder
|
|
||||||
}
|
|
||||||
|
|
||||||
b := ciphertext[0]
|
|
||||||
switch b {
|
|
||||||
case uncompressed:
|
|
||||||
if len(ciphertext) <= 1+2*byteLen+sm3.Size {
|
|
||||||
return nil, nil, nil, errCiphertextTooShort
|
|
||||||
}
|
|
||||||
C1, err := c.newPoint().SetBytes(ciphertext[:1+2*byteLen])
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
c2, c3 := parseCiphertextC2C3(ciphertext[1+2*byteLen:], splicingOrder)
|
|
||||||
return C1, c2, c3, nil
|
|
||||||
case compressed02, compressed03:
|
|
||||||
C1, err := c.newPoint().SetBytes(ciphertext[:1+byteLen])
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
c2, c3 := parseCiphertextC2C3(ciphertext[1+byteLen:], splicingOrder)
|
|
||||||
return C1, c2, c3, nil
|
|
||||||
case byte(0x30):
|
|
||||||
return parseCiphertextASN1(c, ciphertext)
|
|
||||||
default:
|
|
||||||
return nil, nil, nil, errors.New("sm2: invalid/unsupport ciphertext format")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCiphertextC2C3(ciphertext []byte, order ciphertextSplicingOrder) ([]byte, []byte) {
|
|
||||||
if order == C1C3C2 {
|
|
||||||
return ciphertext[sm3.Size:], ciphertext[:sm3.Size]
|
|
||||||
}
|
|
||||||
return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:]
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) {
|
|
||||||
var (
|
|
||||||
x1, y1 = &big.Int{}, &big.Int{}
|
|
||||||
c2, c3 []byte
|
|
||||||
inner cryptobyte.String
|
|
||||||
)
|
|
||||||
input := cryptobyte.String(ciphertext)
|
|
||||||
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
|
|
||||||
!input.Empty() ||
|
|
||||||
!inner.ReadASN1Integer(x1) ||
|
|
||||||
!inner.ReadASN1Integer(y1) ||
|
|
||||||
!inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) ||
|
|
||||||
!inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) ||
|
|
||||||
!inner.Empty() {
|
|
||||||
return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext")
|
|
||||||
}
|
|
||||||
return x1, y1, c2, c3, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
|
|
||||||
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
C1, err := c.pointFromAffine(x1, y1)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, nil, err
|
|
||||||
}
|
|
||||||
return C1, c2, c3, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
|
||||||
|
|
||||||
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA).
|
// CalculateZA ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA).
|
||||||
@ -530,29 +207,53 @@ var defaultUID = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x
|
|||||||
// This function will not use default UID even the uid argument is empty.
|
// This function will not use default UID even the uid argument is empty.
|
||||||
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
|
||||||
uidLen := len(uid)
|
uidLen := len(uid)
|
||||||
if uidLen >= 0x2000 {
|
if uidLen > 0x1fff {
|
||||||
return nil, errors.New("sm2: the uid is too long")
|
return nil, errors.New("sm2: the uid is too long")
|
||||||
}
|
}
|
||||||
entla := uint16(uidLen) << 3
|
uidBitLength := uint16(uidLen) << 3
|
||||||
md := sm3.New()
|
md := sm3.New()
|
||||||
md.Write([]byte{byte(entla >> 8), byte(entla)})
|
md.Write([]byte{byte(uidBitLength >> 8), byte(uidBitLength)})
|
||||||
if uidLen > 0 {
|
if uidLen > 0 {
|
||||||
md.Write(uid)
|
md.Write(uid)
|
||||||
}
|
}
|
||||||
a := new(big.Int).Sub(pub.Params().P, big.NewInt(3))
|
writeCurveParams(md, pub.Curve)
|
||||||
md.Write(toBytes(pub.Curve, a))
|
md.Write(bigIntToBytes(pub.Curve, pub.X))
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().B))
|
md.Write(bigIntToBytes(pub.Curve, pub.Y))
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().Gx))
|
// Return the calculated ZA value
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().Gy))
|
|
||||||
md.Write(toBytes(pub.Curve, pub.X))
|
|
||||||
md.Write(toBytes(pub.Curve, pub.Y))
|
|
||||||
return md.Sum(nil), nil
|
return md.Sum(nil), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateSM2Hash calculates hash value for data including uid and public key parameters
|
// writeCurveParams writes the parameters of the given elliptic curve to the provided hash.Hash.
|
||||||
// according standards.
|
// It writes the following parameters in order:
|
||||||
|
// - a: P - 3 (where P is the prime specifying the base field of the curve)
|
||||||
|
// - B: the coefficient B of the curve equation
|
||||||
|
// - Gx: the x-coordinate of the base point G
|
||||||
|
// - Gy: the y-coordinate of the base point G
|
||||||
//
|
//
|
||||||
// uid can be nil, then it will use default uid (1234567812345678)
|
// Parameters:
|
||||||
|
// - md: the hash.Hash to write the curve parameters to
|
||||||
|
// - curve: the elliptic.Curve whose parameters are to be written
|
||||||
|
func writeCurveParams(md hash.Hash, curve elliptic.Curve) {
|
||||||
|
a := new(big.Int).Sub(curve.Params().P, big.NewInt(3))
|
||||||
|
md.Write(bigIntToBytes(curve, a))
|
||||||
|
md.Write(bigIntToBytes(curve, curve.Params().B))
|
||||||
|
md.Write(bigIntToBytes(curve, curve.Params().Gx))
|
||||||
|
md.Write(bigIntToBytes(curve, curve.Params().Gy))
|
||||||
|
}
|
||||||
|
|
||||||
|
// bigIntToBytes converts a big integer value to a byte slice of the appropriate length for the given elliptic curve.
|
||||||
|
// The byte slice is zero-padded to the left if necessary to match the curve's byte length.
|
||||||
|
func bigIntToBytes(curve elliptic.Curve, value *big.Int) []byte {
|
||||||
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||||
|
byteArray := make([]byte, byteLen)
|
||||||
|
value.FillBytes(byteArray)
|
||||||
|
return byteArray
|
||||||
|
}
|
||||||
|
|
||||||
|
// CalculateSM2Hash calculates the SM2 hash for the given public key, data, and user ID (UID).
|
||||||
|
// If the UID is not provided, a default UID (1234567812345678) is used.
|
||||||
|
// The public key must be valid, otherwise will be panic.
|
||||||
|
// This function is used to calculate the hash value for SM2 signature.
|
||||||
func CalculateSM2Hash(pub *ecdsa.PublicKey, data, uid []byte) ([]byte, error) {
|
func CalculateSM2Hash(pub *ecdsa.PublicKey, data, uid []byte) ([]byte, error) {
|
||||||
if len(uid) == 0 {
|
if len(uid) == 0 {
|
||||||
uid = defaultUID
|
uid = defaultUID
|
||||||
@ -597,21 +298,24 @@ func SignASN1(rand io.Reader, priv *PrivateKey, hash []byte, opts crypto.SignerO
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// inverseOfPrivateKeyPlus1 calculates and returns the modular inverse of (private key + 1) modulo the curve order.
|
||||||
|
// It uses lazy initialization and caching to ensure the calculation is performed only once.
|
||||||
|
// If the private key is invalid, it returns an error.
|
||||||
func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, error) {
|
func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
err error
|
||||||
dp1Inv, oneNat *bigmod.Nat
|
oneNat = bigmod.NewNat().SetUint(1, c.N)
|
||||||
dp1Bytes []byte
|
inverseDPlus1 *bigmod.Nat
|
||||||
|
dp1Bytes []byte
|
||||||
)
|
)
|
||||||
priv.inverseOfKeyPlus1Once.Do(func() {
|
priv.inverseOfKeyPlus1Once.Do(func() {
|
||||||
oneNat = bigmod.NewNat().SetUint(1, c.N)
|
inverseDPlus1, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
||||||
dp1Inv, err = bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dp1Inv.Add(oneNat, c.N)
|
inverseDPlus1.Add(oneNat, c.N)
|
||||||
if dp1Inv.IsZero() == 1 { // make sure private key is NOT N-1
|
if inverseDPlus1.IsZero() == 1 { // make sure private key is NOT N-1
|
||||||
err = errInvalidPrivateKey
|
err = errInvalidPrivateKey
|
||||||
} else {
|
} else {
|
||||||
dp1Bytes, err = _sm2ec.P256OrdInverse(dp1Inv.Bytes(c.N))
|
dp1Bytes, err = _sm2ec.P256OrdInverse(inverseDPlus1.Bytes(c.N))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
priv.inverseOfKeyPlus1, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N)
|
priv.inverseOfKeyPlus1, err = bigmod.NewNat().SetBytes(dp1Bytes, c.N)
|
||||||
}
|
}
|
||||||
@ -624,9 +328,27 @@ func (priv *PrivateKey) inverseOfPrivateKeyPlus1(c *sm2Curve) (*bigmod.Nat, erro
|
|||||||
return priv.inverseOfKeyPlus1, nil
|
return priv.inverseOfKeyPlus1, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// signSM2EC generates an SM2 digital signature using the provided private key and hash.
|
||||||
|
// It follows the SM2 signature algorithm as specified in the Chinese cryptographic standards.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: A pointer to the sm2Curve structure representing the elliptic curve parameters.
|
||||||
|
// - priv: A pointer to the PrivateKey structure containing the private key for signing.
|
||||||
|
// - rand: An io.Reader instance used to generate random values.
|
||||||
|
// - hash: A byte slice containing the hash of the message to be signed.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - sig: A byte slice containing the generated signature.
|
||||||
|
// - err: An error value indicating any issues encountered during the signing process.
|
||||||
|
//
|
||||||
|
// The function performs the following steps:
|
||||||
|
// 1. Computes the inverse of (d + 1) where d is the private key.
|
||||||
|
// 2. Converts the hash to an integer.
|
||||||
|
// 3. Generates a random point on the elliptic curve and computes the signature components (r, s).
|
||||||
|
// 4. Ensures that the signature components are non-zero and valid.
|
||||||
|
// 5. Encodes the signature components into a byte slice and returns it.
|
||||||
func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) {
|
func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig []byte, err error) {
|
||||||
// dp1Inv = (d+1)⁻¹
|
inverseDPlus1, err := priv.inverseOfPrivateKeyPlus1(c)
|
||||||
dp1Inv, err := priv.inverseOfPrivateKeyPlus1(c)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -675,7 +397,7 @@ func signSM2EC(c *sm2Curve, priv *PrivateKey, rand io.Reader, hash []byte) (sig
|
|||||||
// k = [k - s]
|
// k = [k - s]
|
||||||
k.Sub(s, c.N)
|
k.Sub(s, c.N)
|
||||||
// k = [(d+1)⁻¹ * (k - r * d)]
|
// k = [(d+1)⁻¹ * (k - r * d)]
|
||||||
k.Mul(dp1Inv, c.N)
|
k.Mul(inverseDPlus1, c.N)
|
||||||
if k.IsZero() == 0 {
|
if k.IsZero() == 0 {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -713,91 +435,6 @@ func addASN1IntBytes(b *cryptobyte.Builder, bytes []byte) {
|
|||||||
|
|
||||||
var ErrInvalidSignature = errors.New("sm2: invalid signature")
|
var ErrInvalidSignature = errors.New("sm2: invalid signature")
|
||||||
|
|
||||||
// RecoverPublicKeysFromSM2Signature recovers two or four SM2 public keys from a given signature and hash.
|
|
||||||
// It takes the hash and signature as input and returns the recovered public keys as []*ecdsa.PublicKey.
|
|
||||||
// If the signature or hash is invalid, it returns an error.
|
|
||||||
// The function follows the SM2 algorithm to recover the public keys.
|
|
||||||
func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) {
|
|
||||||
c := p256()
|
|
||||||
rBytes, sBytes, err := parseSignature(sig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
|
|
||||||
if err != nil || r.IsZero() == 1 {
|
|
||||||
return nil, ErrInvalidSignature
|
|
||||||
}
|
|
||||||
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
|
|
||||||
if err != nil || s.IsZero() == 1 {
|
|
||||||
return nil, ErrInvalidSignature
|
|
||||||
}
|
|
||||||
|
|
||||||
e := bigmod.NewNat()
|
|
||||||
hashToNat(c, e, hash)
|
|
||||||
|
|
||||||
// p₁ = [-s]G
|
|
||||||
negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N)
|
|
||||||
p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// s = [r + s]
|
|
||||||
s.Add(r, c.N)
|
|
||||||
if s.IsZero() == 1 {
|
|
||||||
return nil, ErrInvalidSignature
|
|
||||||
}
|
|
||||||
// sBytes = (r+s)⁻¹
|
|
||||||
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// r = (Rx + e) mod N
|
|
||||||
// Rx = r - e
|
|
||||||
r.Sub(e, c.N)
|
|
||||||
if r.IsZero() == 1 {
|
|
||||||
return nil, ErrInvalidSignature
|
|
||||||
}
|
|
||||||
pointRx := make([]*bigmod.Nat, 0, 2)
|
|
||||||
pointRx = append(pointRx, r)
|
|
||||||
// check if Rx in (N, P), small probability event
|
|
||||||
s.Set(r)
|
|
||||||
s = s.Add(c.N.Nat(), c.P)
|
|
||||||
if s.CmpGeq(c.N.Nat()) == 1 {
|
|
||||||
pointRx = append(pointRx, s)
|
|
||||||
}
|
|
||||||
pubs := make([]*ecdsa.PublicKey, 0, 4)
|
|
||||||
bytes := make([]byte, 32+1)
|
|
||||||
compressFlags := []byte{compressed02, compressed03}
|
|
||||||
// Rx has one or two possible values, so point R has two or four possible values
|
|
||||||
for _, x := range pointRx {
|
|
||||||
rBytes = x.Bytes(c.N)
|
|
||||||
copy(bytes[1:], rBytes)
|
|
||||||
for _, flag := range compressFlags {
|
|
||||||
bytes[0] = flag
|
|
||||||
// p0 = R
|
|
||||||
p0, err := c.newPoint().SetBytes(bytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// p0 = R - [s]G
|
|
||||||
p0.Add(p0, p1)
|
|
||||||
// Pub = [(r + s)⁻¹](R - [s]G)
|
|
||||||
p0.ScalarMult(p0, sBytes)
|
|
||||||
pub := new(ecdsa.PublicKey)
|
|
||||||
pub.Curve = c.curve
|
|
||||||
pub.X, pub.Y, err = c.pointToAffine(p0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pubs = append(pubs, pub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return pubs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
|
// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the
|
||||||
// public key, pub. Its return value records whether the signature is valid.
|
// public key, pub. Its return value records whether the signature is valid.
|
||||||
//
|
//
|
||||||
@ -919,7 +556,9 @@ func hashToNat(c *sm2Curve, e *bigmod.Nat, hash []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSM2PublicKey check if given public key is a SM2 public key or not
|
// IsSM2PublicKey checks if the provided public key is an SM2 public key.
|
||||||
|
// It takes an interface{} as input and attempts to assert it to an *ecdsa.PublicKey.
|
||||||
|
// The function returns true if the assertion is successful and the public key's curve is SM2 P-256.
|
||||||
func IsSM2PublicKey(publicKey any) bool {
|
func IsSM2PublicKey(publicKey any) bool {
|
||||||
pub, ok := publicKey.(*ecdsa.PublicKey)
|
pub, ok := publicKey.(*ecdsa.PublicKey)
|
||||||
return ok && pub.Curve == sm2ec.P256()
|
return ok && pub.Curve == sm2ec.P256()
|
||||||
@ -939,7 +578,7 @@ func PublicKeyToECDH(k *ecdsa.PublicKey) (*ecdh.PublicKey, error) {
|
|||||||
return nil, errors.New("sm2: unsupported curve by ecdh")
|
return nil, errors.New("sm2: unsupported curve by ecdh")
|
||||||
}
|
}
|
||||||
if !k.Curve.IsOnCurve(k.X, k.Y) {
|
if !k.Curve.IsOnCurve(k.X, k.Y) {
|
||||||
return nil, errors.New("sm2: invalid public key")
|
return nil, errInvalidPublicKey
|
||||||
}
|
}
|
||||||
return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y))
|
return c.NewPublicKey(elliptic.Marshal(k.Curve, k.X, k.Y))
|
||||||
}
|
}
|
||||||
@ -954,7 +593,7 @@ func (k *PrivateKey) ECDH() (*ecdh.PrivateKey, error) {
|
|||||||
}
|
}
|
||||||
size := (k.Curve.Params().N.BitLen() + 7) / 8
|
size := (k.Curve.Params().N.BitLen() + 7) / 8
|
||||||
if k.D.BitLen() > size*8 {
|
if k.D.BitLen() > size*8 {
|
||||||
return nil, errors.New("sm2: invalid private key")
|
return nil, errInvalidPrivateKey
|
||||||
}
|
}
|
||||||
return c.NewPrivateKey(k.D.FillBytes(make([]byte, size)))
|
return c.NewPrivateKey(k.D.FillBytes(make([]byte, size)))
|
||||||
}
|
}
|
||||||
@ -1011,6 +650,110 @@ func randomPoint(c *sm2Curve, rand io.Reader, checkOrderMinus1 bool) (k *bigmod.
|
|||||||
// randomPoint rejects a candidate for being higher than the modulus.
|
// randomPoint rejects a candidate for being higher than the modulus.
|
||||||
var testingOnlyRejectionSamplingLooped func()
|
var testingOnlyRejectionSamplingLooped func()
|
||||||
|
|
||||||
|
|
||||||
|
// RecoverPublicKeysFromSM2Signature attempts to recover the public keys from an SM2 signature.
|
||||||
|
// This function takes a hash and a signature as input and returns a slice of possible public keys
|
||||||
|
// that could have generated the given signature.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - hash: The hash of the message that was signed.
|
||||||
|
// - sig: The SM2 signature.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - A slice of pointers to ecdsa.PublicKey, representing the possible public keys.
|
||||||
|
// - An error if the signature is invalid or if any other error occurs during the recovery process.
|
||||||
|
//
|
||||||
|
// The function performs the following steps:
|
||||||
|
// 1. Parses the signature to extract the r and s values.
|
||||||
|
// 2. Converts the hash to a big integer (Nat).
|
||||||
|
// 3. Computes the point p₁ = [-s]G.
|
||||||
|
// 4. Computes s = [r + s] and its modular inverse.
|
||||||
|
// 5. Computes the possible x-coordinates (Rx) for the point R.
|
||||||
|
// 6. For each possible Rx, computes the corresponding point R and derives the public key.
|
||||||
|
//
|
||||||
|
// Note: The function handles the case where there are one or two possible values for Rx,
|
||||||
|
// resulting in two or four possible public keys.
|
||||||
|
func RecoverPublicKeysFromSM2Signature(hash, sig []byte) ([]*ecdsa.PublicKey, error) {
|
||||||
|
c := p256()
|
||||||
|
rBytes, sBytes, err := parseSignature(sig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
r, err := bigmod.NewNat().SetBytes(rBytes, c.N)
|
||||||
|
if err != nil || r.IsZero() == 1 {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
s, err := bigmod.NewNat().SetBytes(sBytes, c.N)
|
||||||
|
if err != nil || s.IsZero() == 1 {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
|
||||||
|
e := bigmod.NewNat()
|
||||||
|
hashToNat(c, e, hash)
|
||||||
|
|
||||||
|
// p₁ = [-s]G
|
||||||
|
negS := bigmod.NewNat().ExpandFor(c.N).Sub(s, c.N)
|
||||||
|
p1, err := c.newPoint().ScalarBaseMult(negS.Bytes(c.N))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// s = [r + s]
|
||||||
|
s.Add(r, c.N)
|
||||||
|
if s.IsZero() == 1 {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
// sBytes = (r+s)⁻¹
|
||||||
|
sBytes, err = _sm2ec.P256OrdInverse(s.Bytes(c.N))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// r = (Rx + e) mod N
|
||||||
|
// Rx = r - e
|
||||||
|
r.Sub(e, c.N)
|
||||||
|
if r.IsZero() == 1 {
|
||||||
|
return nil, ErrInvalidSignature
|
||||||
|
}
|
||||||
|
pointRx := make([]*bigmod.Nat, 0, 2)
|
||||||
|
pointRx = append(pointRx, r)
|
||||||
|
// check if Rx in (N, P), small probability event
|
||||||
|
s.Set(r)
|
||||||
|
s = s.Add(c.N.Nat(), c.P)
|
||||||
|
if s.CmpGeq(c.N.Nat()) == 1 {
|
||||||
|
pointRx = append(pointRx, s)
|
||||||
|
}
|
||||||
|
pubs := make([]*ecdsa.PublicKey, 0, 4)
|
||||||
|
bytes := make([]byte, 32+1)
|
||||||
|
compressFlags := []byte{compressed02, compressed03}
|
||||||
|
// Rx has one or two possible values, so point R has two or four possible values
|
||||||
|
for _, x := range pointRx {
|
||||||
|
rBytes = x.Bytes(c.N)
|
||||||
|
copy(bytes[1:], rBytes)
|
||||||
|
for _, flag := range compressFlags {
|
||||||
|
bytes[0] = flag
|
||||||
|
// p0 = R
|
||||||
|
p0, err := c.newPoint().SetBytes(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// p0 = R - [s]G
|
||||||
|
p0.Add(p0, p1)
|
||||||
|
// Pub = [(r + s)⁻¹](R - [s]G)
|
||||||
|
p0.ScalarMult(p0, sBytes)
|
||||||
|
pub := new(ecdsa.PublicKey)
|
||||||
|
pub.Curve = c.curve
|
||||||
|
pub.X, pub.Y, err = c.pointToAffine(p0)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pubs = append(pubs, pub)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pubs, nil
|
||||||
|
}
|
||||||
|
|
||||||
type sm2Curve struct {
|
type sm2Curve struct {
|
||||||
newPoint func() *_sm2ec.SM2P256Point
|
newPoint func() *_sm2ec.SM2P256Point
|
||||||
curve elliptic.Curve
|
curve elliptic.Curve
|
||||||
@ -1073,5 +816,3 @@ func precomputeParams(c *sm2Curve, curve elliptic.Curve) {
|
|||||||
c.nMinus1 = c.N.Nat().SubOne(c.N)
|
c.nMinus1 = c.N.Nat().SubOne(c.N)
|
||||||
c.nMinus2 = new(bigmod.Nat).Set(c.nMinus1).SubOne(c.N).Bytes(c.N)
|
c.nMinus2 = new(bigmod.Nat).Set(c.nMinus1).SubOne(c.N).Bytes(c.N)
|
||||||
}
|
}
|
||||||
|
|
||||||
var errInvalidPrivateKey = errors.New("sm2: invalid private key")
|
|
@ -10,7 +10,6 @@ import (
|
|||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
"reflect"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
@ -109,367 +108,6 @@ func TestNewPublicKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSplicingOrder(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
from ciphertextSplicingOrder
|
|
||||||
to ciphertextSplicingOrder
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
|
|
||||||
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
|
|
||||||
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
|
|
||||||
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
|
|
||||||
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
|
|
||||||
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
|
|
||||||
//Adjust splicing order
|
|
||||||
ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("adjust splicing order failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed after adjust splicing order %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncryptDecryptASN1(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
key2 := new(PrivateKey)
|
|
||||||
key2.PrivateKey = *priv2
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
priv *PrivateKey
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32", "encryption standard", priv},
|
|
||||||
{"equals 32", "encryption standard encryption ", priv},
|
|
||||||
{"long than 32", "encryption standard encryption standard", priv},
|
|
||||||
{"less than 32", "encryption standard", key2},
|
|
||||||
{"equals 32", "encryption standard encryption ", key2},
|
|
||||||
{"long than 32", "encryption standard encryption standard", key2},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
encrypterOpts := ASN1EncrypterOpts
|
|
||||||
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err)
|
|
||||||
}
|
|
||||||
plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPlainCiphertext2ASN1(t *testing.T) {
|
|
||||||
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
|
|
||||||
_, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
_, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
ciphertext[0] = 0x10
|
|
||||||
_, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdjustCiphertextSplicingOrder(t *testing.T) {
|
|
||||||
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
|
|
||||||
res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2)
|
|
||||||
if err != nil || &res[0] != &ciphertext[0] {
|
|
||||||
t.Fatalf("should be same one")
|
|
||||||
}
|
|
||||||
_, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
ciphertext[0] = 0x10
|
|
||||||
_, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("expected error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCiphertext2ASN1(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32", "encryption standard"},
|
|
||||||
{"equals 32", "encryption standard encryption "},
|
|
||||||
{"long than 32", "encryption standard encryption standard"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("convert to ASN.1 failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
|
|
||||||
ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("adjust order failed %v", err)
|
|
||||||
}
|
|
||||||
ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("convert to ASN.1 failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCiphertextASN12Plain(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32", "encryption standard"},
|
|
||||||
{"equals 32", "encryption standard encryption "},
|
|
||||||
{"long than 32", "encryption standard encryption standard"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("convert to plain failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncryptWithInfinitePublicKey(t *testing.T) {
|
|
||||||
pub := new(ecdsa.PublicKey)
|
|
||||||
pub.Curve = P256()
|
|
||||||
pub.X = big.NewInt(0)
|
|
||||||
pub.Y = big.NewInt(0)
|
|
||||||
|
|
||||||
_, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("should be failed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncryptEmptyPlaintext(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil)
|
|
||||||
if err != nil || ciphertext != nil {
|
|
||||||
t.Fatalf("nil plaintext should return nil")
|
|
||||||
}
|
|
||||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil)
|
|
||||||
if err != nil || ciphertext != nil {
|
|
||||||
t.Fatalf("empty plaintext should return nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEncryptDecrypt(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
|
||||||
key2 := new(PrivateKey)
|
|
||||||
key2.PrivateKey = *priv2
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
priv *PrivateKey
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32", "encryption standard", priv},
|
|
||||||
{"equals 32", "encryption standard encryption ", priv},
|
|
||||||
{"long than 32", "encryption standard encryption standard", priv},
|
|
||||||
{"less than 32", "encryption standard", key2},
|
|
||||||
{"equals 32", "encryption standard encryption ", key2},
|
|
||||||
{"long than 32", "encryption standard encryption standard", key2},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err := Decrypt(tt.priv, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
// compress mode
|
|
||||||
encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
|
|
||||||
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err = Decrypt(tt.priv, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hybrid mode
|
|
||||||
encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
|
|
||||||
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("encrypt failed %v", err)
|
|
||||||
}
|
|
||||||
plaintext, err = Decrypt(tt.priv, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
plaintext, err = Decrypt(tt.priv, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("decrypt failed %v", err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
|
||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInvalidCiphertext(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
ciphertext []byte
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{errCiphertextTooShort.Error(), make([]byte, 65)},
|
|
||||||
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)},
|
|
||||||
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)},
|
|
||||||
{ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)},
|
|
||||||
{ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)},
|
|
||||||
{ErrDecryption.Error(), make([]byte, 97)},
|
|
||||||
}
|
|
||||||
for i, tt := range tests {
|
|
||||||
_, err := Decrypt(priv, tt.ciphertext)
|
|
||||||
if err.Error() != tt.name {
|
|
||||||
t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) {
|
|
||||||
priv := new(PrivateKey)
|
|
||||||
priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1))
|
|
||||||
priv.Curve = P256()
|
|
||||||
priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes())
|
|
||||||
|
|
||||||
_, err := priv.inverseOfPrivateKeyPlus1(p256())
|
|
||||||
if err == nil || err != errInvalidPrivateKey {
|
|
||||||
t.Errorf("expected invalid private key error")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSignVerify(t *testing.T) {
|
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
plainText string
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"less than 32", "encryption standard"},
|
|
||||||
{"equals 32", "encryption standard encryption "},
|
|
||||||
{"long than 32", "encryption standard encryption standard"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
hashed := sm3.Sum([]byte(tt.plainText))
|
|
||||||
signature, err := priv.Sign(rand.Reader, hashed[:], nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sign failed %v", err)
|
|
||||||
}
|
|
||||||
result := VerifyASN1(&priv.PublicKey, hashed[:], signature)
|
|
||||||
if !result {
|
|
||||||
t.Fatal("verify failed")
|
|
||||||
}
|
|
||||||
hashed[0] ^= 0xff
|
|
||||||
if VerifyASN1(&priv.PublicKey, hashed[:], signature) {
|
|
||||||
t.Errorf("VerifyASN1 always works!")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testRecoverPublicKeysFromSM2Signature(t *testing.T, priv *PrivateKey) {
|
func testRecoverPublicKeysFromSM2Signature(t *testing.T, priv *PrivateKey) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -774,6 +412,48 @@ func TestRandomPoint(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) {
|
||||||
|
priv := new(PrivateKey)
|
||||||
|
priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1))
|
||||||
|
priv.Curve = P256()
|
||||||
|
priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes())
|
||||||
|
|
||||||
|
_, err := priv.inverseOfPrivateKeyPlus1(p256())
|
||||||
|
if err == nil || err != errInvalidPrivateKey {
|
||||||
|
t.Errorf("expected invalid private key error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignVerify(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard"},
|
||||||
|
{"equals 32", "encryption standard encryption "},
|
||||||
|
{"long than 32", "encryption standard encryption standard"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
hashed := sm3.Sum([]byte(tt.plainText))
|
||||||
|
signature, err := priv.Sign(rand.Reader, hashed[:], nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sign failed %v", err)
|
||||||
|
}
|
||||||
|
result := VerifyASN1(&priv.PublicKey, hashed[:], signature)
|
||||||
|
if !result {
|
||||||
|
t.Fatal("verify failed")
|
||||||
|
}
|
||||||
|
hashed[0] ^= 0xff
|
||||||
|
if VerifyASN1(&priv.PublicKey, hashed[:], signature) {
|
||||||
|
t.Errorf("VerifyASN1 always works!")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGenerateKey_SM2(b *testing.B) {
|
func BenchmarkGenerateKey_SM2(b *testing.B) {
|
||||||
r := bufio.NewReaderSize(rand.Reader, 1<<15)
|
r := bufio.NewReaderSize(rand.Reader, 1<<15)
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
@ -894,45 +574,3 @@ func BenchmarkVerify_SM2(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) {
|
|
||||||
r := bufio.NewReaderSize(rand.Reader, 1<<15)
|
|
||||||
priv, err := ecdsa.GenerateKey(curve, r)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatal(err)
|
|
||||||
}
|
|
||||||
b.SetBytes(int64(len(plaintext)))
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, P256(), make([]byte, 31))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncrypt128_P256(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncrypt128_SM2(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, P256(), make([]byte, 128))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncrypt512_SM2(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, P256(), make([]byte, 512))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncrypt1K_SM2(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, P256(), make([]byte, 1024))
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkEncrypt8K_SM2(b *testing.B) {
|
|
||||||
benchmarkEncrypt(b, P256(), make([]byte, 8*1024))
|
|
||||||
}
|
|
@ -79,8 +79,26 @@ func MarshalEnvelopedPrivateKey(rand io.Reader, pub *ecdsa.PublicKey, tobeEnvelo
|
|||||||
return b.Bytes()
|
return b.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseEnvelopedPrivateKey, parses and decrypts the enveloped SM2 private key.
|
// ParseEnvelopedPrivateKey parses an enveloped private key using the provided private key.
|
||||||
// This methed just supports SM4 cipher now.
|
// The enveloped key is expected to be in ASN.1 format and encrypted with a symmetric cipher.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - priv: The private key used to decrypt the symmetric key.
|
||||||
|
// - enveloped: The ASN.1 encoded and encrypted enveloped private key.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - A pointer to the decrypted PrivateKey.
|
||||||
|
// - An error if the parsing or decryption fails.
|
||||||
|
//
|
||||||
|
// The function performs the following steps:
|
||||||
|
// 1. Unmarshals the ASN.1 data to extract the symmetric algorithm identifier, encrypted symmetric key, public key, and encrypted private key.
|
||||||
|
// 2. Verifies that the symmetric algorithm is supported (SM4 or SM4ECB).
|
||||||
|
// 3. Parses the public key from the ASN.1 data.
|
||||||
|
// 4. Decrypts the symmetric key using the provided private key.
|
||||||
|
// 5. Decrypts the SM2 private key using the decrypted symmetric key.
|
||||||
|
// 6. Verifies that the decrypted private key matches the public key.
|
||||||
|
//
|
||||||
|
// Errors are returned if any of the steps fail, including invalid ASN.1 format, unsupported symmetric cipher, decryption failures, or key mismatches.
|
||||||
func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, error) {
|
func ParseEnvelopedPrivateKey(priv *PrivateKey, enveloped []byte) (*PrivateKey, error) {
|
||||||
// unmarshal the asn.1 data
|
// unmarshal the asn.1 data
|
||||||
var (
|
var (
|
||||||
|
@ -149,34 +149,34 @@ func (ke *KeyExchange) InitKeyExchange(rand io.Reader) (*ecdsa.PublicKey, error)
|
|||||||
func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
|
func (ke *KeyExchange) sign(isResponder bool, prefix byte) []byte {
|
||||||
var buffer []byte
|
var buffer []byte
|
||||||
hash := sm3.New()
|
hash := sm3.New()
|
||||||
hash.Write(toBytes(ke.privateKey, ke.v.X))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.v.X))
|
||||||
if isResponder {
|
if isResponder {
|
||||||
hash.Write(ke.peerZ)
|
hash.Write(ke.peerZ)
|
||||||
hash.Write(ke.z)
|
hash.Write(ke.z)
|
||||||
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.secret.X))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y))
|
||||||
} else {
|
} else {
|
||||||
hash.Write(ke.z)
|
hash.Write(ke.z)
|
||||||
hash.Write(ke.peerZ)
|
hash.Write(ke.peerZ)
|
||||||
hash.Write(toBytes(ke.privateKey, ke.secret.X))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.X))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.secret.Y))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.secret.Y))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.peerSecret.X))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.X))
|
||||||
hash.Write(toBytes(ke.privateKey, ke.peerSecret.Y))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.peerSecret.Y))
|
||||||
}
|
}
|
||||||
buffer = hash.Sum(nil)
|
buffer = hash.Sum(nil)
|
||||||
hash.Reset()
|
hash.Reset()
|
||||||
hash.Write([]byte{prefix})
|
hash.Write([]byte{prefix})
|
||||||
hash.Write(toBytes(ke.privateKey, ke.v.Y))
|
hash.Write(bigIntToBytes(ke.privateKey, ke.v.Y))
|
||||||
hash.Write(buffer)
|
hash.Write(buffer)
|
||||||
return hash.Sum(nil)
|
return hash.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
||||||
var buffer []byte
|
var buffer []byte
|
||||||
buffer = append(buffer, toBytes(ke.privateKey, ke.v.X)...)
|
buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.X)...)
|
||||||
buffer = append(buffer, toBytes(ke.privateKey, ke.v.Y)...)
|
buffer = append(buffer, bigIntToBytes(ke.privateKey, ke.v.Y)...)
|
||||||
if isResponder {
|
if isResponder {
|
||||||
buffer = append(buffer, ke.peerZ...)
|
buffer = append(buffer, ke.peerZ...)
|
||||||
buffer = append(buffer, ke.z...)
|
buffer = append(buffer, ke.z...)
|
||||||
|
@ -301,12 +301,9 @@ func calculateSampleZA(pub *ecdsa.PublicKey, a *big.Int, uid []byte) ([]byte, er
|
|||||||
if uidLen > 0 {
|
if uidLen > 0 {
|
||||||
md.Write(uid)
|
md.Write(uid)
|
||||||
}
|
}
|
||||||
md.Write(toBytes(pub.Curve, a))
|
writeCurveParams(md, pub.Curve.Params())
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().B))
|
md.Write(bigIntToBytes(pub.Curve, pub.X))
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().Gx))
|
md.Write(bigIntToBytes(pub.Curve, pub.Y))
|
||||||
md.Write(toBytes(pub.Curve, pub.Params().Gy))
|
|
||||||
md.Write(toBytes(pub.Curve, pub.X))
|
|
||||||
md.Write(toBytes(pub.Curve, pub.Y))
|
|
||||||
return md.Sum(nil), nil
|
return md.Sum(nil), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -259,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
|
|||||||
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||||
|
|
||||||
//A5, calculate t=KDF(x2||y2, klen)
|
//A5, calculate t=KDF(x2||y2, klen)
|
||||||
c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
c2 := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen)
|
||||||
if subtle.ConstantTimeAllZero(c2) == 1 {
|
if subtle.ConstantTimeAllZero(c2) == 1 {
|
||||||
retryCount++
|
retryCount++
|
||||||
if retryCount > maxRetryLimit {
|
if retryCount > maxRetryLimit {
|
||||||
@ -289,9 +289,9 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
|
|||||||
|
|
||||||
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
||||||
md := sm3.New()
|
md := sm3.New()
|
||||||
md.Write(toBytes(curve, x2))
|
md.Write(bigIntToBytes(curve, x2))
|
||||||
md.Write(msg)
|
md.Write(msg)
|
||||||
md.Write(toBytes(curve, y2))
|
md.Write(bigIntToBytes(curve, y2))
|
||||||
return md.Sum(nil)
|
return md.Sum(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -306,95 +306,6 @@ func mashalASN1Ciphertext(x1, y1 *big.Int, c2, c3 []byte) ([]byte, error) {
|
|||||||
return b.Bytes()
|
return b.Bytes()
|
||||||
}
|
}
|
||||||
|
|
||||||
// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format
|
|
||||||
func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
|
|
||||||
if opts == nil {
|
|
||||||
opts = defaultEncrypterOpts
|
|
||||||
}
|
|
||||||
x1, y1, c2, c3, err := unmarshalASN1Ciphertext((ciphertext))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
curve := sm2ec.P256()
|
|
||||||
c1 := opts.pointMarshalMode.mashal(curve, x1, y1)
|
|
||||||
if opts.ciphertextSplicingOrder == C1C3C2 {
|
|
||||||
// c1 || c3 || c2
|
|
||||||
return append(append(c1, c3...), c2...), nil
|
|
||||||
}
|
|
||||||
// c1 || c2 || c3
|
|
||||||
return append(append(c1, c2...), c3...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
|
|
||||||
func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
|
|
||||||
if ciphertext[0] == 0x30 {
|
|
||||||
return nil, errors.New("sm2: invalid plain encoding ciphertext")
|
|
||||||
}
|
|
||||||
curve := sm2ec.P256()
|
|
||||||
ciphertextLen := len(ciphertext)
|
|
||||||
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
|
|
||||||
return nil, errCiphertextTooShort
|
|
||||||
}
|
|
||||||
// get C1, and check C1
|
|
||||||
x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var c2, c3 []byte
|
|
||||||
|
|
||||||
if from == C1C3C2 {
|
|
||||||
c2 = ciphertext[c3Start+sm3.Size:]
|
|
||||||
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
|
||||||
} else {
|
|
||||||
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
|
||||||
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
|
||||||
}
|
|
||||||
return mashalASN1Ciphertext(x1, y1, c2, c3)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AdjustCiphertextSplicingOrder utility method to change c2 c3 order
|
|
||||||
func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) {
|
|
||||||
curve := sm2ec.P256()
|
|
||||||
if from == to {
|
|
||||||
return ciphertext, nil
|
|
||||||
}
|
|
||||||
ciphertextLen := len(ciphertext)
|
|
||||||
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
|
|
||||||
return nil, errCiphertextTooShort
|
|
||||||
}
|
|
||||||
|
|
||||||
// get C1, and check C1
|
|
||||||
_, _, c3Start, err := bytes2Point(curve, ciphertext)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var c1, c2, c3 []byte
|
|
||||||
|
|
||||||
c1 = ciphertext[:c3Start]
|
|
||||||
if from == C1C3C2 {
|
|
||||||
c2 = ciphertext[c3Start+sm3.Size:]
|
|
||||||
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
|
||||||
} else {
|
|
||||||
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
|
||||||
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]byte, ciphertextLen)
|
|
||||||
copy(result, c1)
|
|
||||||
if to == C1C3C2 {
|
|
||||||
// c1 || c3 || c2
|
|
||||||
copy(result[c3Start:], c3)
|
|
||||||
copy(result[c3Start+sm3.Size:], c2)
|
|
||||||
} else {
|
|
||||||
// c1 || c2 || c3
|
|
||||||
copy(result[c3Start:], c2)
|
|
||||||
copy(result[ciphertextLen-sm3.Size:], c3)
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
func decryptASN1(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||||
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
|
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -407,7 +318,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
|
|||||||
curve := priv.Curve
|
curve := priv.Curve
|
||||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
msg := sm3.Kdf(append(bigIntToBytes(curve, x2), bigIntToBytes(curve, y2)...), msgLen)
|
||||||
if subtle.ConstantTimeAllZero(c2) == 1 {
|
if subtle.ConstantTimeAllZero(c2) == 1 {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
@ -428,7 +339,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
|
|||||||
if opts.ciphertextEncoding == ENCODING_ASN1 {
|
if opts.ciphertextEncoding == ENCODING_ASN1 {
|
||||||
return decryptASN1(priv, ciphertext)
|
return decryptASN1(priv, ciphertext)
|
||||||
}
|
}
|
||||||
splicingOrder = opts.cipherTextSplicingOrder
|
splicingOrder = opts.ciphertextSplicingOrder
|
||||||
}
|
}
|
||||||
if ciphertext[0] == 0x30 {
|
if ciphertext[0] == 0x30 {
|
||||||
return decryptASN1(priv, ciphertext)
|
return decryptASN1(priv, ciphertext)
|
||||||
@ -436,7 +347,7 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
|
|||||||
ciphertextLen := len(ciphertext)
|
ciphertextLen := len(ciphertext)
|
||||||
curve := priv.Curve
|
curve := priv.Curve
|
||||||
// B1, get C1, and check C1
|
// B1, get C1, and check C1
|
||||||
x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
|
x1, y1, c3Start, err := bytesToPoint(curve, ciphertext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
@ -454,7 +365,20 @@ func decryptLegacy(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]
|
|||||||
return rawDecrypt(priv, x1, y1, c2, c3)
|
return rawDecrypt(priv, x1, y1, c2, c3)
|
||||||
}
|
}
|
||||||
|
|
||||||
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
|
switch mode {
|
||||||
|
case MarshalCompressed:
|
||||||
|
return elliptic.MarshalCompressed(curve, x, y)
|
||||||
|
case MarshalHybrid:
|
||||||
|
buffer := elliptic.Marshal(curve, x, y)
|
||||||
|
buffer[0] = byte(y.Bit(0)) | hybrid06
|
||||||
|
return buffer
|
||||||
|
default:
|
||||||
|
return elliptic.Marshal(curve, x, y)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesToPoint(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
||||||
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
||||||
return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes))
|
return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes))
|
||||||
}
|
}
|
||||||
@ -486,20 +410,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
|
|||||||
}
|
}
|
||||||
return x, y, 1 + byteLen, nil
|
return x, y, 1 + byteLen, nil
|
||||||
}
|
}
|
||||||
return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name)
|
return nil, nil, 0, fmt.Errorf("sm2: unsupported point form %d, curve %s", format, curve.Params().Name)
|
||||||
}
|
}
|
||||||
return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format)
|
return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
||||||
switch mode {
|
|
||||||
case MarshalCompressed:
|
|
||||||
return elliptic.MarshalCompressed(curve, x, y)
|
|
||||||
case MarshalHybrid:
|
|
||||||
buffer := elliptic.Marshal(curve, x, y)
|
|
||||||
buffer[0] = byte(y.Bit(0)) | hybrid06
|
|
||||||
return buffer
|
|
||||||
default:
|
|
||||||
return elliptic.Marshal(curve, x, y)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
396
sm2/sm2_pke.go
Normal file
396
sm2/sm2_pke.go
Normal file
@ -0,0 +1,396 @@
|
|||||||
|
// Package sm2 implements ShangMi(SM) sm2 digital signature, public key encryption and key exchange algorithms.
|
||||||
|
package sm2
|
||||||
|
|
||||||
|
// Further references:
|
||||||
|
// [NSA]: Suite B implementer's guide to FIPS 186-3
|
||||||
|
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.182.4503&rep=rep1&type=pdf
|
||||||
|
// [SECG]: SECG, SEC1
|
||||||
|
// http://www.secg.org/sec1-v2.pdf
|
||||||
|
// [GM/T]: SM2 GB/T 32918.2-2016, GB/T 32918.4-2016
|
||||||
|
//
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
_subtle "crypto/subtle"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
|
||||||
|
"github.com/emmansun/gmsm/internal/bigmod"
|
||||||
|
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
|
||||||
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
|
"github.com/emmansun/gmsm/sm3"
|
||||||
|
"golang.org/x/crypto/cryptobyte"
|
||||||
|
"golang.org/x/crypto/cryptobyte/asn1"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
uncompressed byte = 0x04
|
||||||
|
compressed02 byte = 0x02
|
||||||
|
compressed03 byte = compressed02 | 0x01
|
||||||
|
hybrid06 byte = 0x06
|
||||||
|
hybrid07 byte = hybrid06 | 0x01
|
||||||
|
)
|
||||||
|
|
||||||
|
type pointMarshalMode byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
//MarshalUncompressed uncompressed marshal mode
|
||||||
|
MarshalUncompressed pointMarshalMode = iota
|
||||||
|
//MarshalCompressed compressed marshal mode
|
||||||
|
MarshalCompressed
|
||||||
|
//MarshalHybrid hybrid marshal mode
|
||||||
|
MarshalHybrid
|
||||||
|
)
|
||||||
|
|
||||||
|
type ciphertextSplicingOrder byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
C1C3C2 ciphertextSplicingOrder = iota
|
||||||
|
C1C2C3
|
||||||
|
)
|
||||||
|
|
||||||
|
// splitC2C3 splits the given ciphertext into two parts, C2 and C3, based on the splicing order.
|
||||||
|
// If the order is C1C3C2, it returns the first sm3.Size bytes as C3 and the rest as C2.
|
||||||
|
// Otherwise, it returns the first part as C2 and the last sm3.Size bytes as C3.
|
||||||
|
func (order ciphertextSplicingOrder) splitC2C3(ciphertext []byte) ([]byte, []byte) {
|
||||||
|
if order == C1C3C2 {
|
||||||
|
return ciphertext[sm3.Size:], ciphertext[:sm3.Size]
|
||||||
|
}
|
||||||
|
return ciphertext[:len(ciphertext)-sm3.Size], ciphertext[len(ciphertext)-sm3.Size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// spliceCiphertext splices the given ciphertext components together based on the splicing order.
|
||||||
|
func (order ciphertextSplicingOrder) spliceCiphertext(c1, c2, c3 []byte) ([]byte, error) {
|
||||||
|
switch order {
|
||||||
|
case C1C3C2:
|
||||||
|
return append(append(c1, c3...), c2...), nil
|
||||||
|
case C1C2C3:
|
||||||
|
return append(append(c1, c2...), c3...), nil
|
||||||
|
default:
|
||||||
|
return nil, errors.New("sm2: invalid ciphertext splicing order")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ciphertextEncoding byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
ENCODING_PLAIN ciphertextEncoding = iota
|
||||||
|
ENCODING_ASN1
|
||||||
|
)
|
||||||
|
|
||||||
|
// EncrypterOpts represents the options for the SM2 encryption process.
|
||||||
|
// It includes settings for ciphertext encoding, point marshaling mode,
|
||||||
|
// and the order in which the ciphertext components are spliced together.
|
||||||
|
type EncrypterOpts struct {
|
||||||
|
ciphertextEncoding ciphertextEncoding
|
||||||
|
pointMarshalMode pointMarshalMode
|
||||||
|
ciphertextSplicingOrder ciphertextSplicingOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrypterOpts represents the options for the decryption process.
|
||||||
|
// It includes settings for how the ciphertext is encoded and how the
|
||||||
|
// components of the ciphertext are spliced together.
|
||||||
|
//
|
||||||
|
// Fields:
|
||||||
|
// - ciphertextEncoding: Specifies the encoding format of the ciphertext.
|
||||||
|
// - ciphertextSplicingOrder: Defines the order in which the components
|
||||||
|
// of the ciphertext are spliced together.
|
||||||
|
type DecrypterOpts struct {
|
||||||
|
ciphertextEncoding ciphertextEncoding
|
||||||
|
ciphertextSplicingOrder ciphertextSplicingOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPlainEncrypterOpts creates a SM2 non-ASN1 encrypter options.
|
||||||
|
func NewPlainEncrypterOpts(marshalMode pointMarshalMode, splicingOrder ciphertextSplicingOrder) *EncrypterOpts {
|
||||||
|
return &EncrypterOpts{ENCODING_PLAIN, marshalMode, splicingOrder}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPlainDecrypterOpts creates a SM2 non-ASN1 decrypter options.
|
||||||
|
func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts {
|
||||||
|
return &DecrypterOpts{ENCODING_PLAIN, splicingOrder}
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2}
|
||||||
|
|
||||||
|
ASN1EncrypterOpts = &EncrypterOpts{ENCODING_ASN1, MarshalUncompressed, C1C3C2}
|
||||||
|
|
||||||
|
ASN1DecrypterOpts = &DecrypterOpts{ENCODING_ASN1, C1C3C2}
|
||||||
|
)
|
||||||
|
|
||||||
|
const maxRetryLimit = 100
|
||||||
|
|
||||||
|
var errCiphertextTooShort = errors.New("sm2: ciphertext too short")
|
||||||
|
|
||||||
|
// EncryptASN1 sm2 encrypt and output ASN.1 result, compliance with GB/T 32918.4-2016.
|
||||||
|
//
|
||||||
|
// The random parameter is used as a source of entropy to ensure that
|
||||||
|
// encrypting the same message twice doesn't result in the same ciphertext.
|
||||||
|
// Most applications should use [crypto/rand.Reader] as random.
|
||||||
|
func EncryptASN1(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
||||||
|
return Encrypt(random, pub, msg, ASN1EncrypterOpts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt sm2 encrypt implementation, compliance with GB/T 32918.4-2016.
|
||||||
|
//
|
||||||
|
// The random parameter is used as a source of entropy to ensure that
|
||||||
|
// encrypting the same message twice doesn't result in the same ciphertext.
|
||||||
|
// Most applications should use [crypto/rand.Reader] as random.
|
||||||
|
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
||||||
|
//A3, requirement is to check if h*P is infinite point, h is 1
|
||||||
|
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
|
||||||
|
return nil, errors.New("sm2: public key point is the infinity")
|
||||||
|
}
|
||||||
|
if len(msg) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if opts == nil {
|
||||||
|
opts = defaultEncrypterOpts
|
||||||
|
}
|
||||||
|
switch pub.Curve.Params() {
|
||||||
|
case P256().Params():
|
||||||
|
return encryptSM2EC(p256(), pub, random, msg, opts)
|
||||||
|
default:
|
||||||
|
return encryptLegacy(random, pub, msg, opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byte, opts *EncrypterOpts) ([]byte, error) {
|
||||||
|
Q, err := c.pointFromAffine(pub.X, pub.Y)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
retryCount := 0
|
||||||
|
for {
|
||||||
|
k, C1, err := randomPoint(c, random, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
C2, err := Q.ScalarMult(Q, k.Bytes(c.N))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
C2Bytes := C2.Bytes()[1:]
|
||||||
|
c2 := sm3.Kdf(C2Bytes, len(msg))
|
||||||
|
if subtle.ConstantTimeAllZero(c2) == 1 {
|
||||||
|
retryCount++
|
||||||
|
if retryCount > maxRetryLimit {
|
||||||
|
return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", retryCount)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
//A6, C2 = M + t;
|
||||||
|
subtle.XORBytes(c2, msg, c2)
|
||||||
|
|
||||||
|
//A7, C3 = hash(x2||M||y2)
|
||||||
|
md := sm3.New()
|
||||||
|
md.Write(C2Bytes[:len(C2Bytes)/2])
|
||||||
|
md.Write(msg)
|
||||||
|
md.Write(C2Bytes[len(C2Bytes)/2:])
|
||||||
|
c3 := md.Sum(nil)
|
||||||
|
|
||||||
|
if opts.ciphertextEncoding == ENCODING_PLAIN {
|
||||||
|
return encodeCiphertext(opts, C1, c2, c3)
|
||||||
|
}
|
||||||
|
return encodingCiphertextASN1(C1, c2, c3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeCiphertext(opts *EncrypterOpts, C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
|
||||||
|
var c1 []byte
|
||||||
|
switch opts.pointMarshalMode {
|
||||||
|
case MarshalCompressed:
|
||||||
|
c1 = C1.BytesCompressed()
|
||||||
|
default:
|
||||||
|
c1 = C1.Bytes()
|
||||||
|
}
|
||||||
|
return opts.ciphertextSplicingOrder.spliceCiphertext(c1, c2, c3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodingCiphertextASN1(C1 *_sm2ec.SM2P256Point, c2, c3 []byte) ([]byte, error) {
|
||||||
|
c1 := C1.Bytes()
|
||||||
|
var b cryptobyte.Builder
|
||||||
|
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
|
||||||
|
addASN1IntBytes(b, c1[1:len(c1)/2+1])
|
||||||
|
addASN1IntBytes(b, c1[len(c1)/2+1:])
|
||||||
|
b.AddASN1OctetString(c3)
|
||||||
|
b.AddASN1OctetString(c2)
|
||||||
|
})
|
||||||
|
return b.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt decrypts ciphertext msg to plaintext.
|
||||||
|
// The opts argument should be appropriate for the primitive used.
|
||||||
|
// Compliance with GB/T 32918.4-2016 chapter 7.
|
||||||
|
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||||
|
var sm2Opts *DecrypterOpts
|
||||||
|
sm2Opts, _ = opts.(*DecrypterOpts)
|
||||||
|
return decrypt(priv, msg, sm2Opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}.
|
||||||
|
// Compliance with GB/T 32918.4-2016.
|
||||||
|
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||||
|
return decrypt(priv, ciphertext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrDecryption represents a failure to decrypt a message.
|
||||||
|
// It is deliberately vague to avoid adaptive attacks.
|
||||||
|
var ErrDecryption = errors.New("sm2: decryption error")
|
||||||
|
|
||||||
|
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
||||||
|
ciphertextLen := len(ciphertext)
|
||||||
|
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||||
|
return nil, errCiphertextTooShort
|
||||||
|
}
|
||||||
|
switch priv.Curve.Params() {
|
||||||
|
case P256().Params():
|
||||||
|
return decryptSM2EC(p256(), priv, ciphertext, opts)
|
||||||
|
default:
|
||||||
|
return decryptLegacy(priv, ciphertext, opts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
||||||
|
C1, c2, c3, err := parseCiphertext(c, ciphertext, opts)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrDecryption
|
||||||
|
}
|
||||||
|
d, err := bigmod.NewNat().SetBytes(priv.D.Bytes(), c.N)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrDecryption
|
||||||
|
}
|
||||||
|
|
||||||
|
C2, err := C1.ScalarMult(C1, d.Bytes(c.N))
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrDecryption
|
||||||
|
}
|
||||||
|
C2Bytes := C2.Bytes()[1:]
|
||||||
|
msgLen := len(c2)
|
||||||
|
msg := sm3.Kdf(C2Bytes, msgLen)
|
||||||
|
if subtle.ConstantTimeAllZero(c2) == 1 {
|
||||||
|
return nil, ErrDecryption
|
||||||
|
}
|
||||||
|
|
||||||
|
//B5, calculate msg = c2 ^ t
|
||||||
|
subtle.XORBytes(msg, c2, msg)
|
||||||
|
|
||||||
|
md := sm3.New()
|
||||||
|
md.Write(C2Bytes[:len(C2Bytes)/2])
|
||||||
|
md.Write(msg)
|
||||||
|
md.Write(C2Bytes[len(C2Bytes)/2:])
|
||||||
|
u := md.Sum(nil)
|
||||||
|
|
||||||
|
if _subtle.ConstantTimeCompare(u, c3) == 1 {
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
return nil, ErrDecryption
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseCiphertext parses the given ciphertext according to the specified SM2 curve and decryption options.
|
||||||
|
// It returns the parsed SM2 point (C1), the decrypted message (C2), the message digest (C3), and an error if any.
|
||||||
|
func parseCiphertext(c *sm2Curve, ciphertext []byte, opts *DecrypterOpts) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
|
||||||
|
bitSize := c.curve.Params().BitSize
|
||||||
|
byteLen := (bitSize + 7) / 8
|
||||||
|
splicingOrder := C1C3C2
|
||||||
|
if opts != nil {
|
||||||
|
splicingOrder = opts.ciphertextSplicingOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
var ciphertextFormat byte = 0xff // invalid
|
||||||
|
if len(ciphertext) > 0 {
|
||||||
|
ciphertextFormat = ciphertext[0]
|
||||||
|
}
|
||||||
|
var c1Len int
|
||||||
|
switch ciphertextFormat {
|
||||||
|
case byte(asn1.SEQUENCE):
|
||||||
|
return parseCiphertextASN1(c, ciphertext)
|
||||||
|
case uncompressed:
|
||||||
|
c1Len = 1 + 2*byteLen
|
||||||
|
case compressed02, compressed03:
|
||||||
|
c1Len = 1 + byteLen
|
||||||
|
default:
|
||||||
|
return nil, nil, nil, errors.New("sm2: invalid/unsupported ciphertext format")
|
||||||
|
}
|
||||||
|
if len(ciphertext) < c1Len+sm3.Size {
|
||||||
|
return nil, nil, nil, errCiphertextTooShort
|
||||||
|
}
|
||||||
|
C1, err := c.newPoint().SetBytes(ciphertext[:c1Len])
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
c2, c3 := splicingOrder.splitC2C3(ciphertext[c1Len:])
|
||||||
|
return C1, c2, c3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []byte, error) {
|
||||||
|
var (
|
||||||
|
x1, y1 = &big.Int{}, &big.Int{}
|
||||||
|
c2, c3 []byte
|
||||||
|
inner cryptobyte.String
|
||||||
|
)
|
||||||
|
input := cryptobyte.String(ciphertext)
|
||||||
|
if !input.ReadASN1(&inner, asn1.SEQUENCE) ||
|
||||||
|
!input.Empty() ||
|
||||||
|
!inner.ReadASN1Integer(x1) ||
|
||||||
|
!inner.ReadASN1Integer(y1) ||
|
||||||
|
!inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) ||
|
||||||
|
!inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) ||
|
||||||
|
!inner.Empty() {
|
||||||
|
return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext")
|
||||||
|
}
|
||||||
|
return x1, y1, c2, c3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCiphertextASN1(c *sm2Curve, ciphertext []byte) (*_sm2ec.SM2P256Point, []byte, []byte, error) {
|
||||||
|
x1, y1, c2, c3, err := unmarshalASN1Ciphertext(ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
C1, err := c.pointFromAffine(x1, y1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
return C1, c2, c3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdjustCiphertextSplicingOrder utility method to change c2 c3 order
|
||||||
|
func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicingOrder) ([]byte, error) {
|
||||||
|
curve := p256()
|
||||||
|
if from == to {
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
C1, c2, c3, err := parseCiphertext(curve, ciphertext, NewPlainDecrypterOpts(from))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
opts := NewPlainEncrypterOpts(MarshalUncompressed, to)
|
||||||
|
if ciphertext[0] == compressed02 || ciphertext[0] == compressed03 {
|
||||||
|
opts.pointMarshalMode = MarshalCompressed
|
||||||
|
}
|
||||||
|
return encodeCiphertext(opts, C1, c2, c3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ASN1Ciphertext2Plain utility method to convert ASN.1 encoding ciphertext to plain encoding format
|
||||||
|
func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
|
||||||
|
if opts == nil {
|
||||||
|
opts = defaultEncrypterOpts
|
||||||
|
}
|
||||||
|
C1, c2, c3, err := parseCiphertextASN1(p256(), ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return encodeCiphertext(opts, C1, c2, c3)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
|
||||||
|
func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
|
||||||
|
C1, c2, c3, err := parseCiphertext(p256(), ciphertext, NewPlainDecrypterOpts(from))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return encodingCiphertextASN1(C1, c2, c3)
|
||||||
|
}
|
374
sm2/sm2_pke_test.go
Normal file
374
sm2/sm2_pke_test.go
Normal file
@ -0,0 +1,374 @@
|
|||||||
|
package sm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"math/big"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSplicingOrder(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
from ciphertextSplicingOrder
|
||||||
|
to ciphertextSplicingOrder
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
|
||||||
|
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
|
||||||
|
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
|
||||||
|
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
|
||||||
|
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
|
||||||
|
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Adjust splicing order
|
||||||
|
ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("adjust splicing order failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed after adjust splicing order %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecryptASN1(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
key2 := new(PrivateKey)
|
||||||
|
key2.PrivateKey = *priv2
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
priv *PrivateKey
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard", priv},
|
||||||
|
{"equals 32", "encryption standard encryption ", priv},
|
||||||
|
{"long than 32", "encryption standard encryption standard", priv},
|
||||||
|
{"less than 32", "encryption standard", key2},
|
||||||
|
{"equals 32", "encryption standard encryption ", key2},
|
||||||
|
{"long than 32", "encryption standard encryption standard", key2},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
encrypterOpts := ASN1EncrypterOpts
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err)
|
||||||
|
}
|
||||||
|
plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlainCiphertext2ASN1(t *testing.T) {
|
||||||
|
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
|
||||||
|
_, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
_, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
ciphertext[0] = 0x10
|
||||||
|
_, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdjustCiphertextSplicingOrder(t *testing.T) {
|
||||||
|
ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
|
||||||
|
res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2)
|
||||||
|
if err != nil || &res[0] != &ciphertext[0] {
|
||||||
|
t.Fatalf("should be same one")
|
||||||
|
}
|
||||||
|
_, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
ciphertext[0] = 0x10
|
||||||
|
_, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCiphertext2ASN1(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard"},
|
||||||
|
{"equals 32", "encryption standard encryption "},
|
||||||
|
{"long than 32", "encryption standard encryption standard"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("convert to ASN.1 failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
|
||||||
|
ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("adjust order failed %v", err)
|
||||||
|
}
|
||||||
|
ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("convert to ASN.1 failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCiphertextASN12Plain(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard"},
|
||||||
|
{"equals 32", "encryption standard encryption "},
|
||||||
|
{"long than 32", "encryption standard encryption standard"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("convert to plain failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptWithInfinitePublicKey(t *testing.T) {
|
||||||
|
pub := new(ecdsa.PublicKey)
|
||||||
|
pub.Curve = P256()
|
||||||
|
pub.X = big.NewInt(0)
|
||||||
|
pub.Y = big.NewInt(0)
|
||||||
|
|
||||||
|
_, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("should be failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptEmptyPlaintext(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil)
|
||||||
|
if err != nil || ciphertext != nil {
|
||||||
|
t.Fatalf("nil plaintext should return nil")
|
||||||
|
}
|
||||||
|
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil)
|
||||||
|
if err != nil || ciphertext != nil {
|
||||||
|
t.Fatalf("empty plaintext should return nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEncryptDecrypt(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
key2 := new(PrivateKey)
|
||||||
|
key2.PrivateKey = *priv2
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
priv *PrivateKey
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard", priv},
|
||||||
|
{"equals 32", "encryption standard encryption ", priv},
|
||||||
|
{"long than 32", "encryption standard encryption standard", priv},
|
||||||
|
{"less than 32", "encryption standard", key2},
|
||||||
|
{"equals 32", "encryption standard encryption ", key2},
|
||||||
|
{"long than 32", "encryption standard encryption standard", key2},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := Decrypt(tt.priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
// compress mode
|
||||||
|
encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
|
||||||
|
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = Decrypt(tt.priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
|
||||||
|
// hybrid mode
|
||||||
|
encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
|
||||||
|
ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = Decrypt(tt.priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
plaintext, err = Decrypt(tt.priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInvalidCiphertext(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
ciphertext []byte
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{errCiphertextTooShort.Error(), nil},
|
||||||
|
{errCiphertextTooShort.Error(), make([]byte, 65)},
|
||||||
|
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)},
|
||||||
|
{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)},
|
||||||
|
{ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)},
|
||||||
|
{ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)},
|
||||||
|
{ErrDecryption.Error(), make([]byte, 97)},
|
||||||
|
}
|
||||||
|
for i, tt := range tests {
|
||||||
|
_, err := Decrypt(priv, tt.ciphertext)
|
||||||
|
if err.Error() != tt.name {
|
||||||
|
t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) {
|
||||||
|
r := bufio.NewReaderSize(rand.Reader, 1<<15)
|
||||||
|
priv, err := ecdsa.GenerateKey(curve, r)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
b.SetBytes(int64(len(plaintext)))
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), make([]byte, 31))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt128_P256(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt128_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), make([]byte, 128))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt512_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), make([]byte, 512))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt1K_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), make([]byte, 1024))
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt8K_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), make([]byte, 8*1024))
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user