bigmod: optimization for 256 bits

This commit is contained in:
Sun Yimin 2023-06-07 09:43:20 +08:00 committed by GitHub
parent 207fd1e7a4
commit df8cb4d95d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 171 additions and 23 deletions

View File

@ -636,6 +636,19 @@ func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
// optimized for the sizes most used in RSA. addMulVVW is implemented in // optimized for the sizes most used in RSA. addMulVVW is implemented in
// assembly with loop unrolling depending on the architecture and bounds // assembly with loop unrolling depending on the architecture and bounds
// checks are removed by the compiler thanks to the constant size. // checks are removed by the compiler thanks to the constant size.
case 256 / _W: // optimization for 256 bits nat
const n = 256 / _W // compiler hint
T := make([]uint, n*2)
var c uint
for i := 0; i < n; i++ {
d := bLimbs[i]
c1 := addMulVVW256(&T[i], &aLimbs[0], d)
Y := T[i] * m.m0inv
c2 := addMulVVW256(&T[i], &mLimbs[0], Y)
T[n+i], c = bits.Add(c1, c2, c)
}
copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m)
case 1024 / _W: case 1024 / _W:
const n = 1024 / _W // compiler hint const n = 1024 / _W // compiler hint
T := make([]uint, n*2) T := make([]uint, n*2)

View File

@ -7,6 +7,11 @@
#include "textflag.h" #include "textflag.h"
// func addMulVVW256(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW256(SB), $0-16
MOVL $8, BX
JMP addMulVVWx(SB)
// func addMulVVW1024(z, x *uint, y uint) (c uint) // func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-16 TEXT ·addMulVVW1024(SB), $0-16
MOVL $32, BX MOVL $32, BX

View File

@ -3,6 +3,94 @@
//go:build !purego //go:build !purego
// +build !purego // +build !purego
// func addMulVVW256(z *uint, x *uint, y uint) (c uint)
// Requires: ADX, BMI2
TEXT ·addMulVVW256(SB), $0-32
CMPB ·supportADX+0(SB), $0x01
JEQ adx
MOVQ z+0(FP), CX
MOVQ x+8(FP), BX
MOVQ y+16(FP), SI
XORQ DI, DI
// Iteration 0
MOVQ (BX), AX
MULQ SI
ADDQ (CX), AX
ADCQ $0x00, DX
ADDQ DI, AX
ADCQ $0x00, DX
MOVQ DX, DI
MOVQ AX, (CX)
// Iteration 1
MOVQ 8(BX), AX
MULQ SI
ADDQ 8(CX), AX
ADCQ $0x00, DX
ADDQ DI, AX
ADCQ $0x00, DX
MOVQ DX, DI
MOVQ AX, 8(CX)
// Iteration 2
MOVQ 16(BX), AX
MULQ SI
ADDQ 16(CX), AX
ADCQ $0x00, DX
ADDQ DI, AX
ADCQ $0x00, DX
MOVQ DX, DI
MOVQ AX, 16(CX)
// Iteration 3
MOVQ 24(BX), AX
MULQ SI
ADDQ 24(CX), AX
ADCQ $0x00, DX
ADDQ DI, AX
ADCQ $0x00, DX
MOVQ DX, DI
MOVQ AX, 24(CX)
RET
adx:
MOVQ z+0(FP), AX
MOVQ x+8(FP), CX
MOVQ y+16(FP), DX
XORQ BX, BX
XORQ SI, SI
// Iteration 0
MULXQ (CX), R8, DI
ADCXQ BX, R8
ADOXQ (AX), R8
MOVQ R8, (AX)
// Iteration 1
MULXQ 8(CX), R8, BX
ADCXQ DI, R8
ADOXQ 8(AX), R8
MOVQ R8, 8(AX)
// Iteration 2
MULXQ 16(CX), R8, DI
ADCXQ BX, R8
ADOXQ 16(AX), R8
MOVQ R8, 16(AX)
// Iteration 3
MULXQ 24(CX), R8, BX
ADCXQ DI, R8
ADOXQ 24(AX), R8
MOVQ R8, 24(AX)
// Add back carry flags and return
ADCXQ SI, BX
ADOXQ SI, BX
MOVQ BX, c+24(FP)
RET
// func addMulVVW1024(z *uint, x *uint, y uint) (c uint) // func addMulVVW1024(z *uint, x *uint, y uint) (c uint)
// Requires: ADX, BMI2 // Requires: ADX, BMI2
TEXT ·addMulVVW1024(SB), $0-32 TEXT ·addMulVVW1024(SB), $0-32

View File

@ -7,6 +7,11 @@
#include "textflag.h" #include "textflag.h"
// func addMulVVW256(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW256(SB), $0-16
MOVW $8, R5
JMP addMulVVWx(SB)
// func addMulVVW1024(z, x *uint, y uint) (c uint) // func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-16 TEXT ·addMulVVW1024(SB), $0-16
MOVW $32, R5 MOVW $32, R5

View File

@ -7,6 +7,11 @@
#include "textflag.h" #include "textflag.h"
// func addMulVVW256(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW256(SB), $0-32
MOVD $4, R0
JMP addMulVVWx(SB)
// func addMulVVW1024(z, x *uint, y uint) (c uint) // func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32 TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R0 MOVD $16, R0

