From ecab51741142471998f596fb08912d92afd74fff Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 21 Jun 2023 15:45:06 +0800 Subject: [PATCH] sm9/bn256: curvePointMovCond twistPointMovCond asm implementation for amd64 & arm64 --- sm9/bn256/curve.go | 2 +- sm9/bn256/select_amd64.s | 320 ++++++++++++++++++++++++++++++++++++ sm9/bn256/select_arm64.s | 159 ++++++++++++++++++ sm9/bn256/select_decl.go | 10 ++ sm9/bn256/select_generic.go | 8 + sm9/bn256/twist.go | 2 +- 6 files changed, 499 insertions(+), 2 deletions(-) diff --git a/sm9/bn256/curve.go b/sm9/bn256/curve.go index 46382b4..b3b782f 100644 --- a/sm9/bn256/curve.go +++ b/sm9/bn256/curve.go @@ -272,6 +272,6 @@ func (table *curvePointTable) Select(p *curvePoint, n uint8) { p.SetInfinity() for i, f := range table { cond := subtle.ConstantTimeByteEq(uint8(i+1), n) - p.Select(f, p, cond) + curvePointMovCond(p, f, p, cond) } } diff --git a/sm9/bn256/select_amd64.s b/sm9/bn256/select_amd64.s index a8ab327..4d4a69d 100644 --- a/sm9/bn256/select_amd64.s +++ b/sm9/bn256/select_amd64.s @@ -294,3 +294,323 @@ move_avx2: VZEROUPPER RET + +// func curvePointMovCond(res, a, b *curvePoint, cond int) +TEXT ·curvePointMovCond(SB),NOSPLIT,$0 + MOVQ res+0(FP), res_ptr + MOVQ a+8(FP), x_ptr + MOVQ b+16(FP), y_ptr + MOVQ cond+24(FP), X12 + + CMPB ·supportAVX2+0(SB), $0x01 + JEQ move_avx2 + + PXOR X13, X13 + PSHUFD $0, X12, X12 + PCMPEQL X13, X12 + + MOVOU X12, X0 + MOVOU (16*0)(x_ptr), X6 + PANDN X6, X0 + + MOVOU X12, X1 + MOVOU (16*1)(x_ptr), X7 + PANDN X7, X1 + + MOVOU X12, X2 + MOVOU (16*2)(x_ptr), X8 + PANDN X8, X2 + + MOVOU X12, X3 + MOVOU (16*3)(x_ptr), X9 + PANDN X9, X3 + + MOVOU X12, X4 + MOVOU (16*4)(x_ptr), X10 + PANDN X10, X4 + + MOVOU X12, X5 + MOVOU (16*5)(x_ptr), X11 + PANDN X11, X5 + + MOVOU (16*0)(y_ptr), X6 + MOVOU (16*1)(y_ptr), X7 + MOVOU (16*2)(y_ptr), X8 + MOVOU (16*3)(y_ptr), X9 + MOVOU (16*4)(y_ptr), X10 + MOVOU (16*5)(y_ptr), X11 + + PAND X12, X6 + PAND X12, X7 + PAND X12, X8 + PAND X12, X9 + PAND X12, X10 + PAND X12, X11 + + PXOR X6, X0 + PXOR X7, X1 + PXOR X8, X2 + PXOR X9, X3 + PXOR X10, X4 + PXOR X11, X5 + + MOVOU X0, (16*0)(res_ptr) + MOVOU X1, (16*1)(res_ptr) + MOVOU X2, (16*2)(res_ptr) + MOVOU X3, (16*3)(res_ptr) + MOVOU X4, (16*4)(res_ptr) + MOVOU X5, (16*5)(res_ptr) + + MOVOU X12, X0 + MOVOU (16*6)(x_ptr), X6 + PANDN X6, X0 + + MOVOU X12, X1 + MOVOU (16*7)(x_ptr), X7 + PANDN X7, X1 + + MOVOU (16*6)(y_ptr), X6 + MOVOU (16*7)(y_ptr), X7 + + PAND X12, X6 + PAND X12, X7 + + PXOR X6, X0 + PXOR X7, X1 + + MOVOU X0, (16*6)(res_ptr) + MOVOU X1, (16*7)(res_ptr) + + RET + +move_avx2: + VPXOR Y13, Y13, Y13 + VPBROADCASTD X12, Y12 + VPCMPEQD Y13, Y12, Y12 + + VPANDN (32*0)(x_ptr), Y12, Y0 + VPANDN (32*1)(x_ptr), Y12, Y1 + VPANDN (32*2)(x_ptr), Y12, Y2 + VPANDN (32*3)(x_ptr), Y12, Y3 + + VPAND (32*0)(y_ptr), Y12, Y6 + VPAND (32*1)(y_ptr), Y12, Y7 + VPAND (32*2)(y_ptr), Y12, Y8 + VPAND (32*3)(y_ptr), Y12, Y9 + + VPXOR Y6, Y0, Y0 + VPXOR Y7, Y1, Y1 + VPXOR Y8, Y2, Y2 + VPXOR Y9, Y3, Y3 + + VMOVDQU Y0, (32*0)(res_ptr) + VMOVDQU Y1, (32*1)(res_ptr) + VMOVDQU Y2, (32*2)(res_ptr) + VMOVDQU Y3, (32*3)(res_ptr) + + VZEROUPPER + RET + +// func twistPointMovCond(res, a, b *twistPoint, cond int) +TEXT ·twistPointMovCond(SB),NOSPLIT,$0 + MOVQ res+0(FP), res_ptr + MOVQ a+8(FP), x_ptr + MOVQ b+16(FP), y_ptr + MOVQ cond+24(FP), X12 + + CMPB ·supportAVX2+0(SB), $0x01 + JEQ move_avx2 + + PXOR X13, X13 + PSHUFD $0, X12, X12 + PCMPEQL X13, X12 + + MOVOU X12, X0 + MOVOU (16*0)(x_ptr), X6 + PANDN X6, X0 + + MOVOU X12, X1 + MOVOU (16*1)(x_ptr), X7 + PANDN X7, X1 + + MOVOU X12, X2 + MOVOU (16*2)(x_ptr), X8 + PANDN X8, X2 + + MOVOU X12, X3 + MOVOU (16*3)(x_ptr), X9 + PANDN X9, X3 + + MOVOU X12, X4 + MOVOU (16*4)(x_ptr), X10 + PANDN X10, X4 + + MOVOU X12, X5 + MOVOU (16*5)(x_ptr), X11 + PANDN X11, X5 + + MOVOU (16*0)(y_ptr), X6 + MOVOU (16*1)(y_ptr), X7 + MOVOU (16*2)(y_ptr), X8 + MOVOU (16*3)(y_ptr), X9 + MOVOU (16*4)(y_ptr), X10 + MOVOU (16*5)(y_ptr), X11 + + PAND X12, X6 + PAND X12, X7 + PAND X12, X8 + PAND X12, X9 + PAND X12, X10 + PAND X12, X11 + + PXOR X6, X0 + PXOR X7, X1 + PXOR X8, X2 + PXOR X9, X3 + PXOR X10, X4 + PXOR X11, X5 + + MOVOU X0, (16*0)(res_ptr) + MOVOU X1, (16*1)(res_ptr) + MOVOU X2, (16*2)(res_ptr) + MOVOU X3, (16*3)(res_ptr) + MOVOU X4, (16*4)(res_ptr) + MOVOU X5, (16*5)(res_ptr) + + MOVOU X12, X0 + MOVOU (16*6)(x_ptr), X6 + PANDN X6, X0 + + MOVOU X12, X1 + MOVOU (16*7)(x_ptr), X7 + PANDN X7, X1 + + MOVOU X12, X2 + MOVOU (16*8)(x_ptr), X8 + PANDN X8, X2 + + MOVOU X12, X3 + MOVOU (16*9)(x_ptr), X9 + PANDN X9, X3 + + MOVOU X12, X4 + MOVOU (16*10)(x_ptr), X10 + PANDN X10, X4 + + MOVOU X12, X5 + MOVOU (16*11)(x_ptr), X11 + PANDN X11, X5 + + MOVOU (16*6)(y_ptr), X6 + MOVOU (16*7)(y_ptr), X7 + MOVOU (16*8)(y_ptr), X8 + MOVOU (16*9)(y_ptr), X9 + MOVOU (16*10)(y_ptr), X10 + MOVOU (16*11)(y_ptr), X11 + + PAND X12, X6 + PAND X12, X7 + PAND X12, X8 + PAND X12, X9 + PAND X12, X10 + PAND X12, X11 + + PXOR X6, X0 + PXOR X7, X1 + PXOR X8, X2 + PXOR X9, X3 + PXOR X10, X4 + PXOR X11, X5 + + MOVOU X0, (16*6)(res_ptr) + MOVOU X1, (16*7)(res_ptr) + MOVOU X2, (16*8)(res_ptr) + MOVOU X3, (16*9)(res_ptr) + MOVOU X4, (16*10)(res_ptr) + MOVOU X5, (16*11)(res_ptr) + + MOVOU X12, X0 + MOVOU (16*12)(x_ptr), X6 + PANDN X6, X0 + + MOVOU X12, X1 + MOVOU (16*13)(x_ptr), X7 + PANDN X7, X1 + + MOVOU X12, X2 + MOVOU (16*14)(x_ptr), X8 + PANDN X8, X2 + + MOVOU X12, X3 + MOVOU (16*15)(x_ptr), X9 + PANDN X9, X3 + + MOVOU (16*12)(y_ptr), X6 + MOVOU (16*13)(y_ptr), X7 + MOVOU (16*14)(y_ptr), X8 + MOVOU (16*15)(y_ptr), X9 + + PAND X12, X6 + PAND X12, X7 + PAND X12, X8 + PAND X12, X9 + + PXOR X6, X0 + PXOR X7, X1 + PXOR X8, X2 + PXOR X9, X3 + + MOVOU X0, (16*12)(res_ptr) + MOVOU X1, (16*13)(res_ptr) + MOVOU X2, (16*14)(res_ptr) + MOVOU X3, (16*15)(res_ptr) + + RET + +move_avx2: + VPXOR Y13, Y13, Y13 + VPBROADCASTD X12, Y12 + VPCMPEQD Y13, Y12, Y12 + + VPANDN (32*0)(x_ptr), Y12, Y0 + VPANDN (32*1)(x_ptr), Y12, Y1 + VPANDN (32*2)(x_ptr), Y12, Y2 + VPANDN (32*3)(x_ptr), Y12, Y3 + VPANDN (32*4)(x_ptr), Y12, Y4 + VPANDN (32*5)(x_ptr), Y12, Y5 + + VPAND (32*0)(y_ptr), Y12, Y6 + VPAND (32*1)(y_ptr), Y12, Y7 + VPAND (32*2)(y_ptr), Y12, Y8 + VPAND (32*3)(y_ptr), Y12, Y9 + VPAND (32*4)(y_ptr), Y12, Y10 + VPAND (32*5)(y_ptr), Y12, Y11 + + VPXOR Y6, Y0, Y0 + VPXOR Y7, Y1, Y1 + VPXOR Y8, Y2, Y2 + VPXOR Y9, Y3, Y3 + VPXOR Y10, Y4, Y4 + VPXOR Y11, Y5, Y5 + + VMOVDQU Y0, (32*0)(res_ptr) + VMOVDQU Y1, (32*1)(res_ptr) + VMOVDQU Y2, (32*2)(res_ptr) + VMOVDQU Y3, (32*3)(res_ptr) + VMOVDQU Y4, (32*4)(res_ptr) + VMOVDQU Y5, (32*5)(res_ptr) + + VPANDN (32*6)(x_ptr), Y12, Y0 + VPANDN (32*7)(x_ptr), Y12, Y1 + + VPAND (32*6)(y_ptr), Y12, Y6 + VPAND (32*7)(y_ptr), Y12, Y7 + + VPXOR Y6, Y0, Y0 + VPXOR Y7, Y1, Y1 + + VMOVDQU Y0, (32*6)(res_ptr) + VMOVDQU Y1, (32*7)(res_ptr) + + VZEROUPPER + RET diff --git a/sm9/bn256/select_arm64.s b/sm9/bn256/select_arm64.s index 2f88300..519a382 100644 --- a/sm9/bn256/select_arm64.s +++ b/sm9/bn256/select_arm64.s @@ -149,3 +149,162 @@ TEXT ·gfP12MovCond(SB),NOSPLIT,$0 STP (R8, R9), 23*16(res_ptr) RET + +/* ---------------------------------------*/ +// func curvePointMovCond(res, a, b *curvePoint, cond int) +// If cond == 0 res=b, else res=a +TEXT ·curvePointMovCond(SB),NOSPLIT,$0 + MOVD res+0(FP), res_ptr + MOVD a+8(FP), a_ptr + MOVD b+16(FP), b_ptr + MOVD cond+24(FP), R3 + + CMP $0, R3 + // Two remarks: + // 1) Will want to revisit NEON, when support is better + // 2) CSEL might not be constant time on all ARM processors + LDP 0*16(a_ptr), (R4, R5) + LDP 1*16(a_ptr), (R6, R7) + LDP 2*16(a_ptr), (R8, R9) + LDP 0*16(b_ptr), (R16, R17) + LDP 1*16(b_ptr), (R19, R20) + LDP 2*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 0*16(res_ptr) + STP (R6, R7), 1*16(res_ptr) + STP (R8, R9), 2*16(res_ptr) + + LDP 3*16(a_ptr), (R4, R5) + LDP 4*16(a_ptr), (R6, R7) + LDP 5*16(a_ptr), (R8, R9) + LDP 3*16(b_ptr), (R16, R17) + LDP 4*16(b_ptr), (R19, R20) + LDP 5*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 3*16(res_ptr) + STP (R6, R7), 4*16(res_ptr) + STP (R8, R9), 5*16(res_ptr) + + LDP 6*16(a_ptr), (R4, R5) + LDP 7*16(a_ptr), (R6, R7) + LDP 6*16(b_ptr), (R16, R17) + LDP 7*16(b_ptr), (R19, R20) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + STP (R4, R5), 6*16(res_ptr) + STP (R6, R7), 7*16(res_ptr) + + RET + +/* ---------------------------------------*/ +// func twistPointMovCond(res, a, b *twistPoint, cond int) +// If cond == 0 res=b, else res=a +TEXT ·twistPointMovCond(SB),NOSPLIT,$0 + MOVD res+0(FP), res_ptr + MOVD a+8(FP), a_ptr + MOVD b+16(FP), b_ptr + MOVD cond+24(FP), R3 + + CMP $0, R3 + // Two remarks: + // 1) Will want to revisit NEON, when support is better + // 2) CSEL might not be constant time on all ARM processors + LDP 0*16(a_ptr), (R4, R5) + LDP 1*16(a_ptr), (R6, R7) + LDP 2*16(a_ptr), (R8, R9) + LDP 0*16(b_ptr), (R16, R17) + LDP 1*16(b_ptr), (R19, R20) + LDP 2*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 0*16(res_ptr) + STP (R6, R7), 1*16(res_ptr) + STP (R8, R9), 2*16(res_ptr) + + LDP 3*16(a_ptr), (R4, R5) + LDP 4*16(a_ptr), (R6, R7) + LDP 5*16(a_ptr), (R8, R9) + LDP 3*16(b_ptr), (R16, R17) + LDP 4*16(b_ptr), (R19, R20) + LDP 5*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 3*16(res_ptr) + STP (R6, R7), 4*16(res_ptr) + STP (R8, R9), 5*16(res_ptr) + + LDP 6*16(a_ptr), (R4, R5) + LDP 7*16(a_ptr), (R6, R7) + LDP 8*16(a_ptr), (R8, R9) + LDP 6*16(b_ptr), (R16, R17) + LDP 7*16(b_ptr), (R19, R20) + LDP 8*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 6*16(res_ptr) + STP (R6, R7), 7*16(res_ptr) + STP (R8, R9), 8*16(res_ptr) + + LDP 9*16(a_ptr), (R4, R5) + LDP 10*16(a_ptr), (R6, R7) + LDP 11*16(a_ptr), (R8, R9) + LDP 9*16(b_ptr), (R16, R17) + LDP 10*16(b_ptr), (R19, R20) + LDP 11*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 9*16(res_ptr) + STP (R6, R7), 10*16(res_ptr) + STP (R8, R9), 11*16(res_ptr) + + LDP 12*16(a_ptr), (R4, R5) + LDP 13*16(a_ptr), (R6, R7) + LDP 14*16(a_ptr), (R8, R9) + LDP 12*16(b_ptr), (R16, R17) + LDP 13*16(b_ptr), (R19, R20) + LDP 14*16(b_ptr), (R21, R22) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + CSEL EQ, R19, R6, R6 + CSEL EQ, R20, R7, R7 + CSEL EQ, R21, R8, R8 + CSEL EQ, R22, R9, R9 + STP (R4, R5), 12*16(res_ptr) + STP (R6, R7), 13*16(res_ptr) + STP (R8, R9), 14*16(res_ptr) + + LDP 15*16(a_ptr), (R4, R5) + LDP 15*16(b_ptr), (R16, R17) + CSEL EQ, R16, R4, R4 + CSEL EQ, R17, R5, R5 + STP (R4, R5), 15*16(res_ptr) + + RET diff --git a/sm9/bn256/select_decl.go b/sm9/bn256/select_decl.go index 444272d..0d5c0da 100644 --- a/sm9/bn256/select_decl.go +++ b/sm9/bn256/select_decl.go @@ -11,3 +11,13 @@ var supportAVX2 = cpu.X86.HasAVX2 // //go:noescape func gfP12MovCond(res, a, b *gfP12, cond int) + +// If cond is 0, sets res = b, otherwise sets res = a. +// +//go:noescape +func curvePointMovCond(res, a, b *curvePoint, cond int) + +// If cond is 0, sets res = b, otherwise sets res = a. +// +//go:noescape +func twistPointMovCond(res, a, b *twistPoint, cond int) diff --git a/sm9/bn256/select_generic.go b/sm9/bn256/select_generic.go index c8d9aa2..7500455 100644 --- a/sm9/bn256/select_generic.go +++ b/sm9/bn256/select_generic.go @@ -6,3 +6,11 @@ package bn256 func gfP12MovCond(res, a, b *gfP12, cond int) { res.Select(a, b, cond) } + +func curvePointMovCond(res, a, b *curvePoint, cond int) { + res.Select(a, b, cond) +} + +func twistPointMovCond(res, a, b *twistPoint, cond int) { + res.Select(a, b, cond) +} diff --git a/sm9/bn256/twist.go b/sm9/bn256/twist.go index b5cb9f7..222db35 100644 --- a/sm9/bn256/twist.go +++ b/sm9/bn256/twist.go @@ -264,7 +264,7 @@ func (table *twistPointTable) Select(p *twistPoint, n uint8) { p.SetInfinity() for i, f := range table { cond := subtle.ConstantTimeByteEq(uint8(i+1), n) - p.Select(f, p, cond) + twistPointMovCond(p, f, p, cond) } }