feat: 新增XTS/CCM流式与KDF能力,补充安全测试并更新README/CHANGELOG

This commit is contained in:
2026-03-18 13:43:18 +08:00
parent e89350b56a
commit 4fa79744e8
44 changed files with 4636 additions and 77 deletions
+234 -9
View File
@@ -7,6 +7,7 @@ import (
"errors"
"io"
"b612.me/starcrypto/ccm"
"b612.me/starcrypto/paddingx"
)
@@ -17,13 +18,27 @@ const (
ANSIX923PADDING = paddingx.ANSIX923
)
var ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
var (
ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
ErrInvalidCCMNonceLength = errors.New("ccm nonce length must be 12 bytes")
)
const (
aeadCCMTagSize = 16
aeadCCMNonceSize = 12
)
func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptAesGCM(data, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptAesCCM(data, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@@ -34,9 +49,15 @@ func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error)
func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptAesGCM(src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptAesCCM(src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@@ -47,9 +68,15 @@ func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptAesGCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptAesCCMStream(dst, src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@@ -60,9 +87,15 @@ func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptAesGCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptAesCCMStream(dst, src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@@ -80,6 +113,9 @@ func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
if mode == MODEGCM {
return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptAesCCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptAes(data, key, cfg.IV, mode, cfg.Padding)
}
@@ -92,6 +128,9 @@ func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
if mode == MODEGCM {
return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptAesCCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptAes(src, key, cfg.IV, mode, cfg.Padding)
}
@@ -104,6 +143,9 @@ func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@@ -116,6 +158,9 @@ func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@@ -149,30 +194,138 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return gcm.Open(nil, nonce, ciphertext, aad)
}
func newAesCCM(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
}
func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Seal(nil, nonce, plain, aad), nil
}
func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Open(nil, nonce, ciphertext, aad)
}
func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
}
func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
}
func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newAesCCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newAesCCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
}
func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
}
func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
}
func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
plain, err := io.ReadAll(src)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
out, err := EncryptAesGCM(plain, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
}
func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
out, err := DecryptAesGCM(enc, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
}
func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) {
@@ -215,6 +368,18 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) {
return DecryptAes(src, key, iv, MODECTR, "")
}
func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return xorCTRAtOffset(block, data, iv, offset)
}
func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return EncryptAesCTRAt(src, key, iv, offset)
}
func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return EncryptAesStream(dst, src, key, nil, MODEECB, paddingType)
}
@@ -329,3 +494,63 @@ func PKCS7Trimming(encrypted []byte, blockSize int) []byte {
}
return out
}
func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return encryptCFB8(block, data, iv)
}
func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFB8(block, src, iv)
}
func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := aes.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, false)
}
func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := aes.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, true)
}
func DecryptAesECBBlocks(src, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptECBBlocks(block, src)
}
// DecryptAesCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
}
// DecryptAesCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
}
+98
View File
@@ -0,0 +1,98 @@
package symm
import (
"bytes"
"testing"
)
var (
benchPlain4K = bytes.Repeat([]byte("0123456789abcdef"), 256) // 4 KiB
benchPlain256K = bytes.Repeat([]byte("0123456789abcdef"), 16384) // 256 KiB
benchAAD = []byte("benchmark-aad")
benchAESKey = []byte("0123456789abcdef")
benchSM4Key = []byte("0123456789abcdef")
benchXTSKey2 = []byte("fedcba9876543210")
benchNonce12Byte = []byte("123456789012")
)
func BenchmarkAesGCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesGCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesCCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesCCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesXTSEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesXTS(benchPlain4K, benchAESKey, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4GCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4GCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4CCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4CCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4XTSEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4XTS(benchPlain4K, benchSM4Key, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesXTSStreamEncrypt256K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain256K)))
for i := 0; i < b.N; i++ {
var dst bytes.Buffer
if err := EncryptAesXTSStream(&dst, bytes.NewReader(benchPlain256K), benchAESKey, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4XTSStreamEncrypt256K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain256K)))
for i := 0; i < b.N; i++ {
var dst bytes.Buffer
if err := EncryptSM4XTSStream(&dst, bytes.NewReader(benchPlain256K), benchSM4Key, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
+127
View File
@@ -0,0 +1,127 @@
package symm
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
)
const ccmStreamMagic = "SCC1"
var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk")
func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Seal(nil, chunkNonce, plain, aad)
}
func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Open(nil, chunkNonce, ciphertext, aad)
}
func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil {
return err
}
buf := make([]byte, gcmStreamChunkSize)
lenBuf := make([]byte, 4)
var chunkIndex uint64
for {
n, err := src.Read(buf)
if n > 0 {
sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
if _, werr := dst.Write(lenBuf); werr != nil {
return werr
}
if _, werr := dst.Write(sealed); werr != nil {
return werr
}
chunkIndex++
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
header := make([]byte, len(ccmStreamMagic))
n, err := io.ReadFull(src, header)
if err != nil {
if err == io.EOF {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
}
if string(header) != ccmStreamMagic {
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
}
return decryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
lenBuf := make([]byte, 4)
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
var chunkIndex uint64
for {
_, err := io.ReadFull(src, lenBuf)
if err != nil {
if err == io.EOF {
return nil
}
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
chunkLen := binary.BigEndian.Uint32(lenBuf)
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
return ErrInvalidCCMStreamChunk
}
chunk := make([]byte, chunkLen)
if _, err := io.ReadFull(src, chunk); err != nil {
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex)
if err != nil {
return err
}
if _, err := dst.Write(plain); err != nil {
return err
}
chunkIndex++
}
}
func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
if err != nil {
return err
}
plain, err := aead.Open(nil, nonce, enc, aad)
if err != nil {
return err
}
_, err = dst.Write(plain)
return err
}
+103
View File
@@ -0,0 +1,103 @@
package symm
import (
"crypto/cipher"
"errors"
"io"
)
func encryptCFB8(block cipher.Block, data, iv []byte) ([]byte, error) {
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
out := make([]byte, len(data))
head := 0
for i := 0; i < len(data); i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
c := data[i] ^ streamBlock[0]
out[i] = c
advanceCFB8Register(reg, &head, c)
}
return out, nil
}
func decryptCFB8(block cipher.Block, src, iv []byte) ([]byte, error) {
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
out := make([]byte, len(src))
head := 0
for i := 0; i < len(src); i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
p := src[i] ^ streamBlock[0]
out[i] = p
advanceCFB8Register(reg, &head, src[i])
}
return out, nil
}
func encryptCFB8Stream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, decrypt bool) error {
if len(iv) != block.BlockSize() {
return errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
buf := make([]byte, 32*1024)
out := make([]byte, 32*1024)
head := 0
for {
n, err := src.Read(buf)
if n > 0 {
for i := 0; i < n; i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
if decrypt {
out[i] = buf[i] ^ streamBlock[0]
advanceCFB8Register(reg, &head, buf[i])
} else {
c := buf[i] ^ streamBlock[0]
out[i] = c
advanceCFB8Register(reg, &head, c)
}
}
if _, werr := dst.Write(out[:n]); werr != nil {
return werr
}
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func buildCFB8Register(dst, reg []byte, head int) {
first := len(reg) - head
copy(dst, reg[head:])
copy(dst[first:], reg[:head])
}
func advanceCFB8Register(reg []byte, head *int, feedback byte) {
reg[*head] = feedback
*head = *head + 1
if *head == len(reg) {
*head = 0
}
}
+56
View File
@@ -0,0 +1,56 @@
package symm
import (
"crypto/cipher"
"errors"
)
var (
ErrInvalidCTROffset = errors.New("ctr offset must be non-negative")
ErrCTRCounterOverflow = errors.New("ctr counter overflow")
)
func xorCTRAtOffset(block cipher.Block, data, iv []byte, offset int64) ([]byte, error) {
if offset < 0 {
return nil, ErrInvalidCTROffset
}
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
counter := make([]byte, len(iv))
copy(counter, iv)
blockSize := int64(block.BlockSize())
blockOffset := uint64(offset / blockSize)
byteOffset := int(offset % blockSize)
if err := addUint64ToCounter(counter, blockOffset); err != nil {
return nil, err
}
stream := cipher.NewCTR(block, counter)
if byteOffset > 0 {
skip := make([]byte, byteOffset)
stream.XORKeyStream(skip, skip)
}
out := make([]byte, len(data))
stream.XORKeyStream(out, data)
return out, nil
}
func addUint64ToCounter(counter []byte, inc uint64) error {
if inc == 0 {
return nil
}
for i := len(counter) - 1; i >= 0 && inc > 0; i-- {
sum := uint64(counter[i]) + (inc & 0xff)
counter[i] = byte(sum)
inc = (inc >> 8) + (sum >> 8)
}
if inc > 0 {
return ErrCTRCounterOverflow
}
return nil
}
+119
View File
@@ -70,3 +70,122 @@ func FuzzAesCBCStreamRoundTrip(f *testing.F) {
}
})
}
func xtsFuzzDataUnitSize(selector uint8) int {
sizes := [...]int{16, 32, 64, 128, 256, 512}
return sizes[int(selector)%len(sizes)]
}
func xtsFuzzNormalizeData(data []byte, maxLen int) []byte {
if len(data) > maxLen {
data = data[:maxLen]
}
return data[:len(data)/16*16]
}
func FuzzAesXTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("aes xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("aes xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts bytes error for non-block-aligned input")
}
if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts stream error for non-block-aligned input")
}
}
})
}
func FuzzSM4XTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("sm4 xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("sm4 xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts bytes error for non-block-aligned input")
}
if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts stream error for non-block-aligned input")
}
}
})
}
+146
View File
@@ -0,0 +1,146 @@
package symm
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
)
const (
gcmStreamMagic = "SCG1"
gcmStreamChunkSize = 32 * 1024
)
var ErrInvalidGCMStreamChunk = errors.New("invalid gcm stream chunk")
func encryptGCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Seal(nil, chunkNonce, plain, aad)
}
func decryptGCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Open(nil, chunkNonce, ciphertext, aad)
}
func encryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
if _, err := dst.Write([]byte(gcmStreamMagic)); err != nil {
return err
}
buf := make([]byte, gcmStreamChunkSize)
lenBuf := make([]byte, 4)
var chunkIndex uint64
for {
n, err := src.Read(buf)
if n > 0 {
sealed := encryptGCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
if _, werr := dst.Write(lenBuf); werr != nil {
return werr
}
if _, werr := dst.Write(sealed); werr != nil {
return werr
}
chunkIndex++
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func decryptGCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
header := make([]byte, len(gcmStreamMagic))
n, err := io.ReadFull(src, header)
if err != nil {
if err == io.EOF {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
}
if string(header) != gcmStreamMagic {
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
}
return decryptGCMChunkedStream(dst, src, aead, nonce, aad)
}
func decryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
lenBuf := make([]byte, 4)
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
var chunkIndex uint64
for {
_, err := io.ReadFull(src, lenBuf)
if err != nil {
if err == io.EOF {
return nil
}
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
chunkLen := binary.BigEndian.Uint32(lenBuf)
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
return ErrInvalidGCMStreamChunk
}
chunk := make([]byte, chunkLen)
if _, err := io.ReadFull(src, chunk); err != nil {
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
plain, err := decryptGCMChunk(aead, chunk, nonce, aad, chunkIndex)
if err != nil {
return err
}
if _, err := dst.Write(plain); err != nil {
return err
}
chunkIndex++
}
}
func decryptGCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
if err != nil {
return err
}
plain, err := aead.Open(nil, nonce, enc, aad)
if err != nil {
return err
}
_, err = dst.Write(plain)
return err
}
func deriveChunkNonce(baseNonce []byte, chunkIndex uint64) []byte {
nonce := make([]byte, len(baseNonce))
copy(nonce, baseNonce)
if len(nonce) < 8 {
return nonce
}
var indexBytes [8]byte
binary.BigEndian.PutUint64(indexBytes[:], chunkIndex)
off := len(nonce) - 8
for i := 0; i < 8; i++ {
nonce[off+i] ^= indexBytes[i]
}
return nonce
}
+1
View File
@@ -16,6 +16,7 @@ const (
MODEOFB = "OFB"
MODECTR = "CTR"
MODEGCM = "GCM"
MODECCM = "CCM"
)
var ErrUnsupportedCipherMode = errors.New("cipher mode not supported")
+2 -5
View File
@@ -1,7 +1,7 @@
package symm
// CipherOptions provides a unified configuration for symmetric APIs.
// For GCM mode, Nonce is used; if Nonce is empty, IV is used as fallback.
// For AEAD modes (GCM/CCM), Nonce must be set explicitly.
type CipherOptions struct {
Mode string
Padding string
@@ -18,8 +18,5 @@ func normalizeCipherOptions(opts *CipherOptions) CipherOptions {
}
func nonceFromOptions(opts CipherOptions) []byte {
if len(opts.Nonce) > 0 {
return opts.Nonce
}
return opts.IV
return opts.Nonce
}
+28
View File
@@ -0,0 +1,28 @@
package symm
import (
"errors"
"testing"
)
func TestAEADOptionsRequireNonce(t *testing.T) {
aesKey := []byte("0123456789abcdef")
sm4Key := []byte("0123456789abcdef")
plain := []byte("nonce-required")
gcmIVOnly := &CipherOptions{Mode: MODEGCM, IV: []byte("123456789012")}
if _, err := EncryptAesWithOptions(plain, aesKey, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
t.Fatalf("expected ErrInvalidGCMNonceLength for AES GCM with IV-only opts, got: %v", err)
}
if _, err := EncryptSM4WithOptions(plain, sm4Key, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
t.Fatalf("expected ErrInvalidGCMNonceLength for SM4 GCM with IV-only opts, got: %v", err)
}
ccmIVOnly := &CipherOptions{Mode: MODECCM, IV: []byte("123456789012")}
if _, err := EncryptAesWithOptions(plain, aesKey, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
t.Fatalf("expected ErrInvalidCCMNonceLength for AES CCM with IV-only opts, got: %v", err)
}
if _, err := EncryptSM4WithOptions(plain, sm4Key, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
t.Fatalf("expected ErrInvalidCCMNonceLength for SM4 CCM with IV-only opts, got: %v", err)
}
}
+54
View File
@@ -0,0 +1,54 @@
package symm
import (
"crypto/cipher"
"errors"
)
var (
ErrSegmentNotFullBlocks = errors.New("ciphertext segment is not a full block size")
)
func decryptECBBlocks(block cipher.Block, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(src)%block.BlockSize() != 0 {
return nil, ErrSegmentNotFullBlocks
}
out := make([]byte, len(src))
ecbDecryptBlocks(block, out, src)
return out, nil
}
// decryptCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from the second block (or later).
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
func decryptCBCFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(prevCipherBlock) != block.BlockSize() {
return nil, errors.New("prev cipher block length must match block size")
}
if len(src)%block.BlockSize() != 0 {
return nil, ErrSegmentNotFullBlocks
}
out := make([]byte, len(src))
cipher.NewCBCDecrypter(block, prevCipherBlock).CryptBlocks(out, src)
return out, nil
}
// decryptCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from the second block (or later).
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
func decryptCFBFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(prevCipherBlock) != block.BlockSize() {
return nil, errors.New("prev cipher block length must match block size")
}
stream := cipher.NewCFBDecrypter(block, prevCipherBlock)
out := make([]byte, len(src))
stream.XORKeyStream(out, src)
return out, nil
}
+225 -8
View File
@@ -6,14 +6,21 @@ import (
"errors"
"io"
"b612.me/starcrypto/ccm"
"github.com/emmansun/gmsm/sm4"
)
func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptSM4GCM(data, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptSM4CCM(data, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@@ -24,9 +31,15 @@ func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error)
func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptSM4GCM(src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptSM4CCM(src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@@ -37,9 +50,15 @@ func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptSM4GCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptSM4CCMStream(dst, src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@@ -50,9 +69,15 @@ func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptSM4GCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptSM4CCMStream(dst, src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@@ -70,6 +95,9 @@ func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
if mode == MODEGCM {
return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptSM4CCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding)
}
@@ -82,6 +110,9 @@ func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
if mode == MODEGCM {
return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptSM4CCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding)
}
@@ -94,6 +125,9 @@ func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@@ -106,6 +140,9 @@ func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@@ -139,30 +176,138 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return gcm.Open(nil, nonce, ciphertext, aad)
}
func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
}
func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
}
func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
plain, err := io.ReadAll(src)
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
out, err := EncryptSM4GCM(plain, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
}
func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
out, err := DecryptSM4GCM(enc, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
}
func newSM4CCM(key []byte) (cipher.AEAD, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
}
func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Seal(nil, nonce, plain, aad), nil
}
func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Open(nil, nonce, ciphertext, aad)
}
func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
}
func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
}
func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newSM4CCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newSM4CCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
}
func EncryptSM4CFB(origData, key []byte) ([]byte, error) {
@@ -235,6 +380,18 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) {
return DecryptSM4(src, key, iv, MODECTR, "")
}
func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return xorCTRAtOffset(block, data, iv, offset)
}
func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return EncryptSM4CTRAt(src, key, iv, offset)
}
func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return EncryptSM4Stream(dst, src, key, nil, MODEECB, paddingType)
}
@@ -274,3 +431,63 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
return DecryptSM4Stream(dst, src, key, iv, MODECTR, "")
}
func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return encryptCFB8(block, data, iv)
}
func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFB8(block, src, iv)
}
func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, false)
}
func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, true)
}
func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptECBBlocks(block, src)
}
// DecryptSM4CBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
}
// DecryptSM4CFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
}
+739 -8
View File
@@ -6,21 +6,39 @@ import (
"testing"
)
func TestEncryptAesDefaultModeCBC(t *testing.T) {
func TestEncryptAesDefaultModeGCM(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("aes-default-mode-cbc")
nonce := []byte("123456789012")
plain := []byte("aes-default-mode-gcm")
encDefault, err := EncryptAes(plain, key, iv, "", "")
encDefault, err := EncryptAes(plain, key, nonce, "", "")
if err != nil {
t.Fatalf("EncryptAes default failed: %v", err)
}
encCBC, err := EncryptAesCBC(plain, key, iv, "")
encGCM, err := EncryptAesGCM(plain, key, nonce, nil)
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
t.Fatalf("EncryptAesGCM failed: %v", err)
}
if !bytes.Equal(encDefault, encCBC) {
t.Fatalf("default mode should match CBC mode")
if !bytes.Equal(encDefault, encGCM) {
t.Fatalf("default mode should match GCM mode")
}
}
func TestEncryptSM4DefaultModeGCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
plain := []byte("sm4-default-mode-gcm")
encDefault, err := EncryptSM4(plain, key, nonce, "", "")
if err != nil {
t.Fatalf("EncryptSM4 default failed: %v", err)
}
encGCM, err := EncryptSM4GCM(plain, key, nonce, nil)
if err != nil {
t.Fatalf("EncryptSM4GCM failed: %v", err)
}
if !bytes.Equal(encDefault, encGCM) {
t.Fatalf("default mode should match GCM mode")
}
}
@@ -529,6 +547,62 @@ func TestChaCha20Poly1305RFCVector(t *testing.T) {
t.Fatalf("ChaCha20-Poly1305 vector mismatch: got %x want %x", enc, want)
}
}
func TestAesGCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("aes-gcm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptAesGCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptAesGCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesGCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesGCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes gcm stream mismatch")
}
}
func TestAesGCMStreamLegacyCompatDecrypt(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-gcm-legacy-compat")
legacyCipher, err := EncryptAesGCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesGCM failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesGCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesGCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes gcm legacy decrypt mismatch")
}
}
func TestSM4GCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("sm4-gcm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptSM4GCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptSM4GCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4GCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptSM4GCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 gcm stream mismatch")
}
}
func TestAesOptionsDefaultToGCM(t *testing.T) {
key := []byte("0123456789abcdef")
@@ -598,6 +672,113 @@ func TestLargeStreamRoundTrip(t *testing.T) {
}
}
func TestAesCTRAtOffsetSegment(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("aes-ctr-offset-"), 256)
full, err := EncryptAesCTR(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCTR failed: %v", err)
}
offset := 137
length := 521
segCipher := full[offset : offset+length]
segPlain, err := DecryptAesCTRAt(segCipher, key, iv, int64(offset))
if err != nil {
t.Fatalf("DecryptAesCTRAt failed: %v", err)
}
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
t.Fatalf("aes ctr offset decrypt mismatch")
}
encSeg, err := EncryptAesCTRAt(plain[offset:offset+length], key, iv, int64(offset))
if err != nil {
t.Fatalf("EncryptAesCTRAt failed: %v", err)
}
if !bytes.Equal(encSeg, segCipher) {
t.Fatalf("aes ctr offset encrypt mismatch")
}
}
func TestSM4CTRAtOffsetSegment(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("sm4-ctr-offset-"), 256)
full, err := EncryptSM4CTR(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CTR failed: %v", err)
}
offset := 193
length := 487
segCipher := full[offset : offset+length]
segPlain, err := DecryptSM4CTRAt(segCipher, key, iv, int64(offset))
if err != nil {
t.Fatalf("DecryptSM4CTRAt failed: %v", err)
}
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
t.Fatalf("sm4 ctr offset decrypt mismatch")
}
encSeg, err := EncryptSM4CTRAt(plain[offset:offset+length], key, iv, int64(offset))
if err != nil {
t.Fatalf("EncryptSM4CTRAt failed: %v", err)
}
if !bytes.Equal(encSeg, segCipher) {
t.Fatalf("sm4 ctr offset encrypt mismatch")
}
}
func TestAesGCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-gcm-chunk")
chunkIndex := uint64(7)
enc, err := EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptAesGCMChunk failed: %v", err)
}
dec, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptAesGCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes gcm chunk decrypt mismatch")
}
if _, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestSM4GCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-gcm-chunk")
chunkIndex := uint64(11)
enc, err := EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptSM4GCMChunk failed: %v", err)
}
dec, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptSM4GCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 gcm chunk decrypt mismatch")
}
if _, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func mustHex(t *testing.T, s string) []byte {
t.Helper()
b, err := hex.DecodeString(s)
@@ -606,3 +787,553 @@ func mustHex(t *testing.T, s string) []byte {
}
return b
}
func TestAesCFB8RoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("aes-cfb8-roundtrip-content")
enc, err := EncryptAesCFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB8 failed: %v", err)
}
dec, err := DecryptAesCFB8(enc, key, iv)
if err != nil {
t.Fatalf("DecryptAesCFB8 failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes cfb8 mismatch")
}
}
func TestSM4CFB8RoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("sm4-cfb8-roundtrip-content")
enc, err := EncryptSM4CFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB8 failed: %v", err)
}
dec, err := DecryptSM4CFB8(enc, key, iv)
if err != nil {
t.Fatalf("DecryptSM4CFB8 failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 cfb8 mismatch")
}
}
func TestAesCFB8StreamRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("aes-cfb8-stream-"), 512)
enc := &bytes.Buffer{}
if err := EncryptAesCFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
t.Fatalf("EncryptAesCFB8Stream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
t.Fatalf("DecryptAesCFB8Stream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes cfb8 stream mismatch")
}
}
func TestSM4CFB8StreamRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("sm4-cfb8-stream-"), 512)
enc := &bytes.Buffer{}
if err := EncryptSM4CFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
t.Fatalf("EncryptSM4CFB8Stream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4CFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
t.Fatalf("DecryptSM4CFB8Stream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 cfb8 stream mismatch")
}
}
func TestAesSegmentDecryptModes(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
ecbEnc, err := EncryptAesECB(plain, key, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptAesECB failed: %v", err)
}
ecbDec, err := DecryptAesECBBlocks(ecbEnc, key)
if err != nil {
t.Fatalf("DecryptAesECBBlocks failed: %v", err)
}
if !bytes.Equal(ecbDec, plain) {
t.Fatalf("aes ecb segment mismatch")
}
cbcEnc, err := EncryptAesCBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
}
cbcSegDec, err := DecryptAesCBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
t.Fatalf("aes cbc from-second-block mismatch")
}
cfbEnc, err := EncryptAesCFB(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB failed: %v", err)
}
cfbSegDec, err := DecryptAesCFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
t.Fatalf("aes cfb from-second-block mismatch")
}
}
func TestSM4SegmentDecryptModes(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
ecbEnc, err := EncryptSM4ECB(plain, key, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptSM4ECB failed: %v", err)
}
ecbDec, err := DecryptSM4ECBBlocks(ecbEnc, key)
if err != nil {
t.Fatalf("DecryptSM4ECBBlocks failed: %v", err)
}
if !bytes.Equal(ecbDec, plain) {
t.Fatalf("sm4 ecb segment mismatch")
}
cbcEnc, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptSM4CBC failed: %v", err)
}
cbcSegDec, err := DecryptSM4CBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
t.Fatalf("sm4 cbc from-second-block mismatch")
}
cfbEnc, err := EncryptSM4CFBNoBlock(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB failed: %v", err)
}
cfbSegDec, err := DecryptSM4CFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
t.Fatalf("sm4 cfb from-second-block mismatch")
}
}
func TestAesCCMRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-roundtrip")
enc, err := EncryptAesCCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesCCM failed: %v", err)
}
dec, err := DecryptAesCCM(enc, key, nonce, aad)
if err != nil {
t.Fatalf("DecryptAesCCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes ccm mismatch")
}
}
func TestSM4CCMRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-ccm-roundtrip")
enc, err := EncryptSM4CCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptSM4CCM failed: %v", err)
}
dec, err := DecryptSM4CCM(enc, key, nonce, aad)
if err != nil {
t.Fatalf("DecryptSM4CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 ccm mismatch")
}
}
func TestAesCCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("aes-ccm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptAesCCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptAesCCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesCCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes ccm stream mismatch")
}
}
func TestAesCCMStreamLegacyCompatDecrypt(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-legacy-compat")
legacyCipher, err := EncryptAesCCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesCCM failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesCCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes ccm legacy decrypt mismatch")
}
}
func TestSM4CCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("sm4-ccm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptSM4CCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptSM4CCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4CCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptSM4CCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 ccm stream mismatch")
}
}
func TestAesCCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-chunk")
chunkIndex := uint64(5)
enc, err := EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptAesCCMChunk failed: %v", err)
}
dec, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptAesCCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes ccm chunk decrypt mismatch")
}
if _, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestSM4CCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-ccm-chunk")
chunkIndex := uint64(9)
enc, err := EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptSM4CCMChunk failed: %v", err)
}
dec, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptSM4CCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 ccm chunk decrypt mismatch")
}
if _, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestAesOptionsModeCCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-options-ccm")
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
enc, err := EncryptAesWithOptions(plain, key, opts)
if err != nil {
t.Fatalf("EncryptAesWithOptions CCM failed: %v", err)
}
dec, err := DecryptAesWithOptions(enc, key, opts)
if err != nil {
t.Fatalf("DecryptAesWithOptions CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes options ccm mismatch")
}
}
func TestSM4OptionsModeCCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-options-ccm")
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
enc, err := EncryptSM4WithOptions(plain, key, opts)
if err != nil {
t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err)
}
dec, err := DecryptSM4WithOptions(enc, key, opts)
if err != nil {
t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 options ccm mismatch")
}
}
func TestCCMInvalidNonceLength(t *testing.T) {
key := []byte("0123456789abcdef")
shortNonce := []byte("short")
if _, err := EncryptAesCCM([]byte("x"), key, shortNonce, nil); err == nil {
t.Fatalf("expected aes ccm nonce length error")
}
if _, err := EncryptSM4CCM([]byte("x"), key, shortNonce, nil); err == nil {
t.Fatalf("expected sm4 ccm nonce length error")
}
}
func TestAesXTSRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
enc, err := EncryptAesXTS(plain, k1, k2, 32)
if err != nil {
t.Fatalf("EncryptAesXTS failed: %v", err)
}
dec, err := DecryptAesXTS(enc, k1, k2, 32)
if err != nil {
t.Fatalf("DecryptAesXTS failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes xts mismatch")
}
}
func TestSM4XTSRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
enc, err := EncryptSM4XTS(plain, k1, k2, 32)
if err != nil {
t.Fatalf("EncryptSM4XTS failed: %v", err)
}
dec, err := DecryptSM4XTS(enc, k1, k2, 32)
if err != nil {
t.Fatalf("DecryptSM4XTS failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 xts mismatch")
}
}
func TestAesXTSAtDataUnit(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
dataUnitSize := 32
full, err := EncryptAesXTS(plain, k1, k2, dataUnitSize)
if err != nil {
t.Fatalf("EncryptAesXTS failed: %v", err)
}
segPlain := plain[64:96]
segEnc, err := EncryptAesXTSAt(segPlain, k1, k2, dataUnitSize, 2)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(segEnc, full[64:96]) {
t.Fatalf("aes xts at encrypt mismatch")
}
segDec, err := DecryptAesXTSAt(full[64:96], k1, k2, dataUnitSize, 2)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(segDec, segPlain) {
t.Fatalf("aes xts at decrypt mismatch")
}
}
func TestSM4XTSAtDataUnit(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
dataUnitSize := 32
full, err := EncryptSM4XTS(plain, k1, k2, dataUnitSize)
if err != nil {
t.Fatalf("EncryptSM4XTS failed: %v", err)
}
segPlain := plain[32:64]
segEnc, err := EncryptSM4XTSAt(segPlain, k1, k2, dataUnitSize, 1)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(segEnc, full[32:64]) {
t.Fatalf("sm4 xts at encrypt mismatch")
}
segDec, err := DecryptSM4XTSAt(full[32:64], k1, k2, dataUnitSize, 1)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(segDec, segPlain) {
t.Fatalf("sm4 xts at decrypt mismatch")
}
}
func TestAesXTSStreamRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
enc := &bytes.Buffer{}
if err := EncryptAesXTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
t.Fatalf("EncryptAesXTSStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesXTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
t.Fatalf("DecryptAesXTSStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes xts stream mismatch")
}
}
func TestSM4XTSStreamRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
enc := &bytes.Buffer{}
if err := EncryptSM4XTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
t.Fatalf("EncryptSM4XTSStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4XTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
t.Fatalf("DecryptSM4XTSStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 xts stream mismatch")
}
}
func TestXTSRejectsNonBlockMultiple(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
if _, err := EncryptAesXTS([]byte("short"), k1, k2, 32); err == nil {
t.Fatalf("expected aes xts non-block error")
}
if _, err := EncryptSM4XTS([]byte("short"), k1, k2, 32); err == nil {
t.Fatalf("expected sm4 xts non-block error")
}
}
func TestXTSStreamRejectsTailNotFullBlock(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
err := EncryptAesXTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
if err == nil {
t.Fatalf("expected aes xts stream tail error")
}
err = EncryptSM4XTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
if err == nil {
t.Fatalf("expected sm4 xts stream tail error")
}
}
func TestXTSInvalidDataUnitSize(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2)
if _, err := EncryptAesXTS(plain, k1, k2, 30); err == nil {
t.Fatalf("expected aes xts invalid data unit size error")
}
if _, err := EncryptSM4XTS(plain, k1, k2, 30); err == nil {
t.Fatalf("expected sm4 xts invalid data unit size error")
}
}
func TestSplitXTSMasterKeyHelpers(t *testing.T) {
master := []byte("0123456789abcdef0123456789abcdef")
k1, k2, err := SplitXTSMasterKey(master)
if err != nil {
t.Fatalf("SplitXTSMasterKey failed: %v", err)
}
if len(k1) != 16 || len(k2) != 16 {
t.Fatalf("split key lengths mismatch")
}
if !bytes.Equal(append(k1, k2...), master) {
t.Fatalf("split key content mismatch")
}
aesK1, aesK2, err := SplitAesXTSMasterKey(master)
if err != nil {
t.Fatalf("SplitAesXTSMasterKey failed: %v", err)
}
if !bytes.Equal(aesK1, k1) || !bytes.Equal(aesK2, k2) {
t.Fatalf("aes split mismatch")
}
sm4K1, sm4K2, err := SplitSM4XTSMasterKey(master)
if err != nil {
t.Fatalf("SplitSM4XTSMasterKey failed: %v", err)
}
if !bytes.Equal(sm4K1, k1) || !bytes.Equal(sm4K2, k2) {
t.Fatalf("sm4 split mismatch")
}
if _, _, err := SplitXTSMasterKey([]byte("abc")); err == nil {
t.Fatalf("expected odd-length split error")
}
if _, _, err := SplitAesXTSMasterKey([]byte("0123456789abcdef0123456789abcdef01")); err == nil {
t.Fatalf("expected aes master length error")
}
if _, _, err := SplitSM4XTSMasterKey([]byte("0123456789abcdef")); err == nil {
t.Fatalf("expected sm4 master length error")
}
}
+263
View File
@@ -0,0 +1,263 @@
package symm
import (
"crypto/aes"
"crypto/cipher"
"errors"
"io"
"github.com/emmansun/gmsm/sm4"
"golang.org/x/crypto/xts"
)
const xtsBlockSize = 16
var (
ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16")
ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16")
ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal")
ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even")
ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes")
ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes")
)
type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error)
func combineXTSKeys(key1, key2 []byte) ([]byte, error) {
if len(key1) == 0 || len(key1) != len(key2) {
return nil, ErrInvalidXTSKeyLength
}
out := make([]byte, len(key1)+len(key2))
copy(out, key1)
copy(out[len(key1):], key2)
return out, nil
}
func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) == 0 || len(masterKey)%2 != 0 {
return nil, nil, ErrInvalidXTSMasterKeyLength
}
half := len(masterKey) / 2
k1 := make([]byte, half)
k2 := make([]byte, half)
copy(k1, masterKey[:half])
copy(k2, masterKey[half:])
return k1, k2, nil
}
// SplitXTSMasterKey splits a master key into two equal XTS keys.
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return splitXTSMasterKey(masterKey)
}
// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes).
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
switch len(masterKey) {
case 32, 48, 64:
return splitXTSMasterKey(masterKey)
default:
return nil, nil, ErrInvalidAESXTSMasterKeyLength
}
}
// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes).
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) != 32 {
return nil, nil, ErrInvalidSM4XTSMasterKeyLength
}
return splitXTSMasterKey(masterKey)
}
func validateXTSDataUnitSize(dataUnitSize int) error {
if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 {
return ErrInvalidXTSDataUnitSize
}
return nil
}
func validateXTSDataLength(data []byte) error {
if len(data)%xtsBlockSize != 0 {
return ErrInvalidXTSDataLength
}
return nil
}
func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return nil, err
}
if err := validateXTSDataLength(in); err != nil {
return nil, err
}
if len(in) == 0 {
return []byte{}, nil
}
out := make([]byte, len(in))
off := 0
unit := dataUnitIndex
for off < len(in) {
chunkLen := dataUnitSize
if remain := len(in) - off; remain < chunkLen {
chunkLen = remain
}
if decrypt {
c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
} else {
c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
}
off += chunkLen
unit++
}
return out, nil
}
func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return err
}
buf := make([]byte, 32*1024)
pending := make([]byte, 0, dataUnitSize*2)
unit := dataUnitIndex
for {
n, err := src.Read(buf)
if n > 0 {
pending = append(pending, buf[:n]...)
processLen := len(pending) / dataUnitSize * dataUnitSize
if processLen > 0 {
out := make([]byte, processLen)
for off := 0; off < processLen; off += dataUnitSize {
if decrypt {
c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
} else {
c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
}
unit++
}
if _, werr := dst.Write(out); werr != nil {
return werr
}
pending = append([]byte(nil), pending[processLen:]...)
}
}
if err != nil {
if err == io.EOF {
break
}
return err
}
}
if len(pending) == 0 {
return nil
}
if err := validateXTSDataLength(pending); err != nil {
return err
}
out := make([]byte, len(pending))
if decrypt {
c.Decrypt(out, pending, unit)
} else {
c.Encrypt(out, pending, unit)
}
_, err := dst.Write(out)
return err
}
func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) {
key, err := combineXTSKeys(key1, key2)
if err != nil {
return nil, err
}
return xts.NewCipher(newBlock, key)
}
func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(aes.NewCipher, key1, key2)
}
func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(sm4.NewCipher, key1, key2)
}
func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) {
c, err := factory(key1, key2)
if err != nil {
return nil, err
}
return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt)
}
func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error {
c, err := factory(key1, key2)
if err != nil {
return err
}
return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt)
}
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}
+79
View File
@@ -0,0 +1,79 @@
package symm
import (
"bytes"
"testing"
)
// AES-XTS vectors from IEEE P1619/D16 Annex B (same set used by golang.org/x/crypto/xts tests).
var aesXTSStandardVectors = []struct {
key string
dataUnitIndex uint64
plaintext string
ciphertext string
}{
{
key: "0000000000000000000000000000000000000000000000000000000000000000",
dataUnitIndex: 0,
plaintext: "0000000000000000000000000000000000000000000000000000000000000000",
ciphertext: "917cf69ebd68b2ec9b9fe9a3eadda692cd43d2f59598ed858c02c2652fbf922e",
},
{
key: "1111111111111111111111111111111122222222222222222222222222222222",
dataUnitIndex: 0x3333333333,
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
ciphertext: "c454185e6a16936e39334038acef838bfb186fff7480adc4289382ecd6d394f0",
},
{
key: "fffefdfcfbfaf9f8f7f6f5f4f3f2f1f022222222222222222222222222222222",
dataUnitIndex: 0x3333333333,
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
ciphertext: "af85336b597afc1a900b2eb21ec949d292df4c047e0b21532186a5971a227a89",
},
}
func TestAesXTSStandardVectors(t *testing.T) {
for i, tc := range aesXTSStandardVectors {
master := mustHex(t, tc.key)
key1, key2, err := SplitAesXTSMasterKey(master)
if err != nil {
t.Fatalf("#%d split key failed: %v", i, err)
}
plain := mustHex(t, tc.plaintext)
wantCipher := mustHex(t, tc.ciphertext)
dataUnitSize := len(plain)
gotCipher, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, tc.dataUnitIndex)
if err != nil {
t.Fatalf("#%d EncryptAesXTSAt failed: %v", i, err)
}
if !bytes.Equal(gotCipher, wantCipher) {
t.Fatalf("#%d ciphertext mismatch", i)
}
gotPlain, err := DecryptAesXTSAt(wantCipher, key1, key2, dataUnitSize, tc.dataUnitIndex)
if err != nil {
t.Fatalf("#%d DecryptAesXTSAt failed: %v", i, err)
}
if !bytes.Equal(gotPlain, plain) {
t.Fatalf("#%d plaintext mismatch", i)
}
encStream := &bytes.Buffer{}
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
t.Fatalf("#%d EncryptAesXTSStreamAt failed: %v", i, err)
}
if !bytes.Equal(encStream.Bytes(), wantCipher) {
t.Fatalf("#%d stream ciphertext mismatch", i)
}
decStream := &bytes.Buffer{}
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(wantCipher), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
t.Fatalf("#%d DecryptAesXTSStreamAt failed: %v", i, err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("#%d stream plaintext mismatch", i)
}
}
}