kdf: share Z hash state #220

This commit is contained in:
Sun Yimin 2024-05-15 08:28:47 +08:00 committed by GitHub
parent 57318eaf5b
commit c99ad27ce1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 153 additions and 35 deletions

View File

@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"math/big" "math/big"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/padding" "github.com/emmansun/gmsm/padding"
"github.com/emmansun/gmsm/pkcs" "github.com/emmansun/gmsm/pkcs"
"github.com/emmansun/gmsm/sm2" "github.com/emmansun/gmsm/sm2"
@ -59,7 +58,7 @@ func ParseSM2(password, data []byte) (*sm2.PrivateKey, *smx509.Certificate, erro
if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) { if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) {
return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm) return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm)
} }
ivkey := kdf.Kdf(sm3.New(), password, 32) ivkey := sm3.Kdf(password, 32)
marshalledIV, err := asn1.Marshal(ivkey[:16]) marshalledIV, err := asn1.Marshal(ivkey[:16])
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@ -91,7 +90,7 @@ func MarshalSM2(password []byte, key *sm2.PrivateKey, cert *smx509.Certificate)
if len(password) == 0 { if len(password) == 0 {
return nil, errors.New("cfca: invalid password") return nil, errors.New("cfca: invalid password")
} }
ivkey := kdf.Kdf(sm3.New(), password, 32) ivkey := sm3.Kdf(password, 32)
block, err := sm4.NewCipher(ivkey[16:]) block, err := sm4.NewCipher(ivkey[16:])
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -2,27 +2,48 @@
package kdf package kdf
import ( import (
"encoding"
"encoding/binary" "encoding/binary"
"hash" "hash"
) )
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. // Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
// ANSI-X9.63-KDF // ANSI-X9.63-KDF
func Kdf(md hash.Hash, z []byte, len int) []byte { func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
limit := uint64(len+md.Size()-1) / uint64(md.Size()) baseMD := newHash()
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
if limit >= uint64(1<<32)-1 { if limit >= uint64(1<<32)-1 {
panic("kdf: key length too long") panic("kdf: key length too long")
} }
var countBytes [4]byte var countBytes [4]byte
var ct uint32 = 1 var ct uint32 = 1
var k []byte var k []byte
marshaler, ok := baseMD.(encoding.BinaryMarshaler)
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
for i := 0; i < int(limit); i++ { for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct) binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(z) baseMD.Write(z)
baseMD.Write(countBytes[:])
k = baseMD.Sum(k)
ct++
baseMD.Reset()
}
} else {
baseMD.Write(z)
zstate, _ := marshaler.MarshalBinary()
for i := 0; i < int(limit); i++ {
md := newHash()
err := md.(encoding.BinaryUnmarshaler).UnmarshalBinary(zstate)
if err != nil {
panic(err)
}
binary.BigEndian.PutUint32(countBytes[:], ct)
md.Write(countBytes[:]) md.Write(countBytes[:])
k = md.Sum(k) k = md.Sum(k)
ct++ ct++
md.Reset()
} }
return k[:len] }
return k[:keyLen]
} }

View File

@ -11,6 +11,6 @@ import (
// This case should be failed on 32bits system. // This case should be failed on 32bits system.
func TestKdfPanic(t *testing.T) { func TestKdfPanic(t *testing.T) {
shouldPanic(t, func() { shouldPanic(t, func() {
Kdf(sm3.New(), []byte("123456"), 1<<37) Kdf(sm3.New, []byte("123456"), 1<<37)
}) })
} }

View File

@ -31,7 +31,7 @@ func TestKdf(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want) wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) { if got := Kdf(sm3.New, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want) t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
} }
}) })
@ -44,7 +44,7 @@ func TestKdfOldCase(t *testing.T) {
expected := "006e30dae231b071dfad8aa379e90264491603" expected := "006e30dae231b071dfad8aa379e90264491603"
result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19) result := Kdf(sm3.New, append(x2.Bytes(), y2.Bytes()...), 19)
resultStr := hex.EncodeToString(result) resultStr := hex.EncodeToString(result)
@ -71,16 +71,17 @@ func BenchmarkKdf(b *testing.B) {
{64, 32}, {64, 32},
{64, 64}, {64, 64},
{64, 128}, {64, 128},
{440, 32}, {64, 256},
{64, 512},
{64, 1024},
} }
sm3Hash := sm3.New()
z := make([]byte, 512) z := make([]byte, 512)
for _, tt := range tests { for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) { b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
Kdf(sm3Hash, z[:tt.zLen], tt.kLen) Kdf(sm3.New, z[:tt.zLen], tt.kLen)
} }
}) })
} }

