diff --git a/cipher/xts.go b/cipher/xts.go index dd21de5..fe23a73 100644 --- a/cipher/xts.go +++ b/cipher/xts.go @@ -220,12 +220,12 @@ func (c *xtsEncrypter) CryptBlocks(ciphertext, plaintext []byte) { // is there a final partial block to handle? if remain := len(plaintext); remain > 0 { var x [blockSize]byte - //Copy the final ciphertext bytes - copy(ciphertext, lastCiphertext[:remain]) //Copy the final plaintext bytes copy(x[:], plaintext) //Steal ciphertext to complete the block copy(x[remain:], lastCiphertext[remain:blockSize]) + //Copy the final ciphertext bytes + copy(ciphertext, lastCiphertext[:remain]) //Merge the tweak into the input block subtle.XORBytes(x[:], x[:], c.tweak[:]) //Encrypt the final block using K1 @@ -290,12 +290,12 @@ func (c *xtsDecrypter) CryptBlocks(plaintext, ciphertext []byte) { //Retrieve the length of the final block remain -= blockSize - //Copy the final plaintext bytes - copy(plaintext[blockSize:], plaintext) //Copy the final ciphertext bytes copy(x[:], ciphertext[blockSize:]) //Steal ciphertext to complete the block copy(x[remain:], plaintext[remain:blockSize]) + //Copy the final plaintext bytes + copy(plaintext[blockSize:], plaintext) } else { //The last block contains exactly 128 bits copy(x[:], ciphertext) diff --git a/cipher/xts_test.go b/cipher/xts_test.go index fe66987..7374548 100644 --- a/cipher/xts_test.go +++ b/cipher/xts_test.go @@ -81,17 +81,18 @@ func TestXTSWithAES(t *testing.T) { plaintext := fromHex(test.plaintext) ciphertext := make([]byte, len(plaintext)) - encrypter.CryptBlocks(ciphertext, plaintext) + copy(ciphertext, plaintext) + + encrypter.CryptBlocks(ciphertext, ciphertext) expectedCiphertext := fromHex(test.ciphertext) if !bytes.Equal(ciphertext, expectedCiphertext) { t.Errorf("#%d: encrypted failed, got: %x, want: %x", i, ciphertext, expectedCiphertext) continue } - decrypted := make([]byte, len(ciphertext)) - decrypter.CryptBlocks(decrypted, ciphertext) - if !bytes.Equal(decrypted, plaintext) { - t.Errorf("#%d: decryption failed, got: %x, want: %x", i, decrypted, plaintext) + decrypter.CryptBlocks(ciphertext, ciphertext) + if !bytes.Equal(ciphertext, plaintext) { + t.Errorf("#%d: decryption failed, got: %x, want: %x", i, ciphertext, plaintext) } } }