From 14af2513d81d0e3e34ed18b6be72366531375cb6 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 15 Jun 2022 15:17:16 +0800 Subject: [PATCH] SM9: G1 G2 support point compress --- sm9/constants.go | 3 ++ sm9/curve.go | 15 +++++++--- sm9/g1.go | 67 +++++++++++++++++++++++++++++++++++++++++++-- sm9/g1_test.go | 30 ++++++++++++++++++++ sm9/g2.go | 71 +++++++++++++++++++++++++++++++++++++++++++++++- sm9/g2_test.go | 33 ++++++++++++++++++++++ sm9/gfp2.go | 40 ++++++++++++++++++++++++--- sm9/gfp2_test.go | 49 +++++++++++++++++++++++++++++++++ sm9/sm9.go | 17 +++--------- sm9/sm9_key.go | 62 ++++++++++++++++++++++++++++-------------- sm9/twist.go | 11 ++++++-- 11 files changed, 350 insertions(+), 48 deletions(-) diff --git a/sm9/constants.go b/sm9/constants.go index c49a3b9..0b59dce 100644 --- a/sm9/constants.go +++ b/sm9/constants.go @@ -41,6 +41,9 @@ var pMinus2 = [4]uint64{0xe56f9b27e351457b, 0x21f2934b1a7aeedb, 0xd603ab4ff58ec7 // pMinus1Over2 is (p-1)/2. var pMinus1Over2 = [4]uint64{0xf2b7cd93f1a8a2be, 0x90f949a58d3d776d, 0xeb01d5a7fac763a2, 0x5b2000000151d378} +// pMinus1Over2Big is (p-1)/2. +var pMinus1Over2Big = bigFromHex("5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be") + // pMinus1Over4 is (p-1)/4. var pMinus1Over4 = bigFromHex("2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f") diff --git a/sm9/curve.go b/sm9/curve.go index bfe8e41..91f0e3e 100644 --- a/sm9/curve.go +++ b/sm9/curve.go @@ -36,6 +36,14 @@ func (c *curvePoint) Set(a *curvePoint) { c.t.Set(&a.t) } +func (c *curvePoint) polynomial(x *gfP) *gfP { + x3 := &gfP{} + gfpMul(x3, x, x) + gfpMul(x3, x3, x) + gfpAdd(x3, x3, curveB) + return x3 +} + // IsOnCurve returns true iff c is on the curve. func (c *curvePoint) IsOnCurve() bool { c.MakeAffine() @@ -43,11 +51,10 @@ func (c *curvePoint) IsOnCurve() bool { return true } - y2, x3 := &gfP{}, &gfP{} + y2 := &gfP{} gfpMul(y2, &c.y, &c.y) - gfpMul(x3, &c.x, &c.x) - gfpMul(x3, x3, &c.x) - gfpAdd(x3, x3, curveB) + + x3 := c.polynomial(&c.x) return *y2 == *x3 } diff --git a/sm9/g1.go b/sm9/g1.go index a535b68..76ae8db 100644 --- a/sm9/g1.go +++ b/sm9/g1.go @@ -202,7 +202,7 @@ func (e *G1) Marshal() []byte { return ret } -// Marshal converts e to a byte slice with prefix +// MarshalUncompressed converts e to a byte slice with prefix func (e *G1) MarshalUncompressed() []byte { // Each value is a 256-bit number. const numBytes = 256 / 8 @@ -214,6 +214,68 @@ func (e *G1) MarshalUncompressed() []byte { return ret } +// MarshalCompressed converts e to a byte slice with compress prefix. +// If the point is not on the curve (or is the conventional point at infinity), the behavior is undefined. +func (e *G1) MarshalCompressed() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + ret := make([]byte, numBytes+1) + if e.p == nil { + e.p = &curvePoint{} + } + + e.p.MakeAffine() + temp := &gfP{} + montDecode(temp, &e.p.y) + temp.Marshal(ret[1:]) + ret[0] = (ret[numBytes] & 1) | 2 + montDecode(temp, &e.p.x) + temp.Marshal(ret[1:]) + + return ret +} + +// UnmarshalCompressed sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G1) UnmarshalCompressed(data []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + if len(data) < 1+numBytes { + return nil, errors.New("sm9.G1: not enough data") + } + if data[0] != 2 && data[0] != 3 { // compressed form + return nil, errors.New("sm9.G1: invalid point compress byte") + } + if e.p == nil { + e.p = &curvePoint{} + } else { + e.p.x, e.p.y = gfP{0}, gfP{0} + } + e.p.x.Unmarshal(data[1:]) + montEncode(&e.p.x, &e.p.x) + x3 := e.p.polynomial(&e.p.x) + e.p.y.Sqrt(x3) + montDecode(x3, &e.p.y) + if byte(x3[0]&1) != data[0]&1 { + gfpNeg(&e.p.y, &e.p.y) + } + if e.p.x == *zero && e.p.y == *zero { + // This is the point at infinity. + e.p.y = *newGFp(1) + e.p.z = gfP{0} + e.p.t = gfP{0} + } else { + e.p.z = *newGFp(1) + e.p.t = *newGFp(1) + + if !e.p.IsOnCurve() { + return nil, errors.New("sm9.G1: malformed point") + } + } + + return data[numBytes+1:], nil +} + func (e *G1) fillBytes(buffer []byte) { const numBytes = 256 / 8 @@ -254,8 +316,7 @@ func (e *G1) Unmarshal(m []byte) ([]byte, error) { montEncode(&e.p.x, &e.p.x) montEncode(&e.p.y, &e.p.y) - zero := gfP{0} - if e.p.x == zero && e.p.y == zero { + if e.p.x == *zero && e.p.y == *zero { // This is the point at infinity. e.p.y = *newGFp(1) e.p.z = gfP{0} diff --git a/sm9/g1_test.go b/sm9/g1_test.go index d3b28c9..42139f0 100644 --- a/sm9/g1_test.go +++ b/sm9/g1_test.go @@ -349,6 +349,36 @@ func TestLargeIsOnCurve(t *testing.T) { } } +func Test_G1MarshalCompressed(t *testing.T) { + e, e2 := &G1{}, &G1{} + ret := e.MarshalCompressed() + _, err := e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if !e2.p.IsInfinity() { + t.Errorf("not same") + } + e.p.Set(curveGen) + ret = e.MarshalCompressed() + _, err = e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z { + t.Errorf("not same") + } + e.p.Neg(e.p) + ret = e.MarshalCompressed() + _, err = e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z { + t.Errorf("not same") + } +} + func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) { tests := []struct { name string diff --git a/sm9/g2.go b/sm9/g2.go index c5af3aa..4e96e15 100644 --- a/sm9/g2.go +++ b/sm9/g2.go @@ -168,7 +168,7 @@ func (e *G2) Marshal() []byte { return ret } -// Marshal converts e into a byte slice with prefix +// MarshalUncompressed converts e into a byte slice with uncompressed point prefix func (e *G2) MarshalUncompressed() []byte { // Each value is a 256-bit number. const numBytes = 256 / 8 @@ -178,6 +178,75 @@ func (e *G2) MarshalUncompressed() []byte { return ret } +// MarshalCompressed converts e into a byte slice with uncompressed point prefix +func (e *G2) MarshalCompressed() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + ret := make([]byte, numBytes*2+1) + if e.p == nil { + e.p = &twistPoint{} + } + e.p.MakeAffine() + temp := &gfP{} + montDecode(temp, &e.p.y.y) + temp.Marshal(ret[1:]) + ret[0] = (ret[numBytes] & 1) | 2 + + montDecode(temp, &e.p.x.x) + temp.Marshal(ret[1:]) + montDecode(temp, &e.p.x.y) + temp.Marshal(ret[numBytes+1:]) + + return ret +} + +// UnmarshalCompressed sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G2) UnmarshalCompressed(data []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + if len(data) < 1+2*numBytes { + return nil, errors.New("sm9.G2: not enough data") + } + if data[0] != 2 && data[0] != 3 { // compressed form + return nil, errors.New("sm9.G2: invalid point compress byte") + } + var err error + // Unmarshal the points and check their caps + if e.p == nil { + e.p = &twistPoint{} + } + if err = e.p.x.x.Unmarshal(data[1:]); err != nil { + return nil, err + } + if err = e.p.x.y.Unmarshal(data[1+numBytes:]); err != nil { + return nil, err + } + montEncode(&e.p.x.x, &e.p.x.x) + montEncode(&e.p.x.y, &e.p.x.y) + x3 := e.p.polynomial(&e.p.x) + e.p.y.Sqrt(x3) + x3y := &gfP{} + montDecode(x3y, &e.p.y.y) + if byte(x3y[0]&1) != data[0]&1 { + e.p.y.Neg(&e.p.y) + } + if e.p.x.IsZero() && e.p.y.IsZero() { + // This is the point at infinity. + e.p.y.SetOne() + e.p.z.SetZero() + e.p.t.SetZero() + } else { + e.p.z.SetOne() + e.p.t.SetOne() + + if !e.p.IsOnCurve() { + return nil, errors.New("sm9.G2: malformed point") + } + } + return data[1+2*numBytes:], nil +} + func (e *G2) fillBytes(buffer []byte) { // Each value is a 256-bit number. const numBytes = 256 / 8 diff --git a/sm9/g2_test.go b/sm9/g2_test.go index dc06c99..7a5b2aa 100644 --- a/sm9/g2_test.go +++ b/sm9/g2_test.go @@ -41,6 +41,39 @@ func TestG2Marshal(t *testing.T) { } } +func Test_G2MarshalCompressed(t *testing.T) { + e, e2 := &G2{}, &G2{} + ret := e.MarshalCompressed() + _, err := e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if !e2.p.IsInfinity() { + t.Errorf("not same") + } + e.p.Set(twistGen) + ret = e.MarshalCompressed() + _, err = e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z { + t.Errorf("not same") + } + e.p.Neg(e.p) + ret = e.MarshalCompressed() + _, err = e2.UnmarshalCompressed(ret) + if err != nil { + t.Fatal(err) + } + if e2.p.x != e.p.x || e2.p.y != e.p.y || e2.p.z != e.p.z { + t.Errorf("not same") + } + if e2.p.x == twistGen.x && e2.p.y == twistGen.y && e2.p.z == twistGen.z { + t.Errorf("not expected") + } +} + func BenchmarkG2(b *testing.B) { x, _ := rand.Int(rand.Reader, Order) b.ReportAllocs() diff --git a/sm9/gfp2.go b/sm9/gfp2.go index 0fc405a..6c2c371 100644 --- a/sm9/gfp2.go +++ b/sm9/gfp2.go @@ -1,6 +1,8 @@ package sm9 -import "math/big" +import ( + "math/big" +) // For details of the algorithms used, see "Multiplication and Squaring on // Pairing-Friendly Fields, Devegili et al. @@ -239,17 +241,47 @@ func (e *gfP2) Frobenius(a *gfP2) *gfP2 { } // Sqrt method is only required when we implement compressed format -func (e *gfP2) Sqrt(f *gfP2) *gfP2 { +func (ret *gfP2) Sqrt(a *gfP2) *gfP2 { // Algorithm 10 https://eprint.iacr.org/2012/685.pdf // TODO + ret.SetZero() + c := &twistGen.x b, b2, bq := &gfP2{}, &gfP2{}, &gfP2{} - b.Exp(f, pMinus1Over4) + b.Exp(a, pMinus1Over4) b2.Mul(b, b) bq.Exp(b, p) - return bq + t := &gfP2{} + x0 := &gfP{} + /* ignore sqrt existing check + a0 := &gfP2{} + a0.Exp(b2, p) + a0.Mul(a0, b2) + a0 = gfP2Decode(a0) + */ + t.Mul(bq, b) + if t.x == *zero && t.y == *one { + t.Mul(b2, a) + x0.Sqrt(&t.y) + t.MulScalar(bq, x0) + ret.Set(t) + } else { + d, e, f := &gfP2{}, &gfP2{}, &gfP2{} + d.Exp(c, pMinus1Over2Big) + e.Mul(d, c) + f.Square(e) + e.Invert(e) + t.Mul(b2, a) + t.Mul(t, f) + x0.Sqrt(&t.y) + t.MulScalar(bq, x0) + t.Mul(t, e) + ret.Set(t) + } + return ret } +// Div2 e = f / 2, not used currently func (e *gfP2) Div2(f *gfP2) *gfP2 { t := &gfP2{} t.x.Div2(&f.x) diff --git a/sm9/gfp2_test.go b/sm9/gfp2_test.go index 6f08a10..20f4bf2 100644 --- a/sm9/gfp2_test.go +++ b/sm9/gfp2_test.go @@ -118,3 +118,52 @@ func Test_gfP2Div2(t *testing.T) { t.Errorf("got %v, expected %v", ret, x) } } + +func Test_gfP2Sqrt(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + x2, x3, sqrt, sqrtNeg := &gfP2{}, &gfP2{}, &gfP2{}, &gfP2{} + x2.Mul(x, x) + sqrt.Sqrt(x2) + sqrtNeg.Neg(sqrt) + x3.Mul(sqrt, sqrt) + + if *x3 != *x2 { + t.Errorf("not correct") + } + + if *sqrt != *x && *sqrtNeg != *x { + t.Errorf("sqrt not expected") + } +} + +/* +func Test_gfP2QuadraticResidue(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + n := bigFromHex("40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84") + y := &gfP2{} + x2 := &gfP2{} + x2.Exp(x, n) + x2 = gfP2Decode(x2) + fmt.Printf("%v\n", x2) + for { + k, err := randomK(rand.Reader) + if err != nil { + t.Fatal(err) + } + + x2.Exp(x, k) + y.Exp(x2, n) + if y.x == *zero && y.y == *one { + break + } + } + x2 = gfP2Decode(x2) + fmt.Printf("%v\n", x2) +} +*/ diff --git a/sm9/sm9.go b/sm9/sm9.go index 773784f..1e4b336 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -389,8 +389,7 @@ func UnmarshalSM9KeyPackage(der []byte) ([]byte, *G1, error) { !inner.Empty() { return nil, nil, errors.New("sm9: invalid SM9KeyPackage asn.1 data") } - g := new(G1) - _, err := g.Unmarshal(cipherBytes[1:]) + g, err := unmarshalG1(cipherBytes) if err != nil { return nil, nil, err } @@ -418,16 +417,12 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *G1, kLen int) ([]byt } func (priv *EncryptPrivateKey) UnwrapKey(uid, cipherDer []byte, kLen int) ([]byte, error) { - bytes := make([]byte, 64+1) + var bytes []byte input := cryptobyte.String(cipherDer) if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { return nil, errors.New("sm9: invalid chipher asn1 data") } - if bytes[0] != 4 { - return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", bytes[0]) - } - g := new(G1) - _, err := g.Unmarshal(bytes[1:]) + g, err := unmarshalG1(bytes) if err != nil { return nil, err } @@ -534,11 +529,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error if encType != int(ENC_TYPE_XOR) { return nil, fmt.Errorf("sm9: does not support this kind of encrypt type <%v> yet", encType) } - if c1Bytes[0] != 4 { - return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", c1Bytes[0]) - } - c := &G1{} - _, err := c.Unmarshal(c1Bytes[1:]) + c, err := unmarshalG1(c1Bytes) if err != nil { return nil, err } diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go index bf8023c..1798365 100644 --- a/sm9/sm9_key.go +++ b/sm9/sm9_key.go @@ -124,6 +124,25 @@ func (pub *SignMasterPublicKey) MarshalASN1() ([]byte, error) { return b.Bytes() } +func unmarshalG2(bytes []byte) (*G2, error) { + g2 := new(G2) + switch bytes[0] { + case 4: + _, err := g2.Unmarshal(bytes[1:]) + if err != nil { + return nil, err + } + case 2, 3: + _, err := g2.UnmarshalCompressed(bytes) + if err != nil { + return nil, err + } + default: + return nil, errors.New("sm9: invalid point identity byte") + } + return g2, nil +} + // UnmarshalASN1 unmarsal der data to sign master public key func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { var bytes []byte @@ -131,11 +150,7 @@ func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { return errors.New("sm9: invalid sign master public key asn1 data") } - if bytes[0] != 4 { - return errors.New("sm9: invalid prefix of sign master public key") - } - g2 := new(G2) - _, err := g2.Unmarshal(bytes[1:]) + g2, err := unmarshalG2(bytes) if err != nil { return err } @@ -163,6 +178,25 @@ func (priv *SignPrivateKey) MarshalASN1() ([]byte, error) { return b.Bytes() } +func unmarshalG1(bytes []byte) (*G1, error) { + g := new(G1) + switch bytes[0] { + case 4: + _, err := g.Unmarshal(bytes[1:]) + if err != nil { + return nil, err + } + case 2, 3: + _, err := g.UnmarshalCompressed(bytes) + if err != nil { + return nil, err + } + default: + return nil, errors.New("sm9: invalid point identity byte") + } + return g, nil +} + // UnmarshalASN1 unmarsal der data to sign user private key // Note, priv's SignMasterPublicKey should be handled separately. func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { @@ -171,11 +205,7 @@ func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { return errors.New("sm9: invalid sign user private key asn1 data") } - if bytes[0] != 4 { - return errors.New("sm9: invalid prefix of sign user private key") - } - g := new(G1) - _, err := g.Unmarshal(bytes[1:]) + g, err := unmarshalG1(bytes) if err != nil { return err } @@ -269,11 +299,7 @@ func (pub *EncryptMasterPublicKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { return errors.New("sm9: invalid encrypt master public key asn1 data") } - if bytes[0] != 4 { - return errors.New("sm9: invalid prefix of encrypt master public key") - } - g := new(G1) - _, err := g.Unmarshal(bytes[1:]) + g, err := unmarshalG1(bytes) if err != nil { return err } @@ -309,11 +335,7 @@ func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error { if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { return errors.New("sm9: invalid encrypt user private key asn1 data") } - if bytes[0] != 4 { - return errors.New("sm9: invalid prefix of encrypt user private key") - } - g := new(G2) - _, err := g.Unmarshal(bytes[1:]) + g, err := unmarshalG2(bytes) if err != nil { return err } diff --git a/sm9/twist.go b/sm9/twist.go index 37736b4..4a83923 100644 --- a/sm9/twist.go +++ b/sm9/twist.go @@ -56,6 +56,12 @@ func NewTwistGenerator() *twistPoint { return c } +func (c *twistPoint) polynomial(x *gfP2) *gfP2 { + x3 := &gfP2{} + x3.Square(x).Mul(x3, x).Add(x3, twistB) + return x3 +} + // IsOnCurve returns true iff c is on the curve. func (c *twistPoint) IsOnCurve() bool { c.MakeAffine() @@ -63,9 +69,9 @@ func (c *twistPoint) IsOnCurve() bool { return true } - y2, x3 := &gfP2{}, &gfP2{} + y2 := &gfP2{} y2.Square(&c.y) - x3.Square(&c.x).Mul(x3, &c.x).Add(x3, twistB) + x3 := c.polynomial(&c.x) return *y2 == *x3 } @@ -169,7 +175,6 @@ func (c *twistPoint) Double(a *twistPoint) { c.y.Sub(t2, t) } -// TODO: improve it func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) { sum, t := &twistPoint{}, &twistPoint{}