From 9ef3fdc7d5b3821e8872474c9c8284e5c8f8f1a5 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 17 May 2024 08:40:27 +0800 Subject: [PATCH] kdf: refactoring, create one interface --- kdf/kdf.go | 12 ++++++++++-- sm3/sm3.go | 8 ++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/kdf/kdf.go b/kdf/kdf.go index cb12023..a0e7a69 100644 --- a/kdf/kdf.go +++ b/kdf/kdf.go @@ -7,10 +7,19 @@ import ( "hash" ) +// KdfInterface is the interface implemented by some specific Hash implementations. +type KdfInterface interface { + Kdf(z []byte, keyLen int) []byte +} + // Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. // ANSI-X9.63-KDF func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte { baseMD := newHash() + // If the hash implements KdfInterface, use the optimized Kdf method. + if kdfImpl, ok := baseMD.(KdfInterface); ok { + return kdfImpl.Kdf(z, keyLen) + } limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size()) if limit >= uint64(1<<32)-1 { panic("kdf: key length too long") @@ -19,8 +28,7 @@ func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte { var ct uint32 = 1 var k []byte - marshaler, ok := baseMD.(encoding.BinaryMarshaler) - if limit == 1 || len(z) < baseMD.BlockSize() || !ok { + if marshaler, ok := baseMD.(encoding.BinaryMarshaler); limit == 1 || len(z) < baseMD.BlockSize() || !ok { for i := 0; i < int(limit); i++ { binary.BigEndian.PutUint32(countBytes[:], ct) baseMD.Write(z) diff --git a/sm3/sm3.go b/sm3/sm3.go index 8924846..c014689 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -213,7 +213,7 @@ func Sum(data []byte) [Size]byte { } // Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3. -func Kdf(z []byte, keyLen int) []byte { +func (baseMD *digest) Kdf(z []byte, keyLen int) []byte { limit := uint64(keyLen+Size-1) / uint64(Size) if limit >= uint64(1<<32)-1 { panic("sm3: key length too long") @@ -221,7 +221,6 @@ func Kdf(z []byte, keyLen int) []byte { var countBytes [4]byte var ct uint32 = 1 k := make([]byte, keyLen) - baseMD := new(digest) baseMD.Reset() baseMD.Write(z) for i := 0; i < int(limit); i++ { @@ -234,3 +233,8 @@ func Kdf(z []byte, keyLen int) []byte { } return k } + +func Kdf(z []byte, keyLen int) []byte { + baseMD := new(digest) + return baseMD.Kdf(z, keyLen) +}