mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
kdf: move Kdf() from sm3 to kdf
This commit is contained in:
parent
8f5dcb842e
commit
eedd5ebc2b
9
internal/subtle/constant_time.go
Normal file
9
internal/subtle/constant_time.go
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
package subtle
|
||||||
|
|
||||||
|
func ConstantTimeAllZero(bytes []byte) bool {
|
||||||
|
var b uint8
|
||||||
|
for _, v := range bytes {
|
||||||
|
b |= v
|
||||||
|
}
|
||||||
|
return b == 0
|
||||||
|
}
|
24
internal/subtle/constant_time_test.go
Normal file
24
internal/subtle/constant_time_test.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package subtle
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestConstantTimeAllZero(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
bytes []byte
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, true},
|
||||||
|
{"not all zero", args{[]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := ConstantTimeAllZero(tt.args.bytes); got != tt.want {
|
||||||
|
t.Errorf("ConstantTimeAllZero() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
27
kdf/kdf.go
Normal file
27
kdf/kdf.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package kdf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"hash"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
|
||||||
|
// ANSI-X9.63-KDF
|
||||||
|
func Kdf(md hash.Hash, z []byte, len int) []byte {
|
||||||
|
limit := uint64(len+md.Size()-1) / uint64(md.Size())
|
||||||
|
if limit >= uint64(1<<32)-1 {
|
||||||
|
panic("kdf: key length too long")
|
||||||
|
}
|
||||||
|
var countBytes [4]byte
|
||||||
|
var ct uint32 = 1
|
||||||
|
k := make([]byte, len+md.Size()-1)
|
||||||
|
for i := 0; i < int(limit); i++ {
|
||||||
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
|
md.Write(z)
|
||||||
|
md.Write(countBytes[:])
|
||||||
|
copy(k[i*md.Size():], md.Sum(nil))
|
||||||
|
ct++
|
||||||
|
md.Reset()
|
||||||
|
}
|
||||||
|
return k[:len]
|
||||||
|
}
|
64
kdf/kdf_test.go
Normal file
64
kdf/kdf_test.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package kdf
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"hash"
|
||||||
|
"math/big"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/emmansun/gmsm/sm3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKdf(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
md hash.Hash
|
||||||
|
z []byte
|
||||||
|
len int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []byte
|
||||||
|
}{
|
||||||
|
{"sm3 case 1", args{sm3.New(), []byte("emmansun"), 16}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84}},
|
||||||
|
{"sm3 case 2", args{sm3.New(), []byte("emmansun"), 32}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84, 198, 50, 99, 62, 53, 109, 219, 152, 155, 235, 128, 79, 218, 150, 207, 212}},
|
||||||
|
{"sm3 case 3", args{sm3.New(), []byte("emmansun"), 48}, []byte{112, 137, 147, 239, 19, 136, 160, 174, 66, 69, 161, 155, 182, 192, 37, 84, 198, 50, 99, 62, 53, 109, 219, 152, 155, 235, 128, 79, 218, 150, 207, 212, 126, 186, 79, 164, 96, 231, 178, 119, 188, 107, 76, 228, 208, 126, 212, 147}},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("Kdf() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKdfOldCase(t *testing.T) {
|
||||||
|
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
|
||||||
|
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
|
||||||
|
|
||||||
|
expected := "006e30dae231b071dfad8aa379e90264491603"
|
||||||
|
|
||||||
|
result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19)
|
||||||
|
|
||||||
|
resultStr := hex.EncodeToString(result)
|
||||||
|
|
||||||
|
if expected != resultStr {
|
||||||
|
t.Fatalf("expected %s, real value %s", expected, resultStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldPanic(t *testing.T, f func()) {
|
||||||
|
t.Helper()
|
||||||
|
defer func() { _ = recover() }()
|
||||||
|
f()
|
||||||
|
t.Errorf("should have panicked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// This case should be failed on 32bits system.
|
||||||
|
func TestKdfPanic(t *testing.T) {
|
||||||
|
shouldPanic(t, func() {
|
||||||
|
Kdf(sm3.New(), []byte("123456"), 1<<37)
|
||||||
|
})
|
||||||
|
}
|
@ -24,6 +24,7 @@ import (
|
|||||||
|
|
||||||
"github.com/emmansun/gmsm/internal/randutil"
|
"github.com/emmansun/gmsm/internal/randutil"
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
|
"github.com/emmansun/gmsm/kdf"
|
||||||
"github.com/emmansun/gmsm/sm2/sm2ec"
|
"github.com/emmansun/gmsm/sm2/sm2ec"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -335,7 +336,8 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
|||||||
|
|
||||||
//A5, calculate t=KDF(x2||y2, klen)
|
//A5, calculate t=KDF(x2||y2, klen)
|
||||||
var kdfCount int = 0
|
var kdfCount int = 0
|
||||||
c2, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
|
success := subtle.ConstantTimeAllZero(c2)
|
||||||
if !success {
|
if !success {
|
||||||
kdfCount++
|
kdfCount++
|
||||||
if kdfCount > maxRetryLimit {
|
if kdfCount > maxRetryLimit {
|
||||||
@ -396,7 +398,8 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
|
|||||||
curve := priv.Curve
|
curve := priv.Curve
|
||||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
msg, success := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
|
success := subtle.ConstantTimeAllZero(c2)
|
||||||
if !success {
|
if !success {
|
||||||
return nil, errors.New("sm2: invalid cipher text")
|
return nil, errors.New("sm2: invalid cipher text")
|
||||||
}
|
}
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
|
"github.com/emmansun/gmsm/kdf"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -184,11 +185,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
|||||||
buffer = append(buffer, ke.z...)
|
buffer = append(buffer, ke.z...)
|
||||||
buffer = append(buffer, ke.peerZ...)
|
buffer = append(buffer, ke.peerZ...)
|
||||||
}
|
}
|
||||||
key, ok := sm3.Kdf(buffer, ke.keyLength)
|
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("sm2: internal error, kdf failed")
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// avf is the associative value function.
|
// avf is the associative value function.
|
||||||
|
23
sm3/sm3.go
23
sm3/sm3.go
@ -209,26 +209,3 @@ func Sum(data []byte) [Size]byte {
|
|||||||
d.Write(data)
|
d.Write(data)
|
||||||
return d.checkSum()
|
return d.checkSum()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
|
|
||||||
func Kdf(z []byte, len int) ([]byte, bool) {
|
|
||||||
limit := (len + Size - 1) >> SizeBitSize
|
|
||||||
md := New()
|
|
||||||
var countBytes [4]byte
|
|
||||||
var ct uint32 = 1
|
|
||||||
k := make([]byte, len+Size-1)
|
|
||||||
for i := 0; i < limit; i++ {
|
|
||||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
|
||||||
md.Write(z)
|
|
||||||
md.Write(countBytes[:])
|
|
||||||
copy(k[i*Size:], md.Sum(nil))
|
|
||||||
ct++
|
|
||||||
md.Reset()
|
|
||||||
}
|
|
||||||
k = k[:len]
|
|
||||||
var b uint8
|
|
||||||
for _, v := range k {
|
|
||||||
b |= v
|
|
||||||
}
|
|
||||||
return k, int((uint32(b)-1)>>31) != 1
|
|
||||||
}
|
|
||||||
|
@ -5,11 +5,9 @@ import (
|
|||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/sys/cpu"
|
"golang.org/x/sys/cpu"
|
||||||
@ -115,24 +113,6 @@ func TestBlockSize(t *testing.T) {
|
|||||||
fmt.Printf("ARM64 has sm3 %v, has sm4 %v, has aes %v\n", cpu.ARM64.HasSM3, cpu.ARM64.HasSM4, cpu.ARM64.HasAES)
|
fmt.Printf("ARM64 has sm3 %v, has sm4 %v, has aes %v\n", cpu.ARM64.HasSM3, cpu.ARM64.HasSM4, cpu.ARM64.HasAES)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_kdf(t *testing.T) {
|
|
||||||
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
|
|
||||||
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
|
|
||||||
|
|
||||||
expected := "006e30dae231b071dfad8aa379e90264491603"
|
|
||||||
|
|
||||||
result, success := Kdf(append(x2.Bytes(), y2.Bytes()...), 19)
|
|
||||||
if !success {
|
|
||||||
t.Fatalf("failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
resultStr := hex.EncodeToString(result)
|
|
||||||
|
|
||||||
if expected != resultStr {
|
|
||||||
t.Fatalf("expected %s, real value %s", expected, resultStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var bench = New()
|
var bench = New()
|
||||||
var benchSH256 = sha256.New()
|
var benchSH256 = sha256.New()
|
||||||
var buf = make([]byte, 8192)
|
var buf = make([]byte, 8192)
|
||||||
|
13
sm9/sm9.go
13
sm9/sm9.go
@ -11,6 +11,7 @@ import (
|
|||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
|
"github.com/emmansun/gmsm/kdf"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"github.com/emmansun/gmsm/sm9/bn256"
|
"github.com/emmansun/gmsm/sm9/bn256"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -226,7 +227,8 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key, ok = sm3.Kdf(buffer, kLen)
|
key = kdf.Kdf(sm3.New(), buffer, kLen)
|
||||||
|
ok = subtle.ConstantTimeAllZero(key)
|
||||||
if ok {
|
if ok {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -297,7 +299,8 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int)
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key, ok := sm3.Kdf(buffer, kLen)
|
key := kdf.Kdf(sm3.New(), buffer, kLen)
|
||||||
|
ok := subtle.ConstantTimeAllZero(key)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("sm9: invalid cipher")
|
return nil, errors.New("sm9: invalid cipher")
|
||||||
}
|
}
|
||||||
@ -562,11 +565,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
|||||||
buffer = append(buffer, ke.g2.Marshal()...)
|
buffer = append(buffer, ke.g2.Marshal()...)
|
||||||
buffer = append(buffer, ke.g3.Marshal()...)
|
buffer = append(buffer, ke.g3.Marshal()...)
|
||||||
|
|
||||||
key, ok := sm3.Kdf(buffer, ke.keyLength)
|
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("sm9: internal error, kdf failed")
|
|
||||||
}
|
|
||||||
return key, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) {
|
func respondKeyExchange(ke *KeyExchange, hid byte, r *big.Int, rA *bn256.G1) (*bn256.G1, []byte, error) {
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
|
"github.com/emmansun/gmsm/kdf"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"github.com/emmansun/gmsm/sm9/bn256"
|
"github.com/emmansun/gmsm/sm9/bn256"
|
||||||
)
|
)
|
||||||
@ -437,10 +438,8 @@ func TestWrapKeySM9Sample(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key, ok := sm3.Kdf(buffer, 32)
|
key := kdf.Kdf(sm3.New(), buffer, 32)
|
||||||
if !ok {
|
|
||||||
t.Failed()
|
|
||||||
}
|
|
||||||
if hex.EncodeToString(key) != expectedKey {
|
if hex.EncodeToString(key) != expectedKey {
|
||||||
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
|
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
|
||||||
}
|
}
|
||||||
@ -501,10 +500,8 @@ func TestEncryptSM9Sample(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key, ok := sm3.Kdf(buffer, len(plaintext)+32)
|
key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32)
|
||||||
if !ok {
|
|
||||||
t.Failed()
|
|
||||||
}
|
|
||||||
if hex.EncodeToString(key) != expectedKey {
|
if hex.EncodeToString(key) != expectedKey {
|
||||||
t.Errorf("not expected key")
|
t.Errorf("not expected key")
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user