From 3cbabc3d1ce30f4f02b24bf4f3ae3fa6949b4647 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 16 Jun 2023 16:06:38 +0800 Subject: [PATCH] optimize purego sm3/sm4 and reduce bounds checking for sm9 --- sm3/sm3block.go | 16 +++++++++------- sm4/block.go | 7 ++++--- sm9/bn256/g1.go | 45 ++++++++++++++++++++++++++++++++++++++------- sm9/bn256/g2.go | 42 ++++++++++++++++++++++++++++++++++++------ sm9/sm9.go | 32 ++++++++++++++++++++------------ 5 files changed, 107 insertions(+), 35 deletions(-) diff --git a/sm3/sm3block.go b/sm3/sm3block.go index d45e66e..e8963d1 100644 --- a/sm3/sm3block.go +++ b/sm3/sm3block.go @@ -2,10 +2,12 @@ package sm3 import "math/bits" -var _T = [2]uint32{ - 0x79cc4519, - 0x7a879d8a, -} + +const ( + _T0 = 0x79cc4519 + _T1 = 0x7a879d8a +) + func p1(x uint32) uint32 { return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23) @@ -39,7 +41,7 @@ func blockGeneric(dig *digest, p []byte) { // handle first 12 rounds state for i := 0; i < 12; i++ { - ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7) ss2 := ss1 ^ bits.RotateLeft32(a, 12) tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) tt2 := e ^ f ^ g + h + ss1 + w[i] @@ -56,7 +58,7 @@ func blockGeneric(dig *digest, p []byte) { // handle next 4 rounds state for i := 12; i < 16; i++ { w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] - ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7) ss2 := ss1 ^ bits.RotateLeft32(a, 12) tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) tt2 := e ^ f ^ g + h + ss1 + w[i] @@ -73,7 +75,7 @@ func blockGeneric(dig *digest, p []byte) { // handle last 48 rounds state for i := 16; i < 64; i++ { w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] - ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[1], i), 7) + ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T1, i), 7) ss2 := ss1 ^ bits.RotateLeft32(a, 12) tt1 := (a & b) | (a & c) | (b & c) + d + ss2 + (w[i] ^ w[i+4]) tt2 := (e & f) | (^e & g) + h + ss1 + w[i] diff --git a/sm4/block.go b/sm4/block.go index 14bf11e..7529a49 100644 --- a/sm4/block.go +++ b/sm4/block.go @@ -4,7 +4,6 @@ package sm4 import ( "encoding/binary" - "math/bits" ) // Encrypt one block from src into dst, using the expanded key xk. @@ -164,7 +163,8 @@ func t(in uint32) uint32 { b |= uint32(sbox[(in>>16)&0xff]) << 16 b |= uint32(sbox[(in>>24)&0xff]) << 24 - return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24) + // L + return b ^ (b<<2 | b>>30) ^ (b<<10 | b>>22) ^ (b<<18 | b>>14) ^ (b<<24 | b>>8) } // T' @@ -176,7 +176,8 @@ func t2(in uint32) uint32 { b |= uint32(sbox[(in>>16)&0xff]) << 16 b |= uint32(sbox[(in>>24)&0xff]) << 24 - return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23) + // L2 + return b ^ (b<<13 | b>>19) ^ (b<<23 | b>>9) } func precompute_t(in uint32) uint32 { diff --git a/sm9/bn256/g1.go b/sm9/bn256/g1.go index a1777aa..c03a35b 100644 --- a/sm9/bn256/g1.go +++ b/sm9/bn256/g1.go @@ -24,7 +24,7 @@ type G1 struct { p *curvePoint } -//Gen1 is the generator of G1. +// Gen1 is the generator of G1. var Gen1 = &G1{curveGen} var g1GeneratorTable *[32 * 2]curvePointTable @@ -37,12 +37,42 @@ func (g *G1) generatorTable() *[32 * 2]curvePointTable { for i := 0; i < 32*2; i++ { g1GeneratorTable[i][0] = &curvePoint{} g1GeneratorTable[i][0].Set(base) - for j := 1; j < 15; j += 2 { - g1GeneratorTable[i][j] = &curvePoint{} - g1GeneratorTable[i][j].Double(g1GeneratorTable[i][j/2]) - g1GeneratorTable[i][j+1] = &curvePoint{} - g1GeneratorTable[i][j+1].Add(g1GeneratorTable[i][j], base) - } + + g1GeneratorTable[i][1] = &curvePoint{} + g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0]) + g1GeneratorTable[i][2] = &curvePoint{} + g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base) + + g1GeneratorTable[i][3] = &curvePoint{} + g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1]) + g1GeneratorTable[i][4] = &curvePoint{} + g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base) + + g1GeneratorTable[i][5] = &curvePoint{} + g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2]) + g1GeneratorTable[i][6] = &curvePoint{} + g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base) + + g1GeneratorTable[i][7] = &curvePoint{} + g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3]) + g1GeneratorTable[i][8] = &curvePoint{} + g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base) + + g1GeneratorTable[i][9] = &curvePoint{} + g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4]) + g1GeneratorTable[i][10] = &curvePoint{} + g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base) + + g1GeneratorTable[i][11] = &curvePoint{} + g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5]) + g1GeneratorTable[i][12] = &curvePoint{} + g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base) + + g1GeneratorTable[i][13] = &curvePoint{} + g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6]) + g1GeneratorTable[i][14] = &curvePoint{} + g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base) + base.Double(base) base.Double(base) base.Double(base) @@ -229,6 +259,7 @@ func (e *G1) MarshalCompressed() []byte { e.p.MakeAffine() temp := &gfP{} montDecode(temp, &e.p.y) + temp.Marshal(ret[1:]) ret[0] = (ret[numBytes] & 1) | 2 montDecode(temp, &e.p.x) diff --git a/sm9/bn256/g2.go b/sm9/bn256/g2.go index ba13436..ee93578 100644 --- a/sm9/bn256/g2.go +++ b/sm9/bn256/g2.go @@ -26,12 +26,42 @@ func (g *G2) generatorTable() *[32 * 2]twistPointTable { for i := 0; i < 32*2; i++ { g2GeneratorTable[i][0] = &twistPoint{} g2GeneratorTable[i][0].Set(base) - for j := 1; j < 15; j += 2 { - g2GeneratorTable[i][j] = &twistPoint{} - g2GeneratorTable[i][j].Double(g2GeneratorTable[i][j/2]) - g2GeneratorTable[i][j+1] = &twistPoint{} - g2GeneratorTable[i][j+1].Add(g2GeneratorTable[i][j], base) - } + + g2GeneratorTable[i][1] = &twistPoint{} + g2GeneratorTable[i][1].Double(g2GeneratorTable[i][0]) + g2GeneratorTable[i][2] = &twistPoint{} + g2GeneratorTable[i][2].Add(g2GeneratorTable[i][1], base) + + g2GeneratorTable[i][3] = &twistPoint{} + g2GeneratorTable[i][3].Double(g2GeneratorTable[i][1]) + g2GeneratorTable[i][4] = &twistPoint{} + g2GeneratorTable[i][4].Add(g2GeneratorTable[i][3], base) + + g2GeneratorTable[i][5] = &twistPoint{} + g2GeneratorTable[i][5].Double(g2GeneratorTable[i][2]) + g2GeneratorTable[i][6] = &twistPoint{} + g2GeneratorTable[i][6].Add(g2GeneratorTable[i][5], base) + + g2GeneratorTable[i][7] = &twistPoint{} + g2GeneratorTable[i][7].Double(g2GeneratorTable[i][3]) + g2GeneratorTable[i][8] = &twistPoint{} + g2GeneratorTable[i][8].Add(g2GeneratorTable[i][7], base) + + g2GeneratorTable[i][9] = &twistPoint{} + g2GeneratorTable[i][9].Double(g2GeneratorTable[i][4]) + g2GeneratorTable[i][10] = &twistPoint{} + g2GeneratorTable[i][10].Add(g2GeneratorTable[i][9], base) + + g2GeneratorTable[i][11] = &twistPoint{} + g2GeneratorTable[i][11].Double(g2GeneratorTable[i][5]) + g2GeneratorTable[i][12] = &twistPoint{} + g2GeneratorTable[i][12].Add(g2GeneratorTable[i][11], base) + + g2GeneratorTable[i][13] = &twistPoint{} + g2GeneratorTable[i][13].Double(g2GeneratorTable[i][6]) + g2GeneratorTable[i][14] = &twistPoint{} + g2GeneratorTable[i][14].Add(g2GeneratorTable[i][13], base) + base.Double(base) base.Double(base) base.Double(base) diff --git a/sm9/sm9.go b/sm9/sm9.go index e2c22c6..c9dd290 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -49,22 +49,27 @@ const ( ENC_TYPE_CFB encryptType = 8 ) -//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. +// hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. func hash(z []byte, h hashMode) *bigmod.Nat { md := sm3.New() var ha [64]byte var countBytes [4]byte var ct uint32 = 1 - for i := 0; i < 2; i++ { - binary.BigEndian.PutUint32(countBytes[:], ct) - md.Write([]byte{byte(h)}) - md.Write(z) - md.Write(countBytes[:]) - copy(ha[i*sm3.Size:], md.Sum(nil)) - ct++ - md.Reset() - } + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write([]byte{byte(h)}) + md.Write(z) + md.Write(countBytes[:]) + copy(ha[:], md.Sum(nil)) + ct++ + md.Reset() + + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write([]byte{byte(h)}) + md.Write(z) + md.Write(countBytes[:]) + copy(ha[sm3.Size:], md.Sum(nil)) + k := new(big.Int).SetBytes(ha[:40]) kNat := bigmod.NewNat().SetBig(k) kNat = bigmod.NewNat().ModNat(kNat, orderMinus1) @@ -469,6 +474,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts return nil, ErrDecryption } + _ = c3c2[sm3.Size] // bounds check elimination hint + c3 := c3c2[:sm3.Size] c2 := c3c2[sm3.Size:] key1Len := opts.GetKeySize(c2) @@ -476,8 +483,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts if err != nil { return nil, err } - - return decrypt(c, key[:key1Len], key[key1Len:], c2, c3c2[:sm3.Size], opts) + _ = key[key1Len] // bounds check elimination hint + return decrypt(c, key[:key1Len], key[key1Len:], c2, c3, opts) } func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) { @@ -532,6 +539,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error return nil, err } + _ = key[key1Len] // bounds check elimination hint return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts) }