From 8a2ba16639bbba64a4a51b6124767b4a26330f25 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 30 Aug 2024 16:25:25 +0800 Subject: [PATCH] internal/cryptotest: add tests for the cipher.AEAD interface --- cipher/gcm_sm4_test.go | 41 +++- internal/cryptotest/aead.go | 387 ++++++++++++++++++++++++++++++++++++ 2 files changed, 427 insertions(+), 1 deletion(-) create mode 100644 internal/cryptotest/aead.go diff --git a/cipher/gcm_sm4_test.go b/cipher/gcm_sm4_test.go index be09fa2..e3796e4 100644 --- a/cipher/gcm_sm4_test.go +++ b/cipher/gcm_sm4_test.go @@ -5,9 +5,11 @@ import ( "crypto/cipher" "crypto/rand" "encoding/hex" + "fmt" "io" "testing" + "github.com/emmansun/gmsm/internal/cryptotest" "github.com/emmansun/gmsm/sm4" ) @@ -398,7 +400,7 @@ func TestSM4GCMRandom(t *testing.T) { key := []byte("0123456789ABCDEF") nonce := []byte("0123456789AB") plaintext := make([]byte, 0x198) - + io.ReadFull(rand.Reader, plaintext) c, err := sm4.NewCipher(key) if err != nil { @@ -418,3 +420,40 @@ func TestSM4GCMRandom(t *testing.T) { t.Error("gcm seal/open 408 bytes fail") } } + +// Test GCM against the general cipher.AEAD interface tester. +func TestGCMAEAD(t *testing.T) { + minTagSize := 12 + + for _, keySize := range []int{128} { + // Use AES as underlying block cipher at different key sizes for GCM. + t.Run(fmt.Sprintf("SM4-%d", keySize), func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, keySize/8) + rng.Read(key) + + block, err := sm4.NewCipher(key) + if err != nil { + panic(err) + } + + // Test GCM with the current AES block with the standard nonce and tag + // sizes. + cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCM(block) }) + + // Test non-standard tag sizes. + t.Run("MinTagSize", func(t *testing.T) { + cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCMWithTagSize(block, minTagSize) }) + }) + + // Test non-standard nonce sizes. + for _, nonceSize := range []int{1, 16, 100} { + t.Run(fmt.Sprintf("NonceSize-%d", nonceSize), func(t *testing.T) { + + cryptotest.TestAEAD(t, func() (cipher.AEAD, error) { return cipher.NewGCMWithNonceSize(block, nonceSize) }) + }) + } + }) + } +} diff --git a/internal/cryptotest/aead.go b/internal/cryptotest/aead.go new file mode 100644 index 0000000..e17cdf8 --- /dev/null +++ b/internal/cryptotest/aead.go @@ -0,0 +1,387 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cryptotest + +import ( + "bytes" + "crypto/cipher" + "fmt" + "testing" +) + +var lengths = []int{0, 156, 8192, 8193, 8208} + +// MakeAEAD returns a cipher.AEAD instance. +// +// Multiple calls to MakeAEAD must return equivalent instances, so for example +// the key must be fixed. +type MakeAEAD func() (cipher.AEAD, error) + +// TestAEAD performs a set of tests on cipher.AEAD implementations, checking +// the documented requirements of NonceSize, Overhead, Seal and Open. +func TestAEAD(t *testing.T, mAEAD MakeAEAD) { + aead, err := mAEAD() + if err != nil { + t.Fatal(err) + } + + t.Run("Roundtrip", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + before, addData := make([]byte, adLen), make([]byte, ptLen) + rng.Read(before) + rng.Read(addData) + + ciphertext := sealMsg(t, aead, nil, nonce, before, addData) + after := openWithoutError(t, aead, nil, nonce, ciphertext, addData) + + if !bytes.Equal(after, before) { + t.Errorf("plaintext is different after a seal/open cycle; got %s, want %s", truncateHex(after), truncateHex(before)) + } + }) + } + } + }) + + t.Run("InputNotModified", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + t.Run("Seal", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + src, before := make([]byte, ptLen), make([]byte, ptLen) + rng.Read(src) + copy(before, src) + + addData := make([]byte, adLen) + rng.Read(addData) + + sealMsg(t, aead, nil, nonce, src, addData) + if !bytes.Equal(src, before) { + t.Errorf("Seal modified src; got %s, want %s", truncateHex(src), truncateHex(before)) + } + }) + + t.Run("Open", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + plaintext, addData := make([]byte, ptLen), make([]byte, adLen) + rng.Read(plaintext) + rng.Read(addData) + + // Record the ciphertext that shouldn't be modified as the input of + // Open. + ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData) + before := make([]byte, len(ciphertext)) + copy(before, ciphertext) + + openWithoutError(t, aead, nil, nonce, ciphertext, addData) + if !bytes.Equal(ciphertext, before) { + t.Errorf("Open modified src; got %s, want %s", truncateHex(ciphertext), truncateHex(before)) + } + }) + }) + } + } + }) + + t.Run("BufferOverlap", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + if ptLen <= 1 { // We need enough room for an overlap to occur. + continue + } + for _, adLen := range lengths { + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + t.Run("Seal", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + // Make a buffer that can hold a plaintext and ciphertext as we + // overlap their slices to check for panic on inexact overlaps. + ctLen := ptLen + aead.Overhead() + buff := make([]byte, ptLen+ctLen) + rng.Read(buff) + + addData := make([]byte, adLen) + rng.Read(addData) + + // Make plaintext and dst slices point to same array with inexact overlap. + plaintext := buff[:ptLen] + dst := buff[1:1] // Shift dst to not start at start of plaintext. + mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) }) + + // Only overlap on one byte + plaintext = buff[:ptLen] + dst = buff[ptLen-1 : ptLen-1] + mustPanic(t, "invalid buffer overlap", func() { sealMsg(t, aead, dst, nonce, plaintext, addData) }) + }) + + t.Run("Open", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + // Create a valid ciphertext to test Open with. + plaintext := make([]byte, ptLen) + rng.Read(plaintext) + addData := make([]byte, adLen) + rng.Read(addData) + validCT := sealMsg(t, aead, nil, nonce, plaintext, addData) + + // Make a buffer that can hold a plaintext and ciphertext as we + // overlap their slices to check for panic on inexact overlaps. + buff := make([]byte, ptLen+len(validCT)) + + // Make ciphertext and dst slices point to same array with inexact overlap. + ciphertext := buff[:len(validCT)] + copy(ciphertext, validCT) + dst := buff[1:1] // Shift dst to not start at start of ciphertext. + mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) }) + + // Only overlap on one byte. + ciphertext = buff[:len(validCT)] + copy(ciphertext, validCT) + // Make sure it is the actual ciphertext being overlapped and not + // the hash digest which might be extracted/truncated in some + // implementations: Go one byte past the hash digest/tag and into + // the ciphertext. + beforeTag := len(validCT) - aead.Overhead() + dst = buff[beforeTag-1 : beforeTag-1] + mustPanic(t, "invalid buffer overlap", func() { aead.Open(dst, nonce, ciphertext, addData) }) + }) + }) + } + } + }) + + t.Run("AppendDst", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + + t.Run("Seal", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + shortBuff := []byte("a") + longBuff := make([]byte, 512) + rng.Read(longBuff) + prefixes := [][]byte{shortBuff, longBuff} + + // Check each prefix gets appended to by Seal with altering them. + for _, prefix := range prefixes { + plaintext, addData := make([]byte, ptLen), make([]byte, adLen) + rng.Read(plaintext) + rng.Read(addData) + out := sealMsg(t, aead, prefix, nonce, plaintext, addData) + + // Check that Seal didn't alter the prefix + if !bytes.Equal(out[0:len(prefix)], prefix) { + t.Errorf("Seal alters dst instead of appending; got %s, want %s", truncateHex(out[0:len(prefix)]), truncateHex(prefix)) + } + + ciphertext := out[len(prefix):] + // Check that the appended ciphertext wasn't affected by the prefix + if expectedCT := sealMsg(t, aead, nil, nonce, plaintext, addData); !bytes.Equal(ciphertext, expectedCT) { + t.Errorf("Seal behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(ciphertext), truncateHex(expectedCT)) + } + } + }) + + t.Run("Open", func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + shortBuff := []byte("a") + longBuff := make([]byte, 512) + rng.Read(longBuff) + prefixes := [][]byte{shortBuff, longBuff} + + // Check each prefix gets appended to by Open with altering them. + for _, prefix := range prefixes { + before, addData := make([]byte, adLen), make([]byte, ptLen) + rng.Read(before) + rng.Read(addData) + ciphertext := sealMsg(t, aead, nil, nonce, before, addData) + + out := openWithoutError(t, aead, prefix, nonce, ciphertext, addData) + + // Check that Open didn't alter the prefix + if !bytes.Equal(out[0:len(prefix)], prefix) { + t.Errorf("Open alters dst instead of appending; got %s, want %s", truncateHex(out[0:len(prefix)]), truncateHex(prefix)) + } + + after := out[len(prefix):] + // Check that the appended plaintext wasn't affected by the prefix + if !bytes.Equal(after, before) { + t.Errorf("Open behavior affected by pre-existing data in dst; got %s, want %s", truncateHex(after), truncateHex(before)) + } + } + }) + }) + } + } + }) + + t.Run("WrongNonce", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + plaintext, addData := make([]byte, ptLen), make([]byte, adLen) + rng.Read(plaintext) + rng.Read(addData) + + ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData) + + // Perturb the nonce and check for an error when Opening + alterNonce := make([]byte, aead.NonceSize()) + copy(alterNonce, nonce) + alterNonce[len(alterNonce)-1] += 1 + _, err := aead.Open(nil, alterNonce, ciphertext, addData) + + if err == nil { + t.Errorf("Open did not error when given different nonce than Sealed with") + } + }) + } + } + }) + + t.Run("WrongAddData", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + if adLen == 0 { + continue + } + + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + plaintext, addData := make([]byte, ptLen), make([]byte, adLen) + rng.Read(plaintext) + rng.Read(addData) + + ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData) + + // Perturb the Additional Data and check for an error when Opening + alterAD := make([]byte, adLen) + copy(alterAD, addData) + alterAD[len(alterAD)-1] += 1 + _, err := aead.Open(nil, nonce, ciphertext, alterAD) + + if err == nil { + t.Errorf("Open did not error when given different Additional Data than Sealed with") + } + }) + } + } + }) + + t.Run("WrongCiphertext", func(t *testing.T) { + + // Test all combinations of plaintext and additional data lengths. + for _, ptLen := range lengths { + for _, adLen := range lengths { + + t.Run(fmt.Sprintf("Plaintext-Length=%d,AddData-Length=%d", ptLen, adLen), func(t *testing.T) { + rng := newRandReader(t) + + nonce := make([]byte, aead.NonceSize()) + rng.Read(nonce) + + plaintext, addData := make([]byte, ptLen), make([]byte, adLen) + rng.Read(plaintext) + rng.Read(addData) + + ciphertext := sealMsg(t, aead, nil, nonce, plaintext, addData) + + // Perturb the ciphertext and check for an error when Opening + alterCT := make([]byte, len(ciphertext)) + copy(alterCT, ciphertext) + alterCT[len(alterCT)-1] += 1 + _, err := aead.Open(nil, nonce, alterCT, addData) + + if err == nil { + t.Errorf("Open did not error when given different ciphertext than was produced by Seal") + } + }) + } + } + }) +} + +// Helper function to Seal a plaintext with additional data. Checks that +// ciphertext isn't bigger than the plaintext length plus Overhead() +func sealMsg(t *testing.T, aead cipher.AEAD, ciphertext, nonce, plaintext, addData []byte) []byte { + t.Helper() + + initialLen := len(ciphertext) + + ciphertext = aead.Seal(ciphertext, nonce, plaintext, addData) + + lenCT := len(ciphertext) - initialLen + + // Appended ciphertext shouldn't ever be longer than the length of the + // plaintext plus Overhead + if lenCT > len(plaintext)+aead.Overhead() { + t.Errorf("length of ciphertext from Seal exceeds length of plaintext by more than Overhead(); got %d, want <=%d", lenCT, len(plaintext)+aead.Overhead()) + } + + return ciphertext +} + +// Helper function to Open and authenticate ciphertext. Checks that Open +// doesn't error (assuming ciphertext was well-formed with corresponding nonce +// and additional data). +func openWithoutError(t *testing.T, aead cipher.AEAD, plaintext, nonce, ciphertext, addData []byte) []byte { + t.Helper() + + plaintext, err := aead.Open(plaintext, nonce, ciphertext, addData) + if err != nil { + t.Fatalf("Open returned error on properly formed ciphertext; got \"%s\", want \"nil\"", err) + } + + return plaintext +}