View File

@ -20,6 +20,9 @@ import "golang.org/x/sys/cpu"
var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2 var supportADX = cpu.X86.HasADX && cpu.X86.HasBMI2
//go:noescape
func addMulVVW256(z, x *uint, y uint) (c uint)
//go:noescape //go:noescape
func addMulVVW1024(z, x *uint, y uint) (c uint) func addMulVVW1024(z, x *uint, y uint) (c uint)

View File

@ -10,6 +10,10 @@ package bigmod
import "unsafe" import "unsafe"
// TODO: will use unsafe.Slice directly once upgrade golang sdk to 1.17+ // TODO: will use unsafe.Slice directly once upgrade golang sdk to 1.17+
func slice256(ptr *uint) []uint {
return (*[256 / _W]uint)(unsafe.Pointer(ptr))[:]
}
func slice1024(ptr *uint) []uint { func slice1024(ptr *uint) []uint {
return (*[1024 / _W]uint)(unsafe.Pointer(ptr))[:] return (*[1024 / _W]uint)(unsafe.Pointer(ptr))[:]
} }
@ -22,6 +26,10 @@ func slice2048(ptr *uint) []uint {
return (*[2048 / _W]uint)(unsafe.Pointer(ptr))[:] return (*[2048 / _W]uint)(unsafe.Pointer(ptr))[:]
} }
func addMulVVW256(z, x *uint, y uint) (c uint) {
return addMulVVW(slice256(z), slice256(x), y)
}
func addMulVVW1024(z, x *uint, y uint) (c uint) { func addMulVVW1024(z, x *uint, y uint) (c uint) {
return addMulVVW(slice1024(z), slice1024(x), y) return addMulVVW(slice1024(z), slice1024(x), y)
} }

View File

@ -8,6 +8,11 @@
#include "textflag.h" #include "textflag.h"
// func addMulVVW256(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW256(SB), $0-32
MOVD $4, R22 // R22 = z_len
JMP addMulVVWx(SB)
// func addMulVVW1024(z, x *uint, y uint) (c uint) // func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32 TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R22 // R22 = z_len MOVD $16, R22 // R22 = z_len

View File

@ -7,6 +7,11 @@
#include "textflag.h" #include "textflag.h"
// func addMulVVW256(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW256(SB), $0-32
MOVD $4, R5
JMP addMulVVWx(SB)
// func addMulVVW1024(z, x *uint, y uint) (c uint) // func addMulVVW1024(z, x *uint, y uint) (c uint)
TEXT ·addMulVVW1024(SB), $0-32 TEXT ·addMulVVW1024(SB), $0-32
MOVD $16, R5 MOVD $16, R5

View File

@ -363,13 +363,13 @@ func maxModulus(n uint) *Modulus {
return m return m
} }
func makeBenchmarkModulus() *Modulus { func makeBenchmarkModulus(n uint) *Modulus {
return maxModulus(32) return maxModulus(n)
} }
func makeBenchmarkValue() *Nat { func makeBenchmarkValue(n int) *Nat {
x := make([]uint, 32) x := make([]uint, n)
for i := 0; i < 32; i++ { for i := 0; i < n; i++ {
x[i]-- x[i]--
} }
return &Nat{limbs: x} return &Nat{limbs: x}
@ -384,9 +384,9 @@ func makeBenchmarkExponent() []byte {
} }
func BenchmarkModAdd(b *testing.B) { func BenchmarkModAdd(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
y := makeBenchmarkValue() y := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -395,9 +395,9 @@ func BenchmarkModAdd(b *testing.B) {
} }
func BenchmarkModSub(b *testing.B) { func BenchmarkModSub(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
y := makeBenchmarkValue() y := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -406,8 +406,8 @@ func BenchmarkModSub(b *testing.B) {
} }
func BenchmarkMontgomeryRepr(b *testing.B) { func BenchmarkMontgomeryRepr(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -416,10 +416,10 @@ func BenchmarkMontgomeryRepr(b *testing.B) {
} }
func BenchmarkMontgomeryMul(b *testing.B) { func BenchmarkMontgomeryMul(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
y := makeBenchmarkValue() y := makeBenchmarkValue(32)
out := makeBenchmarkValue() out := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -428,9 +428,20 @@ func BenchmarkMontgomeryMul(b *testing.B) {
} }
func BenchmarkModMul(b *testing.B) { func BenchmarkModMul(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
y := makeBenchmarkValue() y := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(y, m)
}
}
func BenchmarkModMul256(b *testing.B) {
x := makeBenchmarkValue(4)
y := makeBenchmarkValue(4)
m := makeBenchmarkModulus(4)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
@ -454,10 +465,10 @@ func BenchmarkExpBig(b *testing.B) {
} }
func BenchmarkExp(b *testing.B) { func BenchmarkExp(b *testing.B) {
x := makeBenchmarkValue() x := makeBenchmarkValue(32)
e := makeBenchmarkExponent() e := makeBenchmarkExponent()
out := makeBenchmarkValue() out := makeBenchmarkValue(32)
m := makeBenchmarkModulus() m := makeBenchmarkModulus(32)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {