sm2ec: optimize amd64 with MULX & AVX2

This commit is contained in:
Sun Yimin 2023-06-10 10:55:17 +08:00 committed by GitHub
parent df8cb4d95d
commit a0c4a389b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 2003 additions and 264 deletions

View File

@ -649,6 +649,7 @@ func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
} }
copy(x.reset(n).limbs, T[n:]) copy(x.reset(n).limbs, T[n:])
x.maybeSubtractModulus(choice(c), m) x.maybeSubtractModulus(choice(c), m)
case 1024 / _W: case 1024 / _W:
const n = 1024 / _W // compiler hint const n = 1024 / _W // compiler hint
T := make([]uint, n*2) T := make([]uint, n*2)

File diff suppressed because it is too large Load Diff

View File

@ -1364,13 +1364,13 @@ TEXT ·p256PointDoubleAsm(SB),NOSPLIT,$136-16
LDP p256p<>+0x10(SB), (const2, const3) LDP p256p<>+0x10(SB), (const2, const3)
// Begin point double // Begin point double
LDP 4*16(a_ptr), (x0, x1) LDP 4*16(a_ptr), (x0, x1) // load z
LDP 5*16(a_ptr), (x2, x3) LDP 5*16(a_ptr), (x2, x3)
CALL sm2P256SqrInternal<>(SB) CALL sm2P256SqrInternal<>(SB)
STP (y0, y1), zsqr(0*8) STP (y0, y1), zsqr(0*8) // store z^2
STP (y2, y3), zsqr(2*8) STP (y2, y3), zsqr(2*8)
LDP 0*16(a_ptr), (x0, x1) LDP 0*16(a_ptr), (x0, x1) // load x
LDP 1*16(a_ptr), (x2, x3) LDP 1*16(a_ptr), (x2, x3)
p256AddInline p256AddInline
STx(m) STx(m)
@ -1446,6 +1446,187 @@ TEXT ·p256PointDoubleAsm(SB),NOSPLIT,$136-16
CALL sm2P256Subinternal<>(SB) CALL sm2P256Subinternal<>(SB)
STx(y3out) STx(y3out)
RET RET
#define p256PointDoubleRound() \
LDx(z3out) \ // load z
CALL sm2P256SqrInternal<>(SB) \
STP (y0, y1), zsqr(0*8) \ // store z^2
STP (y2, y3), zsqr(2*8) \
\
LDx(x3out) \// load x
p256AddInline \
STx(m) \
\
LDx(z3out) \ // load z
LDy(y3out) \ // load y
CALL sm2P256MulInternal<>(SB) \
p256MulBy2Inline \
STx(z3out) \ // store result z
\
LDy(x3out) \ // load x
LDx(zsqr) \
CALL sm2P256Subinternal<>(SB) \
LDy(m) \
CALL sm2P256MulInternal<>(SB) \
\
\// Multiply by 3
p256MulBy2Inline \
p256AddInline \
STx(m) \
\
LDy(y3out) \ // load y
p256MulBy2Inline \
CALL sm2P256SqrInternal<>(SB) \
STy(s) \
MOVD y0, x0 \
MOVD y1, x1 \
MOVD y2, x2 \
MOVD y3, x3 \
CALL sm2P256SqrInternal<>(SB) \
\
\// Divide by 2
ADDS const0, y0, t0 \
ADCS const1, y1, t1 \
ADCS const2, y2, acc5 \
ADCS const3, y3, acc6 \
ADC $0, ZR, hlp0 \
\
ANDS $1, y0, ZR \
CSEL EQ, y0, t0, t0 \
CSEL EQ, y1, t1, t1 \
CSEL EQ, y2, acc5, acc5 \
CSEL EQ, y3, acc6, acc6 \
AND y0, hlp0, hlp0 \
\
EXTR $1, t0, t1, y0 \
EXTR $1, t1, acc5, y1 \
EXTR $1, acc5, acc6, y2 \
EXTR $1, acc6, hlp0, y3 \
STy(y3out) \
\
LDx(x3out) \ // load x
LDy(s) \
CALL sm2P256MulInternal<>(SB) \
STy(s) \
p256MulBy2Inline \
STx(tmp) \
\
LDx(m) \
CALL sm2P256SqrInternal<>(SB) \
LDx(tmp) \
CALL sm2P256Subinternal<>(SB) \
\
STx(x3out) \
\
LDy(s) \
CALL sm2P256Subinternal<>(SB) \
\
LDy(m) \
CALL sm2P256MulInternal<>(SB) \
\
LDx(y3out) \
CALL sm2P256Subinternal<>(SB) \
STx(y3out) \
//func p256PointDoubleAsm(res, in *SM2P256Point)
TEXT ·p256PointDouble5TimesAsm(SB),NOSPLIT,$136-16
MOVD res+0(FP), res_ptr
MOVD in+8(FP), a_ptr
LDP p256p<>+0x00(SB), (const0, const1)
LDP p256p<>+0x10(SB), (const2, const3)
// Begin point double round 1
LDP 4*16(a_ptr), (x0, x1) // load z
LDP 5*16(a_ptr), (x2, x3)
CALL sm2P256SqrInternal<>(SB)
STP (y0, y1), zsqr(0*8) // store z^2
STP (y2, y3), zsqr(2*8)
LDP 0*16(a_ptr), (x0, x1) // load x
LDP 1*16(a_ptr), (x2, x3)
p256AddInline
STx(m)
LDx(z1in) // load z
LDy(y1in) // load y
CALL sm2P256MulInternal<>(SB)
p256MulBy2Inline
STx(z3out) // store result z
LDy(x1in) // load x
LDx(zsqr)
CALL sm2P256Subinternal<>(SB)
LDy(m)
CALL sm2P256MulInternal<>(SB)
// Multiply by 3
p256MulBy2Inline
p256AddInline
STx(m)
LDy(y1in) // load y
p256MulBy2Inline
CALL sm2P256SqrInternal<>(SB)
STy(s)
MOVD y0, x0
MOVD y1, x1
MOVD y2, x2
MOVD y3, x3
CALL sm2P256SqrInternal<>(SB)
// Divide by 2
ADDS const0, y0, t0
ADCS const1, y1, t1
ADCS const2, y2, acc5
ADCS const3, y3, acc6
ADC $0, ZR, hlp0
ANDS $1, y0, ZR
CSEL EQ, y0, t0, t0
CSEL EQ, y1, t1, t1
CSEL EQ, y2, acc5, acc5
CSEL EQ, y3, acc6, acc6
AND y0, hlp0, hlp0
EXTR $1, t0, t1, y0
EXTR $1, t1, acc5, y1
EXTR $1, acc5, acc6, y2
EXTR $1, acc6, hlp0, y3
STy(y3out)
LDx(x1in) // load x
LDy(s)
CALL sm2P256MulInternal<>(SB)
STy(s)
p256MulBy2Inline
STx(tmp)
LDx(m)
CALL sm2P256SqrInternal<>(SB)
LDx(tmp)
CALL sm2P256Subinternal<>(SB)
STx(x3out)
LDy(s)
CALL sm2P256Subinternal<>(SB)
LDy(m)
CALL sm2P256MulInternal<>(SB)
LDx(y3out)
CALL sm2P256Subinternal<>(SB)
STx(y3out)
// Begin point double rounds 2 - 5
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()
p256PointDoubleRound()
RET
/* ---------------------------------------*/ /* ---------------------------------------*/
#undef y2in #undef y2in
#undef x3out #undef x3out

