diff --git a/sm3/kdf_mult4_asm.go b/sm3/kdf_mult4_asm.go index 81d2dee..eeb037c 100644 --- a/sm3/kdf_mult4_asm.go +++ b/sm3/kdf_mult4_asm.go @@ -4,11 +4,11 @@ package sm3 import "encoding/binary" -func prepareData(baseMD *digest, p []byte, ct uint32, len, t uint64) { +func prepareInitData(baseMD *digest, p []byte, len, t uint64) { if baseMD.nx > 0 { copy(p, baseMD.x[:baseMD.nx]) } - binary.BigEndian.PutUint32(p[baseMD.nx:], ct) + // binary.BigEndian.PutUint32(p[baseMD.nx:], ct) // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. var tmp [64 + 8]byte // padding + length buffer tmp[0] = 0x80 @@ -17,17 +17,6 @@ func prepareData(baseMD *digest, p []byte, ct uint32, len, t uint64) { copy(p[baseMD.nx+4:], padlen) } -func copyResult(result []byte, dig *[8]uint32) { - binary.BigEndian.PutUint32(result[0:], dig[0]) - binary.BigEndian.PutUint32(result[4:], dig[1]) - binary.BigEndian.PutUint32(result[8:], dig[2]) - binary.BigEndian.PutUint32(result[12:], dig[3]) - binary.BigEndian.PutUint32(result[16:], dig[4]) - binary.BigEndian.PutUint32(result[20:], dig[5]) - binary.BigEndian.PutUint32(result[24:], dig[6]) - binary.BigEndian.PutUint32(result[28:], dig[7]) -} - // p || state || words // p = 64 * 4 * 2 = 512 // state = 8 * 16 = 128 @@ -57,11 +46,21 @@ func kdfBy4(baseMD *digest, keyLen int, limit int) []byte { buffer := make([]byte, preallocSizeBy4) tmp := buffer[tmpStart:] // prepare processing data - var data [parallelSize4]*byte + var dataPtrs [parallelSize4]*byte + var data [parallelSize4][]byte var digs [parallelSize4]*[8]uint32 var states [parallelSize4][8]uint32 - for j := 0; j < 4; j++ { + + for j := 0; j < parallelSize4; j++ { digs[j] = &states[j] + p := buffer[blocks*BlockSize*j:] + data[j] = p + dataPtrs[j] = &p[0] + if j == 0 { + prepareInitData(baseMD, p, len, t) + } else { + copy(p, data[0]) + } } var ct uint32 = 1 @@ -73,16 +72,12 @@ func kdfBy4(baseMD *digest, keyLen int, limit int) []byte { // prepare states states[j] = baseMD.h // prepare data - p := buffer[blocks*BlockSize*j:] - data[j] = &p[0] - prepareData(baseMD, p, ct, len, t) + binary.BigEndian.PutUint32(data[j][baseMD.nx:], ct) ct++ } - blockMultBy4(&digs[0], &data[0], &tmp[0], blocks) - for j := 0; j < parallelSize4; j++ { - copyResult(ret, digs[j]) - ret = ret[Size:] - } + blockMultBy4(&digs[0], &dataPtrs[0], &tmp[0], blocks) + copyResultsBy4(&states[0][0], &ret[0]) + ret = ret[Size*parallelSize4:] } remain := limit % parallelSize4 for i := 0; i < remain; i++ { @@ -99,3 +94,6 @@ func kdfBy4(baseMD *digest, keyLen int, limit int) []byte { //go:noescape func blockMultBy4(dig **[8]uint32, p **byte, buffer *byte, blocks int) + +//go:noescape +func copyResultsBy4(dig *uint32, p *byte) diff --git a/sm3/kdf_mult8_amd64.go b/sm3/kdf_mult8_amd64.go index 8adf6c4..067a514 100644 --- a/sm3/kdf_mult8_amd64.go +++ b/sm3/kdf_mult8_amd64.go @@ -34,48 +34,50 @@ func kdfBy8(baseMD *digest, keyLen int, limit int) []byte { buffer := make([]byte, preallocSizeBy8) tmp := buffer[tmpStart:] // prepare processing data - var data [parallelSize8]*byte + var dataPtrs [parallelSize8]*byte + var data [parallelSize8][]byte var digs [parallelSize8]*[8]uint32 var states [parallelSize8][8]uint32 + for j := 0; j < parallelSize8; j++ { digs[j] = &states[j] + p := buffer[blocks*BlockSize*j:] + data[j] = p + dataPtrs[j] = &p[0] + if j == 0 { + prepareInitData(baseMD, p, len, t) + } else { + copy(p, data[0]) + } } - + times := limit / parallelSize8 for i := 0; i < times; i++ { for j := 0; j < parallelSize8; j++ { // prepare states states[j] = baseMD.h // prepare data - p := buffer[blocks*BlockSize*j:] - data[j] = &p[0] - prepareData(baseMD, p, ct, len, t) + binary.BigEndian.PutUint32(data[j][baseMD.nx:], ct) ct++ } - blockMultBy8(&digs[0], &data[0], &tmp[0], blocks) - for j := 0; j < parallelSize8; j++ { - copyResult(ret, digs[j]) - ret = ret[Size:] - } + blockMultBy8(&digs[0], &dataPtrs[0], &tmp[0], blocks) + copyResultsBy8(&states[0][0], &ret[0]) + ret = ret[Size*parallelSize8:] } remain := limit % parallelSize8 - if remain >= 4 { - for j := 0; j < 4; j++ { + if remain >= parallelSize4 { + for j := 0; j < parallelSize4; j++ { // prepare states states[j] = baseMD.h // prepare data - p := buffer[blocks*BlockSize*j:] - data[j] = &p[0] - prepareData(baseMD, p, ct, len, t) + binary.BigEndian.PutUint32(data[j][baseMD.nx:], ct) ct++ } - blockMultBy4(&digs[0], &data[0], &tmp[0], blocks) - for j := 0; j < 4; j++ { - copyResult(ret, digs[j]) - ret = ret[Size:] - } - remain -= 4 + blockMultBy4(&digs[0], &dataPtrs[0], &tmp[0], blocks) + copyResultsBy4(&states[0][0], &ret[0]) + ret = ret[Size*parallelSize4:] + remain -= parallelSize4 } for i := 0; i < remain; i++ { @@ -95,3 +97,6 @@ func blockMultBy8(dig **[8]uint32, p **byte, buffer *byte, blocks int) //go:noescape func transposeMatrix8x8(dig **[8]uint32) + +//go:noescape +func copyResultsBy8(dig *uint32, p *byte) diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index ea49666..7b244c9 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -469,6 +469,7 @@ func BenchmarkKdfWithSM3(b *testing.B) { z := make([]byte, 512) for _, tt := range tests { b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) { + b.SetBytes(int64(tt.kLen)) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { diff --git a/sm3/sm3blocks_arm64.s b/sm3/sm3blocks_arm64.s index 2b7cab8..26c050b 100644 --- a/sm3/sm3blocks_arm64.s +++ b/sm3/sm3blocks_arm64.s @@ -309,3 +309,28 @@ loop: VST1 [h.S4], (R20) RET + +// func copyResultsBy4(dig *uint32, dst *byte) +TEXT ·copyResultsBy4(SB),NOSPLIT,$0 +#define digPtr R0 +#define dstPtr R1 + MOVD dig+0(FP), digPtr + MOVD dst+8(FP), dstPtr + + // load state + VLD1.P 64(digPtr), [a.S4, b.S4, c.S4, d.S4] + VLD1 64(digPtr), [e.S4, f.S4, g.S4, h.S4] + + VREV32 a.B16, a.B16 + VREV32 b.B16, b.B16 + VREV32 c.B16, c.B16 + VREV32 d.B16, d.B16 + VREV32 e.B16, e.B16 + VREV32 f.B16, f.B16 + VREV32 g.B16, g.B16 + VREV32 h.B16, h.B16 + + VST1.P [a.B16, b.B16, c.B16, d.B16], 64(dstPtr) + VST1 [e.B16, f.B16, g.B16, h.B16], (dstPtr) + + RET diff --git a/sm3/sm3blocks_avx2_amd64.s b/sm3/sm3blocks_avx2_amd64.s index 257b2d3..64ccdc1 100644 --- a/sm3/sm3blocks_avx2_amd64.s +++ b/sm3/sm3blocks_avx2_amd64.s @@ -440,3 +440,39 @@ end: VZEROUPPER RET + +// func copyResultsBy8(dig *uint32, dst *byte) +TEXT ·copyResultsBy8(SB),NOSPLIT,$0 + MOVQ dig+0(FP), DI + MOVQ dst+8(FP), SI + + // load state + VMOVDQU (0*32)(DI), a + VMOVDQU (1*32)(DI), b + VMOVDQU (2*32)(DI), c + VMOVDQU (3*32)(DI), d + VMOVDQU (4*32)(DI), e + VMOVDQU (5*32)(DI), f + VMOVDQU (6*32)(DI), g + VMOVDQU (7*32)(DI), h + + VPSHUFB flip_mask<>(SB), a, a + VPSHUFB flip_mask<>(SB), b, b + VPSHUFB flip_mask<>(SB), c, c + VPSHUFB flip_mask<>(SB), d, d + VPSHUFB flip_mask<>(SB), e, e + VPSHUFB flip_mask<>(SB), f, f + VPSHUFB flip_mask<>(SB), g, g + VPSHUFB flip_mask<>(SB), h, h + + VMOVDQU a, (0*32)(SI) + VMOVDQU b, (1*32)(SI) + VMOVDQU c, (2*32)(SI) + VMOVDQU d, (3*32)(SI) + VMOVDQU e, (4*32)(SI) + VMOVDQU f, (5*32)(SI) + VMOVDQU g, (6*32)(SI) + VMOVDQU h, (7*32)(SI) + + VZEROUPPER + RET diff --git a/sm3/sm3blocks_simd_amd64.s b/sm3/sm3blocks_simd_amd64.s index 07f0ab0..f4dfecf 100644 --- a/sm3/sm3blocks_simd_amd64.s +++ b/sm3/sm3blocks_simd_amd64.s @@ -661,3 +661,72 @@ avxEnd: VMOVDQU h, (1*16)(R8) RET + +// func copyResultsBy4(dig *uint32, dst *byte) +TEXT ·copyResultsBy4(SB),NOSPLIT,$0 + MOVQ dig+0(FP), DI + MOVQ dst+8(FP), SI + + CMPB ·useAVX(SB), $1 + JE avx + + // load state + MOVOU (0*16)(DI), a + MOVOU (1*16)(DI), b + MOVOU (2*16)(DI), c + MOVOU (3*16)(DI), d + MOVOU (4*16)(DI), e + MOVOU (5*16)(DI), f + MOVOU (6*16)(DI), g + MOVOU (7*16)(DI), h + + MOVOU flip_mask<>(SB), tmp1 + PSHUFB tmp1, a + PSHUFB tmp1, b + PSHUFB tmp1, c + PSHUFB tmp1, d + PSHUFB tmp1, e + PSHUFB tmp1, f + PSHUFB tmp1, g + PSHUFB tmp1, h + MOVOU a, (0*16)(SI) + MOVOU b, (1*16)(SI) + MOVOU c, (2*16)(SI) + MOVOU d, (3*16)(SI) + MOVOU e, (4*16)(SI) + MOVOU f, (5*16)(SI) + MOVOU g, (6*16)(SI) + MOVOU h, (7*16)(SI) + + RET + +avx: + // load state + VMOVDQU (0*16)(DI), a + VMOVDQU (1*16)(DI), b + VMOVDQU (2*16)(DI), c + VMOVDQU (3*16)(DI), d + VMOVDQU (4*16)(DI), e + VMOVDQU (5*16)(DI), f + VMOVDQU (6*16)(DI), g + VMOVDQU (7*16)(DI), h + + VPSHUFB flip_mask<>(SB), a, a + VPSHUFB flip_mask<>(SB), b, b + VPSHUFB flip_mask<>(SB), c, c + VPSHUFB flip_mask<>(SB), d, d + VPSHUFB flip_mask<>(SB), e, e + VPSHUFB flip_mask<>(SB), f, f + VPSHUFB flip_mask<>(SB), g, g + VPSHUFB flip_mask<>(SB), h, h + + VMOVDQU a, (0*16)(SI) + VMOVDQU b, (1*16)(SI) + VMOVDQU c, (2*16)(SI) + VMOVDQU d, (3*16)(SI) + VMOVDQU e, (4*16)(SI) + VMOVDQU f, (5*16)(SI) + VMOVDQU g, (6*16)(SI) + VMOVDQU h, (7*16)(SI) + + RET