mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 12:16:20 +08:00
optimize purego sm3/sm4 and reduce bounds checking for sm9
This commit is contained in:
parent
3bd048c903
commit
3cbabc3d1c
@ -2,10 +2,12 @@ package sm3
|
||||
|
||||
import "math/bits"
|
||||
|
||||
var _T = [2]uint32{
|
||||
0x79cc4519,
|
||||
0x7a879d8a,
|
||||
}
|
||||
|
||||
const (
|
||||
_T0 = 0x79cc4519
|
||||
_T1 = 0x7a879d8a
|
||||
)
|
||||
|
||||
|
||||
func p1(x uint32) uint32 {
|
||||
return x ^ bits.RotateLeft32(x, 15) ^ bits.RotateLeft32(x, 23)
|
||||
@ -39,7 +41,7 @@ func blockGeneric(dig *digest, p []byte) {
|
||||
|
||||
// handle first 12 rounds state
|
||||
for i := 0; i < 12; i++ {
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7)
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7)
|
||||
ss2 := ss1 ^ bits.RotateLeft32(a, 12)
|
||||
tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4])
|
||||
tt2 := e ^ f ^ g + h + ss1 + w[i]
|
||||
@ -56,7 +58,7 @@ func blockGeneric(dig *digest, p []byte) {
|
||||
// handle next 4 rounds state
|
||||
for i := 12; i < 16; i++ {
|
||||
w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2]
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[0], i), 7)
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T0, i), 7)
|
||||
ss2 := ss1 ^ bits.RotateLeft32(a, 12)
|
||||
tt1 := a ^ b ^ c + d + ss2 + (w[i] ^ w[i+4])
|
||||
tt2 := e ^ f ^ g + h + ss1 + w[i]
|
||||
@ -73,7 +75,7 @@ func blockGeneric(dig *digest, p []byte) {
|
||||
// handle last 48 rounds state
|
||||
for i := 16; i < 64; i++ {
|
||||
w[i+4] = p1(w[i-12]^w[i-5]^bits.RotateLeft32(w[i+1], 15)) ^ bits.RotateLeft32(w[i-9], 7) ^ w[i-2]
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T[1], i), 7)
|
||||
ss1 := bits.RotateLeft32(bits.RotateLeft32(a, 12)+e+bits.RotateLeft32(_T1, i), 7)
|
||||
ss2 := ss1 ^ bits.RotateLeft32(a, 12)
|
||||
tt1 := (a & b) | (a & c) | (b & c) + d + ss2 + (w[i] ^ w[i+4])
|
||||
tt2 := (e & f) | (^e & g) + h + ss1 + w[i]
|
||||
|
@ -4,7 +4,6 @@ package sm4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// Encrypt one block from src into dst, using the expanded key xk.
|
||||
@ -164,7 +163,8 @@ func t(in uint32) uint32 {
|
||||
b |= uint32(sbox[(in>>16)&0xff]) << 16
|
||||
b |= uint32(sbox[(in>>24)&0xff]) << 24
|
||||
|
||||
return b ^ bits.RotateLeft32(b, 2) ^ bits.RotateLeft32(b, 10) ^ bits.RotateLeft32(b, 18) ^ bits.RotateLeft32(b, 24)
|
||||
// L
|
||||
return b ^ (b<<2 | b>>30) ^ (b<<10 | b>>22) ^ (b<<18 | b>>14) ^ (b<<24 | b>>8)
|
||||
}
|
||||
|
||||
// T'
|
||||
@ -176,7 +176,8 @@ func t2(in uint32) uint32 {
|
||||
b |= uint32(sbox[(in>>16)&0xff]) << 16
|
||||
b |= uint32(sbox[(in>>24)&0xff]) << 24
|
||||
|
||||
return b ^ bits.RotateLeft32(b, 13) ^ bits.RotateLeft32(b, 23)
|
||||
// L2
|
||||
return b ^ (b<<13 | b>>19) ^ (b<<23 | b>>9)
|
||||
}
|
||||
|
||||
func precompute_t(in uint32) uint32 {
|
||||
|
@ -24,7 +24,7 @@ type G1 struct {
|
||||
p *curvePoint
|
||||
}
|
||||
|
||||
//Gen1 is the generator of G1.
|
||||
// Gen1 is the generator of G1.
|
||||
var Gen1 = &G1{curveGen}
|
||||
|
||||
var g1GeneratorTable *[32 * 2]curvePointTable
|
||||
@ -37,12 +37,42 @@ func (g *G1) generatorTable() *[32 * 2]curvePointTable {
|
||||
for i := 0; i < 32*2; i++ {
|
||||
g1GeneratorTable[i][0] = &curvePoint{}
|
||||
g1GeneratorTable[i][0].Set(base)
|
||||
for j := 1; j < 15; j += 2 {
|
||||
g1GeneratorTable[i][j] = &curvePoint{}
|
||||
g1GeneratorTable[i][j].Double(g1GeneratorTable[i][j/2])
|
||||
g1GeneratorTable[i][j+1] = &curvePoint{}
|
||||
g1GeneratorTable[i][j+1].Add(g1GeneratorTable[i][j], base)
|
||||
}
|
||||
|
||||
g1GeneratorTable[i][1] = &curvePoint{}
|
||||
g1GeneratorTable[i][1].Double(g1GeneratorTable[i][0])
|
||||
g1GeneratorTable[i][2] = &curvePoint{}
|
||||
g1GeneratorTable[i][2].Add(g1GeneratorTable[i][1], base)
|
||||
|
||||
g1GeneratorTable[i][3] = &curvePoint{}
|
||||
g1GeneratorTable[i][3].Double(g1GeneratorTable[i][1])
|
||||
g1GeneratorTable[i][4] = &curvePoint{}
|
||||
g1GeneratorTable[i][4].Add(g1GeneratorTable[i][3], base)
|
||||
|
||||
g1GeneratorTable[i][5] = &curvePoint{}
|
||||
g1GeneratorTable[i][5].Double(g1GeneratorTable[i][2])
|
||||
g1GeneratorTable[i][6] = &curvePoint{}
|
||||
g1GeneratorTable[i][6].Add(g1GeneratorTable[i][5], base)
|
||||
|
||||
g1GeneratorTable[i][7] = &curvePoint{}
|
||||
g1GeneratorTable[i][7].Double(g1GeneratorTable[i][3])
|
||||
g1GeneratorTable[i][8] = &curvePoint{}
|
||||
g1GeneratorTable[i][8].Add(g1GeneratorTable[i][7], base)
|
||||
|
||||
g1GeneratorTable[i][9] = &curvePoint{}
|
||||
g1GeneratorTable[i][9].Double(g1GeneratorTable[i][4])
|
||||
g1GeneratorTable[i][10] = &curvePoint{}
|
||||
g1GeneratorTable[i][10].Add(g1GeneratorTable[i][9], base)
|
||||
|
||||
g1GeneratorTable[i][11] = &curvePoint{}
|
||||
g1GeneratorTable[i][11].Double(g1GeneratorTable[i][5])
|
||||
g1GeneratorTable[i][12] = &curvePoint{}
|
||||
g1GeneratorTable[i][12].Add(g1GeneratorTable[i][11], base)
|
||||
|
||||
g1GeneratorTable[i][13] = &curvePoint{}
|
||||
g1GeneratorTable[i][13].Double(g1GeneratorTable[i][6])
|
||||
g1GeneratorTable[i][14] = &curvePoint{}
|
||||
g1GeneratorTable[i][14].Add(g1GeneratorTable[i][13], base)
|
||||
|
||||
base.Double(base)
|
||||
base.Double(base)
|
||||
base.Double(base)
|
||||
@ -229,6 +259,7 @@ func (e *G1) MarshalCompressed() []byte {
|
||||
e.p.MakeAffine()
|
||||
temp := &gfP{}
|
||||
montDecode(temp, &e.p.y)
|
||||
|
||||
temp.Marshal(ret[1:])
|
||||
ret[0] = (ret[numBytes] & 1) | 2
|
||||
montDecode(temp, &e.p.x)
|
||||
|
@ -26,12 +26,42 @@ func (g *G2) generatorTable() *[32 * 2]twistPointTable {
|
||||
for i := 0; i < 32*2; i++ {
|
||||
g2GeneratorTable[i][0] = &twistPoint{}
|
||||
g2GeneratorTable[i][0].Set(base)
|
||||
for j := 1; j < 15; j += 2 {
|
||||
g2GeneratorTable[i][j] = &twistPoint{}
|
||||
g2GeneratorTable[i][j].Double(g2GeneratorTable[i][j/2])
|
||||
g2GeneratorTable[i][j+1] = &twistPoint{}
|
||||
g2GeneratorTable[i][j+1].Add(g2GeneratorTable[i][j], base)
|
||||
}
|
||||
|
||||
g2GeneratorTable[i][1] = &twistPoint{}
|
||||
g2GeneratorTable[i][1].Double(g2GeneratorTable[i][0])
|
||||
g2GeneratorTable[i][2] = &twistPoint{}
|
||||
g2GeneratorTable[i][2].Add(g2GeneratorTable[i][1], base)
|
||||
|
||||
g2GeneratorTable[i][3] = &twistPoint{}
|
||||
g2GeneratorTable[i][3].Double(g2GeneratorTable[i][1])
|
||||
g2GeneratorTable[i][4] = &twistPoint{}
|
||||
g2GeneratorTable[i][4].Add(g2GeneratorTable[i][3], base)
|
||||
|
||||
g2GeneratorTable[i][5] = &twistPoint{}
|
||||
g2GeneratorTable[i][5].Double(g2GeneratorTable[i][2])
|
||||
g2GeneratorTable[i][6] = &twistPoint{}
|
||||
g2GeneratorTable[i][6].Add(g2GeneratorTable[i][5], base)
|
||||
|
||||
g2GeneratorTable[i][7] = &twistPoint{}
|
||||
g2GeneratorTable[i][7].Double(g2GeneratorTable[i][3])
|
||||
g2GeneratorTable[i][8] = &twistPoint{}
|
||||
g2GeneratorTable[i][8].Add(g2GeneratorTable[i][7], base)
|
||||
|
||||
g2GeneratorTable[i][9] = &twistPoint{}
|
||||
g2GeneratorTable[i][9].Double(g2GeneratorTable[i][4])
|
||||
g2GeneratorTable[i][10] = &twistPoint{}
|
||||
g2GeneratorTable[i][10].Add(g2GeneratorTable[i][9], base)
|
||||
|
||||
g2GeneratorTable[i][11] = &twistPoint{}
|
||||
g2GeneratorTable[i][11].Double(g2GeneratorTable[i][5])
|
||||
g2GeneratorTable[i][12] = &twistPoint{}
|
||||
g2GeneratorTable[i][12].Add(g2GeneratorTable[i][11], base)
|
||||
|
||||
g2GeneratorTable[i][13] = &twistPoint{}
|
||||
g2GeneratorTable[i][13].Double(g2GeneratorTable[i][6])
|
||||
g2GeneratorTable[i][14] = &twistPoint{}
|
||||
g2GeneratorTable[i][14].Add(g2GeneratorTable[i][13], base)
|
||||
|
||||
base.Double(base)
|
||||
base.Double(base)
|
||||
base.Double(base)
|
||||
|
20
sm9/sm9.go
20
sm9/sm9.go
@ -49,22 +49,27 @@ const (
|
||||
ENC_TYPE_CFB encryptType = 8
|
||||
)
|
||||
|
||||
//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
|
||||
// hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm.
|
||||
func hash(z []byte, h hashMode) *bigmod.Nat {
|
||||
md := sm3.New()
|
||||
var ha [64]byte
|
||||
var countBytes [4]byte
|
||||
var ct uint32 = 1
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||
md.Write([]byte{byte(h)})
|
||||
md.Write(z)
|
||||
md.Write(countBytes[:])
|
||||
copy(ha[i*sm3.Size:], md.Sum(nil))
|
||||
copy(ha[:], md.Sum(nil))
|
||||
ct++
|
||||
md.Reset()
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||
md.Write([]byte{byte(h)})
|
||||
md.Write(z)
|
||||
md.Write(countBytes[:])
|
||||
copy(ha[sm3.Size:], md.Sum(nil))
|
||||
|
||||
k := new(big.Int).SetBytes(ha[:40])
|
||||
kNat := bigmod.NewNat().SetBig(k)
|
||||
kNat = bigmod.NewNat().ModNat(kNat, orderMinus1)
|
||||
@ -469,6 +474,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts
|
||||
return nil, ErrDecryption
|
||||
}
|
||||
|
||||
_ = c3c2[sm3.Size] // bounds check elimination hint
|
||||
c3 := c3c2[:sm3.Size]
|
||||
c2 := c3c2[sm3.Size:]
|
||||
key1Len := opts.GetKeySize(c2)
|
||||
|
||||
@ -476,8 +483,8 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte, opts EncrypterOpts
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return decrypt(c, key[:key1Len], key[key1Len:], c2, c3c2[:sm3.Size], opts)
|
||||
_ = key[key1Len] // bounds check elimination hint
|
||||
return decrypt(c, key[:key1Len], key[key1Len:], c2, c3, opts)
|
||||
}
|
||||
|
||||
func decrypt(cipher *bn256.G1, key1, key2, c2, c3 []byte, opts EncrypterOpts) ([]byte, error) {
|
||||
@ -532,6 +539,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_ = key[key1Len] // bounds check elimination hint
|
||||
return decrypt(c, key[:key1Len], key[key1Len:], c2Bytes, c3Bytes, opts)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user