sm2ec: sync with sdk

This commit is contained in:
Sun Yimin 2022-08-15 15:16:07 +08:00 committed by GitHub
parent c37e143c66
commit f254673618
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 327 additions and 193 deletions

View File

@ -120,12 +120,16 @@ func main() {
const tmplWrapper = `// Copyright 2021 The Go Authors. All rights reserved. const tmplWrapper = `// Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT. // Code generated by generate.go. DO NOT EDIT.
package fiat package fiat
import ( import (
"crypto/subtle" "crypto/subtle"
"errors" "errors"
) )
// {{ .Element }} is an integer modulo {{ .Prime }}. // {{ .Element }} is an integer modulo {{ .Prime }}.
// //
// The zero value is a valid zero element. // The zero value is a valid zero element.
@ -134,30 +138,37 @@ type {{ .Element }} struct {
// converted in Bytes and SetBytes. // converted in Bytes and SetBytes.
x {{ .Prefix }}MontgomeryDomainFieldElement x {{ .Prefix }}MontgomeryDomainFieldElement
} }
const {{ .Prefix }}ElementLen = {{ .BytesLen }} const {{ .Prefix }}ElementLen = {{ .BytesLen }}
type {{ .Prefix }}UntypedFieldElement = {{ .FiatType }} type {{ .Prefix }}UntypedFieldElement = {{ .FiatType }}
// One sets e = 1, and returns e. // One sets e = 1, and returns e.
func (e *{{ .Element }}) One() *{{ .Element }} { func (e *{{ .Element }}) One() *{{ .Element }} {
{{ .Prefix }}SetOne(&e.x) {{ .Prefix }}SetOne(&e.x)
return e return e
} }
// Equal returns 1 if e == t, and zero otherwise. // Equal returns 1 if e == t, and zero otherwise.
func (e *{{ .Element }}) Equal(t *{{ .Element }}) int { func (e *{{ .Element }}) Equal(t *{{ .Element }}) int {
eBytes := e.Bytes() eBytes := e.Bytes()
tBytes := t.Bytes() tBytes := t.Bytes()
return subtle.ConstantTimeCompare(eBytes, tBytes) return subtle.ConstantTimeCompare(eBytes, tBytes)
} }
var {{ .Prefix }}ZeroEncoding = new({{ .Element }}).Bytes()
// IsZero returns 1 if e == 0, and zero otherwise. // IsZero returns 1 if e == 0, and zero otherwise.
func (e *{{ .Element }}) IsZero() int { func (e *{{ .Element }}) IsZero() int {
zero := make([]byte, {{ .Prefix }}ElementLen)
eBytes := e.Bytes() eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, {{ .Prefix }}ZeroEncoding) return subtle.ConstantTimeCompare(eBytes, zero)
} }
// Set sets e = t, and returns e. // Set sets e = t, and returns e.
func (e *{{ .Element }}) Set(t *{{ .Element }}) *{{ .Element }} { func (e *{{ .Element }}) Set(t *{{ .Element }}) *{{ .Element }} {
e.x = t.x e.x = t.x
return e return e
} }
// Bytes returns the {{ .BytesLen }}-byte big-endian encoding of e. // Bytes returns the {{ .BytesLen }}-byte big-endian encoding of e.
func (e *{{ .Element }}) Bytes() []byte { func (e *{{ .Element }}) Bytes() []byte {
// This function is outlined to make the allocations inline in the caller // This function is outlined to make the allocations inline in the caller
@ -165,6 +176,7 @@ func (e *{{ .Element }}) Bytes() []byte {
var out [{{ .Prefix }}ElementLen]byte var out [{{ .Prefix }}ElementLen]byte
return e.bytes(&out) return e.bytes(&out)
} }
func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte { func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte {
var tmp {{ .Prefix }}NonMontgomeryDomainFieldElement var tmp {{ .Prefix }}NonMontgomeryDomainFieldElement
{{ .Prefix }}FromMontgomery(&tmp, &e.x) {{ .Prefix }}FromMontgomery(&tmp, &e.x)
@ -172,11 +184,7 @@ func (e *{{ .Element }}) bytes(out *[{{ .Prefix }}ElementLen]byte) []byte {
{{ .Prefix }}InvertEndianness(out[:]) {{ .Prefix }}InvertEndianness(out[:])
return out[:] return out[:]
} }
// {{ .Prefix }}MinusOneEncoding is the encoding of -1 mod p, so p - 1, the
// highest canonical encoding. It is used by SetBytes to check for non-canonical
// encodings such as p + k, 2p + k, etc.
var {{ .Prefix }}MinusOneEncoding = new({{ .Element }}).Sub(
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
// SetBytes sets e = v, where v is a big-endian {{ .BytesLen }}-byte encoding, and returns e. // SetBytes sets e = v, where v is a big-endian {{ .BytesLen }}-byte encoding, and returns e.
// If v is not {{ .BytesLen }} bytes or it encodes a value higher than {{ .Prime }}, // If v is not {{ .BytesLen }} bytes or it encodes a value higher than {{ .Prime }},
// SetBytes returns nil and an error, and e is unchanged. // SetBytes returns nil and an error, and e is unchanged.
@ -184,11 +192,15 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
if len(v) != {{ .Prefix }}ElementLen { if len(v) != {{ .Prefix }}ElementLen {
return nil, errors.New("invalid {{ .Element }} encoding") return nil, errors.New("invalid {{ .Element }} encoding")
} }
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new({{ .Element }}).Sub(
new({{ .Element }}), new({{ .Element }}).One()).Bytes()
for i := range v { for i := range v {
if v[i] < {{ .Prefix }}MinusOneEncoding[i] { if v[i] < minusOneEncoding[i] {
break break
} }
if v[i] > {{ .Prefix }}MinusOneEncoding[i] { if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid {{ .Element }} encoding") return nil, errors.New("invalid {{ .Element }} encoding")
} }
} }
@ -200,32 +212,38 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
{{ .Prefix }}ToMontgomery(&e.x, &tmp) {{ .Prefix }}ToMontgomery(&e.x, &tmp)
return e, nil return e, nil
} }
// Add sets e = t1 + t2, and returns e. // Add sets e = t1 + t2, and returns e.
func (e *{{ .Element }}) Add(t1, t2 *{{ .Element }}) *{{ .Element }} { func (e *{{ .Element }}) Add(t1, t2 *{{ .Element }}) *{{ .Element }} {
{{ .Prefix }}Add(&e.x, &t1.x, &t2.x) {{ .Prefix }}Add(&e.x, &t1.x, &t2.x)
return e return e
} }
// Sub sets e = t1 - t2, and returns e. // Sub sets e = t1 - t2, and returns e.
func (e *{{ .Element }}) Sub(t1, t2 *{{ .Element }}) *{{ .Element }} { func (e *{{ .Element }}) Sub(t1, t2 *{{ .Element }}) *{{ .Element }} {
{{ .Prefix }}Sub(&e.x, &t1.x, &t2.x) {{ .Prefix }}Sub(&e.x, &t1.x, &t2.x)
return e return e
} }
// Mul sets e = t1 * t2, and returns e. // Mul sets e = t1 * t2, and returns e.
func (e *{{ .Element }}) Mul(t1, t2 *{{ .Element }}) *{{ .Element }} { func (e *{{ .Element }}) Mul(t1, t2 *{{ .Element }}) *{{ .Element }} {
{{ .Prefix }}Mul(&e.x, &t1.x, &t2.x) {{ .Prefix }}Mul(&e.x, &t1.x, &t2.x)
return e return e
} }
// Square sets e = t * t, and returns e. // Square sets e = t * t, and returns e.
func (e *{{ .Element }}) Square(t *{{ .Element }}) *{{ .Element }} { func (e *{{ .Element }}) Square(t *{{ .Element }}) *{{ .Element }} {
{{ .Prefix }}Square(&e.x, &t.x) {{ .Prefix }}Square(&e.x, &t.x)
return e return e
} }
// Select sets v to a if cond == 1, and to b if cond == 0. // Select sets v to a if cond == 1, and to b if cond == 0.
func (v *{{ .Element }}) Select(a, b *{{ .Element }}, cond int) *{{ .Element }} { func (v *{{ .Element }}) Select(a, b *{{ .Element }}, cond int) *{{ .Element }} {
{{ .Prefix }}Selectznz((*{{ .Prefix }}UntypedFieldElement)(&v.x), {{ .Prefix }}Uint1(cond), {{ .Prefix }}Selectznz((*{{ .Prefix }}UntypedFieldElement)(&v.x), {{ .Prefix }}Uint1(cond),
(*{{ .Prefix }}UntypedFieldElement)(&b.x), (*{{ .Prefix }}UntypedFieldElement)(&a.x)) (*{{ .Prefix }}UntypedFieldElement)(&b.x), (*{{ .Prefix }}UntypedFieldElement)(&a.x))
return v return v
} }
func {{ .Prefix }}InvertEndianness(v []byte) { func {{ .Prefix }}InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ { for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i] v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]

