diff --git a/sm4/asm_amd64.s b/sm4/asm_amd64.s index 929a1c8..3c80ecd 100644 --- a/sm4/asm_amd64.s +++ b/sm4/asm_amd64.s @@ -290,7 +290,7 @@ GLOBL fk_mask<>(SB), RODATA, $16 AVX_SM4_TAO_L1(x, y); \ VPXOR x, t0, t0 -// func expandKeyAsm(key *byte, ck, enc, dec *uint32) +// func expandKeyAsm(key *byte, ck, enc, dec *uint32, inst int) TEXT ·expandKeyAsm(SB),NOSPLIT,$0 MOVQ key+0(FP), AX MOVQ ck+8(FP), BX @@ -321,7 +321,7 @@ loop: expand_end: RET -// func encryptBlocksAsm(xk *uint32, dst, src []byte) +// func encryptBlocksAsm(xk *uint32, dst, src []byte, inst int) TEXT ·encryptBlocksAsm(SB),NOSPLIT,$0 MOVQ xk+0(FP), AX MOVQ dst+8(FP), BX @@ -497,7 +497,7 @@ avx2_sm4_done: VZEROUPPER RET -// func encryptBlockAsm(xk *uint32, dst, src *byte) +// func encryptBlockAsm(xk *uint32, dst, src *byte, inst int) TEXT ·encryptBlockAsm(SB),NOSPLIT,$0 MOVQ xk+0(FP), AX MOVQ dst+8(FP), BX diff --git a/sm4/asm_arm64.s b/sm4/asm_arm64.s index 97f0163..480773e 100644 --- a/sm4/asm_arm64.s +++ b/sm4/asm_arm64.s @@ -164,13 +164,43 @@ GLOBL fk_mask<>(SB), (NOPTR+RODATA), $16 VMOV R0, R24_MASK.D[0] \ VMOV R1, R24_MASK.D[1] -// func expandKeyAsm(key *byte, ck, enc, dec *uint32) +#define SM4EKEY_EXPORT_KEYS() \ + VMOV V8.S[3], V10.S[0] \ + VMOV V8.S[2], V10.S[1] \ + VMOV V8.S[1], V10.S[2] \ + VMOV V8.S[0], V10.S[3] \ + VMOV V9.S[3], V11.S[0] \ + VMOV V9.S[2], V11.S[1] \ + VMOV V9.S[1], V11.S[2] \ + VMOV V9.S[0], V11.S[3] \ + VST1.P [V9.S4, V8.S4], 32(R10) \ + VST1.P [V10.S4, V11.S4], -32(R11) + +#define SM4E_ROUND() \ + VLD1.P 16(R10), [V8.B16] \ + VREV32 V8.B16, V8.B16 \ + WORD 0x0884c0ce \ + WORD 0x2884c0ce \ + WORD 0x4884c0ce \ + WORD 0x6884c0ce \ + WORD 0x8884c0ce \ + WORD 0xa884c0ce \ + WORD 0xc884c0ce \ + WORD 0xe884c0ce \ + VREV32 V8.B16, V8.B16 \ + VST1.P [V8.B16], 16(R9) + +// func expandKeyAsm(key *byte, ck, enc, dec *uint32, inst int) TEXT ·expandKeyAsm(SB),NOSPLIT,$0 MOVD key+0(FP), R8 MOVD ck+8(FP), R9 MOVD enc+16(FP), R10 MOVD dec+24(FP), R11 - + MOVD inst+32(FP), R12 + + CMP $1, R12 + BEQ sm4ekey + load_global_data_1() VLD1 (R8), [t0.B16] @@ -193,14 +223,45 @@ ksLoop: ADD $16, R0 CMP $128, R0 BNE ksLoop - RET -// func encryptBlocksAsm(xk *uint32, dst, src []byte) +sm4ekey: + LDP fk_mask<>(SB), (R0, R1) + VMOV R0, FK_MASK.D[0] + VMOV R1, FK_MASK.D[1] + VLD1 (R8), [V8.B16] + VREV32 V8.B16, V8.B16 + VEOR FK_MASK, V8, V8 + ADD $96, R11 + + VLD1.P 64(R9), [V0.S4, V1.S4, V2.S4, V3.S4] + WORD 0x09c960ce //SM4EKEY V9.4S, V8.4S, V0.4S + WORD 0x28c961ce //SM4EKEY V8.4S, V9.4S, V1.4S + SM4EKEY_EXPORT_KEYS() + + WORD 0x09c962ce //SM4EKEY V9.4S, V8.4S, V2.4S + WORD 0x28c963ce //SM4EKEY V8.4S, V9.4S, V3.4S + SM4EKEY_EXPORT_KEYS() + + VLD1.P 64(R9), [V0.S4, V1.S4, V2.S4, V3.S4] + WORD 0x09c960ce //SM4EKEY V9.4S, V8.4S, V0.4S + WORD 0x28c961ce //SM4EKEY V8.4S, V9.4S, V1.4S + SM4EKEY_EXPORT_KEYS() + + WORD 0x09c962ce //SM4EKEY V9.4S, V8.4S, V2.4S + WORD 0x28c963ce //SM4EKEY V8.4S, V9.4S, V3.4S + SM4EKEY_EXPORT_KEYS() + RET + +// func encryptBlocksAsm(xk *uint32, dst, src []byte, inst int) TEXT ·encryptBlocksAsm(SB),NOSPLIT,$0 MOVD xk+0(FP), R8 MOVD dst+8(FP), R9 MOVD src+32(FP), R10 + MOVD inst+24(FP), R11 + + CMP $1, R11 + BEQ sm4niblocks VLD1 (R10), [V5.S4, V6.S4, V7.S4, V8.S4] VMOV V5.S[0], t0.S[0] @@ -271,15 +332,26 @@ encryptBlocksLoop: VMOV t1.S[3], V8.S[2] VMOV t0.S[3], V8.S[3] VST1 [V8.B16], (R9) - RET +sm4niblocks: + VLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4] + VLD1.P 64(R8), [V4.S4, V5.S4, V6.S4, V7.S4] + SM4E_ROUND() + SM4E_ROUND() + SM4E_ROUND() + SM4E_ROUND() + RET -// func encryptBlockAsm(xk *uint32, dst, src *byte) +// func encryptBlockAsm(xk *uint32, dst, src *byte, inst int) TEXT ·encryptBlockAsm(SB),NOSPLIT,$0 MOVD xk+0(FP), R8 MOVD dst+8(FP), R9 MOVD src+16(FP), R10 + MOVD inst+24(FP), R11 + + CMP $1, R11 + BEQ sm4niblock VLD1 (R10), [t0.S4] VREV32 t0.B16, t0.B16 @@ -312,5 +384,21 @@ encryptBlockLoop: VMOV t1.S[0], V8.S[2] VMOV t0.S[0], V8.S[3] VST1 [V8.B16], (R9) + RET +sm4niblock: + VLD1 (R10), [V8.B16] + VREV32 V8.B16, V8.B16 + VLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4] + WORD 0x0884c0ce //SM4E V8.4S, V0.4S + WORD 0x2884c0ce //SM4E V8.4S, V1.4S + WORD 0x4884c0ce //SM4E V8.4S, V2.4S + WORD 0x6884c0ce //SM4E V8.4S, V3.4S + VLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4] + WORD 0x0884c0ce //SM4E V8.4S, V0.4S + WORD 0x2884c0ce //SM4E V8.4S, V1.4S + WORD 0x4884c0ce //SM4E V8.4S, V2.4S + WORD 0x6884c0ce //SM4E V8.4S, V3.4S + VREV32 V8.B16, V8.B16 + VST1 [V8.B16], (R9) RET diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index ad08b5e..d063732 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -16,13 +16,13 @@ var supportsGFMUL = cpu.X86.HasPCLMULQDQ || cpu.ARM64.HasPMULL var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 //go:noescape -func encryptBlocksAsm(xk *uint32, dst, src []byte) +func encryptBlocksAsm(xk *uint32, dst, src []byte, inst int) //go:noescape -func encryptBlockAsm(xk *uint32, dst, src *byte) +func encryptBlockAsm(xk *uint32, dst, src *byte, inst int) //go:noescape -func expandKeyAsm(key *byte, ck, enc, dec *uint32) +func expandKeyAsm(key *byte, ck, enc, dec *uint32, inst int) type sm4CipherAsm struct { sm4Cipher @@ -31,7 +31,7 @@ type sm4CipherAsm struct { } func newCipher(key []byte) (cipher.Block, error) { - if !supportsAES { + if !(supportsAES || supportSM4) { return newCipherGeneric(key) } blocks := 4 @@ -39,7 +39,11 @@ func newCipher(key []byte) (cipher.Block, error) { blocks = 8 } c := sm4CipherAsm{sm4Cipher{make([]uint32, rounds), make([]uint32, rounds)}, blocks, blocks * BlockSize} - expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0]) + if supportSM4 { + expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], 1) + } else { + expandKeyAsm(&key[0], &ck[0], &c.enc[0], &c.dec[0], 0) + } if supportsAES && supportsGFMUL { return &sm4CipherGCM{c}, nil } @@ -50,6 +54,22 @@ func (c *sm4CipherAsm) BlockSize() int { return BlockSize } func (c *sm4CipherAsm) Concurrency() int { return c.batchBlocks } +func encryptBlockAsmInst(xk *uint32, dst, src *byte) { + if supportSM4 { + encryptBlockAsm(xk, dst, src, 1) + } else { + encryptBlockAsm(xk, dst, src, 0) + } +} + +func encryptBlocksAsmInst(xk *uint32, dst, src []byte) { + if supportSM4 { + encryptBlocksAsm(xk, dst, src, 1) + } else { + encryptBlocksAsm(xk, dst, src, 0) + } +} + func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if len(src) < BlockSize { panic("sm4: input not full block") @@ -60,7 +80,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockAsm(&c.enc[0], &dst[0], &src[0]) + encryptBlockAsmInst(&c.enc[0], &dst[0], &src[0]) } func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { @@ -73,7 +93,7 @@ func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsm(&c.enc[0], dst, src) + encryptBlocksAsmInst(&c.enc[0], dst, src) } func (c *sm4CipherAsm) Decrypt(dst, src []byte) { @@ -86,7 +106,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) { if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } - encryptBlockAsm(&c.dec[0], &dst[0], &src[0]) + encryptBlockAsmInst(&c.dec[0], &dst[0], &src[0]) } func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { @@ -99,14 +119,16 @@ func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } - encryptBlocksAsm(&c.dec[0], dst, src) + encryptBlocksAsmInst(&c.dec[0], dst, src) } // expandKey is used by BenchmarkExpand to ensure that the asm implementation // of key expansion is used for the benchmark when it is available. func expandKey(key []byte, enc, dec []uint32) { - if supportsAES { - expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0]) + if supportSM4 { + expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], 1) + } else if supportsAES { + expandKeyAsm(&key[0], &ck[0], &enc[0], &dec[0], 0) } else { expandKeyGo(key, enc, dec) } diff --git a/sm4/cipher_asm_fuzzy_test.go b/sm4/cipher_asm_fuzzy_test.go index 22c3821..b3894cd 100644 --- a/sm4/cipher_asm_fuzzy_test.go +++ b/sm4/cipher_asm_fuzzy_test.go @@ -34,7 +34,7 @@ func TestExpandKey(t *testing.T) { } io.ReadFull(rand.Reader, key) expandKeyGo(key, encRes1, decRes1) - expandKeyAsm(&key[0], &ck[0], &encRes2[0], &decRes2[0]) + expandKey(key, encRes2, decRes2) if !reflect.DeepEqual(encRes1, encRes2) { t.Errorf("expected=%v, result=%v\n", encRes1, encRes2) } diff --git a/sm4/gen_arm64_ni.go b/sm4/gen_arm64_ni.go new file mode 100644 index 0000000..864d40a --- /dev/null +++ b/sm4/gen_arm64_ni.go @@ -0,0 +1,137 @@ +// Not used yet!!! +// go run gen_arm64_ni.go + +//go:build ignore +// +build ignore + +package main + +import ( + "bytes" + "fmt" + "log" + "math/bits" + "os" +) + +//SM4E .4S, .4S +func sm4e(Vd, Vn byte) uint32 { + inst := uint32(0xcec08400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 + return bits.ReverseBytes32(inst) +} + +//SM4EKEY .4S, .4S, .4S +func sm4ekey(Vd, Vn, Vm byte) uint32 { + inst := uint32(0xce60c800) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16) + return bits.ReverseBytes32(inst) +} + +func sm4ekeyRound(buf *bytes.Buffer, d, n, m byte) { + fmt.Fprintf(buf, "\tWORD 0x%08x //SM4EKEY V%d.4S, V%d.4S, V%d.4S\n", sm4ekey(d, n, m), d, n, m) +} + +func sm4eRound(buf *bytes.Buffer, d, n byte) { + fmt.Fprintf(buf, "\tWORD 0x%08x //SM4E V%d.4S, V%d.4S\n", sm4e(d, n), d, n) +} + +func main() { + buf := new(bytes.Buffer) + fmt.Fprint(buf, ` +// Generated by gen_arm64_ni.go. DO NOT EDIT. + +#include "textflag.h" + +// func expandKeySM4E(key *byte, fk, ck, enc *uint32) +TEXT ·expandKeySM4E(SB),NOSPLIT,$0 + MOVD key+0(FP), R8 + MOVD fk+8(FP), R9 + MOVD ck+16(FP), R10 + MOVD enc+24(FP), R11 + + VLD1 (R8), [V8.B16] + VREV32 V8.B16, V8.B16 + VLD1 (R9), [V9.S4] + VEOR V9, V8, V8 + VLD1.P 64(R10), [V0.S4, V1.S4, V2.S4, V3.S4] +`[1:]) + + sm4ekeyRound(buf, 9, 8, 0) + sm4ekeyRound(buf, 8, 9, 1) + fmt.Fprintf(buf, "\tVST1.P [V9.S4, V8.S4], 32(R11)\n") + sm4ekeyRound(buf, 9, 8, 2) + sm4ekeyRound(buf, 8, 9, 3) + fmt.Fprintf(buf, "\tVST1.P [V9.S4, V8.S4], 32(R11)\n") + fmt.Fprintf(buf, "\tVLD1.P 64(R10), [V0.S4, V1.S4, V2.S4, V3.S4]\n") + sm4ekeyRound(buf, 9, 8, 0) + sm4ekeyRound(buf, 8, 9, 1) + fmt.Fprintf(buf, "\tVST1.P [V9.S4, V8.S4], 32(R11)\n") + sm4ekeyRound(buf, 9, 8, 2) + sm4ekeyRound(buf, 8, 9, 3) + fmt.Fprintf(buf, ` + VST1.P [V9.S4, V8.S4], 32(R11) + RET +`[1:]) + fmt.Fprint(buf, ` + +// func encryptBlockSM4E(xk *uint32, dst, src *byte) +TEXT ·encryptBlockSM4E(SB),NOSPLIT,$0 + MOVD xk+0(FP), R8 + MOVD dst+8(FP), R9 + MOVD src+16(FP), R10 + + VLD1 (R10), [V8.B16] + VREV32 V8.B16, V8.B16 + VLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4] +`[1:]) + sm4eRound(buf, 8, 0) + sm4eRound(buf, 8, 1) + sm4eRound(buf, 8, 2) + sm4eRound(buf, 8, 3) + fmt.Fprintf(buf, "\tVLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4]\n") + sm4eRound(buf, 8, 0) + sm4eRound(buf, 8, 1) + sm4eRound(buf, 8, 2) + sm4eRound(buf, 8, 3) + fmt.Fprintf(buf, ` + VREV32 V8.B16, V8.B16 + VST1 [V8.B16], (R9) + RET +`[1:]) + + fmt.Fprint(buf, ` + +// func encryptBlocksSM4E(xk *uint32, dst, src *byte) +TEXT ·encryptBlocksSM4E(SB),NOSPLIT,$0 + MOVD xk+0(FP), R8 + MOVD dst+8(FP), R9 + MOVD src+16(FP), R10 + + VLD1.P 64(R8), [V0.S4, V1.S4, V2.S4, V3.S4] + VLD1.P 64(R8), [V4.S4, V5.S4, V6.S4, V7.S4] + +`[1:]) + for i := 0; i < 4; i++ { + fmt.Fprintf(buf, "\tVLD1.P 16(R10), [V8.B16]\n") + fmt.Fprintf(buf, "\tVREV32 V8.B16, V8.B16\n") + sm4eRound(buf, 8, 0) + sm4eRound(buf, 8, 1) + sm4eRound(buf, 8, 2) + sm4eRound(buf, 8, 3) + sm4eRound(buf, 8, 4) + sm4eRound(buf, 8, 5) + sm4eRound(buf, 8, 6) + sm4eRound(buf, 8, 7) + fmt.Fprintf(buf, "\tVREV32 V8.B16, V8.B16\n") + fmt.Fprintf(buf, "\tVST1.P [V8.B16], 16(R9)\n\n") + } + fmt.Fprintf(buf, ` + RET +`[1:]) + + src := buf.Bytes() + // fmt.Println(string(src)) + err := os.WriteFile("sm4e_arm64.s", src, 0644) + if err != nil { + log.Fatal(err) + } +}