diff --git a/sm4/block.go b/sm4/block.go index c7134d8..14bf11e 100644 --- a/sm4/block.go +++ b/sm4/block.go @@ -7,8 +7,6 @@ import ( "math/bits" ) -type convert func(uint32) uint32 - // Encrypt one block from src into dst, using the expanded key xk. func encryptBlockGo(xk []uint32, dst, src []byte) { _ = src[15] // early bounds check @@ -71,34 +69,85 @@ func encryptBlockGo(xk []uint32, dst, src []byte) { func expandKeyGo(key []byte, enc, dec []uint32) { // Encryption key setup. enc = enc[:rounds] - var i int - var mk [4]uint32 - var k [rounds + 4]uint32 - - key = key[:KeySize] - mk[0] = binary.BigEndian.Uint32(key) - k[0] = mk[0] ^ fk[0] - mk[1] = binary.BigEndian.Uint32(key[4:]) - k[1] = mk[1] ^ fk[1] - mk[2] = binary.BigEndian.Uint32(key[8:]) - k[2] = mk[2] ^ fk[2] - mk[3] = binary.BigEndian.Uint32(key[12:]) - k[3] = mk[3] ^ fk[3] - - for i = 0; i < rounds; i++ { - k[i+4] = k[i] ^ t2(k[i+1]^k[i+2]^k[i+3]^ck[i]) - enc[i] = k[i+4] - } - - // Derive decryption key from encryption key. - if dec == nil { - return - } - dec = dec[:rounds] - for i = 0; i < rounds; i++ { - dec[i] = enc[rounds-1-i] - } + key = key[:KeySize] + var b0, b1, b2, b3 uint32 + b0 = binary.BigEndian.Uint32(key[:4]) ^ fk[0] + b1 = binary.BigEndian.Uint32(key[4:8]) ^ fk[1] + b2 = binary.BigEndian.Uint32(key[8:12]) ^ fk[2] + b3 = binary.BigEndian.Uint32(key[12:16]) ^ fk[3] + + b0 = b0 ^ t2(b1^b2^b3^ck[0]) + enc[0], dec[31] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[1]) + enc[1], dec[30] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[2]) + enc[2], dec[29] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[3]) + enc[3], dec[28] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[4]) + enc[4], dec[27] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[5]) + enc[5], dec[26] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[6]) + enc[6], dec[25] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[7]) + enc[7], dec[24] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[8]) + enc[8], dec[23] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[9]) + enc[9], dec[22] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[10]) + enc[10], dec[21] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[11]) + enc[11], dec[20] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[12]) + enc[12], dec[19] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[13]) + enc[13], dec[18] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[14]) + enc[14], dec[17] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[15]) + enc[15], dec[16] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[16]) + enc[16], dec[15] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[17]) + enc[17], dec[14] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[18]) + enc[18], dec[13] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[19]) + enc[19], dec[12] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[20]) + enc[20], dec[11] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[21]) + enc[21], dec[10] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[22]) + enc[22], dec[9] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[23]) + enc[23], dec[8] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[24]) + enc[24], dec[7] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[25]) + enc[25], dec[6] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[26]) + enc[26], dec[5] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[27]) + enc[27], dec[4] = b3, b3 + + b0 = b0 ^ t2(b1^b2^b3^ck[28]) + enc[28], dec[3] = b0, b0 + b1 = b1 ^ t2(b2^b3^b0^ck[29]) + enc[29], dec[2] = b1, b1 + b2 = b2 ^ t2(b3^b0^b1^ck[30]) + enc[30], dec[1] = b2, b2 + b3 = b3 ^ t2(b0^b1^b2^ck[31]) + enc[31], dec[0] = b3, b3 } // Decrypt one block from src into dst, using the expanded key xk. @@ -106,33 +155,28 @@ func decryptBlockGo(xk []uint32, dst, src []byte) { encryptBlockGo(xk, dst, src) } -// L(B) -func l(b uint32) uint32 { - return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24) -} - -// L'(B) -func l2(b uint32) uint32 { - return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23) -} - -func _t(in uint32, fn convert) uint32 { - var bytes [4]byte - binary.BigEndian.PutUint32(bytes[:], in) - for i := 0; i < 4; i++ { - bytes[i] = sbox[bytes[i]] - } - return fn(binary.BigEndian.Uint32(bytes[:])) -} - // T func t(in uint32) uint32 { - return _t(in, l) + var b uint32 + + b = uint32(sbox[in&0xff]) + b |= uint32(sbox[(in>>8)&0xff]) << 8 + b |= uint32(sbox[(in>>16)&0xff]) << 16 + b |= uint32(sbox[(in>>24)&0xff]) << 24 + + return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24) } // T' func t2(in uint32) uint32 { - return _t(in, l2) + var b uint32 + + b = uint32(sbox[in&0xff]) + b |= uint32(sbox[(in>>8)&0xff]) << 8 + b |= uint32(sbox[(in>>16)&0xff]) << 16 + b |= uint32(sbox[(in>>24)&0xff]) << 24 + + return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23) } func precompute_t(in uint32) uint32 {