sm9/bn256: asm implementation for gfP Marshal/Unmarshal #140

This commit is contained in:
Sun Yimin 2023-07-07 18:09:49 +08:00 committed by GitHub
parent ebe5aca2d8
commit 0e54e68bfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 222 additions and 102 deletions

View File

@ -56,7 +56,7 @@ func (c *curvePoint) IsOnCurve() bool {
x3 := c.polynomial(&c.x)
return *y2 == *x3
return y2.Equal(x3) == 1
}
func NewCurvePoint() *curvePoint {
@ -72,14 +72,14 @@ func NewCurveGenerator() *curvePoint {
}
func (c *curvePoint) SetInfinity() {
c.x = *zero
c.y = *one
c.z = *zero
c.t = *zero
c.x.Set(zero)
c.y.Set(one)
c.z.Set(zero)
c.t.Set(zero)
}
func (c *curvePoint) IsInfinity() bool {
return c.z == *zero
return c.z.Equal(zero) == 1
}
func (c *curvePoint) Add(a, b *curvePoint) {
@ -122,7 +122,6 @@ func (c *curvePoint) Add(a, b *curvePoint) {
// with the notations below.
h := &gfP{}
gfpSub(h, u2, u1)
xEqual := *h == *zero
gfpAdd(t, h, h)
// i = 4h²
@ -133,8 +132,8 @@ func (c *curvePoint) Add(a, b *curvePoint) {
gfpMul(j, h, i)
gfpSub(t, s2, s1)
yEqual := *t == *one
if xEqual && yEqual {
if h.Equal(zero) == 1 && t.Equal(one) == 1 {
c.Double(a)
return
}
@ -219,12 +218,12 @@ func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int) {
}
func (c *curvePoint) MakeAffine() {
if c.z == *one {
if c.z.Equal(one) == 1 {
return
} else if c.z == *zero {
c.x = *zero
c.y = *one
c.t = *zero
} else if c.z.Equal(zero) == 1 {
c.x.Set(zero)
c.y.Set(one)
c.t.Set(zero)
return
}
@ -238,24 +237,15 @@ func (c *curvePoint) MakeAffine() {
gfpMul(&c.x, &c.x, zInv2)
gfpMul(&c.y, t, zInv2)
c.z = *one
c.t = *one
c.z.Set(one)
c.t.Set(one)
}
func (c *curvePoint) Neg(a *curvePoint) {
c.x.Set(&a.x)
gfpNeg(&c.y, &a.y)
c.z.Set(&a.z)
c.t = *zero
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *curvePoint) Select(p1, p2 *curvePoint, cond int) *curvePoint {
q.x.Select(&p1.x, &p2.x, cond)
q.y.Select(&p1.y, &p2.y, cond)
q.z.Select(&p1.z, &p2.z, cond)
q.t.Select(&p1.t, &p2.t, cond)
return q
c.t.Set(zero)
}
// A curvePointTable holds the first 15 multiples of a point at offset -1, so [1]P

View File

@ -5,7 +5,6 @@ import (
"errors"
"io"
"math/big"
"math/bits"
"sync"
)
@ -282,7 +281,8 @@ func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) {
if e.p == nil {
e.p = &curvePoint{}
} else {
e.p.x, e.p.y = gfP{0}, gfP{0}
e.p.x.Set(zero)
e.p.y.Set(zero)
}
e.p.x.Unmarshal(data[1:])
montEncode(&e.p.x, &e.p.x)
@ -292,14 +292,12 @@ func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) {
if byte(x3[0]&1) != data[0]&1 {
gfpNeg(&e.p.y, &e.p.y)
}
if e.p.x == *zero && e.p.y == *zero {
if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
// This is the point at infinity.
e.p.y = *newGFp(1)
e.p.z = gfP{0}
e.p.t = gfP{0}
e.p.SetInfinity()
} else {
e.p.z = *newGFp(1)
e.p.t = *newGFp(1)
e.p.z.Set(one)
e.p.t.Set(one)
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G1: malformed point")
@ -341,7 +339,8 @@ func (e *G1) Unmarshal(m []byte) ([]byte, error) {
if e.p == nil {
e.p = &curvePoint{}
} else {
e.p.x, e.p.y = gfP{0}, gfP{0}
e.p.x.Set(zero)
e.p.y.Set(zero)
}
e.p.x.Unmarshal(m)
@ -349,14 +348,12 @@ func (e *G1) Unmarshal(m []byte) ([]byte, error) {
montEncode(&e.p.x, &e.p.x)
montEncode(&e.p.y, &e.p.y)
if e.p.x == *zero && e.p.y == *zero {
if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
// This is the point at infinity.
e.p.y = *newGFp(1)
e.p.z = gfP{0}
e.p.t = gfP{0}
e.p.SetInfinity()
} else {
e.p.z = *newGFp(1)
e.p.t = *newGFp(1)
e.p.z.Set(one)
e.p.t.Set(one)
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G1: malformed point")
@ -371,10 +368,10 @@ func (e *G1) Equal(other *G1) bool {
if e.p == nil && other.p == nil {
return true
}
return e.p.x == other.p.x &&
e.p.y == other.p.y &&
e.p.z == other.p.z &&
e.p.t == other.p.t
return e.p.x.Equal(&other.p.x) == 1 &&
e.p.y.Equal(&other.p.y) == 1 &&
e.p.z.Equal(&other.p.z) == 1 &&
e.p.t.Equal(&other.p.t) == 1
}
// IsOnCurve returns true if e is on the curve.
@ -492,15 +489,6 @@ func (g1 *G1Curve) IsOnCurve(x, y *big.Int) bool {
return err == nil
}
func lessThanP(x *gfP) int {
var b uint64
_, b = bits.Sub64(x[0], p2[0], b)
_, b = bits.Sub64(x[1], p2[1], b)
_, b = bits.Sub64(x[2], p2[2], b)
_, b = bits.Sub64(x[3], p2[3], b)
return int(b)
}
func (curve *G1Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) {
if len(data) != 33 || (data[0] != 2 && data[0] != 3) {
return nil, nil

View File

@ -354,10 +354,10 @@ func (e *G2) Equal(other *G2) bool {
if e.p == nil && other.p == nil {
return true
}
return e.p.x == other.p.x &&
e.p.y == other.p.y &&
e.p.z == other.p.z &&
e.p.t == other.p.t
return e.p.x.Equal(&other.p.x) == 1 &&
e.p.y.Equal(&other.p.y) == 1 &&
e.p.z.Equal(&other.p.z) == 1 &&
e.p.t.Equal(&other.p.t) == 1
}
// IsOnCurve returns true if e is on the twist curve.

View File

@ -5,6 +5,8 @@ import (
"errors"
"fmt"
"math/big"
"math/bits"
"unsafe"
)
type gfP [4]uint64
@ -94,10 +96,11 @@ func (e *gfP) Square(a *gfP, n int) *gfP {
// Equal returns 1 if e == t, and zero otherwise.
func (e *gfP) Equal(t *gfP) int {
if *e == *t {
return 1
var acc uint64
for i := range e {
acc |= e[i] ^ t[i]
}
return 0
return uint64IsZero(acc)
}
func (e *gfP) Sqrt(f *gfP) {
@ -117,23 +120,43 @@ func (e *gfP) Sqrt(f *gfP) {
e.Set(i)
}
// toElementArray, convert slice of bytes to pointer to [32]byte.
// This function is required for low version of golang, can type cast directly
// since golang 1.17.
func toElementArray(b []byte) *[32]byte {
tmpPtr := (*unsafe.Pointer)(unsafe.Pointer(&b))
return (*[32]byte)(*tmpPtr)
}
func (e *gfP) Marshal(out []byte) {
for w := uint(0); w < 4; w++ {
for b := uint(0); b < 8; b++ {
out[8*w+b] = byte(e[3-w] >> (56 - 8*b))
}
}
gfpMarshal(toElementArray(out), e)
}
// uint64IsZero returns 1 if x is zero and zero otherwise.
func uint64IsZero(x uint64) int {
x = ^x
x &= x >> 32
x &= x >> 16
x &= x >> 8
x &= x >> 4
x &= x >> 2
x &= x >> 1
return int(x & 1)
}
func lessThanP(x *gfP) int {
var b uint64
_, b = bits.Sub64(x[0], p2[0], b)
_, b = bits.Sub64(x[1], p2[1], b)
_, b = bits.Sub64(x[2], p2[2], b)
_, b = bits.Sub64(x[3], p2[3], b)
return int(b)
}
func (e *gfP) Unmarshal(in []byte) error {
// Unmarshal the bytes into little endian form
for w := uint(0); w < 4; w++ {
e[3-w] = 0
for b := uint(0); b < 8; b++ {
e[3-w] += uint64(in[8*w+b]) << (56 - 8*b)
}
}
gfpUnmarshal(e, toElementArray(in))
// Ensure the point respects the curve modulus
// TODO: Do we need to change it to constant time version ?
for i := 3; i >= 0; i-- {
if e[i] < p2[i] {
return nil
@ -151,14 +174,18 @@ func montDecode(c, a *gfP) { gfpFromMont(c, a) }
// cmovznzU64 is a single-word conditional move.
//
// Postconditions:
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// out1 = (if arg1 = 0 then arg2 else arg3)
//
// Input Bounds:
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// arg1: [0x0 ~> 0x1]
// arg2: [0x0 ~> 0xffffffffffffffff]
// arg3: [0x0 ~> 0xffffffffffffffff]
//
// Output Bounds:
// out1: [0x0 ~> 0xffffffffffffffff]
//
// out1: [0x0 ~> 0xffffffffffffffff]
func cmovznzU64(out1 *uint64, arg1 uint64, arg2 uint64, arg3 uint64) {
x1 := (uint64(arg1) * 0xffffffffffffffff)
x2 := ((x1 & arg3) | ((^x1) & arg2))

View File

@ -140,6 +140,7 @@ func (e *gfP12) Mul(a, b *gfP12) *gfP12 {
return e
}
// Mul without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP12) MulNC(a, b *gfP12) *gfP12 {
// (z0 + y0*w + x0*w^2)* (z1 + y1*w + x1*w^2)
// z0*z1 + z0*y1*w + z0*x1*w^2
@ -187,6 +188,7 @@ func (e *gfP12) Square(a *gfP12) *gfP12 {
return e
}
// Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP12) SquareNC(a *gfP12) *gfP12 {
// (z + y*w + x*w^2)* (z + y*w + x*w^2)
// z^2 + z*y*w + z*x*w^2 + y*z*w + y^2*w^2 + y*x*v + x*z*w^2 + x*y*v + x^2 *v *w
@ -303,6 +305,7 @@ func (e *gfP12) SpecialSquares(a *gfP12, n int) *gfP12 {
return e
}
// Special Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP12) SpecialSquareNC(a *gfP12) *gfP12 {
tx, ty, tz := &gfP4{}, &gfP4{}, &gfP4{}

View File

@ -132,6 +132,7 @@ func (e *gfP12b6) Mul(a, b *gfP12b6) *gfP12b6 {
return e
}
// Mul without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP12b6) MulNC(a, b *gfP12b6) *gfP12b6 {
// "Multiplication and Squaring on Pairing-Friendly Fields"
// Section 4, Karatsuba method.
@ -183,6 +184,7 @@ func (e *gfP12b6) Square(a *gfP12b6) *gfP12b6 {
return e
}
// Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP12b6) SquareNC(a *gfP12b6) *gfP12b6 {
// Complex squaring algorithm
// (xt+y)² = (x^2*s + y^2) + 2*x*y*t
@ -211,6 +213,7 @@ func (e *gfP12b6) SpecialSquare(a *gfP12b6) *gfP12b6 {
return e
}
// Special Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP12b6) SpecialSquareNC(a *gfP12b6) *gfP12b6 {
f02 := &e.y.x
f01 := &e.y.y

View File

@ -31,35 +31,46 @@ func (e *gfP2) Set(a *gfP2) *gfP2 {
}
func (e *gfP2) SetZero() *gfP2 {
e.x = *zero
e.y = *zero
e.x.Set(zero)
e.y.Set(zero)
return e
}
func (e *gfP2) SetOne() *gfP2 {
e.x = *zero
e.y = *one
e.x.Set(zero)
e.y.Set(one)
return e
}
func (e *gfP2) SetU() *gfP2 {
e.x = *one
e.y = *zero
e.x.Set(one)
e.y.Set(zero)
return e
}
func (e *gfP2) SetFrobConstant() *gfP2 {
e.x = *zero
e.y = *frobConstant
e.x.Set(zero)
e.y.Set(frobConstant)
return e
}
func (e *gfP2) Equal(t *gfP2) int {
var acc uint64
for i := range e.x {
acc |= e.x[i] ^ t.x[i]
}
for i := range e.y {
acc |= e.y[i] ^ t.y[i]
}
return uint64IsZero(acc)
}
func (e *gfP2) IsZero() bool {
return e.x == *zero && e.y == *zero
return (e.x.Equal(zero) == 1) && (e.y.Equal(zero) == 1)
}
func (e *gfP2) IsOne() bool {
return e.x == *zero && e.y == *one
return (e.x.Equal(zero) == 1) && (e.y.Equal(one) == 1)
}
func (e *gfP2) Conjugate(a *gfP2) *gfP2 {
@ -114,7 +125,7 @@ func (e *gfP2) Mul(a, b *gfP2) *gfP2 {
return e
}
// Mul without Copy
// Mul without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP2) MulNC(a, b *gfP2) *gfP2 {
tx := &e.x
ty := &e.y
@ -142,6 +153,7 @@ func (e *gfP2) MulU(a, b *gfP2) *gfP2 {
return e
}
// MulU without value copy, will use e directly, so e can't be same as a and b.
// MulU: a * b * u
// (a0+a1*u)(b0+b1*u)*u=c0+c1*u, where
// c1 = (a0*b0 - 2a1*b1)u
@ -192,6 +204,7 @@ func (e *gfP2) Square(a *gfP2) *gfP2 {
return e
}
// Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP2) SquareNC(a *gfP2) *gfP2 {
// Complex squaring algorithm:
// (xu+y)² = y^2-2*x^2 + 2*u*x*y
@ -219,6 +232,7 @@ func (e *gfP2) SquareU(a *gfP2) *gfP2 {
return e
}
// SquareU without value copy, will use e directly, so e can't be same as a.
func (e *gfP2) SquareUNC(a *gfP2) *gfP2 {
// Complex squaring algorithm:
// (xu+y)²*u = (y^2-2*x^2)u - 4*x*y
@ -312,7 +326,7 @@ func (ret *gfP2) Sqrt(a *gfP2) *gfP2 {
a0 = gfP2Decode(a0)
*/
t.Mul(bq, b)
if t.x == *zero && t.y == *one {
if t.x.Equal(zero) == 1 && t.y.Equal(one) == 1 {
t.Mul(b2, a)
x0.Sqrt(&t.y)
t.MulScalar(bq, x0)

View File

@ -104,6 +104,7 @@ func (e *gfP4) Mul(a, b *gfP4) *gfP4 {
return e
}
// Mul without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP4) MulNC(a, b *gfP4) *gfP4 {
// "Multiplication and Squaring on Pairing-Friendly Fields"
// Section 4, Karatsuba method.
@ -130,7 +131,7 @@ func (e *gfP4) MulNC(a, b *gfP4) *gfP4 {
}
// MulNC2 muls a with (xv+y), this method is used in mulLine function
// to avoid gfP4 instance construction.
// to avoid gfP4 instance construction. e can't be same as a.
func (e *gfP4) MulNC2(a *gfP4, x, y *gfP2) *gfP4 {
// "Multiplication and Squaring on Pairing-Friendly Fields"
// Section 4, Karatsuba method.
@ -169,6 +170,7 @@ func (e *gfP4) MulV(a, b *gfP4) *gfP4 {
return e
}
// MulV without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP4) MulVNC(a, b *gfP4) *gfP4 {
tx := &e.x
ty := &e.y
@ -211,6 +213,7 @@ func (e *gfP4) Square(a *gfP4) *gfP4 {
return e
}
// Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP4) SquareNC(a *gfP4) *gfP4 {
// Complex squaring algorithm:
// (xv+y)² = (x^2*u + y^2) + 2*x*y*v
@ -228,6 +231,7 @@ func (e *gfP4) SquareNC(a *gfP4) *gfP4 {
return e
}
// SquareV without value copy, will use e directly, so e can't be same as a.
// SquareV: (a^2) * v
// v*(xv+y)² = (x^2*u + y^2)v + 2*x*y*u
func (e *gfP4) SquareV(a *gfP4) *gfP4 {

View File

@ -108,6 +108,7 @@ func (e *gfP6) Mul(a, b *gfP6) *gfP6 {
return e
}
// Mul without value copy, will use e directly, so e can't be same as a and b.
func (e *gfP6) MulNC(a, b *gfP6) *gfP6 {
// (z0 + y0*s + x0*s²)* (z1 + y1*s + x1*s²)
// z0*z1 + z0*y1*s + z0*x1*s²
@ -172,6 +173,7 @@ func (e *gfP6) Square(a *gfP6) *gfP6 {
return e
}
// Square without value copy, will use e directly, so e can't be same as a.
func (e *gfP6) SquareNC(a *gfP6) *gfP6 {
// (z + y*s + x*s²)* (z + y*s + x*s²)
// z^2 + z*y*s + z*x*s² + y*z*s + y^2*s² + y*x*u + x*z*s² + x*y*u + x^2 *u *s

View File

@ -1162,3 +1162,31 @@ TEXT ·gfpFromMont(SB),NOSPLIT,$0
gfpCarryWithoutCarry(acc4, acc5, acc0, acc1, x_ptr, acc3, t0, t1)
storeBlock(acc4,acc5,acc0,acc1, 0(res_ptr))
RET
/* ---------------------------------------*/
// func gfpUnmarshal(res *gfP, in *[32]byte)
TEXT ·gfpUnmarshal(SB),NOSPLIT,$0
JMP ·gfpMarshal(SB)
/* ---------------------------------------*/
// func gfpMarshal(res *[32]byte, in *gfP)
TEXT ·gfpMarshal(SB),NOSPLIT,$0
MOVQ res+0(FP), res_ptr
MOVQ in+8(FP), x_ptr
MOVQ (8*0)(x_ptr), acc0
MOVQ (8*1)(x_ptr), acc1
MOVQ (8*2)(x_ptr), acc2
MOVQ (8*3)(x_ptr), acc3
BSWAPQ acc0
BSWAPQ acc1
BSWAPQ acc2
BSWAPQ acc3
MOVQ acc3, (8*0)(res_ptr)
MOVQ acc2, (8*1)(res_ptr)
MOVQ acc1, (8*2)(res_ptr)
MOVQ acc0, (8*3)(res_ptr)
RET

View File

@ -682,3 +682,26 @@ TEXT ·gfpFromMont(SB),NOSPLIT,$0
STP (x2, x3), 1*16(res_ptr)
RET
/* ---------------------------------------*/
// func gfpUnmarshal(res *gfP, in *[32]byte)
TEXT ·gfpUnmarshal(SB),NOSPLIT,$0
JMP ·gfpMarshal(SB)
/* ---------------------------------------*/
// func gfpMarshal(res *[32]byte, in *gfP)
TEXT ·gfpMarshal(SB),NOSPLIT,$0
MOVD res+0(FP), res_ptr
MOVD in+8(FP), a_ptr
LDP 0*16(a_ptr), (acc0, acc1)
LDP 1*16(a_ptr), (acc2, acc3)
REV acc0, acc0
REV acc1, acc1
REV acc2, acc2
REV acc3, acc3
STP (acc3, acc2), 0*16(res_ptr)
STP (acc1, acc0), 1*16(res_ptr)
RET

View File

@ -46,3 +46,13 @@ func gfpSqr(res, in *gfP, n int)
//
//go:noescape
func gfpFromMont(res, in *gfP)
// Marshal gfP into big endian form
//
//go:noescape
func gfpMarshal(out *[32]byte, in *gfP)
// Unmarshal the bytes into little endian form
//
//go:noescape
func gfpUnmarshal(out *gfP, in *[32]byte)

View File

@ -146,3 +146,20 @@ func gfpFromMont(res, in *gfP) {
*res = gfP{T[4], T[5], T[6], T[7]}
gfpCarry(res, carry)
}
func gfpMarshal(out *[32]byte, in *gfP) {
for w := uint(0); w < 4; w++ {
for b := uint(0); b < 8; b++ {
out[8*w+b] = byte(in[3-w] >> (56 - 8*b))
}
}
}
func gfpUnmarshal(out *gfP, in *[32]byte) {
for w := uint(0); w < 4; w++ {
out[3-w] = 0
for b := uint(0); b < 8; b++ {
out[3-w] += uint64(in[8*w+b]) << (56 - 8*b)
}
}
}

View File

@ -195,6 +195,17 @@ func TestGfpNeg(t *testing.T) {
}
}
func BenchmarkGfPUnmarshal(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()
b.ResetTimer()
var out [32]byte
x.Marshal(out[:])
for i := 0; i < b.N; i++ {
x.Unmarshal(out[:])
}
}
func BenchmarkGfPMul(b *testing.B) {
x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596"))
b.ReportAllocs()

View File

@ -49,6 +49,15 @@ func TestGT(t *testing.T) {
}
}
func BenchmarkGTMarshal(b *testing.B) {
x := &GT{gfP12Gen}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Marshal()
}
}
func BenchmarkGT(b *testing.B) {
x, _ := rand.Int(rand.Reader, Order)
b.ReportAllocs()

View File

@ -73,7 +73,7 @@ func (c *twistPoint) IsOnCurve() bool {
y2.SquareNC(&c.y)
x3 := c.polynomial(&c.x)
return *y2 == *x3
return y2.Equal(x3) == 1
}
func (c *twistPoint) SetInfinity() {
@ -241,15 +241,6 @@ func (c *twistPoint) NegFrobeniusP2(a *twistPoint) {
c.t.Square(&a.z)
}
// Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *twistPoint) Select(p1, p2 *twistPoint, cond int) *twistPoint {
q.x.Select(&p1.x, &p2.x, cond)
q.y.Select(&p1.y, &p2.y, cond)
q.z.Select(&p1.z, &p2.z, cond)
q.t.Select(&p1.t, &p2.t, cond)
return q
}
// A twistPointTable holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point.