View File

@ -25,7 +25,6 @@ import (
"github.com/emmansun/gmsm/internal/randutil" "github.com/emmansun/gmsm/internal/randutil"
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec" _sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
@ -251,7 +250,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt
return nil, err return nil, err
} }
C2Bytes := C2.Bytes()[1:] C2Bytes := C2.Bytes()[1:]
c2 := kdf.Kdf(sm3.New(), C2Bytes, len(msg)) c2 := sm3.Kdf(C2Bytes, len(msg))
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
retryCount++ retryCount++
if retryCount > maxRetryLimit { if retryCount > maxRetryLimit {
@ -424,7 +423,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp
} }
C2Bytes := C2.Bytes()[1:] C2Bytes := C2.Bytes()[1:]
msgLen := len(c2) msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen) msg := sm3.Kdf(C2Bytes, msgLen)
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption return nil, ErrDecryption
} }

View File

@ -7,7 +7,6 @@ import (
"io" "io"
"math/big" "math/big"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
) )
@ -185,7 +184,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.z...) buffer = append(buffer, ke.z...)
buffer = append(buffer, ke.peerZ...) buffer = append(buffer, ke.peerZ...)
} }
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil return sm3.Kdf(buffer, ke.keyLength), nil
} }
// avf is the associative value function. // avf is the associative value function.

View File

@ -11,7 +11,6 @@ import (
"strings" "strings"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm2/sm2ec"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
@ -260,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes()) x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
//A5, calculate t=KDF(x2||y2, klen) //A5, calculate t=KDF(x2||y2, klen)
c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
retryCount++ retryCount++
if retryCount > maxRetryLimit { if retryCount > maxRetryLimit {
@ -408,7 +407,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
curve := priv.Curve curve := priv.Curve
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
msgLen := len(c2) msgLen := len(c2)
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if subtle.ConstantTimeAllZero(c2) { if subtle.ConstantTimeAllZero(c2) {
return nil, ErrDecryption return nil, ErrDecryption
} }

View File

@ -851,3 +851,11 @@ func BenchmarkMoreThan32_P256(b *testing.B) {
func BenchmarkMoreThan32_SM2(b *testing.B) { func BenchmarkMoreThan32_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard") benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
} }
func BenchmarkEncrypt512_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}
func BenchmarkEncrypt1024_SM2(b *testing.B) {
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption sencryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
}

View File

@ -211,3 +211,26 @@ func Sum(data []byte) [Size]byte {
d.Write(data) d.Write(data)
return d.checkSum() return d.checkSum()
} }
// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
func Kdf(z []byte, keyLen int) []byte {
limit := uint64(keyLen+Size-1) / uint64(Size)
if limit >= uint64(1<<32)-1 {
panic("sm3: key length too long")
}
var countBytes [4]byte
var ct uint32 = 1
var k []byte
baseMD := new(digest)
baseMD.Reset()
baseMD.Write(z)
for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct)
md := *baseMD
md.Write(countBytes[:])
h := md.checkSum()
k = append(k, h[:]...)
ct++
}
return k[:keyLen]
}

View File

