diff --git a/sm4/asm_amd64.s b/sm4/asm_amd64.s index 8533b00..4525002 100644 --- a/sm4/asm_amd64.s +++ b/sm4/asm_amd64.s @@ -16,6 +16,11 @@ DATA flip_mask<>+0x00(SB)/8, $0x0405060700010203 DATA flip_mask<>+0x08(SB)/8, $0x0c0d0e0f08090a0b GLOBL flip_mask<>(SB), RODATA, $16 +// shuffle byte and word order +DATA flip_mask2<>+0x00(SB)/8, $0x08090a0b0c0d0e0f +DATA flip_mask2<>+0x08(SB)/8, $0x0001020304050607 +GLOBL flip_mask2<>(SB), RODATA, $16 + //nibble mask DATA nibble_mask<>+0x00(SB)/8, $0x0F0F0F0F0F0F0F0F DATA nibble_mask<>+0x08(SB)/8, $0x0F0F0F0F0F0F0F0F @@ -121,6 +126,116 @@ GLOBL fk_mask<>(SB), RODATA, $16 PXOR XTMP7, y; \ PXOR y, x +#define SM4_ROUND(index, x, y, t0, t1, t2, t3) \ + PINSRD $0, (index * 4)(AX)(CX*1), x; \ + PSHUFD $0, x, x; \ + PXOR t1, x; \ + PXOR t2, x; \ + PXOR t3, x; \ + SM4_TAO_L1(x, y); \ + PXOR x, t0 + +#define SM4_SINGLE_ROUND(index, x, y, t0, t1, t2, t3) \ + PINSRD $0, (index * 4)(AX)(CX*1), x; \ + PXOR t1, x; \ + PXOR t2, x; \ + PXOR t3, x; \ + SM4_TAO_L1(x, y); \ + PXOR x, t0 + +#define SM4_EXPANDKEY_ROUND(index, x, y, t0, t1, t2, t3) \ + PINSRD $0, (index * 4)(BX)(CX*1), x; \ + PXOR t1, x; \ + PXOR t2, x; \ + PXOR t3, x; \ + SM4_TAO_L2(x, y); \ + PXOR x, t0; \ + PEXTRD $0, t0, R8; \ + MOVL R8, (index * 4)(DX)(CX*1); \ + MOVL R8, (12 - index * 4)(DI)(SI*1) + +#define XDWORD0 Y4 +#define XDWORD1 Y5 +#define XDWORD2 Y6 +#define XDWORD3 Y7 + +#define XWORD0 X4 +#define XWORD1 X5 +#define XWORD2 X6 +#define XWORD3 X7 + +#define XTMP0 Y0 +#define XTMP1 Y1 +#define XTMP2 Y2 +#define NIBBLE_MASK Y3 +#define X_NIBBLE_MASK X3 + +#define BYTE_FLIP_MASK Y13 // mask to convert LE -> BE +#define XDWORD Y8 +#define XWORD X8 +#define YDWORD Y9 +#define YWORD X9 + +#define TRANSPOSE_MATRIX(r0, r1, r2, r3) \ + VPUNPCKHDQ r1, r0, XTMP2; \ // XTMP2 = [w15, w7, w14, w6, w11, w3, w10, w2] + VPUNPCKLDQ r1, r0, r0; \ // r0 = [w13, w5, w12, w4, w9, w1, w8, w0] + VPUNPCKLDQ r3, r2, XTMP1; \ // XTMP1 = [w29, w21, w28, w20, w25, w17, w24, w16] + VPUNPCKHDQ r3, r2, r2; \ // r2 = [w31, w27, w30, w22, w27, w19, w26, w18] + VPUNPCKHQDQ XTMP1, r0, r1; \ // r1 = [w29, w21, w13, w5, w25, w17, w9, w1] + VPUNPCKLQDQ XTMP1, r0, r0; \ // r0 = [w28, w20, w12, w4, w24, w16, w8, w0] + VPUNPCKHQDQ r2, XTMP2, r3; \ // r3 = [w31, w27, w15, w7, w27, w19, w11, w3] + VPUNPCKLQDQ r2, XTMP2, r2 // r2 = [w30, w22, w14, w6, w26, w18, w10, w2] + +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html +#define AVX2_SM4_SBOX(x, y) \ + VBROADCASTI128 nibble_mask<>(SB), NIBBLE_MASK; \ + VPAND NIBBLE_MASK, x, XTMP1; \ + VBROADCASTI128 m1_low<>(SB), y; \ + VPSHUFB XTMP1, y, y; \ + VPSRLQ $4, x, x; \ + VPAND NIBBLE_MASK, x, x; \ + VBROADCASTI128 m1_high<>(SB), XTMP1; \ + VPSHUFB x, XTMP1, x; \ + VPXOR y, x, x; \ + VBROADCASTI128 inverse_shift_rows<>(SB), XTMP1;\ + VPSHUFB XTMP1, x, x; \ + VEXTRACTI128 $1, x, YWORD \ + VAESENCLAST X_NIBBLE_MASK, XWORD, XWORD; \ + VAESENCLAST X_NIBBLE_MASK, YWORD, YWORD; \ + VINSERTI128 $1, YWORD, x, x; \ + VPANDN NIBBLE_MASK, x, XTMP1; \ + VBROADCASTI128 m2_low<>(SB), y; \ + VPSHUFB XTMP1, y, y; \ + VPSRLQ $4, x, x; \ + VPAND NIBBLE_MASK, x, x; \ + VBROADCASTI128 m2_high<>(SB), XTMP1; \ + VPSHUFB x, XTMP1, x; \ + VPXOR y, x, x + +#define AVX2_SM4_TAO_L1(x, y) \ + AVX2_SM4_SBOX(x, y); \ + VBROADCASTI128 r08_mask<>(SB), XTMP0; \ + VPSHUFB XTMP0, x, y; \ + VPXOR x, y, y; \ + VBROADCASTI128 r16_mask<>(SB), XTMP0; \ + VPSHUFB XTMP0, x, XTMP0; \ + VPXOR XTMP0, y, y; \ + VPSLLD $2, y, XTMP1; \ + VPSRLD $30, y, y; \ + VPXOR XTMP1, y, y; \ + VBROADCASTI128 r24_mask<>(SB), XTMP0; \ + VPSHUFB XTMP0, x, XTMP0; \ + VPXOR y, x, x; \ + VPXOR x, XTMP0, x + +#define AVX2_SM4_ROUND(index, x, y, t0, t1, t2, t3) \ + VPBROADCASTD (index * 4)(AX)(CX*1), x; \ + VPXOR t1, x, x; \ + VPXOR t2, x, x; \ + VPXOR t3, x, x; \ + AVX2_SM4_TAO_L1(x, y); \ + VPXOR x, t0, t0 + // func expandKeyAsm(key *byte, ck, enc, dec *uint32) TEXT ·expandKeyAsm(SB),NOSPLIT,$0 MOVQ key+0(FP), AX @@ -139,45 +254,10 @@ TEXT ·expandKeyAsm(SB),NOSPLIT,$0 MOVL $112, SI loop: - PINSRD $0, 0(BX)(CX*1), x - PXOR t1, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L2(x, y) - PXOR x, t0 - PEXTRD $0, t0, R8 - MOVL R8, 0(DX)(CX*1) - MOVL R8, 12(DI)(SI*1) - - PINSRD $0, 4(BX)(CX*1), x - PXOR t0, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L2(x, y) - PXOR x, t1 - PEXTRD $0, t1, R8 - MOVL R8, 4(DX)(CX*1) - MOVL R8, 8(DI)(SI*1) - - PINSRD $0, 8(BX)(CX*1), x - PXOR t0, x - PXOR t1, x - PXOR t3, x - SM4_TAO_L2(x, y) - PXOR x, t2 - PEXTRD $0, t2, R8 - MOVL R8, 8(DX)(CX*1) - MOVL R8, 4(DI)(SI*1) - - PINSRD $0, 12(BX)(CX*1), x - PXOR t0, x - PXOR t1, x - PXOR t2, x - SM4_TAO_L2(x, y) - PXOR x, t3 - PEXTRD $0, t3, R8 - MOVL R8, 12(DX)(CX*1) - MOVL R8, 0(DI)(SI*1) + SM4_EXPANDKEY_ROUND(0, x, y, t0, t1, t2, t3) + SM4_EXPANDKEY_ROUND(1, x, y, t1, t2, t3, t0) + SM4_EXPANDKEY_ROUND(2, x, y, t2, t3, t0, t1) + SM4_EXPANDKEY_ROUND(3, x, y, t3, t0, t1, t2) ADDL $16, CX SUBL $16, SI @@ -187,12 +267,20 @@ loop: expand_end: RET -// func encryptBlocksAsm(xk *uint32, dst, src *byte) +// func encryptBlocksAsm(xk *uint32, dst, src []byte) TEXT ·encryptBlocksAsm(SB),NOSPLIT,$0 MOVQ xk+0(FP), AX MOVQ dst+8(FP), BX - MOVQ src+16(FP), DX + MOVQ src+32(FP), DX + MOVQ src_len+40(FP), DI + CMPL DI, $64 + JBE non_avx2_start + + CMPB ·useAVX2(SB), $1 + JE avx2 + +non_avx2_start: PINSRD $0, 0(DX), t0 PINSRD $1, 16(DX), t0 PINSRD $2, 32(DX), t0 @@ -220,37 +308,10 @@ TEXT ·encryptBlocksAsm(SB),NOSPLIT,$0 XORL CX, CX loop: - PINSRD $0, 0(AX)(CX*1), x - PSHUFD $0, x, x - PXOR t1, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t0 - - PINSRD $0, 4(AX)(CX*1), x - PSHUFD $0, x, x - PXOR t0, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t1 - - PINSRD $0, 8(AX)(CX*1), x - PSHUFD $0, x, x - PXOR t0, x - PXOR t1, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t2 - - PINSRD $0, 12(AX)(CX*1), x - PSHUFD $0, x, x - PXOR t0, x - PXOR t1, x - PXOR t2, x - SM4_TAO_L1(x, y) - PXOR x, t3 + SM4_ROUND(0, x, y, t0, t1, t2, t3) + SM4_ROUND(1, x, y, t1, t2, t3, t0) + SM4_ROUND(2, x, y, t2, t3, t0, t1) + SM4_ROUND(3, x, y, t3, t0, t1, t2) ADDL $16, CX CMPL CX, $4*32 @@ -290,7 +351,52 @@ loop: MOVL R8, 56(BX) done_sm4: - RET + RET + +avx2: + VMOVDQU 0(DX), XDWORD0 + VMOVDQU 32(DX), XDWORD1 + VMOVDQU 64(DX), XDWORD2 + VMOVDQU 96(DX), XDWORD3 + VBROADCASTI128 flip_mask<>(SB), BYTE_FLIP_MASK + + // Apply Byte Flip Mask: LE -> BE + VPSHUFB BYTE_FLIP_MASK, XDWORD0, XDWORD0 + VPSHUFB BYTE_FLIP_MASK, XDWORD1, XDWORD1 + VPSHUFB BYTE_FLIP_MASK, XDWORD2, XDWORD2 + VPSHUFB BYTE_FLIP_MASK, XDWORD3, XDWORD3 + + // Transpose matrix 4 x 4 32bits word + TRANSPOSE_MATRIX(XDWORD0, XDWORD1, XDWORD2, XDWORD3) + + XORL CX, CX + +avx2_loop: + AVX2_SM4_ROUND(0, XDWORD, YDWORD, XDWORD0, XDWORD1, XDWORD2, XDWORD3) + AVX2_SM4_ROUND(1, XDWORD, YDWORD, XDWORD1, XDWORD2, XDWORD3, XDWORD0) + AVX2_SM4_ROUND(2, XDWORD, YDWORD, XDWORD2, XDWORD3, XDWORD0, XDWORD1) + AVX2_SM4_ROUND(3, XDWORD, YDWORD, XDWORD3, XDWORD0, XDWORD1, XDWORD2) + + ADDL $16, CX + CMPL CX, $4*32 + JB avx2_loop + + // Transpose matrix 4 x 4 32bits word + TRANSPOSE_MATRIX(XDWORD0, XDWORD1, XDWORD2, XDWORD3) + + VBROADCASTI128 flip_mask2<>(SB), BYTE_FLIP_MASK + VPSHUFB BYTE_FLIP_MASK, XDWORD0, XDWORD0 + VPSHUFB BYTE_FLIP_MASK, XDWORD1, XDWORD1 + VPSHUFB BYTE_FLIP_MASK, XDWORD2, XDWORD2 + VPSHUFB BYTE_FLIP_MASK, XDWORD3, XDWORD3 + + VMOVDQU XDWORD0, 0(BX) + VMOVDQU XDWORD1, 32(BX) + VMOVDQU XDWORD2, 64(BX) + VMOVDQU XDWORD3, 96(BX) + + VZEROUPPER + RET // func encryptBlockAsm(xk *uint32, dst, src *byte) TEXT ·encryptBlockAsm(SB),NOSPLIT,$0 @@ -313,33 +419,10 @@ TEXT ·encryptBlockAsm(SB),NOSPLIT,$0 XORL CX, CX loop: - PINSRD $0, 0(AX)(CX*1), x - PXOR t1, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t0 - - PINSRD $0, 4(AX)(CX*1), x - PXOR t0, x - PXOR t2, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t1 - - PINSRD $0, 8(AX)(CX*1), x - PXOR t0, x - PXOR t1, x - PXOR t3, x - SM4_TAO_L1(x, y) - PXOR x, t2 - - PINSRD $0, 12(AX)(CX*1), x - PXOR t0, x - PXOR t1, x - PXOR t2, x - SM4_TAO_L1(x, y) - PXOR x, t3 + SM4_SINGLE_ROUND(0, x, y, t0, t1, t2, t3) + SM4_SINGLE_ROUND(1, x, y, t1, t2, t3, t0) + SM4_SINGLE_ROUND(2, x, y, t2, t3, t0, t1) + SM4_SINGLE_ROUND(3, x, y, t3, t0, t1, t2) ADDL $16, CX CMPL CX, $4*32 diff --git a/sm4/asm_arm64.s b/sm4/asm_arm64.s index 7b3047d..d2c9ff7 100644 --- a/sm4/asm_arm64.s +++ b/sm4/asm_arm64.s @@ -214,11 +214,11 @@ ksLoop: RET -// func encryptBlocksAsm(xk *uint32, dst, src *byte) +// func encryptBlocksAsm(xk *uint32, dst, src []byte) TEXT ·encryptBlocksAsm(SB),NOSPLIT,$0 MOVD xk+0(FP), R8 MOVD dst+8(FP), R9 - MOVD src+16(FP), R10 + MOVD src+32(FP), R10 VLD1 (R10), [V5.S4, V6.S4, V7.S4, V8.S4] VMOV V5.S[0], t0.S[0] diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index d835007..1cae603 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -13,9 +13,10 @@ import ( var supportSM4 = cpu.ARM64.HasSM4 var supportsAES = cpu.X86.HasAES || cpu.ARM64.HasAES var supportsGFMUL = cpu.X86.HasPCLMULQDQ || cpu.ARM64.HasPMULL +var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 //go:noescape -func encryptBlocksAsm(xk *uint32, dst, src *byte) +func encryptBlocksAsm(xk *uint32, dst, src []byte) //go:noescape func encryptBlockAsm(xk *uint32, dst, src *byte) @@ -33,7 +34,11 @@ func newCipher(key []byte) (cipher.Block, error) { if !supportsAES { return newCipherGeneric(key) } - c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, 4, 4 * BlockSize} + blocks := 4 + if useAVX2 { + blocks = 8 + } + c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0]) if supportsAES && supportsGFMUL { return &sm4CipherGCM{c}, nil @@ -68,7 +73,7 @@ func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsm(&c.enc[0], &dst[0], &src[0]) + encryptBlocksAsm(&c.enc[0], dst, src) } func (c *sm4CipherAsm) Decrypt(dst, src []byte) { @@ -94,7 +99,7 @@ func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsm(&c.dec[0], &dst[0], &src[0]) + encryptBlocksAsm(&c.dec[0], dst, src) } // expandKey is used by BenchmarkExpand to ensure that the asm implementation diff --git a/sm4_test/cbc_sm4_test.go b/sm4_test/cbc_sm4_test.go index 9b90131..36ebab3 100644 --- a/sm4_test/cbc_sm4_test.go +++ b/sm4_test/cbc_sm4_test.go @@ -106,6 +106,31 @@ var cbcSM4Tests = []struct { 0xf7, 0x90, 0x47, 0x74, 0xaf, 0x40, 0xfd, 0x72, 0xc6, 0x17, 0xeb, 0xc0, 0x8b, 0x01, 0x71, 0x5c, }, }, + { + "17 blocks", + []byte("0123456789ABCDEF"), + []byte("0123456789ABCDEF"), + []byte("Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World Hello World"), + []byte{ + 0xd3, 0x1e, 0x36, 0x83, 0xe4, 0xfc, 0x9b, 0x51, 0x6a, 0x2c, 0x0f, 0x98, 0x36, 0x76, 0xa9, 0xeb, + 0x1f, 0xdc, 0xc3, 0x2a, 0xf3, 0x84, 0x08, 0x97, 0x81, 0x57, 0xa2, 0x06, 0x5d, 0xe3, 0x4c, 0x6a, + 0xe0, 0x02, 0xd6, 0xe4, 0xf5, 0x66, 0x87, 0xc4, 0xcc, 0x54, 0x1d, 0x1f, 0x1c, 0xc4, 0x2f, 0xe6, + 0xe5, 0x1d, 0xea, 0x52, 0xb8, 0x0c, 0xc8, 0xbe, 0xae, 0xcc, 0x44, 0xa8, 0x51, 0x81, 0x08, 0x60, + 0xb6, 0x09, 0x7b, 0xb8, 0x7e, 0xdb, 0x53, 0x4b, 0xea, 0x2a, 0xc6, 0xa1, 0xe5, 0xa0, 0x2a, 0xe9, + 0x22, 0x65, 0x5b, 0xa3, 0xb9, 0xcc, 0x63, 0x92, 0x16, 0x0e, 0x2f, 0xf4, 0x3b, 0x93, 0x06, 0x82, + 0xb3, 0x8c, 0x26, 0x2e, 0x06, 0x51, 0x34, 0x2c, 0xe4, 0x3d, 0xd0, 0xc7, 0x2b, 0x8f, 0x31, 0x15, + 0x30, 0xa8, 0x96, 0x1c, 0xbc, 0x8e, 0xf7, 0x4f, 0x6b, 0x69, 0x9d, 0xc9, 0x40, 0x89, 0xd7, 0xe8, + 0x2a, 0xe8, 0xc3, 0x3d, 0xcb, 0x8a, 0x1c, 0xb3, 0x70, 0x7d, 0xe9, 0xe6, 0x88, 0x36, 0x65, 0x21, + 0x7b, 0x34, 0xac, 0x73, 0x8d, 0x4f, 0x11, 0xde, 0xd4, 0x21, 0x45, 0x9f, 0x1f, 0x3e, 0xe8, 0xcf, + 0x50, 0x92, 0x8c, 0xa4, 0x79, 0x58, 0x3a, 0x26, 0x01, 0x7b, 0x99, 0x5c, 0xff, 0x8d, 0x66, 0x5b, + 0x07, 0x86, 0x0e, 0x22, 0xb4, 0xb4, 0x83, 0x74, 0x33, 0x79, 0xd0, 0x54, 0x9f, 0x03, 0x6b, 0x60, + 0xa1, 0x52, 0x3c, 0x61, 0x1d, 0x91, 0xbf, 0x50, 0x00, 0xfb, 0x62, 0x58, 0xfa, 0xd3, 0xbd, 0x17, + 0x7d, 0x6f, 0xda, 0x76, 0x9a, 0xdb, 0x01, 0x96, 0x97, 0xc9, 0x5f, 0x64, 0x20, 0x3c, 0x70, 0x7a, + 0x40, 0x1f, 0x35, 0xc8, 0x22, 0xf2, 0x76, 0x6d, 0x8e, 0x4a, 0x78, 0xd7, 0x8d, 0x52, 0x51, 0x60, + 0x39, 0x14, 0xd8, 0xcd, 0xc7, 0x4b, 0x3f, 0xb3, 0x16, 0xdf, 0x52, 0xba, 0xcb, 0x98, 0x56, 0xaa, + 0x97, 0x8b, 0xab, 0xa7, 0xbf, 0xe8, 0x0f, 0x16, 0x27, 0xbb, 0x56, 0xce, 0x10, 0xe5, 0x90, 0x05, + }, + }, { "A.1", []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10},