kdf: refactoring, create one interface

This commit is contained in:
Sun Yimin 2024-05-17 08:40:27 +08:00 committed by GitHub
parent 7fb729f4a8
commit 9ef3fdc7d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 4 deletions

View File

@ -7,10 +7,19 @@ import (
"hash" "hash"
) )
// KdfInterface is the interface implemented by some specific Hash implementations.
type KdfInterface interface {
Kdf(z []byte, keyLen int) []byte
}
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. // Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
// ANSI-X9.63-KDF // ANSI-X9.63-KDF
func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte { func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
baseMD := newHash() baseMD := newHash()
// If the hash implements KdfInterface, use the optimized Kdf method.
if kdfImpl, ok := baseMD.(KdfInterface); ok {
return kdfImpl.Kdf(z, keyLen)
}
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size()) limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
if limit >= uint64(1<<32)-1 { if limit >= uint64(1<<32)-1 {
panic("kdf: key length too long") panic("kdf: key length too long")
@ -19,8 +28,7 @@ func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
var ct uint32 = 1 var ct uint32 = 1
var k []byte var k []byte
marshaler, ok := baseMD.(encoding.BinaryMarshaler) if marshaler, ok := baseMD.(encoding.BinaryMarshaler); limit == 1 || len(z) < baseMD.BlockSize() || !ok {
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
for i := 0; i < int(limit); i++ { for i := 0; i < int(limit); i++ {
binary.BigEndian.PutUint32(countBytes[:], ct) binary.BigEndian.PutUint32(countBytes[:], ct)
baseMD.Write(z) baseMD.Write(z)

View File

@ -213,7 +213,7 @@ func Sum(data []byte) [Size]byte {
} }
// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3. // Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
func Kdf(z []byte, keyLen int) []byte { func (baseMD *digest) Kdf(z []byte, keyLen int) []byte {
limit := uint64(keyLen+Size-1) / uint64(Size) limit := uint64(keyLen+Size-1) / uint64(Size)
if limit >= uint64(1<<32)-1 { if limit >= uint64(1<<32)-1 {
panic("sm3: key length too long") panic("sm3: key length too long")
@ -221,7 +221,6 @@ func Kdf(z []byte, keyLen int) []byte {
var countBytes [4]byte var countBytes [4]byte
var ct uint32 = 1 var ct uint32 = 1
k := make([]byte, keyLen) k := make([]byte, keyLen)
baseMD := new(digest)
baseMD.Reset() baseMD.Reset()
baseMD.Write(z) baseMD.Write(z)
for i := 0; i < int(limit); i++ { for i := 0; i < int(limit); i++ {
@ -234,3 +233,8 @@ func Kdf(z []byte, keyLen int) []byte {
} }
return k return k
} }
func Kdf(z []byte, keyLen int) []byte {
baseMD := new(digest)
return baseMD.Kdf(z, keyLen)
}