starcrypto/symm/fuzz_test.go

192 lines
5.7 KiB
Go
Raw Normal View History

package symm
import (
"bytes"
"testing"
)
func FuzzAesCBCRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-aes-cbc"))
f.Add([]byte{})
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
f.Fuzz(func(t *testing.T, data []byte) {
enc, err := EncryptAesCBC(data, key, iv, "")
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
}
dec, err := DecryptAesCBC(enc, key, iv, "")
if err != nil {
t.Fatalf("DecryptAesCBC failed: %v", err)
}
if !bytes.Equal(dec, data) {
t.Fatalf("aes cbc fuzz roundtrip mismatch")
}
})
}
func FuzzChaCha20RoundTrip(f *testing.F) {
f.Add([]byte("fuzz-chacha20"))
f.Add([]byte{})
key := []byte("0123456789abcdef0123456789abcdef")
nonce := []byte("123456789012")
f.Fuzz(func(t *testing.T, data []byte) {
enc, err := EncryptChaCha20(data, key, nonce)
if err != nil {
t.Fatalf("EncryptChaCha20 failed: %v", err)
}
dec, err := DecryptChaCha20(enc, key, nonce)
if err != nil {
t.Fatalf("DecryptChaCha20 failed: %v", err)
}
if !bytes.Equal(dec, data) {
t.Fatalf("chacha20 fuzz roundtrip mismatch")
}
})
}
func FuzzAesCBCStreamRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-aes-stream"))
f.Add([]byte{})
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
f.Fuzz(func(t *testing.T, data []byte) {
enc := &bytes.Buffer{}
dec := &bytes.Buffer{}
if err := EncryptAesCBCStream(enc, bytes.NewReader(data), key, iv, ""); err != nil {
t.Fatalf("EncryptAesCBCStream failed: %v", err)
}
if err := DecryptAesCBCStream(dec, bytes.NewReader(enc.Bytes()), key, iv, ""); err != nil {
t.Fatalf("DecryptAesCBCStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), data) {
t.Fatalf("aes cbc stream fuzz roundtrip mismatch")
}
})
}
func xtsFuzzDataUnitSize(selector uint8) int {
sizes := [...]int{16, 32, 64, 128, 256, 512}
return sizes[int(selector)%len(sizes)]
}
func xtsFuzzNormalizeData(data []byte, maxLen int) []byte {
if len(data) > maxLen {
data = data[:maxLen]
}
return data[:len(data)/16*16]
}
func FuzzAesXTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("aes xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("aes xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts bytes error for non-block-aligned input")
}
if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts stream error for non-block-aligned input")
}
}
})
}
func FuzzSM4XTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("sm4 xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("sm4 xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts bytes error for non-block-aligned input")
}
if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts stream error for non-block-aligned input")
}
}
})
}