sm9: support non-xor modes

This commit is contained in:
Sun Yimin 2023-02-10 17:19:50 +08:00 committed by GitHub
parent 5bfdfeb9b5
commit ebf9a74d77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 705 additions and 88 deletions

View File

@ -8,8 +8,6 @@ import (
"github.com/emmansun/gmsm/sm4"
)
var commonCounter = []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff}
var ctrSM4Tests = []struct {
name string
key []byte

99
cipher/ecb.go Normal file
View File

@ -0,0 +1,99 @@
// Electronic Code Book (ECB) mode.
// Please do NOT use this mode alone.
package cipher
import (
goCipher "crypto/cipher"
"github.com/emmansun/gmsm/internal/alias"
)
type ecb struct {
b goCipher.Block
blockSize int
}
func newECB(b goCipher.Block) *ecb {
return &ecb{
b: b,
blockSize: b.BlockSize(),
}
}
func (x *ecb) validate(dst, src []byte) {
if len(src)%x.blockSize != 0 {
panic("cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("cipher: invalid buffer overlap")
}
}
type ecbEncrypter ecb
// ecbEncAble is an interface implemented by ciphers that have a specific
// optimized implementation of ECB encryption, like sm4.
// NewECBEncrypter will check for this interface and return the specific
// BlockMode if found.
type ecbEncAble interface {
NewECBEncrypter() goCipher.BlockMode
}
// NewECBEncrypter returns a BlockMode which encrypts in electronic code book
// mode, using the given Block.
func NewECBEncrypter(b goCipher.Block) goCipher.BlockMode {
if ecb, ok := b.(ecbEncAble); ok {
return ecb.NewECBEncrypter()
}
return (*ecbEncrypter)(newECB(b))
}
func (x *ecbEncrypter) BlockSize() int { return x.blockSize }
func (x *ecbEncrypter) CryptBlocks(dst, src []byte) {
(*ecb)(x).validate(dst, src)
for len(src) > 0 {
x.b.Encrypt(dst[:x.blockSize], src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}
type ecbDecrypter ecb
// ecbDecAble is an interface implemented by ciphers that have a specific
// optimized implementation of ECB decryption, like sm4.
// NewECBDecrypter will check for this interface and return the specific
// BlockMode if found.
type ecbDecAble interface {
NewECBDecrypter() goCipher.BlockMode
}
// NewECBDecrypter returns a BlockMode which decrypts in electronic code book
// mode, using the given Block.
func NewECBDecrypter(b goCipher.Block) goCipher.BlockMode {
if ecb, ok := b.(ecbDecAble); ok {
return ecb.NewECBDecrypter()
}
return (*ecbDecrypter)(newECB(b))
}
func (x *ecbDecrypter) BlockSize() int { return x.blockSize }
func (x *ecbDecrypter) CryptBlocks(dst, src []byte) {
(*ecb)(x).validate(dst, src)
if len(src) == 0 {
return
}
for len(src) > 0 {
x.b.Decrypt(dst[:x.blockSize], src[:x.blockSize])
src = src[x.blockSize:]
dst = dst[x.blockSize:]
}
}

119
cipher/ecb_sm4_test.go Normal file
View File

@ -0,0 +1,119 @@
package cipher_test
import (
"bytes"
"testing"
"github.com/emmansun/gmsm/cipher"
"github.com/emmansun/gmsm/sm4"
)
var ecbSM4Tests = []struct {
name string
key []byte
in []byte
}{
{
"1 block",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintext"),
},
{
"2 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintext"),
},
{
"2 different blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextfedcba9876543210"),
},
{
"3 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintext"),
},
{
"4 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
{
"5 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
{
"6 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
{
"7 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
{
"8 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
{
"9 same blocks",
[]byte("0123456789ABCDEF"),
[]byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"),
},
}
func TestECBBasic(t *testing.T) {
for _, test := range ecbSM4Tests {
c, err := sm4.NewCipher(test.key)
if err != nil {
t.Errorf("%s: NewCipher(%d bytes) = %s", test.name, len(test.key), err)
continue
}
encrypter := cipher.NewECBEncrypter(c)
ciphertext := make([]byte, len(test.in))
encrypter.CryptBlocks(ciphertext, test.in)
plaintext := make([]byte, len(test.in))
decrypter := cipher.NewECBDecrypter(c)
decrypter.CryptBlocks(plaintext, ciphertext)
if !bytes.Equal(test.in, plaintext) {
t.Errorf("%s: ECB encrypt/decrypt failed", test.name)
}
}
}
func shouldPanic(t *testing.T, f func()) {
t.Helper()
defer func() { _ = recover() }()
f()
t.Errorf("should have panicked")
}
func TestECBValidate(t *testing.T) {
key := make([]byte, 16)
src := make([]byte, 32)
c, err := sm4.NewCipher(key)
if err != nil {
t.Fatal(err)
}
decrypter := cipher.NewECBDecrypter(c)
// test len(src) == 0
decrypter.CryptBlocks(nil, nil)
// cipher: input not full blocks
shouldPanic(t, func() {
decrypter.CryptBlocks(src, src[1:])
})
// cipher: output smaller than input
shouldPanic(t, func() {
decrypter.CryptBlocks(src[1:], src)
})
// cipher: invalid buffer overlap
shouldPanic(t, func() {
decrypter.CryptBlocks(src[1:17], src[2:18])
})
}

76
cipher/example_test.go Normal file
View File

@ -0,0 +1,76 @@
package cipher_test
import (
"crypto/aes"
"encoding/hex"
"fmt"
"github.com/emmansun/gmsm/cipher"
)
func ExampleNewECBEncrypter() {
// Load your secret key from a safe place and reuse it across multiple
// NewCipher calls. (Obviously don't use this example key for anything
// real.) If you want to convert a passphrase to a key, use a suitable
// package like bcrypt or scrypt.
key, _ := hex.DecodeString("6368616e676520746869732070617373")
plaintext := []byte("exampleplaintextexampleplaintext")
// ECB mode works on blocks so plaintexts may need to be padded to the
// next whole block. For an example of such padding, see
// https://tools.ietf.org/html/rfc5246#section-6.2.3.2. Here we'll
// assume that the plaintext is already of the correct length.
if len(plaintext)%aes.BlockSize != 0 {
panic("plaintext is not a multiple of the block size")
}
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
ciphertext := make([]byte, len(plaintext))
mode := cipher.NewECBEncrypter(block)
mode.CryptBlocks(ciphertext, plaintext)
// It's important to remember that ciphertexts must be authenticated
// (i.e. by using crypto/hmac) as well as being encrypted in order to
// be secure.
fmt.Printf("%x\n", ciphertext)
}
func ExampleNewECBDecrypter() {
// Load your secret key from a safe place and reuse it across multiple
// NewCipher calls. (Obviously don't use this example key for anything
// real.) If you want to convert a passphrase to a key, use a suitable
// package like bcrypt or scrypt.
key, _ := hex.DecodeString("6368616e676520746869732070617373")
ciphertext, _ := hex.DecodeString("f42512e1e4039213bd449ba47faa1b74f42512e1e4039213bd449ba47faa1b74")
block, err := aes.NewCipher(key)
if err != nil {
panic(err)
}
// ECB mode always works in whole blocks.
if len(ciphertext)%aes.BlockSize != 0 {
panic("ciphertext is not a multiple of the block size")
}
mode := cipher.NewECBDecrypter(block)
// CryptBlocks can work in-place if the two arguments are the same.
mode.CryptBlocks(ciphertext, ciphertext)
// If the original plaintext lengths are not a multiple of the block
// size, padding would have to be added when encrypting, which would be
// removed at this point. For an example, see
// https://tools.ietf.org/html/rfc5246#section-6.2.3.2. However, it's
// critical to note that ciphertexts must be authenticated (i.e. by
// using crypto/hmac) before being decrypted in order to avoid creating
// a padding oracle.
fmt.Printf("%s\n", ciphertext)
// Output: exampleplaintextexampleplaintext
}

View File

@ -34,7 +34,7 @@ func (c *sm4CipherAsm) NewCTR(iv []byte) cipher.Stream {
}
s := &ctr{
b: c,
ctr: make([]byte, c.batchBlocks*len(iv)),
ctr: make([]byte, c.blocksSize),
out: make([]byte, 0, bufSize),
outUsed: 0,
}

82
sm4/ecb_cipher_asm.go Normal file
View File

@ -0,0 +1,82 @@
//go:build (amd64 && !purego) || (arm64 && !purego)
// +build amd64,!purego arm64,!purego
package sm4
import (
"crypto/cipher"
"github.com/emmansun/gmsm/internal/alias"
)
// Assert that sm4CipherAsm implements the ecbEncAble and ecbDecAble interfaces.
var _ ecbEncAble = (*sm4CipherAsm)(nil)
var _ ecbDecAble = (*sm4CipherAsm)(nil)
const ecbEncrypt = 1
const ecbDecrypt = 0
type ecb struct {
b *sm4CipherAsm
enc int
}
func (x *ecb) validate(dst, src []byte) {
if len(src)%BlockSize != 0 {
panic("cipher: input not full blocks")
}
if len(dst) < len(src) {
panic("cipher: output smaller than input")
}
if alias.InexactOverlap(dst[:len(src)], src) {
panic("cipher: invalid buffer overlap")
}
}
func (b *sm4CipherAsm) NewECBEncrypter() cipher.BlockMode {
var c ecb
c.b = b
c.enc = ecbEncrypt
return &c
}
func (b *sm4CipherAsm) NewECBDecrypter() cipher.BlockMode {
var c ecb
c.b = b
c.enc = ecbDecrypt
return &c
}
func (x *ecb) BlockSize() int { return BlockSize }
func (x *ecb) CryptBlocks(dst, src []byte) {
x.validate(dst, src)
if len(src) == 0 {
return
}
for len(src) >= x.b.blocksSize {
if x.enc == ecbEncrypt {
x.b.EncryptBlocks(dst[:x.b.blocksSize], src[:x.b.blocksSize])
} else {
x.b.DecryptBlocks(dst[:x.b.blocksSize], src[:x.b.blocksSize])
}
src = src[x.b.blocksSize:]
dst = dst[x.b.blocksSize:]
}
if len(src) > BlockSize {
temp := make([]byte, x.b.blocksSize)
copy(temp, src)
if x.enc == ecbEncrypt {
x.b.EncryptBlocks(temp, temp)
} else {
x.b.DecryptBlocks(temp, temp)
}
copy(dst, temp[:len(src)])
} else if len(src) > 0 {
if x.enc == ecbEncrypt {
x.b.Encrypt(dst, src)
} else {
x.b.Decrypt(dst, src)
}
}
}

View File

@ -0,0 +1,33 @@
package sm4
import (
"testing"
"github.com/emmansun/gmsm/cipher"
)
func TestECBValidate(t *testing.T) {
key := make([]byte, 16)
src := make([]byte, 32)
c, err := NewCipher(key)
if err != nil {
t.Fatal(err)
}
decrypter := cipher.NewECBDecrypter(c)
// test len(src) == 0
decrypter.CryptBlocks(nil, nil)
// cipher: input not full blocks
shouldPanic(t, func() {
decrypter.CryptBlocks(src, src[1:])
})
// cipher: output smaller than input
shouldPanic(t, func() {
decrypter.CryptBlocks(src[1:], src)
})
// cipher: invalid buffer overlap
shouldPanic(t, func() {
decrypter.CryptBlocks(src[1:17], src[2:18])
})
}

View File

@ -2,6 +2,20 @@ package sm4
import "crypto/cipher"
// ecbcEncAble is implemented by cipher.Blocks that can provide an optimized
// implementation of ECB encryption through the cipher.BlockMode interface.
// See crypto/ecb.go.
type ecbEncAble interface {
NewECBEncrypter() cipher.BlockMode
}
// ecbDecAble is implemented by cipher.Blocks that can provide an optimized
// implementation of ECB decryption through the cipher.BlockMode interface.
// See crypto/ecb.go.
type ecbDecAble interface {
NewECBDecrypter() cipher.BlockMode
}
// cbcEncAble is implemented by cipher.Blocks that can provide an optimized
// implementation of CBC encryption through the cipher.BlockMode interface.
// See crypto/cipher/cbc.go.

View File

@ -3,7 +3,7 @@
2.Sign/Verify
3.Key Exchange
4.Wrap/Unwrap Key
5.Encryption/Decryption (XOR mode)
5.Encryption/Decryption
## SM9 current performance:

View File

@ -135,7 +135,7 @@ func ExampleEncryptPrivateKey_Decrypt() {
}
uid := []byte("Bob")
cipherDer, _ := hex.DecodeString("307f020100034200042cb3e90b0977211597652f26ee4abbe275ccb18dd7f431876ab5d40cc2fc563d9417791c75bc8909336a4e6562450836cc863f51002e31ecf0c4aae8d98641070420638ca5bfb35d25cff7cbd684f3ed75f2d919da86a921a2e3e2e2f4cbcf583f240414b7e776811774722a8720752fb1355ce45dc3d0df")
plaintext, err := userKey.Decrypt(uid, cipherDer)
plaintext, err := userKey.Decrypt(uid, cipherDer, nil)
if err != nil {
fmt.Fprintf(os.Stderr, "Error from Decrypt: %s\n", err)
return
@ -156,7 +156,7 @@ func ExampleEncryptMasterPublicKey_Encrypt() {
hid := byte(0x03)
uid := []byte("Bob")
ciphertext, err := masterPubKey.Encrypt(rand.Reader, uid, hid, []byte("Chinese IBE standard"))
ciphertext, err := masterPubKey.Encrypt(rand.Reader, uid, hid, []byte("Chinese IBE standard"), sm9.DefaultEncrypterOpts)
if err != nil {
fmt.Fprintf(os.Stderr, "Error from Encrypt: %s\n", err)
return

View File

@ -3,6 +3,7 @@ package sm9
import (
"crypto"
"crypto/cipher"
goSubtle "crypto/subtle"
"encoding/binary"
"errors"
@ -10,10 +11,13 @@ import (
"io"
"math/big"
_cipher "github.com/emmansun/gmsm/cipher"
"github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/padding"
"github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm4"
"github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte"
"golang.org/x/crypto/cryptobyte/asn1"
@ -267,7 +271,7 @@ func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) b
return VerifyASN1(pub, uid, hid, hash, sig)
}
// WrapKey generate and wrap key with reciever's uid and system hid
// WrapKey generates and wraps key with reciever's uid and system hid, returns generated key and cipher.
func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) {
q := pub.GenerateUserPublicKey(uid, hid)
var (
@ -392,78 +396,269 @@ func (priv *EncryptPrivateKey) UnwrapKey(uid, cipherDer []byte, kLen int) ([]byt
return UnwrapKey(priv, uid, g, kLen)
}
// Encrypt encrypt plaintext, output ciphertext with format C1||C3||C2
func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) {
key, cipher, err := WrapKey(rand, pub, uid, hid, len(plaintext)+sm3.Size)
type CipherFactory func(key []byte) (cipher.Block, error)
// EncrypterOpts indicate encrypt/decrypt options
type EncrypterOpts struct {
EncryptType encryptType
Padding padding.Padding
CipherFactory CipherFactory
CipherKeySize int
}
type DecrypterOpts EncrypterOpts
func (opts *EncrypterOpts) getKeySize(plaintext []byte) int {
if opts.EncryptType == ENC_TYPE_XOR {
return len(plaintext)
}
return opts.CipherKeySize
}
// NewEncrypterOpts creates EncrypterOpts with given parameters
func NewEncrypterOpts(encType encryptType, padMode padding.Padding, factory CipherFactory, cipherKeySize int) *EncrypterOpts {
opts := new(EncrypterOpts)
opts.EncryptType = encType
opts.Padding = padMode
opts.CipherFactory = factory
opts.CipherKeySize = cipherKeySize
return opts
}
// DefaultEncrypterOpts default option represents XOR mode
var DefaultEncrypterOpts = NewEncrypterOpts(ENC_TYPE_XOR, nil, nil, 0)
// SM4ECBEncrypterOpts option represents SM4 ECB mode
var SM4ECBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_ECB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize)
// SM4CBCEncrypterOpts option represents SM4 CBC mode
var SM4CBCEncrypterOpts = NewEncrypterOpts(ENC_TYPE_CBC, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize)
// SM4CFBEncrypterOpts option represents SM4 CFB mode
var SM4CFBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_CFB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize)
// SM4OFBEncrypterOpts option represents SM4 OFB mode
var SM4OFBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_OFB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize)
// Encrypt encrypt plaintext, output ciphertext with format C1||C3||C2.
func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) {
c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts)
if err != nil {
return nil, err
}
subtle.XORBytes(key, key[:len(plaintext)], plaintext)
ciphertext := append(c1.Marshal(), c3...)
ciphertext = append(ciphertext, c2...)
return ciphertext, nil
}
func encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) (c1 *bn256.G1, c2, c3 []byte, err error) {
if opts == nil {
opts = DefaultEncrypterOpts
}
key1Len := opts.getKeySize(plaintext)
key, c1, err := WrapKey(rand, pub, uid, hid, key1Len+sm3.Size)
if err != nil {
return nil, nil, nil, err
}
c2, err = encryptPlaintext(rand, key[:key1Len], plaintext, opts)
if err != nil {
return nil, nil, nil, err
}
hash := sm3.New()
hash.Write(key)
c3 := hash.Sum(nil)
hash.Write(c2)
hash.Write(key[key1Len:])
c3 = hash.Sum(nil)
ciphertext := append(cipher.Marshal(), c3...)
ciphertext = append(ciphertext, key[:len(plaintext)]...)
return ciphertext, nil
return
}
func encryptPlaintext(rand io.Reader, key, plaintext []byte, opts *EncrypterOpts) ([]byte, error) {
switch opts.EncryptType {
case ENC_TYPE_XOR:
subtle.XORBytes(key, key, plaintext)
return key, nil
case ENC_TYPE_ECB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
paddedPlainText := opts.Padding.Pad(plaintext)
ciphertext := make([]byte, len(paddedPlainText))
mode := _cipher.NewECBEncrypter(block)
mode.CryptBlocks(ciphertext, paddedPlainText)
return ciphertext, nil
case ENC_TYPE_CBC:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
paddedPlainText := opts.Padding.Pad(plaintext)
blockSize := block.BlockSize()
ciphertext := make([]byte, blockSize+len(paddedPlainText))
iv := ciphertext[:blockSize]
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, err
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ciphertext[blockSize:], paddedPlainText)
return ciphertext, nil
case ENC_TYPE_CFB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
ciphertext := make([]byte, blockSize+len(plaintext))
iv := ciphertext[:blockSize]
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, err
}
stream := cipher.NewCFBEncrypter(block, iv)
stream.XORKeyStream(ciphertext[blockSize:], plaintext)
return ciphertext, nil
case ENC_TYPE_OFB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
ciphertext := make([]byte, blockSize+len(plaintext))
iv := ciphertext[:blockSize]
if _, err := io.ReadFull(rand, iv); err != nil {
return nil, err
}
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(ciphertext[blockSize:], plaintext)
return ciphertext, nil
}
return nil, fmt.Errorf("sm9: unsupported encryption type <%v>", opts.EncryptType)
}
// EncryptASN1 encrypt plaintext and output ciphertext with ASN.1 format according
// SM9 cryptographic algorithm application specification, SM9Cipher definition.
func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) {
return pub.Encrypt(rand, uid, hid, plaintext)
func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) {
return pub.Encrypt(rand, uid, hid, plaintext, opts)
}
// Encrypt encrypt plaintext and output ciphertext with ASN.1 format according
// SM9 cryptographic algorithm application specification, SM9Cipher definition.
func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte) ([]byte, error) {
key, cipher, err := WrapKey(rand, pub, uid, hid, len(plaintext)+sm3.Size)
func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) {
if opts == nil {
opts = DefaultEncrypterOpts
}
c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts)
if err != nil {
return nil, err
}
subtle.XORBytes(key, key[:len(plaintext)], plaintext)
hash := sm3.New()
hash.Write(key)
c3 := hash.Sum(nil)
var b cryptobyte.Builder
b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) {
b.AddASN1Int64(int64(ENC_TYPE_XOR))
b.AddASN1BitString(cipher.MarshalUncompressed())
b.AddASN1Int64(int64(opts.EncryptType))
b.AddASN1BitString(c1.MarshalUncompressed())
b.AddASN1OctetString(c3)
b.AddASN1OctetString(key[:len(plaintext)])
b.AddASN1OctetString(c2)
})
return b.Bytes()
}
// Decrypt decrypt chipher, ciphertext should be with format C1||C3||C2
func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) {
func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
if opts == nil {
opts = DefaultEncrypterOpts
}
c := &bn256.G1{}
c3, err := c.Unmarshal(ciphertext)
c3c2, err := c.Unmarshal(ciphertext)
if err != nil {
return nil, ErrDecryption
}
key, err := UnwrapKey(priv, uid, c, len(c3))
c2 := c3c2[sm3.Size:]
key1Len := opts.getKeySize(c2)
key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size)
if err != nil {
return nil, err
}
c2 := c3[sm3.Size:]
return decrypt(c, key[:key1Len], key[key1Len:], c2, c3c2[:sm3.Size], opts)
}
func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts *EncrypterOpts) ([]byte, error) {
hash := sm3.New()
hash.Write(c2)
hash.Write(key[len(c2):])
hash.Write(key2)
c32 := hash.Sum(nil)
if goSubtle.ConstantTimeCompare(c3[:sm3.Size], c32) != 1 {
if goSubtle.ConstantTimeCompare(c3, c32) != 1 {
return nil, ErrDecryption
}
subtle.XORBytes(key, c2, key[:len(c2)])
return key[:len(c2)], nil
return decryptCiphertext(key1, c2, opts)
}
func decryptCiphertext(key, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
switch opts.EncryptType {
case ENC_TYPE_XOR:
subtle.XORBytes(key, ciphertext, key)
return key, nil
case ENC_TYPE_ECB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
plaintext := make([]byte, len(ciphertext))
mode := _cipher.NewECBDecrypter(block)
mode.CryptBlocks(plaintext, ciphertext)
return opts.Padding.Unpad(plaintext)
case ENC_TYPE_CBC:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
if len(ciphertext) < blockSize {
return nil, ErrDecryption
}
iv := ciphertext[:blockSize]
ciphertext = ciphertext[blockSize:]
plaintext := make([]byte, len(ciphertext))
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(plaintext, ciphertext)
return opts.Padding.Unpad(plaintext)
case ENC_TYPE_CFB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
if len(ciphertext) < blockSize {
return nil, ErrDecryption
}
iv := ciphertext[:blockSize]
ciphertext = ciphertext[blockSize:]
plaintext := make([]byte, len(ciphertext))
stream := cipher.NewCFBDecrypter(block, iv)
stream.XORKeyStream(plaintext, ciphertext)
return plaintext, nil
case ENC_TYPE_OFB:
block, err := opts.CipherFactory(key)
if err != nil {
return nil, err
}
blockSize := block.BlockSize()
if len(ciphertext) < blockSize {
return nil, ErrDecryption
}
iv := ciphertext[:blockSize]
ciphertext = ciphertext[blockSize:]
plaintext := make([]byte, len(ciphertext))
stream := cipher.NewOFB(block, iv)
stream.XORKeyStream(plaintext, ciphertext)
return plaintext, nil
}
return nil, fmt.Errorf("sm9: unsupported encryption type <%v>", opts.EncryptType)
}
// DecryptASN1 decrypt chipher, ciphertext should be with ASN.1 format according
@ -489,39 +684,30 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error
!inner.Empty() {
return nil, errors.New("sm9: invalid ciphertext asn.1 data")
}
if encType != int(ENC_TYPE_XOR) {
return nil, fmt.Errorf("sm9: does not support this kind of encrypt type <%v> yet", encType)
}
// We just make assumption block cipher is SM4 and padding scheme is pkcs7
opts := NewEncrypterOpts(encryptType(encType), padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize)
c, err := unmarshalG1(c1Bytes)
if err != nil {
return nil, ErrDecryption
}
key, err := UnwrapKey(priv, uid, c, len(c2Bytes)+len(c3Bytes))
key1Len := opts.getKeySize(c2Bytes)
key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size)
if err != nil {
return nil, err
}
hash := sm3.New()
hash.Write(c2Bytes)
hash.Write(key[len(c2Bytes):])
c32 := hash.Sum(nil)
if goSubtle.ConstantTimeCompare(c3Bytes, c32) != 1 {
return nil, ErrDecryption
}
subtle.XORBytes(key, c2Bytes, key[:len(c2Bytes)])
return key[:len(c2Bytes)], nil
return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts)
}
// Decrypt decrypt chipher, ciphertext should be with ASN.1 format according
// SM9 cryptographic algorithm application specification, SM9Cipher definition.
func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte) ([]byte, error) {
func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) {
if ciphertext[0] == 0x30 { // should be ASN.1 format
return DecryptASN1(priv, uid, ciphertext)
}
// fallback to C1||C3||C2 raw format
return Decrypt(priv, uid, ciphertext)
return Decrypt(priv, uid, ciphertext, opts)
}
// KeyExchange key exchange struct, include internal stat in whole key exchange flow.

