192 lines
5.7 KiB
Go
192 lines
5.7 KiB
Go
package symm
|
|
|
|
import (
|
|
"bytes"
|
|
"testing"
|
|
)
|
|
|
|
func FuzzAesCBCRoundTrip(f *testing.F) {
|
|
f.Add([]byte("fuzz-aes-cbc"))
|
|
f.Add([]byte{})
|
|
|
|
key := []byte("0123456789abcdef")
|
|
iv := []byte("abcdef9876543210")
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
enc, err := EncryptAesCBC(data, key, iv, "")
|
|
if err != nil {
|
|
t.Fatalf("EncryptAesCBC failed: %v", err)
|
|
}
|
|
dec, err := DecryptAesCBC(enc, key, iv, "")
|
|
if err != nil {
|
|
t.Fatalf("DecryptAesCBC failed: %v", err)
|
|
}
|
|
if !bytes.Equal(dec, data) {
|
|
t.Fatalf("aes cbc fuzz roundtrip mismatch")
|
|
}
|
|
})
|
|
}
|
|
|
|
func FuzzChaCha20RoundTrip(f *testing.F) {
|
|
f.Add([]byte("fuzz-chacha20"))
|
|
f.Add([]byte{})
|
|
|
|
key := []byte("0123456789abcdef0123456789abcdef")
|
|
nonce := []byte("123456789012")
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
enc, err := EncryptChaCha20(data, key, nonce)
|
|
if err != nil {
|
|
t.Fatalf("EncryptChaCha20 failed: %v", err)
|
|
}
|
|
dec, err := DecryptChaCha20(enc, key, nonce)
|
|
if err != nil {
|
|
t.Fatalf("DecryptChaCha20 failed: %v", err)
|
|
}
|
|
if !bytes.Equal(dec, data) {
|
|
t.Fatalf("chacha20 fuzz roundtrip mismatch")
|
|
}
|
|
})
|
|
}
|
|
|
|
func FuzzAesCBCStreamRoundTrip(f *testing.F) {
|
|
f.Add([]byte("fuzz-aes-stream"))
|
|
f.Add([]byte{})
|
|
|
|
key := []byte("0123456789abcdef")
|
|
iv := []byte("abcdef9876543210")
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte) {
|
|
enc := &bytes.Buffer{}
|
|
dec := &bytes.Buffer{}
|
|
if err := EncryptAesCBCStream(enc, bytes.NewReader(data), key, iv, ""); err != nil {
|
|
t.Fatalf("EncryptAesCBCStream failed: %v", err)
|
|
}
|
|
if err := DecryptAesCBCStream(dec, bytes.NewReader(enc.Bytes()), key, iv, ""); err != nil {
|
|
t.Fatalf("DecryptAesCBCStream failed: %v", err)
|
|
}
|
|
if !bytes.Equal(dec.Bytes(), data) {
|
|
t.Fatalf("aes cbc stream fuzz roundtrip mismatch")
|
|
}
|
|
})
|
|
}
|
|
func xtsFuzzDataUnitSize(selector uint8) int {
|
|
sizes := [...]int{16, 32, 64, 128, 256, 512}
|
|
return sizes[int(selector)%len(sizes)]
|
|
}
|
|
|
|
func xtsFuzzNormalizeData(data []byte, maxLen int) []byte {
|
|
if len(data) > maxLen {
|
|
data = data[:maxLen]
|
|
}
|
|
return data[:len(data)/16*16]
|
|
}
|
|
|
|
func FuzzAesXTSRoundTrip(f *testing.F) {
|
|
f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0))
|
|
f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3))
|
|
|
|
key1 := []byte("0123456789abcdef")
|
|
key2 := []byte("fedcba9876543210")
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
|
|
bounded := data
|
|
if len(bounded) > 4097 {
|
|
bounded = bounded[:4097]
|
|
}
|
|
plain := xtsFuzzNormalizeData(bounded, 4096)
|
|
dataUnitSize := xtsFuzzDataUnitSize(selector)
|
|
|
|
enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
|
|
if err != nil {
|
|
t.Fatalf("EncryptAesXTSAt failed: %v", err)
|
|
}
|
|
dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
|
|
if err != nil {
|
|
t.Fatalf("DecryptAesXTSAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(dec, plain) {
|
|
t.Fatalf("aes xts roundtrip mismatch")
|
|
}
|
|
|
|
encStream := &bytes.Buffer{}
|
|
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
|
t.Fatalf("EncryptAesXTSStreamAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(encStream.Bytes(), enc) {
|
|
t.Fatalf("aes xts bytes/stream encrypt mismatch")
|
|
}
|
|
|
|
decStream := &bytes.Buffer{}
|
|
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
|
t.Fatalf("DecryptAesXTSStreamAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(decStream.Bytes(), plain) {
|
|
t.Fatalf("aes xts stream decrypt mismatch")
|
|
}
|
|
|
|
if len(bounded)%16 != 0 {
|
|
if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
|
t.Fatalf("expected aes xts bytes error for non-block-aligned input")
|
|
}
|
|
if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
|
t.Fatalf("expected aes xts stream error for non-block-aligned input")
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func FuzzSM4XTSRoundTrip(f *testing.F) {
|
|
f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0))
|
|
f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4))
|
|
|
|
key1 := []byte("0123456789abcdef")
|
|
key2 := []byte("fedcba9876543210")
|
|
|
|
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
|
|
bounded := data
|
|
if len(bounded) > 4097 {
|
|
bounded = bounded[:4097]
|
|
}
|
|
plain := xtsFuzzNormalizeData(bounded, 4096)
|
|
dataUnitSize := xtsFuzzDataUnitSize(selector)
|
|
|
|
enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
|
|
if err != nil {
|
|
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
|
|
}
|
|
dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
|
|
if err != nil {
|
|
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(dec, plain) {
|
|
t.Fatalf("sm4 xts roundtrip mismatch")
|
|
}
|
|
|
|
encStream := &bytes.Buffer{}
|
|
if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
|
t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(encStream.Bytes(), enc) {
|
|
t.Fatalf("sm4 xts bytes/stream encrypt mismatch")
|
|
}
|
|
|
|
decStream := &bytes.Buffer{}
|
|
if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
|
|
t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err)
|
|
}
|
|
if !bytes.Equal(decStream.Bytes(), plain) {
|
|
t.Fatalf("sm4 xts stream decrypt mismatch")
|
|
}
|
|
|
|
if len(bounded)%16 != 0 {
|
|
if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
|
t.Fatalf("expected sm4 xts bytes error for non-block-aligned input")
|
|
}
|
|
if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
|
|
t.Fatalf("expected sm4 xts stream error for non-block-aligned input")
|
|
}
|
|
}
|
|
})
|
|
}
|