SM9: G1 G2 support point compress

This commit is contained in:
Sun Yimin 2022-06-15 15:17:16 +08:00 committed by GitHub
parent 0ea5fa3966
commit 14af2513d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 350 additions and 48 deletions

View File

@ -41,6 +41,9 @@ var pMinus2 = [4]uint64{0xe56f9b27e351457b, 0x21f2934b1a7aeedb, 0xd603ab4ff58ec7
// pMinus1Over2 is (p-1)/2.
var pMinus1Over2 = [4]uint64{0xf2b7cd93f1a8a2be, 0x90f949a58d3d776d, 0xeb01d5a7fac763a2, 0x5b2000000151d378}
// pMinus1Over2Big is (p-1)/2.
var pMinus1Over2Big = bigFromHex("5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be")
// pMinus1Over4 is (p-1)/4.
var pMinus1Over4 = bigFromHex("2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f")

View File

@ -36,6 +36,14 @@ func (c *curvePoint) Set(a *curvePoint) {
c.t.Set(&a.t)
}
func (c *curvePoint) polynomial(x *gfP) *gfP {
x3 := &gfP{}
gfpMul(x3, x, x)
gfpMul(x3, x3, x)
gfpAdd(x3, x3, curveB)
return x3
}
// IsOnCurve returns true iff c is on the curve.
func (c *curvePoint) IsOnCurve() bool {
c.MakeAffine()
@ -43,11 +51,10 @@ func (c *curvePoint) IsOnCurve() bool {
return true
}
y2, x3 := &gfP{}, &gfP{}
y2 := &gfP{}
gfpMul(y2, &c.y, &c.y)
gfpMul(x3, &c.x, &c.x)
gfpMul(x3, x3, &c.x)
gfpAdd(x3, x3, curveB)
x3 := c.polynomial(&c.x)
return *y2 == *x3
}

View File

@ -202,7 +202,7 @@ func (e *G1) Marshal() []byte {
return ret
}
// Marshal converts e to a byte slice with prefix
// MarshalUncompressed converts e to a byte slice with prefix
func (e *G1) MarshalUncompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
@ -214,6 +214,68 @@ func (e *G1) MarshalUncompressed() []byte {
return ret
}
// MarshalCompressed converts e to a byte slice with compress prefix.
// If the point is not on the curve (or is the conventional point at infinity), the behavior is undefined.
func (e *G1) MarshalCompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes+1)
if e.p == nil {
e.p = &curvePoint{}
}
e.p.MakeAffine()
temp := &gfP{}
montDecode(temp, &e.p.y)
temp.Marshal(ret[1:])
ret[0] = (ret[numBytes] & 1) | 2
montDecode(temp, &e.p.x)
temp.Marshal(ret[1:])
return ret
}
// UnmarshalCompressed sets e to the result of converting the output of Marshal back into
// a group element and then returns e.
func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if len(data) < 1+numBytes {
return nil, errors.New("sm9.G1: not enough data")
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, errors.New("sm9.G1: invalid point compress byte")
}
if e.p == nil {
e.p = &curvePoint{}
} else {
e.p.x, e.p.y = gfP{0}, gfP{0}
}
e.p.x.Unmarshal(data[1:])
montEncode(&e.p.x, &e.p.x)
x3 := e.p.polynomial(&e.p.x)
e.p.y.Sqrt(x3)
montDecode(x3, &e.p.y)
if byte(x3[0]&1) != data[0]&1 {
gfpNeg(&e.p.y, &e.p.y)
}
if e.p.x == *zero && e.p.y == *zero {
// This is the point at infinity.
e.p.y = *newGFp(1)
e.p.z = gfP{0}
e.p.t = gfP{0}
} else {
e.p.z = *newGFp(1)
e.p.t = *newGFp(1)
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G1: malformed point")
}
}
return data[numBytes+1:], nil
}
func (e *G1) fillBytes(buffer []byte) {
const numBytes = 256 / 8
@ -254,8 +316,7 @@ func (e *G1) Unmarshal(m []byte) ([]byte, error) {
montEncode(&e.p.x, &e.p.x)
montEncode(&e.p.y, &e.p.y)
zero := gfP{0}
if e.p.x == zero && e.p.y == zero {
if e.p.x == *zero && e.p.y == *zero {
// This is the point at infinity.
e.p.y = *newGFp(1)
e.p.z = gfP{0}

View File

@ -349,6 +349,36 @@ func TestLargeIsOnCurve(t *testing.T) {
}
}
func Test_G1MarshalCompressed(t *testing.T) {
e, e2 := &G1{}, &G1{}
ret := e.MarshalCompressed()
_, err := e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if !e2.p.IsInfinity() {
t.Errorf("not same")
}
e.p.Set(curveGen)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
e.p.Neg(e.p)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
}
func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) {
tests := []struct {
name string

View File

@ -168,7 +168,7 @@ func (e *G2) Marshal() []byte {
return ret
}
// Marshal converts e into a byte slice with prefix
// MarshalUncompressed converts e into a byte slice with uncompressed point prefix
func (e *G2) MarshalUncompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
@ -178,6 +178,75 @@ func (e *G2) MarshalUncompressed() []byte {
return ret
}
// MarshalCompressed converts e into a byte slice with uncompressed point prefix
func (e *G2) MarshalCompressed() []byte {
// Each value is a 256-bit number.
const numBytes = 256 / 8
ret := make([]byte, numBytes*2+1)
if e.p == nil {
e.p = &twistPoint{}
}
e.p.MakeAffine()
temp := &gfP{}
montDecode(temp, &e.p.y.y)
temp.Marshal(ret[1:])
ret[0] = (ret[numBytes] & 1) | 2
montDecode(temp, &e.p.x.x)
temp.Marshal(ret[1:])
montDecode(temp, &e.p.x.y)
temp.Marshal(ret[numBytes+1:])
return ret
}
// UnmarshalCompressed sets e to the result of converting the output of Marshal back into
// a group element and then returns e.
func (e *G2) UnmarshalCompressed(data []byte) ([]byte, error) {
// Each value is a 256-bit number.
const numBytes = 256 / 8
if len(data) < 1+2*numBytes {
return nil, errors.New("sm9.G2: not enough data")
}
if data[0] != 2 && data[0] != 3 { // compressed form
return nil, errors.New("sm9.G2: invalid point compress byte")
}
var err error
// Unmarshal the points and check their caps
if e.p == nil {
e.p = &twistPoint{}
}
if err = e.p.x.x.Unmarshal(data[1:]); err != nil {
return nil, err
}
if err = e.p.x.y.Unmarshal(data[1+numBytes:]); err != nil {
return nil, err
}
montEncode(&e.p.x.x, &e.p.x.x)
montEncode(&e.p.x.y, &e.p.x.y)
x3 := e.p.polynomial(&e.p.x)
e.p.y.Sqrt(x3)
x3y := &gfP{}
montDecode(x3y, &e.p.y.y)
if byte(x3y[0]&1) != data[0]&1 {
e.p.y.Neg(&e.p.y)
}
if e.p.x.IsZero() && e.p.y.IsZero() {
// This is the point at infinity.
e.p.y.SetOne()
e.p.z.SetZero()
e.p.t.SetZero()
} else {
e.p.z.SetOne()
e.p.t.SetOne()
if !e.p.IsOnCurve() {
return nil, errors.New("sm9.G2: malformed point")
}
}
return data[1+2*numBytes:], nil
}
func (e *G2) fillBytes(buffer []byte) {
// Each value is a 256-bit number.
const numBytes = 256 / 8

View File

@ -41,6 +41,39 @@ func TestG2Marshal(t *testing.T) {
}
}
func Test_G2MarshalCompressed(t *testing.T) {
e, e2 := &G2{}, &G2{}
ret := e.MarshalCompressed()
_, err := e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if !e2.p.IsInfinity() {
t.Errorf("not same")
}
e.p.Set(twistGen)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
e.p.Neg(e.p)
ret = e.MarshalCompressed()
_, err = e2.UnmarshalCompressed(ret)
if err != nil {
t.Fatal(err)
}
if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z {
t.Errorf("not same")
}
if e2.p.x == twistGen.x && e2.p.y == twistGen.y && e2.p.z == twistGen.z {
t.Errorf("not expected")
}
}
func BenchmarkG2(b *testing.B) {
x, _ := rand.Int(rand.Reader, Order)
b.ReportAllocs()

View File

@ -1,6 +1,8 @@
package sm9
import "math/big"
import (
"math/big"
)
// For details of the algorithms used, see "Multiplication and Squaring on
// Pairing-Friendly Fields, Devegili et al.
@ -239,17 +241,47 @@ func (e *gfP2) Frobenius(a *gfP2) *gfP2 {
}
// Sqrt method is only required when we implement compressed format
func (e *gfP2) Sqrt(f *gfP2) *gfP2 {
func (ret *gfP2) Sqrt(a *gfP2) *gfP2 {
// Algorithm 10 https://eprint.iacr.org/2012/685.pdf
// TODO
ret.SetZero()
c := &twistGen.x
b, b2, bq := &gfP2{}, &gfP2{}, &gfP2{}
b.Exp(f, pMinus1Over4)
b.Exp(a, pMinus1Over4)
b2.Mul(b, b)
bq.Exp(b, p)
return bq
t := &gfP2{}
x0 := &gfP{}
/* ignore sqrt existing check
a0 := &gfP2{}
a0.Exp(b2, p)
a0.Mul(a0, b2)
a0 = gfP2Decode(a0)
*/
t.Mul(bq, b)
if t.x == *zero && t.y == *one {
t.Mul(b2, a)
x0.Sqrt(&t.y)
t.MulScalar(bq, x0)
ret.Set(t)
} else {
d, e, f := &gfP2{}, &gfP2{}, &gfP2{}
d.Exp(c, pMinus1Over2Big)
e.Mul(d, c)
f.Square(e)
e.Invert(e)
t.Mul(b2, a)
t.Mul(t, f)
x0.Sqrt(&t.y)
t.MulScalar(bq, x0)
t.Mul(t, e)
ret.Set(t)
}
return ret
}
// Div2 e = f / 2, not used currently
func (e *gfP2) Div2(f *gfP2) *gfP2 {
t := &gfP2{}
t.x.Div2(&f.x)

View File

@ -118,3 +118,52 @@ func Test_gfP2Div2(t *testing.T) {
t.Errorf("got %v, expected %v", ret, x)
}
}
func Test_gfP2Sqrt(t *testing.T) {
x := &gfP2{
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
*fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")),
}
x2, x3, sqrt, sqrtNeg := &gfP2{}, &gfP2{}, &gfP2{}, &gfP2{}
x2.Mul(x, x)
sqrt.Sqrt(x2)
sqrtNeg.Neg(sqrt)
x3.Mul(sqrt, sqrt)
if *x3 != *x2 {
t.Errorf("not correct")
}
if *sqrt != *x && *sqrtNeg != *x {
t.Errorf("sqrt not expected")
}
}
/*
func Test_gfP2QuadraticResidue(t *testing.T) {
x := &gfP2{
*fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")),
*fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")),
}
n := bigFromHex("40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84")
y := &gfP2{}
x2 := &gfP2{}
x2.Exp(x, n)
x2 = gfP2Decode(x2)
fmt.Printf("%v\n", x2)
for {
k, err := randomK(rand.Reader)
if err != nil {
t.Fatal(err)
}
x2.Exp(x, k)
y.Exp(x2, n)
if y.x == *zero && y.y == *one {
break
}
}
x2 = gfP2Decode(x2)
fmt.Printf("%v\n", x2)
}
*/

View File

@ -389,8 +389,7 @@ func UnmarshalSM9KeyPackage(der []byte) ([]byte, *G1, error) {
!inner.Empty() {
return nil, nil, errors.New("sm9: invalid SM9KeyPackage asn.1 data")
}
g := new(G1)
_, err := g.Unmarshal(cipherBytes[1:])
g, err := unmarshalG1(cipherBytes)
if err != nil {
return nil, nil, err
}
@ -418,16 +417,12 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *G1, kLen int) ([]byt
}
func (priv *EncryptPrivateKey) UnwrapKey(uid, cipherDer []byte, kLen int) ([]byte, error) {
bytes := make([]byte, 64+1)
var bytes []byte
input := cryptobyte.String(cipherDer)
if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
return nil, errors.New("sm9: invalid chipher asn1 data")
}
if bytes[0] != 4 {
return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", bytes[0])
}
g := new(G1)
_, err := g.Unmarshal(bytes[1:])
g, err := unmarshalG1(bytes)
if err != nil {
return nil, err
}
@ -534,11 +529,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error
if encType != int(ENC_TYPE_XOR) {
return nil, fmt.Errorf("sm9: does not support this kind of encrypt type <%v> yet", encType)
}
if c1Bytes[0] != 4 {
return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", c1Bytes[0])
}
c := &G1{}
_, err := c.Unmarshal(c1Bytes[1:])
c, err := unmarshalG1(c1Bytes)
if err != nil {
return nil, err
}

View File

@ -124,6 +124,25 @@ func (pub *SignMasterPublicKey) MarshalASN1() ([]byte, error) {
return b.Bytes()
}
func unmarshalG2(bytes []byte) (*G2, error) {
g2 := new(G2)
switch bytes[0] {
case 4:
_, err := g2.Unmarshal(bytes[1:])
if err != nil {
return nil, err
}
case 2, 3:
_, err := g2.UnmarshalCompressed(bytes)
if err != nil {
return nil, err
}
default:
return nil, errors.New("sm9: invalid point identity byte")
}
return g2, nil
}
// UnmarshalASN1 unmarsal der data to sign master public key
func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error {
var bytes []byte
@ -131,11 +150,7 @@ func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error {
if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
return errors.New("sm9: invalid sign master public key asn1 data")
}
if bytes[0] != 4 {
return errors.New("sm9: invalid prefix of sign master public key")
}
g2 := new(G2)
_, err := g2.Unmarshal(bytes[1:])
g2, err := unmarshalG2(bytes)
if err != nil {
return err
}
@ -163,6 +178,25 @@ func (priv *SignPrivateKey) MarshalASN1() ([]byte, error) {
return b.Bytes()
}
func unmarshalG1(bytes []byte) (*G1, error) {
g := new(G1)
switch bytes[0] {
case 4:
_, err := g.Unmarshal(bytes[1:])
if err != nil {
return nil, err
}
case 2, 3:
_, err := g.UnmarshalCompressed(bytes)
if err != nil {
return nil, err
}
default:
return nil, errors.New("sm9: invalid point identity byte")
}
return g, nil
}
// UnmarshalASN1 unmarsal der data to sign user private key
// Note, priv's SignMasterPublicKey should be handled separately.
func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error {
@ -171,11 +205,7 @@ func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error {
if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
return errors.New("sm9: invalid sign user private key asn1 data")
}
if bytes[0] != 4 {
return errors.New("sm9: invalid prefix of sign user private key")
}
g := new(G1)
_, err := g.Unmarshal(bytes[1:])
g, err := unmarshalG1(bytes)
if err != nil {
return err
}
@ -269,11 +299,7 @@ func (pub *EncryptMasterPublicKey) UnmarshalASN1(der []byte) error {
if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
return errors.New("sm9: invalid encrypt master public key asn1 data")
}
if bytes[0] != 4 {
return errors.New("sm9: invalid prefix of encrypt master public key")
}
g := new(G1)
_, err := g.Unmarshal(bytes[1:])
g, err := unmarshalG1(bytes)
if err != nil {
return err
}
@ -309,11 +335,7 @@ func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error {
if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() {
return errors.New("sm9: invalid encrypt user private key asn1 data")
}
if bytes[0] != 4 {
return errors.New("sm9: invalid prefix of encrypt user private key")
}
g := new(G2)
_, err := g.Unmarshal(bytes[1:])
g, err := unmarshalG2(bytes)
if err != nil {
return err
}

View File

@ -56,6 +56,12 @@ func NewTwistGenerator() *twistPoint {
return c
}
func (c *twistPoint) polynomial(x *gfP2) *gfP2 {
x3 := &gfP2{}
x3.Square(x).Mul(x3, x).Add(x3, twistB)
return x3
}
// IsOnCurve returns true iff c is on the curve.
func (c *twistPoint) IsOnCurve() bool {
c.MakeAffine()
@ -63,9 +69,9 @@ func (c *twistPoint) IsOnCurve() bool {
return true
}
y2, x3 := &gfP2{}, &gfP2{}
y2 := &gfP2{}
y2.Square(&c.y)
x3.Square(&c.x).Mul(x3, &c.x).Add(x3, twistB)
x3 := c.polynomial(&c.x)
return *y2 == *x3
}
@ -169,7 +175,6 @@ func (c *twistPoint) Double(a *twistPoint) {
c.y.Sub(t2, t)
}
// TODO: improve it
func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) {
sum, t := &twistPoint{}, &twistPoint{}