From f254673618162d8574097ad12880ed468b2c973a Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Mon, 15 Aug 2022 15:16:07 +0800 Subject: [PATCH] sm2ec: sync with sdk --- internal/sm2ec/fiat/generate.go | 36 +++-- internal/sm2ec/fiat/sm2p256.go | 19 ++- internal/sm2ec/generate.go | 245 ++++++++++++++++++++------------ internal/sm2ec/sm2p256.go | 220 ++++++++++++++++------------ 4 files changed, 327 insertions(+), 193 deletions(-) diff --git a/internal/sm2ec/fiat/generate.go b/internal/sm2ec/fiat/generate.go index 555ea26..8ecce9d 100644 --- a/internal/sm2ec/fiat/generate.go +++ b/internal/sm2ec/fiat/generate.go @@ -120,12 +120,16 @@ func main() { const tmplWrapper = `// Copyright 2021 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + // Code generated by generate.go. DO NOT EDIT. + package fiat + import ( "crypto/subtle" "errors" ) + // {{ .Element }} is an integer modulo {{ .Prime }}. // // The zero value is a valid zero element. @@ -134,30 +138,37 @@ type {{ .Element }} struct { // converted in Bytes and SetBytes. x {{ .Prefix }}MontgomeryDomainFieldElement } + const {{ .Prefix }}ElementLen = {{ .BytesLen }} + type {{ .Prefix }}UntypedFieldElement = {{ .FiatType }} + // One sets e = 1, and returns e. func (e *{{ .Element }}) One() *{{ .Element }} { {{ .Prefix }}SetOne(&e.x) return e } + // Equal returns 1 if e == t, and zero otherwise. func (e *{{ .Element }}) Equal(t *{{ .Element }}) int { eBytes := e.Bytes() tBytes := t.Bytes() return subtle.ConstantTimeCompare(eBytes, tBytes) } -var {{ .Prefix }}ZeroEncoding = new({{ .Element }}).Bytes() + // IsZero returns 1 if e == 0, and zero otherwise. func (e *{{ .Element }}) IsZero() int { + zero := make([]byte, {{ .Prefix }}ElementLen) eBytes := e.Bytes() - return subtle.ConstantTimeCompare(eBytes, {{ .Prefix }}ZeroEncoding) + return subtle.ConstantTimeCompare(eBytes, zero) } + // Set sets e = t, and returns e. func (e *{{ .Element }}) Set(t *{{ .Element }}) *{{ .Element }} { e.x = t.x return e } + // Bytes returns the {{ .BytesLen }}-byte big-endian encoding of e. func (e *{{ .Element }}) Bytes() []byte { // This function is outlined to make the allocations inline in the caller @@ -165,6 +176,7 @@ func (e *{{ .Element }}) Bytes() []byte { var out [{{ .Prefix }}ElementLen]byte return e.bytes(&out) } + func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte { var tmp {{ .Prefix }}NonMontgomeryDomainFieldElement {{ .Prefix }}FromMontgomery(&tmp, &e.x) @@ -172,11 +184,7 @@ func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte { {{ .Prefix }}InvertEndianness(out[:]) return out[:] } -// {{ .Prefix }}MinusOneEncoding is the encoding of -1 mod p, so p - 1, the -// highest canonical encoding. It is used by SetBytes to check for non-canonical -// encodings such as p + k, 2p + k, etc. -var {{ .Prefix }}MinusOneEncoding = new({{ .Element }}).Sub( - new({{ .Element }}), new({{ .Element }}).One()).Bytes() + // SetBytes sets e = v, where v is a big-endian {{ .BytesLen }}-byte encoding, and returns e. // If v is not {{ .BytesLen }} bytes or it encodes a value higher than {{ .Prime }}, // SetBytes returns nil and an error, and e is unchanged. @@ -184,11 +192,15 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) { if len(v) != {{ .Prefix }}ElementLen { return nil, errors.New("invalid {{ .Element }} encoding") } + // Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to + // the encoding of -1 mod p, so p - 1, the highest canonical encoding. + var minusOneEncoding = new({{ .Element }}).Sub( + new({{ .Element }}), new({{ .Element }}).One()).Bytes() for i := range v { - if v[i] < {{ .Prefix }}MinusOneEncoding[i] { + if v[i] < minusOneEncoding[i] { break } - if v[i] > {{ .Prefix }}MinusOneEncoding[i] { + if v[i] > minusOneEncoding[i] { return nil, errors.New("invalid {{ .Element }} encoding") } } @@ -200,32 +212,38 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) { {{ .Prefix }}ToMontgomery(&e.x, &tmp) return e, nil } + // Add sets e = t1 + t2, and returns e. func (e *{{ .Element }}) Add(t1, t2 *{{ .Element }}) *{{ .Element }} { {{ .Prefix }}Add(&e.x, &t1.x, &t2.x) return e } + // Sub sets e = t1 - t2, and returns e. func (e *{{ .Element }}) Sub(t1, t2 *{{ .Element }}) *{{ .Element }} { {{ .Prefix }}Sub(&e.x, &t1.x, &t2.x) return e } + // Mul sets e = t1 * t2, and returns e. func (e *{{ .Element }}) Mul(t1, t2 *{{ .Element }}) *{{ .Element }} { {{ .Prefix }}Mul(&e.x, &t1.x, &t2.x) return e } + // Square sets e = t * t, and returns e. func (e *{{ .Element }}) Square(t *{{ .Element }}) *{{ .Element }} { {{ .Prefix }}Square(&e.x, &t.x) return e } + // Select sets v to a if cond == 1, and to b if cond == 0. func (v *{{ .Element }}) Select(a, b *{{ .Element }}, cond int) *{{ .Element }} { {{ .Prefix }}Selectznz((*{{ .Prefix }}UntypedFieldElement)(&v.x), {{ .Prefix }}Uint1(cond), (*{{ .Prefix }}UntypedFieldElement)(&b.x), (*{{ .Prefix }}UntypedFieldElement)(&a.x)) return v } + func {{ .Prefix }}InvertEndianness(v []byte) { for i := 0; i < len(v)/2; i++ { v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i] diff --git a/internal/sm2ec/fiat/sm2p256.go b/internal/sm2ec/fiat/sm2p256.go index 33302ce..78e85bd 100644 --- a/internal/sm2ec/fiat/sm2p256.go +++ b/internal/sm2ec/fiat/sm2p256.go @@ -1,7 +1,9 @@ // Copyright 2021 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + // Code generated by generate.go. DO NOT EDIT. + package fiat import ( @@ -35,12 +37,11 @@ func (e *SM2P256Element) Equal(t *SM2P256Element) int { return subtle.ConstantTimeCompare(eBytes, tBytes) } -var sm2p256ZeroEncoding = new(SM2P256Element).Bytes() - // IsZero returns 1 if e == 0, and zero otherwise. func (e *SM2P256Element) IsZero() int { + zero := make([]byte, sm2p256ElementLen) eBytes := e.Bytes() - return subtle.ConstantTimeCompare(eBytes, sm2p256ZeroEncoding) + return subtle.ConstantTimeCompare(eBytes, zero) } // Set sets e = t, and returns e. @@ -56,6 +57,7 @@ func (e *SM2P256Element) Bytes() []byte { var out [sm2p256ElementLen]byte return e.bytes(&out) } + func (e *SM2P256Element) bytes(out *[sm2p256ElementLen]byte) []byte { var tmp sm2p256NonMontgomeryDomainFieldElement sm2p256FromMontgomery(&tmp, &e.x) @@ -77,14 +79,20 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) { if len(v) != sm2p256ElementLen { return nil, errors.New("invalid SM2P256Element encoding") } + + // Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to + // the encoding of -1 mod p, so p - 1, the highest canonical encoding. + var minusOneEncoding = new(SM2P256Element).Sub( + new(SM2P256Element), new(SM2P256Element).One()).Bytes() for i := range v { - if v[i] < sm2p256MinusOneEncoding[i] { + if v[i] < minusOneEncoding[i] { break } - if v[i] > sm2p256MinusOneEncoding[i] { + if v[i] > minusOneEncoding[i] { return nil, errors.New("invalid SM2P256Element encoding") } } + var in [sm2p256ElementLen]byte copy(in[:], v) sm2p256InvertEndianness(in[:]) @@ -124,6 +132,7 @@ func (v *SM2P256Element) Select(a, b *SM2P256Element, cond int) *SM2P256Element (*sm2p256UntypedFieldElement)(&b.x), (*sm2p256UntypedFieldElement)(&a.x)) return v } + func sm2p256InvertEndianness(v []byte) { for i := 0; i < len(v)/2; i++ { v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i] diff --git a/internal/sm2ec/generate.go b/internal/sm2ec/generate.go index 33661fa..5276125 100644 --- a/internal/sm2ec/generate.go +++ b/internal/sm2ec/generate.go @@ -26,6 +26,14 @@ import ( "text/template" ) +func bigFromHex(s string) *big.Int { + b, ok := new(big.Int).SetString(s, 16) + if !ok { + panic("sm2/elliptic: internal error: invalid encoding") + } + return b +} + var curves = []struct { P string Element string @@ -67,7 +75,8 @@ func main() { p := strings.ToLower(c.P) elementLen := (c.Params.BitSize + 7) / 8 B := fmt.Sprintf("%#v", c.Params.B.FillBytes(make([]byte, elementLen))) - G := fmt.Sprintf("%#v", elliptic.Marshal(c.Params, c.Params.Gx, c.Params.Gy)) + Gx := fmt.Sprintf("%#v", c.Params.Gx.FillBytes(make([]byte, elementLen))) + Gy := fmt.Sprintf("%#v", c.Params.Gy.FillBytes(make([]byte, elementLen))) log.Printf("Generating %s.go...", p) f, err := os.Create(p + ".go") @@ -77,7 +86,7 @@ func main() { defer f.Close() buf := &bytes.Buffer{} if err := t.Execute(buf, map[string]interface{}{ - "P": c.P, "p": p, "B": B, "G": G, + "P": c.P, "p": p, "B": B, "Gx": Gx, "Gy": Gy, "Element": c.Element, "ElementLen": elementLen, "BuildTags": c.BuildTags, }); err != nil { @@ -136,28 +145,32 @@ const tmplNISTEC = `// Copyright 2022 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Code generated by generate.go. DO NOT EDIT. + {{ if .BuildTags }} //go:build {{ .BuildTags }} // +build {{ .BuildTags }} {{ end }} + package sm2ec + import ( "github.com/emmansun/gmsm/internal/sm2ec/fiat" "crypto/subtle" "errors" "sync" ) -var {{.p}}B, _ = new({{.Element}}).SetBytes({{.B}}) -var {{.p}}G, _ = New{{.P}}Point().SetBytes({{.G}}) + // {{.p}}ElementLength is the length of an element of the base or scalar field, // which have the same bytes length for all NIST P curves. const {{.p}}ElementLength = {{ .ElementLen }} + // {{.P}}Point is a {{.P}} point. The zero value is NOT valid. type {{.P}}Point struct { // The point is represented in projective coordinates (X:Y:Z), // where x = X/Z and y = Y/Z. x, y, z *{{.Element}} } + // New{{.P}}Point returns a new {{.P}}Point representing the point at infinity point. func New{{.P}}Point() *{{.P}}Point { return &{{.P}}Point{ @@ -166,14 +179,15 @@ func New{{.P}}Point() *{{.P}}Point { z: new({{.Element}}), } } -// New{{.P}}Generator returns a new {{.P}}Point set to the canonical generator. -func New{{.P}}Generator() *{{.P}}Point { - return (&{{.P}}Point{ - x: new({{.Element}}), - y: new({{.Element}}), - z: new({{.Element}}), - }).Set({{.p}}G) + +// SetGenerator sets p to the canonical generator and returns p. +func (p *{{.P}}Point) SetGenerator() *{{.P}}Point { + p.x.SetBytes({{.Gx}}) + p.y.SetBytes({{.Gy}}) + p.z.One() + return p } + // Set sets p = q and returns p. func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point { p.x.Set(q.x) @@ -181,6 +195,7 @@ func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point { p.z.Set(q.z) return p } + // SetBytes sets p to the compressed, uncompressed, or infinity value encoded in // b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on // the curve, it returns nil and an error, and the receiver is unchanged. @@ -232,15 +247,29 @@ func (p *{{.P}}Point) SetBytes(b []byte) (*{{.P}}Point, error) { return nil, errors.New("invalid {{.P}} point encoding") } } + +var _{{.p}}B *{{.Element}} +var _{{.p}}BOnce sync.Once +func {{.p}}B() *{{.Element}} { + _{{.p}}BOnce.Do(func() { + _{{.p}}B, _ = new({{.Element}}).SetBytes({{.B}}) + }) + return _{{.p}}B +} + // {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2. func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} { y2.Square(x) y2.Mul(y2, x) + threeX := new({{.Element}}).Add(x, x) threeX.Add(threeX, x) + y2.Sub(y2, threeX) - return y2.Add(y2, {{.p}}B) + + return y2.Add(y2, {{.p}}B()) } + func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { // y² = x³ - 3x + b rhs := {{.p}}Polynomial(new({{.Element}}), x) @@ -250,6 +279,7 @@ func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { } return nil } + // Bytes returns the uncompressed or infinity encoding of p, as specified in // SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at // infinity is shorter than all other encodings. @@ -259,6 +289,7 @@ func (p *{{.P}}Point) Bytes() []byte { var out [1+2*{{.p}}ElementLength]byte return p.bytes(&out) } + func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) @@ -271,6 +302,25 @@ func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte { buf = append(buf, y.Bytes()...) return buf } + +// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1, +// Version 2.0, Section 2.3.5, or an error if p is the point at infinity. +func (p *{{.P}}Point) BytesX() ([]byte, error) { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [{{.p}}ElementLength]byte + return p.bytesX(&out) +} + +func (p *{{.P}}Point) bytesX(out *[{{.p}}ElementLength]byte) ([]byte, error) { + if p.z.IsZero() == 1 { + return nil, errors.New("{{.P}} point is the point at infinity") + } + zinv := new({{.Element}}).Invert(p.z) + x := new({{.Element}}).Mul(p.x, zinv) + return append(out[:0], x.Bytes()...), nil +} + // BytesCompressed returns the compressed or infinity encoding of p, as // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the // point at infinity is shorter than all other encodings. @@ -280,6 +330,7 @@ func (p *{{.P}}Point) BytesCompressed() []byte { var out [1 + {{.p}}ElementLength]byte return p.bytesCompressed(&out) } + func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) @@ -294,101 +345,106 @@ func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte buf = append(buf, x.Bytes()...) return buf } + // Add sets q = p1 + p2, and returns q. The points may overlap. func (q *{{.P}}Point) Add(p1, p2 *{{.P}}Point) *{{.P}}Point { // Complete addition formula for a = -3 from "Complete addition formulas for // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. - t0 := new({{.Element}}).Mul(p1.x, p2.x) // t0 := X1 * X2 - t1 := new({{.Element}}).Mul(p1.y, p2.y) // t1 := Y1 * Y2 - t2 := new({{.Element}}).Mul(p1.z, p2.z) // t2 := Z1 * Z2 - t3 := new({{.Element}}).Add(p1.x, p1.y) // t3 := X1 + Y1 - t4 := new({{.Element}}).Add(p2.x, p2.y) // t4 := X2 + Y2 - t3.Mul(t3, t4) // t3 := t3 * t4 - t4.Add(t0, t1) // t4 := t0 + t1 - t3.Sub(t3, t4) // t3 := t3 - t4 - t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 - x3 := new({{.Element}}).Add(p2.y, p2.z) // X3 := Y2 + Z2 - t4.Mul(t4, x3) // t4 := t4 * X3 - x3.Add(t1, t2) // X3 := t1 + t2 - t4.Sub(t4, x3) // t4 := t4 - X3 - x3.Add(p1.x, p1.z) // X3 := X1 + Z1 - y3 := new({{.Element}}).Add(p2.x, p2.z) // Y3 := X2 + Z2 - x3.Mul(x3, y3) // X3 := X3 * Y3 - y3.Add(t0, t2) // Y3 := t0 + t2 - y3.Sub(x3, y3) // Y3 := X3 - Y3 - z3 := new({{.Element}}).Mul({{.p}}B, t2) // Z3 := b * t2 - x3.Sub(y3, z3) // X3 := Y3 - Z3 - z3.Add(x3, x3) // Z3 := X3 + X3 - x3.Add(x3, z3) // X3 := X3 + Z3 - z3.Sub(t1, x3) // Z3 := t1 - X3 - x3.Add(t1, x3) // X3 := t1 + X3 - y3.Mul({{.p}}B, y3) // Y3 := b * Y3 - t1.Add(t2, t2) // t1 := t2 + t2 - t2.Add(t1, t2) // t2 := t1 + t2 - y3.Sub(y3, t2) // Y3 := Y3 - t2 - y3.Sub(y3, t0) // Y3 := Y3 - t0 - t1.Add(y3, y3) // t1 := Y3 + Y3 - y3.Add(t1, y3) // Y3 := t1 + Y3 - t1.Add(t0, t0) // t1 := t0 + t0 - t0.Add(t1, t0) // t0 := t1 + t0 - t0.Sub(t0, t2) // t0 := t0 - t2 - t1.Mul(t4, y3) // t1 := t4 * Y3 - t2.Mul(t0, y3) // t2 := t0 * Y3 - y3.Mul(x3, z3) // Y3 := X3 * Z3 - y3.Add(y3, t2) // Y3 := Y3 + t2 - x3.Mul(t3, x3) // X3 := t3 * X3 - x3.Sub(x3, t1) // X3 := X3 - t1 - z3.Mul(t4, z3) // Z3 := t4 * Z3 - t1.Mul(t3, t0) // t1 := t3 * t0 - z3.Add(z3, t1) // Z3 := Z3 + t1 + t0 := new({{.Element}}).Mul(p1.x, p2.x) // t0 := X1 * X2 + t1 := new({{.Element}}).Mul(p1.y, p2.y) // t1 := Y1 * Y2 + t2 := new({{.Element}}).Mul(p1.z, p2.z) // t2 := Z1 * Z2 + t3 := new({{.Element}}).Add(p1.x, p1.y) // t3 := X1 + Y1 + t4 := new({{.Element}}).Add(p2.x, p2.y) // t4 := X2 + Y2 + t3.Mul(t3, t4) // t3 := t3 * t4 + t4.Add(t0, t1) // t4 := t0 + t1 + t3.Sub(t3, t4) // t3 := t3 - t4 + t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 + x3 := new({{.Element}}).Add(p2.y, p2.z) // X3 := Y2 + Z2 + t4.Mul(t4, x3) // t4 := t4 * X3 + x3.Add(t1, t2) // X3 := t1 + t2 + t4.Sub(t4, x3) // t4 := t4 - X3 + x3.Add(p1.x, p1.z) // X3 := X1 + Z1 + y3 := new({{.Element}}).Add(p2.x, p2.z) // Y3 := X2 + Z2 + x3.Mul(x3, y3) // X3 := X3 * Y3 + y3.Add(t0, t2) // Y3 := t0 + t2 + y3.Sub(x3, y3) // Y3 := X3 - Y3 + z3 := new({{.Element}}).Mul({{.p}}B(), t2) // Z3 := b * t2 + x3.Sub(y3, z3) // X3 := Y3 - Z3 + z3.Add(x3, x3) // Z3 := X3 + X3 + x3.Add(x3, z3) // X3 := X3 + Z3 + z3.Sub(t1, x3) // Z3 := t1 - X3 + x3.Add(t1, x3) // X3 := t1 + X3 + y3.Mul({{.p}}B(), y3) // Y3 := b * Y3 + t1.Add(t2, t2) // t1 := t2 + t2 + t2.Add(t1, t2) // t2 := t1 + t2 + y3.Sub(y3, t2) // Y3 := Y3 - t2 + y3.Sub(y3, t0) // Y3 := Y3 - t0 + t1.Add(y3, y3) // t1 := Y3 + Y3 + y3.Add(t1, y3) // Y3 := t1 + Y3 + t1.Add(t0, t0) // t1 := t0 + t0 + t0.Add(t1, t0) // t0 := t1 + t0 + t0.Sub(t0, t2) // t0 := t0 - t2 + t1.Mul(t4, y3) // t1 := t4 * Y3 + t2.Mul(t0, y3) // t2 := t0 * Y3 + y3.Mul(x3, z3) // Y3 := X3 * Z3 + y3.Add(y3, t2) // Y3 := Y3 + t2 + x3.Mul(t3, x3) // X3 := t3 * X3 + x3.Sub(x3, t1) // X3 := X3 - t1 + z3.Mul(t4, z3) // Z3 := t4 * Z3 + t1.Mul(t3, t0) // t1 := t3 * t0 + z3.Add(z3, t1) // Z3 := Z3 + t1 + q.x.Set(x3) q.y.Set(y3) q.z.Set(z3) return q } + // Double sets q = p + p, and returns q. The points may overlap. func (q *{{.P}}Point) Double(p *{{.P}}Point) *{{.P}}Point { // Complete addition formula for a = -3 from "Complete addition formulas for // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. - t0 := new({{.Element}}).Square(p.x) // t0 := X ^ 2 - t1 := new({{.Element}}).Square(p.y) // t1 := Y ^ 2 - t2 := new({{.Element}}).Square(p.z) // t2 := Z ^ 2 - t3 := new({{.Element}}).Mul(p.x, p.y) // t3 := X * Y - t3.Add(t3, t3) // t3 := t3 + t3 - z3 := new({{.Element}}).Mul(p.x, p.z) // Z3 := X * Z - z3.Add(z3, z3) // Z3 := Z3 + Z3 - y3 := new({{.Element}}).Mul({{.p}}B, t2) // Y3 := b * t2 - y3.Sub(y3, z3) // Y3 := Y3 - Z3 - x3 := new({{.Element}}).Add(y3, y3) // X3 := Y3 + Y3 - y3.Add(x3, y3) // Y3 := X3 + Y3 - x3.Sub(t1, y3) // X3 := t1 - Y3 - y3.Add(t1, y3) // Y3 := t1 + Y3 - y3.Mul(x3, y3) // Y3 := X3 * Y3 - x3.Mul(x3, t3) // X3 := X3 * t3 - t3.Add(t2, t2) // t3 := t2 + t2 - t2.Add(t2, t3) // t2 := t2 + t3 - z3.Mul({{.p}}B, z3) // Z3 := b * Z3 - z3.Sub(z3, t2) // Z3 := Z3 - t2 - z3.Sub(z3, t0) // Z3 := Z3 - t0 - t3.Add(z3, z3) // t3 := Z3 + Z3 - z3.Add(z3, t3) // Z3 := Z3 + t3 - t3.Add(t0, t0) // t3 := t0 + t0 - t0.Add(t3, t0) // t0 := t3 + t0 - t0.Sub(t0, t2) // t0 := t0 - t2 - t0.Mul(t0, z3) // t0 := t0 * Z3 - y3.Add(y3, t0) // Y3 := Y3 + t0 - t0.Mul(p.y, p.z) // t0 := Y * Z - t0.Add(t0, t0) // t0 := t0 + t0 - z3.Mul(t0, z3) // Z3 := t0 * Z3 - x3.Sub(x3, z3) // X3 := X3 - Z3 - z3.Mul(t0, t1) // Z3 := t0 * t1 - z3.Add(z3, z3) // Z3 := Z3 + Z3 - z3.Add(z3, z3) // Z3 := Z3 + Z3 + t0 := new({{.Element}}).Square(p.x) // t0 := X ^ 2 + t1 := new({{.Element}}).Square(p.y) // t1 := Y ^ 2 + t2 := new({{.Element}}).Square(p.z) // t2 := Z ^ 2 + t3 := new({{.Element}}).Mul(p.x, p.y) // t3 := X * Y + t3.Add(t3, t3) // t3 := t3 + t3 + z3 := new({{.Element}}).Mul(p.x, p.z) // Z3 := X * Z + z3.Add(z3, z3) // Z3 := Z3 + Z3 + y3 := new({{.Element}}).Mul({{.p}}B(), t2) // Y3 := b * t2 + y3.Sub(y3, z3) // Y3 := Y3 - Z3 + x3 := new({{.Element}}).Add(y3, y3) // X3 := Y3 + Y3 + y3.Add(x3, y3) // Y3 := X3 + Y3 + x3.Sub(t1, y3) // X3 := t1 - Y3 + y3.Add(t1, y3) // Y3 := t1 + Y3 + y3.Mul(x3, y3) // Y3 := X3 * Y3 + x3.Mul(x3, t3) // X3 := X3 * t3 + t3.Add(t2, t2) // t3 := t2 + t2 + t2.Add(t2, t3) // t2 := t2 + t3 + z3.Mul({{.p}}B(), z3) // Z3 := b * Z3 + z3.Sub(z3, t2) // Z3 := Z3 - t2 + z3.Sub(z3, t0) // Z3 := Z3 - t0 + t3.Add(z3, z3) // t3 := Z3 + Z3 + z3.Add(z3, t3) // Z3 := Z3 + t3 + t3.Add(t0, t0) // t3 := t0 + t0 + t0.Add(t3, t0) // t0 := t3 + t0 + t0.Sub(t0, t2) // t0 := t0 - t2 + t0.Mul(t0, z3) // t0 := t0 * Z3 + y3.Add(y3, t0) // Y3 := Y3 + t0 + t0.Mul(p.y, p.z) // t0 := Y * Z + t0.Add(t0, t0) // t0 := t0 + t0 + z3.Mul(t0, z3) // Z3 := t0 * Z3 + x3.Sub(x3, z3) // X3 := X3 - Z3 + z3.Mul(t0, t1) // Z3 := t0 * t1 + z3.Add(z3, z3) // Z3 := Z3 + Z3 + z3.Add(z3, z3) // Z3 := Z3 + Z3 + q.x.Set(x3) q.y.Set(y3) q.z.Set(z3) return q } + // Select sets q to p1 if cond == 1, and to p2 if cond == 0. func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point { q.x.Select(p1.x, p2.x, cond) @@ -396,10 +452,12 @@ func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point { q.z.Select(p1.z, p2.z, cond) return q } + // A {{.p}}Table holds the first 15 multiples of a point at offset -1, so [1]P // is at table[0], [15]P is at table[14], and [0]P is implicitly the identity // point. type {{.p}}Table [15]*{{.P}}Point + // Select selects the n-th multiple of the table base point into p. It works in // constant time by iterating over every entry of the table. n must be in [0, 15]. func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) { @@ -412,6 +470,7 @@ func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) { p.Select(table[i-1], p, cond) } } + // ScalarMult sets p = scalar * q, and returns p. func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, error) { // Compute a {{.p}}Table for the base point q. The explicit New{{.P}}Point @@ -451,15 +510,17 @@ func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, e } return p, nil } + var {{.p}}GeneratorTable *[{{.p}}ElementLength * 2]{{.p}}Table var {{.p}}GeneratorTableOnce sync.Once + // generatorTable returns a sequence of {{.p}}Tables. The first table contains // multiples of G. Each successive table is the previous table doubled four // times. func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table { {{.p}}GeneratorTableOnce.Do(func() { {{.p}}GeneratorTable = new([{{.p}}ElementLength * 2]{{.p}}Table) - base := New{{.P}}Generator() + base := New{{.P}}Point().SetGenerator() for i := 0; i < {{.p}}ElementLength*2; i++ { {{.p}}GeneratorTable[i][0] = New{{.P}}Point().Set(base) for j := 1; j < 15; j++ { @@ -473,6 +534,7 @@ func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table { }) return {{.p}}GeneratorTable } + // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and // returns p. func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) { @@ -494,13 +556,16 @@ func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) { tables[tableIndex].Select(t, windowValue) p.Add(p, t) tableIndex-- + windowValue = byte & 0b1111 tables[tableIndex].Select(t, windowValue) p.Add(p, t) tableIndex-- } + return p, nil } + // {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns // false and e is unchanged. e and x can overlap. func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) { diff --git a/internal/sm2ec/sm2p256.go b/internal/sm2ec/sm2p256.go index d9adf3f..e6e9f0e 100644 --- a/internal/sm2ec/sm2p256.go +++ b/internal/sm2ec/sm2p256.go @@ -1,7 +1,9 @@ // Copyright 2022 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. + // Code generated by generate.go. DO NOT EDIT. + //go:build !amd64 && !arm64 || generic // +build !amd64,!arm64 generic @@ -14,9 +16,6 @@ import ( "sync" ) -var sm2p256B, _ = new(fiat.SM2P256Element).SetBytes([]byte{0x28, 0xe9, 0xfa, 0x9e, 0x9d, 0x9f, 0x5e, 0x34, 0x4d, 0x5a, 0x9e, 0x4b, 0xcf, 0x65, 0x9, 0xa7, 0xf3, 0x97, 0x89, 0xf5, 0x15, 0xab, 0x8f, 0x92, 0xdd, 0xbc, 0xbd, 0x41, 0x4d, 0x94, 0xe, 0x93}) -var sm2p256G, _ = NewSM2P256Point().SetBytes([]byte{0x4, 0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19, 0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94, 0x8f, 0xe3, 0xb, 0xbf, 0xf2, 0x66, 0xb, 0xe1, 0x71, 0x5a, 0x45, 0x89, 0x33, 0x4c, 0x74, 0xc7, 0xbc, 0x37, 0x36, 0xa2, 0xf4, 0xf6, 0x77, 0x9c, 0x59, 0xbd, 0xce, 0xe3, 0x6b, 0x69, 0x21, 0x53, 0xd0, 0xa9, 0x87, 0x7c, 0xc6, 0x2a, 0x47, 0x40, 0x2, 0xdf, 0x32, 0xe5, 0x21, 0x39, 0xf0, 0xa0}) - // sm2p256ElementLength is the length of an element of the base or scalar field, // which have the same bytes length for all NIST P curves. const sm2p256ElementLength = 32 @@ -37,13 +36,12 @@ func NewSM2P256Point() *SM2P256Point { } } -// NewSM2P256Generator returns a new SM2P256Point set to the canonical generator. -func NewSM2P256Generator() *SM2P256Point { - return (&SM2P256Point{ - x: new(fiat.SM2P256Element), - y: new(fiat.SM2P256Element), - z: new(fiat.SM2P256Element), - }).Set(sm2p256G) +// SetGenerator sets p to the canonical generator and returns p. +func (p *SM2P256Point) SetGenerator() *SM2P256Point { + p.x.SetBytes([]byte{0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19, 0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94, 0x8f, 0xe3, 0xb, 0xbf, 0xf2, 0x66, 0xb, 0xe1, 0x71, 0x5a, 0x45, 0x89, 0x33, 0x4c, 0x74, 0xc7}) + p.y.SetBytes([]byte{0xbc, 0x37, 0x36, 0xa2, 0xf4, 0xf6, 0x77, 0x9c, 0x59, 0xbd, 0xce, 0xe3, 0x6b, 0x69, 0x21, 0x53, 0xd0, 0xa9, 0x87, 0x7c, 0xc6, 0x2a, 0x47, 0x40, 0x2, 0xdf, 0x32, 0xe5, 0x21, 0x39, 0xf0, 0xa0}) + p.z.One() + return p } // Set sets p = q and returns p. @@ -106,15 +104,29 @@ func (p *SM2P256Point) SetBytes(b []byte) (*SM2P256Point, error) { } } +var _sm2p256B *fiat.SM2P256Element +var _sm2p256BOnce sync.Once + +func sm2p256B() *fiat.SM2P256Element { + _sm2p256BOnce.Do(func() { + _sm2p256B, _ = new(fiat.SM2P256Element).SetBytes([]byte{0x28, 0xe9, 0xfa, 0x9e, 0x9d, 0x9f, 0x5e, 0x34, 0x4d, 0x5a, 0x9e, 0x4b, 0xcf, 0x65, 0x9, 0xa7, 0xf3, 0x97, 0x89, 0xf5, 0x15, 0xab, 0x8f, 0x92, 0xdd, 0xbc, 0xbd, 0x41, 0x4d, 0x94, 0xe, 0x93}) + }) + return _sm2p256B +} + // sm2p256Polynomial sets y2 to x³ - 3x + b, and returns y2. func sm2p256Polynomial(y2, x *fiat.SM2P256Element) *fiat.SM2P256Element { y2.Square(x) y2.Mul(y2, x) + threeX := new(fiat.SM2P256Element).Add(x, x) threeX.Add(threeX, x) + y2.Sub(y2, threeX) - return y2.Add(y2, sm2p256B) + + return y2.Add(y2, sm2p256B()) } + func sm2p256CheckOnCurve(x, y *fiat.SM2P256Element) error { // y² = x³ - 3x + b rhs := sm2p256Polynomial(new(fiat.SM2P256Element), x) @@ -134,6 +146,7 @@ func (p *SM2P256Point) Bytes() []byte { var out [1 + 2*sm2p256ElementLength]byte return p.bytes(&out) } + func (p *SM2P256Point) bytes(out *[1 + 2*sm2p256ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) @@ -147,6 +160,24 @@ func (p *SM2P256Point) bytes(out *[1 + 2*sm2p256ElementLength]byte) []byte { return buf } +// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1, +// Version 2.0, Section 2.3.5, or an error if p is the point at infinity. +func (p *SM2P256Point) BytesX() ([]byte, error) { + // This function is outlined to make the allocations inline in the caller + // rather than happen on the heap. + var out [sm2p256ElementLength]byte + return p.bytesX(&out) +} + +func (p *SM2P256Point) bytesX(out *[sm2p256ElementLength]byte) ([]byte, error) { + if p.z.IsZero() == 1 { + return nil, errors.New("SM2P256 point is the point at infinity") + } + zinv := new(fiat.SM2P256Element).Invert(p.z) + x := new(fiat.SM2P256Element).Mul(p.x, zinv) + return append(out[:0], x.Bytes()...), nil +} + // BytesCompressed returns the compressed or infinity encoding of p, as // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the // point at infinity is shorter than all other encodings. @@ -156,6 +187,7 @@ func (p *SM2P256Point) BytesCompressed() []byte { var out [1 + sm2p256ElementLength]byte return p.bytesCompressed(&out) } + func (p *SM2P256Point) bytesCompressed(out *[1 + sm2p256ElementLength]byte) []byte { if p.z.IsZero() == 1 { return append(out[:0], 0) @@ -175,49 +207,50 @@ func (p *SM2P256Point) bytesCompressed(out *[1 + sm2p256ElementLength]byte) []by func (q *SM2P256Point) Add(p1, p2 *SM2P256Point) *SM2P256Point { // Complete addition formula for a = -3 from "Complete addition formulas for // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. - t0 := new(fiat.SM2P256Element).Mul(p1.x, p2.x) // t0 := X1 * X2 - t1 := new(fiat.SM2P256Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2 - t2 := new(fiat.SM2P256Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2 - t3 := new(fiat.SM2P256Element).Add(p1.x, p1.y) // t3 := X1 + Y1 - t4 := new(fiat.SM2P256Element).Add(p2.x, p2.y) // t4 := X2 + Y2 - t3.Mul(t3, t4) // t3 := t3 * t4 - t4.Add(t0, t1) // t4 := t0 + t1 - t3.Sub(t3, t4) // t3 := t3 - t4 - t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 - x3 := new(fiat.SM2P256Element).Add(p2.y, p2.z) // X3 := Y2 + Z2 - t4.Mul(t4, x3) // t4 := t4 * X3 - x3.Add(t1, t2) // X3 := t1 + t2 - t4.Sub(t4, x3) // t4 := t4 - X3 - x3.Add(p1.x, p1.z) // X3 := X1 + Z1 - y3 := new(fiat.SM2P256Element).Add(p2.x, p2.z) // Y3 := X2 + Z2 - x3.Mul(x3, y3) // X3 := X3 * Y3 - y3.Add(t0, t2) // Y3 := t0 + t2 - y3.Sub(x3, y3) // Y3 := X3 - Y3 - z3 := new(fiat.SM2P256Element).Mul(sm2p256B, t2) // Z3 := b * t2 - x3.Sub(y3, z3) // X3 := Y3 - Z3 - z3.Add(x3, x3) // Z3 := X3 + X3 - x3.Add(x3, z3) // X3 := X3 + Z3 - z3.Sub(t1, x3) // Z3 := t1 - X3 - x3.Add(t1, x3) // X3 := t1 + X3 - y3.Mul(sm2p256B, y3) // Y3 := b * Y3 - t1.Add(t2, t2) // t1 := t2 + t2 - t2.Add(t1, t2) // t2 := t1 + t2 - y3.Sub(y3, t2) // Y3 := Y3 - t2 - y3.Sub(y3, t0) // Y3 := Y3 - t0 - t1.Add(y3, y3) // t1 := Y3 + Y3 - y3.Add(t1, y3) // Y3 := t1 + Y3 - t1.Add(t0, t0) // t1 := t0 + t0 - t0.Add(t1, t0) // t0 := t1 + t0 - t0.Sub(t0, t2) // t0 := t0 - t2 - t1.Mul(t4, y3) // t1 := t4 * Y3 - t2.Mul(t0, y3) // t2 := t0 * Y3 - y3.Mul(x3, z3) // Y3 := X3 * Z3 - y3.Add(y3, t2) // Y3 := Y3 + t2 - x3.Mul(t3, x3) // X3 := t3 * X3 - x3.Sub(x3, t1) // X3 := X3 - t1 - z3.Mul(t4, z3) // Z3 := t4 * Z3 - t1.Mul(t3, t0) // t1 := t3 * t0 - z3.Add(z3, t1) // Z3 := Z3 + t1 + t0 := new(fiat.SM2P256Element).Mul(p1.x, p2.x) // t0 := X1 * X2 + t1 := new(fiat.SM2P256Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2 + t2 := new(fiat.SM2P256Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2 + t3 := new(fiat.SM2P256Element).Add(p1.x, p1.y) // t3 := X1 + Y1 + t4 := new(fiat.SM2P256Element).Add(p2.x, p2.y) // t4 := X2 + Y2 + t3.Mul(t3, t4) // t3 := t3 * t4 + t4.Add(t0, t1) // t4 := t0 + t1 + t3.Sub(t3, t4) // t3 := t3 - t4 + t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 + x3 := new(fiat.SM2P256Element).Add(p2.y, p2.z) // X3 := Y2 + Z2 + t4.Mul(t4, x3) // t4 := t4 * X3 + x3.Add(t1, t2) // X3 := t1 + t2 + t4.Sub(t4, x3) // t4 := t4 - X3 + x3.Add(p1.x, p1.z) // X3 := X1 + Z1 + y3 := new(fiat.SM2P256Element).Add(p2.x, p2.z) // Y3 := X2 + Z2 + x3.Mul(x3, y3) // X3 := X3 * Y3 + y3.Add(t0, t2) // Y3 := t0 + t2 + y3.Sub(x3, y3) // Y3 := X3 - Y3 + z3 := new(fiat.SM2P256Element).Mul(sm2p256B(), t2) // Z3 := b * t2 + x3.Sub(y3, z3) // X3 := Y3 - Z3 + z3.Add(x3, x3) // Z3 := X3 + X3 + x3.Add(x3, z3) // X3 := X3 + Z3 + z3.Sub(t1, x3) // Z3 := t1 - X3 + x3.Add(t1, x3) // X3 := t1 + X3 + y3.Mul(sm2p256B(), y3) // Y3 := b * Y3 + t1.Add(t2, t2) // t1 := t2 + t2 + t2.Add(t1, t2) // t2 := t1 + t2 + y3.Sub(y3, t2) // Y3 := Y3 - t2 + y3.Sub(y3, t0) // Y3 := Y3 - t0 + t1.Add(y3, y3) // t1 := Y3 + Y3 + y3.Add(t1, y3) // Y3 := t1 + Y3 + t1.Add(t0, t0) // t1 := t0 + t0 + t0.Add(t1, t0) // t0 := t1 + t0 + t0.Sub(t0, t2) // t0 := t0 - t2 + t1.Mul(t4, y3) // t1 := t4 * Y3 + t2.Mul(t0, y3) // t2 := t0 * Y3 + y3.Mul(x3, z3) // Y3 := X3 * Z3 + y3.Add(y3, t2) // Y3 := Y3 + t2 + x3.Mul(t3, x3) // X3 := t3 * X3 + x3.Sub(x3, t1) // X3 := X3 - t1 + z3.Mul(t4, z3) // Z3 := t4 * Z3 + t1.Mul(t3, t0) // t1 := t3 * t0 + z3.Add(z3, t1) // Z3 := Z3 + t1 + q.x.Set(x3) q.y.Set(y3) q.z.Set(z3) @@ -228,40 +261,41 @@ func (q *SM2P256Point) Add(p1, p2 *SM2P256Point) *SM2P256Point { func (q *SM2P256Point) Double(p *SM2P256Point) *SM2P256Point { // Complete addition formula for a = -3 from "Complete addition formulas for // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. - t0 := new(fiat.SM2P256Element).Square(p.x) // t0 := X ^ 2 - t1 := new(fiat.SM2P256Element).Square(p.y) // t1 := Y ^ 2 - t2 := new(fiat.SM2P256Element).Square(p.z) // t2 := Z ^ 2 - t3 := new(fiat.SM2P256Element).Mul(p.x, p.y) // t3 := X * Y - t3.Add(t3, t3) // t3 := t3 + t3 - z3 := new(fiat.SM2P256Element).Mul(p.x, p.z) // Z3 := X * Z - z3.Add(z3, z3) // Z3 := Z3 + Z3 - y3 := new(fiat.SM2P256Element).Mul(sm2p256B, t2) // Y3 := b * t2 - y3.Sub(y3, z3) // Y3 := Y3 - Z3 - x3 := new(fiat.SM2P256Element).Add(y3, y3) // X3 := Y3 + Y3 - y3.Add(x3, y3) // Y3 := X3 + Y3 - x3.Sub(t1, y3) // X3 := t1 - Y3 - y3.Add(t1, y3) // Y3 := t1 + Y3 - y3.Mul(x3, y3) // Y3 := X3 * Y3 - x3.Mul(x3, t3) // X3 := X3 * t3 - t3.Add(t2, t2) // t3 := t2 + t2 - t2.Add(t2, t3) // t2 := t2 + t3 - z3.Mul(sm2p256B, z3) // Z3 := b * Z3 - z3.Sub(z3, t2) // Z3 := Z3 - t2 - z3.Sub(z3, t0) // Z3 := Z3 - t0 - t3.Add(z3, z3) // t3 := Z3 + Z3 - z3.Add(z3, t3) // Z3 := Z3 + t3 - t3.Add(t0, t0) // t3 := t0 + t0 - t0.Add(t3, t0) // t0 := t3 + t0 - t0.Sub(t0, t2) // t0 := t0 - t2 - t0.Mul(t0, z3) // t0 := t0 * Z3 - y3.Add(y3, t0) // Y3 := Y3 + t0 - t0.Mul(p.y, p.z) // t0 := Y * Z - t0.Add(t0, t0) // t0 := t0 + t0 - z3.Mul(t0, z3) // Z3 := t0 * Z3 - x3.Sub(x3, z3) // X3 := X3 - Z3 - z3.Mul(t0, t1) // Z3 := t0 * t1 - z3.Add(z3, z3) // Z3 := Z3 + Z3 - z3.Add(z3, z3) // Z3 := Z3 + Z3 + t0 := new(fiat.SM2P256Element).Square(p.x) // t0 := X ^ 2 + t1 := new(fiat.SM2P256Element).Square(p.y) // t1 := Y ^ 2 + t2 := new(fiat.SM2P256Element).Square(p.z) // t2 := Z ^ 2 + t3 := new(fiat.SM2P256Element).Mul(p.x, p.y) // t3 := X * Y + t3.Add(t3, t3) // t3 := t3 + t3 + z3 := new(fiat.SM2P256Element).Mul(p.x, p.z) // Z3 := X * Z + z3.Add(z3, z3) // Z3 := Z3 + Z3 + y3 := new(fiat.SM2P256Element).Mul(sm2p256B(), t2) // Y3 := b * t2 + y3.Sub(y3, z3) // Y3 := Y3 - Z3 + x3 := new(fiat.SM2P256Element).Add(y3, y3) // X3 := Y3 + Y3 + y3.Add(x3, y3) // Y3 := X3 + Y3 + x3.Sub(t1, y3) // X3 := t1 - Y3 + y3.Add(t1, y3) // Y3 := t1 + Y3 + y3.Mul(x3, y3) // Y3 := X3 * Y3 + x3.Mul(x3, t3) // X3 := X3 * t3 + t3.Add(t2, t2) // t3 := t2 + t2 + t2.Add(t2, t3) // t2 := t2 + t3 + z3.Mul(sm2p256B(), z3) // Z3 := b * Z3 + z3.Sub(z3, t2) // Z3 := Z3 - t2 + z3.Sub(z3, t0) // Z3 := Z3 - t0 + t3.Add(z3, z3) // t3 := Z3 + Z3 + z3.Add(z3, t3) // Z3 := Z3 + t3 + t3.Add(t0, t0) // t3 := t0 + t0 + t0.Add(t3, t0) // t0 := t3 + t0 + t0.Sub(t0, t2) // t0 := t0 - t2 + t0.Mul(t0, z3) // t0 := t0 * Z3 + y3.Add(y3, t0) // Y3 := Y3 + t0 + t0.Mul(p.y, p.z) // t0 := Y * Z + t0.Add(t0, t0) // t0 := t0 + t0 + z3.Mul(t0, z3) // Z3 := t0 * Z3 + x3.Sub(x3, z3) // X3 := X3 - Z3 + z3.Mul(t0, t1) // Z3 := t0 * t1 + z3.Add(z3, z3) // Z3 := Z3 + Z3 + z3.Add(z3, z3) // Z3 := Z3 + Z3 + q.x.Set(x3) q.y.Set(y3) q.z.Set(z3) @@ -307,6 +341,7 @@ func (p *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point table[i].Double(table[i/2]) table[i+1].Add(table[i], q) } + // Instead of doing the classic double-and-add chain, we do it with a // four-bit window: we double four times, and then add [0-15]P. t := NewSM2P256Point() @@ -320,17 +355,21 @@ func (p *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point p.Double(p) p.Double(p) } + windowValue := byte >> 4 table.Select(t, windowValue) p.Add(p, t) + p.Double(p) p.Double(p) p.Double(p) p.Double(p) + windowValue = byte & 0b1111 table.Select(t, windowValue) p.Add(p, t) } + return p, nil } @@ -343,7 +382,7 @@ var sm2p256GeneratorTableOnce sync.Once func (p *SM2P256Point) generatorTable() *[sm2p256ElementLength * 2]sm2p256Table { sm2p256GeneratorTableOnce.Do(func() { sm2p256GeneratorTable = new([sm2p256ElementLength * 2]sm2p256Table) - base := NewSM2P256Generator() + base := NewSM2P256Point().SetGenerator() for i := 0; i < sm2p256ElementLength*2; i++ { sm2p256GeneratorTable[i][0] = NewSM2P256Point().Set(base) for j := 1; j < 15; j++ { @@ -365,6 +404,7 @@ func (p *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) { return nil, errors.New("invalid scalar length") } tables := p.generatorTable() + // This is also a scalar multiplication with a four-bit window like in // ScalarMult, but in this case the doublings are precomputed. The value // [windowValue]G added at iteration k would normally get doubled @@ -379,11 +419,13 @@ func (p *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) { tables[tableIndex].Select(t, windowValue) p.Add(p, t) tableIndex-- + windowValue = byte & 0b1111 tables[tableIndex].Select(t, windowValue) p.Add(p, t) tableIndex-- } + return p, nil }