From ed0b2551ed149c115414a33ad804f65f549517a2 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Thu, 23 May 2024 13:09:16 +0800 Subject: [PATCH] kdf-sm3: mult 8 way avx2 version #222 --- sm3/kdf_amd64.go | 4 + sm3/{kdf_mult_asm.go => kdf_mult4_asm.go} | 0 sm3/kdf_mult8_amd64.go | 97 +++++ sm3/sm3blocks_avx2_amd64.s | 443 ++++++++++++++++++++++ sm3/sm3blocks_avx2_test.go | 147 +++++++ 5 files changed, 691 insertions(+) rename sm3/{kdf_mult_asm.go => kdf_mult4_asm.go} (100%) create mode 100644 sm3/kdf_mult8_amd64.go create mode 100644 sm3/sm3blocks_avx2_amd64.s create mode 100644 sm3/sm3blocks_avx2_test.go diff --git a/sm3/kdf_amd64.go b/sm3/kdf_amd64.go index a7f5a7b..ce53788 100644 --- a/sm3/kdf_amd64.go +++ b/sm3/kdf_amd64.go @@ -7,5 +7,9 @@ func kdf(baseMD *digest, keyLen int, limit int) []byte { return kdfGeneric(baseMD, keyLen, limit) } + if useAVX2 && limit >= 8 { + return kdfBy8(baseMD, keyLen, limit) + } + return kdfBy4(baseMD, keyLen, limit) } diff --git a/sm3/kdf_mult_asm.go b/sm3/kdf_mult4_asm.go similarity index 100% rename from sm3/kdf_mult_asm.go rename to sm3/kdf_mult4_asm.go diff --git a/sm3/kdf_mult8_amd64.go b/sm3/kdf_mult8_amd64.go new file mode 100644 index 0000000..8adf6c4 --- /dev/null +++ b/sm3/kdf_mult8_amd64.go @@ -0,0 +1,97 @@ +//go:build !purego + +package sm3 + +import "encoding/binary" + +// p || state || words +// p = 64 * 8 * 2 = 1024 +// state = 8 * 32 = 256 +// words = 68 * 32 = 2176 +const preallocSizeBy8 = 3456 + +const parallelSize8 = 8 + +func kdfBy8(baseMD *digest, keyLen int, limit int) []byte { + var t uint64 + blocks := 1 + len := baseMD.len + 4 + remainlen := len % 64 + if remainlen < 56 { + t = 56 - remainlen + } else { + t = 64 + 56 - remainlen + blocks = 2 + } + len <<= 3 + + var ct uint32 = 1 + k := make([]byte, keyLen) + ret := k + + // prepare temporary buffer + tmpStart := parallelSize8 * blocks * BlockSize + buffer := make([]byte, preallocSizeBy8) + tmp := buffer[tmpStart:] + // prepare processing data + var data [parallelSize8]*byte + var digs [parallelSize8]*[8]uint32 + var states [parallelSize8][8]uint32 + for j := 0; j < parallelSize8; j++ { + digs[j] = &states[j] + } + + times := limit / parallelSize8 + for i := 0; i < times; i++ { + for j := 0; j < parallelSize8; j++ { + // prepare states + states[j] = baseMD.h + // prepare data + p := buffer[blocks*BlockSize*j:] + data[j] = &p[0] + prepareData(baseMD, p, ct, len, t) + ct++ + } + blockMultBy8(&digs[0], &data[0], &tmp[0], blocks) + for j := 0; j < parallelSize8; j++ { + copyResult(ret, digs[j]) + ret = ret[Size:] + } + } + + remain := limit % parallelSize8 + if remain >= 4 { + for j := 0; j < 4; j++ { + // prepare states + states[j] = baseMD.h + // prepare data + p := buffer[blocks*BlockSize*j:] + data[j] = &p[0] + prepareData(baseMD, p, ct, len, t) + ct++ + } + blockMultBy4(&digs[0], &data[0], &tmp[0], blocks) + for j := 0; j < 4; j++ { + copyResult(ret, digs[j]) + ret = ret[Size:] + } + remain -= 4 + } + + for i := 0; i < remain; i++ { + binary.BigEndian.PutUint32(tmp[:], ct) + md := *baseMD + md.Write(tmp[:4]) + h := md.checkSum() + copy(ret[i*Size:], h[:]) + ct++ + } + + return k +} + +//go:noescape +func blockMultBy8(dig **[8]uint32, p **byte, buffer *byte, blocks int) + +//go:noescape +func transposeMatrix8x8(dig **[8]uint32) diff --git a/sm3/sm3blocks_avx2_amd64.s b/sm3/sm3blocks_avx2_amd64.s new file mode 100644 index 0000000..3a0d49d --- /dev/null +++ b/sm3/sm3blocks_avx2_amd64.s @@ -0,0 +1,443 @@ +//go:build !purego + +#include "textflag.h" + +// shuffle byte order from LE to BE +DATA flip_mask<>+0x00(SB)/8, $0x0405060700010203 +DATA flip_mask<>+0x08(SB)/8, $0x0c0d0e0f08090a0b +DATA flip_mask<>+0x10(SB)/8, $0x0405060700010203 +DATA flip_mask<>+0x18(SB)/8, $0x0c0d0e0f08090a0b +GLOBL flip_mask<>(SB), 8, $32 + +// left rotations of 32-bit words by 8-bit increments +DATA r08_mask<>+0x00(SB)/8, $0x0605040702010003 +DATA r08_mask<>+0x08(SB)/8, $0x0E0D0C0F0A09080B +DATA r08_mask<>+0x10(SB)/8, $0x0605040702010003 +DATA r08_mask<>+0x18(SB)/8, $0x0E0D0C0F0A09080B +GLOBL r08_mask<>(SB), 8, $32 + +#define a Y0 +#define b Y1 +#define c Y2 +#define d Y3 +#define e Y4 +#define f Y5 +#define g Y6 +#define h Y7 +#define TMP1 Y8 +#define TMP2 Y9 +#define TMP3 Y10 +#define TMP4 Y11 + +#define srcPtr1 CX +#define srcPtr2 R8 +#define srcPtr3 R9 +#define srcPtr4 R10 +#define srcPtr5 R11 +#define srcPtr6 R12 +#define srcPtr7 R13 +#define srcPtr8 R14 + +// transpose matrix function, AVX2 version +// parameters: +// - r0: 256 bits register as input/output data +// - r1: 256 bits register as input/output data +// - r2: 256 bits register as input/output data +// - r3: 256 bits register as input/output data +// - r4: 256 bits register as input/output data +// - r5: 256 bits register as input/output data +// - r6: 256 bits register as input/output data +// - r7: 256 bits register as input/output data +// - tmp1: 256 bits temp register +// - tmp2: 256 bits temp register +// - tmp3: 256 bits temp register +// - tmp4: 256 bits temp register +#define TRANSPOSE_MATRIX(r0, r1, r2, r3, r4, r5, r6, r7, tmp1, tmp2, tmp3, tmp4) \ + ; \ // [r0, r1, r2, r3] => [tmp3, tmp4, tmp2, tmp1] + VPUNPCKHDQ r1, r0, tmp4; \ // tmp4 = [w15, w7, w14, w6, w11, w3, w10, w2] + VPUNPCKLDQ r1, r0, r0; \ // r0 = [w13, w5, w12, w4, w9, w1, w8, w0] + VPUNPCKLDQ r3, r2, tmp3; \ // tmp3 = [w29, w21, w28, w20, w25, w17, w24, w16] + VPUNPCKHDQ r3, r2, r2; \ // r2 = [w31, w27, w30, w22, w27, w19, w26, w18] + VPUNPCKHQDQ tmp3, r0, tmp2; \ // tmp2 = [w29, w21, w13, w5, w25, w17, w9, w1] + VPUNPCKLQDQ tmp3, r0, tmp1; \ // tmp1 = [w28, w20, w12, w4, w24, w16, w8, w0] + VPUNPCKHQDQ r2, tmp4, tmp3; \ // tmp3 = [w31, w23, w15, w7, w27, w19, w11, w3] + VPUNPCKLQDQ r2, tmp4, tmp4; \ // tmp4 = [w30, w22, w14, w6, w26, w18, w10, w2] + ; \ // [r4, r5, r6, r7] => [r4, r5, r6, r7] + VPUNPCKHDQ r5, r4, r1; \ // r1 = [w47, w39, w46, w38, w43, w35, w42, w34] + VPUNPCKLDQ r5, r4, r4; \ // r4 = [w45, w37, w44, w36, w41, w33, w40, w32] + VPUNPCKLDQ r7, r6, r0; \ // r0 = [w61, w53, w60, w52, w57, w49, w56, w48] + VPUNPCKHDQ r7, r6, r6; \ // r6 = [w63, w59, w52, w54, w59, w51, w58, w50] + VPUNPCKHQDQ r0, r4, r5; \ // r5 = [w61, w53, w45, w37, w57, w49, w41, w33] + VPUNPCKLQDQ r0, r4, r4; \ // r4 = [w60, w52, w44, w36, w56, w48, w40, w32] + VPUNPCKHQDQ r6, r1, r7; \ // r7 = [w63, w55, w47, w39, w59, w51, w43, w35] + VPUNPCKLQDQ r6, r1, r6; \ // r6 = [w62, w54, w46, w38, w58, w50, w42, w34] + ; \ // [tmp3, tmp4, tmp2, tmp1], [r4, r5, r6, r7] => [r0, r1, r2, r3, r4, r5, r6, r7] + VPERM2I128 $0x20, r4, tmp1, r0; \ // r0 = [w56, w48, w40, w32, w24, w16, w8, w0] + VPERM2I128 $0x20, r5, tmp2, r1; \ // r1 = [w57, w49, w41, w33, w25, w17, w9, w1] + VPERM2I128 $0x20, r6, tmp4, r2; \ // r2 = [w58, w50, w42, w34, w26, w18, w10, w2] + VPERM2I128 $0x20, r7, tmp3, r3; \ // r3 = [w59, w51, w43, w35, w27, w19, w11, w3] + VPERM2I128 $0x31, r4, tmp1, r4; \ // r4 = [w60, w52, w44, w36, w28, w20, w12, w4] + VPERM2I128 $0x31, r5, tmp2, r5; \ // r5 = [w61, w53, w45, w37, w29, w21, w13, w5] + VPERM2I128 $0x31, r6, tmp4, r6; \ // r6 = [w62, w54, w46, w38, w30, w22, w14, w6] + VPERM2I128 $0x31, r7, tmp3, r7; \ // r7 = [w63, w55, w47, w39, w31, w23, w15, w7] + +// xorm (mem), reg +// xor reg to mem using reg-mem xor and store +#define xorm(P1, P2) \ + VPXOR P1, P2, P2; \ + VMOVDQU P2, P1 + +// store 256 bits +#define storeWord(W, j) VMOVDQU W, (256+(j)*32)(BX) +// load 256 bits +#define loadWord(W, i) VMOVDQU (256+(i)*32)(BX), W + +#define prepare8Words(i) \ + VMOVDQU (i*32)(srcPtr1), a; \ + VMOVDQU (i*32)(srcPtr2), b; \ + VMOVDQU (i*32)(srcPtr3), c; \ + VMOVDQU (i*32)(srcPtr4), d; \ + VMOVDQU (i*32)(srcPtr5), e; \ + VMOVDQU (i*32)(srcPtr6), f; \ + VMOVDQU (i*32)(srcPtr7), g; \ + VMOVDQU (i*32)(srcPtr8), h; \ + ; \ + TRANSPOSE_MATRIX(a, b, c, d, e, f, g, h, TMP1, TMP2, TMP3, TMP4); \ + VPSHUFB flip_mask<>(SB), a, a; \ + VPSHUFB flip_mask<>(SB), b, b; \ + VPSHUFB flip_mask<>(SB), c, c; \ + VPSHUFB flip_mask<>(SB), d, d; \ + VPSHUFB flip_mask<>(SB), e, e; \ + VPSHUFB flip_mask<>(SB), f, f; \ + VPSHUFB flip_mask<>(SB), g, g; \ + VPSHUFB flip_mask<>(SB), h, h; \ + ; \ + storeWord(a, 8*i+0); \ + storeWord(b, 8*i+1); \ + storeWord(c, 8*i+2); \ + storeWord(d, 8*i+3); \ + storeWord(e, 8*i+4); \ + storeWord(f, 8*i+5); \ + storeWord(g, 8*i+6); \ + storeWord(h, 8*i+7) + +#define saveState \ + VMOVDQU a, (0*32)(BX); \ + VMOVDQU b, (1*32)(BX); \ + VMOVDQU c, (2*32)(BX); \ + VMOVDQU d, (3*32)(BX); \ + VMOVDQU e, (4*32)(BX); \ + VMOVDQU f, (5*32)(BX); \ + VMOVDQU g, (6*32)(BX); \ + VMOVDQU h, (7*32)(BX) + +#define loadState \ + VMOVDQU (0*32)(BX), a; \ + VMOVDQU (1*32)(BX), b; \ + VMOVDQU (2*32)(BX), c; \ + VMOVDQU (3*32)(BX), d; \ + VMOVDQU (4*32)(BX), e; \ + VMOVDQU (5*32)(BX), f; \ + VMOVDQU (6*32)(BX), g; \ + VMOVDQU (7*32)(BX), h + +// r <<< n +#define VPROLD(r, n) \ + VPSLLD $(n), r, TMP1; \ + VPSRLD $(32-n), r, r; \ + VPOR TMP1, r, r + +// d = r <<< n +#define VPROLD2(r, d, n) \ + VPSLLD $(n), r, TMP1; \ + VPSRLD $(32-n), r, d; \ + VPOR TMP1, d, d + +#define LOAD_T(index, T) \ + VPBROADCASTD (index*4)(AX), T + +#define ROUND_00_11(index, a, b, c, d, e, f, g, h) \ + VPROLD2(a, Y13, 12); \ // a <<< 12 + LOAD_T(index, Y12); \ + VPADDD Y12, Y13, Y12; \ + VPADDD e, Y12, Y12; \ + VPROLD(Y12, 7); \ // SS1 + VPXOR Y12, Y13, Y13; \ // SS2 + ; \ + VPXOR a, b, Y14; \ + VPXOR c, Y14, Y14; \ // (a XOR b XOR c) + VPADDD d, Y14, Y14; \ // (a XOR b XOR c) + d + loadWord(Y10, index); \ + loadWord(Y11, index+4); \ + VPXOR Y10, Y11, Y11; \ //Wt XOR Wt+4 + VPADDD Y11, Y14, Y14; \ // (a XOR b XOR c) + d + Wt XOR Wt+4 + VPADDD Y14, Y13, Y13; \ // TT1 + VPADDD h, Y10, Y10; \ // Wt + h + VPADDD Y12, Y10, Y10; \ // Wt + h + SS1 + VPXOR e, f, Y11; \ + VPXOR g, Y11, Y11; \ // (e XOR f XOR g) + VPADDD Y11, Y10, Y10; \ // TT2 = (e XOR f XOR g) + Wt + h + SS1 + ; \ // copy result + VPROLD(b, 9); \ + VMOVDQU Y13, h; \ + VPROLD(f, 19); \ + VPROLD2(Y10, Y13, 9); \ // tt2 <<< 9 + VPSHUFB r08_mask<>(SB), Y13, Y11; \ // ROTL(17, tt2) + VPXOR Y10, Y13, Y13; \ // tt2 XOR ROTL(9, tt2) + VPXOR Y11, Y13, d + +#define MESSAGE_SCHEDULE(index) \ + loadWord(Y10, index+1); \ // Wj-3 + VPROLD(Y10, 15); \ + VPXOR (256+(index-12)*32)(BX), Y10, Y10; \ // Wj-16 + VPXOR (256+(index-5)*32)(BX), Y10, Y10; \ // Wj-9 + ; \ // P1 + VPROLD2(Y10, Y11, 15); \ + VPXOR Y11, Y10, Y10; \ + VPSHUFB r08_mask<>(SB), Y11, Y11; \ + VPXOR Y11, Y10, Y10; \ // P1 + loadWord(Y11, index-9); \ // Wj-13 + VPROLD(Y11, 7); \ + VPXOR Y11, Y10, Y10; \ + VPXOR (256+(index-2)*32)(BX), Y10, Y11; \ + storeWord(Y11, index+4) + +#define ROUND_12_15(index, a, b, c, d, e, f, g, h) \ + MESSAGE_SCHEDULE(index); \ + ROUND_00_11(index, a, b, c, d, e, f, g, h) + +#define ROUND_16_63(index, a, b, c, d, e, f, g, h) \ + MESSAGE_SCHEDULE(index); \ // Y11 is Wt+4 now, Pls do not use it + VPROLD2(a, Y13, 12); \ // a <<< 12 + LOAD_T(index, Y12); \ + VPADDD Y12, Y13, Y12; \ + VPADDD e, Y12, Y12; \ + VPROLD(Y12, 7); \ // SS1 + VPXOR Y12, Y13, Y13; \ // SS2 + ; \ + VPOR a, b, Y14; \ + VPAND a, b, Y10; \ + VPAND c, Y14, Y14; \ + VPOR Y10, Y14, Y14; \ // (a AND b) OR (a AND c) OR (b AND c) + VPADDD d, Y14, Y14; \ // (a AND b) OR (a AND c) OR (b AND c) + d + loadWord(Y10, index); \ + VPXOR Y10, Y11, Y11; \ //Wt XOR Wt+4 + VPADDD Y11, Y14, Y14; \ // (a AND b) OR (a AND c) OR (b AND c) + d + Wt XOR Wt+4 + VPADDD Y14, Y13, Y13; \ // TT1 + ; \ + VPADDD h, Y10, Y10; \ // Wt + h + VPADDD Y12, Y10, Y10; \ // Wt + h + SS1 + VPXOR f, g, Y11; \ + VPAND e, Y11, Y11; \ + VPXOR g, Y11, Y11; \ // (f XOR g) AND e XOR g + VPADDD Y11, Y10, Y10; \ // TT2 = (e XOR f XOR g) + Wt + h + SS1 + ; \ // copy result + VPROLD(b, 9); \ + VMOVDQU Y13, h; \ + VPROLD(f, 19); \ + VPROLD2(Y10, Y13, 9); \ // tt2 <<< 9 + VPSHUFB r08_mask<>(SB), Y13, Y11; \ // ROTL(17, tt2) + VPXOR Y10, Y13, Y13; \ // tt2 XOR ROTL(9, tt2) + VPXOR Y11, Y13, d + +// transposeMatrix8x8(dig **[8]uint32) +TEXT ·transposeMatrix8x8(SB),NOSPLIT,$0 + MOVQ dig+0(FP), DI + + // load state + MOVQ (DI), R8 + VMOVDQU (R8), a + MOVQ 8(DI), R8 + VMOVDQU (R8), b + MOVQ 16(DI), R8 + VMOVDQU (R8), c + MOVQ 24(DI), R8 + VMOVDQU (R8), d + MOVQ 32(DI), R8 + VMOVDQU (R8), e + MOVQ 40(DI), R8 + VMOVDQU (R8), f + MOVQ 48(DI), R8 + VMOVDQU (R8), g + MOVQ 56(DI), R8 + VMOVDQU (R8), h + + TRANSPOSE_MATRIX(a, b, c, d, e, f, g, h, TMP1, TMP2, TMP3, TMP4) + + // save state + MOVQ (DI), R8 + VMOVDQU a, (R8) + MOVQ 8(DI), R8 + VMOVDQU b, (R8) + MOVQ 16(DI), R8 + VMOVDQU c, (R8) + MOVQ 24(DI), R8 + VMOVDQU d, (R8) + MOVQ 32(DI), R8 + VMOVDQU e, (R8) + MOVQ 40(DI), R8 + VMOVDQU f, (R8) + MOVQ 48(DI), R8 + VMOVDQU g, (R8) + MOVQ 56(DI), R8 + VMOVDQU h, (R8) + + VZEROUPPER + + RET + +// blockMultBy8(dig **[8]uint32, p *[]byte, buffer *byte, blocks int) +TEXT ·blockMultBy8(SB),NOSPLIT,$0 + MOVQ dig+0(FP), DI + MOVQ p+8(FP), SI + MOVQ buffer+16(FP), BX + MOVQ blocks+24(FP), DX + + // load state + MOVQ (DI), R8 + VMOVDQU (R8), a + MOVQ 8(DI), R8 + VMOVDQU (R8), b + MOVQ 16(DI), R8 + VMOVDQU (R8), c + MOVQ 24(DI), R8 + VMOVDQU (R8), d + MOVQ 32(DI), R8 + VMOVDQU (R8), e + MOVQ 40(DI), R8 + VMOVDQU (R8), f + MOVQ 48(DI), R8 + VMOVDQU (R8), g + MOVQ 56(DI), R8 + VMOVDQU (R8), h + + TRANSPOSE_MATRIX(a, b, c, d, e, f, g, h, TMP1, TMP2, TMP3, TMP4) + + saveState + + MOVQ $·_K+0(SB), AX + MOVQ (0*8)(SI), srcPtr1 + MOVQ (1*8)(SI), srcPtr2 + MOVQ (2*8)(SI), srcPtr3 + MOVQ (3*8)(SI), srcPtr4 + MOVQ (4*8)(SI), srcPtr5 + MOVQ (5*8)(SI), srcPtr6 + MOVQ (6*8)(SI), srcPtr7 + MOVQ (7*8)(SI), srcPtr8 + +loop: + prepare8Words(0) + prepare8Words(1) + + loadState + + ROUND_00_11(0, a, b, c, d, e, f, g, h) + ROUND_00_11(1, h, a, b, c, d, e, f, g) + ROUND_00_11(2, g, h, a, b, c, d, e, f) + ROUND_00_11(3, f, g, h, a, b, c, d, e) + ROUND_00_11(4, e, f, g, h, a, b, c, d) + ROUND_00_11(5, d, e, f, g, h, a, b, c) + ROUND_00_11(6, c, d, e, f, g, h, a, b) + ROUND_00_11(7, b, c, d, e, f, g, h, a) + ROUND_00_11(8, a, b, c, d, e, f, g, h) + ROUND_00_11(9, h, a, b, c, d, e, f, g) + ROUND_00_11(10, g, h, a, b, c, d, e, f) + ROUND_00_11(11, f, g, h, a, b, c, d, e) + + ROUND_12_15(12, e, f, g, h, a, b, c, d) + ROUND_12_15(13, d, e, f, g, h, a, b, c) + ROUND_12_15(14, c, d, e, f, g, h, a, b) + ROUND_12_15(15, b, c, d, e, f, g, h, a) + + ROUND_16_63(16, a, b, c, d, e, f, g, h) + ROUND_16_63(17, h, a, b, c, d, e, f, g) + ROUND_16_63(18, g, h, a, b, c, d, e, f) + ROUND_16_63(19, f, g, h, a, b, c, d, e) + ROUND_16_63(20, e, f, g, h, a, b, c, d) + ROUND_16_63(21, d, e, f, g, h, a, b, c) + ROUND_16_63(22, c, d, e, f, g, h, a, b) + ROUND_16_63(23, b, c, d, e, f, g, h, a) + ROUND_16_63(24, a, b, c, d, e, f, g, h) + ROUND_16_63(25, h, a, b, c, d, e, f, g) + ROUND_16_63(26, g, h, a, b, c, d, e, f) + ROUND_16_63(27, f, g, h, a, b, c, d, e) + ROUND_16_63(28, e, f, g, h, a, b, c, d) + ROUND_16_63(29, d, e, f, g, h, a, b, c) + ROUND_16_63(30, c, d, e, f, g, h, a, b) + ROUND_16_63(31, b, c, d, e, f, g, h, a) + ROUND_16_63(32, a, b, c, d, e, f, g, h) + ROUND_16_63(33, h, a, b, c, d, e, f, g) + ROUND_16_63(34, g, h, a, b, c, d, e, f) + ROUND_16_63(35, f, g, h, a, b, c, d, e) + ROUND_16_63(36, e, f, g, h, a, b, c, d) + ROUND_16_63(37, d, e, f, g, h, a, b, c) + ROUND_16_63(38, c, d, e, f, g, h, a, b) + ROUND_16_63(39, b, c, d, e, f, g, h, a) + ROUND_16_63(40, a, b, c, d, e, f, g, h) + ROUND_16_63(41, h, a, b, c, d, e, f, g) + ROUND_16_63(42, g, h, a, b, c, d, e, f) + ROUND_16_63(43, f, g, h, a, b, c, d, e) + ROUND_16_63(44, e, f, g, h, a, b, c, d) + ROUND_16_63(45, d, e, f, g, h, a, b, c) + ROUND_16_63(46, c, d, e, f, g, h, a, b) + ROUND_16_63(47, b, c, d, e, f, g, h, a) + ROUND_16_63(48, a, b, c, d, e, f, g, h) + ROUND_16_63(49, h, a, b, c, d, e, f, g) + ROUND_16_63(50, g, h, a, b, c, d, e, f) + ROUND_16_63(51, f, g, h, a, b, c, d, e) + ROUND_16_63(52, e, f, g, h, a, b, c, d) + ROUND_16_63(53, d, e, f, g, h, a, b, c) + ROUND_16_63(54, c, d, e, f, g, h, a, b) + ROUND_16_63(55, b, c, d, e, f, g, h, a) + ROUND_16_63(56, a, b, c, d, e, f, g, h) + ROUND_16_63(57, h, a, b, c, d, e, f, g) + ROUND_16_63(58, g, h, a, b, c, d, e, f) + ROUND_16_63(59, f, g, h, a, b, c, d, e) + ROUND_16_63(60, e, f, g, h, a, b, c, d) + ROUND_16_63(61, d, e, f, g, h, a, b, c) + ROUND_16_63(62, c, d, e, f, g, h, a, b) + ROUND_16_63(63, b, c, d, e, f, g, h, a) + + xorm( 0(BX), a) + xorm( 32(BX), b) + xorm( 64(BX), c) + xorm( 96(BX), d) + xorm( 128(BX), e) + xorm( 160(BX), f) + xorm( 192(BX), g) + xorm(224(BX), h) + + LEAQ 64(srcPtr1), srcPtr1 + LEAQ 64(srcPtr2), srcPtr2 + LEAQ 64(srcPtr3), srcPtr3 + LEAQ 64(srcPtr4), srcPtr4 + LEAQ 64(srcPtr5), srcPtr5 + LEAQ 64(srcPtr6), srcPtr6 + LEAQ 64(srcPtr7), srcPtr7 + LEAQ 64(srcPtr8), srcPtr8 + + DECQ DX + JNZ loop + + TRANSPOSE_MATRIX(a, b, c, d, e, f, g, h, TMP1, TMP2, TMP3, TMP4) + + // save state + MOVQ (DI), R8 + VMOVDQU a, (R8) + MOVQ 8(DI), R8 + VMOVDQU b, (R8) + MOVQ 16(DI), R8 + VMOVDQU c, (R8) + MOVQ 24(DI), R8 + VMOVDQU d, (R8) + MOVQ 32(DI), R8 + VMOVDQU e, (R8) + MOVQ 40(DI), R8 + VMOVDQU f, (R8) + MOVQ 48(DI), R8 + VMOVDQU g, (R8) + MOVQ 56(DI), R8 + VMOVDQU h, (R8) + + VZEROUPPER + RET diff --git a/sm3/sm3blocks_avx2_test.go b/sm3/sm3blocks_avx2_test.go new file mode 100644 index 0000000..4e088ef --- /dev/null +++ b/sm3/sm3blocks_avx2_test.go @@ -0,0 +1,147 @@ +//go:build amd64 && !purego + +package sm3 + +import ( + "fmt" + "testing" +) + +func initState8() [8]*[8]uint32 { + d := new(digest) + d.Reset() + var dig1 = d.h + var dig2 = d.h + var dig3 = d.h + var dig4 = d.h + var dig5 = d.h + var dig6 = d.h + var dig7 = d.h + return [8]*[8]uint32{&d.h, &dig1, &dig2, &dig3, &dig4, &dig5, &dig6, &dig7} +} + +func createOneBlockBy8() [8]*byte { + var p1 [64]byte + p1[0] = 0x61 + p1[1] = 0x62 + p1[2] = 0x63 + p1[3] = 0x80 + p1[63] = 0x18 + var p2 = p1 + var p3 = p1 + var p4 = p1 + var p5 = p1 + var p6 = p1 + var p7 = p1 + var p8 = p1 + return [8]*byte{&p1[0], &p2[0], &p3[0], &p4[0], &p5[0], &p6[0], &p7[0], &p8[0]} +} + +func createTwoBlocksBy8() [8]*byte { + var p1 [128]byte + p1[0] = 0x61 + p1[1] = 0x62 + p1[2] = 0x63 + p1[3] = 0x64 + copy(p1[4:], p1[:4]) + copy(p1[8:], p1[:8]) + copy(p1[16:], p1[:16]) + copy(p1[32:], p1[:32]) + p1[64] = 0x80 + p1[126] = 0x02 + var p2 = p1 + var p3 = p1 + var p4 = p1 + var p5 = p1 + var p6 = p1 + var p7 = p1 + var p8 = p1 + return [8]*byte{&p1[0], &p2[0], &p3[0], &p4[0], &p5[0], &p6[0], &p7[0], &p8[0]} +} + +func TestTransposeMatrix8x8(t *testing.T) { + if !useAVX2 { + t.Skip("AVX2 is not supported") + } + var m [8][8]uint32 + for i := 0; i < 8; i++ { + for j := 0; j < 8; j++ { + m[i][j] = uint32(i*8 + j) + } + } + input := [8]*[8]uint32{&m[0], &m[1], &m[2], &m[3], &m[4], &m[5], &m[6], &m[7]} + transposeMatrix8x8(&input[0]) + for i := 0; i < 8; i++ { + for j := 0; j < 8; j++ { + if m[j][i] != uint32(i*8+j) { + t.Errorf("m[%d][%d] got %d", i, j, m[j][i]) + } + } + } + transposeMatrix8x8(&input[0]) + for i := 0; i < 8; i++ { + for j := 0; j < 8; j++ { + if m[i][j] != uint32(i*8+j) { + t.Errorf("m[%d][%d] got %d", i, j, m[i][j]) + } + } + } +} + +func TestBlockMultBy8(t *testing.T) { + if !useAVX2 { + t.Skip("AVX2 is not supported") + } + digs := initState8() + p := createOneBlockBy8() + buffer := make([]byte, preallocSizeBy8) + blockMultBy8(&digs[0], &p[0], &buffer[0], 1) + expected := "[66c7f0f4 62eeedd9 d1f2d46b dc10e4e2 4167c487 5cf2f7a2 297da02b 8f4ba8e0]" + for i := 0; i < 8; i++ { + s := fmt.Sprintf("%x", digs[i][:]) + if s != expected { + t.Errorf("digs[%d] got %s", i, s) + } + } + + digs = initState8() + p = createTwoBlocksBy8() + blockMultBy8(&digs[0], &p[0], &buffer[0], 2) + expected = "[debe9ff9 2275b8a1 38604889 c18e5a4d 6fdb70e5 387e5765 293dcba3 9c0c5732]" + for i := 0; i < 8; i++ { + s := fmt.Sprintf("%x", digs[i][:]) + if s != expected { + t.Errorf("digs[%d] got %s", i, s) + } + } +} + +func BenchmarkOneBlockBy8(b *testing.B) { + if !useAVX2 { + b.Skip("AVX2 is not supported") + } + digs := initState8() + p := createOneBlockBy8() + buffer := make([]byte, preallocSizeBy8) + b.SetBytes(64 * 8) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + blockMultBy8(&digs[0], &p[0], &buffer[0], 1) + } +} + +func BenchmarkTwoBlocksBy8(b *testing.B) { + if !useAVX2 { + b.Skip("AVX2 is not supported") + } + digs := initState8() + p := createTwoBlocksBy8() + buffer := make([]byte, preallocSizeBy8) + b.SetBytes(64 * 2 * 8) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + blockMultBy8(&digs[0], &p[0], &buffer[0], 2) + } +}