From 19636d09c105a81c97e18328d007c4fadd8b9b7c Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 31 Jan 2024 13:08:51 +0800 Subject: [PATCH] sm4: code refactor for block --- sm4/block.go | 42 ++++++++++++++++++++---------------------- sm4/cipher.go | 2 +- sm4/cipher_asm.go | 2 +- 3 files changed, 22 insertions(+), 24 deletions(-) diff --git a/sm4/block.go b/sm4/block.go index 7529a49..46b8db2 100644 --- a/sm4/block.go +++ b/sm4/block.go @@ -8,9 +8,8 @@ import ( // Encrypt one block from src into dst, using the expanded key xk. func encryptBlockGo(xk []uint32, dst, src []byte) { - _ = src[15] // early bounds check - dst = dst[:16] // early bounds check - _ = xk[31] // bounds check elimination hint + _ = src[15] // early bounds check + _ = xk[31] // bounds check elimination hint var b0, b1, b2, b3 uint32 b0 = binary.BigEndian.Uint32(src[0:4]) @@ -18,11 +17,13 @@ func encryptBlockGo(xk []uint32, dst, src []byte) { b2 = binary.BigEndian.Uint32(src[8:12]) b3 = binary.BigEndian.Uint32(src[12:16]) + // First round uses s-box directly and T transformation. b0 ^= t(b1 ^ b2 ^ b3 ^ xk[0]) b1 ^= t(b2 ^ b3 ^ b0 ^ xk[1]) b2 ^= t(b3 ^ b0 ^ b1 ^ xk[2]) b3 ^= t(b0 ^ b1 ^ b2 ^ xk[3]) + // Middle rounds (unroll loop) uses precomputed tables. b0 ^= precompute_t(b1 ^ b2 ^ b3 ^ xk[4]) b1 ^= precompute_t(b2 ^ b3 ^ b0 ^ xk[5]) b2 ^= precompute_t(b3 ^ b0 ^ b1 ^ xk[6]) @@ -53,15 +54,17 @@ func encryptBlockGo(xk []uint32, dst, src []byte) { b2 ^= precompute_t(b3 ^ b0 ^ b1 ^ xk[26]) b3 ^= precompute_t(b0 ^ b1 ^ b2 ^ xk[27]) + // Last round uses s-box directly and and T transformation to produce output. b0 ^= t(b1 ^ b2 ^ b3 ^ xk[28]) b1 ^= t(b2 ^ b3 ^ b0 ^ xk[29]) b2 ^= t(b3 ^ b0 ^ b1 ^ xk[30]) b3 ^= t(b0 ^ b1 ^ b2 ^ xk[31]) - binary.BigEndian.PutUint32(dst[:], b3) - binary.BigEndian.PutUint32(dst[4:], b2) - binary.BigEndian.PutUint32(dst[8:], b1) - binary.BigEndian.PutUint32(dst[12:], b0) + _ = dst[15] // early bounds check + binary.BigEndian.PutUint32(dst[0:4], b3) + binary.BigEndian.PutUint32(dst[4:8], b2) + binary.BigEndian.PutUint32(dst[8:12], b1) + binary.BigEndian.PutUint32(dst[12:16], b0) } // Key expansion algorithm. @@ -149,19 +152,14 @@ func expandKeyGo(key []byte, enc, dec []uint32) { enc[31], dec[0] = b3, b3 } -// Decrypt one block from src into dst, using the expanded key xk. -func decryptBlockGo(xk []uint32, dst, src []byte) { - encryptBlockGo(xk, dst, src) -} - // T func t(in uint32) uint32 { var b uint32 b = uint32(sbox[in&0xff]) - b |= uint32(sbox[(in>>8)&0xff]) << 8 - b |= uint32(sbox[(in>>16)&0xff]) << 16 - b |= uint32(sbox[(in>>24)&0xff]) << 24 + b |= uint32(sbox[in>>8&0xff]) << 8 + b |= uint32(sbox[in>>16&0xff]) << 16 + b |= uint32(sbox[in>>24&0xff]) << 24 // L return b ^ (b<<2 | b>>30) ^ (b<<10 | b>>22) ^ (b<<18 | b>>14) ^ (b<<24 | b>>8) @@ -172,17 +170,17 @@ func t2(in uint32) uint32 { var b uint32 b = uint32(sbox[in&0xff]) - b |= uint32(sbox[(in>>8)&0xff]) << 8 - b |= uint32(sbox[(in>>16)&0xff]) << 16 - b |= uint32(sbox[(in>>24)&0xff]) << 24 + b |= uint32(sbox[in>>8&0xff]) << 8 + b |= uint32(sbox[in>>16&0xff]) << 16 + b |= uint32(sbox[in>>24&0xff]) << 24 // L2 return b ^ (b<<13 | b>>19) ^ (b<<23 | b>>9) } func precompute_t(in uint32) uint32 { - return sbox_t0[byte(in>>24)] ^ - sbox_t1[byte(in>>16)] ^ - sbox_t2[byte(in>>8)] ^ - sbox_t3[byte(in)] + return sbox_t0[in>>24&0xff] ^ + sbox_t1[in>>16&0xff] ^ + sbox_t2[in>>8&0xff] ^ + sbox_t3[in&0xff] } diff --git a/sm4/cipher.go b/sm4/cipher.go index 29316e6..e3225a7 100644 --- a/sm4/cipher.go +++ b/sm4/cipher.go @@ -65,5 +65,5 @@ func (c *sm4Cipher) Decrypt(dst, src []byte) { if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - decryptBlockGo(c.dec, dst, src) + encryptBlockGo(c.dec, dst, src) } diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index f926575..56cbaf1 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -91,7 +91,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) { if useAESNI4SingleBlock { encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES) } else { - decryptBlockGo(c.dec, dst, src) + encryptBlockGo(c.dec, dst, src) } }