feat: 新增XTS/CCM流式与KDF能力,补充安全测试并更新README/CHANGELOG
This commit is contained in:
+234
-9
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"b612.me/starcrypto/ccm"
|
||||
"b612.me/starcrypto/paddingx"
|
||||
)
|
||||
|
||||
@@ -17,13 +18,27 @@ const (
|
||||
ANSIX923PADDING = paddingx.ANSIX923
|
||||
)
|
||||
|
||||
var ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
|
||||
var (
|
||||
ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
|
||||
ErrInvalidCCMNonceLength = errors.New("ccm nonce length must be 12 bytes")
|
||||
)
|
||||
|
||||
const (
|
||||
aeadCCMTagSize = 16
|
||||
aeadCCMNonceSize = 12
|
||||
)
|
||||
|
||||
func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return EncryptAesGCM(data, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return EncryptAesCCM(data, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -34,9 +49,15 @@ func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error)
|
||||
|
||||
func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return DecryptAesGCM(src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return DecryptAesCCM(src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -47,9 +68,15 @@ func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
|
||||
func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return EncryptAesGCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return EncryptAesCCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -60,9 +87,15 @@ func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
|
||||
|
||||
func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return DecryptAesGCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return DecryptAesCCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -80,6 +113,9 @@ func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
|
||||
if mode == MODEGCM {
|
||||
return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return EncryptAesCCM(data, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return EncryptAes(data, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -92,6 +128,9 @@ func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
|
||||
if mode == MODEGCM {
|
||||
return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return DecryptAesCCM(src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return DecryptAes(src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -104,6 +143,9 @@ func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
|
||||
if mode == MODEGCM {
|
||||
return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return EncryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -116,6 +158,9 @@ func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
|
||||
if mode == MODEGCM {
|
||||
return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return DecryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -149,30 +194,138 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
|
||||
return gcm.Open(nil, nonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func newAesCCM(key []byte) (cipher.AEAD, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
|
||||
}
|
||||
|
||||
func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return aead.Seal(nil, nonce, plain, aad), nil
|
||||
}
|
||||
|
||||
func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return aead.Open(nil, nonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
|
||||
}
|
||||
|
||||
func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
|
||||
}
|
||||
|
||||
func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return ErrInvalidCCMNonceLength
|
||||
}
|
||||
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
aead, err := newAesCCM(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return ErrInvalidCCMNonceLength
|
||||
}
|
||||
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return nil, ErrInvalidGCMNonceLength
|
||||
}
|
||||
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
|
||||
}
|
||||
|
||||
func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return nil, ErrInvalidGCMNonceLength
|
||||
}
|
||||
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
|
||||
}
|
||||
|
||||
func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
plain, err := io.ReadAll(src)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := EncryptAesGCM(plain, key, nonce, aad)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(out)
|
||||
return err
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return ErrInvalidGCMNonceLength
|
||||
}
|
||||
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
|
||||
}
|
||||
|
||||
func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
enc, err := io.ReadAll(src)
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := DecryptAesGCM(enc, key, nonce, aad)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(out)
|
||||
return err
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return ErrInvalidGCMNonceLength
|
||||
}
|
||||
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
|
||||
}
|
||||
|
||||
func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) {
|
||||
@@ -215,6 +368,18 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) {
|
||||
return DecryptAes(src, key, iv, MODECTR, "")
|
||||
}
|
||||
|
||||
func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return xorCTRAtOffset(block, data, iv, offset)
|
||||
}
|
||||
|
||||
func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
|
||||
return EncryptAesCTRAt(src, key, iv, offset)
|
||||
}
|
||||
|
||||
func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
|
||||
return EncryptAesStream(dst, src, key, nil, MODEECB, paddingType)
|
||||
}
|
||||
@@ -329,3 +494,63 @@ func PKCS7Trimming(encrypted []byte, blockSize int) []byte {
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encryptCFB8(block, data, iv)
|
||||
}
|
||||
|
||||
func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCFB8(block, src, iv)
|
||||
}
|
||||
|
||||
func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return encryptCFB8Stream(block, dst, src, iv, false)
|
||||
}
|
||||
|
||||
func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return encryptCFB8Stream(block, dst, src, iv, true)
|
||||
}
|
||||
|
||||
func DecryptAesECBBlocks(src, key []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptECBBlocks(block, src)
|
||||
}
|
||||
|
||||
// DecryptAesCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
|
||||
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
|
||||
func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
|
||||
}
|
||||
|
||||
// DecryptAesCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
|
||||
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
|
||||
func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
benchPlain4K = bytes.Repeat([]byte("0123456789abcdef"), 256) // 4 KiB
|
||||
benchPlain256K = bytes.Repeat([]byte("0123456789abcdef"), 16384) // 256 KiB
|
||||
benchAAD = []byte("benchmark-aad")
|
||||
benchAESKey = []byte("0123456789abcdef")
|
||||
benchSM4Key = []byte("0123456789abcdef")
|
||||
benchXTSKey2 = []byte("fedcba9876543210")
|
||||
benchNonce12Byte = []byte("123456789012")
|
||||
)
|
||||
|
||||
func BenchmarkAesGCMEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptAesGCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAesCCMEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptAesCCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAesXTSEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptAesXTS(benchPlain4K, benchAESKey, benchXTSKey2, 512); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSM4GCMEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptSM4GCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSM4CCMEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptSM4CCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSM4XTSEncrypt4K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain4K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := EncryptSM4XTS(benchPlain4K, benchSM4Key, benchXTSKey2, 512); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkAesXTSStreamEncrypt256K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain256K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
var dst bytes.Buffer
|
||||
if err := EncryptAesXTSStream(&dst, bytes.NewReader(benchPlain256K), benchAESKey, benchXTSKey2, 512); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSM4XTSStreamEncrypt256K(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(len(benchPlain256K)))
|
||||
for i := 0; i < b.N; i++ {
|
||||
var dst bytes.Buffer
|
||||
if err := EncryptSM4XTSStream(&dst, bytes.NewReader(benchPlain256K), benchSM4Key, benchXTSKey2, 512); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
const ccmStreamMagic = "SCC1"
|
||||
|
||||
var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk")
|
||||
|
||||
func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
|
||||
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
|
||||
return aead.Seal(nil, chunkNonce, plain, aad)
|
||||
}
|
||||
|
||||
func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
|
||||
return aead.Open(nil, chunkNonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, gcmStreamChunkSize)
|
||||
lenBuf := make([]byte, 4)
|
||||
var chunkIndex uint64
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
|
||||
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
|
||||
if _, werr := dst.Write(lenBuf); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := dst.Write(sealed); werr != nil {
|
||||
return werr
|
||||
}
|
||||
chunkIndex++
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
header := make([]byte, len(ccmStreamMagic))
|
||||
n, err := io.ReadFull(src, header)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
return err
|
||||
}
|
||||
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
|
||||
}
|
||||
|
||||
if string(header) != ccmStreamMagic {
|
||||
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
|
||||
}
|
||||
return decryptCCMChunkedStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
lenBuf := make([]byte, 4)
|
||||
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
|
||||
var chunkIndex uint64
|
||||
|
||||
for {
|
||||
_, err := io.ReadFull(src, lenBuf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
chunkLen := binary.BigEndian.Uint32(lenBuf)
|
||||
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
|
||||
return ErrInvalidCCMStreamChunk
|
||||
}
|
||||
|
||||
chunk := make([]byte, chunkLen)
|
||||
if _, err := io.ReadFull(src, chunk); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := dst.Write(plain); err != nil {
|
||||
return err
|
||||
}
|
||||
chunkIndex++
|
||||
}
|
||||
}
|
||||
|
||||
func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
enc, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
plain, err := aead.Open(nil, nonce, enc, aad)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(plain)
|
||||
return err
|
||||
}
|
||||
+103
@@ -0,0 +1,103 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
func encryptCFB8(block cipher.Block, data, iv []byte) ([]byte, error) {
|
||||
if len(iv) != block.BlockSize() {
|
||||
return nil, errors.New("iv length must match block size")
|
||||
}
|
||||
reg := make([]byte, len(iv))
|
||||
copy(reg, iv)
|
||||
regView := make([]byte, block.BlockSize())
|
||||
streamBlock := make([]byte, block.BlockSize())
|
||||
out := make([]byte, len(data))
|
||||
head := 0
|
||||
|
||||
for i := 0; i < len(data); i++ {
|
||||
buildCFB8Register(regView, reg, head)
|
||||
block.Encrypt(streamBlock, regView)
|
||||
c := data[i] ^ streamBlock[0]
|
||||
out[i] = c
|
||||
advanceCFB8Register(reg, &head, c)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func decryptCFB8(block cipher.Block, src, iv []byte) ([]byte, error) {
|
||||
if len(iv) != block.BlockSize() {
|
||||
return nil, errors.New("iv length must match block size")
|
||||
}
|
||||
reg := make([]byte, len(iv))
|
||||
copy(reg, iv)
|
||||
regView := make([]byte, block.BlockSize())
|
||||
streamBlock := make([]byte, block.BlockSize())
|
||||
out := make([]byte, len(src))
|
||||
head := 0
|
||||
|
||||
for i := 0; i < len(src); i++ {
|
||||
buildCFB8Register(regView, reg, head)
|
||||
block.Encrypt(streamBlock, regView)
|
||||
p := src[i] ^ streamBlock[0]
|
||||
out[i] = p
|
||||
advanceCFB8Register(reg, &head, src[i])
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func encryptCFB8Stream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, decrypt bool) error {
|
||||
if len(iv) != block.BlockSize() {
|
||||
return errors.New("iv length must match block size")
|
||||
}
|
||||
reg := make([]byte, len(iv))
|
||||
copy(reg, iv)
|
||||
regView := make([]byte, block.BlockSize())
|
||||
streamBlock := make([]byte, block.BlockSize())
|
||||
buf := make([]byte, 32*1024)
|
||||
out := make([]byte, 32*1024)
|
||||
head := 0
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
for i := 0; i < n; i++ {
|
||||
buildCFB8Register(regView, reg, head)
|
||||
block.Encrypt(streamBlock, regView)
|
||||
if decrypt {
|
||||
out[i] = buf[i] ^ streamBlock[0]
|
||||
advanceCFB8Register(reg, &head, buf[i])
|
||||
} else {
|
||||
c := buf[i] ^ streamBlock[0]
|
||||
out[i] = c
|
||||
advanceCFB8Register(reg, &head, c)
|
||||
}
|
||||
}
|
||||
if _, werr := dst.Write(out[:n]); werr != nil {
|
||||
return werr
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func buildCFB8Register(dst, reg []byte, head int) {
|
||||
first := len(reg) - head
|
||||
copy(dst, reg[head:])
|
||||
copy(dst[first:], reg[:head])
|
||||
}
|
||||
|
||||
func advanceCFB8Register(reg []byte, head *int, feedback byte) {
|
||||
reg[*head] = feedback
|
||||
*head = *head + 1
|
||||
if *head == len(reg) {
|
||||
*head = 0
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidCTROffset = errors.New("ctr offset must be non-negative")
|
||||
ErrCTRCounterOverflow = errors.New("ctr counter overflow")
|
||||
)
|
||||
|
||||
func xorCTRAtOffset(block cipher.Block, data, iv []byte, offset int64) ([]byte, error) {
|
||||
if offset < 0 {
|
||||
return nil, ErrInvalidCTROffset
|
||||
}
|
||||
if len(iv) != block.BlockSize() {
|
||||
return nil, errors.New("iv length must match block size")
|
||||
}
|
||||
|
||||
counter := make([]byte, len(iv))
|
||||
copy(counter, iv)
|
||||
|
||||
blockSize := int64(block.BlockSize())
|
||||
blockOffset := uint64(offset / blockSize)
|
||||
byteOffset := int(offset % blockSize)
|
||||
if err := addUint64ToCounter(counter, blockOffset); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stream := cipher.NewCTR(block, counter)
|
||||
if byteOffset > 0 {
|
||||
skip := make([]byte, byteOffset)
|
||||
stream.XORKeyStream(skip, skip)
|
||||
}
|
||||
|
||||
out := make([]byte, len(data))
|
||||
stream.XORKeyStream(out, data)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func addUint64ToCounter(counter []byte, inc uint64) error {
|
||||
if inc == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := len(counter) - 1; i >= 0 && inc > 0; i-- {
|
||||
sum := uint64(counter[i]) + (inc & 0xff)
|
||||
counter[i] = byte(sum)
|
||||
inc = (inc >> 8) + (sum >> 8)
|
||||
}
|
||||
if inc > 0 {
|
||||
return ErrCTRCounterOverflow
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -70,3 +70,122 @@ func FuzzAesCBCStreamRoundTrip(f *testing.F) {
|
||||
}
|
||||
})
|
||||
}
|
||||
func xtsFuzzDataUnitSize(selector uint8) int {
|
||||
sizes := [...]int{16, 32, 64, 128, 256, 512}
|
||||
return sizes[int(selector)%len(sizes)]
|
||||
}
|
||||
|
||||
func xtsFuzzNormalizeData(data []byte, maxLen int) []byte {
|
||||
if len(data) > maxLen {
|
||||
data = data[:maxLen]
|
||||
}
|
||||
return data[:len(data)/16*16]
|
||||
}
|
||||
|
||||
func FuzzAesXTSRoundTrip(f *testing.F) {
|
||||
f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0))
|
||||
f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3))
|
||||
|
||||
key1 := []byte("0123456789abcdef")
|
||||
key2 := []byte("fedcba9876543210")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
|
||||
bounded := data
|
||||
if len(bounded) > 4097 {
|
||||
bounded = bounded[:4097]
|
||||
}
|
||||
plain := xtsFuzzNormalizeData(bounded, 4096)
|
||||
dataUnitSize := xtsFuzzDataUnitSize(selector)
|
||||
|
||||
enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesXTSAt failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesXTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes xts roundtrip mismatch")
|
||||
}
|
||||
|
||||
encStream := &bytes.Buffer{}
|
||||
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
||||
t.Fatalf("EncryptAesXTSStreamAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encStream.Bytes(), enc) {
|
||||
t.Fatalf("aes xts bytes/stream encrypt mismatch")
|
||||
}
|
||||
|
||||
decStream := &bytes.Buffer{}
|
||||
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
||||
t.Fatalf("DecryptAesXTSStreamAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decStream.Bytes(), plain) {
|
||||
t.Fatalf("aes xts stream decrypt mismatch")
|
||||
}
|
||||
|
||||
if len(bounded)%16 != 0 {
|
||||
if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
||||
t.Fatalf("expected aes xts bytes error for non-block-aligned input")
|
||||
}
|
||||
if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
||||
t.Fatalf("expected aes xts stream error for non-block-aligned input")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzSM4XTSRoundTrip(f *testing.F) {
|
||||
f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0))
|
||||
f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4))
|
||||
|
||||
key1 := []byte("0123456789abcdef")
|
||||
key2 := []byte("fedcba9876543210")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
|
||||
bounded := data
|
||||
if len(bounded) > 4097 {
|
||||
bounded = bounded[:4097]
|
||||
}
|
||||
plain := xtsFuzzNormalizeData(bounded, 4096)
|
||||
dataUnitSize := xtsFuzzDataUnitSize(selector)
|
||||
|
||||
enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 xts roundtrip mismatch")
|
||||
}
|
||||
|
||||
encStream := &bytes.Buffer{}
|
||||
if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
||||
t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encStream.Bytes(), enc) {
|
||||
t.Fatalf("sm4 xts bytes/stream encrypt mismatch")
|
||||
}
|
||||
|
||||
decStream := &bytes.Buffer{}
|
||||
if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
||||
t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(decStream.Bytes(), plain) {
|
||||
t.Fatalf("sm4 xts stream decrypt mismatch")
|
||||
}
|
||||
|
||||
if len(bounded)%16 != 0 {
|
||||
if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
||||
t.Fatalf("expected sm4 xts bytes error for non-block-aligned input")
|
||||
}
|
||||
if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
||||
t.Fatalf("expected sm4 xts stream error for non-block-aligned input")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
)
|
||||
|
||||
const (
|
||||
gcmStreamMagic = "SCG1"
|
||||
gcmStreamChunkSize = 32 * 1024
|
||||
)
|
||||
|
||||
var ErrInvalidGCMStreamChunk = errors.New("invalid gcm stream chunk")
|
||||
|
||||
func encryptGCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
|
||||
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
|
||||
return aead.Seal(nil, chunkNonce, plain, aad)
|
||||
}
|
||||
|
||||
func decryptGCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
|
||||
return aead.Open(nil, chunkNonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func encryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
if _, err := dst.Write([]byte(gcmStreamMagic)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, gcmStreamChunkSize)
|
||||
lenBuf := make([]byte, 4)
|
||||
var chunkIndex uint64
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
sealed := encryptGCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
|
||||
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
|
||||
if _, werr := dst.Write(lenBuf); werr != nil {
|
||||
return werr
|
||||
}
|
||||
if _, werr := dst.Write(sealed); werr != nil {
|
||||
return werr
|
||||
}
|
||||
chunkIndex++
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func decryptGCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
header := make([]byte, len(gcmStreamMagic))
|
||||
n, err := io.ReadFull(src, header)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err != io.ErrUnexpectedEOF {
|
||||
return err
|
||||
}
|
||||
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
|
||||
}
|
||||
|
||||
if string(header) != gcmStreamMagic {
|
||||
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
|
||||
}
|
||||
return decryptGCMChunkedStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func decryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
lenBuf := make([]byte, 4)
|
||||
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
|
||||
var chunkIndex uint64
|
||||
|
||||
for {
|
||||
_, err := io.ReadFull(src, lenBuf)
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
chunkLen := binary.BigEndian.Uint32(lenBuf)
|
||||
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
|
||||
return ErrInvalidGCMStreamChunk
|
||||
}
|
||||
|
||||
chunk := make([]byte, chunkLen)
|
||||
if _, err := io.ReadFull(src, chunk); err != nil {
|
||||
if err == io.ErrUnexpectedEOF {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
plain, err := decryptGCMChunk(aead, chunk, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := dst.Write(plain); err != nil {
|
||||
return err
|
||||
}
|
||||
chunkIndex++
|
||||
}
|
||||
}
|
||||
|
||||
func decryptGCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
|
||||
enc, err := io.ReadAll(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
plain, err := aead.Open(nil, nonce, enc, aad)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(plain)
|
||||
return err
|
||||
}
|
||||
|
||||
func deriveChunkNonce(baseNonce []byte, chunkIndex uint64) []byte {
|
||||
nonce := make([]byte, len(baseNonce))
|
||||
copy(nonce, baseNonce)
|
||||
if len(nonce) < 8 {
|
||||
return nonce
|
||||
}
|
||||
|
||||
var indexBytes [8]byte
|
||||
binary.BigEndian.PutUint64(indexBytes[:], chunkIndex)
|
||||
off := len(nonce) - 8
|
||||
for i := 0; i < 8; i++ {
|
||||
nonce[off+i] ^= indexBytes[i]
|
||||
}
|
||||
return nonce
|
||||
}
|
||||
@@ -16,6 +16,7 @@ const (
|
||||
MODEOFB = "OFB"
|
||||
MODECTR = "CTR"
|
||||
MODEGCM = "GCM"
|
||||
MODECCM = "CCM"
|
||||
)
|
||||
|
||||
var ErrUnsupportedCipherMode = errors.New("cipher mode not supported")
|
||||
|
||||
+2
-5
@@ -1,7 +1,7 @@
|
||||
package symm
|
||||
|
||||
// CipherOptions provides a unified configuration for symmetric APIs.
|
||||
// For GCM mode, Nonce is used; if Nonce is empty, IV is used as fallback.
|
||||
// For AEAD modes (GCM/CCM), Nonce must be set explicitly.
|
||||
type CipherOptions struct {
|
||||
Mode string
|
||||
Padding string
|
||||
@@ -18,8 +18,5 @@ func normalizeCipherOptions(opts *CipherOptions) CipherOptions {
|
||||
}
|
||||
|
||||
func nonceFromOptions(opts CipherOptions) []byte {
|
||||
if len(opts.Nonce) > 0 {
|
||||
return opts.Nonce
|
||||
}
|
||||
return opts.IV
|
||||
return opts.Nonce
|
||||
}
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAEADOptionsRequireNonce(t *testing.T) {
|
||||
aesKey := []byte("0123456789abcdef")
|
||||
sm4Key := []byte("0123456789abcdef")
|
||||
plain := []byte("nonce-required")
|
||||
|
||||
gcmIVOnly := &CipherOptions{Mode: MODEGCM, IV: []byte("123456789012")}
|
||||
if _, err := EncryptAesWithOptions(plain, aesKey, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
|
||||
t.Fatalf("expected ErrInvalidGCMNonceLength for AES GCM with IV-only opts, got: %v", err)
|
||||
}
|
||||
if _, err := EncryptSM4WithOptions(plain, sm4Key, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
|
||||
t.Fatalf("expected ErrInvalidGCMNonceLength for SM4 GCM with IV-only opts, got: %v", err)
|
||||
}
|
||||
|
||||
ccmIVOnly := &CipherOptions{Mode: MODECCM, IV: []byte("123456789012")}
|
||||
if _, err := EncryptAesWithOptions(plain, aesKey, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
|
||||
t.Fatalf("expected ErrInvalidCCMNonceLength for AES CCM with IV-only opts, got: %v", err)
|
||||
}
|
||||
if _, err := EncryptSM4WithOptions(plain, sm4Key, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
|
||||
t.Fatalf("expected ErrInvalidCCMNonceLength for SM4 CCM with IV-only opts, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSegmentNotFullBlocks = errors.New("ciphertext segment is not a full block size")
|
||||
)
|
||||
|
||||
func decryptECBBlocks(block cipher.Block, src []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
if len(src)%block.BlockSize() != 0 {
|
||||
return nil, ErrSegmentNotFullBlocks
|
||||
}
|
||||
out := make([]byte, len(src))
|
||||
ecbDecryptBlocks(block, out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// decryptCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from the second block (or later).
|
||||
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
|
||||
func decryptCBCFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
if len(prevCipherBlock) != block.BlockSize() {
|
||||
return nil, errors.New("prev cipher block length must match block size")
|
||||
}
|
||||
if len(src)%block.BlockSize() != 0 {
|
||||
return nil, ErrSegmentNotFullBlocks
|
||||
}
|
||||
out := make([]byte, len(src))
|
||||
cipher.NewCBCDecrypter(block, prevCipherBlock).CryptBlocks(out, src)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// decryptCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from the second block (or later).
|
||||
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
|
||||
func decryptCFBFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
|
||||
if len(src) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
if len(prevCipherBlock) != block.BlockSize() {
|
||||
return nil, errors.New("prev cipher block length must match block size")
|
||||
}
|
||||
stream := cipher.NewCFBDecrypter(block, prevCipherBlock)
|
||||
out := make([]byte, len(src))
|
||||
stream.XORKeyStream(out, src)
|
||||
return out, nil
|
||||
}
|
||||
+225
-8
@@ -6,14 +6,21 @@ import (
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"b612.me/starcrypto/ccm"
|
||||
"github.com/emmansun/gmsm/sm4"
|
||||
)
|
||||
|
||||
func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return EncryptSM4GCM(data, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return EncryptSM4CCM(data, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -24,9 +31,15 @@ func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error)
|
||||
|
||||
func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return DecryptSM4GCM(src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return DecryptSM4CCM(src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -37,9 +50,15 @@ func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
|
||||
|
||||
func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return EncryptSM4GCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return EncryptSM4CCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -50,9 +69,15 @@ func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
|
||||
|
||||
func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
|
||||
normalizedMode := normalizeCipherMode(mode)
|
||||
if normalizedMode == "" {
|
||||
normalizedMode = MODEGCM
|
||||
}
|
||||
if normalizedMode == MODEGCM {
|
||||
return DecryptSM4GCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
if normalizedMode == MODECCM {
|
||||
return DecryptSM4CCMStream(dst, src, key, iv, nil)
|
||||
}
|
||||
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
@@ -70,6 +95,9 @@ func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
|
||||
if mode == MODEGCM {
|
||||
return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return EncryptSM4CCM(data, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -82,6 +110,9 @@ func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
|
||||
if mode == MODEGCM {
|
||||
return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return DecryptSM4CCM(src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -94,6 +125,9 @@ func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
|
||||
if mode == MODEGCM {
|
||||
return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return EncryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -106,6 +140,9 @@ func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
|
||||
if mode == MODEGCM {
|
||||
return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
if mode == MODECCM {
|
||||
return DecryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
|
||||
}
|
||||
return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
|
||||
}
|
||||
|
||||
@@ -139,30 +176,138 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
|
||||
return gcm.Open(nil, nonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return nil, ErrInvalidGCMNonceLength
|
||||
}
|
||||
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
|
||||
}
|
||||
|
||||
func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return nil, ErrInvalidGCMNonceLength
|
||||
}
|
||||
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
|
||||
}
|
||||
|
||||
func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
plain, err := io.ReadAll(src)
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := EncryptSM4GCM(plain, key, nonce, aad)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(out)
|
||||
return err
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return ErrInvalidGCMNonceLength
|
||||
}
|
||||
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
|
||||
}
|
||||
|
||||
func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
enc, err := io.ReadAll(src)
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
out, err := DecryptSM4GCM(enc, key, nonce, aad)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = dst.Write(out)
|
||||
return err
|
||||
if len(nonce) != gcm.NonceSize() {
|
||||
return ErrInvalidGCMNonceLength
|
||||
}
|
||||
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
|
||||
}
|
||||
|
||||
func newSM4CCM(key []byte) (cipher.AEAD, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
|
||||
}
|
||||
|
||||
func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return aead.Seal(nil, nonce, plain, aad), nil
|
||||
}
|
||||
|
||||
func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return aead.Open(nil, nonce, ciphertext, aad)
|
||||
}
|
||||
|
||||
func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
|
||||
}
|
||||
|
||||
func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return nil, ErrInvalidCCMNonceLength
|
||||
}
|
||||
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
|
||||
}
|
||||
|
||||
func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return ErrInvalidCCMNonceLength
|
||||
}
|
||||
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
|
||||
aead, err := newSM4CCM(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(nonce) != aead.NonceSize() {
|
||||
return ErrInvalidCCMNonceLength
|
||||
}
|
||||
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
|
||||
}
|
||||
|
||||
func EncryptSM4CFB(origData, key []byte) ([]byte, error) {
|
||||
@@ -235,6 +380,18 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) {
|
||||
return DecryptSM4(src, key, iv, MODECTR, "")
|
||||
}
|
||||
|
||||
func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return xorCTRAtOffset(block, data, iv, offset)
|
||||
}
|
||||
|
||||
func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
|
||||
return EncryptSM4CTRAt(src, key, iv, offset)
|
||||
}
|
||||
|
||||
func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
|
||||
return EncryptSM4Stream(dst, src, key, nil, MODEECB, paddingType)
|
||||
}
|
||||
@@ -274,3 +431,63 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
return DecryptSM4Stream(dst, src, key, iv, MODECTR, "")
|
||||
}
|
||||
|
||||
func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encryptCFB8(block, data, iv)
|
||||
}
|
||||
|
||||
func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCFB8(block, src, iv)
|
||||
}
|
||||
|
||||
func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return encryptCFB8Stream(block, dst, src, iv, false)
|
||||
}
|
||||
|
||||
func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return encryptCFB8Stream(block, dst, src, iv, true)
|
||||
}
|
||||
|
||||
func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptECBBlocks(block, src)
|
||||
}
|
||||
|
||||
// DecryptSM4CBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
|
||||
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
|
||||
func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
|
||||
}
|
||||
|
||||
// DecryptSM4CFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
|
||||
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
|
||||
func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
|
||||
block, err := sm4.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
|
||||
}
|
||||
|
||||
+739
-8
@@ -6,21 +6,39 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEncryptAesDefaultModeCBC(t *testing.T) {
|
||||
func TestEncryptAesDefaultModeGCM(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := []byte("aes-default-mode-cbc")
|
||||
nonce := []byte("123456789012")
|
||||
plain := []byte("aes-default-mode-gcm")
|
||||
|
||||
encDefault, err := EncryptAes(plain, key, iv, "", "")
|
||||
encDefault, err := EncryptAes(plain, key, nonce, "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAes default failed: %v", err)
|
||||
}
|
||||
encCBC, err := EncryptAesCBC(plain, key, iv, "")
|
||||
encGCM, err := EncryptAesGCM(plain, key, nonce, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCBC failed: %v", err)
|
||||
t.Fatalf("EncryptAesGCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encDefault, encCBC) {
|
||||
t.Fatalf("default mode should match CBC mode")
|
||||
if !bytes.Equal(encDefault, encGCM) {
|
||||
t.Fatalf("default mode should match GCM mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptSM4DefaultModeGCM(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
plain := []byte("sm4-default-mode-gcm")
|
||||
|
||||
encDefault, err := EncryptSM4(plain, key, nonce, "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4 default failed: %v", err)
|
||||
}
|
||||
encGCM, err := EncryptSM4GCM(plain, key, nonce, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4GCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encDefault, encGCM) {
|
||||
t.Fatalf("default mode should match GCM mode")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -529,6 +547,62 @@ func TestChaCha20Poly1305RFCVector(t *testing.T) {
|
||||
t.Fatalf("ChaCha20-Poly1305 vector mismatch: got %x want %x", enc, want)
|
||||
}
|
||||
}
|
||||
func TestAesGCMStreamRoundTripChunked(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := bytes.Repeat([]byte("aes-gcm-stream-chunk-"), 10000)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptAesGCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
|
||||
t.Fatalf("EncryptAesGCMStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesGCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptAesGCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes gcm stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesGCMStreamLegacyCompatDecrypt(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-gcm-legacy-compat")
|
||||
|
||||
legacyCipher, err := EncryptAesGCM(plain, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesGCM failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesGCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptAesGCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes gcm legacy decrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4GCMStreamRoundTripChunked(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := bytes.Repeat([]byte("sm4-gcm-stream-chunk-"), 10000)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptSM4GCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
|
||||
t.Fatalf("EncryptSM4GCMStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptSM4GCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptSM4GCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("sm4 gcm stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesOptionsDefaultToGCM(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
@@ -598,6 +672,113 @@ func TestLargeStreamRoundTrip(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesCTRAtOffsetSegment(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("aes-ctr-offset-"), 256)
|
||||
|
||||
full, err := EncryptAesCTR(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCTR failed: %v", err)
|
||||
}
|
||||
|
||||
offset := 137
|
||||
length := 521
|
||||
segCipher := full[offset : offset+length]
|
||||
|
||||
segPlain, err := DecryptAesCTRAt(segCipher, key, iv, int64(offset))
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCTRAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
|
||||
t.Fatalf("aes ctr offset decrypt mismatch")
|
||||
}
|
||||
|
||||
encSeg, err := EncryptAesCTRAt(plain[offset:offset+length], key, iv, int64(offset))
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCTRAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encSeg, segCipher) {
|
||||
t.Fatalf("aes ctr offset encrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CTRAtOffsetSegment(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("sm4-ctr-offset-"), 256)
|
||||
|
||||
full, err := EncryptSM4CTR(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CTR failed: %v", err)
|
||||
}
|
||||
|
||||
offset := 193
|
||||
length := 487
|
||||
segCipher := full[offset : offset+length]
|
||||
|
||||
segPlain, err := DecryptSM4CTRAt(segCipher, key, iv, int64(offset))
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CTRAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
|
||||
t.Fatalf("sm4 ctr offset decrypt mismatch")
|
||||
}
|
||||
|
||||
encSeg, err := EncryptSM4CTRAt(plain[offset:offset+length], key, iv, int64(offset))
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CTRAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(encSeg, segCipher) {
|
||||
t.Fatalf("sm4 ctr offset encrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesGCMChunkRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-gcm-chunk")
|
||||
chunkIndex := uint64(7)
|
||||
|
||||
enc, err := EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesGCMChunk failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesGCMChunk failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes gcm chunk decrypt mismatch")
|
||||
}
|
||||
if _, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
|
||||
t.Fatalf("expected decrypt error for wrong chunk index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4GCMChunkRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("sm4-gcm-chunk")
|
||||
chunkIndex := uint64(11)
|
||||
|
||||
enc, err := EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4GCMChunk failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4GCMChunk failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 gcm chunk decrypt mismatch")
|
||||
}
|
||||
if _, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
|
||||
t.Fatalf("expected decrypt error for wrong chunk index")
|
||||
}
|
||||
}
|
||||
func mustHex(t *testing.T, s string) []byte {
|
||||
t.Helper()
|
||||
b, err := hex.DecodeString(s)
|
||||
@@ -606,3 +787,553 @@ func mustHex(t *testing.T, s string) []byte {
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
func TestAesCFB8RoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := []byte("aes-cfb8-roundtrip-content")
|
||||
|
||||
enc, err := EncryptAesCFB8(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCFB8 failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesCFB8(enc, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCFB8 failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes cfb8 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CFB8RoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := []byte("sm4-cfb8-roundtrip-content")
|
||||
|
||||
enc, err := EncryptSM4CFB8(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CFB8 failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4CFB8(enc, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CFB8 failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 cfb8 mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesCFB8StreamRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("aes-cfb8-stream-"), 512)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptAesCFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
|
||||
t.Fatalf("EncryptAesCFB8Stream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesCFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
|
||||
t.Fatalf("DecryptAesCFB8Stream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes cfb8 stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CFB8StreamRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("sm4-cfb8-stream-"), 512)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptSM4CFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
|
||||
t.Fatalf("EncryptSM4CFB8Stream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptSM4CFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
|
||||
t.Fatalf("DecryptSM4CFB8Stream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("sm4 cfb8 stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesSegmentDecryptModes(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
|
||||
|
||||
ecbEnc, err := EncryptAesECB(plain, key, ZEROPADDING)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesECB failed: %v", err)
|
||||
}
|
||||
ecbDec, err := DecryptAesECBBlocks(ecbEnc, key)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesECBBlocks failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(ecbDec, plain) {
|
||||
t.Fatalf("aes ecb segment mismatch")
|
||||
}
|
||||
|
||||
cbcEnc, err := EncryptAesCBC(plain, key, iv, ZEROPADDING)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCBC failed: %v", err)
|
||||
}
|
||||
cbcSegDec, err := DecryptAesCBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
|
||||
t.Fatalf("aes cbc from-second-block mismatch")
|
||||
}
|
||||
|
||||
cfbEnc, err := EncryptAesCFB(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCFB failed: %v", err)
|
||||
}
|
||||
cfbSegDec, err := DecryptAesCFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
|
||||
t.Fatalf("aes cfb from-second-block mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4SegmentDecryptModes(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
iv := []byte("abcdef9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
|
||||
|
||||
ecbEnc, err := EncryptSM4ECB(plain, key, ZEROPADDING)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4ECB failed: %v", err)
|
||||
}
|
||||
ecbDec, err := DecryptSM4ECBBlocks(ecbEnc, key)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4ECBBlocks failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(ecbDec, plain) {
|
||||
t.Fatalf("sm4 ecb segment mismatch")
|
||||
}
|
||||
|
||||
cbcEnc, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CBC failed: %v", err)
|
||||
}
|
||||
cbcSegDec, err := DecryptSM4CBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
|
||||
t.Fatalf("sm4 cbc from-second-block mismatch")
|
||||
}
|
||||
|
||||
cfbEnc, err := EncryptSM4CFBNoBlock(plain, key, iv)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CFB failed: %v", err)
|
||||
}
|
||||
cfbSegDec, err := DecryptSM4CFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
|
||||
t.Fatalf("sm4 cfb from-second-block mismatch")
|
||||
}
|
||||
}
|
||||
func TestAesCCMRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-ccm-roundtrip")
|
||||
|
||||
enc, err := EncryptAesCCM(plain, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCCM failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesCCM(enc, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes ccm mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CCMRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("sm4-ccm-roundtrip")
|
||||
|
||||
enc, err := EncryptSM4CCM(plain, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CCM failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4CCM(enc, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 ccm mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesCCMStreamRoundTripChunked(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := bytes.Repeat([]byte("aes-ccm-stream-chunk-"), 10000)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptAesCCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
|
||||
t.Fatalf("EncryptAesCCMStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesCCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptAesCCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes ccm stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesCCMStreamLegacyCompatDecrypt(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-ccm-legacy-compat")
|
||||
|
||||
legacyCipher, err := EncryptAesCCM(plain, key, nonce, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCCM failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesCCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptAesCCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes ccm legacy decrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CCMStreamRoundTripChunked(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := bytes.Repeat([]byte("sm4-ccm-stream-chunk-"), 10000)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptSM4CCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
|
||||
t.Fatalf("EncryptSM4CCMStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptSM4CCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
|
||||
t.Fatalf("DecryptSM4CCMStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("sm4 ccm stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesCCMChunkRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-ccm-chunk")
|
||||
chunkIndex := uint64(5)
|
||||
|
||||
enc, err := EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesCCMChunk failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesCCMChunk failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes ccm chunk decrypt mismatch")
|
||||
}
|
||||
if _, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
|
||||
t.Fatalf("expected decrypt error for wrong chunk index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4CCMChunkRoundTrip(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("sm4-ccm-chunk")
|
||||
chunkIndex := uint64(9)
|
||||
|
||||
enc, err := EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4CCMChunk failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4CCMChunk failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 ccm chunk decrypt mismatch")
|
||||
}
|
||||
if _, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
|
||||
t.Fatalf("expected decrypt error for wrong chunk index")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesOptionsModeCCM(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("aes-options-ccm")
|
||||
|
||||
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
|
||||
enc, err := EncryptAesWithOptions(plain, key, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesWithOptions CCM failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesWithOptions(enc, key, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesWithOptions CCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes options ccm mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4OptionsModeCCM(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
nonce := []byte("123456789012")
|
||||
aad := []byte("aad")
|
||||
plain := []byte("sm4-options-ccm")
|
||||
|
||||
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
|
||||
enc, err := EncryptSM4WithOptions(plain, key, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4WithOptions(enc, key, opts)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 options ccm mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCCMInvalidNonceLength(t *testing.T) {
|
||||
key := []byte("0123456789abcdef")
|
||||
shortNonce := []byte("short")
|
||||
if _, err := EncryptAesCCM([]byte("x"), key, shortNonce, nil); err == nil {
|
||||
t.Fatalf("expected aes ccm nonce length error")
|
||||
}
|
||||
if _, err := EncryptSM4CCM([]byte("x"), key, shortNonce, nil); err == nil {
|
||||
t.Fatalf("expected sm4 ccm nonce length error")
|
||||
}
|
||||
}
|
||||
func TestAesXTSRoundTrip(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
|
||||
|
||||
enc, err := EncryptAesXTS(plain, k1, k2, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesXTS failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptAesXTS(enc, k1, k2, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesXTS failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("aes xts mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4XTSRoundTrip(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
|
||||
|
||||
enc, err := EncryptSM4XTS(plain, k1, k2, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4XTS failed: %v", err)
|
||||
}
|
||||
dec, err := DecryptSM4XTS(enc, k1, k2, 32)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4XTS failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec, plain) {
|
||||
t.Fatalf("sm4 xts mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesXTSAtDataUnit(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
|
||||
dataUnitSize := 32
|
||||
|
||||
full, err := EncryptAesXTS(plain, k1, k2, dataUnitSize)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesXTS failed: %v", err)
|
||||
}
|
||||
segPlain := plain[64:96]
|
||||
segEnc, err := EncryptAesXTSAt(segPlain, k1, k2, dataUnitSize, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptAesXTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segEnc, full[64:96]) {
|
||||
t.Fatalf("aes xts at encrypt mismatch")
|
||||
}
|
||||
|
||||
segDec, err := DecryptAesXTSAt(full[64:96], k1, k2, dataUnitSize, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptAesXTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segDec, segPlain) {
|
||||
t.Fatalf("aes xts at decrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4XTSAtDataUnit(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
|
||||
dataUnitSize := 32
|
||||
|
||||
full, err := EncryptSM4XTS(plain, k1, k2, dataUnitSize)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4XTS failed: %v", err)
|
||||
}
|
||||
segPlain := plain[32:64]
|
||||
segEnc, err := EncryptSM4XTSAt(segPlain, k1, k2, dataUnitSize, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segEnc, full[32:64]) {
|
||||
t.Fatalf("sm4 xts at encrypt mismatch")
|
||||
}
|
||||
|
||||
segDec, err := DecryptSM4XTSAt(full[32:64], k1, k2, dataUnitSize, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(segDec, segPlain) {
|
||||
t.Fatalf("sm4 xts at decrypt mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAesXTSStreamRoundTrip(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptAesXTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
|
||||
t.Fatalf("EncryptAesXTSStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptAesXTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
|
||||
t.Fatalf("DecryptAesXTSStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("aes xts stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSM4XTSStreamRoundTrip(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
|
||||
|
||||
enc := &bytes.Buffer{}
|
||||
if err := EncryptSM4XTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
|
||||
t.Fatalf("EncryptSM4XTSStream failed: %v", err)
|
||||
}
|
||||
dec := &bytes.Buffer{}
|
||||
if err := DecryptSM4XTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
|
||||
t.Fatalf("DecryptSM4XTSStream failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(dec.Bytes(), plain) {
|
||||
t.Fatalf("sm4 xts stream mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXTSRejectsNonBlockMultiple(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
if _, err := EncryptAesXTS([]byte("short"), k1, k2, 32); err == nil {
|
||||
t.Fatalf("expected aes xts non-block error")
|
||||
}
|
||||
if _, err := EncryptSM4XTS([]byte("short"), k1, k2, 32); err == nil {
|
||||
t.Fatalf("expected sm4 xts non-block error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXTSStreamRejectsTailNotFullBlock(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
err := EncryptAesXTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
|
||||
if err == nil {
|
||||
t.Fatalf("expected aes xts stream tail error")
|
||||
}
|
||||
err = EncryptSM4XTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
|
||||
if err == nil {
|
||||
t.Fatalf("expected sm4 xts stream tail error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestXTSInvalidDataUnitSize(t *testing.T) {
|
||||
k1 := []byte("0123456789abcdef")
|
||||
k2 := []byte("fedcba9876543210")
|
||||
plain := bytes.Repeat([]byte("0123456789abcdef"), 2)
|
||||
if _, err := EncryptAesXTS(plain, k1, k2, 30); err == nil {
|
||||
t.Fatalf("expected aes xts invalid data unit size error")
|
||||
}
|
||||
if _, err := EncryptSM4XTS(plain, k1, k2, 30); err == nil {
|
||||
t.Fatalf("expected sm4 xts invalid data unit size error")
|
||||
}
|
||||
}
|
||||
func TestSplitXTSMasterKeyHelpers(t *testing.T) {
|
||||
master := []byte("0123456789abcdef0123456789abcdef")
|
||||
|
||||
k1, k2, err := SplitXTSMasterKey(master)
|
||||
if err != nil {
|
||||
t.Fatalf("SplitXTSMasterKey failed: %v", err)
|
||||
}
|
||||
if len(k1) != 16 || len(k2) != 16 {
|
||||
t.Fatalf("split key lengths mismatch")
|
||||
}
|
||||
if !bytes.Equal(append(k1, k2...), master) {
|
||||
t.Fatalf("split key content mismatch")
|
||||
}
|
||||
|
||||
aesK1, aesK2, err := SplitAesXTSMasterKey(master)
|
||||
if err != nil {
|
||||
t.Fatalf("SplitAesXTSMasterKey failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(aesK1, k1) || !bytes.Equal(aesK2, k2) {
|
||||
t.Fatalf("aes split mismatch")
|
||||
}
|
||||
|
||||
sm4K1, sm4K2, err := SplitSM4XTSMasterKey(master)
|
||||
if err != nil {
|
||||
t.Fatalf("SplitSM4XTSMasterKey failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(sm4K1, k1) || !bytes.Equal(sm4K2, k2) {
|
||||
t.Fatalf("sm4 split mismatch")
|
||||
}
|
||||
|
||||
if _, _, err := SplitXTSMasterKey([]byte("abc")); err == nil {
|
||||
t.Fatalf("expected odd-length split error")
|
||||
}
|
||||
if _, _, err := SplitAesXTSMasterKey([]byte("0123456789abcdef0123456789abcdef01")); err == nil {
|
||||
t.Fatalf("expected aes master length error")
|
||||
}
|
||||
if _, _, err := SplitSM4XTSMasterKey([]byte("0123456789abcdef")); err == nil {
|
||||
t.Fatalf("expected sm4 master length error")
|
||||
}
|
||||
}
|
||||
|
||||
+263
@@ -0,0 +1,263 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/emmansun/gmsm/sm4"
|
||||
"golang.org/x/crypto/xts"
|
||||
)
|
||||
|
||||
const xtsBlockSize = 16
|
||||
|
||||
var (
|
||||
ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16")
|
||||
ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16")
|
||||
ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal")
|
||||
ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even")
|
||||
ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes")
|
||||
ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes")
|
||||
)
|
||||
|
||||
type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error)
|
||||
|
||||
func combineXTSKeys(key1, key2 []byte) ([]byte, error) {
|
||||
if len(key1) == 0 || len(key1) != len(key2) {
|
||||
return nil, ErrInvalidXTSKeyLength
|
||||
}
|
||||
out := make([]byte, len(key1)+len(key2))
|
||||
copy(out, key1)
|
||||
copy(out[len(key1):], key2)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||||
if len(masterKey) == 0 || len(masterKey)%2 != 0 {
|
||||
return nil, nil, ErrInvalidXTSMasterKeyLength
|
||||
}
|
||||
half := len(masterKey) / 2
|
||||
k1 := make([]byte, half)
|
||||
k2 := make([]byte, half)
|
||||
copy(k1, masterKey[:half])
|
||||
copy(k2, masterKey[half:])
|
||||
return k1, k2, nil
|
||||
}
|
||||
|
||||
// SplitXTSMasterKey splits a master key into two equal XTS keys.
|
||||
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||||
return splitXTSMasterKey(masterKey)
|
||||
}
|
||||
|
||||
// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes).
|
||||
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||||
switch len(masterKey) {
|
||||
case 32, 48, 64:
|
||||
return splitXTSMasterKey(masterKey)
|
||||
default:
|
||||
return nil, nil, ErrInvalidAESXTSMasterKeyLength
|
||||
}
|
||||
}
|
||||
|
||||
// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes).
|
||||
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||||
if len(masterKey) != 32 {
|
||||
return nil, nil, ErrInvalidSM4XTSMasterKeyLength
|
||||
}
|
||||
return splitXTSMasterKey(masterKey)
|
||||
}
|
||||
|
||||
func validateXTSDataUnitSize(dataUnitSize int) error {
|
||||
if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 {
|
||||
return ErrInvalidXTSDataUnitSize
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateXTSDataLength(data []byte) error {
|
||||
if len(data)%xtsBlockSize != 0 {
|
||||
return ErrInvalidXTSDataLength
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) {
|
||||
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateXTSDataLength(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(in) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
out := make([]byte, len(in))
|
||||
off := 0
|
||||
unit := dataUnitIndex
|
||||
for off < len(in) {
|
||||
chunkLen := dataUnitSize
|
||||
if remain := len(in) - off; remain < chunkLen {
|
||||
chunkLen = remain
|
||||
}
|
||||
if decrypt {
|
||||
c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
|
||||
} else {
|
||||
c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
|
||||
}
|
||||
off += chunkLen
|
||||
unit++
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error {
|
||||
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 32*1024)
|
||||
pending := make([]byte, 0, dataUnitSize*2)
|
||||
unit := dataUnitIndex
|
||||
|
||||
for {
|
||||
n, err := src.Read(buf)
|
||||
if n > 0 {
|
||||
pending = append(pending, buf[:n]...)
|
||||
processLen := len(pending) / dataUnitSize * dataUnitSize
|
||||
if processLen > 0 {
|
||||
out := make([]byte, processLen)
|
||||
for off := 0; off < processLen; off += dataUnitSize {
|
||||
if decrypt {
|
||||
c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
|
||||
} else {
|
||||
c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
|
||||
}
|
||||
unit++
|
||||
}
|
||||
if _, werr := dst.Write(out); werr != nil {
|
||||
return werr
|
||||
}
|
||||
pending = append([]byte(nil), pending[processLen:]...)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(pending) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := validateXTSDataLength(pending); err != nil {
|
||||
return err
|
||||
}
|
||||
out := make([]byte, len(pending))
|
||||
if decrypt {
|
||||
c.Decrypt(out, pending, unit)
|
||||
} else {
|
||||
c.Encrypt(out, pending, unit)
|
||||
}
|
||||
_, err := dst.Write(out)
|
||||
return err
|
||||
}
|
||||
|
||||
func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) {
|
||||
key, err := combineXTSKeys(key1, key2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return xts.NewCipher(newBlock, key)
|
||||
}
|
||||
|
||||
func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) {
|
||||
return newXTSCipher(aes.NewCipher, key1, key2)
|
||||
}
|
||||
|
||||
func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) {
|
||||
return newXTSCipher(sm4.NewCipher, key1, key2)
|
||||
}
|
||||
|
||||
func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) {
|
||||
c, err := factory(key1, key2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt)
|
||||
}
|
||||
|
||||
func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error {
|
||||
c, err := factory(key1, key2)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt)
|
||||
}
|
||||
|
||||
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||||
return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||||
return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||||
return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||||
}
|
||||
|
||||
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||||
return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||||
}
|
||||
|
||||
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||||
return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||||
return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||||
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||||
}
|
||||
|
||||
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||||
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||||
}
|
||||
|
||||
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||||
return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||||
return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||||
return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||||
}
|
||||
|
||||
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||||
return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||||
}
|
||||
|
||||
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||||
return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||||
return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||||
}
|
||||
|
||||
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||||
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||||
}
|
||||
|
||||
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||||
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package symm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// AES-XTS vectors from IEEE P1619/D16 Annex B (same set used by golang.org/x/crypto/xts tests).
|
||||
var aesXTSStandardVectors = []struct {
|
||||
key string
|
||||
dataUnitIndex uint64
|
||||
plaintext string
|
||||
ciphertext string
|
||||
}{
|
||||
{
|
||||
key: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
dataUnitIndex: 0,
|
||||
plaintext: "0000000000000000000000000000000000000000000000000000000000000000",
|
||||
ciphertext: "917cf69ebd68b2ec9b9fe9a3eadda692cd43d2f59598ed858c02c2652fbf922e",
|
||||
},
|
||||
{
|
||||
key: "1111111111111111111111111111111122222222222222222222222222222222",
|
||||
dataUnitIndex: 0x3333333333,
|
||||
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
|
||||
ciphertext: "c454185e6a16936e39334038acef838bfb186fff7480adc4289382ecd6d394f0",
|
||||
},
|
||||
{
|
||||
key: "fffefdfcfbfaf9f8f7f6f5f4f3f2f1f022222222222222222222222222222222",
|
||||
dataUnitIndex: 0x3333333333,
|
||||
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
|
||||
ciphertext: "af85336b597afc1a900b2eb21ec949d292df4c047e0b21532186a5971a227a89",
|
||||
},
|
||||
}
|
||||
|
||||
func TestAesXTSStandardVectors(t *testing.T) {
|
||||
for i, tc := range aesXTSStandardVectors {
|
||||
master := mustHex(t, tc.key)
|
||||
key1, key2, err := SplitAesXTSMasterKey(master)
|
||||
if err != nil {
|
||||
t.Fatalf("#%d split key failed: %v", i, err)
|
||||
}
|
||||
|
||||
plain := mustHex(t, tc.plaintext)
|
||||
wantCipher := mustHex(t, tc.ciphertext)
|
||||
dataUnitSize := len(plain)
|
||||
|
||||
gotCipher, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, tc.dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("#%d EncryptAesXTSAt failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(gotCipher, wantCipher) {
|
||||
t.Fatalf("#%d ciphertext mismatch", i)
|
||||
}
|
||||
|
||||
gotPlain, err := DecryptAesXTSAt(wantCipher, key1, key2, dataUnitSize, tc.dataUnitIndex)
|
||||
if err != nil {
|
||||
t.Fatalf("#%d DecryptAesXTSAt failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(gotPlain, plain) {
|
||||
t.Fatalf("#%d plaintext mismatch", i)
|
||||
}
|
||||
|
||||
encStream := &bytes.Buffer{}
|
||||
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
|
||||
t.Fatalf("#%d EncryptAesXTSStreamAt failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(encStream.Bytes(), wantCipher) {
|
||||
t.Fatalf("#%d stream ciphertext mismatch", i)
|
||||
}
|
||||
|
||||
decStream := &bytes.Buffer{}
|
||||
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(wantCipher), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
|
||||
t.Fatalf("#%d DecryptAesXTSStreamAt failed: %v", i, err)
|
||||
}
|
||||
if !bytes.Equal(decStream.Bytes(), plain) {
|
||||
t.Fatalf("#%d stream plaintext mismatch", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user