378 lines
8.9 KiB
Go
Raw Normal View History

2022-07-15 16:42:39 +08:00
package bn256
import (
"crypto/rand"
"errors"
"io"
"math/big"
"sync"
)
func randomK(r io.Reader) (k *big.Int, err error) {
for {
k, err = rand.Int(r, Order)
2023-03-07 08:55:09 +08:00
if err != nil || k.Sign() > 0 {
2022-07-15 16:42:39 +08:00
return
}
}
}
// G1 is an abstract cyclic group. The zero value is suitable for use as the
// output of an operation, but cannot be used as an input.
type G1 struct {
p *curvePoint
}
// Gen1 is the generator of G1.
2022-07-15 16:42:39 +08:00
var Gen1 = &G1{curveGen}
var g1GeneratorTable *[32 * 2]curvePointTable
var g1GeneratorTableOnce sync.Once
func (g *G1) generatorTable() *[32 * 2]curvePointTable {
g1GeneratorTableOnce.Do(func() {
g1GeneratorTable = new([32 * 2]curvePointTable)
base := NewCurveGenerator()
for i := 0; i < 32*2; i++ {
g1GeneratorTable[i][0] = &curvePoint{}
g1GeneratorTable[i][0].Set(base)
g1GeneratorTable[i][1] = &curvePoint{}
g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0])
g1GeneratorTable[i][2] = &curvePoint{}
g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base)
g1GeneratorTable[i][3] = &curvePoint{}
g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1])
g1GeneratorTable[i][4] = &curvePoint{}
g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base)
g1GeneratorTable[i][5] = &curvePoint{}
g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2])
g1GeneratorTable[i][6] = &curvePoint{}
g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base)
g1GeneratorTable[i][7] = &curvePoint{}
g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3])
g1GeneratorTable[i][8] = &curvePoint{}
g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base)
g1GeneratorTable[i][9] = &curvePoint{}
g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4])
g1GeneratorTable[i][10] = &curvePoint{}
g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base)
g1GeneratorTable[i][11] = &curvePoint{}
g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5])
g1GeneratorTable[i][12] = &curvePoint{}
g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base)
g1GeneratorTable[i][13] = &curvePoint{}
g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6])
g1GeneratorTable[i][14] = &curvePoint{}
g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base)
2022-07-15 16:42:39 +08:00
base.Double(base)
base.Double(base)
base.Double(base)
base.Double(base)
}
})
return g1GeneratorTable
}
// RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r.
func RandomG1(r io.Reader) (*big.Int, *G1, error) {
k, err := randomK(r)
if err != nil {
return nil, nil, err
}
2022-11-25 10:11:46 +08:00
g1, err := new(G1).ScalarBaseMult(NormalizeScalar(k.Bytes()))
return k, g1, err
2022-07-15 16:42:39 +08:00
}
func (g *G1) String() string {
return "sm9.G1" + g.p.String()
}
2022-11-25 10:11:46 +08:00
func NormalizeScalar(scalar []byte) []byte {
2022-07-15 16:42:39 +08:00
if len(scalar) == 32 {
return scalar
}
s := new(big.Int).SetBytes(scalar)
if len(scalar) > 32 {
s.Mod(s, Order)
}
out := make([]byte, 32)
return s.FillBytes(out)
}
2022-11-25 10:11:46 +08:00
// ScalarBaseMult sets e to scaler*g where g is the generator of the group and then
2022-07-15 16:42:39 +08:00
// returns e.
2022-11-25 10:11:46 +08:00
func (e *G1) ScalarBaseMult(scalar []byte) (*G1, error) {
if len(scalar) != 32 {
return nil, errors.New("invalid scalar length")
}
2022-07-15 16:42:39 +08:00
if e.p == nil {
e.p = &curvePoint{}
}
//e.p.Mul(curveGen, k)
tables := e.generatorTable()
// This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled
// (totIterations-k)×4 times, but with a larger precomputation we can
// instead add [2^((totIterations-k)×4)][windowValue]G and avoid the
// doublings between iterations.
t := NewCurvePoint()
e.p.SetInfinity()
tableIndex := len(tables) - 1
for _, byte := range scalar {
windowValue := byte >> 4
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue)
e.p.Add(e.p, t)
tableIndex--
}
2022-11-25 10:11:46 +08:00
return e, nil
2022-07-15 16:42:39 +08:00
}
// ScalarMult sets e to a*k and then returns e.
2022-11-25 10:11:46 +08:00
func (e *G1) ScalarMult(a *G1, scalar []byte) (*G1, error) {
2022-07-15 16:42:39 +08:00
if e.p == nil {
e.p = &curvePoint{}
}
//e.p.Mul(a.p, k)
// Compute a curvePointTable for the base point a.
var table = curvePointTable{NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint(),
NewCurvePoint(), NewCurvePoint(), NewCurvePoint(), NewCurvePoint()}
table[0].Set(a.p)
for i := 1; i < 15; i += 2 {
table[i].Double(table[i/2])
table[i+1].Add(table[i], a.p)
}
// Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P.
t := &G1{NewCurvePoint()}
e.p.SetInfinity()
2022-11-25 10:11:46 +08:00
for i, byte := range scalar {
2022-07-15 16:42:39 +08:00
// No need to double on the first iteration, as p is the identity at
// this point, and [N]∞ = ∞.
if i != 0 {
e.Double(e)
e.Double(e)
e.Double(e)
e.Double(e)
}
windowValue := byte >> 4
table.Select(t.p, windowValue)
e.Add(e, t)
e.Double(e)
e.Double(e)
e.Double(e)
e.Double(e)
windowValue = byte & 0b1111
table.Select(t.p, windowValue)
e.Add(e, t)
}
2022-11-25 10:11:46 +08:00
return e, nil
2022-07-15 16:42:39 +08:00
}
// Add sets e to a+b and then returns e.
func (e *G1) Add(a, b *G1) *G1 {
if e.p == nil {
e.p = &curvePoint{}
}
e.p.Add(a.p, b.p)
return e
}
// Double sets e to [2]a and then returns e.
func (e *G1) Double(a *G1) *G1 {
if e.p == nil {
e.p = &curvePoint{}
}
e.p.Double(a.p)
return e
}
// Neg sets e to -a and then returns e.
func (e *G1) Neg(a *G1) *G1 {
if e.p == nil {
e.p = &curvePoint{}
}
e.p.Neg(a.p)
return e
}
// Set sets e to a and then returns e.
func (e *G1) Set(a *G1) *G1 {
if e.p == nil {
e.p = &curvePoint{}
}
e.p.Set(a.p)
return e
}
// Marshal converts e to a byte slice.
func (e *G1) Marshal() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes*2)
e.fillBytes(ret)
return ret
}
// MarshalUncompressed converts e to a byte slice with prefix
func (e *G1) MarshalUncompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes*2+1)
ret[0] = 4
e.fillBytes(ret[1:])
return ret
}
// MarshalCompressed converts e to a byte slice with compress prefix.
// If the point is not on the curve (or is the conventional point at infinity), the behavior is undefined.
func (e *G1) MarshalCompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes+1)
if e.p == nil {
e.p = &curvePoint{}
}
e.p.MakeAffine()
temp := &gfP{}
montDecode(temp, &e.p.y)
2022-07-15 16:42:39 +08:00
temp.Marshal(ret[1:])
ret[0] = (ret[numBytes] & 1) | 2
montDecode(temp, &e.p.x)
temp.Marshal(ret[1:])
return ret
}
// UnmarshalCompressed sets e to the result of converting the output of Marshal back into
// a group element and then returns e.
func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if len(data) < 1+numBytes {
return nil, errors.New("sm9.G1: not enough data")
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, errors.New("sm9.G1: invalid point compress byte")
}
if e.p == nil {
e.p = &curvePoint{}
} else {
e.p.x.Set(zero)
e.p.y.Set(zero)
2022-07-15 16:42:39 +08:00
}
e.p.x.Unmarshal(data[1:])
montEncode(&e.p.x, &e.p.x)
x3 := e.p.polynomial(&e.p.x)
2023-04-29 13:47:58 +08:00
e.p.y.Sqrt(x3)
2022-07-15 16:42:39 +08:00
montDecode(x3, &e.p.y)
if byte(x3[0]&1) != data[0]&1 {
gfpNeg(&e.p.y, &e.p.y)
}
if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
2022-07-15 16:42:39 +08:00
// This is the point at infinity.
e.p.SetInfinity()
2022-07-15 16:42:39 +08:00
} else {
e.p.z.Set(one)
e.p.t.Set(one)
2022-07-15 16:42:39 +08:00
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G1: malformed point")
}
}
return data[numBytes+1:], nil
}
func (e *G1) fillBytes(buffer []byte) {
const numBytes = 256 / 8
if e.p == nil {
e.p = &curvePoint{}
}
e.p.MakeAffine()
if e.p.IsInfinity() {
return
}
temp := &gfP{}
montDecode(temp, &e.p.x)
temp.Marshal(buffer)
montDecode(temp, &e.p.y)
temp.Marshal(buffer[numBytes:])
}
// Unmarshal sets e to the result of converting the output of Marshal back into
// a group element and then returns e.
func (e *G1) Unmarshal(m []byte) ([]byte, error) {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if len(m) < 2*numBytes {
return nil, errors.New("sm9.G1: not enough data")
}
if e.p == nil {
e.p = &curvePoint{}
} else {
e.p.x.Set(zero)
e.p.y.Set(zero)
2022-07-15 16:42:39 +08:00
}
e.p.x.Unmarshal(m)
e.p.y.Unmarshal(m[numBytes:])
montEncode(&e.p.x, &e.p.x)
montEncode(&e.p.y, &e.p.y)
if e.p.x.Equal(zero) == 1 && e.p.y.Equal(zero) == 1 {
2022-07-15 16:42:39 +08:00
// This is the point at infinity.
e.p.SetInfinity()
2022-07-15 16:42:39 +08:00
} else {
e.p.z.Set(one)
e.p.t.Set(one)
2022-07-15 16:42:39 +08:00
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G1: malformed point")
}
}
return m[2*numBytes:], nil
}
// Equal compare e and other
func (e *G1) Equal(other *G1) bool {
if e.p == nil && other.p == nil {
return true
}
return e.p.Equal(other.p)
2022-07-15 16:42:39 +08:00
}
// IsOnCurve returns true if e is on the curve.
func (e *G1) IsOnCurve() bool {
return e.p.IsOnCurve()
}