diff --git a/sm3/gen_sm3block_ni.go b/sm3/gen_sm3block_ni.go new file mode 100644 index 0000000..808312c --- /dev/null +++ b/sm3/gen_sm3block_ni.go @@ -0,0 +1,183 @@ +// Not used yet!!! +// go run gen_sm3block_ni.go + +//go:build ignore +// +build ignore + +package main + +import ( + "bytes" + "fmt" + "log" + "math/bits" + "os" +) + +//SM3PARTW1 .4S, .4S, .4S +func sm3partw1(Vd, Vn, Vm byte) uint32 { + inst := uint32(0xce60c000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16) + return bits.ReverseBytes32(inst) +} + +//SM3PARTW2 .4S, .4S, .4S +func sm3partw2(Vd, Vn, Vm byte) uint32 { + inst := uint32(0xce60c400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | (uint32(Vm&0x1f) << 16) + return bits.ReverseBytes32(inst) +} + +//SM3SS1 .4S, .4S, .4S, .4S +func sm3ss1(Vd, Vn, Vm, Va byte) uint32 { + inst := uint32(0xce400000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(Va&0x1f)<<10 | uint32(Vm&0x1f)<<16 + return bits.ReverseBytes32(inst) +} + +//SM3TT1A .4S, .4S, .S[] +func sm3tt1a(Vd, Vn, Vm, imm2 byte) uint32 { + inst := uint32(0xce408000) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16 + return bits.ReverseBytes32(inst) +} + +//SM3TT1B .4S, .4S, .S[] +func sm3tt1b(Vd, Vn, Vm, imm2 byte) uint32 { + inst := uint32(0xce408400) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16 + return bits.ReverseBytes32(inst) +} + +//SM3TT2A .4S, .4S, .S[] +func sm3tt2a(Vd, Vn, Vm, imm2 byte) uint32 { + inst := uint32(0xce408800) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16 + return bits.ReverseBytes32(inst) +} + +//SM3TT2B .4S, .4S, .S[] +func sm3tt2b(Vd, Vn, Vm, imm2 byte) uint32 { + inst := uint32(0xce408c00) | uint32(Vd&0x1f) | uint32(Vn&0x1f)<<5 | uint32(imm2&0x3)<<12 | uint32(Vm&0x1f)<<16 + return bits.ReverseBytes32(inst) +} + +// Used v5 as temp register +func roundA(buf *bytes.Buffer, i, t, st1, st2, w, wt byte) { + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t, st2), 5, st1, t, st2) + fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t, t) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3TT1A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1a(st1, 5, wt, i), st1, 5, wt, i) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3TT2A V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2a(st2, 5, w, i), st2, 5, w, i) +} + +// Used v5 as temp register +func roundB(buf *bytes.Buffer, i, t, st1, st2, w, wt byte) { + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3SS1 V%d.4S, V%d.4S, V%d.4S, V%d.4S\n", sm3ss1(5, st1, t, st2), 5, st1, t, st2) + fmt.Fprintf(buf, "\tVSHL $1, V%d.S4, V%d.S4\n", t, t) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3TT1B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt1b(st1, 5, wt, i), st1, 5, wt, i) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3TT2B V%dd.4S, V%d.4S, V%d.S, %d\n", sm3tt2b(st2, 5, w, i), st2, 5, w, i) +} + +// Compress 4 words and generate 4 words, use v6, v7, v10 as temp registers +// s4, used to store next 4 words +// s0, W(4i) W(4i+1) W(4i+2) W(4i+3) +// s1, W(4i+4) W(4i+5) W(4i+6) W(4i+7) +// s2, W(4i+8) W(4i+9) W(4i+10) W(4i+11) +// s3, W(4i+12) W(4i+13) W(4i+14) W(4i+15) +// t, t constant +// st1, st2, sm3 state +func qroundA(buf *bytes.Buffer, t, st1, st2, s0, s1, s2, s3, s4 byte) { + fmt.Fprintf(buf, "\t// Extension\n") + fmt.Fprintf(buf, "\tVEXT 3, V%d, V%d, V%d\n", s2, s1, s4) + fmt.Fprintf(buf, "\tVEXT 3, V%d, V%d, V%d\n", s1, s0, 6) + fmt.Fprintf(buf, "\tVEXT 2, V%d, V%d, V%d\n", s3, s2, 7) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6) + fmt.Fprintf(buf, "\tVEOR V%d, V%d, V10\n", s1, s0) + fmt.Fprintf(buf, "\t// Compression\n") + roundA(buf, 0, t, st1, st2, s0, 10) + roundA(buf, 1, t, st1, st2, s0, 10) + roundA(buf, 2, t, st1, st2, s0, 10) + roundA(buf, 3, t, st1, st2, s0, 10) + fmt.Fprintf(buf, "\n") +} + +// Used v6, v7, v10 as temp registers +func qroundB(buf *bytes.Buffer, t, st1, st2, s0, s1, s2, s3, s4 byte) { + if s4 != 0xff { + fmt.Fprintf(buf, "\t// Extension\n") + fmt.Fprintf(buf, "\tVEXT 3, V%d, V%d, V%d\n", s2, s1, s4) + fmt.Fprintf(buf, "\tVEXT 3, V%d, V%d, V%d\n", s1, s0, 6) + fmt.Fprintf(buf, "\tVEXT 2, V%d, V%d, V%d\n", s3, s2, 7) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3PARTW1 V%d.4S, V%d.4S, V%d.4S\n", sm3partw1(s4, s0, s3), s4, s0, s3) + fmt.Fprintf(buf, "\tWORD 0x%08x //SM3PARTW2 V%d.4S, V%d.4S, V%d.4S\n", sm3partw2(s4, 7, 6), s4, 7, 6) + } + fmt.Fprintf(buf, "\tVEOR V%d, V%d, V10\n", s1, s0) + fmt.Fprintf(buf, "\t// Compression\n") + roundB(buf, 0, t, st1, st2, s0, 10) + roundB(buf, 1, t, st1, st2, s0, 10) + roundB(buf, 2, t, st1, st2, s0, 10) + roundB(buf, 3, t, st1, st2, s0, 10) + fmt.Fprintf(buf, "\n") +} + +func main() { + buf := new(bytes.Buffer) + fmt.Fprint(buf, ` +// Generated by gen_sm3block_ni.go. DO NOT EDIT. + +#include "textflag.h" + +// func blockSM3NI(h []uint32, p []byte, t []uint32) +TEXT ·blockSM3NI(SB), 0, $0 + MOVD h_base+0(FP), R0 // Hash value first address + MOVD p_base+24(FP), R1 // message first address + MOVD p_len+32(FP), R3 // message length + MOVD t_base+48(FP), R2 // t constants first address + + VLD1 (R0), [V8.S4, V9.S4] // load h(a,b,c,d,e,f,g,h) + LDPW (0*8)(R2), (R5, R6) // load t constants + +blockloop: + VLD1.P 64(R1), [V0.B16, V1.B16, V2.B16, V3.B16] // load 64bytes message + VMOV V8.B16, V15.B16 // backup: V8 h(dcba) + VMOV V9.B16, V16.B16 // backup: V9 h(hgfe) + VREV32 V0.B16, V0.B16 // prepare for using message in Byte format + VREV32 V1.B16, V1.B16 + VREV32 V2.B16, V2.B16 + VREV32 V3.B16, V3.B16 + // first 16 rounds + VMOV R5, V11.S[3] +`[1:]) + qroundA(buf, 11, 8, 9, 0, 1, 2, 3, 4) + qroundA(buf, 11, 8, 9, 1, 2, 3, 4, 0) + qroundA(buf, 11, 8, 9, 2, 3, 4, 0, 1) + qroundA(buf, 11, 8, 9, 3, 4, 0, 1, 2) + + fmt.Fprintf(buf, "\t// second 48 rounds\n") + fmt.Fprintf(buf, "\tVMOV R6, V11.S[3]\n") + qroundB(buf, 11, 8, 9, 4, 0, 1, 2, 3) + qroundB(buf, 11, 8, 9, 0, 1, 2, 3, 4) + qroundB(buf, 11, 8, 9, 1, 2, 3, 4, 0) + qroundB(buf, 11, 8, 9, 2, 3, 4, 0, 1) + qroundB(buf, 11, 8, 9, 3, 4, 0, 1, 2) + qroundB(buf, 11, 8, 9, 4, 0, 1, 2, 3) + qroundB(buf, 11, 8, 9, 0, 1, 2, 3, 4) + qroundB(buf, 11, 8, 9, 1, 2, 3, 4, 0) + qroundB(buf, 11, 8, 9, 2, 3, 4, 0, 1) + qroundB(buf, 11, 8, 9, 3, 4, 0xff, 0xff, 0xff) + qroundB(buf, 11, 8, 9, 4, 0, 0xff, 0xff, 0xff) + qroundB(buf, 11, 8, 9, 0, 1, 0xff, 0xff, 0xff) + + fmt.Fprint(buf, ` + SUB $64, R3, R3 // message length - 64bytes, then compare with 64bytes + VEOR V8.S4, V15.S4, V8.S4 + VEOR V9.S4, V16.S4, V9.S4 + CBNZ R3, blockloop + +sm3ret: + VST1 [V8.S4, V9.S4], (R0) // store hash value H + RET +`[1:]) + src := buf.Bytes() + // fmt.Println(string(src)) + + err := os.WriteFile("sm3block_arm64_ni.s", src, 0644) + if err != nil { + log.Fatal(err) + } +} diff --git a/sm3/sm3block_amd64.go b/sm3/sm3block_amd64.go index 040cd35..3c2a762 100644 --- a/sm3/sm3block_amd64.go +++ b/sm3/sm3block_amd64.go @@ -6,3 +6,6 @@ package sm3 import "golang.org/x/sys/cpu" var useAVX2 = cpu.X86.HasAVX2 && cpu.X86.HasBMI2 + +//go:noescape +func block(dig *digest, p []byte) \ No newline at end of file diff --git a/sm3/sm3block_arm64.go b/sm3/sm3block_arm64.go index d2b22e6..1076e8e 100644 --- a/sm3/sm3block_arm64.go +++ b/sm3/sm3block_arm64.go @@ -6,3 +6,23 @@ package sm3 import "golang.org/x/sys/cpu" var useSM3NI = cpu.ARM64.HasSM3 + +var t = []uint32{ + 0x79cc4519, + 0x9d8a7a87, +} + +//go:noescape +func blockARM64(dig *digest, p []byte) + +//go:noescape +func blockSM3NI(h []uint32, p []byte, t []uint32) + +func block(dig *digest, p []byte) { + //if !useSM3NI { + blockARM64(dig, p) + //} else { + // h := dig.h[:] + // blockSM3NI(h, p, t) + //} +} diff --git a/sm3/sm3block_arm64.s b/sm3/sm3block_arm64.s index 016ca41..9e43c8d 100644 --- a/sm3/sm3block_arm64.s +++ b/sm3/sm3block_arm64.s @@ -127,8 +127,8 @@ SM3TT21(e, f, g, h); \ COPYRESULT(b, d, f, h) -// func block(dig *digest, p []byte) -TEXT ·block(SB), 0, $272-32 +// func blockARM64(dig *digest, p []byte) +TEXT ·blockARM64(SB), 0, $272-32 MOVD dig+0(FP), hlp1 MOVD p_base+8(FP), SI MOVD p_len+16(FP), DX diff --git a/sm3/sm3block_decl.go b/sm3/sm3block_decl.go deleted file mode 100644 index 611ebbb..0000000 --- a/sm3/sm3block_decl.go +++ /dev/null @@ -1,7 +0,0 @@ -//go:build amd64 || arm64 -// +build amd64 arm64 - -package sm3 - -//go:noescape -func block(dig *digest, p []byte)