optimize purego sm3/sm4 and reduce bounds checking for sm9

This commit is contained in:
Sun Yimin 2023-06-16 16:06:38 +08:00 committed by GitHub
parent 3bd048c903
commit 3cbabc3d1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 107 additions and 35 deletions

View File

@ -2,10 +2,12 @@ package sm3
import "math/bits" import "math/bits"
var _T = [2]uint32{
0x79cc4519, const (
0x7a879d8a, _T0 = 0x79cc4519
} _T1 = 0x7a879d8a
)
func p1(x uint32) uint32 { func p1(x uint32) uint32 {
return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23) return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23)
@ -39,7 +41,7 @@ func blockGeneric(dig *digest, p []byte) {
// handle first 12 rounds state // handle first 12 rounds state
for i := 0; i < 12; i++ { for i := 0; i < 12; i++ {
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7)
ss2 := ss1 ^ bits.RotateLeft32(a, 12) ss2 := ss1 ^ bits.RotateLeft32(a, 12)
tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4])
tt2 := e ^ f ^ g + h + ss1 + w[i] tt2 := e ^ f ^ g + h + ss1 + w[i]
@ -56,7 +58,7 @@ func blockGeneric(dig *digest, p []byte) {
// handle next 4 rounds state // handle next 4 rounds state
for i := 12; i < 16; i++ { for i := 12; i < 16; i++ {
w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2]
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7) ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7)
ss2 := ss1 ^ bits.RotateLeft32(a, 12) ss2 := ss1 ^ bits.RotateLeft32(a, 12)
tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4]) tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4])
tt2 := e ^ f ^ g + h + ss1 + w[i] tt2 := e ^ f ^ g + h + ss1 + w[i]
@ -73,7 +75,7 @@ func blockGeneric(dig *digest, p []byte) {
// handle last 48 rounds state // handle last 48 rounds state
for i := 16; i < 64; i++ { for i := 16; i < 64; i++ {
w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2] w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2]
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[1], i), 7) ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T1, i), 7)
ss2 := ss1 ^ bits.RotateLeft32(a, 12) ss2 := ss1 ^ bits.RotateLeft32(a, 12)
tt1 := (a & b) | (a & c) | (b & c) + d + ss2 + (w[i] ^ w[i+4]) tt1 := (a & b) | (a & c) | (b & c) + d + ss2 + (w[i] ^ w[i+4])
tt2 := (e & f) | (^e & g) + h + ss1 + w[i] tt2 := (e & f) | (^e & g) + h + ss1 + w[i]

View File

