From eedd5ebc2b35d05caec8699b420928881c362f70 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Thu, 25 Aug 2022 11:48:41 +0800 Subject: [PATCH] kdf: move Kdf() from sm3 to kdf --- internal/subtle/constant_time.go | 9 ++++ internal/subtle/constant_time_test.go | 24 ++++++++++ kdf/kdf.go | 27 +++++++++++ kdf/kdf_test.go | 64 +++++++++++++++++++++++++++ sm2/sm2.go | 7 ++- sm2/sm2_keyexchange.go | 7 +-- sm3/sm3.go | 23 ---------- sm3/sm3_test.go | 20 --------- sm9/sm9.go | 13 +++--- sm9/sm9_test.go | 13 +++--- 10 files changed, 142 insertions(+), 65 deletions(-) create mode 100644 internal/subtle/constant_time.go create mode 100644 internal/subtle/constant_time_test.go create mode 100644 kdf/kdf.go create mode 100644 kdf/kdf_test.go diff --git a/internal/subtle/constant_time.go b/internal/subtle/constant_time.go new file mode 100644 index 0000000..f70ef38 --- /dev/null +++ b/internal/subtle/constant_time.go @@ -0,0 +1,9 @@ +package subtle + +func ConstantTimeAllZero(bytes []byte) bool { + var b uint8 + for _, v := range bytes { + b |= v + } + return b == 0 +} diff --git a/internal/subtle/constant_time_test.go b/internal/subtle/constant_time_test.go new file mode 100644 index 0000000..8134608 --- /dev/null +++ b/internal/subtle/constant_time_test.go @@ -0,0 +1,24 @@ +package subtle + +import "testing" + +func TestConstantTimeAllZero(t *testing.T) { + type args struct { + bytes []byte + } + tests := []struct { + name string + args args + want bool + }{ + {"all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, true}, + {"not all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := ConstantTimeAllZero(tt.args.bytes); got != tt.want { + t.Errorf("ConstantTimeAllZero() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kdf/kdf.go b/kdf/kdf.go new file mode 100644 index 0000000..7f81a42 --- /dev/null +++ b/kdf/kdf.go @@ -0,0 +1,27 @@ +package kdf + +import ( + "encoding/binary" + "hash" +) + +// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. +// ANSI-X9.63-KDF +func Kdf(md hash.Hash, z []byte, len int) []byte { + limit := uint64(len+md.Size()-1) / uint64(md.Size()) + if limit >= uint64(1<<32)-1 { + panic("kdf: key length too long") + } + var countBytes [4]byte + var ct uint32 = 1 + k := make([]byte, len+md.Size()-1) + for i := 0; i < int(limit); i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write(z) + md.Write(countBytes[:]) + copy(k[i*md.Size():], md.Sum(nil)) + ct++ + md.Reset() + } + return k[:len] +} diff --git a/kdf/kdf_test.go b/kdf/kdf_test.go new file mode 100644 index 0000000..8e09b52 --- /dev/null +++ b/kdf/kdf_test.go @@ -0,0 +1,64 @@ +package kdf + +import ( + "encoding/hex" + "hash" + "math/big" + "reflect" + "testing" + + "github.com/emmansun/gmsm/sm3" +) + +func TestKdf(t *testing.T) { + type args struct { + md hash.Hash + z []byte + len int + } + tests := []struct { + name string + args args + want []byte + }{ + {"sm3 case 1", args{sm3.New(), []byte("emmansun"), 16}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84}}, + {"sm3 case 2", args{sm3.New(), []byte("emmansun"), 32}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84, 198, 50, 99, 62, 53, 109, 219, 152, 155, 235, 128, 79, 218, 150, 207, 212}}, + {"sm3 case 3", args{sm3.New(), []byte("emmansun"), 48}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84, 198, 50, 99, 62, 53, 109, 219, 152, 155, 235, 128, 79, 218, 150, 207, 212, 126, 186, 79, 164, 96, 231, 178, 119, 188, 107, 76, 228, 208, 126, 212, 147}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Kdf() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestKdfOldCase(t *testing.T) { + x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) + y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) + + expected := "006e30dae231b071dfad8aa379e90264491603" + + result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19) + + resultStr := hex.EncodeToString(result) + + if expected != resultStr { + t.Fatalf("expected %s, real value %s", expected, resultStr) + } +} + +func shouldPanic(t *testing.T, f func()) { + t.Helper() + defer func() { _ = recover() }() + f() + t.Errorf("should have panicked") +} + +// This case should be failed on 32bits system. +func TestKdfPanic(t *testing.T) { + shouldPanic(t, func() { + Kdf(sm3.New(), []byte("123456"), 1<<37) + }) +} diff --git a/sm2/sm2.go b/sm2/sm2.go index 47e5902..d5c6c44 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -24,6 +24,7 @@ import ( "github.com/emmansun/gmsm/internal/randutil" "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" @@ -335,7 +336,8 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter //A5, calculate t=KDF(x2||y2, klen) var kdfCount int = 0 - c2, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + success := subtle.ConstantTimeAllZero(c2) if !success { kdfCount++ if kdfCount > maxRetryLimit { @@ -396,7 +398,8 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error curve := priv.Curve x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) msgLen := len(c2) - msg, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) + success := subtle.ConstantTimeAllZero(c2) if !success { return nil, errors.New("sm2: invalid cipher text") } diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index ceef36d..1e9f815 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -7,6 +7,7 @@ import ( "io" "math/big" + "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" ) @@ -184,11 +185,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { buffer = append(buffer, ke.z...) buffer = append(buffer, ke.peerZ...) } - key, ok := sm3.Kdf(buffer, ke.keyLength) - if !ok { - return nil, errors.New("sm2: internal error, kdf failed") - } - return key, nil + return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil } // avf is the associative value function. diff --git a/sm3/sm3.go b/sm3/sm3.go index 091a327..828f006 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -209,26 +209,3 @@ func Sum(data []byte) [Size]byte { d.Write(data) return d.checkSum() } - -// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3. -func Kdf(z []byte, len int) ([]byte, bool) { - limit := (len + Size - 1) >> SizeBitSize - md := New() - var countBytes [4]byte - var ct uint32 = 1 - k := make([]byte, len+Size-1) - for i := 0; i < limit; i++ { - binary.BigEndian.PutUint32(countBytes[:], ct) - md.Write(z) - md.Write(countBytes[:]) - copy(k[i*Size:], md.Sum(nil)) - ct++ - md.Reset() - } - k = k[:len] - var b uint8 - for _, v := range k { - b |= v - } - return k, int((uint32(b)-1)>>31) != 1 -} diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index 17c75df..77b96ba 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -5,11 +5,9 @@ import ( "crypto/sha256" "encoding" "encoding/base64" - "encoding/hex" "fmt" "hash" "io" - "math/big" "testing" "golang.org/x/sys/cpu" @@ -115,24 +113,6 @@ func TestBlockSize(t *testing.T) { fmt.Printf("ARM64 has sm3 %v, has sm4 %v, has aes %v\n", cpu.ARM64.HasSM3, cpu.ARM64.HasSM4, cpu.ARM64.HasAES) } -func Test_kdf(t *testing.T) { - x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16) - y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16) - - expected := "006e30dae231b071dfad8aa379e90264491603" - - result, success := Kdf(append(x2.Bytes(), y2.Bytes()...), 19) - if !success { - t.Fatalf("failed") - } - - resultStr := hex.EncodeToString(result) - - if expected != resultStr { - t.Fatalf("expected %s, real value %s", expected, resultStr) - } -} - var bench = New() var benchSH256 = sha256.New() var buf = make([]byte, 8192) diff --git a/sm9/sm9.go b/sm9/sm9.go index 535ea49..7a963f9 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -11,6 +11,7 @@ import ( "math/big" "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" @@ -226,7 +227,8 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key, ok = sm3.Kdf(buffer, kLen) + key = kdf.Kdf(sm3.New(), buffer, kLen) + ok = subtle.ConstantTimeAllZero(key) if ok { break } @@ -297,7 +299,8 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int) buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key, ok := sm3.Kdf(buffer, kLen) + key := kdf.Kdf(sm3.New(), buffer, kLen) + ok := subtle.ConstantTimeAllZero(key) if !ok { return nil, errors.New("sm9: invalid cipher") } @@ -562,11 +565,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) { buffer = append(buffer, ke.g2.Marshal()...) buffer = append(buffer, ke.g3.Marshal()...) - key, ok := sm3.Kdf(buffer, ke.keyLength) - if !ok { - return nil, errors.New("sm9: internal error, kdf failed") - } - return key, nil + return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil } func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) { diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index c59f9cb..b8157ce 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -7,6 +7,7 @@ import ( "testing" "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/kdf" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" ) @@ -437,10 +438,8 @@ func TestWrapKeySM9Sample(t *testing.T) { buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key, ok := sm3.Kdf(buffer, 32) - if !ok { - t.Failed() - } + key := kdf.Kdf(sm3.New(), buffer, 32) + if hex.EncodeToString(key) != expectedKey { t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key)) } @@ -501,10 +500,8 @@ func TestEncryptSM9Sample(t *testing.T) { buffer = append(buffer, w.Marshal()...) buffer = append(buffer, uid...) - key, ok := sm3.Kdf(buffer, len(plaintext)+32) - if !ok { - t.Failed() - } + key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32) + if hex.EncodeToString(key) != expectedKey { t.Errorf("not expected key") }