From df8cb4d95d3e4205163fc22c4759055c978f1782 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 7 Jun 2023 09:43:20 +0800 Subject: [PATCH] bigmod: optimization for 256 bits --- internal/bigmod/nat.go | 13 ++++++ internal/bigmod/nat_386.s | 5 ++ internal/bigmod/nat_amd64.s | 88 ++++++++++++++++++++++++++++++++++++ internal/bigmod/nat_arm.s | 5 ++ internal/bigmod/nat_arm64.s | 5 ++ internal/bigmod/nat_asm.go | 3 ++ internal/bigmod/nat_noasm.go | 8 ++++ internal/bigmod/nat_ppc64x.s | 5 ++ internal/bigmod/nat_s390x.s | 5 ++ internal/bigmod/nat_test.go | 57 +++++++++++++---------- 10 files changed, 171 insertions(+), 23 deletions(-) diff --git a/internal/bigmod/nat.go b/internal/bigmod/nat.go index c277e13..8fe0b2a 100644 --- a/internal/bigmod/nat.go +++ b/internal/bigmod/nat.go @@ -636,6 +636,19 @@ func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { // optimized for the sizes most used in RSA. addMulVVW is implemented in // assembly with loop unrolling depending on the architecture and bounds // checks are removed by the compiler thanks to the constant size. + case 256 / _W: // optimization for 256 bits nat + const n = 256 / _W // compiler hint + T := make([]uint, n*2) + var c uint + for i := 0; i < n; i++ { + d := bLimbs[i] + c1 := addMulVVW256(&T[i], &aLimbs[0], d) + Y := T[i] * m.m0inv + c2 := addMulVVW256(&T[i], &mLimbs[0], Y) + T[n+i], c = bits.Add(c1, c2, c) + } + copy(x.reset(n).limbs, T[n:]) + x.maybeSubtractModulus(choice(c), m) case 1024 / _W: const n = 1024 / _W // compiler hint T := make([]uint, n*2) diff --git a/internal/bigmod/nat_386.s b/internal/bigmod/nat_386.s index e094c0e..626878b 100644 --- a/internal/bigmod/nat_386.s +++ b/internal/bigmod/nat_386.s @@ -7,6 +7,11 @@ #include "textflag.h" +// func addMulVVW256(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW256(SB), $0-16 + MOVL $8, BX + JMP addMulVVWx(SB) + // func addMulVVW1024(z, x *uint, y uint) (c uint) TEXT ·addMulVVW1024(SB), $0-16 MOVL $32, BX diff --git a/internal/bigmod/nat_amd64.s b/internal/bigmod/nat_amd64.s index 16cb4a6..afef2be 100644 --- a/internal/bigmod/nat_amd64.s +++ b/internal/bigmod/nat_amd64.s @@ -3,6 +3,94 @@ //go:build !purego // +build !purego +// func addMulVVW256(z *uint, x *uint, y uint) (c uint) +// Requires: ADX, BMI2 +TEXT ·addMulVVW256(SB), $0-32 + CMPB ·supportADX+0(SB), $0x01 + JEQ adx + MOVQ z+0(FP), CX + MOVQ x+8(FP), BX + MOVQ y+16(FP), SI + XORQ DI, DI + + // Iteration 0 + MOVQ (BX), AX + MULQ SI + ADDQ (CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, (CX) + + // Iteration 1 + MOVQ 8(BX), AX + MULQ SI + ADDQ 8(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 8(CX) + + // Iteration 2 + MOVQ 16(BX), AX + MULQ SI + ADDQ 16(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 16(CX) + + // Iteration 3 + MOVQ 24(BX), AX + MULQ SI + ADDQ 24(CX), AX + ADCQ $0x00, DX + ADDQ DI, AX + ADCQ $0x00, DX + MOVQ DX, DI + MOVQ AX, 24(CX) + RET + +adx: + MOVQ z+0(FP), AX + MOVQ x+8(FP), CX + MOVQ y+16(FP), DX + XORQ BX, BX + XORQ SI, SI + + // Iteration 0 + MULXQ (CX), R8, DI + ADCXQ BX, R8 + ADOXQ (AX), R8 + MOVQ R8, (AX) + + // Iteration 1 + MULXQ 8(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 8(AX), R8 + MOVQ R8, 8(AX) + + // Iteration 2 + MULXQ 16(CX), R8, DI + ADCXQ BX, R8 + ADOXQ 16(AX), R8 + MOVQ R8, 16(AX) + + // Iteration 3 + MULXQ 24(CX), R8, BX + ADCXQ DI, R8 + ADOXQ 24(AX), R8 + MOVQ R8, 24(AX) + + // Add back carry flags and return + ADCXQ SI, BX + ADOXQ SI, BX + MOVQ BX, c+24(FP) + RET + // func addMulVVW1024(z *uint, x *uint, y uint) (c uint) // Requires: ADX, BMI2 TEXT ·addMulVVW1024(SB), $0-32 diff --git a/internal/bigmod/nat_arm.s b/internal/bigmod/nat_arm.s index c1edac7..6570b3a 100644 --- a/internal/bigmod/nat_arm.s +++ b/internal/bigmod/nat_arm.s @@ -7,6 +7,11 @@ #include "textflag.h" +// func addMulVVW256(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW256(SB), $0-16 + MOVW $8, R5 + JMP addMulVVWx(SB) + // func addMulVVW1024(z, x *uint, y uint) (c uint) TEXT ·addMulVVW1024(SB), $0-16 MOVW $32, R5 diff --git a/internal/bigmod/nat_arm64.s b/internal/bigmod/nat_arm64.s index 98e691a..d4983d0 100644 --- a/internal/bigmod/nat_arm64.s +++ b/internal/bigmod/nat_arm64.s @@ -7,6 +7,11 @@ #include "textflag.h" +// func addMulVVW256(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW256(SB), $0-32 + MOVD $4, R0 + JMP addMulVVWx(SB) + // func addMulVVW1024(z, x *uint, y uint) (c uint) TEXT ·addMulVVW1024(SB), $0-32 MOVD $16, R0 diff --git a/internal/bigmod/nat_asm.go b/internal/bigmod/nat_asm.go index 26e2abc..2ba9c72 100644 --- a/internal/bigmod/nat_asm.go +++ b/internal/bigmod/nat_asm.go @@ -20,6 +20,9 @@ import "golang.org/x/sys/cpu" var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2 +//go:noescape +func addMulVVW256(z, x *uint, y uint) (c uint) + //go:noescape func addMulVVW1024(z, x *uint, y uint) (c uint) diff --git a/internal/bigmod/nat_noasm.go b/internal/bigmod/nat_noasm.go index 00a7625..429700b 100644 --- a/internal/bigmod/nat_noasm.go +++ b/internal/bigmod/nat_noasm.go @@ -10,6 +10,10 @@ package bigmod import "unsafe" // TODO: will use unsafe.Slice directly once upgrade golang sdk to 1.17+ +func slice256(ptr *uint) []uint { + return (*[256 / _W]uint)(unsafe.Pointer(ptr))[:] +} + func slice1024(ptr *uint) []uint { return (*[1024 / _W]uint)(unsafe.Pointer(ptr))[:] } @@ -22,6 +26,10 @@ func slice2048(ptr *uint) []uint { return (*[2048 / _W]uint)(unsafe.Pointer(ptr))[:] } +func addMulVVW256(z, x *uint, y uint) (c uint) { + return addMulVVW(slice256(z), slice256(x), y) +} + func addMulVVW1024(z, x *uint, y uint) (c uint) { return addMulVVW(slice1024(z), slice1024(x), y) } diff --git a/internal/bigmod/nat_ppc64x.s b/internal/bigmod/nat_ppc64x.s index d3ae981..735639c 100644 --- a/internal/bigmod/nat_ppc64x.s +++ b/internal/bigmod/nat_ppc64x.s @@ -8,6 +8,11 @@ #include "textflag.h" +// func addMulVVW256(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW256(SB), $0-32 + MOVD $4, R22 // R22 = z_len + JMP addMulVVWx(SB) + // func addMulVVW1024(z, x *uint, y uint) (c uint) TEXT ·addMulVVW1024(SB), $0-32 MOVD $16, R22 // R22 = z_len diff --git a/internal/bigmod/nat_s390x.s b/internal/bigmod/nat_s390x.s index 7d259c8..02389ac 100644 --- a/internal/bigmod/nat_s390x.s +++ b/internal/bigmod/nat_s390x.s @@ -7,6 +7,11 @@ #include "textflag.h" +// func addMulVVW256(z, x *uint, y uint) (c uint) +TEXT ·addMulVVW256(SB), $0-32 + MOVD $4, R5 + JMP addMulVVWx(SB) + // func addMulVVW1024(z, x *uint, y uint) (c uint) TEXT ·addMulVVW1024(SB), $0-32 MOVD $16, R5 diff --git a/internal/bigmod/nat_test.go b/internal/bigmod/nat_test.go index 9b5a9d1..5d9474d 100644 --- a/internal/bigmod/nat_test.go +++ b/internal/bigmod/nat_test.go @@ -363,13 +363,13 @@ func maxModulus(n uint) *Modulus { return m } -func makeBenchmarkModulus() *Modulus { - return maxModulus(32) +func makeBenchmarkModulus(n uint) *Modulus { + return maxModulus(n) } -func makeBenchmarkValue() *Nat { - x := make([]uint, 32) - for i := 0; i < 32; i++ { +func makeBenchmarkValue(n int) *Nat { + x := make([]uint, n) + for i := 0; i < n; i++ { x[i]-- } return &Nat{limbs: x} @@ -384,9 +384,9 @@ func makeBenchmarkExponent() []byte { } func BenchmarkModAdd(b *testing.B) { - x := makeBenchmarkValue() - y := makeBenchmarkValue() - m := makeBenchmarkModulus() + x := makeBenchmarkValue(32) + y := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -395,9 +395,9 @@ func BenchmarkModAdd(b *testing.B) { } func BenchmarkModSub(b *testing.B) { - x := makeBenchmarkValue() - y := makeBenchmarkValue() - m := makeBenchmarkModulus() + x := makeBenchmarkValue(32) + y := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -406,8 +406,8 @@ func BenchmarkModSub(b *testing.B) { } func BenchmarkMontgomeryRepr(b *testing.B) { - x := makeBenchmarkValue() - m := makeBenchmarkModulus() + x := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -416,10 +416,10 @@ func BenchmarkMontgomeryRepr(b *testing.B) { } func BenchmarkMontgomeryMul(b *testing.B) { - x := makeBenchmarkValue() - y := makeBenchmarkValue() - out := makeBenchmarkValue() - m := makeBenchmarkModulus() + x := makeBenchmarkValue(32) + y := makeBenchmarkValue(32) + out := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -428,9 +428,20 @@ func BenchmarkMontgomeryMul(b *testing.B) { } func BenchmarkModMul(b *testing.B) { - x := makeBenchmarkValue() - y := makeBenchmarkValue() - m := makeBenchmarkModulus() + x := makeBenchmarkValue(32) + y := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + x.Mul(y, m) + } +} + +func BenchmarkModMul256(b *testing.B) { + x := makeBenchmarkValue(4) + y := makeBenchmarkValue(4) + m := makeBenchmarkModulus(4) b.ResetTimer() for i := 0; i < b.N; i++ { @@ -454,10 +465,10 @@ func BenchmarkExpBig(b *testing.B) { } func BenchmarkExp(b *testing.B) { - x := makeBenchmarkValue() + x := makeBenchmarkValue(32) e := makeBenchmarkExponent() - out := makeBenchmarkValue() - m := makeBenchmarkModulus() + out := makeBenchmarkValue(32) + m := makeBenchmarkModulus(32) b.ResetTimer() for i := 0; i < b.N; i++ {