@ -4,7 +4,6 @@ package sm4
import ( import (
"encoding/binary" "encoding/binary"
"math/bits"
) )
// Encrypt one block from src into dst, using the expanded key xk. // Encrypt one block from src into dst, using the expanded key xk.
@ -164,7 +163,8 @@ func t(in uint32) uint32 {
b |= uint32(sbox[(in>>16)&0xff]) << 16 b |= uint32(sbox[(in>>16)&0xff]) << 16
b |= uint32(sbox[(in>>24)&0xff]) << 24 b |= uint32(sbox[(in>>24)&0xff]) << 24
return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24) // L
return b ^ (b<<2 | b>>30) ^ (b<<10 | b>>22) ^ (b<<18 | b>>14) ^ (b<<24 | b>>8)
} }
// T' // T'
@ -176,7 +176,8 @@ func t2(in uint32) uint32 {
b |= uint32(sbox[(in>>16)&0xff]) << 16 b |= uint32(sbox[(in>>16)&0xff]) << 16
b |= uint32(sbox[(in>>24)&0xff]) << 24 b |= uint32(sbox[(in>>24)&0xff]) << 24
return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23) // L2
return b ^ (b<<13 | b>>19) ^ (b<<23 | b>>9)
} }
func precompute_t(in uint32) uint32 { func precompute_t(in uint32) uint32 {

View File

@ -24,7 +24,7 @@ type G1 struct {
p *curvePoint p *curvePoint
} }
//Gen1 is the generator of G1. // Gen1 is the generator of G1.
var Gen1 = &G1{curveGen} var Gen1 = &G1{curveGen}
var g1GeneratorTable *[32 * 2]curvePointTable var g1GeneratorTable *[32 * 2]curvePointTable
@ -37,12 +37,42 @@ func (g *G1) generatorTable() *[32 * 2]curvePointTable {
for i := 0; i < 32*2; i++ { for i := 0; i < 32*2; i++ {
g1GeneratorTable[i][0] = &curvePoint{} g1GeneratorTable[i][0] = &curvePoint{}
g1GeneratorTable[i][0].Set(base) g1GeneratorTable[i][0].Set(base)
for j := 1; j < 15; j += 2 {
g1GeneratorTable[i][j] = &curvePoint{} g1GeneratorTable[i][1] = &curvePoint{}
g1GeneratorTable[i][j].Double(g1GeneratorTable[i][j/2]) g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0])
g1GeneratorTable[i][j+1] = &curvePoint{} g1GeneratorTable[i][2] = &curvePoint{}
g1GeneratorTable[i][j+1].Add(g1GeneratorTable[i][j], base) g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base)
}
g1GeneratorTable[i][3] = &curvePoint{}
g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1])
g1GeneratorTable[i][4] = &curvePoint{}
g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base)
g1GeneratorTable[i][5] = &curvePoint{}
g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2])
g1GeneratorTable[i][6] = &curvePoint{}
g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base)
g1GeneratorTable[i][7] = &curvePoint{}
g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3])
g1GeneratorTable[i][8] = &curvePoint{}
g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base)
g1GeneratorTable[i][9] = &curvePoint{}
g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4])
g1GeneratorTable[i][10] = &curvePoint{}
g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base)
g1GeneratorTable[i][11] = &curvePoint{}
g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5])
g1GeneratorTable[i][12] = &curvePoint{}
g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base)
g1GeneratorTable[i][13] = &curvePoint{}
g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6])
g1GeneratorTable[i][14] = &curvePoint{}
g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base)
base.Double(base) base.Double(base)
base.Double(base) base.Double(base)
base.Double(base) base.Double(base)
@ -229,6 +259,7 @@ func (e *G1) MarshalCompressed() []byte {
e.p.MakeAffine() e.p.MakeAffine()
temp := &gfP{} temp := &gfP{}
montDecode(temp, &e.p.y) montDecode(temp, &e.p.y)
temp.Marshal(ret[1:]) temp.Marshal(ret[1:])
ret[0] = (ret[numBytes] & 1) | 2 ret[0] = (ret[numBytes] & 1) | 2
montDecode(temp, &e.p.x) montDecode(temp, &e.p.x)

View File

@ -26,12 +26,42 @@ func (g *G2) generatorTable() *[32 * 2]twistPointTable {
for i := 0; i < 32*2; i++ { for i := 0; i < 32*2; i++ {
g2GeneratorTable[i][0] = &twistPoint{} g2GeneratorTable[i][0] = &twistPoint{}
g2GeneratorTable[i][0].Set(base) g2GeneratorTable[i][0].Set(base)
for j := 1; j < 15; j += 2 {
g2GeneratorTable[i][j] = &twistPoint{} g2GeneratorTable[i][1] = &twistPoint{}
g2GeneratorTable[i][j].Double(g2GeneratorTable[i][j/2]) g2GeneratorTable[i][1].Double(g2GeneratorTable[i][0])
g2GeneratorTable[i][j+1] = &twistPoint{} g2GeneratorTable[i][2] = &twistPoint{}
g2GeneratorTable[i][j+1].Add(g2GeneratorTable[i][j], base) g2GeneratorTable[i][2].Add(g2GeneratorTable[i][1], base)
}
g2GeneratorTable[i][3] = &twistPoint{}
g2GeneratorTable[i][3].Double(g2GeneratorTable[i][1])
g2GeneratorTable[i][4] = &twistPoint{}
g2GeneratorTable[i][4].Add(g2GeneratorTable[i][3], base)
g2GeneratorTable[i][5] = &twistPoint{}
g2GeneratorTable[i][5].Double(g2GeneratorTable[i][2])
g2GeneratorTable[i][6] = &twistPoint{}
g2GeneratorTable[i][6].Add(g2GeneratorTable[i][5], base)
g2GeneratorTable[i][7] = &twistPoint{}
g2GeneratorTable[i][7].Double(g2GeneratorTable[i][3])
g2GeneratorTable[i][8] = &twistPoint{}
g2GeneratorTable[i][8].Add(g2GeneratorTable[i][7], base)
g2GeneratorTable[i][9] = &twistPoint{}
g2GeneratorTable[i][9].Double(g2GeneratorTable[i][4])
g2GeneratorTable[i][10] = &twistPoint{}
g2GeneratorTable[i][10].Add(g2GeneratorTable[i][9], base)
g2GeneratorTable[i][11] = &twistPoint{}
g2GeneratorTable[i][11].Double(g2GeneratorTable[i][5])
g2GeneratorTable[i][12] = &twistPoint{}
g2GeneratorTable[i][12].Add(g2GeneratorTable[i][11], base)
g2GeneratorTable[i][13] = &twistPoint{}
g2GeneratorTable[i][13].Double(g2GeneratorTable[i][6])
g2GeneratorTable[i][14] = &twistPoint{}
g2GeneratorTable[i][14].Add(g2GeneratorTable[i][13], base)
base.Double(base) base.Double(base)
base.Double(base) base.Double(base)
base.Double(base) base.Double(base)

View File

@ -49,22 +49,27 @@ const (
ENC_TYPE_CFB encryptType = 8 ENC_TYPE_CFB encryptType = 8
) )
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. // hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
func hash(z []byte, h hashMode) *bigmod.Nat { func hash(z []byte, h hashMode) *bigmod.Nat {
md := sm3.New() md := sm3.New()
var ha [64]byte var ha [64]byte
var countBytes [4]byte var countBytes [4]byte
var ct uint32 = 1 var ct uint32 = 1
for i := 0; i < 2; i++ {
binary.BigEndian.PutUint32(countBytes[:], ct) binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write([]byte{byte(h)}) md.Write([]byte{byte(h)})
md.Write(z) md.Write(z)
md.Write(countBytes[:]) md.Write(countBytes[:])
copy(ha[i*sm3.Size:], md.Sum(nil)) copy(ha[:], md.Sum(nil))
ct++ ct++
md.Reset() md.Reset()
}
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write([]byte{byte(h)})
md.Write(z)
md.Write(countBytes[:])
copy(ha[sm3.Size:], md.Sum(nil))
k := new(big.Int).SetBytes(ha[:40]) k := new(big.Int).SetBytes(ha[:40])
kNat := bigmod.NewNat().SetBig(k) kNat := bigmod.NewNat().SetBig(k)
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1) kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
@ -469,6 +474,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts
return nil, ErrDecryption return nil, ErrDecryption
} }
_ = c3c2[sm3.Size] // bounds check elimination hint
c3 := c3c2[:sm3.Size]
c2 := c3c2[sm3.Size:] c2 := c3c2[sm3.Size:]
key1Len := opts.GetKeySize(c2) key1Len := opts.GetKeySize(c2)
@ -476,8 +483,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts
if err != nil { if err != nil {
return nil, err return nil, err
} }
_ = key[key1Len] // bounds check elimination hint
return decrypt(c, key[:key1Len], key[key1Len:], c2, c3c2[:sm3.Size], opts) return decrypt(c, key[:key1Len], key[key1Len:], c2, c3, opts)
} }
func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) { func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) {
@ -532,6 +539,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error
return nil, err return nil, err
} }
_ = key[key1Len] // bounds check elimination hint
return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts) return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts)
} }