diff --git a/sm9/README.md b/sm9/README.md new file mode 100644 index 0000000..bc88904 --- /dev/null +++ b/sm9/README.md @@ -0,0 +1,4 @@ +This part codes mainly refer two projects: + +1. [bn256](https://github.com/cloudflare/bn256), 主要是基域运算 +2. [gmssl sm9](https://github.com/guanzhi/GmSSL/blob/develop/src/sm9_alg.c),主要是2-4-12塔式扩域,以及r-ate等 \ No newline at end of file diff --git a/sm9/bn_pair.go b/sm9/bn_pair.go new file mode 100644 index 0000000..8bce124 --- /dev/null +++ b/sm9/bn_pair.go @@ -0,0 +1,239 @@ +package sm9 + +import ( + "math/big" +) + +func lineFunctionAdd(r, p *twistPoint, q *curvePoint, r2 *gfP2) (a, b, c, d *gfP2, rOut *twistPoint) { + // See the mixed addition algorithm from "Faster Computation of the + // Tate Pairing", http://arxiv.org/pdf/0904.0854v3.pdf + B := (&gfP2{}).Mul(&p.x, &r.t) // B = Xp * Zr^2 + + d = (&gfP2{}).Mul(B, &r.z) // d = Xp * Zr^3 + D := (&gfP2{}).Mul(&r.z, &r.x) + d.Sub(D, d) // d = Xr*Zr - Xp * Zr^3 + + D = (&gfP2{}).Add(&p.y, &r.z) // D = Yp + Zr + D.Square(D).Sub(D, r2).Sub(D, &r.t).Mul(D, &r.t) // D = ((Yp + Zr)^2 - Zr^2 - Yp^2)*Zr^2 = 2Yp*Zr^3 + + H := (&gfP2{}).Sub(B, &r.x) // H = Xp * Zr^2 - Xr + I := (&gfP2{}).Square(H) // I = (Xp * Zr^2 - Xr)^2 = Xp^2*Zr^4 + Xr^2 - 2Xr*Xp*Zr^2 + + E := (&gfP2{}).Add(I, I) // E = 2*(Xp * Zr^2 - Xr)^2 + E.Add(E, E) // E = 4*(Xp * Zr^2 - Xr)^2 + + J := (&gfP2{}).Mul(H, E) // J = 4*(Xp * Zr^2 - Xr)^3 + + L1 := (&gfP2{}).Sub(D, &r.y) // L1 = 2Yp*Zr^3 - Yr + L1.Sub(L1, &r.y) // L1 = 2Yp*Zr^3 - 2*Yr + + V := (&gfP2{}).Mul(&r.x, E) // V = 4 * Xr * (Xp * Zr^2 - Xr)^2 + + rOut = &twistPoint{} + rOut.x.Square(L1).Sub(&rOut.x, J).Sub(&rOut.x, V).Sub(&rOut.x, V) // rOut.x = L1^2 - J - 2V + + rOut.z.Add(&r.z, H).Square(&rOut.z).Sub(&rOut.z, &r.t).Sub(&rOut.z, I) // rOut.z = (Zr + H)^2 - Zr^2 - I + + t := (&gfP2{}).Sub(V, &rOut.x) // t = V - rOut.x + t.Mul(t, L1) // t = L1*(V-rOut.x) + t2 := (&gfP2{}).Mul(&r.y, J) + t2.Add(t2, t2) // t2 = 2Yr * J + rOut.y.Sub(t, t2) // rOut.y = L1*(V-rOut.x) - 2Yr*J + + rOut.t.Square(&rOut.z) + + t.Add(&p.y, &rOut.z).Square(t).Sub(t, r2).Sub(t, &rOut.t) // t = (Yp + rOut.Z)^2 - Yp^2 - rOut.Z^2 = 2Yp*rOut.Z + + t2.Mul(L1, &p.x) + t2.Add(t2, t2) // t2 = 2 L1 * Xp + a = (&gfP2{}).Sub(t2, t) // a = 2 L1 * Xp - 2 Yp * rOut.z + + c = (&gfP2{}).MulScalar(&rOut.z, &q.y) + c.Add(c, c) + + b = (&gfP2{}).Neg(L1) + b.MulScalar(b, &q.x).Add(b, b) + + return +} + +func lineFunctionDouble(r *twistPoint, q *curvePoint) (a, b, c, d *gfP2, rOut *twistPoint) { + // See the doubling algorithm for a=0 from "Faster Computation of the + // Tate Pairing", http://arxiv.org/pdf/0904.0854v3.pdf + A := (&gfP2{}).Square(&r.x) + B := (&gfP2{}).Square(&r.y) + C := (&gfP2{}).Square(B) // C = Yr ^ 4 + + D := (&gfP2{}).Add(&r.x, B) + D.Square(D).Sub(D, A).Sub(D, C).Add(D, D) + + E := (&gfP2{}).Add(A, A) // + E.Add(E, A) // E = 3 * Xr ^ 2 + + G := (&gfP2{}).Square(E) // G = 9 * Xr^4 + + rOut = &twistPoint{} + rOut.x.Sub(G, D).Sub(&rOut.x, D) + + rOut.z.Add(&r.y, &r.z).Square(&rOut.z).Sub(&rOut.z, B).Sub(&rOut.z, &r.t) // Z3 = (Yr + Zr)^2 - Yr^2 - Zr^2 = 2Yr*Zr + + rOut.y.Sub(D, &rOut.x).Mul(&rOut.y, E) + t := (&gfP2{}).Add(C, C) // t = 2 * r.y ^ 4 + t.Add(t, t).Add(t, t) // t = 8 * Yr ^ 4 + rOut.y.Sub(&rOut.y, t) + + rOut.t.Square(&rOut.z) + + d = (&gfP2{}).Mul(&rOut.z, &rOut.t) // d = 2Yr*Zr^3 + + t.Mul(E, &r.t).Add(t, t) + b = (&gfP2{}).Neg(t) + b.MulScalar(b, &q.x) + + a = (&gfP2{}).Add(&r.x, E) + a.Square(a).Sub(a, A).Sub(a, G) + t.Add(B, B).Add(t, t) + a.Sub(a, t) + + c = (&gfP2{}).Mul(&rOut.z, &r.t) + c.Add(c, c).MulScalar(c, &q.y) + + return +} + +func mulLine(ret *gfP12, retDen *gfP4, a, b, c, d *gfP2) { + l := &gfP12{} + l.y.SetZero() + l.x.x.SetZero() + l.x.y.Set(b) + l.z.x.Set(c) + l.z.y.Set(a) + + ret.Mul(ret, l) + + lDen := &gfP4{} + lDen.x.Set(d) + lDen.y.SetZero() + retDen.Mul(retDen, lDen) +} + +// +// R-ate Pairing G2 x G1 -> GT +// +// P is a point of order q in G1. Q(x,y) is a point of order q in G2. +// Note that Q is a point on the sextic twist of the curve over Fp^2, P(x,y) is a point on the +// curve over the base field Fp +// +func miller(q *twistPoint, p *curvePoint) *gfP12 { + ret := (&gfP12{}).SetOne() + retDen := (&gfP4{}).SetOne() // denominator + + aAffine := &twistPoint{} + aAffine.Set(q) + aAffine.MakeAffine() + + bAffine := &curvePoint{} + bAffine.Set(p) + bAffine.MakeAffine() + + r := &twistPoint{} + r.Set(aAffine) + + r2 := (&gfP2{}).Square(&aAffine.y) + + for i := sixUPlus2.BitLen() - 2; i >= 0; i-- { + ret.Square(ret) + retDen.Square(retDen) + a, b, c, d, newR := lineFunctionDouble(r, bAffine) + mulLine(ret, retDen, a, b, c, d) + r = newR + if sixUPlus2.Bit(i) == 1 { + a, b, c, d, newR = lineFunctionAdd(r, aAffine, bAffine, r2) + mulLine(ret, retDen, a, b, c, d) + r = newR + } + } + q1 := &twistPoint{} + q1.x.Conjugate(&aAffine.x) + q1.x.MulScalar(&q1.x, betaToNegPPlus1Over3) + q1.y.Conjugate(&aAffine.y) + q1.y.MulScalar(&q1.y, betaToNegPPlus1Over2) + q1.z.SetOne() + q1.t.SetOne() + + minusQ2 := &twistPoint{} + minusQ2.x.Set(&aAffine.x) + minusQ2.x.MulScalar(&minusQ2.x, betaToNegP2Plus1Over3) + minusQ2.y.Neg(&aAffine.y) + minusQ2.y.MulScalar(&minusQ2.y, betaToNegP2Plus1Over2) + minusQ2.z.SetOne() + minusQ2.t.SetOne() + + r2.Square(&q1.y) + a, b, c, d, newR := lineFunctionAdd(r, q1, bAffine, r2) + mulLine(ret, retDen, a, b, c, d) + r = newR + + r2.Square(&minusQ2.y) + a, b, c, d, _ = lineFunctionAdd(r, minusQ2, bAffine, r2) + mulLine(ret, retDen, a, b, c, d) + + retDen.Invert(retDen) + ret.MulScalar(ret, retDen) + + return ret +} + +func finalExponentiationHardPart(in *gfP12) *gfP12 { + a, b, t0, t1 := &gfP12{}, &gfP12{}, &gfP12{}, &gfP12{} + + a.Exp(in, sixUPlus5) + a.Invert(a) + b.Frobenius(a) + b.Mul(a, b) // b = ab + + a.Mul(a, b) + t0.Frobenius(in) + t1.Mul(t0, in) // t1 = in ^(p+1) + t1.Exp(t1, big.NewInt(9)) + a.Mul(a, t1) + + t1.Square(in) + t1.Square(t1) + a.Mul(a, t1) + + t0.Square(t0) // (in^p)^2 + t0.Mul(t0, b) // b*(in^p)^2 + b.FrobeniusP2(in) + t0.Mul(b, t0) // b*(in^p)^2 * in^(p^2) + t0.Exp(t0, sixU2Plus1) + a.Mul(a, t0) + + b.FrobeniusP3(in) + b.Mul(a, b) + return b +} + +// finalExponentiation computes the (p¹²-1)/Order-th power of an element of +// GF(p¹²) to obtain an element of GT. https://eprint.iacr.org/2007/390.pdf +func finalExponentiation(in *gfP12) *gfP12 { + t0, t1 := &gfP12{}, &gfP12{} + + t0.FrobeniusP6(in) + t1.Invert(in) + t0.Mul(t0, t1) + t1.FrobeniusP2(t0) + t0.Mul(t0, t1) + + return finalExponentiationHardPart(t0) +} + +func pairing(a *twistPoint, b *curvePoint) *gfP12 { + e := miller(a, b) + ret := finalExponentiation(e) + + if a.IsInfinity() || b.IsInfinity() { + ret.SetOne() + } + return ret +} diff --git a/sm9/bn_pair_test.go b/sm9/bn_pair_test.go new file mode 100644 index 0000000..9ef8c9c --- /dev/null +++ b/sm9/bn_pair_test.go @@ -0,0 +1,148 @@ +package sm9 + +import ( + "math/big" + "testing" +) + +var expected1 = &gfP12{} +var expected_b2 = &gfP12{} +var expected_b2_2 = &gfP12{} + +func init() { + expected1.x.x.x = *fromBigInt(bigFromHex("4e378fb5561cd0668f906b731ac58fee25738edf09cadc7a29c0abc0177aea6d")) + expected1.x.x.y = *fromBigInt(bigFromHex("28b3404a61908f5d6198815c99af1990c8af38655930058c28c21bb539ce0000")) + expected1.x.y.x = *fromBigInt(bigFromHex("38bffe40a22d529a0c66124b2c308dac9229912656f62b4facfced408e02380f")) + expected1.x.y.y = *fromBigInt(bigFromHex("a01f2c8bee81769609462c69c96aa923fd863e209d3ce26dd889b55e2e3873db")) + expected1.y.x.x = *fromBigInt(bigFromHex("67e0e0c2eed7a6993dce28fe9aa2ef56834307860839677f96685f2b44d0911f")) + expected1.y.x.y = *fromBigInt(bigFromHex("5a1ae172102efd95df7338dbc577c66d8d6c15e0a0158c7507228efb078f42a6")) + expected1.y.y.x = *fromBigInt(bigFromHex("1604a3fcfa9783e667ce9fcb1062c2a5c6685c316dda62de0548baa6ba30038b")) + expected1.y.y.y = *fromBigInt(bigFromHex("93634f44fa13af76169f3cc8fbea880adaff8475d5fd28a75deb83c44362b439")) + expected1.z.x.x = *fromBigInt(bigFromHex("b3129a75d31d17194675a1bc56947920898fbf390a5bf5d931ce6cbb3340f66d")) + expected1.z.x.y = *fromBigInt(bigFromHex("4c744e69c4a2e1c8ed72f796d151a17ce2325b943260fc460b9f73cb57c9014b")) + expected1.z.y.x = *fromBigInt(bigFromHex("84b87422330d7936eaba1109fa5a7a7181ee16f2438b0aeb2f38fd5f7554e57a")) + expected1.z.y.y = *fromBigInt(bigFromHex("aab9f06a4eeba4323a7833db202e4e35639d93fa3305af73f0f071d7d284fcfb")) + + expected_b2.x.x.x = *fromBigInt(bigFromHex("28542FB6954C84BE6A5F2988A31CB6817BA0781966FA83D9673A9577D3C0C134")) + expected_b2.x.x.y = *fromBigInt(bigFromHex("5E27C19FC02ED9AE37F5BB7BE9C03C2B87DE027539CCF03E6B7D36DE4AB45CD1")) + expected_b2.x.y.x = *fromBigInt(bigFromHex("A1ABFCD30C57DB0F1A838E3A8F2BF823479C978BD137230506EA6249C891049E")) + expected_b2.x.y.y = *fromBigInt(bigFromHex("3497477913AB89F5E2960F382B1B5C8EE09DE0FA498BA95C4409D630D343DA40")) + expected_b2.y.x.x = *fromBigInt(bigFromHex("4FEC93472DA33A4DB6599095C0CF895E3A7B993EE5E4EBE3B9AB7D7D5FF2A3D1")) + expected_b2.y.x.y = *fromBigInt(bigFromHex("647BA154C3E8E185DFC33657C1F128D480F3F7E3F16801208029E19434C733BB")) + expected_b2.y.y.x = *fromBigInt(bigFromHex("73F21693C66FC23724DB26380C526223C705DAF6BA18B763A68623C86A632B05")) + expected_b2.y.y.y = *fromBigInt(bigFromHex("0F63A071A6D62EA45B59A1942DFF5335D1A232C9C5664FAD5D6AF54C11418B0D")) + expected_b2.z.x.x = *fromBigInt(bigFromHex("8C8E9D8D905780D50E779067F2C4B1C8F83A8B59D735BB52AF35F56730BDE5AC")) + expected_b2.z.x.y = *fromBigInt(bigFromHex("861CCD9978617267CE4AD9789F77739E62F2E57B48C2FF26D2E90A79A1D86B93")) + expected_b2.z.y.x = *fromBigInt(bigFromHex("9B1CA08F64712E33AEDA3F44BD6CB633E0F722211E344D73EC9BBEBC92142765")) + expected_b2.z.y.y = *fromBigInt(bigFromHex("6BA584CE742A2A3AB41C15D3EF94EDEB8EF74A2BDCDAAECC09ABA567981F6437")) + + expected_b2_2.x.x.x = *fromBigInt(bigFromHex("1052D6E9D13E381909DFF7B2B41E13C987D0A9068423B769480DACCE6A06F492")) + expected_b2_2.x.x.y = *fromBigInt(bigFromHex("5FFEB92AD870F97DC0893114DA22A44DBC9E7A8B6CA31A0CF0467265A1FB48C7")) + expected_b2_2.x.y.x = *fromBigInt(bigFromHex("2C5C3B37E4F2FF83DB33D98C0317BCBBBBF4AC6DF6B89ECA58268B280045E612")) + expected_b2_2.x.y.y = *fromBigInt(bigFromHex("6CED9E2D7C9CD3D5AD630DEFAB0B831506218037EE0F861CF9B43C78434AEC38")) + expected_b2_2.y.x.x = *fromBigInt(bigFromHex("0AE7BF3E1AEC0CB67A03440906C7DFB3BCD4B6EEEBB7E371F0094AD4A816088D")) + expected_b2_2.y.x.y = *fromBigInt(bigFromHex("98DBC791D0671CACA12236CDF8F39E15AEB96FAEB39606D5B04AC581746A663D")) + expected_b2_2.y.y.x = *fromBigInt(bigFromHex("00DD2B7416BAA91172E89D5309D834F78C1E31B4483BB97185931BAD7BE1B9B5")) + expected_b2_2.y.y.y = *fromBigInt(bigFromHex("7EBAC0349F8544469E60C32F6075FB0468A68147FF013537DF792FFCE024F857")) + expected_b2_2.z.x.x = *fromBigInt(bigFromHex("10CC2B561A62B62DA36AEFD60850714F49170FD94A0010C6D4B651B64F3A3A5E")) + expected_b2_2.z.x.y = *fromBigInt(bigFromHex("58C9687BEDDCD9E4FEDAB16B884D1FE6DFA117B2AB821F74E0BF7ACDA2269859")) + expected_b2_2.z.y.x = *fromBigInt(bigFromHex("2A430968F16086061904CE201847934B11CA0F9E9528F5A9D0CE8F015C9AEA79")) + expected_b2_2.z.y.y = *fromBigInt(bigFromHex("934FDDA6D3AB48C8571CE2354B79742AA498CB8CDDE6BD1FA5946345A1A652F6")) +} + +func Test_gfp12Gen(t *testing.T) { + ret := pairing(twistGen, curveGen) + if ret.x != gfP12Gen.x || ret.y != gfP12Gen.y || ret.z != gfP12Gen.z { + t.Errorf("not expected") + } +} + +func Test_Pairing_A2(t *testing.T) { + pk := bigFromHex("0130E78459D78545CB54C587E02CF480CE0B66340F319F348A1D5B1F2DC5F4") + g2 := &G2{} + g2.ScalarBaseMult(pk) + ret := pairing(g2.p, curveGen) + if ret.x != expected1.x || ret.y != expected1.y || ret.z != expected1.z { + t.Errorf("not expected") + } +} + +func Test_Pairing_B2(t *testing.T) { + deB := &twistPoint{} + deB.x.x = *fromBigInt(bigFromHex("74CCC3AC9C383C60AF083972B96D05C75F12C8907D128A17ADAFBAB8C5A4ACF7")) + deB.x.y = *fromBigInt(bigFromHex("01092FF4DE89362670C21711B6DBE52DCD5F8E40C6654B3DECE573C2AB3D29B2")) + deB.y.x = *fromBigInt(bigFromHex("44B0294AA04290E1524FF3E3DA8CFD432BB64DE3A8040B5B88D1B5FC86A4EBC1")) + deB.y.y = *fromBigInt(bigFromHex("8CFC48FB4FF37F1E27727464F3C34E2153861AD08E972D1625FC1A7BD18D5539")) + deB.z.SetOne() + deB.t.SetOne() + + rA := &curvePoint{} + rA.x = *fromBigInt(bigFromHex("7CBA5B19069EE66AA79D490413D11846B9BA76DD22567F809CF23B6D964BB265")) + rA.y = *fromBigInt(bigFromHex("A9760C99CB6F706343FED05637085864958D6C90902ABA7D405FBEDF7B781599")) + rA.z = *one + rA.t = *one + + ret := pairing(deB, rA) + if ret.x != expected_b2.x || ret.y != expected_b2.y || ret.z != expected_b2.z { + t.Errorf("not expected") + } +} + +func Test_Pairing_B2_2(t *testing.T) { + pubE := &curvePoint{} + pubE.x = *fromBigInt(bigFromHex("9174542668E8F14AB273C0945C3690C66E5DD09678B86F734C4350567ED06283")) + pubE.y = *fromBigInt(bigFromHex("54E598C6BF749A3DACC9FFFEDD9DB6866C50457CFC7AA2A4AD65C3168FF74210")) + pubE.z = *one + pubE.t = *one + + ret := pairing(twistGen, pubE) + ret.Exp(ret, bigFromHex("00018B98C44BEF9F8537FB7D071B2C928B3BC65BD3D69E1EEE213564905634FE")) + if ret.x != expected_b2_2.x || ret.y != expected_b2_2.y || ret.z != expected_b2_2.z { + t.Errorf("not expected") + } +} + +func Test_finalExponentiation(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + } + got := finalExponentiation(x) + + exp := new(big.Int).Exp(p, big.NewInt(12), nil) + exp.Sub(exp, big.NewInt(1)) + exp.Div(exp, Order) + expected := (&gfP12{}).Exp(x, exp) + + if got.x != expected.x || got.y != expected.y || got.z != expected.z { + t.Errorf("got %v, expected %v\n", got, expected) + } +} diff --git a/sm9/constants.go b/sm9/constants.go new file mode 100644 index 0000000..41836f5 --- /dev/null +++ b/sm9/constants.go @@ -0,0 +1,82 @@ +package sm9 + +// u is the BN parameter that determines the prime: 600000000058f98a. +var u = bigFromHex("600000000058f98a") + +// sixUPlus2 = 6*u+2 +var sixUPlus2 = bigFromHex("02400000000215d93e") + +// sixUPlus5 = 6*u+5 +var sixUPlus5 = bigFromHex("02400000000215d941") + +// sixU2Plus1 = 6*u^2+1 +var sixU2Plus1 = bigFromHex("d8000000019062ed0000b98b0cb27659") + +// p is a prime over which we form a basic field: 36u⁴+36u³+24u²+6u+1. +var p = bigFromHex("b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457d") + +// Order is the number of elements in both G₁ and G₂: 36u⁴+36u³+18u²+6u+1. +var Order = bigFromHex("b640000002a3a6f1d603ab4ff58ec74449f2934b18ea8beee56ee19cd69ecf25") + +// p2 is p, represented as little-endian 64-bit words. +var p2 = [4]uint64{0xe56f9b27e351457d, 0x21f2934b1a7aeedb, 0xd603ab4ff58ec745, 0xb640000002a3a6f1} + +// np is the negative inverse of p, mod 2^256. +var np = [4]uint64{0x892bc42c2f2ee42b, 0x181ae39613c8dbaf, 0x966a4b291522b137, 0xafd2bac5558a13b3} + +// rN1 is R^-1 where R = 2^256 mod p. +var rN1 = &gfP{0x0a1c7970e5df544d, 0xe74504e9a96b56cc, 0xcda02d92d4d62924, 0x7d2bc576fdf597d1} + +// r2 is R^2 where R = 2^256 mod p. +var r2 = &gfP{0x27dea312b417e2d2, 0x88f8105fae1a5d3f, 0xe479b522d6706e7b, 0x2ea795a656f62fbd} + +// r3 is R^3 where R = 2^256 mod p. +var r3 = &gfP{0x130257769df5827e, 0x36920fc0837ec76e, 0xcbec24519c22a142, 0x219be84a7c687090} + +// pMinus2 is p-2. +var pMinus2 = [4]uint64{0xe56f9b27e351457b, 0x21f2934b1a7aeedb, 0xd603ab4ff58ec745, 0xb640000002a3a6f1} + +// pMinus1Over2 is (p-1)/2. +var pMinus1Over2 = [4]uint64{0xf2b7cd93f1a8a2be, 0x90f949a58d3d776d, 0xeb01d5a7fac763a2, 0x5b2000000151d378} + +// pMinus1Over4 is (p-1)/4. +var pMinus1Over4 = bigFromHex("2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f") + +// pMinus5Over8 is (p-5)/8. +var pMinus5Over8 = [4]uint64{0x7cadf364fc6a28af, 0xa43e5269634f5ddb, 0x3ac07569feb1d8e8, 0x16c80000005474de} + +// Montgomery encoding of 2^pMinus5Over8 +var twoExpPMinus5Over8 = &gfP{0xd5dd560c5235102a, 0xa3772bab091163ac, 0x0ed7304fd0711ab0, 0x8efb889ed7056e1e} + +// Frobenius Constant, frobConstant = i^((p-1)/6) +var frobConstant = fromBigInt(bigFromHex("3f23ea58e5720bdb843c6cfa9c08674947c5c86e0ddd04eda91d8354377b698b")) + +// vToPMinus1 is v^(p-1), vToPMinus1 ^ 2 = p - 1 +var vToPMinus1 = fromBigInt(bigFromHex("6c648de5dc0a3f2cf55acc93ee0baf159f9d411806dc5177f5b21fd3da24d011")) + +// wToPMinus1 is w^(p-1) +var wToPMinus1 = fromBigInt(bigFromHex("3f23ea58e5720bdb843c6cfa9c08674947c5c86e0ddd04eda91d8354377b698b")) + +// w2ToPMinus1 is (w^2)^(p-1) +var w2ToPMinus1 = fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65334")) + +// wToP2Minus1 is w^(p^2-1) +var wToP2Minus1 = fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65334")) + +// w2ToP2Minus1 is (w^2)^(p^2-1), w2ToP2Minus1 = vToPMinus1 * wToPMinus1 +var w2ToP2Minus1 = fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65333")) + +// vToPMinus1Mw2ToPMinus1 = vToPMinus1 * w2ToPMinus1 +var vToPMinus1Mw2ToPMinus1 = fromBigInt(bigFromHex("2d40a38cf6983351711e5f99520347cc57d778a9f8ff4c8a4c949c7fa2a96686")) + +// betaToNegPPlus1Over3 = i^(-(p-1)/3) +var betaToNegPPlus1Over3 = fromBigInt(bigFromHex("b640000002a3a6f0e303ab4ff2eb2052a9f02115caef75e70f738991676af24a")) + +// betaToNegPPlus1Over2 = i^(-(p-1)/2) +var betaToNegPPlus1Over2 = fromBigInt(bigFromHex("49db721a269967c4e0a8debc0783182f82555233139e9d63efbd7b54092c756c")) + +// betaToNegP2Plus1Over3 = i^(-(p^2-1)/3) +var betaToNegP2Plus1Over3 = fromBigInt(bigFromHex("b640000002a3a6f0e303ab4ff2eb2052a9f02115caef75e70f738991676af249")) + +// betaToNegP2Plus1Over2 = i^(-(p^2-1)/2) +var betaToNegP2Plus1Over2 = fromBigInt(bigFromHex("b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457c")) diff --git a/sm9/curve.go b/sm9/curve.go new file mode 100644 index 0000000..583e226 --- /dev/null +++ b/sm9/curve.go @@ -0,0 +1,228 @@ +package sm9 + +import "math/big" + +// curvePoint implements the elliptic curve y²=x³+5. Points are kept in Jacobian +// form and t=z² when valid. G₁ is the set of points of this curve on GF(p). +type curvePoint struct { + x, y, z, t gfP +} + +var curveB = newGFp(5) + +// curveGen is the generator of G₁. +var curveGen = &curvePoint{ + x: *fromBigInt(bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD")), + y: *fromBigInt(bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616")), + z: *one, + t: *one, +} + +func (c *curvePoint) String() string { + c.MakeAffine() + x, y := &gfP{}, &gfP{} + montDecode(x, &c.x) + montDecode(y, &c.y) + return "(" + x.String() + ", " + y.String() + ")" +} + +func (c *curvePoint) Set(a *curvePoint) { + c.x.Set(&a.x) + c.y.Set(&a.y) + c.z.Set(&a.z) + c.t.Set(&a.t) +} + +// IsOnCurve returns true iff c is on the curve. +func (c *curvePoint) IsOnCurve() bool { + c.MakeAffine() + if c.IsInfinity() { // TBC: This is not same as golang elliptic + return true + } + + y2, x3 := &gfP{}, &gfP{} + gfpMul(y2, &c.y, &c.y) + gfpMul(x3, &c.x, &c.x) + gfpMul(x3, x3, &c.x) + gfpAdd(x3, x3, curveB) + + return *y2 == *x3 +} + +func (c *curvePoint) SetInfinity() { + c.x = *zero + c.y = *one + c.z = *zero + c.t = *zero +} + +func (c *curvePoint) IsInfinity() bool { + return c.z == *zero +} + +func (c *curvePoint) Add(a, b *curvePoint) { + if a.IsInfinity() { + c.Set(b) + return + } + if b.IsInfinity() { + c.Set(a) + return + } + + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3 + + // Normalize the points by replacing a = [x1:y1:z1] and b = [x2:y2:z2] + // by [u1:s1:z1·z2] and [u2:s2:z1·z2] + // where u1 = x1·z2², s1 = y1·z2³ and u1 = x2·z1², s2 = y2·z1³ + z12, z22 := &gfP{}, &gfP{} + gfpMul(z12, &a.z, &a.z) + gfpMul(z22, &b.z, &b.z) + + u1, u2 := &gfP{}, &gfP{} + gfpMul(u1, &a.x, z22) + gfpMul(u2, &b.x, z12) + + t, s1 := &gfP{}, &gfP{} + gfpMul(t, &b.z, z22) + gfpMul(s1, &a.y, t) + + s2 := &gfP{} + gfpMul(t, &a.z, z12) + gfpMul(s2, &b.y, t) + + // Compute x = (2h)²(s²-u1-u2) + // where s = (s2-s1)/(u2-u1) is the slope of the line through + // (u1,s1) and (u2,s2). The extra factor 2h = 2(u2-u1) comes from the value of z below. + // This is also: + // 4(s2-s1)² - 4h²(u1+u2) = 4(s2-s1)² - 4h³ - 4h²(2u1) + // = r² - j - 2v + // with the notations below. + h := &gfP{} + gfpSub(h, u2, u1) + xEqual := *h == *zero + + gfpAdd(t, h, h) + // i = 4h² + i := &gfP{} + gfpMul(i, t, t) + // j = 4h³ + j := &gfP{} + gfpMul(j, h, i) + + gfpSub(t, s2, s1) + yEqual := *t == *one + if xEqual && yEqual { + c.Double(a) + return + } + r := &gfP{} + gfpAdd(r, t, t) + + v := &gfP{} + gfpMul(v, u1, i) + + // t4 = 4(s2-s1)² + t4, t6 := &gfP{}, &gfP{} + gfpMul(t4, r, r) + gfpAdd(t, v, v) + gfpSub(t6, t4, j) + + gfpSub(&c.x, t6, t) + + // Set y = -(2h)³(s1 + s*(x/4h²-u1)) + // This is also + // y = - 2·s1·j - (s2-s1)(2x - 2i·u1) = r(v-x) - 2·s1·j + gfpSub(t, v, &c.x) // t7 + gfpMul(t4, s1, j) // t8 + gfpAdd(t6, t4, t4) // t9 + gfpMul(t4, r, t) // t10 + gfpSub(&c.y, t4, t6) + + // Set z = 2(u2-u1)·z1·z2 = 2h·z1·z2 + gfpAdd(t, &a.z, &b.z) // t11 + gfpMul(t4, t, t) // t12 + gfpSub(t, t4, z12) // t13 + gfpSub(t4, t, z22) // t14 + gfpMul(&c.z, t4, h) +} + +func (c *curvePoint) Double(a *curvePoint) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + A, B, C := &gfP{}, &gfP{}, &gfP{} + gfpMul(A, &a.x, &a.x) + gfpMul(B, &a.y, &a.y) + gfpMul(C, B, B) + + t, t2 := &gfP{}, &gfP{} + gfpAdd(t, &a.x, B) + gfpMul(t2, t, t) + gfpSub(t, t2, A) + gfpSub(t2, t, C) + + d, e, f := &gfP{}, &gfP{}, &gfP{} + gfpAdd(d, t2, t2) + gfpAdd(t, A, A) + gfpAdd(e, t, A) + gfpMul(f, e, e) + + gfpAdd(t, d, d) + gfpSub(&c.x, f, t) + + gfpMul(&c.z, &a.y, &a.z) + gfpAdd(&c.z, &c.z, &c.z) + + gfpAdd(t, C, C) + gfpAdd(t2, t, t) + gfpAdd(t, t2, t2) + gfpSub(&c.y, d, &c.x) + gfpMul(t2, e, &c.y) + gfpSub(&c.y, t2, t) +} + +func (c *curvePoint) Mul(a *curvePoint, scalar *big.Int) { + sum, t := &curvePoint{}, &curvePoint{} + sum.SetInfinity() + + for i := scalar.BitLen(); i >= 0; i-- { + t.Double(sum) + if scalar.Bit(i) != 0 { + sum.Add(t, a) + } else { + sum.Set(t) + } + } + + c.Set(sum) +} + +func (c *curvePoint) MakeAffine() { + if c.z == *one { + return + } else if c.z == *zero { + c.x = *zero + c.y = *one + c.t = *zero + return + } + + zInv := &gfP{} + zInv.Invert(&c.z) + + t, zInv2 := &gfP{}, &gfP{} + gfpMul(t, &c.y, zInv) + gfpMul(zInv2, zInv, zInv) + + gfpMul(&c.x, &c.x, zInv2) + gfpMul(&c.y, t, zInv2) + + c.z = *one + c.t = *one +} + +func (c *curvePoint) Neg(a *curvePoint) { + c.x.Set(&a.x) + gfpNeg(&c.y, &a.y) + c.z.Set(&a.z) + c.t = *zero +} diff --git a/sm9/elliptic.go b/sm9/elliptic.go new file mode 100644 index 0000000..85faafd --- /dev/null +++ b/sm9/elliptic.go @@ -0,0 +1,182 @@ +package sm9 + +import ( + "io" + "math/big" +) + +// A Curve represents a short-form Weierstrass curve with a=0. +// +// The behavior of Add, Double, and ScalarMult when the input is not a point on +// the curve is undefined. +// +// Note that the conventional point at infinity (0, 0) is not considered on the +// curve, although it can be returned by Add, Double, ScalarMult, or +// ScalarBaseMult (but not the Unmarshal or UnmarshalCompressed functions). +type Curve interface { + // Params returns the parameters for the curve. + Params() *CurveParams + // IsOnCurve reports whether the given (x,y) lies on the curve. + IsOnCurve(x, y *big.Int) bool + // Add returns the sum of (x1,y1) and (x2,y2) + Add(x1, y1, x2, y2 *big.Int) (x, y *big.Int) + // Double returns 2*(x,y) + Double(x1, y1 *big.Int) (x, y *big.Int) + // ScalarMult returns k*(Bx,By) where k is a number in big-endian form. + ScalarMult(x1, y1 *big.Int, k []byte) (x, y *big.Int) + // ScalarBaseMult returns k*G, where G is the base point of the group + // and k is an integer in big-endian form. + ScalarBaseMult(k []byte) (x, y *big.Int) +} + +var mask = []byte{0xff, 0x1, 0x3, 0x7, 0xf, 0x1f, 0x3f, 0x7f} + +// GenerateKey returns a public/private key pair. The private key is +// generated using the given reader, which must return random data. +func GenerateKey(curve Curve, rand io.Reader) (priv []byte, x, y *big.Int, err error) { + N := curve.Params().N + bitSize := N.BitLen() + byteLen := (bitSize + 7) / 8 + priv = make([]byte, byteLen) + + for x == nil { + _, err = io.ReadFull(rand, priv) + if err != nil { + return + } + // We have to mask off any excess bits in the case that the size of the + // underlying field is not a whole number of bytes. + priv[0] &= mask[bitSize%8] + // This is because, in tests, rand will return all zeros and we don't + // want to get the point at infinity and loop forever. + priv[1] ^= 0x42 + + // If the scalar is out of range, sample another random number. + if new(big.Int).SetBytes(priv).Cmp(N) >= 0 { + continue + } + + x, y = curve.ScalarBaseMult(priv) + } + return +} + +// Marshal converts a point on the curve into the uncompressed form specified in +// SEC 1, Version 2.0, Section 2.3.3. If the point is not on the curve (or is +// the conventional point at infinity), the behavior is undefined. +func Marshal(curve Curve, x, y *big.Int) []byte { + panicIfNotOnCurve(curve, x, y) + + byteLen := (curve.Params().BitSize + 7) / 8 + + ret := make([]byte, 1+2*byteLen) + ret[0] = 4 // uncompressed point + + x.FillBytes(ret[1 : 1+byteLen]) + y.FillBytes(ret[1+byteLen : 1+2*byteLen]) + + return ret +} + +// MarshalCompressed converts a point on the curve into the compressed form +// specified in SEC 1, Version 2.0, Section 2.3.3. If the point is not on the +// curve (or is the conventional point at infinity), the behavior is undefined. +func MarshalCompressed(curve Curve, x, y *big.Int) []byte { + panicIfNotOnCurve(curve, x, y) + byteLen := (curve.Params().BitSize + 7) / 8 + compressed := make([]byte, 1+byteLen) + compressed[0] = byte(y.Bit(0)) | 2 + x.FillBytes(compressed[1:]) + return compressed +} + +// unmarshaler is implemented by curves with their own constant-time Unmarshal. +// +// There isn't an equivalent interface for Marshal/MarshalCompressed because +// that doesn't involve any mathematical operations, only FillBytes and Bit. +type unmarshaler interface { + Unmarshal([]byte) (x, y *big.Int) + UnmarshalCompressed([]byte) (x, y *big.Int) +} + +// Unmarshal converts a point, serialized by Marshal, into an x, y pair. It is +// an error if the point is not in uncompressed form, is not on the curve, or is +// the point at infinity. On error, x = nil. +func Unmarshal(curve Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.Unmarshal(data) + } + + byteLen := (curve.Params().BitSize + 7) / 8 + if len(data) != 1+2*byteLen { + return nil, nil + } + if data[0] != 4 { // uncompressed form + return nil, nil + } + p := curve.Params().P + x = new(big.Int).SetBytes(data[1 : 1+byteLen]) + y = new(big.Int).SetBytes(data[1+byteLen:]) + if x.Cmp(p) >= 0 || y.Cmp(p) >= 0 { + return nil, nil + } + if !curve.IsOnCurve(x, y) { + return nil, nil + } + return +} + +// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into +// an x, y pair. It is an error if the point is not in compressed form, is not +// on the curve, or is the point at infinity. On error, x = nil. +func UnmarshalCompressed(curve Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.UnmarshalCompressed(data) + } + + byteLen := (curve.Params().BitSize + 7) / 8 + if len(data) != 1+byteLen { + return nil, nil + } + if data[0] != 2 && data[0] != 3 { // compressed form + return nil, nil + } + p := curve.Params().P + x = new(big.Int).SetBytes(data[1:]) + if x.Cmp(p) >= 0 { + return nil, nil + } + // y² = x³ + b + y = curve.Params().polynomial(x) + y = y.ModSqrt(y, p) + if y == nil { + return nil, nil + } + if byte(y.Bit(0)) != data[0]&1 { + y.Neg(y).Mod(y, p) + } + if !curve.IsOnCurve(x, y) { + return nil, nil + } + return +} + +func panicIfNotOnCurve(curve Curve, x, y *big.Int) { + // (0, 0) is the point at infinity by convention. It's ok to operate on it, + // although IsOnCurve is documented to return false for it. See Issue 37294. + if x.Sign() == 0 && y.Sign() == 0 { + return + } + + if !curve.IsOnCurve(x, y) { + panic("sm9/elliptic: attempted operation on invalid point") + } +} + +func bigFromHex(s string) *big.Int { + b, ok := new(big.Int).SetString(s, 16) + if !ok { + panic("sm9/elliptic: internal error: invalid encoding") + } + return b +} diff --git a/sm9/g1.go b/sm9/g1.go new file mode 100644 index 0000000..e76c846 --- /dev/null +++ b/sm9/g1.go @@ -0,0 +1,331 @@ +package sm9 + +import ( + "crypto/rand" + "errors" + "io" + "math/big" + "math/bits" +) + +func randomK(r io.Reader) (k *big.Int, err error) { + for { + k, err = rand.Int(r, Order) + if k.Sign() > 0 || err != nil { + return + } + } +} + +// G1 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G1 struct { + p *curvePoint +} + +//Gen1 is the generator of G1. +var Gen1 = &G1{curveGen} + +// RandomG1 returns x and g₁ˣ where x is a random, non-zero number read from r. +func RandomG1(r io.Reader) (*big.Int, *G1, error) { + k, err := randomK(r) + if err != nil { + return nil, nil, err + } + + return k, new(G1).ScalarBaseMult(k), nil +} + +func (g *G1) String() string { + return "sm9.G1" + g.p.String() +} + +// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// returns e. +func (e *G1) ScalarBaseMult(k *big.Int) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Mul(curveGen, k) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G1) ScalarMult(a *G1, k *big.Int) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Mul(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *G1) Add(a, b *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Add(a.p, b.p) + return e +} + +// Double sets e to [2]a and then returns e. +func (e *G1) Double(a *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Double(a.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *G1) Neg(a *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Neg(a.p) + return e +} + +// Set sets e to a and then returns e. +func (e *G1) Set(a *G1) *G1 { + if e.p == nil { + e.p = &curvePoint{} + } + e.p.Set(a.p) + return e +} + +// Marshal converts e to a byte slice. +func (e *G1) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if e.p == nil { + e.p = &curvePoint{} + } + + e.p.MakeAffine() + ret := make([]byte, numBytes*2) + if e.p.IsInfinity() { + return ret + } + temp := &gfP{} + + montDecode(temp, &e.p.x) + temp.Marshal(ret) + montDecode(temp, &e.p.y) + temp.Marshal(ret[numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G1) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if len(m) < 2*numBytes { + return nil, errors.New("sm9.G1: not enough data") + } + + if e.p == nil { + e.p = &curvePoint{} + } else { + e.p.x, e.p.y = gfP{0}, gfP{0} + } + + e.p.x.Unmarshal(m) + e.p.y.Unmarshal(m[numBytes:]) + montEncode(&e.p.x, &e.p.x) + montEncode(&e.p.y, &e.p.y) + + zero := gfP{0} + if e.p.x == zero && e.p.y == zero { + // This is the point at infinity. + e.p.y = *newGFp(1) + e.p.z = gfP{0} + e.p.t = gfP{0} + } else { + e.p.z = *newGFp(1) + e.p.t = *newGFp(1) + + if !e.p.IsOnCurve() { + return nil, errors.New("sm9.G1: malformed point") + } + } + + return m[2*numBytes:], nil +} + +type G1Curve struct { + params *CurveParams + g G1 +} + +var g1Curve = &G1Curve{ + params: &CurveParams{ + Name: "sm9", + BitSize: 256, + P: bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"), + N: bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"), + B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"), + Gx: bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"), + Gy: bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"), + }, + g: G1{}, +} + +func (g1 *G1Curve) pointFromAffine(x, y *big.Int) (a *G1, err error) { + a = &G1{&curvePoint{}} + if x.Sign() == 0 { + a.p.SetInfinity() + return a, nil + } + // Reject values that would not get correctly encoded. + if x.Sign() < 0 || y.Sign() < 0 { + return a, errors.New("negative coordinate") + } + if x.BitLen() > g1.params.BitSize || y.BitLen() > g1.params.BitSize { + return a, errors.New("overflowing coordinate") + } + a.p.x = *fromBigInt(x) + a.p.y = *fromBigInt(y) + a.p.z = *newGFp(1) + a.p.t = *newGFp(1) + + if !a.p.IsOnCurve() { + return a, errors.New("point not on G1 curve") + } + + return a, nil +} + +func (g1 *G1Curve) Params() *CurveParams { + return g1.params +} + +// normalizeScalar brings the scalar within the byte size of the order of the +// curve, as expected by the nistec scalar multiplication functions. +func (curve *G1Curve) normalizeScalar(scalar []byte) *big.Int { + byteSize := (curve.params.N.BitLen() + 7) / 8 + s := new(big.Int).SetBytes(scalar) + if len(scalar) > byteSize { + s.Mod(s, curve.params.N) + } + return s +} + +func (g1 *G1Curve) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { + scalar := g1.normalizeScalar(k) + res := g1.g.ScalarBaseMult(scalar).Marshal() + return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) +} + +func (g1 *G1Curve) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { + a, err := g1.pointFromAffine(Bx, By) + if err != nil { + panic("sm9: ScalarMult was called on an invalid point") + } + res := g1.g.ScalarMult(a, new(big.Int).SetBytes(k)).Marshal() + return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) +} + +func (g1 *G1Curve) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { + a, err := g1.pointFromAffine(x1, y1) + if err != nil { + panic("sm9: Add was called on an invalid point") + } + b, err := g1.pointFromAffine(x2, y2) + if err != nil { + panic("sm9: Add was called on an invalid point") + } + res := g1.g.Add(a, b).Marshal() + return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) +} + +func (g1 *G1Curve) Double(x, y *big.Int) (*big.Int, *big.Int) { + a, err := g1.pointFromAffine(x, y) + if err != nil { + panic("sm9: Double was called on an invalid point") + } + res := g1.g.Double(a).Marshal() + return new(big.Int).SetBytes(res[:32]), new(big.Int).SetBytes(res[32:]) +} + +func (g1 *G1Curve) IsOnCurve(x, y *big.Int) bool { + _, err := g1.pointFromAffine(x, y) + return err == nil +} + +func lessThanP(x *gfP) int { + var b uint64 + _, b = bits.Sub64(x[0], p2[0], b) + _, b = bits.Sub64(x[1], p2[1], b) + _, b = bits.Sub64(x[2], p2[2], b) + _, b = bits.Sub64(x[3], p2[3], b) + return int(b) +} + +func (curve *G1Curve) UnmarshalCompressed(data []byte) (x, y *big.Int) { + if len(data) != 33 || (data[0] != 2 && data[0] != 3) { + return nil, nil + } + r := &gfP{} + r.Unmarshal(data[1:33]) + if lessThanP(r) == 0 { + return nil, nil + } + x = new(big.Int).SetBytes(data[1:33]) + p := &curvePoint{} + montEncode(r, r) + p.x = *r + p.z = *newGFp(1) + p.t = *newGFp(1) + y2 := &gfP{} + gfpMul(y2, r, r) + gfpMul(y2, y2, r) + gfpAdd(y2, y2, curveB) + y2.Sqrt(y2) + p.y = *y2 + if !p.IsOnCurve() { + return nil, nil + } + montDecode(y2, y2) + ret := make([]byte, 32) + y2.Marshal(ret) + y = new(big.Int).SetBytes(ret) + if byte(y.Bit(0)) != data[0]&1 { + gfpNeg(y2, y2) + y2.Marshal(ret) + y.SetBytes(ret) + } + return x, y +} + +func (curve *G1Curve) Unmarshal(data []byte) (x, y *big.Int) { + if len(data) != 65 || (data[0] != 4) { + return nil, nil + } + x1 := &gfP{} + x1.Unmarshal(data[1:33]) + y1 := &gfP{} + y1.Unmarshal(data[33:]) + if lessThanP(x1) == 0 || lessThanP(y1) == 0 { + return nil, nil + } + montEncode(x1, x1) + montEncode(y1, y1) + p := &curvePoint{ + x: *x1, + y: *y1, + z: *newGFp(1), + t: *newGFp(1), + } + if !p.IsOnCurve() { + return nil, nil + } + x = new(big.Int).SetBytes(data[1:33]) + y = new(big.Int).SetBytes(data[33:]) + return x, y +} diff --git a/sm9/g1_test.go b/sm9/g1_test.go new file mode 100644 index 0000000..42d1ede --- /dev/null +++ b/sm9/g1_test.go @@ -0,0 +1,414 @@ +package sm9 + +import ( + "crypto/rand" + "io" + "math/big" + "testing" + "time" +) + +type g1BaseMultTest struct { + k string +} + +var baseMultTests = []g1BaseMultTest{ + { + "112233445566778899", + }, + { + "112233445566778899112233445566778899", + }, + { + "6950511619965839450988900688150712778015737983940691968051900319680", + }, + { + "13479972933410060327035789020509431695094902435494295338570602119423", + }, + { + "13479971751745682581351455311314208093898607229429740618390390702079", + }, + { + "13479972931865328106486971546324465392952975980343228160962702868479", + }, + { + "11795773708834916026404142434151065506931607341523388140225443265536", + }, + { + "784254593043826236572847595991346435467177662189391577090", + }, + { + "13479767645505654746623887797783387853576174193480695826442858012671", + }, + { + "205688069665150753842126177372015544874550518966168735589597183", + }, + { + "13479966930919337728895168462090683249159702977113823384618282123295", + }, + { + "50210731791415612487756441341851895584393717453129007497216", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368041", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368042", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368043", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368044", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368045", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368046", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368047", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368048", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368049", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368050", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368051", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368052", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368053", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368054", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368055", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368056", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368057", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368058", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368059", + }, + { + "26959946667150639794667015087019625940457807714424391721682722368060", + }, +} + +func TestG1BaseMult(t *testing.T) { + g1 := g1Curve + g1Generic := g1.Params() + + scalars := make([]*big.Int, 0, len(baseMultTests)+1) + for i := 1; i <= 20; i++ { + k := new(big.Int).SetInt64(int64(i)) + scalars = append(scalars, k) + } + for _, e := range baseMultTests { + k, _ := new(big.Int).SetString(e.k, 10) + scalars = append(scalars, k) + } + k := new(big.Int).SetInt64(1) + k.Lsh(k, 500) + scalars = append(scalars, k) + + for i, k := range scalars { + x, y := g1.ScalarBaseMult(k.Bytes()) + x2, y2 := g1Generic.ScalarBaseMult(k.Bytes()) + if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 { + t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, x, y, x2, y2) + } + + if testing.Short() && i > 5 { + break + } + } +} + +func TestFuzz(t *testing.T) { + g1 := g1Curve + g1Generic := g1.Params() + + var scalar1 [32]byte + var scalar2 [32]byte + var timeout *time.Timer + + if testing.Short() { + timeout = time.NewTimer(10 * time.Millisecond) + } else { + timeout = time.NewTimer(2 * time.Second) + } + + for { + select { + case <-timeout.C: + return + default: + } + + io.ReadFull(rand.Reader, scalar1[:]) + io.ReadFull(rand.Reader, scalar2[:]) + + x, y := g1.ScalarBaseMult(scalar1[:]) + x2, y2 := g1Generic.ScalarBaseMult(scalar1[:]) + + xx, yy := g1.ScalarMult(x, y, scalar2[:]) + xx2, yy2 := g1Generic.ScalarMult(x2, y2, scalar2[:]) + + if x.Cmp(x2) != 0 || y.Cmp(y2) != 0 { + t.Fatalf("ScalarBaseMult does not match reference result with scalar: %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1) + } + + if xx.Cmp(xx2) != 0 || yy.Cmp(yy2) != 0 { + t.Fatalf("ScalarMult does not match reference result with scalars: %x and %x, please report this error to https://github.com/emmansun/gmsm/issues", scalar1, scalar2) + } + } +} + +func TestG1OnCurve(t *testing.T) { + if !g1Curve.IsOnCurve(g1Curve.Params().Gx, g1Curve.Params().Gy) { + t.Error("basepoint is not on the curve") + } +} + +func TestOffCurve(t *testing.T) { + x, y := new(big.Int).SetInt64(1), new(big.Int).SetInt64(1) + if g1Curve.IsOnCurve(x, y) { + t.Errorf("point off curve is claimed to be on the curve") + } + + byteLen := (g1Curve.Params().BitSize + 7) / 8 + b := make([]byte, 1+2*byteLen) + b[0] = 4 // uncompressed point + x.FillBytes(b[1 : 1+byteLen]) + y.FillBytes(b[1+byteLen : 1+2*byteLen]) + + x1, y1 := Unmarshal(g1Curve, b) + if x1 != nil || y1 != nil { + t.Errorf("unmarshaling a point not on the curve succeeded") + } +} + +func isInfinity(x, y *big.Int) bool { + return x.Sign() == 0 && y.Sign() == 0 +} + +func TestInfinity(t *testing.T) { + x0, y0 := new(big.Int), new(big.Int) + xG, yG := g1Curve.Params().Gx, g1Curve.Params().Gy + + if !isInfinity(g1Curve.ScalarMult(xG, yG, g1Curve.Params().N.Bytes())) { + t.Errorf("x^q != ∞") + } + if !isInfinity(g1Curve.ScalarMult(xG, yG, []byte{0})) { + t.Errorf("x^0 != ∞") + } + + if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{1, 2, 3})) { + t.Errorf("∞^k != ∞") + } + if !isInfinity(g1Curve.ScalarMult(x0, y0, []byte{0})) { + t.Errorf("∞^0 != ∞") + } + + if !isInfinity(g1Curve.ScalarBaseMult(g1Curve.Params().N.Bytes())) { + t.Errorf("b^q != ∞") + } + if !isInfinity(g1Curve.ScalarBaseMult([]byte{0})) { + t.Errorf("b^0 != ∞") + } + + if !isInfinity(g1Curve.Double(x0, y0)) { + t.Errorf("2∞ != ∞") + } + // There is no other point of order two on the NIST curves (as they have + // cofactor one), so Double can't otherwise return the point at infinity. + + nMinusOne := new(big.Int).Sub(g1Curve.Params().N, big.NewInt(1)) + x, y := g1Curve.ScalarMult(xG, yG, nMinusOne.Bytes()) + x, y = g1Curve.Add(x, y, xG, yG) + if !isInfinity(x, y) { + t.Errorf("x^(q-1) + x != ∞") + } + x, y = g1Curve.Add(xG, yG, x0, y0) + if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { + t.Errorf("x+∞ != x") + } + x, y = g1Curve.Add(x0, y0, xG, yG) + if x.Cmp(xG) != 0 || y.Cmp(yG) != 0 { + t.Errorf("∞+x != x") + } + + if !g1Curve.IsOnCurve(x0, y0) { + t.Errorf("IsOnCurve(∞) != true") + } + + if xx, yy := Unmarshal(g1Curve, Marshal(g1Curve, x0, y0)); xx == nil || yy == nil { + t.Errorf("Unmarshal(Marshal(∞)) did return an error") + } + // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are + // two valid points with x = 0. + if xx, yy := Unmarshal(g1Curve, []byte{0x00}); xx != nil || yy != nil { + t.Errorf("Unmarshal(∞) did not return an error") + } + byteLen := (g1Curve.Params().BitSize + 7) / 8 + buf := make([]byte, byteLen*2+1) + buf[0] = 4 // Uncompressed format. + if xx, yy := Unmarshal(g1Curve, buf); xx == nil || yy == nil { + t.Errorf("Unmarshal((0,0)) did return an error") + } +} + +func TestMarshal(t *testing.T) { + _, x, y, err := GenerateKey(g1Curve, rand.Reader) + if err != nil { + t.Fatal(err) + } + serialized := Marshal(g1Curve, x, y) + xx, yy := Unmarshal(g1Curve, serialized) + if xx == nil { + t.Fatal("failed to unmarshal") + } + if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { + t.Fatal("unmarshal returned different values") + } +} + +func TestInvalidCoordinates(t *testing.T) { + checkIsOnCurveFalse := func(name string, x, y *big.Int) { + if g1Curve.IsOnCurve(x, y) { + t.Errorf("IsOnCurve(%s) unexpectedly returned true", name) + } + } + + p := g1Curve.Params().P + _, x, y, _ := GenerateKey(g1Curve, rand.Reader) + xx, yy := new(big.Int), new(big.Int) + + // Check if the sign is getting dropped. + xx.Neg(x) + checkIsOnCurveFalse("-x, y", xx, y) + yy.Neg(y) + checkIsOnCurveFalse("x, -y", x, yy) + + // Check if negative values are reduced modulo P. + xx.Sub(x, p) + checkIsOnCurveFalse("x-P, y", xx, y) + yy.Sub(y, p) + checkIsOnCurveFalse("x, y-P", x, yy) + + // Check if positive values are reduced modulo P. + xx.Add(x, p) + checkIsOnCurveFalse("x+P, y", xx, y) + yy.Add(y, p) + checkIsOnCurveFalse("x, y+P", x, yy) + + // Check if the overflow is dropped. + xx.Add(x, new(big.Int).Lsh(big.NewInt(1), 535)) + checkIsOnCurveFalse("x+2⁵³⁵, y", xx, y) + yy.Add(y, new(big.Int).Lsh(big.NewInt(1), 535)) + checkIsOnCurveFalse("x, y+2⁵³⁵", x, yy) + + // Check if P is treated like zero (if possible). + // y^2 = x^3 + B + // y = mod_sqrt(x^3 + B) + // y = mod_sqrt(B) if x = 0 + // If there is no modsqrt, there is no point with x = 0, can't test x = P. + if yy := new(big.Int).ModSqrt(g1Curve.Params().B, p); yy != nil { + if !g1Curve.IsOnCurve(big.NewInt(0), yy) { + t.Fatal("(0, mod_sqrt(B)) is not on the curve?") + } + checkIsOnCurveFalse("P, y", p, yy) + } +} + +func TestLargeIsOnCurve(t *testing.T) { + large := big.NewInt(1) + large.Lsh(large, 1000) + if g1Curve.IsOnCurve(large, large) { + t.Errorf("(2^1000, 2^1000) is reported on the curve") + } +} + +func benchmarkAllCurves(b *testing.B, f func(*testing.B, Curve)) { + tests := []struct { + name string + curve Curve + }{ + {"sm9", g1Curve}, + {"sm9Parmas", g1Curve.Params()}, + } + for _, test := range tests { + curve := test.curve + b.Run(test.name, func(b *testing.B) { + f(b, curve) + }) + } +} + +func BenchmarkScalarBaseMult(b *testing.B) { + benchmarkAllCurves(b, func(b *testing.B, curve Curve) { + priv, _, _, _ := GenerateKey(curve, rand.Reader) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x, _ := curve.ScalarBaseMult(priv) + // Prevent the compiler from optimizing out the operation. + priv[0] ^= byte(x.Bits()[0]) + } + }) +} + +func BenchmarkScalarMult(b *testing.B) { + benchmarkAllCurves(b, func(b *testing.B, curve Curve) { + _, x, y, _ := GenerateKey(curve, rand.Reader) + priv, _, _, _ := GenerateKey(curve, rand.Reader) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + x, y = curve.ScalarMult(x, y, priv) + } + }) +} + +func BenchmarkMarshalUnmarshal(b *testing.B) { + benchmarkAllCurves(b, func(b *testing.B, curve Curve) { + _, x, y, _ := GenerateKey(curve, rand.Reader) + b.Run("Uncompressed", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := Marshal(curve, x, y) + xx, yy := Unmarshal(curve, buf) + if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { + b.Error("Unmarshal output different from Marshal input") + } + } + }) + b.Run("Compressed", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + buf := MarshalCompressed(curve, x, y) + xx, yy := UnmarshalCompressed(curve, buf) + if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { + b.Error("Unmarshal output different from Marshal input") + } + } + }) + }) +} diff --git a/sm9/g2.go b/sm9/g2.go new file mode 100644 index 0000000..877362c --- /dev/null +++ b/sm9/g2.go @@ -0,0 +1,148 @@ +package sm9 + +import ( + "errors" + "io" + "math/big" +) + +// G2 is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type G2 struct { + p *twistPoint +} + +// RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r. +func RandomG2(r io.Reader) (*big.Int, *G2, error) { + k, err := randomK(r) + if err != nil { + return nil, nil, err + } + + return k, new(G2).ScalarBaseMult(k), nil +} + +func (e *G2) String() string { + return "sm9.G2" + e.p.String() +} + +// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// returns out. +func (e *G2) ScalarBaseMult(k *big.Int) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Mul(twistGen, k) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *G2) ScalarMult(a *G2, k *big.Int) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Mul(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *G2) Add(a, b *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Add(a.p, b.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *G2) Neg(a *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Neg(a.p) + return e +} + +// Set sets e to a and then returns e. +func (e *G2) Set(a *G2) *G2 { + if e.p == nil { + e.p = &twistPoint{} + } + e.p.Set(a.p) + return e +} + +// Marshal converts e into a byte slice. +func (e *G2) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if e.p == nil { + e.p = &twistPoint{} + } + + e.p.MakeAffine() + ret := make([]byte, numBytes*4) + if e.p.IsInfinity() { + return ret + } + temp := &gfP{} + + montDecode(temp, &e.p.x.x) + temp.Marshal(ret) + montDecode(temp, &e.p.x.y) + temp.Marshal(ret[numBytes:]) + montDecode(temp, &e.p.y.x) + temp.Marshal(ret[2*numBytes:]) + montDecode(temp, &e.p.y.y) + temp.Marshal(ret[3*numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *G2) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + if len(m) < 4*numBytes { + return nil, errors.New("sm9.G2: not enough data") + } + // Unmarshal the points and check their caps + if e.p == nil { + e.p = &twistPoint{} + } + var err error + if err = e.p.x.x.Unmarshal(m); err != nil { + return nil, err + } + if err = e.p.x.y.Unmarshal(m[numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.Unmarshal(m[2*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.Unmarshal(m[3*numBytes:]); err != nil { + return nil, err + } + // Encode into Montgomery form and ensure it's on the curve + montEncode(&e.p.x.x, &e.p.x.x) + montEncode(&e.p.x.y, &e.p.x.y) + montEncode(&e.p.y.x, &e.p.y.x) + montEncode(&e.p.y.y, &e.p.y.y) + + if e.p.x.IsZero() && e.p.y.IsZero() { + // This is the point at infinity. + e.p.y.SetOne() + e.p.z.SetZero() + e.p.t.SetZero() + } else { + e.p.z.SetOne() + e.p.t.SetOne() + + if !e.p.IsOnCurve() { + return nil, errors.New("sm9.G2: malformed point") + } + } + return m[4*numBytes:], nil +} diff --git a/sm9/g2_test.go b/sm9/g2_test.go new file mode 100644 index 0000000..dc06c99 --- /dev/null +++ b/sm9/g2_test.go @@ -0,0 +1,52 @@ +package sm9 + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "testing" +) + +func TestG2(t *testing.T) { + k, Ga, err := RandomG2(rand.Reader) + if err != nil { + t.Fatal(err) + } + ma := Ga.Marshal() + + Gb := new(G2).ScalarBaseMult(k) + mb := Gb.Marshal() + + if !bytes.Equal(ma, mb) { + t.Errorf("bytes are different, expected %v, got %v", hex.EncodeToString(ma), hex.EncodeToString(mb)) + } +} + +func TestG2Marshal(t *testing.T) { + _, Ga, err := RandomG2(rand.Reader) + if err != nil { + t.Fatal(err) + } + ma := Ga.Marshal() + + Gb := new(G2) + _, err = Gb.Unmarshal(ma) + if err != nil { + t.Fatal(err) + } + mb := Gb.Marshal() + + if !bytes.Equal(ma, mb) { + t.Errorf("bytes are different, expected %v, got %v", hex.EncodeToString(ma), hex.EncodeToString(mb)) + } +} + +func BenchmarkG2(b *testing.B) { + x, _ := rand.Int(rand.Reader, Order) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + new(G2).ScalarBaseMult(x) + } +} diff --git a/sm9/gfp.go b/sm9/gfp.go new file mode 100644 index 0000000..0aee698 --- /dev/null +++ b/sm9/gfp.go @@ -0,0 +1,197 @@ +package sm9 + +import ( + "crypto/sha256" + "encoding/binary" + "errors" + "fmt" + "math/big" + + "golang.org/x/crypto/hkdf" +) + +type gfP [4]uint64 + +var zero = newGFp(0) +var one = newGFp(1) +var two = newGFp(2) + +func newGFp(x int64) (out *gfP) { + if x >= 0 { + out = &gfP{uint64(x)} + } else { + out = &gfP{uint64(-x)} + gfpNeg(out, out) + } + + montEncode(out, out) + return out +} + +func fromBigInt(x *big.Int) (out *gfP) { + out = &gfP{} + var a *big.Int + if x.Sign() >= 0 { + a = x + } else { + a = new(big.Int).Neg(x) + } + for i, v := range a.Bits() { + out[i] = uint64(v) + } + if x.Sign() < 0 { + gfpNeg(out, out) + } + if x.Sign() != 0 { + montEncode(out, out) + } + return out +} + +// hashToBase implements hashing a message to an element of the field. +// +// L = ceil((256+128)/8)=48, ctr = 0, i = 1 +func hashToBase(msg, dst []byte) *gfP { + var t [48]byte + info := []byte{'H', '2', 'C', byte(0), byte(1)} + r := hkdf.New(sha256.New, msg, dst, info) + if _, err := r.Read(t[:]); err != nil { + panic(err) + } + var x big.Int + v := x.SetBytes(t[:]).Mod(&x, p).Bytes() + v32 := [32]byte{} + for i := len(v) - 1; i >= 0; i-- { + v32[len(v)-1-i] = v[i] + } + u := &gfP{ + binary.LittleEndian.Uint64(v32[0*8 : 1*8]), + binary.LittleEndian.Uint64(v32[1*8 : 2*8]), + binary.LittleEndian.Uint64(v32[2*8 : 3*8]), + binary.LittleEndian.Uint64(v32[3*8 : 4*8]), + } + montEncode(u, u) + return u +} + +func (e *gfP) String() string { + return fmt.Sprintf("%16.16x%16.16x%16.16x%16.16x", e[3], e[2], e[1], e[0]) +} + +func (e *gfP) Set(f *gfP) { + e[0] = f[0] + e[1] = f[1] + e[2] = f[2] + e[3] = f[3] +} + +func (e *gfP) exp(f *gfP, bits [4]uint64) { + sum, power := &gfP{}, &gfP{} + sum.Set(rN1) + power.Set(f) + + for word := 0; word < 4; word++ { + for bit := uint(0); bit < 64; bit++ { + if (bits[word]>>bit)&1 == 1 { + gfpMul(sum, sum, power) + } + gfpMul(power, power, power) + } + } + + gfpMul(sum, sum, r3) + e.Set(sum) +} + +func (e *gfP) Invert(f *gfP) { + e.exp(f, pMinus2) +} + +func (e *gfP) Sqrt(f *gfP) { + // Since p = 8k+5, + // Atkin algorithm + // https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.896.6057&rep=rep1&type=pdf + // https://eprint.iacr.org/2012/685.pdf + // + a1, b, i := &gfP{}, &gfP{}, &gfP{} + a1.exp(f, pMinus5Over8) + gfpMul(b, twoExpPMinus5Over8, a1) // b=ta1 + gfpMul(a1, f, b) // a1=fb + gfpMul(i, two, a1) // i=2(fb) + gfpMul(i, i, b) // i=2(fb)b + gfpSub(i, i, one) // i=2(fb)b-1 + gfpMul(i, a1, i) // i=(fb)(2(fb)b-1) + e.Set(i) +} + +func (e *gfP) Marshal(out []byte) { + for w := uint(0); w < 4; w++ { + for b := uint(0); b < 8; b++ { + out[8*w+b] = byte(e[3-w] >> (56 - 8*b)) + } + } +} + +func (e *gfP) Unmarshal(in []byte) error { + // Unmarshal the bytes into little endian form + for w := uint(0); w < 4; w++ { + e[3-w] = 0 + for b := uint(0); b < 8; b++ { + e[3-w] += uint64(in[8*w+b]) << (56 - 8*b) + } + } + // Ensure the point respects the curve modulus + for i := 3; i >= 0; i-- { + if e[i] < p2[i] { + return nil + } + if e[i] > p2[i] { + return errors.New("sm9: coordinate exceeds modulus") + } + } + return errors.New("sm9: coordinate equals modulus") +} + +func montEncode(c, a *gfP) { gfpMul(c, a, r2) } +func montDecode(c, a *gfP) { gfpMul(c, a, &gfP{1}) } + +func sign0(e *gfP) int { + x := &gfP{} + montDecode(x, e) + for w := 3; w >= 0; w-- { + if x[w] > pMinus1Over2[w] { + return 1 + } else if x[w] < pMinus1Over2[w] { + return -1 + } + } + return 1 +} + +func legendre(e *gfP) int { + f := &gfP{} + // Since p = 8k+5, then e^(4k+2) is the Legendre symbol of e. + f.exp(e, pMinus1Over2) + + montDecode(f, f) + + if *f != [4]uint64{} { + return 2*int(f[0]&1) - 1 + } + + return 0 +} + +func (e *gfP) Div2(f *gfP) *gfP { + ret := &gfP{} + gfpMul(ret, f, twoInvert) + e.Set(ret) + return e +} + +var twoInvert = &gfP{} + +func init() { + t1 := newGFp(2) + twoInvert.Invert(t1) +} diff --git a/sm9/gfp12.go b/sm9/gfp12.go new file mode 100644 index 0000000..3346960 --- /dev/null +++ b/sm9/gfp12.go @@ -0,0 +1,369 @@ +package sm9 + +import "math/big" + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. +// + +// gfP12 implements the field of size p¹² as a cubic extension of gfP4 where v³=u +type gfP12 struct { + x, y, z gfP4 // value is xw² + yw + z +} + +func gfP12Decode(in *gfP12) *gfP12 { + out := &gfP12{} + out.x = *gfP4Decode(&in.x) + out.y = *gfP4Decode(&in.y) + out.z = *gfP4Decode(&in.z) + return out +} + +var gfP12Gen *gfP12 = &gfP12{ + x: gfP4{ + x: gfP2{ + x: *fromBigInt(bigFromHex("256943fbdb2bf87ab91ae7fbeaff14e146cf7e2279b9d155d13461e09b22f523")), + y: *fromBigInt(bigFromHex("0167b0280051495c6af1ec23ba2cd2ff1cdcdeca461a5ab0b5449e9091308310")), + }, + y: gfP2{ + x: *fromBigInt(bigFromHex("5e7addaddf7fbfe16291b4e89af50b8217ddc47ba3cba833c6e77c3fb027685e")), + y: *fromBigInt(bigFromHex("79d0c8337072c93fef482bb055f44d6247ccac8e8e12525854b3566236337ebe")), + }, + }, + y: gfP4{ + x: gfP2{ + x: *fromBigInt(bigFromHex("082cde173022da8cd09b28a2d80a8cee53894436a52007f978dc37f36116d39b")), + y: *fromBigInt(bigFromHex("3fa7ed741eaed99a58f53e3df82df7ccd3407bcc7b1d44a9441920ced5fb824f")), + }, + y: gfP2{ + x: *fromBigInt(bigFromHex("7fc6eb2aa771d99c9234fddd31752edfd60723e05a4ebfdeb5c33fbd47e0cf06")), + y: *fromBigInt(bigFromHex("6fa6b6fa6dd6b6d3b19a959a110e748154eef796dc0fc2dd766ea414de786968")), + }, + }, + z: gfP4{ + x: gfP2{ + x: *fromBigInt(bigFromHex("8ffe1c0e9de45fd0fed790ac26be91f6b3f0a49c084fe29a3fb6ed288ad7994d")), + y: *fromBigInt(bigFromHex("1664a1366beb3196f0443e15f5f9042a947354a5678430d45ba031cff06db927")), + }, + y: gfP2{ + x: *fromBigInt(bigFromHex("7f7c6d52b475e6aaa827fdc5b4175ac6929320f782d998f86b6b57cda42a0426")), + y: *fromBigInt(bigFromHex("36a699de7c136f78eee2dbac4ca9727bff0cee02ee920f5822e65ea170aa9669")), + }, + }, +} + +func (e *gfP12) String() string { + return "(" + e.x.String() + ", " + e.y.String() + ", " + e.z.String() + ")" +} + +func (e *gfP12) Set(a *gfP12) *gfP12 { + e.x.Set(&a.x) + e.y.Set(&a.y) + e.z.Set(&a.z) + return e +} + +func (e *gfP12) SetZero() *gfP12 { + e.x.SetZero() + e.y.SetZero() + e.z.SetZero() + return e +} + +func (e *gfP12) SetOne() *gfP12 { + e.x.SetZero() + e.y.SetZero() + e.z.SetOne() + return e +} + +func (e *gfP12) SetW() *gfP12 { + e.x.SetZero() + e.y.SetOne() + e.z.SetZero() + return e +} + +func (e *gfP12) SetW2() *gfP12 { + e.x.SetOne() + e.y.SetZero() + e.z.SetZero() + return e +} + +func (e *gfP12) IsZero() bool { + return e.x.IsZero() && e.y.IsZero() && e.z.IsZero() +} + +func (e *gfP12) IsOne() bool { + return e.x.IsZero() && e.y.IsZero() && e.z.IsOne() +} + +func (e *gfP12) Neg(a *gfP12) *gfP12 { + e.x.Neg(&a.x) + e.y.Neg(&a.y) + e.z.Neg(&a.z) + return e +} + +func (e *gfP12) Add(a, b *gfP12) *gfP12 { + e.x.Add(&a.x, &b.x) + e.y.Add(&a.y, &b.y) + e.z.Add(&a.z, &b.z) + return e +} + +func (e *gfP12) Sub(a, b *gfP12) *gfP12 { + e.x.Sub(&a.x, &b.x) + e.y.Sub(&a.y, &b.y) + e.z.Sub(&a.z, &b.z) + return e +} + +func (e *gfP12) MulScalar(a *gfP12, b *gfP4) *gfP12 { + e.x.Mul(&a.x, b) + e.y.Mul(&a.y, b) + e.z.Mul(&a.z, b) + return e +} + +func (e *gfP12) MulGFP2(a *gfP12, b *gfP2) *gfP12 { + e.x.MulScalar(&a.x, b) + e.y.MulScalar(&a.y, b) + e.z.MulScalar(&a.z, b) + return e +} + +func (e *gfP12) MulGFP(a *gfP12, b *gfP) *gfP12 { + e.x.MulGFP(&a.x, b) + e.y.MulGFP(&a.y, b) + e.z.MulGFP(&a.z, b) + return e +} + +func (e *gfP12) Mul(a, b *gfP12) *gfP12 { + // (z0 + y0*w + x0*w^2)* (z1 + y1*w + x1*w^2) + // z0*z1 + z0*y1*w + z0*x1*w^2 + // +y0*z1*w + y0*y1*w^2 + y0*x1*v + // +x0*z1*w^2 + x0*y1*v + x0*x1*v*w + //=(z0*z1+y0*x1*v+x0*y1*v) + (z0*y1+y0*z1+x0*x1*v)w + (z0*x1 + y0*y1 + x0*z1)*w^2 + tx, ty, tz, t := &gfP4{}, &gfP4{}, &gfP4{}, &gfP4{} + tz.Mul(&a.z, &b.z) + t.MulV(&a.y, &b.x) + tz.Add(tz, t) + t.MulV(&a.x, &b.y) + tz.Add(tz, t) + + ty.Mul(&a.z, &b.y) + t.Mul(&a.y, &b.z) + ty.Add(ty, t) + t.MulV(&a.x, &b.x) + ty.Add(ty, t) + + tx.Mul(&a.z, &b.x) + t.Mul(&a.y, &b.y) + tx.Add(tx, t) + t.Mul(&a.x, &b.z) + tx.Add(tx, t) + + e.x.Set(tx) + e.y.Set(ty) + e.z.Set(tz) + return e +} + +func (e *gfP12) Square(a *gfP12) *gfP12 { + // (z + y*w + x*w^2)* (z + y*w + x*w^2) + // z^2 + z*y*w + z*x*w^2 + y*z*w + y^2*w^2 + y*x*v + x*z*w^2 + x*y*v + x^2 *v *w + // (z^2 + y*x*v + x*y*v) + (z*y + y*z + v * x^2)w + (z*x + y^2 + x*z)*w^2 + // (z^2 + 2*x*y*v) + (v*x^2 + 2*y*z) *w + (y^2 + 2*x*z) * w^2 + tx, ty, tz, t := &gfP4{}, &gfP4{}, &gfP4{}, &gfP4{} + + tz.Square(&a.z) + t.MulV(&a.x, &a.y) + t.Add(t, t) + tz.Add(tz, t) + + ty.SquareV(&a.x) + t.Mul(&a.y, &a.z) + t.Add(t, t) + ty.Add(ty, t) + + tx.Square(&a.y) + t.Mul(&a.x, &a.z) + t.Add(t, t) + tx.Add(tx, t) + + e.x.Set(tx) + e.y.Set(ty) + e.z.Set(tz) + return e +} + +func (e *gfP12) Exp(f *gfP12, power *big.Int) *gfP12 { + sum := (&gfP12{}).SetOne() + t := &gfP12{} + + for i := power.BitLen() - 1; i >= 0; i-- { + t.Square(sum) + if power.Bit(i) != 0 { + sum.Mul(t, f) + } else { + sum.Set(t) + } + } + + e.Set(sum) + return e +} + +func (e *gfP12) Invert(a *gfP12) *gfP12 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + + // Here we can give a short explanation of how it works: let j be a cubic root of + // unity in GF(p^4) so that 1+j+j²=0. + // Then (xτ² + yτ + z)(xj²τ² + yjτ + z)(xjτ² + yj²τ + z) + // = (xτ² + yτ + z)(Cτ²+Bτ+A) + // = (x³ξ²+y³ξ+z³-3ξxyz) = F is an element of the base field (the norm). + // + // On the other hand (xj²τ² + yjτ + z)(xjτ² + yj²τ + z) + // = τ²(y²-ξxz) + τ(ξx²-yz) + (z²-ξxy) + // + // So that's why A = (z²-ξxy), B = (ξx²-yz), C = (y²-ξxz) + t1 := (&gfP4{}).MulV(&a.x, &a.y) + A := (&gfP4{}).Square(&a.z) + A.Sub(A, t1) + + B := (&gfP4{}).SquareV(&a.x) + t1.Mul(&a.y, &a.z) + B.Sub(B, t1) + + C := (&gfP4{}).Square(&a.y) + t1.Mul(&a.x, &a.z) + C.Sub(C, t1) + + F := (&gfP4{}).MulV(C, &a.y) + t1.Mul(A, &a.z) + F.Add(F, t1) + t1.MulV(B, &a.x) + F.Add(F, t1) + + F.Invert(F) + + e.x.Mul(C, F) + e.y.Mul(B, F) + e.z.Mul(A, F) + return e +} + +// (z + y*w + x*w^2)^p +//= z^p + y^p*w*w^(p-1)+x^p*w^2*(w^2)^(p-1) +// w2ToP2Minus1 = vToPMinus1 * wToPMinus1 +func (e *gfP12) Frobenius(a *gfP12) *gfP12 { + x, y := &gfP2{}, &gfP2{} + + x.Conjugate(&a.z.x) + y.Conjugate(&a.z.y) + x.MulScalar(x, vToPMinus1) + e.z.x.Set(x) + e.z.y.Set(y) + + x.Conjugate(&a.y.x) + y.Conjugate(&a.y.y) + x.MulScalar(x, w2ToP2Minus1) + y.MulScalar(y, wToPMinus1) + e.y.x.Set(x) + e.y.y.Set(y) + + x.Conjugate(&a.x.x) + y.Conjugate(&a.x.y) + x.MulScalar(x, vToPMinus1Mw2ToPMinus1) + y.MulScalar(y, w2ToPMinus1) + e.x.x.Set(x) + e.x.y.Set(y) + + return e +} + +// (z + y*w + x*w^2)^(p^2) +//= z^(p^2) + y^(p^2)*w*w^((p^2)-1)+x^(p^2)*w^2*(w^2)^((p^2)-1) +func (e *gfP12) FrobeniusP2(a *gfP12) *gfP12 { + tx, ty, tz := &gfP4{}, &gfP4{}, &gfP4{} + + tz.Conjugate(&a.z) + + ty.Conjugate(&a.y) + ty.MulGFP(ty, wToP2Minus1) + + tx.Conjugate(&a.x) + tx.MulGFP(tx, w2ToP2Minus1) + + e.x.Set(tx) + e.y.Set(ty) + e.z.Set(tz) + + return e +} + +// (z + y*w + x*w^2)^(p^3) +//=z^(p^3) + y^(p^3)*w*w^((p^3)-1)+x^(p^3)*w^2*(w^2)^((p^3)-1) +//=z^(p^3) + y^(p^3)*w*vToPMinus1-x^(p^3)*w^2 +// vToPMinus1 * vToPMinus1 = -1 +func (e *gfP12) FrobeniusP3(a *gfP12) *gfP12 { + x, y := &gfP2{}, &gfP2{} + + x.Conjugate(&a.z.x) + y.Conjugate(&a.z.y) + x.MulScalar(x, vToPMinus1) + x.Neg(x) + e.z.x.Set(x) + e.z.y.Set(y) + + x.Conjugate(&a.y.x) + y.Conjugate(&a.y.y) + //x.MulScalar(x, vToPMinus1) + //x.Neg(x) + //x.MulScalar(x, vToPMinus1) + y.MulScalar(y, vToPMinus1) + e.y.x.Set(x) + e.y.y.Set(y) + + x.Conjugate(&a.x.x) + y.Conjugate(&a.x.y) + x.MulScalar(x, vToPMinus1) + y.Neg(y) + e.x.x.Set(x) + e.x.y.Set(y) + + return e +} + +// (z + y*w + x*w^2)^(p^6) +// = ((z + y*w + x*w^2)^(p^3))^(p^3) +func (e *gfP12) FrobeniusP6(a *gfP12) *gfP12 { + tx, ty, tz := &gfP4{}, &gfP4{}, &gfP4{} + + tz.Conjugate(&a.z) + + ty.Conjugate(&a.y) + ty.Neg(ty) + + tx.Conjugate(&a.x) + + e.x.Set(tx) + e.y.Set(ty) + e.z.Set(tz) + + return e +} + +// code logic from https://github.com/miracl/MIRACL/blob/master/source/curve/pairing/zzn12a.h +func (e *gfP12) Conjugate(a *gfP12) *gfP12 { + e.z.Conjugate(&a.z) + e.y.Conjugate(&a.y) + e.y.Neg(&e.y) + e.x.Conjugate(&a.x) + return e +} diff --git a/sm9/gfp12_test.go b/sm9/gfp12_test.go new file mode 100644 index 0000000..2d9b6ac --- /dev/null +++ b/sm9/gfp12_test.go @@ -0,0 +1,408 @@ +package sm9 + +import ( + "math/big" + "testing" +) + +func Test_gfP12Square(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + *(&gfP4{}).SetOne(), + } + xmulx := &gfP12{} + xmulx.Mul(x, x) + xmulx = gfP12Decode(xmulx) + + x2 := &gfP12{} + x2.Square(x) + x2 = gfP12Decode(x2) + + if xmulx.x != x2.x || xmulx.y != x2.y || xmulx.z != x2.z { + t.Errorf("xmulx=%v, x2=%v", xmulx, x2) + } +} + +func Test_gfP12Invert(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + *(&gfP4{}).SetOne(), + } + xInv := &gfP12{} + xInv.Invert(x) + + y := &gfP12{} + y.Mul(x, xInv) + if !y.IsOne() { + t.Fail() + } + x = &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + *(&gfP4{}).SetZero(), + } + xInv.Invert(x) + + y.Mul(x, xInv) + if !y.IsOne() { + t.Fail() + } +} + +func Test_gfP12Frobenius_Case1(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW() + pMinus1 := new(big.Int).Sub(p, big.NewInt(1)) + i.Exp(i, pMinus1) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("3f23ea58e5720bdb843c6cfa9c08674947c5c86e0ddd04eda91d8354377b698b"))) + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12Frobenius_Case2(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW2() + pMinus1 := new(big.Int).Sub(p, big.NewInt(1)) + i.Exp(i, pMinus1) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65334"))) + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12FrobeniusP2_Case1(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW() + p2 := new(big.Int).Mul(p, p) + p2 = new(big.Int).Sub(p2, big.NewInt(1)) + i.Exp(i, p2) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65334"))) + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12FrobeniusP2_Case2(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW2() + p2 := new(big.Int).Mul(p, p) + p2 = new(big.Int).Sub(p2, big.NewInt(1)) + i.Exp(i, p2) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("0000000000000000f300000002a3a6f2780272354f8b78f4d5fc11967be65333"))) + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12FrobeniusP3_Case1(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW() + p3 := new(big.Int).Mul(p, p) + p3.Mul(p3, p) + p3 = new(big.Int).Sub(p3, big.NewInt(1)) + i.Exp(i, p3) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("6c648de5dc0a3f2cf55acc93ee0baf159f9d411806dc5177f5b21fd3da24d011"))) + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12FrobeniusP3_Case2(t *testing.T) { + expected := &gfP12{} + i := &gfP12{} + i.SetW2() + p3 := new(big.Int).Mul(p, p) + p3.Mul(p3, p) + p3 = new(big.Int).Sub(p3, big.NewInt(1)) + i.Exp(i, p3) + i = gfP12Decode(i) + expected.z.x.SetZero() + expected.z.y.x.Set(zero) + expected.z.y.y.Set(fromBigInt(bigFromHex("b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457c"))) // -1 + expected.x.SetZero() + expected.y.SetZero() + expected = gfP12Decode(expected) + if expected.x != i.x || expected.y != i.y || expected.z != i.z { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP12Frobenius(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + } + expected := &gfP12{} + expected.Exp(x, p) + got := &gfP12{} + got.Frobenius(x) + if expected.x != got.x || expected.y != got.y || expected.z != got.z { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func Test_gfP12FrobeniusP2(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + } + expected := &gfP12{} + p2 := new(big.Int).Mul(p, p) + expected.Exp(x, p2) + got := &gfP12{} + got.FrobeniusP2(x) + if expected.x != got.x || expected.y != got.y || expected.z != got.z { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func Test_gfP12FrobeniusP3(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + } + expected := &gfP12{} + p3 := new(big.Int).Mul(p, p) + p3.Mul(p3, p) + expected.Exp(x, p3) + got := &gfP12{} + got.FrobeniusP3(x) + if expected.x != got.x || expected.y != got.y || expected.z != got.z { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func Test_gfP12FrobeniusP6(t *testing.T) { + x := &gfP12{ + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + }, + } + expected := &gfP12{} + p6 := new(big.Int).Mul(p, p) + p6.Mul(p6, p) + p6.Mul(p6, p6) + expected.Exp(x, p6) + got := &gfP12{} + got.FrobeniusP6(x) + if expected.x != got.x || expected.y != got.y || expected.z != got.z { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func Test_W3(t *testing.T) { + w1 := (&gfP12{}).SetW() + w2 := (&gfP12{}).SetW2() + + w1.Mul(w2, w1) + w1 = gfP12Decode(w1) + gfp4zero := (&gfP4{}).SetZero() + gfp4v := (&gfP4{}).SetV() + gfp4v = gfP4Decode(gfp4v) + if w1.x != *gfp4zero || w1.y != *gfp4zero || w1.z != *gfp4v { + t.Errorf("not expected") + } +} diff --git a/sm9/gfp2.go b/sm9/gfp2.go new file mode 100644 index 0000000..b42860e --- /dev/null +++ b/sm9/gfp2.go @@ -0,0 +1,260 @@ +package sm9 + +import "math/big" + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. + +// gfP2 implements a field of size p² as a quadratic extension of the base field +// where i²=-2. +type gfP2 struct { + x, y gfP // value is xi+y. +} + +func gfP2Decode(in *gfP2) *gfP2 { + out := &gfP2{} + montDecode(&out.x, &in.x) + montDecode(&out.y, &in.y) + return out +} + +func (e *gfP2) String() string { + return "(" + e.x.String() + ", " + e.y.String() + ")" +} + +func (e *gfP2) Set(a *gfP2) *gfP2 { + e.x.Set(&a.x) + e.y.Set(&a.y) + return e +} + +func (e *gfP2) SetZero() *gfP2 { + e.x = *zero + e.y = *zero + return e +} + +func (e *gfP2) SetOne() *gfP2 { + e.x = *zero + e.y = *one + return e +} + +func (e *gfP2) SetU() *gfP2 { + e.x = *one + e.y = *zero + return e +} + +func (e *gfP2) SetFrobConstant() *gfP2 { + e.x = *zero + e.y = *frobConstant + return e +} + +func (e *gfP2) IsZero() bool { + return e.x == *zero && e.y == *zero +} + +func (e *gfP2) IsOne() bool { + return e.x == *zero && e.y == *one +} + +func (e *gfP2) Conjugate(a *gfP2) *gfP2 { + e.y.Set(&a.y) + gfpNeg(&e.x, &a.x) + return e +} + +func (e *gfP2) Neg(a *gfP2) *gfP2 { + gfpNeg(&e.x, &a.x) + gfpNeg(&e.y, &a.y) + return e +} + +func (e *gfP2) Add(a, b *gfP2) *gfP2 { + gfpAdd(&e.x, &a.x, &b.x) + gfpAdd(&e.y, &a.y, &b.y) + return e +} + +func (e *gfP2) Sub(a, b *gfP2) *gfP2 { + gfpSub(&e.x, &a.x, &b.x) + gfpSub(&e.y, &a.y, &b.y) + return e +} + +func (e *gfP2) Double(a *gfP2) *gfP2 { + gfpAdd(&e.x, &a.x, &a.x) + gfpAdd(&e.y, &a.y, &a.y) + return e +} + +func (e *gfP2) Triple(a *gfP2) *gfP2 { + gfpAdd(&e.x, &a.x, &a.x) + gfpAdd(&e.y, &a.y, &a.y) + + gfpAdd(&e.x, &e.x, &a.x) + gfpAdd(&e.y, &e.y, &a.y) + return e +} + +// See "Multiplication and Squaring in Pairing-Friendly Fields", +// http://eprint.iacr.org/2006/471.pdf +// The Karatsuba method +//(a0+a1*i)(b0+b1*i)=c0+c1*i, where +//c0 = a0*b0 - 2a1*b1 +//c1 = (a0 + a1)(b0 + b1) - a0*b0 - a1*b1 = a0*b1 + a1*b0 +func (e *gfP2) Mul(a, b *gfP2) *gfP2 { + tx, t := &gfP{}, &gfP{} + gfpMul(tx, &a.x, &b.y) + gfpMul(t, &b.x, &a.y) + gfpAdd(tx, tx, t) + + ty := &gfP{} + gfpMul(ty, &a.y, &b.y) + gfpMul(t, &a.x, &b.x) + gfpMul(t, t, two) + gfpSub(ty, ty, t) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +// MulU: a * b * i +//(a0+a1*i)(b0+b1*i)*i=c0+c1*i, where +//c1 = (a0*b0 - 2a1*b1)i +//c0 = -2 * ((a0 + a1)(b0 + b1) - a0*b0 - a1*b1) = -2 * (a0*b1 + a1*b0) +func (e *gfP2) MulU(a, b *gfP2) *gfP2 { + // ty = -2 * (a0 * b1 + a1 * b0) + ty, t := &gfP{}, &gfP{} + gfpMul(ty, &a.x, &b.y) + gfpMul(t, &b.x, &a.y) + gfpAdd(ty, ty, t) + gfpAdd(ty, ty, ty) + gfpNeg(ty, ty) + + // tx = a0 * b0 - 2 * a1 * b1 + tx := &gfP{} + gfpMul(tx, &a.y, &b.y) + gfpMul(t, &a.x, &b.x) + gfpMul(t, t, two) + gfpSub(tx, tx, t) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) Square(a *gfP2) *gfP2 { + // Complex squaring algorithm: + // (xi+y)² = y^2-2*x^2 + 2*i*x*y + tx, ty := &gfP{}, &gfP{} + gfpMul(tx, &a.x, &a.x) + gfpMul(ty, &a.y, &a.y) + gfpSub(ty, ty, tx) + gfpSub(ty, ty, tx) + + gfpMul(tx, &a.x, &a.y) + gfpAdd(tx, tx, tx) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) SquareU(a *gfP2) *gfP2 { + // Complex squaring algorithm: + // (xi+y)²*i = (y^2-2*x^2)i - 4*x*y + + tx, ty := &gfP{}, &gfP{} + // tx = a0^2 - 2 * a1^2 + gfpMul(ty, &a.x, &a.x) + gfpMul(tx, &a.y, &a.y) + gfpAdd(ty, ty, ty) + gfpSub(tx, tx, ty) + + // ty = -4 * a0 * a1 + gfpMul(ty, &a.x, &a.y) + gfpAdd(ty, ty, ty) + gfpAdd(ty, ty, ty) + gfpNeg(ty, ty) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP2) MulScalar(a *gfP2, b *gfP) *gfP2 { + gfpMul(&e.x, &a.x, b) + gfpMul(&e.y, &a.y, b) + return e +} + +func (e *gfP2) Invert(a *gfP2) *gfP2 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + t1, t2, t3 := &gfP{}, &gfP{}, &gfP{} + gfpMul(t1, &a.x, &a.x) + gfpAdd(t3, t1, t1) + gfpMul(t2, &a.y, &a.y) + gfpAdd(t3, t3, t2) + + inv := &gfP{} + inv.Invert(t3) // inv = (2 * a.x ^ 2 + a.y ^ 2) ^ (-1) + + gfpNeg(t1, &a.x) + + gfpMul(&e.x, t1, inv) // x = - a.x * inv + gfpMul(&e.y, &a.y, inv) // y = a.y * inv + return e +} + +func (e *gfP2) Exp(f *gfP2, power *big.Int) *gfP2 { + sum := (&gfP2{}).SetOne() + t := &gfP2{} + + for i := power.BitLen() - 1; i >= 0; i-- { + t.Square(sum) + if power.Bit(i) != 0 { + sum.Mul(t, f) + } else { + sum.Set(t) + } + } + + e.Set(sum) + return e +} + +// (xi+y)^p = x * i^p + y +// = x * i * i^(p-1) + y +// = (-x)*i + y +// here i^(p-1) = -1 +func (e *gfP2) Frobenius(a *gfP2) *gfP2 { + e.Conjugate(a) + return e +} + +// Sqrt method is only required when we implement compressed format +func (e *gfP2) Sqrt(f *gfP2) *gfP2 { + // Algorithm 10 https://eprint.iacr.org/2012/685.pdf + // TODO + b, b2, bq := &gfP2{}, &gfP2{}, &gfP2{} + b.Exp(f, pMinus1Over4) + b2.Mul(b, b) + bq.Exp(b, p) + + return bq +} + +func (e *gfP2) Div2(f *gfP2) *gfP2 { + t := &gfP2{} + t.x.Div2(&f.x) + t.y.Div2(&f.y) + + e.Set(t) + return e +} diff --git a/sm9/gfp2_test.go b/sm9/gfp2_test.go new file mode 100644 index 0000000..6f08a10 --- /dev/null +++ b/sm9/gfp2_test.go @@ -0,0 +1,120 @@ +package sm9 + +import ( + "math/big" + "testing" +) + +func Test_gfP2Square(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + + xmulx := &gfP2{} + xmulx.Mul(x, x) + xmulx = gfP2Decode(xmulx) + + x2 := &gfP2{} + x2.Square(x) + x2 = gfP2Decode(x2) + + if xmulx.x != x2.x || xmulx.y != x2.y { + t.Errorf("xmulx=%v, x2=%v", xmulx, x2) + } +} + +func Test_gfP2Invert(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + + xInv := &gfP2{} + xInv.Invert(x) + + y := &gfP2{} + y.Mul(x, xInv) + expected := (&gfP2{}).SetOne() + + if y.x != expected.x || y.y != expected.y { + t.Errorf("got %v, expected %v", y, expected) + } + + x = &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *zero, + } + + xInv.Invert(x) + + y.Mul(x, xInv) + + if y.x != expected.x || y.y != expected.y { + t.Errorf("got %v, expected %v", y, expected) + } + + x = &gfP2{ + *zero, + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + + xInv.Invert(x) + + y.Mul(x, xInv) + + if y.x != expected.x || y.y != expected.y { + t.Errorf("got %v, expected %v", y, expected) + } +} + +func Test_gfP2Exp(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + } + got := &gfP2{} + got.Exp(x, big.NewInt(1)) + if x.x != got.x || x.y != got.y { + t.Errorf("got %v, expected %v", got, x) + } +} + +func Test_gfP2Frobenius(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + expected := &gfP2{} + expected.Exp(x, p) + got := &gfP2{} + got.Frobenius(x) + if expected.x != got.x || expected.y != got.y { + t.Errorf("got %v, expected %v", got, x) + } + + // make sure i^(p-1) = -1 + i := &gfP2{} + i.SetU() + i.Exp(i, bigFromHex("b640000002a3a6f1d603ab4ff58ec74521f2934b1a7aeedbe56f9b27e351457c")) + i = gfP2Decode(i) + expected.y.Set(newGFp(-1)) + expected.x.Set(zero) + expected = gfP2Decode(expected) + if expected.x != i.x || expected.y != i.y { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP2Div2(t *testing.T) { + x := &gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + } + ret := &gfP2{} + ret.Div2(x) + ret.Add(ret, ret) + if *ret != *x { + t.Errorf("got %v, expected %v", ret, x) + } +} diff --git a/sm9/gfp4.go b/sm9/gfp4.go new file mode 100644 index 0000000..94a1b4b --- /dev/null +++ b/sm9/gfp4.go @@ -0,0 +1,243 @@ +package sm9 + +import "math/big" + +// For details of the algorithms used, see "Multiplication and Squaring on +// Pairing-Friendly Fields, Devegili et al. +// http://eprint.iacr.org/2006/471.pdf. +// + +// gfP4 implements the field of size p^4 as a quadratic extension of gfP2 +// where u²=i. +type gfP4 struct { + x, y gfP2 // value is xi+y. +} + +func gfP4Decode(in *gfP4) *gfP4 { + out := &gfP4{} + out.x = *gfP2Decode(&in.x) + out.y = *gfP2Decode(&in.y) + return out +} + +func (e *gfP4) String() string { + return "(" + e.x.String() + ", " + e.y.String() + ")" +} + +func (e *gfP4) Set(a *gfP4) *gfP4 { + e.x.Set(&a.x) + e.y.Set(&a.y) + return e +} + +func (e *gfP4) SetZero() *gfP4 { + e.x.SetZero() + e.y.SetZero() + return e +} + +func (e *gfP4) SetOne() *gfP4 { + e.x.SetZero() + e.y.SetOne() + return e +} + +func (e *gfP4) SetV() *gfP4 { + e.x.SetOne() + e.y.SetZero() + return e +} + +func (e *gfP4) IsZero() bool { + return e.x.IsZero() && e.y.IsZero() +} + +func (e *gfP4) IsOne() bool { + return e.x.IsZero() && e.y.IsOne() +} + +func (e *gfP4) Conjugate(a *gfP4) *gfP4 { + e.y.Set(&a.y) + e.x.Neg(&a.x) + return e +} + +func (e *gfP4) Neg(a *gfP4) *gfP4 { + e.x.Neg(&a.x) + e.y.Neg(&a.y) + return e +} + +func (e *gfP4) Add(a, b *gfP4) *gfP4 { + e.x.Add(&a.x, &b.x) + e.y.Add(&a.y, &b.y) + return e +} + +func (e *gfP4) Sub(a, b *gfP4) *gfP4 { + e.x.Sub(&a.x, &b.x) + e.y.Sub(&a.y, &b.y) + return e +} + +func (e *gfP4) MulScalar(a *gfP4, b *gfP2) *gfP4 { + e.x.Mul(&a.x, b) + e.y.Mul(&a.y, b) + return e +} + +func (e *gfP4) MulGFP(a *gfP4, b *gfP) *gfP4 { + e.x.MulScalar(&a.x, b) + e.y.MulScalar(&a.y, b) + return e +} + +func (e *gfP4) Mul(a, b *gfP4) *gfP4 { + // "Multiplication and Squaring on Pairing-Friendly Fields" + // Section 4, Karatsuba method. + // http://eprint.iacr.org/2006/471.pdf + //(a0+a1*v)(b0+b1*v)=c0+c1*v, where + //c0 = a0*b0 +a1*b1*u + //c1 = (a0 + a1)(b0 + b1) - a0*b0 - a1*b1 = a0*b1 + a1*b0 + tx, t := &gfP2{}, &gfP2{} + tx.Mul(&a.x, &b.y) + t.Mul(&a.y, &b.x) + tx.Add(tx, t) + + ty := &gfP2{} + ty.Mul(&a.y, &b.y) + t.MulU(&a.x, &b.x) + ty.Add(ty, t) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +// MulV: a * b * v +//(a0+a1*v)(b0+b1*v)*v=c0+c1*v, where +// (a0*b0 + a0*b1v + a1*b0*v + a1*b1*u)*v +// a0*b0*v + a0*b1*u + a1*b0*u + a1*b1*u*v +// c0 = a0*b1*u + a1*b0*u +// c1 = a0*b0 + a1*b1*u +func (e *gfP4) MulV(a, b *gfP4) *gfP4 { + tx, ty, t := &gfP2{}, &gfP2{}, &gfP2{} + ty.MulU(&a.y, &b.x) + t.MulU(&a.x, &b.y) + ty.Add(ty, t) + + tx.Mul(&a.y, &b.y) + t.MulU(&a.x, &b.x) + tx.Add(tx, t) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP4) Square(a *gfP4) *gfP4 { + // Complex squaring algorithm: + // (xv+y)² = (x^2*u + y^2) + 2*x*y*v + tx, ty := &gfP2{}, &gfP2{} + tx.SquareU(&a.x) + ty.Square(&a.y) + ty.Add(tx, ty) + + tx.Mul(&a.x, &a.y) + tx.Add(tx, tx) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +// SquareV: (a^2) * v +// v*(xv+y)² = (x^2*u + y^2)v + 2*x*y*u +func (e *gfP4) SquareV(a *gfP4) *gfP4 { + tx, ty := &gfP2{}, &gfP2{} + tx.SquareU(&a.x) + ty.Square(&a.y) + tx.Add(tx, ty) + + ty.MulU(&a.x, &a.y) + ty.Add(ty, ty) + + e.x.Set(tx) + e.y.Set(ty) + return e +} + +func (e *gfP4) Invert(a *gfP4) *gfP4 { + // See "Implementing cryptographic pairings", M. Scott, section 3.2. + // ftp://136.206.11.249/pub/crypto/pairings.pdf + t1, t2, t3 := &gfP2{}, &gfP2{}, &gfP2{} + + t3.SquareU(&a.x) + t1.Square(&a.y) + t3.Sub(t3, t1) + t3.Invert(t3) + + t1.Mul(&a.y, t3) + t1.Neg(t1) + + t2.Mul(&a.x, t3) + + e.x.Set(t2) + e.y.Set(t1) + return e +} + +func (e *gfP4) Exp(f *gfP4, power *big.Int) *gfP4 { + sum := (&gfP4{}).SetOne() + t := &gfP4{} + + for i := power.BitLen() - 1; i >= 0; i-- { + t.Square(sum) + if power.Bit(i) != 0 { + sum.Mul(t, f) + } else { + sum.Set(t) + } + } + + e.Set(sum) + return e +} + +// (y+x*v)^p +// = y^p + x^p*v^p +// = f(y) + f(x) * v^p +// = f(y) + f(x) * v * v^(p-1) +func (e *gfP4) Frobenius(a *gfP4) *gfP4 { + x, y := &gfP2{}, &gfP2{} + x.Conjugate(&a.x) + y.Conjugate(&a.y) + x.MulScalar(x, vToPMinus1) + + e.x.Set(x) + e.y.Set(y) + + return e +} + +// (y+x*v)^(p^2) +// y + x*v * v^(p^2-1) +func (e *gfP4) FrobeniusP2(a *gfP4) *gfP4 { + e.Conjugate(a) + return e +} + +// (y+x*v)^(p^3) +// = ((y+x*v)^p)^(p^2) +func (e *gfP4) FrobeniusP3(a *gfP4) *gfP4 { + x, y := &gfP2{}, &gfP2{} + x.Conjugate(&a.x) + y.Conjugate(&a.y) + x.MulScalar(x, vToPMinus1) + x.Neg(x) + + e.x.Set(x) + e.y.Set(y) + + return e +} diff --git a/sm9/gfp4_test.go b/sm9/gfp4_test.go new file mode 100644 index 0000000..fd16eae --- /dev/null +++ b/sm9/gfp4_test.go @@ -0,0 +1,180 @@ +package sm9 + +import ( + "math/big" + "testing" +) + +func Test_gfP4Square(t *testing.T) { + x := &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + xmulx := &gfP4{} + xmulx.Mul(x, x) + xmulx = gfP4Decode(xmulx) + + x2 := &gfP4{} + x2.Square(x) + x2 = gfP4Decode(x2) + + if xmulx.x != x2.x || xmulx.y != x2.y { + t.Errorf("xmulx=%v, x2=%v", xmulx, x2) + } +} + +func Test_gfP4Invert(t *testing.T) { + gfp2Zero := (&gfP2{}).SetZero() + x := &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + + xInv := &gfP4{} + xInv.Invert(x) + + y := &gfP4{} + y.Mul(x, xInv) + if !y.IsOne() { + t.Fail() + } + + x = &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + *gfp2Zero, + } + + xInv.Invert(x) + + y.Mul(x, xInv) + if !y.IsOne() { + t.Fail() + } + + x = &gfP4{ + *gfp2Zero, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + + xInv.Invert(x) + + y.Mul(x, xInv) + if !y.IsOne() { + t.Fail() + } +} + +func Test_gfP4Frobenius(t *testing.T) { + x := &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + expected := &gfP4{} + expected.Exp(x, p) + got := &gfP4{} + got.Frobenius(x) + if expected.x != got.x || expected.y != got.y { + t.Errorf("got %v, expected %v", got, expected) + } +} + +// Generate vToPMinus1 +func Test_gfP4Frobenius_Case1(t *testing.T) { + expected := &gfP4{} + i := &gfP4{} + i.SetV() + pMinus1 := new(big.Int).Sub(p, big.NewInt(1)) + i.Exp(i, pMinus1) + i = gfP4Decode(i) + expected.y.x.Set(zero) + expected.y.y.Set(fromBigInt(bigFromHex("6c648de5dc0a3f2cf55acc93ee0baf159f9d411806dc5177f5b21fd3da24d011"))) + expected.x.SetZero() + expected = gfP4Decode(expected) + if expected.x != i.x || expected.y != i.y { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP4FrobeniusP2(t *testing.T) { + x := &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + expected := &gfP4{} + p2 := new(big.Int).Mul(p, p) + expected.Exp(x, p2) + got := &gfP4{} + got.FrobeniusP2(x) + if expected.x != got.x || expected.y != got.y { + t.Errorf("got %v, expected %v", got, expected) + } +} + +func Test_gfP4FrobeniusP2_Case1(t *testing.T) { + expected := &gfP4{} + i := &gfP4{} + i.SetV() + p2 := new(big.Int).Mul(p, p) + p2 = new(big.Int).Sub(p2, big.NewInt(1)) + i.Exp(i, p2) + i = gfP4Decode(i) + expected.y.x.Set(zero) + expected.y.y.Set(newGFp(-1)) + expected.x.SetZero() + expected = gfP4Decode(expected) + if expected.x != i.x || expected.y != i.y { + t.Errorf("got %v, expected %v", i, expected) + } +} + +func Test_gfP4FrobeniusP3(t *testing.T) { + x := &gfP4{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + } + expected := &gfP4{} + p3 := new(big.Int).Mul(p, p) + p3 = p3.Mul(p3, p) + expected.Exp(x, p3) + got := &gfP4{} + got.FrobeniusP3(x) + if expected.x != got.x || expected.y != got.y { + t.Errorf("got %v, expected %v", got, expected) + } +} diff --git a/sm9/gfp_amd64.s b/sm9/gfp_amd64.s new file mode 100644 index 0000000..64c97ea --- /dev/null +++ b/sm9/gfp_amd64.s @@ -0,0 +1,129 @@ +// +build amd64,!generic + +#define storeBlock(a0,a1,a2,a3, r) \ + MOVQ a0, 0+r \ + MOVQ a1, 8+r \ + MOVQ a2, 16+r \ + MOVQ a3, 24+r + +#define loadBlock(r, a0,a1,a2,a3) \ + MOVQ 0+r, a0 \ + MOVQ 8+r, a1 \ + MOVQ 16+r, a2 \ + MOVQ 24+r, a3 + +#define gfpCarry(a0,a1,a2,a3,a4, b0,b1,b2,b3,b4) \ + \ // b = a-p + MOVQ a0, b0 \ + MOVQ a1, b1 \ + MOVQ a2, b2 \ + MOVQ a3, b3 \ + MOVQ a4, b4 \ + \ + SUBQ ·p2+0(SB), b0 \ + SBBQ ·p2+8(SB), b1 \ + SBBQ ·p2+16(SB), b2 \ + SBBQ ·p2+24(SB), b3 \ + SBBQ $0, b4 \ + \ + \ // if b is negative then return a + \ // else return b + CMOVQCC b0, a0 \ + CMOVQCC b1, a1 \ + CMOVQCC b2, a2 \ + CMOVQCC b3, a3 + +#include "mul_amd64.h" +#include "mul_bmi2_amd64.h" + +TEXT ·gfpNeg(SB),0,$0-16 + MOVQ ·p2+0(SB), R8 + MOVQ ·p2+8(SB), R9 + MOVQ ·p2+16(SB), R10 + MOVQ ·p2+24(SB), R11 + + MOVQ a+8(FP), DI + SUBQ 0(DI), R8 + SBBQ 8(DI), R9 + SBBQ 16(DI), R10 + SBBQ 24(DI), R11 + + MOVQ $0, AX + gfpCarry(R8,R9,R10,R11,AX, R12,R13,R14,CX,BX) + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpAdd(SB),0,$0-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + loadBlock(0(DI), R8,R9,R10,R11) + MOVQ $0, R12 + + ADDQ 0(SI), R8 + ADCQ 8(SI), R9 + ADCQ 16(SI), R10 + ADCQ 24(SI), R11 + ADCQ $0, R12 + + gfpCarry(R8,R9,R10,R11,R12, R13,R14,CX,AX,BX) + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpSub(SB),0,$0-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + loadBlock(0(DI), R8,R9,R10,R11) + + MOVQ ·p2+0(SB), R12 + MOVQ ·p2+8(SB), R13 + MOVQ ·p2+16(SB), R14 + MOVQ ·p2+24(SB), CX + MOVQ $0, AX + + SUBQ 0(SI), R8 + SBBQ 8(SI), R9 + SBBQ 16(SI), R10 + SBBQ 24(SI), R11 + + CMOVQCC AX, R12 + CMOVQCC AX, R13 + CMOVQCC AX, R14 + CMOVQCC AX, CX + + ADDQ R12, R8 + ADCQ R13, R9 + ADCQ R14, R10 + ADCQ CX, R11 + + MOVQ c+0(FP), DI + storeBlock(R8,R9,R10,R11, 0(DI)) + RET + +TEXT ·gfpMul(SB),0,$160-24 + MOVQ a+8(FP), DI + MOVQ b+16(FP), SI + + // Jump to a slightly different implementation if MULX isn't supported. + CMPB ·hasBMI2(SB), $0 + JE nobmi2Mul + + mulBMI2(0(DI),8(DI),16(DI),24(DI), 0(SI)) + storeBlock( R8, R9,R10,R11, 0(SP)) + storeBlock(R12,R13,R14,CX, 32(SP)) + gfpReduceBMI2() + JMP end + +nobmi2Mul: + mul(0(DI),8(DI),16(DI),24(DI), 0(SI), 0(SP)) + gfpReduce(0(SP)) + +end: + MOVQ c+0(FP), DI + storeBlock(R12,R13,R14,CX, 0(DI)) + RET diff --git a/sm9/gfp_arm64.s b/sm9/gfp_arm64.s new file mode 100644 index 0000000..c65e801 --- /dev/null +++ b/sm9/gfp_arm64.s @@ -0,0 +1,113 @@ +// +build arm64,!generic + +#define storeBlock(a0,a1,a2,a3, r) \ + MOVD a0, 0+r \ + MOVD a1, 8+r \ + MOVD a2, 16+r \ + MOVD a3, 24+r + +#define loadBlock(r, a0,a1,a2,a3) \ + MOVD 0+r, a0 \ + MOVD 8+r, a1 \ + MOVD 16+r, a2 \ + MOVD 24+r, a3 + +#define loadModulus(p0,p1,p2,p3) \ + MOVD ·p2+0(SB), p0 \ + MOVD ·p2+8(SB), p1 \ + MOVD ·p2+16(SB), p2 \ + MOVD ·p2+24(SB), p3 + +#include "mul_arm64.h" + +TEXT ·gfpNeg(SB),0,$0-16 + MOVD a+8(FP), R0 + loadBlock(0(R0), R1,R2,R3,R4) + loadModulus(R5,R6,R7,R8) + + SUBS R1, R5, R1 + SBCS R2, R6, R2 + SBCS R3, R7, R3 + SBCS R4, R8, R4 + + SUBS R5, R1, R5 + SBCS R6, R2, R6 + SBCS R7, R3, R7 + SBCS R8, R4, R8 + + CSEL CS, R5, R1, R1 + CSEL CS, R6, R2, R2 + CSEL CS, R7, R3, R3 + CSEL CS, R8, R4, R4 + + MOVD c+0(FP), R0 + storeBlock(R1,R2,R3,R4, 0(R0)) + RET + +TEXT ·gfpAdd(SB),0,$0-24 + MOVD a+8(FP), R0 + loadBlock(0(R0), R1,R2,R3,R4) + MOVD b+16(FP), R0 + loadBlock(0(R0), R5,R6,R7,R8) + loadModulus(R9,R10,R11,R12) + MOVD ZR, R0 + + ADDS R5, R1 + ADCS R6, R2 + ADCS R7, R3 + ADCS R8, R4 + ADCS ZR, R0 + + SUBS R9, R1, R5 + SBCS R10, R2, R6 + SBCS R11, R3, R7 + SBCS R12, R4, R8 + SBCS ZR, R0, R0 + + CSEL CS, R5, R1, R1 + CSEL CS, R6, R2, R2 + CSEL CS, R7, R3, R3 + CSEL CS, R8, R4, R4 + + MOVD c+0(FP), R0 + storeBlock(R1,R2,R3,R4, 0(R0)) + RET + +TEXT ·gfpSub(SB),0,$0-24 + MOVD a+8(FP), R0 + loadBlock(0(R0), R1,R2,R3,R4) + MOVD b+16(FP), R0 + loadBlock(0(R0), R5,R6,R7,R8) + loadModulus(R9,R10,R11,R12) + + SUBS R5, R1 + SBCS R6, R2 + SBCS R7, R3 + SBCS R8, R4 + + CSEL CS, ZR, R9, R9 + CSEL CS, ZR, R10, R10 + CSEL CS, ZR, R11, R11 + CSEL CS, ZR, R12, R12 + + ADDS R9, R1 + ADCS R10, R2 + ADCS R11, R3 + ADCS R12, R4 + + MOVD c+0(FP), R0 + storeBlock(R1,R2,R3,R4, 0(R0)) + RET + +TEXT ·gfpMul(SB),0,$0-24 + MOVD a+8(FP), R0 + loadBlock(0(R0), R1,R2,R3,R4) + MOVD b+16(FP), R0 + loadBlock(0(R0), R5,R6,R7,R8) + + mul(R9,R10,R11,R12,R13,R14,R15,R16) + gfpReduce() + + MOVD c+0(FP), R0 + storeBlock(R1,R2,R3,R4, 0(R0)) + RET diff --git a/sm9/gfp_decl.go b/sm9/gfp_decl.go new file mode 100644 index 0000000..4a8621c --- /dev/null +++ b/sm9/gfp_decl.go @@ -0,0 +1,25 @@ +//go:build (amd64 && !generic) || (arm64 && !generic) +// +build amd64,!generic arm64,!generic + +package sm9 + +// This file contains forward declarations for the architecture-specific +// assembly implementations of these functions, provided that they exist. + +import ( + "golang.org/x/sys/cpu" +) + +var hasBMI2 = cpu.X86.HasBMI2 + +// go:noescape +func gfpNeg(c, a *gfP) + +//go:noescape +func gfpAdd(c, a, b *gfP) + +//go:noescape +func gfpSub(c, a, b *gfP) + +//go:noescape +func gfpMul(c, a, b *gfP) diff --git a/sm9/gfp_generic.go b/sm9/gfp_generic.go new file mode 100644 index 0000000..87b42c3 --- /dev/null +++ b/sm9/gfp_generic.go @@ -0,0 +1,174 @@ +//go:build !amd64 && !arm64 || generic +// +build !amd64,!arm64 generic + +package sm9 + +func gfpCarry(a *gfP, head uint64) { + b := &gfP{} + + var carry uint64 + for i, pi := range p2 { + ai := a[i] + bi := ai - pi - carry + b[i] = bi + carry = (pi&^ai | (pi|^ai)&bi) >> 63 + } + carry = carry &^ head + + // If b is negative, then return a. + // Else return b. + carry = -carry + ncarry := ^carry + for i := 0; i < 4; i++ { + a[i] = (a[i] & carry) | (b[i] & ncarry) + } +} + +func gfpNeg(c, a *gfP) { + var carry uint64 + for i, pi := range p2 { + ai := a[i] + ci := pi - ai - carry + c[i] = ci + carry = (ai&^pi | (ai|^pi)&ci) >> 63 + } + gfpCarry(c, 0) +} + +func gfpAdd(c, a, b *gfP) { + var carry uint64 + for i, ai := range a { + bi := b[i] + ci := ai + bi + carry + c[i] = ci + carry = (ai&bi | (ai|bi)&^ci) >> 63 + } + gfpCarry(c, carry) +} + +func gfpSub(c, a, b *gfP) { + t := &gfP{} + + var carry uint64 + for i, pi := range p2 { + bi := b[i] + ti := pi - bi - carry + t[i] = ti + carry = (bi&^pi | (bi|^pi)&ti) >> 63 + } + + carry = 0 + for i, ai := range a { + ti := t[i] + ci := ai + ti + carry + c[i] = ci + carry = (ai&ti | (ai|ti)&^ci) >> 63 + } + gfpCarry(c, carry) +} + +func mul(a, b [4]uint64) [8]uint64 { + const ( + mask16 uint64 = 0x0000ffff + mask32 uint64 = 0xffffffff + ) + + var buff [32]uint64 + for i, ai := range a { + a0, a1, a2, a3 := ai&mask16, (ai>>16)&mask16, (ai>>32)&mask16, ai>>48 + + for j, bj := range b { + b0, b2 := bj&mask32, bj>>32 + + off := 4 * (i + j) + buff[off+0] += a0 * b0 + buff[off+1] += a1 * b0 + buff[off+2] += a2*b0 + a0*b2 + buff[off+3] += a3*b0 + a1*b2 + buff[off+4] += a2 * b2 + buff[off+5] += a3 * b2 + } + } + + for i := uint(1); i < 4; i++ { + shift := 16 * i + + var head, carry uint64 + for j := uint(0); j < 8; j++ { + block := 4 * j + + xi := buff[block] + yi := (buff[block+i] << shift) + head + zi := xi + yi + carry + buff[block] = zi + carry = (xi&yi | (xi|yi)&^zi) >> 63 + + head = buff[block+i] >> (64 - shift) + } + } + + return [8]uint64{buff[0], buff[4], buff[8], buff[12], buff[16], buff[20], buff[24], buff[28]} +} + +func halfMul(a, b [4]uint64) [4]uint64 { + const ( + mask16 uint64 = 0x0000ffff + mask32 uint64 = 0xffffffff + ) + + var buff [18]uint64 + for i, ai := range a { + a0, a1, a2, a3 := ai&mask16, (ai>>16)&mask16, (ai>>32)&mask16, ai>>48 + + for j, bj := range b { + if i+j > 3 { + break + } + b0, b2 := bj&mask32, bj>>32 + + off := 4 * (i + j) + buff[off+0] += a0 * b0 + buff[off+1] += a1 * b0 + buff[off+2] += a2*b0 + a0*b2 + buff[off+3] += a3*b0 + a1*b2 + buff[off+4] += a2 * b2 + buff[off+5] += a3 * b2 + } + } + + for i := uint(1); i < 4; i++ { + shift := 16 * i + + var head, carry uint64 + for j := uint(0); j < 4; j++ { + block := 4 * j + + xi := buff[block] + yi := (buff[block+i] << shift) + head + zi := xi + yi + carry + buff[block] = zi + carry = (xi&yi | (xi|yi)&^zi) >> 63 + + head = buff[block+i] >> (64 - shift) + } + } + + return [4]uint64{buff[0], buff[4], buff[8], buff[12]} +} + +func gfpMul(c, a, b *gfP) { + T := mul(*a, *b) + m := halfMul([4]uint64{T[0], T[1], T[2], T[3]}, np) + t := mul([4]uint64{m[0], m[1], m[2], m[3]}, p2) + + var carry uint64 + for i, Ti := range T { + ti := t[i] + zi := Ti + ti + carry + T[i] = zi + carry = (Ti&ti | (Ti|ti)&^zi) >> 63 + } + + *c = gfP{T[4], T[5], T[6], T[7]} + gfpCarry(c, carry) +} diff --git a/sm9/gfp_test.go b/sm9/gfp_test.go new file mode 100644 index 0000000..ef5583e --- /dev/null +++ b/sm9/gfp_test.go @@ -0,0 +1,54 @@ +package sm9 + +import ( + "encoding/hex" + "math/big" + "testing" +) + +func TestSqrt(t *testing.T) { + tests := []string{ + "9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596", + "92fe90b700fbd4d8cc177d300ed16e4e15471a681b2c9e3728c1b82c885e49c2", + } + for i, test := range tests { + y2 := bigFromHex(test) + y21 := new(big.Int).ModSqrt(y2, p) + + y3 := new(big.Int).Mul(y21, y21) + y3.Mod(y3, p) + if y2.Cmp(y3) != 0 { + t.Error("Invalid sqrt") + } + + tmp := fromBigInt(y2) + tmp.Sqrt(tmp) + montDecode(tmp, tmp) + var res [32]byte + tmp.Marshal(res[:]) + if hex.EncodeToString(res[:]) != hex.EncodeToString(y21.Bytes()) { + t.Errorf("case %v, got %v, expected %v\n", i, hex.EncodeToString(res[:]), hex.EncodeToString(y21.Bytes())) + } + } +} + +func TestInvert(t *testing.T) { + x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596")) + xInv := &gfP{} + xInv.Invert(x) + y := &gfP{} + gfpMul(y, x, xInv) + if *y != *one { + t.Errorf("got %v, expected %v", y, one) + } +} + +func TestDiv(t *testing.T) { + x := fromBigInt(bigFromHex("9093a2b979e6186f43a9b28d41ba644d533377f2ede8c66b19774bf4a9c7a596")) + ret := &gfP{} + ret.Div2(x) + gfpAdd(ret, ret, ret) + if *ret != *x { + t.Errorf("got %v, expected %v", ret, x) + } +} diff --git a/sm9/gt.go b/sm9/gt.go new file mode 100644 index 0000000..f51de78 --- /dev/null +++ b/sm9/gt.go @@ -0,0 +1,199 @@ +package sm9 + +import ( + "errors" + "io" + "math/big" +) + +// GT is an abstract cyclic group. The zero value is suitable for use as the +// output of an operation, but cannot be used as an input. +type GT struct { + p *gfP12 +} + +// RandomGT returns x and e(g₁, g₂)ˣ where x is a random, non-zero number read +// from r. +func RandomGT(r io.Reader) (*big.Int, *GT, error) { + k, err := randomK(r) + if err != nil { + return nil, nil, err + } + + return k, new(GT).ScalarBaseMult(k), nil +} + +// Pair calculates an R-Ate pairing. +func Pair(g1 *G1, g2 *G2) *GT { + return >{pairing(g2.p, g1.p)} +} + +// Miller applies Miller's algorithm, which is a bilinear function from the +// source groups to F_p^12. Miller(g1, g2).Finalize() is equivalent to Pair(g1, +// g2). +func Miller(g1 *G1, g2 *G2) *GT { + return >{miller(g2.p, g1.p)} +} + +func (g *GT) String() string { + return "sm9.GT" + g.p.String() +} + +// ScalarBaseMult sets e to g*k where g is the generator of the group and then +// returns out. +func (e *GT) ScalarBaseMult(k *big.Int) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Exp(gfP12Gen, k) + return e +} + +// ScalarMult sets e to a*k and then returns e. +func (e *GT) ScalarMult(a *GT, k *big.Int) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Exp(a.p, k) + return e +} + +// Add sets e to a+b and then returns e. +func (e *GT) Add(a, b *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Mul(a.p, b.p) + return e +} + +// Neg sets e to -a and then returns e. +func (e *GT) Neg(a *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Neg(a.p) // TODO: fix it. + return e +} + +// Set sets e to a and then returns e. +func (e *GT) Set(a *GT) *GT { + if e.p == nil { + e.p = &gfP12{} + } + e.p.Set(a.p) + return e +} + +// Finalize is a linear function from F_p^12 to GT. +func (e *GT) Finalize() *GT { + ret := finalExponentiation(e.p) + e.p.Set(ret) + return e +} + +// Marshal converts e into a byte slice. +func (e *GT) Marshal() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + ret := make([]byte, numBytes*12) + temp := &gfP{} + + montDecode(temp, &e.p.x.x.x) + temp.Marshal(ret) + montDecode(temp, &e.p.x.x.y) + temp.Marshal(ret[numBytes:]) + montDecode(temp, &e.p.x.y.x) + temp.Marshal(ret[2*numBytes:]) + montDecode(temp, &e.p.x.y.y) + temp.Marshal(ret[3*numBytes:]) + + montDecode(temp, &e.p.y.x.x) + temp.Marshal(ret[4*numBytes:]) + montDecode(temp, &e.p.y.x.y) + temp.Marshal(ret[5*numBytes:]) + montDecode(temp, &e.p.y.y.x) + temp.Marshal(ret[6*numBytes:]) + montDecode(temp, &e.p.y.y.y) + temp.Marshal(ret[7*numBytes:]) + + montDecode(temp, &e.p.z.x.x) + temp.Marshal(ret[8*numBytes:]) + montDecode(temp, &e.p.z.x.y) + temp.Marshal(ret[9*numBytes:]) + montDecode(temp, &e.p.z.y.x) + temp.Marshal(ret[10*numBytes:]) + montDecode(temp, &e.p.z.y.y) + temp.Marshal(ret[11*numBytes:]) + + return ret +} + +// Unmarshal sets e to the result of converting the output of Marshal back into +// a group element and then returns e. +func (e *GT) Unmarshal(m []byte) ([]byte, error) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + if len(m) < 12*numBytes { + return nil, errors.New("sm9.GT: not enough data") + } + + if e.p == nil { + e.p = &gfP12{} + } + + var err error + if err = e.p.x.x.x.Unmarshal(m); err != nil { + return nil, err + } + if err = e.p.x.x.y.Unmarshal(m[numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.y.x.Unmarshal(m[2*numBytes:]); err != nil { + return nil, err + } + if err = e.p.x.y.y.Unmarshal(m[3*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.x.Unmarshal(m[4*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.x.y.Unmarshal(m[5*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.x.Unmarshal(m[6*numBytes:]); err != nil { + return nil, err + } + if err = e.p.y.y.y.Unmarshal(m[7*numBytes:]); err != nil { + return nil, err + } + if err = e.p.z.x.x.Unmarshal(m[8*numBytes:]); err != nil { + return nil, err + } + if err = e.p.z.x.y.Unmarshal(m[9*numBytes:]); err != nil { + return nil, err + } + if err = e.p.z.y.x.Unmarshal(m[10*numBytes:]); err != nil { + return nil, err + } + if err = e.p.z.y.y.Unmarshal(m[11*numBytes:]); err != nil { + return nil, err + } + + montEncode(&e.p.x.x.x, &e.p.x.x.x) + montEncode(&e.p.x.x.y, &e.p.x.x.y) + montEncode(&e.p.x.y.x, &e.p.x.y.x) + montEncode(&e.p.x.y.y, &e.p.x.y.y) + montEncode(&e.p.y.x.x, &e.p.y.x.x) + montEncode(&e.p.y.x.y, &e.p.y.x.y) + montEncode(&e.p.y.y.x, &e.p.y.y.x) + montEncode(&e.p.y.y.y, &e.p.y.y.y) + montEncode(&e.p.z.x.x, &e.p.z.x.x) + montEncode(&e.p.z.x.y, &e.p.z.x.y) + montEncode(&e.p.z.y.x, &e.p.z.y.x) + montEncode(&e.p.z.y.y, &e.p.z.y.y) + + return m[12*numBytes:], nil +} diff --git a/sm9/gt_test.go b/sm9/gt_test.go new file mode 100644 index 0000000..bd44c80 --- /dev/null +++ b/sm9/gt_test.go @@ -0,0 +1,44 @@ +package sm9 + +import ( + "bytes" + "crypto/rand" + "testing" +) + +func TestGT(t *testing.T) { + k, Ga, err := RandomGT(rand.Reader) + if err != nil { + t.Fatal(err) + } + ma := Ga.Marshal() + + Gb := new(GT) + _, err = Gb.Unmarshal((>{gfP12Gen}).Marshal()) + if err != nil { + t.Fatal("unmarshal not ok") + } + Gb.ScalarMult(Gb, k) + mb := Gb.Marshal() + + if !bytes.Equal(ma, mb) { + t.Fatal("bytes are different") + } +} + +func BenchmarkGT(b *testing.B) { + x, _ := rand.Int(rand.Reader, Order) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + new(GT).ScalarBaseMult(x) + } +} + +func BenchmarkPairing(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + Pair(&G1{curveGen}, &G2{twistGen}) + } +} diff --git a/sm9/mul_amd64.h b/sm9/mul_amd64.h new file mode 100644 index 0000000..9d8e4b3 --- /dev/null +++ b/sm9/mul_amd64.h @@ -0,0 +1,181 @@ +#define mul(a0,a1,a2,a3, rb, stack) \ + MOVQ a0, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a0, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a0, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a0, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + storeBlock(R8,R9,R10,R11, 0+stack) \ + MOVQ R12, 32+stack \ + \ + MOVQ a1, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a1, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a1, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a1, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 8+stack, R8 \ + ADCQ 16+stack, R9 \ + ADCQ 24+stack, R10 \ + ADCQ 32+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 8+stack) \ + MOVQ R12, 40+stack \ + \ + MOVQ a2, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a2, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a2, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a2, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 16+stack, R8 \ + ADCQ 24+stack, R9 \ + ADCQ 32+stack, R10 \ + ADCQ 40+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 16+stack) \ + MOVQ R12, 48+stack \ + \ + MOVQ a3, AX \ + MULQ 0+rb \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ a3, AX \ + MULQ 8+rb \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ a3, AX \ + MULQ 16+rb \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ a3, AX \ + MULQ 24+rb \ + ADDQ AX, R11 \ + ADCQ $0, DX \ + MOVQ DX, R12 \ + \ + ADDQ 24+stack, R8 \ + ADCQ 32+stack, R9 \ + ADCQ 40+stack, R10 \ + ADCQ 48+stack, R11 \ + ADCQ $0, R12 \ + storeBlock(R8,R9,R10,R11, 24+stack) \ + MOVQ R12, 56+stack + +#define gfpReduce(stack) \ + \ // m = (T * N') mod R, store m in R8:R9:R10:R11 + MOVQ ·np+0(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R8 \ + MOVQ DX, R9 \ + MOVQ ·np+0(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R9 \ + ADCQ $0, DX \ + MOVQ DX, R10 \ + MOVQ ·np+0(SB), AX \ + MULQ 16+stack \ + ADDQ AX, R10 \ + ADCQ $0, DX \ + MOVQ DX, R11 \ + MOVQ ·np+0(SB), AX \ + MULQ 24+stack \ + ADDQ AX, R11 \ + \ + MOVQ ·np+8(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R12 \ + MOVQ DX, R13 \ + MOVQ ·np+8(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R13 \ + ADCQ $0, DX \ + MOVQ DX, R14 \ + MOVQ ·np+8(SB), AX \ + MULQ 16+stack \ + ADDQ AX, R14 \ + \ + ADDQ R12, R9 \ + ADCQ R13, R10 \ + ADCQ R14, R11 \ + \ + MOVQ ·np+16(SB), AX \ + MULQ 0+stack \ + MOVQ AX, R12 \ + MOVQ DX, R13 \ + MOVQ ·np+16(SB), AX \ + MULQ 8+stack \ + ADDQ AX, R13 \ + \ + ADDQ R12, R10 \ + ADCQ R13, R11 \ + \ + MOVQ ·np+24(SB), AX \ + MULQ 0+stack \ + ADDQ AX, R11 \ + \ + storeBlock(R8,R9,R10,R11, 64+stack) \ + \ + \ // m * N + mul(·p2+0(SB),·p2+8(SB),·p2+16(SB),·p2+24(SB), 64+stack, 96+stack) \ + \ + \ // Add the 512-bit intermediate to m*N + loadBlock(96+stack, R8,R9,R10,R11) \ + loadBlock(128+stack, R12,R13,R14,CX) \ + \ + MOVQ $0, AX \ + ADDQ 0+stack, R8 \ + ADCQ 8+stack, R9 \ + ADCQ 16+stack, R10 \ + ADCQ 24+stack, R11 \ + ADCQ 32+stack, R12 \ + ADCQ 40+stack, R13 \ + ADCQ 48+stack, R14 \ + ADCQ 56+stack, CX \ + ADCQ $0, AX \ + \ + gfpCarry(R12,R13,R14,CX,AX, R8,R9,R10,R11,BX) diff --git a/sm9/mul_arm64.h b/sm9/mul_arm64.h new file mode 100644 index 0000000..d405eb8 --- /dev/null +++ b/sm9/mul_arm64.h @@ -0,0 +1,133 @@ +#define mul(c0,c1,c2,c3,c4,c5,c6,c7) \ + MUL R1, R5, c0 \ + UMULH R1, R5, c1 \ + MUL R1, R6, R0 \ + ADDS R0, c1 \ + UMULH R1, R6, c2 \ + MUL R1, R7, R0 \ + ADCS R0, c2 \ + UMULH R1, R7, c3 \ + MUL R1, R8, R0 \ + ADCS R0, c3 \ + UMULH R1, R8, c4 \ + ADCS ZR, c4 \ + \ + MUL R2, R5, R1 \ + UMULH R2, R5, R26 \ + MUL R2, R6, R0 \ + ADDS R0, R26 \ + UMULH R2, R6, R27 \ + MUL R2, R7, R0 \ + ADCS R0, R27 \ + UMULH R2, R7, R29 \ + MUL R2, R8, R0 \ + ADCS R0, R29 \ + UMULH R2, R8, c5 \ + ADCS ZR, c5 \ + ADDS R1, c1 \ + ADCS R26, c2 \ + ADCS R27, c3 \ + ADCS R29, c4 \ + ADCS ZR, c5 \ + \ + MUL R3, R5, R1 \ + UMULH R3, R5, R26 \ + MUL R3, R6, R0 \ + ADDS R0, R26 \ + UMULH R3, R6, R27 \ + MUL R3, R7, R0 \ + ADCS R0, R27 \ + UMULH R3, R7, R29 \ + MUL R3, R8, R0 \ + ADCS R0, R29 \ + UMULH R3, R8, c6 \ + ADCS ZR, c6 \ + ADDS R1, c2 \ + ADCS R26, c3 \ + ADCS R27, c4 \ + ADCS R29, c5 \ + ADCS ZR, c6 \ + \ + MUL R4, R5, R1 \ + UMULH R4, R5, R26 \ + MUL R4, R6, R0 \ + ADDS R0, R26 \ + UMULH R4, R6, R27 \ + MUL R4, R7, R0 \ + ADCS R0, R27 \ + UMULH R4, R7, R29 \ + MUL R4, R8, R0 \ + ADCS R0, R29 \ + UMULH R4, R8, c7 \ + ADCS ZR, c7 \ + ADDS R1, c3 \ + ADCS R26, c4 \ + ADCS R27, c5 \ + ADCS R29, c6 \ + ADCS ZR, c7 + +#define gfpReduce() \ + \ // m = (T * N') mod R, store m in R1:R2:R3:R4 + MOVD ·np+0(SB), R17 \ + MOVD ·np+8(SB), R25 \ + MOVD ·np+16(SB), R19 \ + MOVD ·np+24(SB), R20 \ + \ + MUL R9, R17, R1 \ + UMULH R9, R17, R2 \ + MUL R9, R25, R0 \ + ADDS R0, R2 \ + UMULH R9, R25, R3 \ + MUL R9, R19, R0 \ + ADCS R0, R3 \ + UMULH R9, R19, R4 \ + MUL R9, R20, R0 \ + ADCS R0, R4 \ + \ + MUL R10, R17, R21 \ + UMULH R10, R17, R22 \ + MUL R10, R25, R0 \ + ADDS R0, R22 \ + UMULH R10, R25, R23 \ + MUL R10, R19, R0 \ + ADCS R0, R23 \ + ADDS R21, R2 \ + ADCS R22, R3 \ + ADCS R23, R4 \ + \ + MUL R11, R17, R21 \ + UMULH R11, R17, R22 \ + MUL R11, R25, R0 \ + ADDS R0, R22 \ + ADDS R21, R3 \ + ADCS R22, R4 \ + \ + MUL R12, R17, R21 \ + ADDS R21, R4 \ + \ + \ // m * N + loadModulus(R5,R6,R7,R8) \ + mul(R17,R25,R19,R20,R21,R22,R23,R24) \ + \ + \ // Add the 512-bit intermediate to m*N + MOVD ZR, R0 \ + ADDS R9, R17 \ + ADCS R10, R25 \ + ADCS R11, R19 \ + ADCS R12, R20 \ + ADCS R13, R21 \ + ADCS R14, R22 \ + ADCS R15, R23 \ + ADCS R16, R24 \ + ADCS ZR, R0 \ + \ + \ // Our output is R21:R22:R23:R24. Reduce mod p if necessary. + SUBS R5, R21, R10 \ + SBCS R6, R22, R11 \ + SBCS R7, R23, R12 \ + SBCS R8, R24, R13 \ + \ + CSEL CS, R10, R21, R1 \ + CSEL CS, R11, R22, R2 \ + CSEL CS, R12, R23, R3 \ + CSEL CS, R13, R24, R4 diff --git a/sm9/mul_bmi2_amd64.h b/sm9/mul_bmi2_amd64.h new file mode 100644 index 0000000..403566c --- /dev/null +++ b/sm9/mul_bmi2_amd64.h @@ -0,0 +1,112 @@ +#define mulBMI2(a0,a1,a2,a3, rb) \ + MOVQ a0, DX \ + MOVQ $0, R13 \ + MULXQ 0+rb, R8, R9 \ + MULXQ 8+rb, AX, R10 \ + ADDQ AX, R9 \ + MULXQ 16+rb, AX, R11 \ + ADCQ AX, R10 \ + MULXQ 24+rb, AX, R12 \ + ADCQ AX, R11 \ + ADCQ $0, R12 \ + ADCQ $0, R13 \ + \ + MOVQ a1, DX \ + MOVQ $0, R14 \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R9 \ + ADCQ BX, R10 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R11 \ + ADCQ BX, R12 \ + ADCQ $0, R13 \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R12 \ + ADCQ BX, R13 \ + ADCQ $0, R14 \ + \ + MOVQ a2, DX \ + MOVQ $0, CX \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R12 \ + ADCQ BX, R13 \ + ADCQ $0, R14 \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R11 \ + ADCQ BX, R12 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R13 \ + ADCQ BX, R14 \ + ADCQ $0, CX \ + \ + MOVQ a3, DX \ + MULXQ 0+rb, AX, BX \ + ADDQ AX, R11 \ + ADCQ BX, R12 \ + MULXQ 16+rb, AX, BX \ + ADCQ AX, R13 \ + ADCQ BX, R14 \ + ADCQ $0, CX \ + MULXQ 8+rb, AX, BX \ + ADDQ AX, R12 \ + ADCQ BX, R13 \ + MULXQ 24+rb, AX, BX \ + ADCQ AX, R14 \ + ADCQ BX, CX + +#define gfpReduceBMI2() \ + \ // m = (T * N') mod R, store m in R8:R9:R10:R11 + MOVQ ·np+0(SB), DX \ + MULXQ 0(SP), R8, R9 \ + MULXQ 8(SP), AX, R10 \ + ADDQ AX, R9 \ + MULXQ 16(SP), AX, R11 \ + ADCQ AX, R10 \ + MULXQ 24(SP), AX, BX \ + ADCQ AX, R11 \ + \ + MOVQ ·np+8(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R9 \ + ADCQ BX, R10 \ + MULXQ 16(SP), AX, BX \ + ADCQ AX, R11 \ + MULXQ 8(SP), AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + \ + MOVQ ·np+16(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R10 \ + ADCQ BX, R11 \ + MULXQ 8(SP), AX, BX \ + ADDQ AX, R11 \ + \ + MOVQ ·np+24(SB), DX \ + MULXQ 0(SP), AX, BX \ + ADDQ AX, R11 \ + \ + storeBlock(R8,R9,R10,R11, 64(SP)) \ + \ + \ // m * N + mulBMI2(·p2+0(SB),·p2+8(SB),·p2+16(SB),·p2+24(SB), 64(SP)) \ + \ + \ // Add the 512-bit intermediate to m*N + MOVQ $0, AX \ + ADDQ 0(SP), R8 \ + ADCQ 8(SP), R9 \ + ADCQ 16(SP), R10 \ + ADCQ 24(SP), R11 \ + ADCQ 32(SP), R12 \ + ADCQ 40(SP), R13 \ + ADCQ 48(SP), R14 \ + ADCQ 56(SP), CX \ + ADCQ $0, AX \ + \ + gfpCarry(R12,R13,R14,CX,AX, R8,R9,R10,R11,BX) diff --git a/sm9/params.go b/sm9/params.go new file mode 100644 index 0000000..91c2e4a --- /dev/null +++ b/sm9/params.go @@ -0,0 +1,238 @@ +package sm9 + +import "math/big" + +// CurveParams contains the parameters of an elliptic curve and also provides +// a generic, non-constant time implementation of Curve. +type CurveParams struct { + P *big.Int // the order of the underlying field + N *big.Int // the order of the base point + B *big.Int // the constant of the curve equation + Gx, Gy *big.Int // (x,y) of the base point + BitSize int // the size of the underlying field + Name string // the canonical name of the curve +} + +func (curve *CurveParams) Params() *CurveParams { + return curve +} + +// CurveParams operates, internally, on Jacobian coordinates. For a given +// (x, y) position on the curve, the Jacobian coordinates are (x1, y1, z1) +// where x = x1/z1² and y = y1/z1³. The greatest speedups come when the whole +// calculation can be performed within the transform (as in ScalarMult and +// ScalarBaseMult). But even for Add and Double, it's faster to apply and +// reverse the transform than to operate in affine coordinates. + +// polynomial returns x³ + b. +func (curve *CurveParams) polynomial(x *big.Int) *big.Int { + x3 := new(big.Int).Mul(x, x) + x3.Mul(x3, x) + + x3.Add(x3, curve.B) + x3.Mod(x3, curve.P) + + return x3 +} + +func (curve *CurveParams) IsOnCurve(x, y *big.Int) bool { + if x.Sign() < 0 || x.Cmp(curve.P) >= 0 || + y.Sign() < 0 || y.Cmp(curve.P) >= 0 { + return false + } + + // y² = x³ + b + y2 := new(big.Int).Mul(y, y) + y2.Mod(y2, curve.P) + + return curve.polynomial(x).Cmp(y2) == 0 +} + +// zForAffine returns a Jacobian Z value for the affine point (x, y). If x and +// y are zero, it assumes that they represent the point at infinity because (0, +// 0) is not on the any of the curves handled here. +func zForAffine(x, y *big.Int) *big.Int { + z := new(big.Int) + if x.Sign() != 0 || y.Sign() != 0 { + z.SetInt64(1) + } + return z +} + +// affineFromJacobian reverses the Jacobian transform. See the comment at the +// top of the file. If the point is ∞ it returns 0, 0. +func (curve *CurveParams) affineFromJacobian(x, y, z *big.Int) (xOut, yOut *big.Int) { + if z.Sign() == 0 { + return new(big.Int), new(big.Int) + } + + zinv := new(big.Int).ModInverse(z, curve.P) + zinvsq := new(big.Int).Mul(zinv, zinv) + + xOut = new(big.Int).Mul(x, zinvsq) + xOut.Mod(xOut, curve.P) + zinvsq.Mul(zinvsq, zinv) + yOut = new(big.Int).Mul(y, zinvsq) + yOut.Mod(yOut, curve.P) + return +} + +func (curve *CurveParams) Add(x1, y1, x2, y2 *big.Int) (*big.Int, *big.Int) { + z1 := zForAffine(x1, y1) + z2 := zForAffine(x2, y2) + return curve.affineFromJacobian(curve.addJacobian(x1, y1, z1, x2, y2, z2)) +} + +// addJacobian takes two points in Jacobian coordinates, (x1, y1, z1) and +// (x2, y2, z2) and returns their sum, also in Jacobian form. +func (curve *CurveParams) addJacobian(x1, y1, z1, x2, y2, z2 *big.Int) (*big.Int, *big.Int, *big.Int) { + // See https://hyperelliptic.org/EFD/g1p/auto-shortw-jacobian-0.html#addition-add-2007-bl + x3, y3, z3 := new(big.Int), new(big.Int), new(big.Int) + if z1.Sign() == 0 { + x3.Set(x2) + y3.Set(y2) + z3.Set(z2) + return x3, y3, z3 + } + if z2.Sign() == 0 { + x3.Set(x1) + y3.Set(y1) + z3.Set(z1) + return x3, y3, z3 + } + + z1z1 := new(big.Int).Mul(z1, z1) + z1z1.Mod(z1z1, curve.P) + z2z2 := new(big.Int).Mul(z2, z2) + z2z2.Mod(z2z2, curve.P) + + u1 := new(big.Int).Mul(x1, z2z2) + u1.Mod(u1, curve.P) + u2 := new(big.Int).Mul(x2, z1z1) + u2.Mod(u2, curve.P) + h := new(big.Int).Sub(u2, u1) + xEqual := h.Sign() == 0 + if h.Sign() == -1 { + h.Add(h, curve.P) + } + i := new(big.Int).Lsh(h, 1) + i.Mul(i, i) + j := new(big.Int).Mul(h, i) + + s1 := new(big.Int).Mul(y1, z2) + s1.Mul(s1, z2z2) + s1.Mod(s1, curve.P) + s2 := new(big.Int).Mul(y2, z1) + s2.Mul(s2, z1z1) + s2.Mod(s2, curve.P) + r := new(big.Int).Sub(s2, s1) + if r.Sign() == -1 { + r.Add(r, curve.P) + } + yEqual := r.Sign() == 0 + if xEqual && yEqual { + return curve.doubleJacobian(x1, y1, z1) + } + r.Lsh(r, 1) + v := new(big.Int).Mul(u1, i) + + x3.Set(r) + x3.Mul(x3, x3) + x3.Sub(x3, j) + x3.Sub(x3, v) + x3.Sub(x3, v) + x3.Mod(x3, curve.P) + + y3.Set(r) + v.Sub(v, x3) + y3.Mul(y3, v) + s1.Mul(s1, j) + s1.Lsh(s1, 1) + y3.Sub(y3, s1) + y3.Mod(y3, curve.P) + + z3.Add(z1, z2) + z3.Mul(z3, z3) + z3.Sub(z3, z1z1) + z3.Sub(z3, z2z2) + z3.Mul(z3, h) + z3.Mod(z3, curve.P) + + return x3, y3, z3 +} + +func (curve *CurveParams) Double(x1, y1 *big.Int) (*big.Int, *big.Int) { + z1 := zForAffine(x1, y1) + return curve.affineFromJacobian(curve.doubleJacobian(x1, y1, z1)) +} + +// doubleJacobian takes a point in Jacobian coordinates, (x, y, z), and +// returns its double, also in Jacobian form. +func (curve *CurveParams) doubleJacobian(x, y, z *big.Int) (*big.Int, *big.Int, *big.Int) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + a := new(big.Int).Mul(x, x) + a.Mod(a, curve.P) + b := new(big.Int).Mul(y, y) + b.Mod(b, curve.P) + c := new(big.Int).Mul(b, b) + c.Mod(c, curve.P) + + d := new(big.Int).Add(x, b) + d.Mul(d, d) + d.Sub(d, a) + d.Sub(d, c) + d.Lsh(d, 1) + if d.Sign() < 0 { + d.Add(d, curve.P) + } else { + d.Mod(d, curve.P) + } + + e := new(big.Int).Lsh(a, 1) + e.Add(e, a) + f := new(big.Int).Mul(e, e) + x3 := new(big.Int).Lsh(d, 1) + x3.Sub(f, x3) + if x3.Sign() < 0 { + x3.Add(x3, curve.P) + } else { + x3.Mod(x3, curve.P) + } + + y3 := new(big.Int).Sub(d, x3) + y3.Mul(y3, e) + c.Lsh(c, 3) + y3.Sub(y3, c) + if y3.Sign() < 0 { + y3.Add(y3, curve.P) + } else { + y3.Mod(y3, curve.P) + } + + z3 := new(big.Int).Mul(y, z) + z3.Lsh(z3, 1) + z3.Mod(z3, curve.P) + + return x3, y3, z3 +} + +func (curve *CurveParams) ScalarMult(Bx, By *big.Int, k []byte) (*big.Int, *big.Int) { + Bz := new(big.Int).SetInt64(1) + x, y, z := new(big.Int), new(big.Int), new(big.Int) + + for _, byte := range k { + for bitNum := 0; bitNum < 8; bitNum++ { + x, y, z = curve.doubleJacobian(x, y, z) + if byte&0x80 == 0x80 { + x, y, z = curve.addJacobian(Bx, By, Bz, x, y, z) + } + byte <<= 1 + } + } + + return curve.affineFromJacobian(x, y, z) +} + +func (curve *CurveParams) ScalarBaseMult(k []byte) (*big.Int, *big.Int) { + return curve.ScalarMult(curve.Gx, curve.Gy, k) +} diff --git a/sm9/params_test.go b/sm9/params_test.go new file mode 100644 index 0000000..fcab8f3 --- /dev/null +++ b/sm9/params_test.go @@ -0,0 +1,170 @@ +package sm9 + +import ( + "encoding/hex" + "fmt" + "math/big" + "testing" +) + +var secp256k1Params = &CurveParams{ + Name: "secp256k1", + BitSize: 256, + P: bigFromHex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"), + N: bigFromHex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"), + B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000007"), + Gx: bigFromHex("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"), + Gy: bigFromHex("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"), +} + +var sm9CurveParams = &CurveParams{ + Name: "sm9", + BitSize: 256, + P: bigFromHex("B640000002A3A6F1D603AB4FF58EC74521F2934B1A7AEEDBE56F9B27E351457D"), + N: bigFromHex("B640000002A3A6F1D603AB4FF58EC74449F2934B18EA8BEEE56EE19CD69ECF25"), + B: bigFromHex("0000000000000000000000000000000000000000000000000000000000000005"), + Gx: bigFromHex("93DE051D62BF718FF5ED0704487D01D6E1E4086909DC3280E8C4E4817C66DDDD"), + Gy: bigFromHex("21FE8DDA4F21E607631065125C395BBC1C1C00CBFA6024350C464CD70A3EA616"), +} + +type baseMultTest struct { + k string + x, y string +} + +var s256BaseMultTests = []baseMultTest{ + { + "AA5E28D6A97A2479A65527F7290311A3624D4CC0FA1578598EE3C2613BF99522", + "34F9460F0E4F08393D192B3C5133A6BA099AA0AD9FD54EBCCFACDFA239FF49C6", + "B71EA9BD730FD8923F6D25A7A91E7DD7728A960686CB5A901BB419E0F2CA232", + }, + { + "7E2B897B8CEBC6361663AD410835639826D590F393D90A9538881735256DFAE3", + "D74BF844B0862475103D96A611CF2D898447E288D34B360BC885CB8CE7C00575", + "131C670D414C4546B88AC3FF664611B1C38CEB1C21D76369D7A7A0969D61D97D", + }, + { + "6461E6DF0FE7DFD05329F41BF771B86578143D4DD1F7866FB4CA7E97C5FA945D", + "E8AECC370AEDD953483719A116711963CE201AC3EB21D3F3257BB48668C6A72F", + "C25CAF2F0EBA1DDB2F0F3F47866299EF907867B7D27E95B3873BF98397B24EE1", + }, + { + "376A3A2CDCD12581EFFF13EE4AD44C4044B8A0524C42422A7E1E181E4DEECCEC", + "14890E61FCD4B0BD92E5B36C81372CA6FED471EF3AA60A3E415EE4FE987DABA1", + "297B858D9F752AB42D3BCA67EE0EB6DCD1C2B7B0DBE23397E66ADC272263F982", + }, + { + "1B22644A7BE026548810C378D0B2994EEFA6D2B9881803CB02CEFF865287D1B9", + "F73C65EAD01C5126F28F442D087689BFA08E12763E0CEC1D35B01751FD735ED3", + "F449A8376906482A84ED01479BD18882B919C140D638307F0C0934BA12590BDE", + }, +} + +func TestBaseMult(t *testing.T) { + for i, e := range s256BaseMultTests { + k, ok := new(big.Int).SetString(e.k, 16) + if !ok { + t.Errorf("%d: bad value for k: %s", i, e.k) + } + x, y := secp256k1Params.ScalarBaseMult(k.Bytes()) + if fmt.Sprintf("%X", x) != e.x || fmt.Sprintf("%X", y) != e.y { + t.Errorf("%d: bad output for k=%s: got (%X, %X), want (%s, %s)", i, e.k, x, y, e.x, e.y) + } + } +} + +func TestOnCurve(t *testing.T) { + if !secp256k1Params.IsOnCurve(secp256k1Params.Gx, secp256k1Params.Gy) { + t.Errorf("point is not on curve") + } + if !sm9CurveParams.IsOnCurve(sm9CurveParams.Gx, sm9CurveParams.Gy) { + t.Errorf("point is not on curve") + } +} + +func TestPMode4And8(t *testing.T) { + res := new(big.Int).Mod(sm9CurveParams.P, big.NewInt(4)) + if res.Int64() != 1 { + t.Errorf("p mod 4 != 1") + } + res = new(big.Int).Mod(sm9CurveParams.P, big.NewInt(6)) + if res.Int64() != 1 { + t.Errorf("p mod 6 != 1") + } + res = new(big.Int).Mod(sm9CurveParams.P, big.NewInt(8)) + if res.Int64() != 5 { + t.Errorf("p mod 8 != 5") + } + res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(2)) + if hex.EncodeToString(res.Bytes()) != "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be" { + t.Errorf("expected %v, got %v\n", "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2be", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Add(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(2)) + if hex.EncodeToString(res.Bytes()) != "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2bf" { + t.Errorf("expected %v, got %v\n", "5b2000000151d378eb01d5a7fac763a290f949a58d3d776df2b7cd93f1a8a2bf", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Add(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(3)) + if hex.EncodeToString(res.Bytes()) != "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4" { + t.Errorf("expected %v, got %v\n", "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(4)) + if hex.EncodeToString(res.Bytes()) != "2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f" { + t.Errorf("expected %v, got %v\n", "2d90000000a8e9bc7580ead3fd63b1d1487ca4d2c69ebbb6f95be6c9f8d4515f", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(6)) + if hex.EncodeToString(res.Bytes()) != "1e60000000709bd2f9009c8d5397cbe0dafdc3372f147d24a63d4486a5e2e0ea" { + t.Errorf("expected %v, got %v\n", "1e60000000709bd2f9009c8d5397cbe0dafdc3372f147d24a63d4486a5e2e0ea", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(1)) + res.Div(res, big.NewInt(3)) + if hex.EncodeToString(res.Bytes()) != "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4" { + t.Errorf("expected %v, got %v\n", "3cc0000000e137a5f201391aa72f97c1b5fb866e5e28fa494c7a890d4bc5c1d4", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Mul(sm9CurveParams.P, sm9CurveParams.P) + res.Sub(res, big.NewInt(1)) + res.Div(res, big.NewInt(3)) + if hex.EncodeToString(res.Bytes()) != "2b3fb0000140abbbc71510370c6fa2b194d4665ff95c18014568b07bbd19fb54f0b9aded6fea5b670c35d6b4e3b966415456a4a8503c6361c90d41b4e8a78a58" { + t.Errorf("expected %v, got %v\n", "2b3fb0000140abbbc71510370c6fa2b194d4665ff95c18014568b07bbd19fb54f0b9aded6fea5b670c35d6b4e3b966415456a4a8503c6361c90d41b4e8a78a58", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Mul(sm9CurveParams.P, sm9CurveParams.P) + res.Sub(res, big.NewInt(1)) + res.Div(res, big.NewInt(2)) + if hex.EncodeToString(res.Bytes()) != "40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84" { + t.Errorf("expected %v, got %v\n", "40df880001e10199aa9f985292a7740a5f3e998ff60a2401e81d08b99ba6f8ff691684e427df891a9250c20f55961961fe81f6fc785a9512ad93e28f5cfb4f84", hex.EncodeToString(res.Bytes())) + } + + res = new(big.Int).Sub(sm9CurveParams.P, big.NewInt(5)) + res.Div(res, big.NewInt(8)) + if hex.EncodeToString(res.Bytes()) != "16c80000005474de3ac07569feb1d8e8a43e5269634f5ddb7cadf364fc6a28af" { + t.Errorf("expected %v, got %v\n", "16c80000005474de3ac07569feb1d8e8a43e5269634f5ddb7cadf364fc6a28af", hex.EncodeToString(res.Bytes())) + } + + res.Exp(big.NewInt(2), res, sm9CurveParams.P) + if hex.EncodeToString(res.Bytes()) != "800db90d149e875b5b564505fe88efba5223f2bf170cc61fea968b3df63edd75" { + t.Errorf("expected %v, got %v\n", "800db90d149e875b5b564505fe88efba5223f2bf170cc61fea968b3df63edd75", hex.EncodeToString(res.Bytes())) + } + + res.Mul(u, big.NewInt(6)) + res.Add(res, big.NewInt(5)) + if hex.EncodeToString(res.Bytes()) != "02400000000215d941" { + t.Errorf("expected %v, got %v\n", "02400000000215d941", hex.EncodeToString(res.Bytes())) + } + res.Mul(u, big.NewInt(6)) + res.Mul(res, u) + res.Add(res, big.NewInt(1)) + if hex.EncodeToString(res.Bytes()) != "d8000000019062ed0000b98b0cb27659" { + t.Errorf("expected %v, got %v\n", "d8000000019062ed0000b98b0cb27659", hex.EncodeToString(res.Bytes())) + } +} diff --git a/sm9/sm9.go b/sm9/sm9.go new file mode 100644 index 0000000..a2cd67f --- /dev/null +++ b/sm9/sm9.go @@ -0,0 +1,4 @@ +// Package sm9 handle shangmi sm9 algorithm and its curves and pairing implementation +package sm9 + +// TODO: implement SM9 algorithm based on basic curves, G1/G2/GT and r-ate pairing implementation. \ No newline at end of file diff --git a/sm9/twist.go b/sm9/twist.go new file mode 100644 index 0000000..611ceec --- /dev/null +++ b/sm9/twist.go @@ -0,0 +1,280 @@ +package sm9 + +import "math/big" + +// twistPoint implements the elliptic curve y²=x³+5/ξ (y²=x³+5i) over GF(p²). Points are +// kept in Jacobian form and t=z² when valid. The group G₂ is the set of +// n-torsion points of this curve over GF(p²) (where n = Order) +type twistPoint struct { + x, y, z, t gfP2 +} + +var twistB = &gfP2{ + *newGFp(5), + *zero, +} + +// twistGen is the generator of group G₂. +var twistGen = &twistPoint{ + gfP2{ + *fromBigInt(bigFromHex("85AEF3D078640C98597B6027B441A01FF1DD2C190F5E93C454806C11D8806141")), + *fromBigInt(bigFromHex("3722755292130B08D2AAB97FD34EC120EE265948D19C17ABF9B7213BAF82D65B")), + }, + gfP2{ + *fromBigInt(bigFromHex("17509B092E845C1266BA0D262CBEE6ED0736A96FA347C8BD856DC76B84EBEB96")), + *fromBigInt(bigFromHex("A7CF28D519BE3DA65F3170153D278FF247EFBA98A71A08116215BBA5C999A7C7")), + }, + gfP2{*newGFp(0), *newGFp(1)}, + gfP2{*newGFp(0), *newGFp(1)}, +} + +func (c *twistPoint) String() string { + c.MakeAffine() + x, y := gfP2Decode(&c.x), gfP2Decode(&c.y) + return "(" + x.String() + ", " + y.String() + ")" +} + +func (c *twistPoint) Set(a *twistPoint) { + c.x.Set(&a.x) + c.y.Set(&a.y) + c.z.Set(&a.z) + c.t.Set(&a.t) +} + +// IsOnCurve returns true iff c is on the curve. +func (c *twistPoint) IsOnCurve() bool { + c.MakeAffine() + if c.IsInfinity() { + return true + } + + y2, x3 := &gfP2{}, &gfP2{} + y2.Square(&c.y) + x3.Square(&c.x).Mul(x3, &c.x).Add(x3, twistB) + + return *y2 == *x3 +} + +func (c *twistPoint) SetInfinity() { + c.x.SetZero() + c.y.SetOne() + c.z.SetZero() + c.t.SetZero() +} + +func (c *twistPoint) IsInfinity() bool { + return c.z.IsZero() +} + +func (c *twistPoint) Add(a, b *twistPoint) { + // For additional comments, see the same function in curve.go. + + if a.IsInfinity() { + c.Set(b) + return + } + if b.IsInfinity() { + c.Set(a) + return + } + + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/addition/add-2007-bl.op3 + z12 := (&gfP2{}).Square(&a.z) + z22 := (&gfP2{}).Square(&b.z) + u1 := (&gfP2{}).Mul(&a.x, z22) + u2 := (&gfP2{}).Mul(&b.x, z12) + + t := (&gfP2{}).Mul(&b.z, z22) + s1 := (&gfP2{}).Mul(&a.y, t) + + t.Mul(&a.z, z12) + s2 := (&gfP2{}).Mul(&b.y, t) + + h := (&gfP2{}).Sub(u2, u1) + xEqual := h.IsZero() + + t.Add(h, h) + i := (&gfP2{}).Square(t) + j := (&gfP2{}).Mul(h, i) + + t.Sub(s2, s1) + yEqual := t.IsZero() + if xEqual && yEqual { + c.Double(a) + return + } + r := (&gfP2{}).Add(t, t) + + v := (&gfP2{}).Mul(u1, i) + + t4 := (&gfP2{}).Square(r) + t.Add(v, v) + t6 := (&gfP2{}).Sub(t4, j) + c.x.Sub(t6, t) + + t.Sub(v, &c.x) // t7 + t4.Mul(s1, j) // t8 + t6.Add(t4, t4) // t9 + t4.Mul(r, t) // t10 + c.y.Sub(t4, t6) + + t.Add(&a.z, &b.z) // t11 + t4.Square(t) // t12 + t.Sub(t4, z12) // t13 + t4.Sub(t, z22) // t14 + c.z.Mul(t4, h) +} + +func (c *twistPoint) Double(a *twistPoint) { + // See http://hyperelliptic.org/EFD/g1p/auto-code/shortw/jacobian-0/doubling/dbl-2009-l.op3 + A := (&gfP2{}).Square(&a.x) + B := (&gfP2{}).Square(&a.y) + C := (&gfP2{}).Square(B) + + t := (&gfP2{}).Add(&a.x, B) + t2 := (&gfP2{}).Square(t) + t.Sub(t2, A) + t2.Sub(t, C) + d := (&gfP2{}).Add(t2, t2) + t.Add(A, A) + e := (&gfP2{}).Add(t, A) + f := (&gfP2{}).Square(e) + + t.Add(d, d) + c.x.Sub(f, t) + + c.z.Mul(&a.y, &a.z) + c.z.Add(&c.z, &c.z) + + t.Add(C, C) + t2.Add(t, t) + t.Add(t2, t2) + c.y.Sub(d, &c.x) + t2.Mul(e, &c.y) + c.y.Sub(t2, t) +} + +func (c *twistPoint) Mul(a *twistPoint, scalar *big.Int) { + sum, t := &twistPoint{}, &twistPoint{} + + for i := scalar.BitLen(); i >= 0; i-- { + t.Double(sum) + if scalar.Bit(i) != 0 { + sum.Add(t, a) + } else { + sum.Set(t) + } + } + + c.Set(sum) +} + +func (c *twistPoint) MakeAffine() { + if c.z.IsOne() { + return + } else if c.z.IsZero() { + c.x.SetZero() + c.y.SetOne() + c.t.SetZero() + return + } + + zInv := (&gfP2{}).Invert(&c.z) + t := (&gfP2{}).Mul(&c.y, zInv) + zInv2 := (&gfP2{}).Square(zInv) + c.y.Mul(t, zInv2) + t.Mul(&c.x, zInv2) + c.x.Set(t) + c.z.SetOne() + c.t.SetOne() +} + +func (c *twistPoint) Neg(a *twistPoint) { + c.x.Set(&a.x) + c.y.Neg(&a.y) + c.z.Set(&a.z) + c.t.SetZero() +} + +// code logic is form https://github.com/guanzhi/GmSSL/blob/develop/src/sm9_alg.c +// the value is not same as p*a +func (c *twistPoint) Frobenius(a *twistPoint) { + c.x.Conjugate(&a.x) + c.y.Conjugate(&a.y) + c.z.Conjugate(&a.z) + c.z.MulScalar(&a.z, frobConstant) + c.t.Square(&a.z) +} + +func (c *twistPoint) FrobeniusP2(a *twistPoint) { + c.x.Set(&a.x) + c.y.Set(&a.y) + c.z.MulScalar(&a.z, wToP2Minus1) + c.t.Square(&a.z) +} + +func (c *twistPoint) NegFrobeniusP2(a *twistPoint) { + c.x.Set(&a.x) + c.y.Neg(&a.y) + c.z.MulScalar(&a.z, wToP2Minus1) + c.t.Square(&a.z) +} + +/* +//code logic is from https://github.com/miracl/MIRACL/blob/master/source/curve/pairing/bn_pair.cpp +func (c *twistPoint) Frobenius(a *twistPoint) { + w, r, frob := &gfP2{}, &gfP2{}, &gfP2{} + frob.SetFrobConstant() + w.Square(frob) + + r.Conjugate(&twistGen.x) + r.Mul(r, w) + c.x.Set(r) + + r.Conjugate(&twistGen.y) + r.Mul(r, frob) + r.Mul(r, w) + c.y.Set(r) + + r.Conjugate(&twistGen.z) + c.z.Set(r) + + r.Square(&c.z) + c.t.Set(r) +} + +func (c *twistPoint) FrobeniusP2(a *twistPoint) { + ret := &twistPoint{} + ret.Frobenius(a) + c.Frobenius(ret) +} + +*/ +/* +// code logic from https://github.com/cloudflare/bn256/blob/master/optate.go +func (c *twistPoint) Frobenius(a *twistPoint) { + r := &gfP2{} + r.Conjugate(&a.x) + r.MulScalar(r, xiToPMinus1Over3) + c.x.Set(r) + r.Conjugate(&a.y) + r.MulScalar(r, xiToPMinus1Over2) + c.y.Set(r) + c.z.SetOne() + c.t.SetOne() +} + +func (c *twistPoint) FrobeniusP2(a *twistPoint) { + c.x.MulScalar(&a.x, xiToPSquaredMinus1Over3) + c.y.Neg(&a.y) + c.z.SetOne() + c.t.SetOne() +} + +func (c *twistPoint) NegFrobeniusP2(a *twistPoint) { + c.x.MulScalar(&a.x, xiToPSquaredMinus1Over3) + c.y.Set(&a.y) + c.z.SetOne() + c.t.SetOne() +} +*/ diff --git a/sm9/twist_test.go b/sm9/twist_test.go new file mode 100644 index 0000000..3eb4fad --- /dev/null +++ b/sm9/twist_test.go @@ -0,0 +1,110 @@ +package sm9 + +import ( + "testing" +) + +func TestIsOnCurve(t *testing.T) { + if !twistGen.IsOnCurve() { + t.Errorf("twist gen point should be on curve") + } + a := &twistPoint{} + a.SetInfinity() + if !a.IsOnCurve() { + t.Errorf("infinity zero point should be on curve") + } +} + +func TestAddNeg(t *testing.T) { + neg := &twistPoint{} + neg.Neg(twistGen) + res := &twistPoint{} + res.Add(twistGen, neg) + if !res.IsInfinity() { + t.Errorf("a add its neg should be zero") + } +} + +func Test_TwistFrobeniusP(t *testing.T) { + ret1, ret2 := &twistPoint{}, &twistPoint{} + ret1.Frobenius(twistGen) + ret1.MakeAffine() + + ret2.x.Conjugate(&twistGen.x) + ret2.x.MulScalar(&ret2.x, betaToNegPPlus1Over3) + + ret2.y.Conjugate(&twistGen.y) + ret2.y.MulScalar(&ret2.y, betaToNegPPlus1Over2) + ret2.z.SetOne() + ret2.t.SetOne() + if !ret2.IsOnCurve() { + t.Errorf("point should be on curve") + } + + if ret1.x != ret2.x || ret1.y != ret2.y || ret1.z != ret2.z || ret1.t != ret2.t { + t.Errorf("not same") + } +} + +func Test_TwistFrobeniusP2(t *testing.T) { + ret1, ret2 := &twistPoint{}, &twistPoint{} + ret1.Frobenius(twistGen) + ret1.Frobenius(ret1) + if !ret1.IsOnCurve() { + t.Errorf("point should be on curve") + } + + ret2.FrobeniusP2(twistGen) + if !ret2.IsOnCurve() { + t.Errorf("point should be on curve") + } + if ret1.x != ret2.x || ret1.y != ret2.y || ret1.z != ret2.z || ret1.t != ret2.t { + t.Errorf("not same") + } +} + +func Test_TwistFrobeniusP2_Case2(t *testing.T) { + ret1, ret2 := &twistPoint{}, &twistPoint{} + ret1.x.Set(&twistGen.x) + ret1.x.MulScalar(&ret1.x, betaToNegP2Plus1Over3) + + ret1.y.Set(&twistGen.y) + ret1.y.MulScalar(&ret1.y, betaToNegP2Plus1Over2) + ret1.z.SetOne() + ret1.t.SetOne() + if !ret1.IsOnCurve() { + t.Errorf("point should be on curve") + } + + ret2.FrobeniusP2(twistGen) + ret2.MakeAffine() + if !ret2.IsOnCurve() { + t.Errorf("point should be on curve") + } + if ret1.x != ret2.x || ret1.y != ret2.y || ret1.z != ret2.z || ret1.t != ret2.t { + t.Errorf("not same") + } +} + +func Test_TwistNegFrobeniusP2_Case2(t *testing.T) { + ret1, ret2 := &twistPoint{}, &twistPoint{} + ret1.x.Set(&twistGen.x) + ret1.x.MulScalar(&ret1.x, betaToNegP2Plus1Over3) + + ret1.y.Neg(&twistGen.y) + ret1.y.MulScalar(&ret1.y, betaToNegP2Plus1Over2) + ret1.z.SetOne() + ret1.t.SetOne() + if !ret1.IsOnCurve() { + t.Errorf("point should be on curve") + } + + ret2.NegFrobeniusP2(twistGen) + ret2.MakeAffine() + if !ret2.IsOnCurve() { + t.Errorf("point should be on curve") + } + if ret1.x != ret2.x || ret1.y != ret2.y || ret1.z != ret2.z || ret1.t != ret2.t { + t.Errorf("not same") + } +}