internal/sm9/bn256: make gfP.Unmarshal constant time

This commit is contained in:
Sun Yimin 2025-03-26 16:37:04 +08:00 committed by GitHub
parent 9ea8293d10
commit b8d52dd11d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 74 additions and 11 deletions

View File

@ -12,6 +12,11 @@ var zero = newGFp(0)
var one = newGFp(1) var one = newGFp(1)
var two = newGFp(2) var two = newGFp(2)
// newGFp creates a new gfP element from the given int64 value.
// If the input value is non-negative, it directly converts it to uint64.
// If the input value is negative, it converts the absolute value to uint64
// and then negates the resulting gfP element.
// The resulting gfP element is then encoded in Montgomery form.
func newGFp(x int64) (out *gfP) { func newGFp(x int64) (out *gfP) {
if x >= 0 { if x >= 0 {
out = &gfP{uint64(x)} out = &gfP{uint64(x)}
@ -24,6 +29,8 @@ func newGFp(x int64) (out *gfP) {
return out return out
} }
// newGFpFromBytes creates a new gfP element from a byte slice.
// It unmarshals the byte slice into a gfP element, then encodes it in Montgomery form.
func newGFpFromBytes(in []byte) (out *gfP) { func newGFpFromBytes(in []byte) (out *gfP) {
out = &gfP{} out = &gfP{}
gfpUnmarshal(out, (*[32]byte)(in)) gfpUnmarshal(out, (*[32]byte)(in))
@ -40,6 +47,17 @@ func (e *gfP) Set(f *gfP) *gfP {
return e return e
} }
// exp calculates the exponentiation of a given gfP element `f` raised to the power
// represented by the 256-bit integer `bits`. The result is stored in the gfP element `e`.
//
// The function uses a square-and-multiply algorithm to perform the exponentiation.
// It iterates over each bit of the 256-bit integer `bits`, and for each bit, it squares
// the current power and multiplies it to the result if the bit is set.
//
// Parameters:
// - f: The base gfP element to be exponentiated.
// - bits: A 256-bit integer represented as an array of 4 uint64 values, where bits[0]
// contains the least significant 64 bits and bits[3] contains the most significant 64 bits.
func (e *gfP) exp(f *gfP, bits [4]uint64) { func (e *gfP) exp(f *gfP, bits [4]uint64) {
sum, power := &gfP{}, &gfP{} sum, power := &gfP{}, &gfP{}
sum.Set(rN1) sum.Set(rN1)
@ -94,7 +112,12 @@ func (e *gfP) Sqrt(f *gfP) {
e.Set(i) e.Set(i)
} }
// Marshal serializes the gfP element into the provided byte slice.
// The output byte slice must be at least 32 bytes long.
func (e *gfP) Marshal(out []byte) { func (e *gfP) Marshal(out []byte) {
if len(out) < 32 {
panic("sm9: invalid out length")
}
gfpMarshal((*[32]byte)(out), e) gfpMarshal((*[32]byte)(out), e)
} }
@ -110,6 +133,9 @@ func uint64IsZero(x uint64) int {
return int(x & 1) return int(x & 1)
} }
// lessThanP returns 1 if the given gfP element x is less than the prime modulus p2,
// and 0 otherwise. It performs a subtraction of x from p2 and checks the borrow bit
// to determine if x is less than p2.
func lessThanP(x *gfP) int { func lessThanP(x *gfP) int {
var b uint64 var b uint64
_, b = bits.Sub64(x[0], p2[0], b) _, b = bits.Sub64(x[0], p2[0], b)
@ -119,19 +145,18 @@ func lessThanP(x *gfP) int {
return int(b) return int(b)
} }
// Unmarshal decodes a 32-byte big-endian representation of a gfP element.
// It returns an error if the input length is not 32 bytes or if the decoded
// value is not a valid gfP element (i.e., greater than or equal to the field prime).
func (e *gfP) Unmarshal(in []byte) error { func (e *gfP) Unmarshal(in []byte) error {
gfpUnmarshal(e, (*[32]byte)(in)) if len(in) < 32 {
// Ensure the point respects the curve modulus return errors.New("sm9: invalid input length")
// TODO: Do we need to change it to constant time version ?
for i := 3; i >= 0; i-- {
if e[i] < p2[i] {
return nil
}
if e[i] > p2[i] {
return errors.New("sm9: coordinate exceeds modulus")
}
} }
return errors.New("sm9: coordinate equals modulus") gfpUnmarshal(e, (*[32]byte)(in))
if lessThanP(e) == 0 {
return errors.New("sm9: invalid gfP encoding")
}
return nil
} }
func montEncode(c, a *gfP) { gfpMul(c, a, r2) } func montEncode(c, a *gfP) { gfpMul(c, a, r2) }

View File

@ -1,6 +1,7 @@
package bn256 package bn256
import ( import (
"bytes"
"encoding/hex" "encoding/hex"
"math/big" "math/big"
"testing" "testing"
@ -276,6 +277,43 @@ func TestGfpNeg(t *testing.T) {
} }
} }
func TestGfpUnmarshal(t *testing.T) {
validHex := "85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141"
invalidHex := "b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457d"
t.Run("valid input", func(t *testing.T) {
x, _ := hex.DecodeString(validHex)
var out [32]byte
ret := &gfP{}
err := ret.Unmarshal(x[:])
if err != nil {
t.Errorf("unexpected error: %v", err)
}
ret.Marshal(out[:])
if !bytes.Equal(out[:], x) {
t.Errorf("got %x, expected %x", out, x)
}
})
t.Run("invalid length", func(t *testing.T) {
x, _ := hex.DecodeString(validHex)
ret := &gfP{}
err := ret.Unmarshal(x[1:])
if err == nil || err.Error() != "sm9: invalid input length" {
t.Errorf("expected error, got %v", err)
}
})
t.Run("invalid value", func(t *testing.T) {
x, _ := hex.DecodeString(invalidHex)
ret := &gfP{}
err := ret.Unmarshal(x[:])
if err == nil || err.Error() != "sm9: invalid gfP encoding" {
t.Errorf("expected error, got %v", err)
}
})
}
func BenchmarkGfPUnmarshal(b *testing.B) { func BenchmarkGfPUnmarshal(b *testing.B) {
x := newGFpFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596") x := newGFpFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596")
b.ReportAllocs() b.ReportAllocs()