@ -9,6 +9,8 @@ import (
"fmt" "fmt"
"hash" "hash"
"io" "io"
"math/big"
"reflect"
"testing" "testing"
"golang.org/x/sys/cpu" "golang.org/x/sys/cpu"
@ -403,6 +405,75 @@ func BenchmarkHash8K_SH256(b *testing.B) {
benchmarkSize(benchSH256, b, 8192) benchmarkSize(benchSH256, b, 8192)
} }
func TestKdf(t *testing.T) {
type args struct {
z []byte
len int
}
tests := []struct {
name string
args args
want string
}{
{"sm3 case 1", args{[]byte("emmansun"), 16}, "708993ef1388a0ae4245a19bb6c02554"},
{"sm3 case 2", args{[]byte("emmansun"), 32}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd4"},
{"sm3 case 3", args{[]byte("emmansun"), 48}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"},
{"sm3 case 4", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 48}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f"},
{"sm3 case 5", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 128}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f30277f3179baebd795e7853fa643fdf280d8d7b81a2ab7829f615e132ab376d32194cd315908d27090e1180ce442d9be99322523db5bfac40ac5acb03550f5c93e5b01b1d71f2630868909a6a1250edb"},
}
for _, tt := range tests {
wantBytes, _ := hex.DecodeString(tt.want)
t.Run(tt.name, func(t *testing.T) {
if got := Kdf(tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
}
})
}
}
func TestKdfOldCase(t *testing.T) {
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
expected := "006e30dae231b071dfad8aa379e90264491603"
result := Kdf(append(x2.Bytes(), y2.Bytes()...), 19)
resultStr := hex.EncodeToString(result)
if expected != resultStr {
t.Fatalf("expected %s, real value %s", expected, resultStr)
}
}
func BenchmarkKdfWithSM3(b *testing.B) {
tests := []struct {
zLen int
kLen int
}{
{32, 32},
{32, 64},
{32, 128},
{64, 32},
{64, 64},
{64, 128},
{64, 256},
{64, 512},
{64, 1024},
{64, 1024*8},
}
z := make([]byte, 512)
for _, tt := range tests {
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
Kdf(z[:tt.zLen], tt.kLen)
}
})
}
}
/* /*
func round1(a, b, c, d, e, f, g, h string, i int) { func round1(a, b, c, d, e, f, g, h string, i int) {
fmt.Printf("//Round %d\n", i+1) fmt.Printf("//Round %d\n", i+1)

View File

@ -12,7 +12,6 @@ import (
"github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/randutil" "github.com/emmansun/gmsm/internal/randutil"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256" "github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
@ -317,7 +316,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
key = kdf.Kdf(sm3.New(), buffer, kLen) key = sm3.Kdf(buffer, kLen)
if !subtle.ConstantTimeAllZero(key) { if !subtle.ConstantTimeAllZero(key) {
break break
} }
@ -403,7 +402,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int)
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, kLen) key := sm3.Kdf(buffer, kLen)
if subtle.ConstantTimeAllZero(key) { if subtle.ConstantTimeAllZero(key) {
return nil, ErrDecryption return nil, ErrDecryption
} }
@ -685,7 +684,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
buffer = append(buffer, ke.g2.Marshal()...) buffer = append(buffer, ke.g2.Marshal()...)
buffer = append(buffer, ke.g3.Marshal()...) buffer = append(buffer, ke.g3.Marshal()...)
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil return sm3.Kdf(buffer, ke.keyLength), nil
} }
func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) { func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {

View File

@ -8,7 +8,6 @@ import (
"github.com/emmansun/gmsm/internal/bigmod" "github.com/emmansun/gmsm/internal/bigmod"
"github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/internal/subtle"
"github.com/emmansun/gmsm/kdf"
"github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm3"
"github.com/emmansun/gmsm/sm9/bn256" "github.com/emmansun/gmsm/sm9/bn256"
"golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte"
@ -563,7 +562,7 @@ func TestWrapKeySM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, 32) key := sm3.Kdf(buffer, 32)
if hex.EncodeToString(key) != expectedKey { if hex.EncodeToString(key) != expectedKey {
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key)) t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
@ -629,7 +628,7 @@ func TestEncryptSM9Sample(t *testing.T) {
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32) key := sm3.Kdf(buffer, len(plaintext)+32)
if hex.EncodeToString(key) != expectedKey { if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key") t.Errorf("not expected key")
@ -697,7 +696,7 @@ func TestEncryptSM9SampleBlockMode(t *testing.T) {
buffer = append(buffer, w.Marshal()...) buffer = append(buffer, w.Marshal()...)
buffer = append(buffer, uid...) buffer = append(buffer, uid...)
key := kdf.Kdf(sm3.New(), buffer, 16+32) key := sm3.Kdf(buffer, 16+32)
if hex.EncodeToString(key) != expectedKey { if hex.EncodeToString(key) != expectedKey {
t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key) t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)