View File

@ -1,7 +1,9 @@
// Copyright 2021 The Go Authors. All rights reserved. // Copyright 2021 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT. // Code generated by generate.go. DO NOT EDIT.
package fiat package fiat
import ( import (
@ -35,12 +37,11 @@ func (e *SM2P256Element) Equal(t *SM2P256Element) int {
return subtle.ConstantTimeCompare(eBytes, tBytes) return subtle.ConstantTimeCompare(eBytes, tBytes)
} }
var sm2p256ZeroEncoding = new(SM2P256Element).Bytes()
// IsZero returns 1 if e == 0, and zero otherwise. // IsZero returns 1 if e == 0, and zero otherwise.
func (e *SM2P256Element) IsZero() int { func (e *SM2P256Element) IsZero() int {
zero := make([]byte, sm2p256ElementLen)
eBytes := e.Bytes() eBytes := e.Bytes()
return subtle.ConstantTimeCompare(eBytes, sm2p256ZeroEncoding) return subtle.ConstantTimeCompare(eBytes, zero)
} }
// Set sets e = t, and returns e. // Set sets e = t, and returns e.
@ -56,6 +57,7 @@ func (e *SM2P256Element) Bytes() []byte {
var out [sm2p256ElementLen]byte var out [sm2p256ElementLen]byte
return e.bytes(&out) return e.bytes(&out)
} }
func (e *SM2P256Element) bytes(out *[sm2p256ElementLen]byte) []byte { func (e *SM2P256Element) bytes(out *[sm2p256ElementLen]byte) []byte {
var tmp sm2p256NonMontgomeryDomainFieldElement var tmp sm2p256NonMontgomeryDomainFieldElement
sm2p256FromMontgomery(&tmp, &e.x) sm2p256FromMontgomery(&tmp, &e.x)
@ -77,14 +79,20 @@ func (e *SM2P256Element) SetBytes(v []byte) (*SM2P256Element, error) {
if len(v) != sm2p256ElementLen { if len(v) != sm2p256ElementLen {
return nil, errors.New("invalid SM2P256Element encoding") return nil, errors.New("invalid SM2P256Element encoding")
} }
// Check for non-canonical encodings (p + k, 2p + k, etc.) by comparing to
// the encoding of -1 mod p, so p - 1, the highest canonical encoding.
var minusOneEncoding = new(SM2P256Element).Sub(
new(SM2P256Element), new(SM2P256Element).One()).Bytes()
for i := range v { for i := range v {
if v[i] < sm2p256MinusOneEncoding[i] { if v[i] < minusOneEncoding[i] {
break break
} }
if v[i] > sm2p256MinusOneEncoding[i] { if v[i] > minusOneEncoding[i] {
return nil, errors.New("invalid SM2P256Element encoding") return nil, errors.New("invalid SM2P256Element encoding")
} }
} }
var in [sm2p256ElementLen]byte var in [sm2p256ElementLen]byte
copy(in[:], v) copy(in[:], v)
sm2p256InvertEndianness(in[:]) sm2p256InvertEndianness(in[:])
@ -124,6 +132,7 @@ func (v *SM2P256Element) Select(a, b *SM2P256Element, cond int) *SM2P256Element
(*sm2p256UntypedFieldElement)(&b.x), (*sm2p256UntypedFieldElement)(&a.x)) (*sm2p256UntypedFieldElement)(&b.x), (*sm2p256UntypedFieldElement)(&a.x))
return v return v
} }
func sm2p256InvertEndianness(v []byte) { func sm2p256InvertEndianness(v []byte) {
for i := 0; i < len(v)/2; i++ { for i := 0; i < len(v)/2; i++ {
v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i] v[i], v[len(v)-1-i] = v[len(v)-1-i], v[i]

View File

@ -26,6 +26,14 @@ import (
"text/template" "text/template"
) )
func bigFromHex(s string) *big.Int {
b, ok := new(big.Int).SetString(s, 16)
if !ok {
panic("sm2/elliptic: internal error: invalid encoding")
}
return b
}
var curves = []struct { var curves = []struct {
P string P string
Element string Element string
@ -67,7 +75,8 @@ func main() {
p := strings.ToLower(c.P) p := strings.ToLower(c.P)
elementLen := (c.Params.BitSize + 7) / 8 elementLen := (c.Params.BitSize + 7) / 8
B := fmt.Sprintf("%#v", c.Params.B.FillBytes(make([]byte, elementLen))) B := fmt.Sprintf("%#v", c.Params.B.FillBytes(make([]byte, elementLen)))
G := fmt.Sprintf("%#v", elliptic.Marshal(c.Params, c.Params.Gx, c.Params.Gy)) Gx := fmt.Sprintf("%#v", c.Params.Gx.FillBytes(make([]byte, elementLen)))
Gy := fmt.Sprintf("%#v", c.Params.Gy.FillBytes(make([]byte, elementLen)))
log.Printf("Generating %s.go...", p) log.Printf("Generating %s.go...", p)
f, err := os.Create(p + ".go") f, err := os.Create(p + ".go")
@ -77,7 +86,7 @@ func main() {
defer f.Close() defer f.Close()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
if err := t.Execute(buf, map[string]interface{}{ if err := t.Execute(buf, map[string]interface{}{
"P": c.P, "p": p, "B": B, "G": G, "P": c.P, "p": p, "B": B, "Gx": Gx, "Gy": Gy,
"Element": c.Element, "ElementLen": elementLen, "Element": c.Element, "ElementLen": elementLen,
"BuildTags": c.BuildTags, "BuildTags": c.BuildTags,
}); err != nil { }); err != nil {
@ -136,28 +145,32 @@ const tmplNISTEC = `// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT. // Code generated by generate.go. DO NOT EDIT.
{{ if .BuildTags }} {{ if .BuildTags }}
//go:build {{ .BuildTags }} //go:build {{ .BuildTags }}
// +build {{ .BuildTags }} // +build {{ .BuildTags }}
{{ end }} {{ end }}
package sm2ec package sm2ec
import ( import (
"github.com/emmansun/gmsm/internal/sm2ec/fiat" "github.com/emmansun/gmsm/internal/sm2ec/fiat"
"crypto/subtle" "crypto/subtle"
"errors" "errors"
"sync" "sync"
) )
var {{.p}}B, _ = new({{.Element}}).SetBytes({{.B}})
var {{.p}}G, _ = New{{.P}}Point().SetBytes({{.G}})
// {{.p}}ElementLength is the length of an element of the base or scalar field, // {{.p}}ElementLength is the length of an element of the base or scalar field,
// which have the same bytes length for all NIST P curves. // which have the same bytes length for all NIST P curves.
const {{.p}}ElementLength = {{ .ElementLen }} const {{.p}}ElementLength = {{ .ElementLen }}
// {{.P}}Point is a {{.P}} point. The zero value is NOT valid. // {{.P}}Point is a {{.P}} point. The zero value is NOT valid.
type {{.P}}Point struct { type {{.P}}Point struct {
// The point is represented in projective coordinates (X:Y:Z), // The point is represented in projective coordinates (X:Y:Z),
// where x = X/Z and y = Y/Z. // where x = X/Z and y = Y/Z.
x, y, z *{{.Element}} x, y, z *{{.Element}}
} }
// New{{.P}}Point returns a new {{.P}}Point representing the point at infinity point. // New{{.P}}Point returns a new {{.P}}Point representing the point at infinity point.
func New{{.P}}Point() *{{.P}}Point { func New{{.P}}Point() *{{.P}}Point {
return &{{.P}}Point{ return &{{.P}}Point{
@ -166,14 +179,15 @@ func New{{.P}}Point() *{{.P}}Point {
z: new({{.Element}}), z: new({{.Element}}),
} }
} }
// New{{.P}}Generator returns a new {{.P}}Point set to the canonical generator.
func New{{.P}}Generator() *{{.P}}Point { // SetGenerator sets p to the canonical generator and returns p.
return (&{{.P}}Point{ func (p *{{.P}}Point) SetGenerator() *{{.P}}Point {
x: new({{.Element}}), p.x.SetBytes({{.Gx}})
y: new({{.Element}}), p.y.SetBytes({{.Gy}})
z: new({{.Element}}), p.z.One()
}).Set({{.p}}G) return p
} }
// Set sets p = q and returns p. // Set sets p = q and returns p.
func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point { func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point {
p.x.Set(q.x) p.x.Set(q.x)
@ -181,6 +195,7 @@ func (p *{{.P}}Point) Set(q *{{.P}}Point) *{{.P}}Point {
p.z.Set(q.z) p.z.Set(q.z)
return p return p
} }
// SetBytes sets p to the compressed, uncompressed, or infinity value encoded in // SetBytes sets p to the compressed, uncompressed, or infinity value encoded in
// b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on // b, as specified in SEC 1, Version 2.0, Section 2.3.4. If the point is not on
// the curve, it returns nil and an error, and the receiver is unchanged. // the curve, it returns nil and an error, and the receiver is unchanged.
@ -232,15 +247,29 @@ func (p *{{.P}}Point) SetBytes(b []byte) (*{{.P}}Point, error) {
return nil, errors.New("invalid {{.P}} point encoding") return nil, errors.New("invalid {{.P}} point encoding")
} }
} }
var _{{.p}}B *{{.Element}}
var _{{.p}}BOnce sync.Once
func {{.p}}B() *{{.Element}} {
_{{.p}}BOnce.Do(func() {
_{{.p}}B, _ = new({{.Element}}).SetBytes({{.B}})
})
return _{{.p}}B
}
// {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2. // {{.p}}Polynomial sets y2 to x³ - 3x + b, and returns y2.
func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} { func {{.p}}Polynomial(y2, x *{{.Element}}) *{{.Element}} {
y2.Square(x) y2.Square(x)
y2.Mul(y2, x) y2.Mul(y2, x)
threeX := new({{.Element}}).Add(x, x) threeX := new({{.Element}}).Add(x, x)
threeX.Add(threeX, x) threeX.Add(threeX, x)
y2.Sub(y2, threeX) y2.Sub(y2, threeX)
return y2.Add(y2, {{.p}}B)
return y2.Add(y2, {{.p}}B())
} }
func {{.p}}CheckOnCurve(x, y *{{.Element}}) error { func {{.p}}CheckOnCurve(x, y *{{.Element}}) error {
// y² = x³ - 3x + b // y² = x³ - 3x + b
rhs := {{.p}}Polynomial(new({{.Element}}), x) rhs := {{.p}}Polynomial(new({{.Element}}), x)
@ -250,6 +279,7 @@ func {{.p}}CheckOnCurve(x, y *{{.Element}}) error {
} }
return nil return nil
} }
// Bytes returns the uncompressed or infinity encoding of p, as specified in // Bytes returns the uncompressed or infinity encoding of p, as specified in
// SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at // SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the point at
// infinity is shorter than all other encodings. // infinity is shorter than all other encodings.
@ -259,6 +289,7 @@ func (p *{{.P}}Point) Bytes() []byte {
var out [1+2*{{.p}}ElementLength]byte var out [1+2*{{.p}}ElementLength]byte
return p.bytes(&out) return p.bytes(&out)
} }
func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte { func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte {
if p.z.IsZero() == 1 { if p.z.IsZero() == 1 {
return append(out[:0], 0) return append(out[:0], 0)
@ -271,6 +302,25 @@ func (p *{{.P}}Point) bytes(out *[1+2*{{.p}}ElementLength]byte) []byte {
buf = append(buf, y.Bytes()...) buf = append(buf, y.Bytes()...)
return buf return buf
} }
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *{{.P}}Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [{{.p}}ElementLength]byte
return p.bytesX(&out)
}
func (p *{{.P}}Point) bytesX(out *[{{.p}}ElementLength]byte) ([]byte, error) {
if p.z.IsZero() == 1 {
return nil, errors.New("{{.P}} point is the point at infinity")
}
zinv := new({{.Element}}).Invert(p.z)
x := new({{.Element}}).Mul(p.x, zinv)
return append(out[:0], x.Bytes()...), nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as // BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings. // point at infinity is shorter than all other encodings.
@ -280,6 +330,7 @@ func (p *{{.P}}Point) BytesCompressed() []byte {
var out [1 + {{.p}}ElementLength]byte var out [1 + {{.p}}ElementLength]byte
return p.bytesCompressed(&out) return p.bytesCompressed(&out)
} }
func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte { func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte {
if p.z.IsZero() == 1 { if p.z.IsZero() == 1 {
return append(out[:0], 0) return append(out[:0], 0)
@ -294,101 +345,106 @@ func (p *{{.P}}Point) bytesCompressed(out *[1 + {{.p}}ElementLength]byte) []byte
buf = append(buf, x.Bytes()...) buf = append(buf, x.Bytes()...)
return buf return buf
} }
// Add sets q = p1 + p2, and returns q. The points may overlap. // Add sets q = p1 + p2, and returns q. The points may overlap.
func (q *{{.P}}Point) Add(p1, p2 *{{.P}}Point) *{{.P}}Point { func (q *{{.P}}Point) Add(p1, p2 *{{.P}}Point) *{{.P}}Point {
// Complete addition formula for a = -3 from "Complete addition formulas for // Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new({{.Element}}).Mul(p1.x, p2.x) // t0 := X1 * X2 t0 := new({{.Element}}).Mul(p1.x, p2.x) // t0 := X1 * X2
t1 := new({{.Element}}).Mul(p1.y, p2.y) // t1 := Y1 * Y2 t1 := new({{.Element}}).Mul(p1.y, p2.y) // t1 := Y1 * Y2
t2 := new({{.Element}}).Mul(p1.z, p2.z) // t2 := Z1 * Z2 t2 := new({{.Element}}).Mul(p1.z, p2.z) // t2 := Z1 * Z2
t3 := new({{.Element}}).Add(p1.x, p1.y) // t3 := X1 + Y1 t3 := new({{.Element}}).Add(p1.x, p1.y) // t3 := X1 + Y1
t4 := new({{.Element}}).Add(p2.x, p2.y) // t4 := X2 + Y2 t4 := new({{.Element}}).Add(p2.x, p2.y) // t4 := X2 + Y2
t3.Mul(t3, t4) // t3 := t3 * t4 t3.Mul(t3, t4) // t3 := t3 * t4
t4.Add(t0, t1) // t4 := t0 + t1 t4.Add(t0, t1) // t4 := t0 + t1
t3.Sub(t3, t4) // t3 := t3 - t4 t3.Sub(t3, t4) // t3 := t3 - t4
t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 t4.Add(p1.y, p1.z) // t4 := Y1 + Z1
x3 := new({{.Element}}).Add(p2.y, p2.z) // X3 := Y2 + Z2 x3 := new({{.Element}}).Add(p2.y, p2.z) // X3 := Y2 + Z2
t4.Mul(t4, x3) // t4 := t4 * X3 t4.Mul(t4, x3) // t4 := t4 * X3
x3.Add(t1, t2) // X3 := t1 + t2 x3.Add(t1, t2) // X3 := t1 + t2
t4.Sub(t4, x3) // t4 := t4 - X3 t4.Sub(t4, x3) // t4 := t4 - X3
x3.Add(p1.x, p1.z) // X3 := X1 + Z1 x3.Add(p1.x, p1.z) // X3 := X1 + Z1
y3 := new({{.Element}}).Add(p2.x, p2.z) // Y3 := X2 + Z2 y3 := new({{.Element}}).Add(p2.x, p2.z) // Y3 := X2 + Z2
x3.Mul(x3, y3) // X3 := X3 * Y3 x3.Mul(x3, y3) // X3 := X3 * Y3
y3.Add(t0, t2) // Y3 := t0 + t2 y3.Add(t0, t2) // Y3 := t0 + t2
y3.Sub(x3, y3) // Y3 := X3 - Y3 y3.Sub(x3, y3) // Y3 := X3 - Y3
z3 := new({{.Element}}).Mul({{.p}}B, t2) // Z3 := b * t2 z3 := new({{.Element}}).Mul({{.p}}B(), t2) // Z3 := b * t2
x3.Sub(y3, z3) // X3 := Y3 - Z3 x3.Sub(y3, z3) // X3 := Y3 - Z3
z3.Add(x3, x3) // Z3 := X3 + X3 z3.Add(x3, x3) // Z3 := X3 + X3
x3.Add(x3, z3) // X3 := X3 + Z3 x3.Add(x3, z3) // X3 := X3 + Z3
z3.Sub(t1, x3) // Z3 := t1 - X3 z3.Sub(t1, x3) // Z3 := t1 - X3
x3.Add(t1, x3) // X3 := t1 + X3 x3.Add(t1, x3) // X3 := t1 + X3
y3.Mul({{.p}}B, y3) // Y3 := b * Y3 y3.Mul({{.p}}B(), y3) // Y3 := b * Y3
t1.Add(t2, t2) // t1 := t2 + t2 t1.Add(t2, t2) // t1 := t2 + t2
t2.Add(t1, t2) // t2 := t1 + t2 t2.Add(t1, t2) // t2 := t1 + t2
y3.Sub(y3, t2) // Y3 := Y3 - t2 y3.Sub(y3, t2) // Y3 := Y3 - t2
y3.Sub(y3, t0) // Y3 := Y3 - t0 y3.Sub(y3, t0) // Y3 := Y3 - t0
t1.Add(y3, y3) // t1 := Y3 + Y3 t1.Add(y3, y3) // t1 := Y3 + Y3
y3.Add(t1, y3) // Y3 := t1 + Y3 y3.Add(t1, y3) // Y3 := t1 + Y3
t1.Add(t0, t0) // t1 := t0 + t0 t1.Add(t0, t0) // t1 := t0 + t0
t0.Add(t1, t0) // t0 := t1 + t0 t0.Add(t1, t0) // t0 := t1 + t0
t0.Sub(t0, t2) // t0 := t0 - t2 t0.Sub(t0, t2) // t0 := t0 - t2
t1.Mul(t4, y3) // t1 := t4 * Y3 t1.Mul(t4, y3) // t1 := t4 * Y3
t2.Mul(t0, y3) // t2 := t0 * Y3 t2.Mul(t0, y3) // t2 := t0 * Y3
y3.Mul(x3, z3) // Y3 := X3 * Z3 y3.Mul(x3, z3) // Y3 := X3 * Z3
y3.Add(y3, t2) // Y3 := Y3 + t2 y3.Add(y3, t2) // Y3 := Y3 + t2
x3.Mul(t3, x3) // X3 := t3 * X3 x3.Mul(t3, x3) // X3 := t3 * X3
x3.Sub(x3, t1) // X3 := X3 - t1 x3.Sub(x3, t1) // X3 := X3 - t1
z3.Mul(t4, z3) // Z3 := t4 * Z3 z3.Mul(t4, z3) // Z3 := t4 * Z3
t1.Mul(t3, t0) // t1 := t3 * t0 t1.Mul(t3, t0) // t1 := t3 * t0
z3.Add(z3, t1) // Z3 := Z3 + t1 z3.Add(z3, t1) // Z3 := Z3 + t1
q.x.Set(x3) q.x.Set(x3)
q.y.Set(y3) q.y.Set(y3)
q.z.Set(z3) q.z.Set(z3)
return q return q
} }
// Double sets q = p + p, and returns q. The points may overlap. // Double sets q = p + p, and returns q. The points may overlap.
func (q *{{.P}}Point) Double(p *{{.P}}Point) *{{.P}}Point { func (q *{{.P}}Point) Double(p *{{.P}}Point) *{{.P}}Point {
// Complete addition formula for a = -3 from "Complete addition formulas for // Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new({{.Element}}).Square(p.x) // t0 := X ^ 2 t0 := new({{.Element}}).Square(p.x) // t0 := X ^ 2
t1 := new({{.Element}}).Square(p.y) // t1 := Y ^ 2 t1 := new({{.Element}}).Square(p.y) // t1 := Y ^ 2
t2 := new({{.Element}}).Square(p.z) // t2 := Z ^ 2 t2 := new({{.Element}}).Square(p.z) // t2 := Z ^ 2
t3 := new({{.Element}}).Mul(p.x, p.y) // t3 := X * Y t3 := new({{.Element}}).Mul(p.x, p.y) // t3 := X * Y
t3.Add(t3, t3) // t3 := t3 + t3 t3.Add(t3, t3) // t3 := t3 + t3
z3 := new({{.Element}}).Mul(p.x, p.z) // Z3 := X * Z z3 := new({{.Element}}).Mul(p.x, p.z) // Z3 := X * Z
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
y3 := new({{.Element}}).Mul({{.p}}B, t2) // Y3 := b * t2 y3 := new({{.Element}}).Mul({{.p}}B(), t2) // Y3 := b * t2
y3.Sub(y3, z3) // Y3 := Y3 - Z3 y3.Sub(y3, z3) // Y3 := Y3 - Z3
x3 := new({{.Element}}).Add(y3, y3) // X3 := Y3 + Y3 x3 := new({{.Element}}).Add(y3, y3) // X3 := Y3 + Y3
y3.Add(x3, y3) // Y3 := X3 + Y3 y3.Add(x3, y3) // Y3 := X3 + Y3
x3.Sub(t1, y3) // X3 := t1 - Y3 x3.Sub(t1, y3) // X3 := t1 - Y3
y3.Add(t1, y3) // Y3 := t1 + Y3 y3.Add(t1, y3) // Y3 := t1 + Y3
y3.Mul(x3, y3) // Y3 := X3 * Y3 y3.Mul(x3, y3) // Y3 := X3 * Y3
x3.Mul(x3, t3) // X3 := X3 * t3 x3.Mul(x3, t3) // X3 := X3 * t3
t3.Add(t2, t2) // t3 := t2 + t2 t3.Add(t2, t2) // t3 := t2 + t2
t2.Add(t2, t3) // t2 := t2 + t3 t2.Add(t2, t3) // t2 := t2 + t3
z3.Mul({{.p}}B, z3) // Z3 := b * Z3 z3.Mul({{.p}}B(), z3) // Z3 := b * Z3
z3.Sub(z3, t2) // Z3 := Z3 - t2 z3.Sub(z3, t2) // Z3 := Z3 - t2
z3.Sub(z3, t0) // Z3 := Z3 - t0 z3.Sub(z3, t0) // Z3 := Z3 - t0
t3.Add(z3, z3) // t3 := Z3 + Z3 t3.Add(z3, z3) // t3 := Z3 + Z3
z3.Add(z3, t3) // Z3 := Z3 + t3 z3.Add(z3, t3) // Z3 := Z3 + t3
t3.Add(t0, t0) // t3 := t0 + t0 t3.Add(t0, t0) // t3 := t0 + t0
t0.Add(t3, t0) // t0 := t3 + t0 t0.Add(t3, t0) // t0 := t3 + t0
t0.Sub(t0, t2) // t0 := t0 - t2 t0.Sub(t0, t2) // t0 := t0 - t2
t0.Mul(t0, z3) // t0 := t0 * Z3 t0.Mul(t0, z3) // t0 := t0 * Z3
y3.Add(y3, t0) // Y3 := Y3 + t0 y3.Add(y3, t0) // Y3 := Y3 + t0
t0.Mul(p.y, p.z) // t0 := Y * Z t0.Mul(p.y, p.z) // t0 := Y * Z
t0.Add(t0, t0) // t0 := t0 + t0 t0.Add(t0, t0) // t0 := t0 + t0
z3.Mul(t0, z3) // Z3 := t0 * Z3 z3.Mul(t0, z3) // Z3 := t0 * Z3
x3.Sub(x3, z3) // X3 := X3 - Z3 x3.Sub(x3, z3) // X3 := X3 - Z3
z3.Mul(t0, t1) // Z3 := t0 * t1 z3.Mul(t0, t1) // Z3 := t0 * t1
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
q.x.Set(x3) q.x.Set(x3)
q.y.Set(y3) q.y.Set(y3)
q.z.Set(z3) q.z.Set(z3)
return q return q
} }
// Select sets q to p1 if cond == 1, and to p2 if cond == 0. // Select sets q to p1 if cond == 1, and to p2 if cond == 0.
func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point { func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point {
q.x.Select(p1.x, p2.x, cond) q.x.Select(p1.x, p2.x, cond)
@ -396,10 +452,12 @@ func (q *{{.P}}Point) Select(p1, p2 *{{.P}}Point, cond int) *{{.P}}Point {
q.z.Select(p1.z, p2.z, cond) q.z.Select(p1.z, p2.z, cond)
return q return q
} }
// A {{.p}}Table holds the first 15 multiples of a point at offset -1, so [1]P // A {{.p}}Table holds the first 15 multiples of a point at offset -1, so [1]P
// is at table[0], [15]P is at table[14], and [0]P is implicitly the identity // is at table[0], [15]P is at table[14], and [0]P is implicitly the identity
// point. // point.
type {{.p}}Table [15]*{{.P}}Point type {{.p}}Table [15]*{{.P}}Point
// Select selects the n-th multiple of the table base point into p. It works in // Select selects the n-th multiple of the table base point into p. It works in
// constant time by iterating over every entry of the table. n must be in [0, 15]. // constant time by iterating over every entry of the table. n must be in [0, 15].
func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) { func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) {
@ -412,6 +470,7 @@ func (table *{{.p}}Table) Select(p *{{.P}}Point, n uint8) {
p.Select(table[i-1], p, cond) p.Select(table[i-1], p, cond)
} }
} }
// ScalarMult sets p = scalar * q, and returns p. // ScalarMult sets p = scalar * q, and returns p.
func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, error) { func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, error) {
// Compute a {{.p}}Table for the base point q. The explicit New{{.P}}Point // Compute a {{.p}}Table for the base point q. The explicit New{{.P}}Point
@ -451,15 +510,17 @@ func (p *{{.P}}Point) ScalarMult(q *{{.P}}Point, scalar []byte) (*{{.P}}Point, e
} }
return p, nil return p, nil
} }
var {{.p}}GeneratorTable *[{{.p}}ElementLength * 2]{{.p}}Table var {{.p}}GeneratorTable *[{{.p}}ElementLength * 2]{{.p}}Table
var {{.p}}GeneratorTableOnce sync.Once var {{.p}}GeneratorTableOnce sync.Once
// generatorTable returns a sequence of {{.p}}Tables. The first table contains // generatorTable returns a sequence of {{.p}}Tables. The first table contains
// multiples of G. Each successive table is the previous table doubled four // multiples of G. Each successive table is the previous table doubled four
// times. // times.
func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table { func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table {
{{.p}}GeneratorTableOnce.Do(func() { {{.p}}GeneratorTableOnce.Do(func() {
{{.p}}GeneratorTable = new([{{.p}}ElementLength * 2]{{.p}}Table) {{.p}}GeneratorTable = new([{{.p}}ElementLength * 2]{{.p}}Table)
base := New{{.P}}Generator() base := New{{.P}}Point().SetGenerator()
for i := 0; i < {{.p}}ElementLength*2; i++ { for i := 0; i < {{.p}}ElementLength*2; i++ {
{{.p}}GeneratorTable[i][0] = New{{.P}}Point().Set(base) {{.p}}GeneratorTable[i][0] = New{{.P}}Point().Set(base)
for j := 1; j < 15; j++ { for j := 1; j < 15; j++ {
@ -473,6 +534,7 @@ func (p *{{.P}}Point) generatorTable() *[{{.p}}ElementLength * 2]{{.p}}Table {
}) })
return {{.p}}GeneratorTable return {{.p}}GeneratorTable
} }
// ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and // ScalarBaseMult sets p = scalar * B, where B is the canonical generator, and
// returns p. // returns p.
func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) { func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) {
@ -494,13 +556,16 @@ func (p *{{.P}}Point) ScalarBaseMult(scalar []byte) (*{{.P}}Point, error) {
tables[tableIndex].Select(t, windowValue) tables[tableIndex].Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
tableIndex-- tableIndex--
windowValue = byte & 0b1111 windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue) tables[tableIndex].Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
tableIndex-- tableIndex--
} }
return p, nil return p, nil
} }
// {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns // {{.p}}Sqrt sets e to a square root of x. If x is not a square, {{.p}}Sqrt returns
// false and e is unchanged. e and x can overlap. // false and e is unchanged. e and x can overlap.
func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) { func {{.p}}Sqrt(e, x *{{ .Element }}) (isSquare bool) {

View File

@ -1,7 +1,9 @@
// Copyright 2022 The Go Authors. All rights reserved. // Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
// Code generated by generate.go. DO NOT EDIT. // Code generated by generate.go. DO NOT EDIT.
//go:build !amd64 && !arm64 || generic //go:build !amd64 && !arm64 || generic
// +build !amd64,!arm64 generic // +build !amd64,!arm64 generic
@ -14,9 +16,6 @@ import (
"sync" "sync"
) )
var sm2p256B, _ = new(fiat.SM2P256Element).SetBytes([]byte{0x28, 0xe9, 0xfa, 0x9e, 0x9d, 0x9f, 0x5e, 0x34, 0x4d, 0x5a, 0x9e, 0x4b, 0xcf, 0x65, 0x9, 0xa7, 0xf3, 0x97, 0x89, 0xf5, 0x15, 0xab, 0x8f, 0x92, 0xdd, 0xbc, 0xbd, 0x41, 0x4d, 0x94, 0xe, 0x93})
var sm2p256G, _ = NewSM2P256Point().SetBytes([]byte{0x4, 0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19, 0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94, 0x8f, 0xe3, 0xb, 0xbf, 0xf2, 0x66, 0xb, 0xe1, 0x71, 0x5a, 0x45, 0x89, 0x33, 0x4c, 0x74, 0xc7, 0xbc, 0x37, 0x36, 0xa2, 0xf4, 0xf6, 0x77, 0x9c, 0x59, 0xbd, 0xce, 0xe3, 0x6b, 0x69, 0x21, 0x53, 0xd0, 0xa9, 0x87, 0x7c, 0xc6, 0x2a, 0x47, 0x40, 0x2, 0xdf, 0x32, 0xe5, 0x21, 0x39, 0xf0, 0xa0})
// sm2p256ElementLength is the length of an element of the base or scalar field, // sm2p256ElementLength is the length of an element of the base or scalar field,
// which have the same bytes length for all NIST P curves. // which have the same bytes length for all NIST P curves.
const sm2p256ElementLength = 32 const sm2p256ElementLength = 32
@ -37,13 +36,12 @@ func NewSM2P256Point() *SM2P256Point {
} }
} }
// NewSM2P256Generator returns a new SM2P256Point set to the canonical generator. // SetGenerator sets p to the canonical generator and returns p.
func NewSM2P256Generator() *SM2P256Point { func (p *SM2P256Point) SetGenerator() *SM2P256Point {
return (&SM2P256Point{ p.x.SetBytes([]byte{0x32, 0xc4, 0xae, 0x2c, 0x1f, 0x19, 0x81, 0x19, 0x5f, 0x99, 0x4, 0x46, 0x6a, 0x39, 0xc9, 0x94, 0x8f, 0xe3, 0xb, 0xbf, 0xf2, 0x66, 0xb, 0xe1, 0x71, 0x5a, 0x45, 0x89, 0x33, 0x4c, 0x74, 0xc7})
x: new(fiat.SM2P256Element), p.y.SetBytes([]byte{0xbc, 0x37, 0x36, 0xa2, 0xf4, 0xf6, 0x77, 0x9c, 0x59, 0xbd, 0xce, 0xe3, 0x6b, 0x69, 0x21, 0x53, 0xd0, 0xa9, 0x87, 0x7c, 0xc6, 0x2a, 0x47, 0x40, 0x2, 0xdf, 0x32, 0xe5, 0x21, 0x39, 0xf0, 0xa0})
y: new(fiat.SM2P256Element), p.z.One()
z: new(fiat.SM2P256Element), return p
}).Set(sm2p256G)
} }
// Set sets p = q and returns p. // Set sets p = q and returns p.
@ -106,15 +104,29 @@ func (p *SM2P256Point) SetBytes(b []byte) (*SM2P256Point, error) {
} }
} }
var _sm2p256B *fiat.SM2P256Element
var _sm2p256BOnce sync.Once
func sm2p256B() *fiat.SM2P256Element {
_sm2p256BOnce.Do(func() {
_sm2p256B, _ = new(fiat.SM2P256Element).SetBytes([]byte{0x28, 0xe9, 0xfa, 0x9e, 0x9d, 0x9f, 0x5e, 0x34, 0x4d, 0x5a, 0x9e, 0x4b, 0xcf, 0x65, 0x9, 0xa7, 0xf3, 0x97, 0x89, 0xf5, 0x15, 0xab, 0x8f, 0x92, 0xdd, 0xbc, 0xbd, 0x41, 0x4d, 0x94, 0xe, 0x93})
})
return _sm2p256B
}
// sm2p256Polynomial sets y2 to x³ - 3x + b, and returns y2. // sm2p256Polynomial sets y2 to x³ - 3x + b, and returns y2.
func sm2p256Polynomial(y2, x *fiat.SM2P256Element) *fiat.SM2P256Element { func sm2p256Polynomial(y2, x *fiat.SM2P256Element) *fiat.SM2P256Element {
y2.Square(x) y2.Square(x)
y2.Mul(y2, x) y2.Mul(y2, x)
threeX := new(fiat.SM2P256Element).Add(x, x) threeX := new(fiat.SM2P256Element).Add(x, x)
threeX.Add(threeX, x) threeX.Add(threeX, x)
y2.Sub(y2, threeX) y2.Sub(y2, threeX)
return y2.Add(y2, sm2p256B)
return y2.Add(y2, sm2p256B())
} }
func sm2p256CheckOnCurve(x, y *fiat.SM2P256Element) error { func sm2p256CheckOnCurve(x, y *fiat.SM2P256Element) error {
// y² = x³ - 3x + b // y² = x³ - 3x + b
rhs := sm2p256Polynomial(new(fiat.SM2P256Element), x) rhs := sm2p256Polynomial(new(fiat.SM2P256Element), x)
@ -134,6 +146,7 @@ func (p *SM2P256Point) Bytes() []byte {
var out [1 + 2*sm2p256ElementLength]byte var out [1 + 2*sm2p256ElementLength]byte
return p.bytes(&out) return p.bytes(&out)
} }
func (p *SM2P256Point) bytes(out *[1 + 2*sm2p256ElementLength]byte) []byte { func (p *SM2P256Point) bytes(out *[1 + 2*sm2p256ElementLength]byte) []byte {
if p.z.IsZero() == 1 { if p.z.IsZero() == 1 {
return append(out[:0], 0) return append(out[:0], 0)
@ -147,6 +160,24 @@ func (p *SM2P256Point) bytes(out *[1 + 2*sm2p256ElementLength]byte) []byte {
return buf return buf
} }
// BytesX returns the encoding of the x-coordinate of p, as specified in SEC 1,
// Version 2.0, Section 2.3.5, or an error if p is the point at infinity.
func (p *SM2P256Point) BytesX() ([]byte, error) {
// This function is outlined to make the allocations inline in the caller
// rather than happen on the heap.
var out [sm2p256ElementLength]byte
return p.bytesX(&out)
}
func (p *SM2P256Point) bytesX(out *[sm2p256ElementLength]byte) ([]byte, error) {
if p.z.IsZero() == 1 {
return nil, errors.New("SM2P256 point is the point at infinity")
}
zinv := new(fiat.SM2P256Element).Invert(p.z)
x := new(fiat.SM2P256Element).Mul(p.x, zinv)
return append(out[:0], x.Bytes()...), nil
}
// BytesCompressed returns the compressed or infinity encoding of p, as // BytesCompressed returns the compressed or infinity encoding of p, as
// specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the // specified in SEC 1, Version 2.0, Section 2.3.3. Note that the encoding of the
// point at infinity is shorter than all other encodings. // point at infinity is shorter than all other encodings.
@ -156,6 +187,7 @@ func (p *SM2P256Point) BytesCompressed() []byte {
var out [1 + sm2p256ElementLength]byte var out [1 + sm2p256ElementLength]byte
return p.bytesCompressed(&out) return p.bytesCompressed(&out)
} }
func (p *SM2P256Point) bytesCompressed(out *[1 + sm2p256ElementLength]byte) []byte { func (p *SM2P256Point) bytesCompressed(out *[1 + sm2p256ElementLength]byte) []byte {
if p.z.IsZero() == 1 { if p.z.IsZero() == 1 {
return append(out[:0], 0) return append(out[:0], 0)
@ -175,49 +207,50 @@ func (p *SM2P256Point) bytesCompressed(out *[1 + sm2p256ElementLength]byte) []by
func (q *SM2P256Point) Add(p1, p2 *SM2P256Point) *SM2P256Point { func (q *SM2P256Point) Add(p1, p2 *SM2P256Point) *SM2P256Point {
// Complete addition formula for a = -3 from "Complete addition formulas for // Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.SM2P256Element).Mul(p1.x, p2.x) // t0 := X1 * X2 t0 := new(fiat.SM2P256Element).Mul(p1.x, p2.x) // t0 := X1 * X2
t1 := new(fiat.SM2P256Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2 t1 := new(fiat.SM2P256Element).Mul(p1.y, p2.y) // t1 := Y1 * Y2
t2 := new(fiat.SM2P256Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2 t2 := new(fiat.SM2P256Element).Mul(p1.z, p2.z) // t2 := Z1 * Z2
t3 := new(fiat.SM2P256Element).Add(p1.x, p1.y) // t3 := X1 + Y1 t3 := new(fiat.SM2P256Element).Add(p1.x, p1.y) // t3 := X1 + Y1
t4 := new(fiat.SM2P256Element).Add(p2.x, p2.y) // t4 := X2 + Y2 t4 := new(fiat.SM2P256Element).Add(p2.x, p2.y) // t4 := X2 + Y2
t3.Mul(t3, t4) // t3 := t3 * t4 t3.Mul(t3, t4) // t3 := t3 * t4
t4.Add(t0, t1) // t4 := t0 + t1 t4.Add(t0, t1) // t4 := t0 + t1
t3.Sub(t3, t4) // t3 := t3 - t4 t3.Sub(t3, t4) // t3 := t3 - t4
t4.Add(p1.y, p1.z) // t4 := Y1 + Z1 t4.Add(p1.y, p1.z) // t4 := Y1 + Z1
x3 := new(fiat.SM2P256Element).Add(p2.y, p2.z) // X3 := Y2 + Z2 x3 := new(fiat.SM2P256Element).Add(p2.y, p2.z) // X3 := Y2 + Z2
t4.Mul(t4, x3) // t4 := t4 * X3 t4.Mul(t4, x3) // t4 := t4 * X3
x3.Add(t1, t2) // X3 := t1 + t2 x3.Add(t1, t2) // X3 := t1 + t2
t4.Sub(t4, x3) // t4 := t4 - X3 t4.Sub(t4, x3) // t4 := t4 - X3
x3.Add(p1.x, p1.z) // X3 := X1 + Z1 x3.Add(p1.x, p1.z) // X3 := X1 + Z1
y3 := new(fiat.SM2P256Element).Add(p2.x, p2.z) // Y3 := X2 + Z2 y3 := new(fiat.SM2P256Element).Add(p2.x, p2.z) // Y3 := X2 + Z2
x3.Mul(x3, y3) // X3 := X3 * Y3 x3.Mul(x3, y3) // X3 := X3 * Y3
y3.Add(t0, t2) // Y3 := t0 + t2 y3.Add(t0, t2) // Y3 := t0 + t2
y3.Sub(x3, y3) // Y3 := X3 - Y3 y3.Sub(x3, y3) // Y3 := X3 - Y3
z3 := new(fiat.SM2P256Element).Mul(sm2p256B, t2) // Z3 := b * t2 z3 := new(fiat.SM2P256Element).Mul(sm2p256B(), t2) // Z3 := b * t2
x3.Sub(y3, z3) // X3 := Y3 - Z3 x3.Sub(y3, z3) // X3 := Y3 - Z3
z3.Add(x3, x3) // Z3 := X3 + X3 z3.Add(x3, x3) // Z3 := X3 + X3
x3.Add(x3, z3) // X3 := X3 + Z3 x3.Add(x3, z3) // X3 := X3 + Z3
z3.Sub(t1, x3) // Z3 := t1 - X3 z3.Sub(t1, x3) // Z3 := t1 - X3
x3.Add(t1, x3) // X3 := t1 + X3 x3.Add(t1, x3) // X3 := t1 + X3
y3.Mul(sm2p256B, y3) // Y3 := b * Y3 y3.Mul(sm2p256B(), y3) // Y3 := b * Y3
t1.Add(t2, t2) // t1 := t2 + t2 t1.Add(t2, t2) // t1 := t2 + t2
t2.Add(t1, t2) // t2 := t1 + t2 t2.Add(t1, t2) // t2 := t1 + t2
y3.Sub(y3, t2) // Y3 := Y3 - t2 y3.Sub(y3, t2) // Y3 := Y3 - t2
y3.Sub(y3, t0) // Y3 := Y3 - t0 y3.Sub(y3, t0) // Y3 := Y3 - t0
t1.Add(y3, y3) // t1 := Y3 + Y3 t1.Add(y3, y3) // t1 := Y3 + Y3
y3.Add(t1, y3) // Y3 := t1 + Y3 y3.Add(t1, y3) // Y3 := t1 + Y3
t1.Add(t0, t0) // t1 := t0 + t0 t1.Add(t0, t0) // t1 := t0 + t0
t0.Add(t1, t0) // t0 := t1 + t0 t0.Add(t1, t0) // t0 := t1 + t0
t0.Sub(t0, t2) // t0 := t0 - t2 t0.Sub(t0, t2) // t0 := t0 - t2
t1.Mul(t4, y3) // t1 := t4 * Y3 t1.Mul(t4, y3) // t1 := t4 * Y3
t2.Mul(t0, y3) // t2 := t0 * Y3 t2.Mul(t0, y3) // t2 := t0 * Y3
y3.Mul(x3, z3) // Y3 := X3 * Z3 y3.Mul(x3, z3) // Y3 := X3 * Z3
y3.Add(y3, t2) // Y3 := Y3 + t2 y3.Add(y3, t2) // Y3 := Y3 + t2
x3.Mul(t3, x3) // X3 := t3 * X3 x3.Mul(t3, x3) // X3 := t3 * X3
x3.Sub(x3, t1) // X3 := X3 - t1 x3.Sub(x3, t1) // X3 := X3 - t1
z3.Mul(t4, z3) // Z3 := t4 * Z3 z3.Mul(t4, z3) // Z3 := t4 * Z3
t1.Mul(t3, t0) // t1 := t3 * t0 t1.Mul(t3, t0) // t1 := t3 * t0
z3.Add(z3, t1) // Z3 := Z3 + t1 z3.Add(z3, t1) // Z3 := Z3 + t1
q.x.Set(x3) q.x.Set(x3)
q.y.Set(y3) q.y.Set(y3)
q.z.Set(z3) q.z.Set(z3)
@ -228,40 +261,41 @@ func (q *SM2P256Point) Add(p1, p2 *SM2P256Point) *SM2P256Point {
func (q *SM2P256Point) Double(p *SM2P256Point) *SM2P256Point { func (q *SM2P256Point) Double(p *SM2P256Point) *SM2P256Point {
// Complete addition formula for a = -3 from "Complete addition formulas for // Complete addition formula for a = -3 from "Complete addition formulas for
// prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2. // prime order elliptic curves" (https://eprint.iacr.org/2015/1060), §A.2.
t0 := new(fiat.SM2P256Element).Square(p.x) // t0 := X ^ 2 t0 := new(fiat.SM2P256Element).Square(p.x) // t0 := X ^ 2
t1 := new(fiat.SM2P256Element).Square(p.y) // t1 := Y ^ 2 t1 := new(fiat.SM2P256Element).Square(p.y) // t1 := Y ^ 2
t2 := new(fiat.SM2P256Element).Square(p.z) // t2 := Z ^ 2 t2 := new(fiat.SM2P256Element).Square(p.z) // t2 := Z ^ 2
t3 := new(fiat.SM2P256Element).Mul(p.x, p.y) // t3 := X * Y t3 := new(fiat.SM2P256Element).Mul(p.x, p.y) // t3 := X * Y
t3.Add(t3, t3) // t3 := t3 + t3 t3.Add(t3, t3) // t3 := t3 + t3
z3 := new(fiat.SM2P256Element).Mul(p.x, p.z) // Z3 := X * Z z3 := new(fiat.SM2P256Element).Mul(p.x, p.z) // Z3 := X * Z
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
y3 := new(fiat.SM2P256Element).Mul(sm2p256B, t2) // Y3 := b * t2 y3 := new(fiat.SM2P256Element).Mul(sm2p256B(), t2) // Y3 := b * t2
y3.Sub(y3, z3) // Y3 := Y3 - Z3 y3.Sub(y3, z3) // Y3 := Y3 - Z3
x3 := new(fiat.SM2P256Element).Add(y3, y3) // X3 := Y3 + Y3 x3 := new(fiat.SM2P256Element).Add(y3, y3) // X3 := Y3 + Y3
y3.Add(x3, y3) // Y3 := X3 + Y3 y3.Add(x3, y3) // Y3 := X3 + Y3
x3.Sub(t1, y3) // X3 := t1 - Y3 x3.Sub(t1, y3) // X3 := t1 - Y3
y3.Add(t1, y3) // Y3 := t1 + Y3 y3.Add(t1, y3) // Y3 := t1 + Y3
y3.Mul(x3, y3) // Y3 := X3 * Y3 y3.Mul(x3, y3) // Y3 := X3 * Y3
x3.Mul(x3, t3) // X3 := X3 * t3 x3.Mul(x3, t3) // X3 := X3 * t3
t3.Add(t2, t2) // t3 := t2 + t2 t3.Add(t2, t2) // t3 := t2 + t2
t2.Add(t2, t3) // t2 := t2 + t3 t2.Add(t2, t3) // t2 := t2 + t3
z3.Mul(sm2p256B, z3) // Z3 := b * Z3 z3.Mul(sm2p256B(), z3) // Z3 := b * Z3
z3.Sub(z3, t2) // Z3 := Z3 - t2 z3.Sub(z3, t2) // Z3 := Z3 - t2
z3.Sub(z3, t0) // Z3 := Z3 - t0 z3.Sub(z3, t0) // Z3 := Z3 - t0
t3.Add(z3, z3) // t3 := Z3 + Z3 t3.Add(z3, z3) // t3 := Z3 + Z3
z3.Add(z3, t3) // Z3 := Z3 + t3 z3.Add(z3, t3) // Z3 := Z3 + t3
t3.Add(t0, t0) // t3 := t0 + t0 t3.Add(t0, t0) // t3 := t0 + t0
t0.Add(t3, t0) // t0 := t3 + t0 t0.Add(t3, t0) // t0 := t3 + t0
t0.Sub(t0, t2) // t0 := t0 - t2 t0.Sub(t0, t2) // t0 := t0 - t2
t0.Mul(t0, z3) // t0 := t0 * Z3 t0.Mul(t0, z3) // t0 := t0 * Z3
y3.Add(y3, t0) // Y3 := Y3 + t0 y3.Add(y3, t0) // Y3 := Y3 + t0
t0.Mul(p.y, p.z) // t0 := Y * Z t0.Mul(p.y, p.z) // t0 := Y * Z
t0.Add(t0, t0) // t0 := t0 + t0 t0.Add(t0, t0) // t0 := t0 + t0
z3.Mul(t0, z3) // Z3 := t0 * Z3 z3.Mul(t0, z3) // Z3 := t0 * Z3
x3.Sub(x3, z3) // X3 := X3 - Z3 x3.Sub(x3, z3) // X3 := X3 - Z3
z3.Mul(t0, t1) // Z3 := t0 * t1 z3.Mul(t0, t1) // Z3 := t0 * t1
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
z3.Add(z3, z3) // Z3 := Z3 + Z3 z3.Add(z3, z3) // Z3 := Z3 + Z3
q.x.Set(x3) q.x.Set(x3)
q.y.Set(y3) q.y.Set(y3)
q.z.Set(z3) q.z.Set(z3)
@ -307,6 +341,7 @@ func (p *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point
table[i].Double(table[i/2]) table[i].Double(table[i/2])
table[i+1].Add(table[i], q) table[i+1].Add(table[i], q)
} }
// Instead of doing the classic double-and-add chain, we do it with a // Instead of doing the classic double-and-add chain, we do it with a
// four-bit window: we double four times, and then add [0-15]P. // four-bit window: we double four times, and then add [0-15]P.
t := NewSM2P256Point() t := NewSM2P256Point()
@ -320,17 +355,21 @@ func (p *SM2P256Point) ScalarMult(q *SM2P256Point, scalar []byte) (*SM2P256Point
p.Double(p) p.Double(p)
p.Double(p) p.Double(p)
} }
windowValue := byte >> 4 windowValue := byte >> 4
table.Select(t, windowValue) table.Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
p.Double(p) p.Double(p)
p.Double(p) p.Double(p)
p.Double(p) p.Double(p)
p.Double(p) p.Double(p)
windowValue = byte & 0b1111 windowValue = byte & 0b1111
table.Select(t, windowValue) table.Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
} }
return p, nil return p, nil
} }
@ -343,7 +382,7 @@ var sm2p256GeneratorTableOnce sync.Once
func (p *SM2P256Point) generatorTable() *[sm2p256ElementLength * 2]sm2p256Table { func (p *SM2P256Point) generatorTable() *[sm2p256ElementLength * 2]sm2p256Table {
sm2p256GeneratorTableOnce.Do(func() { sm2p256GeneratorTableOnce.Do(func() {
sm2p256GeneratorTable = new([sm2p256ElementLength * 2]sm2p256Table) sm2p256GeneratorTable = new([sm2p256ElementLength * 2]sm2p256Table)
base := NewSM2P256Generator() base := NewSM2P256Point().SetGenerator()
for i := 0; i < sm2p256ElementLength*2; i++ { for i := 0; i < sm2p256ElementLength*2; i++ {
sm2p256GeneratorTable[i][0] = NewSM2P256Point().Set(base) sm2p256GeneratorTable[i][0] = NewSM2P256Point().Set(base)
for j := 1; j < 15; j++ { for j := 1; j < 15; j++ {
@ -365,6 +404,7 @@ func (p *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) {
return nil, errors.New("invalid scalar length") return nil, errors.New("invalid scalar length")
} }
tables := p.generatorTable() tables := p.generatorTable()
// This is also a scalar multiplication with a four-bit window like in // This is also a scalar multiplication with a four-bit window like in
// ScalarMult, but in this case the doublings are precomputed. The value // ScalarMult, but in this case the doublings are precomputed. The value
// [windowValue]G added at iteration k would normally get doubled // [windowValue]G added at iteration k would normally get doubled
@ -379,11 +419,13 @@ func (p *SM2P256Point) ScalarBaseMult(scalar []byte) (*SM2P256Point, error) {
tables[tableIndex].Select(t, windowValue) tables[tableIndex].Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
tableIndex-- tableIndex--
windowValue = byte & 0b1111 windowValue = byte & 0b1111
tables[tableIndex].Select(t, windowValue) tables[tableIndex].Select(t, windowValue)
p.Add(p, t) p.Add(p, t)
tableIndex-- tableIndex--
} }
return p, nil return p, nil
} }