package symm import ( "bytes" "testing" ) func FuzzAesCBCRoundTrip(f *testing.F) { f.Add([]byte("fuzz-aes-cbc")) f.Add([]byte{}) key := []byte("0123456789abcdef") iv := []byte("abcdef9876543210") f.Fuzz(func(t *testing.T, data []byte) { enc, err := EncryptAesCBC(data, key, iv, "") if err != nil { t.Fatalf("EncryptAesCBC failed: %v", err) } dec, err := DecryptAesCBC(enc, key, iv, "") if err != nil { t.Fatalf("DecryptAesCBC failed: %v", err) } if !bytes.Equal(dec, data) { t.Fatalf("aes cbc fuzz roundtrip mismatch") } }) } func FuzzChaCha20RoundTrip(f *testing.F) { f.Add([]byte("fuzz-chacha20")) f.Add([]byte{}) key := []byte("0123456789abcdef0123456789abcdef") nonce := []byte("123456789012") f.Fuzz(func(t *testing.T, data []byte) { enc, err := EncryptChaCha20(data, key, nonce) if err != nil { t.Fatalf("EncryptChaCha20 failed: %v", err) } dec, err := DecryptChaCha20(enc, key, nonce) if err != nil { t.Fatalf("DecryptChaCha20 failed: %v", err) } if !bytes.Equal(dec, data) { t.Fatalf("chacha20 fuzz roundtrip mismatch") } }) } func FuzzAesCBCStreamRoundTrip(f *testing.F) { f.Add([]byte("fuzz-aes-stream")) f.Add([]byte{}) key := []byte("0123456789abcdef") iv := []byte("abcdef9876543210") f.Fuzz(func(t *testing.T, data []byte) { enc := &bytes.Buffer{} dec := &bytes.Buffer{} if err := EncryptAesCBCStream(enc, bytes.NewReader(data), key, iv, ""); err != nil { t.Fatalf("EncryptAesCBCStream failed: %v", err) } if err := DecryptAesCBCStream(dec, bytes.NewReader(enc.Bytes()), key, iv, ""); err != nil { t.Fatalf("DecryptAesCBCStream failed: %v", err) } if !bytes.Equal(dec.Bytes(), data) { t.Fatalf("aes cbc stream fuzz roundtrip mismatch") } }) } func xtsFuzzDataUnitSize(selector uint8) int { sizes := [...]int{16, 32, 64, 128, 256, 512} return sizes[int(selector)%len(sizes)] } func xtsFuzzNormalizeData(data []byte, maxLen int) []byte { if len(data) > maxLen { data = data[:maxLen] } return data[:len(data)/16*16] } func FuzzAesXTSRoundTrip(f *testing.F) { f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0)) f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3)) key1 := []byte("0123456789abcdef") key2 := []byte("fedcba9876543210") f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) { bounded := data if len(bounded) > 4097 { bounded = bounded[:4097] } plain := xtsFuzzNormalizeData(bounded, 4096) dataUnitSize := xtsFuzzDataUnitSize(selector) enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) if err != nil { t.Fatalf("EncryptAesXTSAt failed: %v", err) } dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex) if err != nil { t.Fatalf("DecryptAesXTSAt failed: %v", err) } if !bytes.Equal(dec, plain) { t.Fatalf("aes xts roundtrip mismatch") } encStream := &bytes.Buffer{} if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil { t.Fatalf("EncryptAesXTSStreamAt failed: %v", err) } if !bytes.Equal(encStream.Bytes(), enc) { t.Fatalf("aes xts bytes/stream encrypt mismatch") } decStream := &bytes.Buffer{} if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil { t.Fatalf("DecryptAesXTSStreamAt failed: %v", err) } if !bytes.Equal(decStream.Bytes(), plain) { t.Fatalf("aes xts stream decrypt mismatch") } if len(bounded)%16 != 0 { if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil { t.Fatalf("expected aes xts bytes error for non-block-aligned input") } if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil { t.Fatalf("expected aes xts stream error for non-block-aligned input") } } }) } func FuzzSM4XTSRoundTrip(f *testing.F) { f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0)) f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4)) key1 := []byte("0123456789abcdef") key2 := []byte("fedcba9876543210") f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) { bounded := data if len(bounded) > 4097 { bounded = bounded[:4097] } plain := xtsFuzzNormalizeData(bounded, 4096) dataUnitSize := xtsFuzzDataUnitSize(selector) enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) if err != nil { t.Fatalf("EncryptSM4XTSAt failed: %v", err) } dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex) if err != nil { t.Fatalf("DecryptSM4XTSAt failed: %v", err) } if !bytes.Equal(dec, plain) { t.Fatalf("sm4 xts roundtrip mismatch") } encStream := &bytes.Buffer{} if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil { t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err) } if !bytes.Equal(encStream.Bytes(), enc) { t.Fatalf("sm4 xts bytes/stream encrypt mismatch") } decStream := &bytes.Buffer{} if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil { t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err) } if !bytes.Equal(decStream.Bytes(), plain) { t.Fatalf("sm4 xts stream decrypt mismatch") } if len(bounded)%16 != 0 { if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil { t.Fatalf("expected sm4 xts bytes error for non-block-aligned input") } if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil { t.Fatalf("expected sm4 xts stream error for non-block-aligned input") } } }) }