View File

@ -162,6 +162,43 @@ func TestForSqrt(t *testing.T) {
exp.Div(exp, big.NewInt(4)) exp.Div(exp, big.NewInt(4))
} }
func TestEquivalents(t *testing.T) {
p := NewSM2P256Point().SetGenerator()
elementSize := 32
two := make([]byte, elementSize)
two[len(two)-1] = 2
nPlusTwo := make([]byte, elementSize)
new(big.Int).Add(sm2n, big.NewInt(2)).FillBytes(nPlusTwo)
p1 := NewSM2P256Point().Double(p)
p2 := NewSM2P256Point().Add(p, p)
p3, err := NewSM2P256Point().ScalarMult(p, two)
fatalIfErr(t, err)
p4, err := NewSM2P256Point().ScalarBaseMult(two)
fatalIfErr(t, err)
p5, err := NewSM2P256Point().ScalarMult(p, nPlusTwo)
fatalIfErr(t, err)
p6, err := NewSM2P256Point().ScalarBaseMult(nPlusTwo)
fatalIfErr(t, err)
if !bytes.Equal(p1.Bytes(), p2.Bytes()) {
t.Error("P+P != 2*P")
}
if !bytes.Equal(p1.Bytes(), p3.Bytes()) {
t.Error("P+P != [2]P")
}
if !bytes.Equal(p1.Bytes(), p4.Bytes()) {
t.Error("G+G != [2]G")
}
if !bytes.Equal(p1.Bytes(), p5.Bytes()) {
t.Error("P+P != [N+2]P")
}
if !bytes.Equal(p1.Bytes(), p6.Bytes()) {
t.Error("G+G != [N+2]G")
}
}
func TestScalarMult(t *testing.T) { func TestScalarMult(t *testing.T) {
G := NewSM2P256Point().SetGenerator() G := NewSM2P256Point().SetGenerator()
checkScalar := func(t *testing.T, scalar []byte) { checkScalar := func(t *testing.T, scalar []byte) {

View File

@ -17,6 +17,8 @@ import (
"errors" "errors"
"math/bits" "math/bits"
"unsafe" "unsafe"
"golang.org/x/sys/cpu"
) )
// p256Element is a P-256 base field element in [0, P-1] in the Montgomery // p256Element is a P-256 base field element in [0, P-1] in the Montgomery
@ -309,6 +311,9 @@ func p256Sqrt(e, x *p256Element) (isSquare bool) {
} }
// The following assembly functions are implemented in p256_asm_*.s // The following assembly functions are implemented in p256_asm_*.s
var supportBMI2 = cpu.X86.HasBMI2
var supportAVX2 = cpu.X86.HasAVX2
// Montgomery multiplication. Sets res = in1 * in2 * R⁻¹ mod p. // Montgomery multiplication. Sets res = in1 * in2 * R⁻¹ mod p.
// //
@ -411,6 +416,11 @@ func p256PointAddAsm(res, in1, in2 *SM2P256Point) int
//go:noescape //go:noescape
func p256PointDoubleAsm(res, in *SM2P256Point) func p256PointDoubleAsm(res, in *SM2P256Point)
// Point doubling 5 times. in can be the point at infinity.
//
//go:noescape
func p256PointDouble5TimesAsm(res, in *SM2P256Point)
// p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the // p256OrdElement is a P-256 scalar field element in [0, ord(G)-1] in the
// Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order. // Montgomery domain (with R 2²⁵⁶) as four uint64 limbs in little-endian order.
type p256OrdElement [4]uint64 type p256OrdElement [4]uint64
@ -867,11 +877,8 @@ func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) {
for index > 4 { for index > 4 {
index -= 5 index -= 5
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p) p256PointDouble5TimesAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
if index < 192 { if index < 192 {
wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x3f wvalue = ((scalar[index/64] >> (index % 64)) + (scalar[index/64+1] << (64 - (index % 64)))) & 0x3f
@ -888,12 +895,7 @@ func (p *SM2P256Point) p256ScalarMult(scalar *p256OrdElement) {
p256MovCond(p, &t1, &t0, zero) p256MovCond(p, &t1, &t0, zero)
zero |= sel zero |= sel
} }
p256PointDouble5TimesAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
p256PointDoubleAsm(p, p)
wvalue = (scalar[0] << 1) & 0x3f wvalue = (scalar[0] << 1) & 0x3f
sel, sign = boothW5(uint(wvalue)) sel, sign = boothW5(uint(wvalue))

View File

@ -118,6 +118,22 @@ func TestFuzzyP256Sqr(t *testing.T) {
} }
} }
func BenchmarkP256Sqr(b *testing.B) {
p, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
r, _ := new(big.Int).SetString("10000000000000000000000000000000000000000000000000000000000000000", 16)
var scalar1 [32]byte
io.ReadFull(rand.Reader, scalar1[:])
x := new(big.Int).SetBytes(scalar1[:])
x1 := new(big.Int).Mul(x, r)
x1 = x1.Mod(x1, p)
ax := new(p256Element)
res := new(p256Element)
fromBig(ax, x1)
for i := 0; i < b.N; i++ {
p256Sqr(res, ax, 20)
}
}
func Test_p256Inverse(t *testing.T) { func Test_p256Inverse(t *testing.T) {
r, _ := new(big.Int).SetString("10000000000000000000000000000000000000000000000000000000000000000", 16) r, _ := new(big.Int).SetString("10000000000000000000000000000000000000000000000000000000000000000", 16)
p, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16) p, _ := new(big.Int).SetString("FFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF00000000FFFFFFFFFFFFFFFF", 16)
@ -133,3 +149,10 @@ func Test_p256Inverse(t *testing.T) {
t.Errorf("expected %v, got %v", hex.EncodeToString(xInv.Bytes()), hex.EncodeToString(resInt.Bytes())) t.Errorf("expected %v, got %v", hex.EncodeToString(xInv.Bytes()), hex.EncodeToString(resInt.Bytes()))
} }
} }
func BenchmarkP256SelectAffine(b *testing.B) {
var t0 p256AffinePoint
for i := 0; i < b.N; i++ {
p256SelectAffine(&t0, &p256Precomputed[20], 20)
}
}