From ef51a679a74bc50cd1101752eaf36e6140357fca Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 10 Jun 2022 10:29:12 +0800 Subject: [PATCH] extract kdf to sm3 --- sm2/sm2.go | 40 ++++++---------------------------------- sm2/sm2_test.go | 20 -------------------- sm3/sm3.go | 23 +++++++++++++++++++++++ sm3/sm3_test.go | 20 ++++++++++++++++++++ 4 files changed, 49 insertions(+), 54 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index c972426..3ecf77e 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -16,7 +16,6 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/sha512" - "encoding/binary" "errors" "fmt" "io" @@ -25,6 +24,7 @@ import ( "sync" "github.com/emmansun/gmsm/internal/randutil" + "github.com/emmansun/gmsm/internal/xor" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" @@ -287,29 +287,6 @@ func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) const maxRetryLimit = 100 -// kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. -func kdf(z []byte, len int) ([]byte, bool) { - limit := (len + sm3.Size - 1) >> sm3.SizeBitSize - md := sm3.New() - var countBytes [4]byte - var ct uint32 = 1 - k := make([]byte, len+sm3.Size-1) - for i := 0; i < limit; i++ { - binary.BigEndian.PutUint32(countBytes[:], ct) - md.Write(z) - md.Write(countBytes[:]) - copy(k[i*sm3.Size:], md.Sum(nil)) - ct++ - md.Reset() - } - for i := 0; i < len; i++ { - if k[i] != 0 { - return k[:len], true - } - } - return k, false -} - func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte { md := sm3.New() md.Write(toBytes(curve, x2)) @@ -364,7 +341,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter //A5, calculate t=KDF(x2||y2, klen) var kdfCount int = 0 - t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + c2, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { kdfCount++ if kdfCount > maxRetryLimit { @@ -374,10 +351,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter } //A6, C2 = M + t; - c2 := make([]byte, msgLen) - for i := 0; i < msgLen; i++ { - c2[i] = msg[i] ^ t[i] - } + xor.XorBytes(c2, msg, c2) //A7, C3 = hash(x2||M||y2) c3 := calculateC3(curve, x2, y2, msg) @@ -428,16 +402,14 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error curve := priv.Curve x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) msgLen := len(c2) - t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + msg, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { return nil, errors.New("sm2: invalid cipher text") } //B5, calculate msg = c2 ^ t - msg := make([]byte, msgLen) - for i := 0; i < msgLen; i++ { - msg[i] = c2[i] ^ t[i] - } + xor.XorBytes(msg, c2, msg) + u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { if c3[i] != u[i] { diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 3dfa8ed..891b60d 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -5,32 +5,12 @@ import ( "crypto/ecdsa" "crypto/elliptic" "crypto/rand" - "encoding/hex" - "math/big" "reflect" "testing" "github.com/emmansun/gmsm/sm3" ) -func Test_kdf(t *testing.T) { - x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) - y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) - - expected := "006e30dae231b071dfad8aa379e90264491603" - - result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19) - if !success { - t.Fatalf("failed") - } - - resultStr := hex.EncodeToString(result) - - if expected != resultStr { - t.Fatalf("expected %s, real value %s", expected, resultStr) - } -} - func Test_SplicingOrder(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { diff --git a/sm3/sm3.go b/sm3/sm3.go index 42b4493..c1bb8bb 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -209,3 +209,26 @@ func Sum(data []byte) [Size]byte { d.Write(data) return d.checkSum() } + +// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. +func Kdf(z []byte, len int) ([]byte, bool) { + limit := (len + Size - 1) >> SizeBitSize + md := New() + var countBytes [4]byte + var ct uint32 = 1 + k := make([]byte, len+Size-1) + for i := 0; i < limit; i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write(z) + md.Write(countBytes[:]) + copy(k[i*Size:], md.Sum(nil)) + ct++ + md.Reset() + } + for i := 0; i < len; i++ { + if k[i] != 0 { + return k[:len], true + } + } + return k, false +} diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index 77b96ba..17c75df 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -5,9 +5,11 @@ import ( "crypto/sha256" "encoding" "encoding/base64" + "encoding/hex" "fmt" "hash" "io" + "math/big" "testing" "golang.org/x/sys/cpu" @@ -113,6 +115,24 @@ func TestBlockSize(t *testing.T) { fmt.Printf("ARM64 has sm3 %v, has sm4 %v, has aes %v\n", cpu.ARM64.HasSM3, cpu.ARM64.HasSM4, cpu.ARM64.HasAES) } +func Test_kdf(t *testing.T) { + x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) + y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) + + expected := "006e30dae231b071dfad8aa379e90264491603" + + result, success := Kdf(append(x2.Bytes(), y2.Bytes()...), 19) + if !success { + t.Fatalf("failed") + } + + resultStr := hex.EncodeToString(result) + + if expected != resultStr { + t.Fatalf("expected %s, real value %s", expected, resultStr) + } +} + var bench = New() var benchSH256 = sha256.New() var buf = make([]byte, 8192)