diff --git a/cipher/ctr_sm4_test.go b/cipher/ctr_sm4_test.go index 691733f..a702558 100644 --- a/cipher/ctr_sm4_test.go +++ b/cipher/ctr_sm4_test.go @@ -8,8 +8,6 @@ import ( "github.com/emmansun/gmsm/sm4" ) -var commonCounter = []byte{0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff} - var ctrSM4Tests = []struct { name string key []byte diff --git a/cipher/ecb.go b/cipher/ecb.go new file mode 100644 index 0000000..45c3b70 --- /dev/null +++ b/cipher/ecb.go @@ -0,0 +1,99 @@ +// Electronic Code Book (ECB) mode. + +// Please do NOT use this mode alone. +package cipher + +import ( + goCipher "crypto/cipher" + + "github.com/emmansun/gmsm/internal/alias" +) + +type ecb struct { + b goCipher.Block + blockSize int +} + +func newECB(b goCipher.Block) *ecb { + return &ecb{ + b: b, + blockSize: b.BlockSize(), + } +} + +func (x *ecb) validate(dst, src []byte) { + if len(src)%x.blockSize != 0 { + panic("cipher: input not full blocks") + } + if len(dst) < len(src) { + panic("cipher: output smaller than input") + } + if alias.InexactOverlap(dst[:len(src)], src) { + panic("cipher: invalid buffer overlap") + } +} + +type ecbEncrypter ecb + +// ecbEncAble is an interface implemented by ciphers that have a specific +// optimized implementation of ECB encryption, like sm4. +// NewECBEncrypter will check for this interface and return the specific +// BlockMode if found. +type ecbEncAble interface { + NewECBEncrypter() goCipher.BlockMode +} + +// NewECBEncrypter returns a BlockMode which encrypts in electronic code book +// mode, using the given Block. +func NewECBEncrypter(b goCipher.Block) goCipher.BlockMode { + if ecb, ok := b.(ecbEncAble); ok { + return ecb.NewECBEncrypter() + } + return (*ecbEncrypter)(newECB(b)) +} + +func (x *ecbEncrypter) BlockSize() int { return x.blockSize } + +func (x *ecbEncrypter) CryptBlocks(dst, src []byte) { + (*ecb)(x).validate(dst, src) + + for len(src) > 0 { + x.b.Encrypt(dst[:x.blockSize], src[:x.blockSize]) + src = src[x.blockSize:] + dst = dst[x.blockSize:] + } +} + +type ecbDecrypter ecb + +// ecbDecAble is an interface implemented by ciphers that have a specific +// optimized implementation of ECB decryption, like sm4. +// NewECBDecrypter will check for this interface and return the specific +// BlockMode if found. +type ecbDecAble interface { + NewECBDecrypter() goCipher.BlockMode +} + +// NewECBDecrypter returns a BlockMode which decrypts in electronic code book +// mode, using the given Block. +func NewECBDecrypter(b goCipher.Block) goCipher.BlockMode { + if ecb, ok := b.(ecbDecAble); ok { + return ecb.NewECBDecrypter() + } + return (*ecbDecrypter)(newECB(b)) +} + +func (x *ecbDecrypter) BlockSize() int { return x.blockSize } + +func (x *ecbDecrypter) CryptBlocks(dst, src []byte) { + (*ecb)(x).validate(dst, src) + + if len(src) == 0 { + return + } + for len(src) > 0 { + x.b.Decrypt(dst[:x.blockSize], src[:x.blockSize]) + src = src[x.blockSize:] + dst = dst[x.blockSize:] + } +} diff --git a/cipher/ecb_sm4_test.go b/cipher/ecb_sm4_test.go new file mode 100644 index 0000000..e898a3c --- /dev/null +++ b/cipher/ecb_sm4_test.go @@ -0,0 +1,119 @@ +package cipher_test + +import ( + "bytes" + "testing" + + "github.com/emmansun/gmsm/cipher" + "github.com/emmansun/gmsm/sm4" +) + +var ecbSM4Tests = []struct { + name string + key []byte + in []byte +}{ + { + "1 block", + []byte("0123456789ABCDEF"), + []byte("exampleplaintext"), + }, + { + "2 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintext"), + }, + { + "2 different blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextfedcba9876543210"), + }, + { + "3 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintext"), + }, + { + "4 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, + { + "5 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, + { + "6 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, + { + "7 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, + { + "8 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, + { + "9 same blocks", + []byte("0123456789ABCDEF"), + []byte("exampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintextexampleplaintext"), + }, +} + +func TestECBBasic(t *testing.T) { + for _, test := range ecbSM4Tests { + c, err := sm4.NewCipher(test.key) + if err != nil { + t.Errorf("%s: NewCipher(%d bytes) = %s", test.name, len(test.key), err) + continue + } + encrypter := cipher.NewECBEncrypter(c) + ciphertext := make([]byte, len(test.in)) + encrypter.CryptBlocks(ciphertext, test.in) + + plaintext := make([]byte, len(test.in)) + decrypter := cipher.NewECBDecrypter(c) + decrypter.CryptBlocks(plaintext, ciphertext) + if !bytes.Equal(test.in, plaintext) { + t.Errorf("%s: ECB encrypt/decrypt failed", test.name) + } + } +} + +func shouldPanic(t *testing.T, f func()) { + t.Helper() + defer func() { _ = recover() }() + f() + t.Errorf("should have panicked") +} + +func TestECBValidate(t *testing.T) { + key := make([]byte, 16) + src := make([]byte, 32) + c, err := sm4.NewCipher(key) + if err != nil { + t.Fatal(err) + } + + decrypter := cipher.NewECBDecrypter(c) + // test len(src) == 0 + decrypter.CryptBlocks(nil, nil) + + // cipher: input not full blocks + shouldPanic(t, func() { + decrypter.CryptBlocks(src, src[1:]) + }) + // cipher: output smaller than input + shouldPanic(t, func() { + decrypter.CryptBlocks(src[1:], src) + }) + // cipher: invalid buffer overlap + shouldPanic(t, func() { + decrypter.CryptBlocks(src[1:17], src[2:18]) + }) +} diff --git a/cipher/example_test.go b/cipher/example_test.go new file mode 100644 index 0000000..a817ab6 --- /dev/null +++ b/cipher/example_test.go @@ -0,0 +1,76 @@ +package cipher_test + +import ( + "crypto/aes" + "encoding/hex" + "fmt" + + "github.com/emmansun/gmsm/cipher" +) + +func ExampleNewECBEncrypter() { + // Load your secret key from a safe place and reuse it across multiple + // NewCipher calls. (Obviously don't use this example key for anything + // real.) If you want to convert a passphrase to a key, use a suitable + // package like bcrypt or scrypt. + key, _ := hex.DecodeString("6368616e676520746869732070617373") + plaintext := []byte("exampleplaintextexampleplaintext") + + // ECB mode works on blocks so plaintexts may need to be padded to the + // next whole block. For an example of such padding, see + // https://tools.ietf.org/html/rfc5246#section-6.2.3.2. Here we'll + // assume that the plaintext is already of the correct length. + if len(plaintext)%aes.BlockSize != 0 { + panic("plaintext is not a multiple of the block size") + } + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + ciphertext := make([]byte, len(plaintext)) + mode := cipher.NewECBEncrypter(block) + mode.CryptBlocks(ciphertext, plaintext) + + // It's important to remember that ciphertexts must be authenticated + // (i.e. by using crypto/hmac) as well as being encrypted in order to + // be secure. + + fmt.Printf("%x\n", ciphertext) +} + +func ExampleNewECBDecrypter() { + // Load your secret key from a safe place and reuse it across multiple + // NewCipher calls. (Obviously don't use this example key for anything + // real.) If you want to convert a passphrase to a key, use a suitable + // package like bcrypt or scrypt. + key, _ := hex.DecodeString("6368616e676520746869732070617373") + ciphertext, _ := hex.DecodeString("f42512e1e4039213bd449ba47faa1b74f42512e1e4039213bd449ba47faa1b74") + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + // ECB mode always works in whole blocks. + if len(ciphertext)%aes.BlockSize != 0 { + panic("ciphertext is not a multiple of the block size") + } + + mode := cipher.NewECBDecrypter(block) + + // CryptBlocks can work in-place if the two arguments are the same. + mode.CryptBlocks(ciphertext, ciphertext) + + // If the original plaintext lengths are not a multiple of the block + // size, padding would have to be added when encrypting, which would be + // removed at this point. For an example, see + // https://tools.ietf.org/html/rfc5246#section-6.2.3.2. However, it's + // critical to note that ciphertexts must be authenticated (i.e. by + // using crypto/hmac) before being decrypted in order to avoid creating + // a padding oracle. + + fmt.Printf("%s\n", ciphertext) + // Output: exampleplaintextexampleplaintext +} diff --git a/sm4/ctr_cipher_asm.go b/sm4/ctr_cipher_asm.go index 48cd965..333d3ae 100644 --- a/sm4/ctr_cipher_asm.go +++ b/sm4/ctr_cipher_asm.go @@ -34,7 +34,7 @@ func (c *sm4CipherAsm) NewCTR(iv []byte) cipher.Stream { } s := &ctr{ b: c, - ctr: make([]byte, c.batchBlocks*len(iv)), + ctr: make([]byte, c.blocksSize), out: make([]byte, 0, bufSize), outUsed: 0, } diff --git a/sm4/ecb_cipher_asm.go b/sm4/ecb_cipher_asm.go new file mode 100644 index 0000000..d4975e5 --- /dev/null +++ b/sm4/ecb_cipher_asm.go @@ -0,0 +1,82 @@ +//go:build (amd64 && !purego) || (arm64 && !purego) +// +build amd64,!purego arm64,!purego + +package sm4 + +import ( + "crypto/cipher" + + "github.com/emmansun/gmsm/internal/alias" +) + +// Assert that sm4CipherAsm implements the ecbEncAble and ecbDecAble interfaces. +var _ ecbEncAble = (*sm4CipherAsm)(nil) +var _ ecbDecAble = (*sm4CipherAsm)(nil) + +const ecbEncrypt = 1 +const ecbDecrypt = 0 + +type ecb struct { + b *sm4CipherAsm + enc int +} + +func (x *ecb) validate(dst, src []byte) { + if len(src)%BlockSize != 0 { + panic("cipher: input not full blocks") + } + if len(dst) < len(src) { + panic("cipher: output smaller than input") + } + if alias.InexactOverlap(dst[:len(src)], src) { + panic("cipher: invalid buffer overlap") + } +} + +func (b *sm4CipherAsm) NewECBEncrypter() cipher.BlockMode { + var c ecb + c.b = b + c.enc = ecbEncrypt + return &c +} + +func (b *sm4CipherAsm) NewECBDecrypter() cipher.BlockMode { + var c ecb + c.b = b + c.enc = ecbDecrypt + return &c +} + +func (x *ecb) BlockSize() int { return BlockSize } + +func (x *ecb) CryptBlocks(dst, src []byte) { + x.validate(dst, src) + if len(src) == 0 { + return + } + for len(src) >= x.b.blocksSize { + if x.enc == ecbEncrypt { + x.b.EncryptBlocks(dst[:x.b.blocksSize], src[:x.b.blocksSize]) + } else { + x.b.DecryptBlocks(dst[:x.b.blocksSize], src[:x.b.blocksSize]) + } + src = src[x.b.blocksSize:] + dst = dst[x.b.blocksSize:] + } + if len(src) > BlockSize { + temp := make([]byte, x.b.blocksSize) + copy(temp, src) + if x.enc == ecbEncrypt { + x.b.EncryptBlocks(temp, temp) + } else { + x.b.DecryptBlocks(temp, temp) + } + copy(dst, temp[:len(src)]) + } else if len(src) > 0 { + if x.enc == ecbEncrypt { + x.b.Encrypt(dst, src) + } else { + x.b.Decrypt(dst, src) + } + } +} diff --git a/sm4/ecb_cipher_asm_test.go b/sm4/ecb_cipher_asm_test.go new file mode 100644 index 0000000..130ba76 --- /dev/null +++ b/sm4/ecb_cipher_asm_test.go @@ -0,0 +1,33 @@ +package sm4 + +import ( + "testing" + + "github.com/emmansun/gmsm/cipher" +) + +func TestECBValidate(t *testing.T) { + key := make([]byte, 16) + src := make([]byte, 32) + c, err := NewCipher(key) + if err != nil { + t.Fatal(err) + } + + decrypter := cipher.NewECBDecrypter(c) + // test len(src) == 0 + decrypter.CryptBlocks(nil, nil) + + // cipher: input not full blocks + shouldPanic(t, func() { + decrypter.CryptBlocks(src, src[1:]) + }) + // cipher: output smaller than input + shouldPanic(t, func() { + decrypter.CryptBlocks(src[1:], src) + }) + // cipher: invalid buffer overlap + shouldPanic(t, func() { + decrypter.CryptBlocks(src[1:17], src[2:18]) + }) +} diff --git a/sm4/modes.go b/sm4/modes.go index 61e58e2..9993f69 100644 --- a/sm4/modes.go +++ b/sm4/modes.go @@ -2,6 +2,20 @@ package sm4 import "crypto/cipher" +// ecbcEncAble is implemented by cipher.Blocks that can provide an optimized +// implementation of ECB encryption through the cipher.BlockMode interface. +// See crypto/ecb.go. +type ecbEncAble interface { + NewECBEncrypter() cipher.BlockMode +} + +// ecbDecAble is implemented by cipher.Blocks that can provide an optimized +// implementation of ECB decryption through the cipher.BlockMode interface. +// See crypto/ecb.go. +type ecbDecAble interface { + NewECBDecrypter() cipher.BlockMode +} + // cbcEncAble is implemented by cipher.Blocks that can provide an optimized // implementation of CBC encryption through the cipher.BlockMode interface. // See crypto/cipher/cbc.go. diff --git a/sm9/README.md b/sm9/README.md index 6a80047..db6536b 100644 --- a/sm9/README.md +++ b/sm9/README.md @@ -3,7 +3,7 @@ 2.Sign/Verify 3.Key Exchange 4.Wrap/Unwrap Key -5.Encryption/Decryption (XOR mode) +5.Encryption/Decryption ## SM9 current performance: diff --git a/sm9/example_test.go b/sm9/example_test.go index b744280..880d5be 100644 --- a/sm9/example_test.go +++ b/sm9/example_test.go @@ -135,7 +135,7 @@ func ExampleEncryptPrivateKey_Decrypt() { } uid := []byte("Bob") cipherDer, _ := hex.DecodeString("307f020100034200042cb3e90b0977211597652f26ee4abbe275ccb18dd7f431876ab5d40cc2fc563d9417791c75bc8909336a4e6562450836cc863f51002e31ecf0c4aae8d98641070420638ca5bfb35d25cff7cbd684f3ed75f2d919da86a921a2e3e2e2f4cbcf583f240414b7e776811774722a8720752fb1355ce45dc3d0df") - plaintext, err := userKey.Decrypt(uid, cipherDer) + plaintext, err := userKey.Decrypt(uid, cipherDer, nil) if err != nil { fmt.Fprintf(os.Stderr, "Error from Decrypt: %s\n", err) return @@ -156,7 +156,7 @@ func ExampleEncryptMasterPublicKey_Encrypt() { hid := byte(0x03) uid := []byte("Bob") - ciphertext, err := masterPubKey.Encrypt(rand.Reader, uid, hid, []byte("Chinese IBE standard")) + ciphertext, err := masterPubKey.Encrypt(rand.Reader, uid, hid, []byte("Chinese IBE standard"), sm9.DefaultEncrypterOpts) if err != nil { fmt.Fprintf(os.Stderr, "Error from Encrypt: %s\n", err) return diff --git a/sm9/sm9.go b/sm9/sm9.go index 9530f76..9ed2043 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -3,6 +3,7 @@ package sm9 import ( "crypto" + "crypto/cipher" goSubtle "crypto/subtle" "encoding/binary" "errors" @@ -10,10 +11,13 @@ import ( "io" "math/big" + _cipher "github.com/emmansun/gmsm/cipher" "github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/kdf" + "github.com/emmansun/gmsm/padding" "github.com/emmansun/gmsm/sm3" + "github.com/emmansun/gmsm/sm4" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" @@ -267,7 +271,7 @@ func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) b return VerifyASN1(pub, uid, hid, hash, sig) } -// WrapKey generate and wrap key with reciever's uid and system hid +// WrapKey generates and wraps key with reciever's uid and system hid, returns generated key and cipher. func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *bn256.G1, err error) { q := pub.GenerateUserPublicKey(uid, hid) var ( @@ -392,78 +396,269 @@ func (priv *EncryptPrivateKey) UnwrapKey(uid, cipherDer []byte, kLen int) ([]byt return UnwrapKey(priv, uid, g, kLen) } -// Encrypt encrypt plaintext, output ciphertext with format C1||C3||C2 -func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { - key, cipher, err := WrapKey(rand, pub, uid, hid, len(plaintext)+sm3.Size) +type CipherFactory func(key []byte) (cipher.Block, error) + +// EncrypterOpts indicate encrypt/decrypt options +type EncrypterOpts struct { + EncryptType encryptType + Padding padding.Padding + CipherFactory CipherFactory + CipherKeySize int +} + +type DecrypterOpts EncrypterOpts + +func (opts *EncrypterOpts) getKeySize(plaintext []byte) int { + if opts.EncryptType == ENC_TYPE_XOR { + return len(plaintext) + } + return opts.CipherKeySize +} + +// NewEncrypterOpts creates EncrypterOpts with given parameters +func NewEncrypterOpts(encType encryptType, padMode padding.Padding, factory CipherFactory, cipherKeySize int) *EncrypterOpts { + opts := new(EncrypterOpts) + opts.EncryptType = encType + opts.Padding = padMode + opts.CipherFactory = factory + opts.CipherKeySize = cipherKeySize + return opts +} + +// DefaultEncrypterOpts default option represents XOR mode +var DefaultEncrypterOpts = NewEncrypterOpts(ENC_TYPE_XOR, nil, nil, 0) + +// SM4ECBEncrypterOpts option represents SM4 ECB mode +var SM4ECBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_ECB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize) + +// SM4CBCEncrypterOpts option represents SM4 CBC mode +var SM4CBCEncrypterOpts = NewEncrypterOpts(ENC_TYPE_CBC, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize) + +// SM4CFBEncrypterOpts option represents SM4 CFB mode +var SM4CFBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_CFB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize) + +// SM4OFBEncrypterOpts option represents SM4 OFB mode +var SM4OFBEncrypterOpts = NewEncrypterOpts(ENC_TYPE_OFB, padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize) + +// Encrypt encrypt plaintext, output ciphertext with format C1||C3||C2. +func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) { + c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts) if err != nil { return nil, err } - subtle.XORBytes(key, key[:len(plaintext)], plaintext) + ciphertext := append(c1.Marshal(), c3...) + ciphertext = append(ciphertext, c2...) + return ciphertext, nil +} + +func encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) (c1 *bn256.G1, c2, c3 []byte, err error) { + if opts == nil { + opts = DefaultEncrypterOpts + } + key1Len := opts.getKeySize(plaintext) + key, c1, err := WrapKey(rand, pub, uid, hid, key1Len+sm3.Size) + if err != nil { + return nil, nil, nil, err + } + c2, err = encryptPlaintext(rand, key[:key1Len], plaintext, opts) + if err != nil { + return nil, nil, nil, err + } hash := sm3.New() - hash.Write(key) - c3 := hash.Sum(nil) + hash.Write(c2) + hash.Write(key[key1Len:]) + c3 = hash.Sum(nil) - ciphertext := append(cipher.Marshal(), c3...) - ciphertext = append(ciphertext, key[:len(plaintext)]...) - return ciphertext, nil + return +} + +func encryptPlaintext(rand io.Reader, key, plaintext []byte, opts *EncrypterOpts) ([]byte, error) { + switch opts.EncryptType { + case ENC_TYPE_XOR: + subtle.XORBytes(key, key, plaintext) + return key, nil + case ENC_TYPE_ECB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + paddedPlainText := opts.Padding.Pad(plaintext) + ciphertext := make([]byte, len(paddedPlainText)) + mode := _cipher.NewECBEncrypter(block) + mode.CryptBlocks(ciphertext, paddedPlainText) + return ciphertext, nil + case ENC_TYPE_CBC: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + paddedPlainText := opts.Padding.Pad(plaintext) + blockSize := block.BlockSize() + ciphertext := make([]byte, blockSize+len(paddedPlainText)) + iv := ciphertext[:blockSize] + if _, err := io.ReadFull(rand, iv); err != nil { + return nil, err + } + mode := cipher.NewCBCEncrypter(block, iv) + mode.CryptBlocks(ciphertext[blockSize:], paddedPlainText) + return ciphertext, nil + case ENC_TYPE_CFB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + ciphertext := make([]byte, blockSize+len(plaintext)) + iv := ciphertext[:blockSize] + if _, err := io.ReadFull(rand, iv); err != nil { + return nil, err + } + stream := cipher.NewCFBEncrypter(block, iv) + stream.XORKeyStream(ciphertext[blockSize:], plaintext) + return ciphertext, nil + case ENC_TYPE_OFB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + ciphertext := make([]byte, blockSize+len(plaintext)) + iv := ciphertext[:blockSize] + if _, err := io.ReadFull(rand, iv); err != nil { + return nil, err + } + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(ciphertext[blockSize:], plaintext) + return ciphertext, nil + } + return nil, fmt.Errorf("sm9: unsupported encryption type <%v>", opts.EncryptType) } // EncryptASN1 encrypt plaintext and output ciphertext with ASN.1 format according // SM9 cryptographic algorithm application specification, SM9Cipher definition. -func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { - return pub.Encrypt(rand, uid, hid, plaintext) +func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) { + return pub.Encrypt(rand, uid, hid, plaintext, opts) } // Encrypt encrypt plaintext and output ciphertext with ASN.1 format according // SM9 cryptographic algorithm application specification, SM9Cipher definition. -func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte) ([]byte, error) { - key, cipher, err := WrapKey(rand, pub, uid, hid, len(plaintext)+sm3.Size) +func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte, opts *EncrypterOpts) ([]byte, error) { + if opts == nil { + opts = DefaultEncrypterOpts + } + c1, c2, c3, err := encrypt(rand, pub, uid, hid, plaintext, opts) if err != nil { return nil, err } - subtle.XORBytes(key, key[:len(plaintext)], plaintext) - - hash := sm3.New() - hash.Write(key) - c3 := hash.Sum(nil) var b cryptobyte.Builder b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { - b.AddASN1Int64(int64(ENC_TYPE_XOR)) - b.AddASN1BitString(cipher.MarshalUncompressed()) + b.AddASN1Int64(int64(opts.EncryptType)) + b.AddASN1BitString(c1.MarshalUncompressed()) b.AddASN1OctetString(c3) - b.AddASN1OctetString(key[:len(plaintext)]) + b.AddASN1OctetString(c2) }) return b.Bytes() } // Decrypt decrypt chipher, ciphertext should be with format C1||C3||C2 -func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { +func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { + if opts == nil { + opts = DefaultEncrypterOpts + } + c := &bn256.G1{} - c3, err := c.Unmarshal(ciphertext) + c3c2, err := c.Unmarshal(ciphertext) if err != nil { return nil, ErrDecryption } - key, err := UnwrapKey(priv, uid, c, len(c3)) + c2 := c3c2[sm3.Size:] + key1Len := opts.getKeySize(c2) + + key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size) if err != nil { return nil, err } - c2 := c3[sm3.Size:] + return decrypt(c, key[:key1Len], key[key1Len:], c2, c3c2[:sm3.Size], opts) +} +func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts *EncrypterOpts) ([]byte, error) { hash := sm3.New() hash.Write(c2) - hash.Write(key[len(c2):]) + hash.Write(key2) c32 := hash.Sum(nil) - if goSubtle.ConstantTimeCompare(c3[:sm3.Size], c32) != 1 { + if goSubtle.ConstantTimeCompare(c3, c32) != 1 { return nil, ErrDecryption } - subtle.XORBytes(key, c2, key[:len(c2)]) - return key[:len(c2)], nil + return decryptCiphertext(key1, c2, opts) +} + +func decryptCiphertext(key, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { + switch opts.EncryptType { + case ENC_TYPE_XOR: + subtle.XORBytes(key, ciphertext, key) + return key, nil + case ENC_TYPE_ECB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + plaintext := make([]byte, len(ciphertext)) + mode := _cipher.NewECBDecrypter(block) + mode.CryptBlocks(plaintext, ciphertext) + return opts.Padding.Unpad(plaintext) + case ENC_TYPE_CBC: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + if len(ciphertext) < blockSize { + return nil, ErrDecryption + } + iv := ciphertext[:blockSize] + ciphertext = ciphertext[blockSize:] + plaintext := make([]byte, len(ciphertext)) + mode := cipher.NewCBCDecrypter(block, iv) + mode.CryptBlocks(plaintext, ciphertext) + return opts.Padding.Unpad(plaintext) + case ENC_TYPE_CFB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + if len(ciphertext) < blockSize { + return nil, ErrDecryption + } + iv := ciphertext[:blockSize] + ciphertext = ciphertext[blockSize:] + plaintext := make([]byte, len(ciphertext)) + stream := cipher.NewCFBDecrypter(block, iv) + stream.XORKeyStream(plaintext, ciphertext) + return plaintext, nil + case ENC_TYPE_OFB: + block, err := opts.CipherFactory(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + if len(ciphertext) < blockSize { + return nil, ErrDecryption + } + iv := ciphertext[:blockSize] + ciphertext = ciphertext[blockSize:] + plaintext := make([]byte, len(ciphertext)) + stream := cipher.NewOFB(block, iv) + stream.XORKeyStream(plaintext, ciphertext) + return plaintext, nil + } + return nil, fmt.Errorf("sm9: unsupported encryption type <%v>", opts.EncryptType) } // DecryptASN1 decrypt chipher, ciphertext should be with ASN.1 format according @@ -489,39 +684,30 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error !inner.Empty() { return nil, errors.New("sm9: invalid ciphertext asn.1 data") } - if encType != int(ENC_TYPE_XOR) { - return nil, fmt.Errorf("sm9: does not support this kind of encrypt type <%v> yet", encType) - } + // We just make assumption block cipher is SM4 and padding scheme is pkcs7 + opts := NewEncrypterOpts(encryptType(encType), padding.NewPKCS7Padding(sm4.BlockSize), sm4.NewCipher, sm4.BlockSize) c, err := unmarshalG1(c1Bytes) if err != nil { return nil, ErrDecryption } - key, err := UnwrapKey(priv, uid, c, len(c2Bytes)+len(c3Bytes)) + key1Len := opts.getKeySize(c2Bytes) + key, err := UnwrapKey(priv, uid, c, key1Len+sm3.Size) if err != nil { return nil, err } - hash := sm3.New() - hash.Write(c2Bytes) - hash.Write(key[len(c2Bytes):]) - c32 := hash.Sum(nil) - - if goSubtle.ConstantTimeCompare(c3Bytes, c32) != 1 { - return nil, ErrDecryption - } - subtle.XORBytes(key, c2Bytes, key[:len(c2Bytes)]) - return key[:len(c2Bytes)], nil + return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts) } // Decrypt decrypt chipher, ciphertext should be with ASN.1 format according // SM9 cryptographic algorithm application specification, SM9Cipher definition. -func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte) ([]byte, error) { +func (priv *EncryptPrivateKey) Decrypt(uid, ciphertext []byte, opts *EncrypterOpts) ([]byte, error) { if ciphertext[0] == 0x30 { // should be ASN.1 format return DecryptASN1(priv, uid, ciphertext) } // fallback to C1||C3||C2 raw format - return Decrypt(priv, uid, ciphertext) + return Decrypt(priv, uid, ciphertext, opts) } // KeyExchange key exchange struct, include internal stat in whole key exchange flow. diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index 23f26ca..271044d 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -659,26 +659,31 @@ func TestEncryptDecrypt(t *testing.T) { if err != nil { t.Fatal(err) } - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) - if err != nil { - t.Fatal(err) + encTypes := []*EncrypterOpts{ + DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts, } + for _, opts := range encTypes { + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) + if err != nil { + t.Fatal(err) + } - got, err := Decrypt(userKey, uid, cipher) - if err != nil { - t.Fatal(err) - } - if string(got) != string(plaintext) { - t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) - } + got, err := Decrypt(userKey, uid, cipher, opts) + if err != nil { + t.Fatal(err) + } + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } - got, err = userKey.Decrypt(uid, cipher) - if err != nil { - t.Fatal(err) - } + got, err = userKey.Decrypt(uid, cipher, opts) + if err != nil { + t.Fatal(err) + } - if string(got) != string(plaintext) { - t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } } } @@ -694,27 +699,32 @@ func TestEncryptDecryptASN1(t *testing.T) { if err != nil { t.Fatal(err) } - cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext) - if err != nil { - t.Fatal(err) + encTypes := []*EncrypterOpts{ + DefaultEncrypterOpts, SM4ECBEncrypterOpts, SM4CBCEncrypterOpts, SM4CFBEncrypterOpts, SM4OFBEncrypterOpts, } + for _, opts := range encTypes { + cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, opts) + if err != nil { + t.Fatal(err) + } - got, err := DecryptASN1(userKey, uid, cipher) - if err != nil { - t.Fatal(err) - } + got, err := DecryptASN1(userKey, uid, cipher) + if err != nil { + t.Fatal(err) + } - if string(got) != string(plaintext) { - t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) - } + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } - got, err = userKey.Decrypt(uid, cipher) - if err != nil { - t.Fatal(err) - } + got, err = userKey.Decrypt(uid, cipher, opts) + if err != nil { + t.Fatal(err) + } - if string(got) != string(plaintext) { - t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } } } @@ -782,7 +792,7 @@ func BenchmarkEncrypt(b *testing.B) { b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) } @@ -803,14 +813,14 @@ func BenchmarkDecrypt(b *testing.B) { if err != nil { b.Fatal(err) } - cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) } b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - got, err := Decrypt(userKey, uid, cipher) + got, err := Decrypt(userKey, uid, cipher, nil) if err != nil { b.Fatal(err) } @@ -832,7 +842,7 @@ func BenchmarkDecryptASN1(b *testing.B) { if err != nil { b.Fatal(err) } - cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext) + cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext, nil) if err != nil { b.Fatal(err) }