View File

@ -659,26 +659,31 @@ func TestEncryptDecrypt(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext)
if err != nil {
t.Fatal(err)
encTypes := []*EncrypterOpts{
DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts,
}
for _, opts := range encTypes {
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts)
if err != nil {
t.Fatal(err)
}
got, err := Decrypt(userKey, uid, cipher)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
got, err := Decrypt(userKey, uid, cipher, opts)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
got, err = userKey.Decrypt(uid, cipher)
if err != nil {
t.Fatal(err)
}
got, err = userKey.Decrypt(uid, cipher, opts)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
}
}
@ -694,27 +699,32 @@ func TestEncryptDecryptASN1(t *testing.T) {
if err != nil {
t.Fatal(err)
}
cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext)
if err != nil {
t.Fatal(err)
encTypes := []*EncrypterOpts{
DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts,
}
for _, opts := range encTypes {
cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts)
if err != nil {
t.Fatal(err)
}
got, err := DecryptASN1(userKey, uid, cipher)
if err != nil {
t.Fatal(err)
}
got, err := DecryptASN1(userKey, uid, cipher)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
got, err = userKey.Decrypt(uid, cipher)
if err != nil {
t.Fatal(err)
}
got, err = userKey.Decrypt(uid, cipher, opts)
if err != nil {
t.Fatal(err)
}
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
if string(got) != string(plaintext) {
t.Errorf("expected %v, got %v\n", string(plaintext), string(got))
}
}
}
@ -782,7 +792,7 @@ func BenchmarkEncrypt(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext)
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil)
if err != nil {
b.Fatal(err)
}
@ -803,14 +813,14 @@ func BenchmarkDecrypt(b *testing.B) {
if err != nil {
b.Fatal(err)
}
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext)
cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil)
if err != nil {
b.Fatal(err)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
got, err := Decrypt(userKey, uid, cipher)
got, err := Decrypt(userKey, uid, cipher, nil)
if err != nil {
b.Fatal(err)
}
@ -832,7 +842,7 @@ func BenchmarkDecryptASN1(b *testing.B) {
if err != nil {
b.Fatal(err)
}
cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext)
cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil)
if err != nil {
b.Fatal(err)
}