mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
kdf: share Z hash state #220
This commit is contained in:
parent
57318eaf5b
commit
c99ad27ce1
@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/padding"
|
"github.com/emmansun/gmsm/padding"
|
||||||
"github.com/emmansun/gmsm/pkcs"
|
"github.com/emmansun/gmsm/pkcs"
|
||||||
"github.com/emmansun/gmsm/sm2"
|
"github.com/emmansun/gmsm/sm2"
|
||||||
@ -59,7 +58,7 @@ func ParseSM2(password, data []byte) (*sm2.PrivateKey, *smx509.Certificate, erro
|
|||||||
if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) {
|
if !keys.EncryptedKey.Algorithm.Equal(oidSM4) && !keys.EncryptedKey.Algorithm.Equal(oidSM4CBC) {
|
||||||
return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm)
|
return nil, nil, fmt.Errorf("cfca: unsupported algorithm <%v>", keys.EncryptedKey.Algorithm)
|
||||||
}
|
}
|
||||||
ivkey := kdf.Kdf(sm3.New(), password, 32)
|
ivkey := sm3.Kdf(password, 32)
|
||||||
marshalledIV, err := asn1.Marshal(ivkey[:16])
|
marshalledIV, err := asn1.Marshal(ivkey[:16])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
@ -91,7 +90,7 @@ func MarshalSM2(password []byte, key *sm2.PrivateKey, cert *smx509.Certificate)
|
|||||||
if len(password) == 0 {
|
if len(password) == 0 {
|
||||||
return nil, errors.New("cfca: invalid password")
|
return nil, errors.New("cfca: invalid password")
|
||||||
}
|
}
|
||||||
ivkey := kdf.Kdf(sm3.New(), password, 32)
|
ivkey := sm3.Kdf(password, 32)
|
||||||
block, err := sm4.NewCipher(ivkey[16:])
|
block, err := sm4.NewCipher(ivkey[16:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
31
kdf/kdf.go
31
kdf/kdf.go
@ -2,27 +2,48 @@
|
|||||||
package kdf
|
package kdf
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"hash"
|
"hash"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
|
// Kdf key derivation function, compliance with GB/T 32918.4-2016 5.4.3.
|
||||||
// ANSI-X9.63-KDF
|
// ANSI-X9.63-KDF
|
||||||
func Kdf(md hash.Hash, z []byte, len int) []byte {
|
func Kdf(newHash func() hash.Hash, z []byte, keyLen int) []byte {
|
||||||
limit := uint64(len+md.Size()-1) / uint64(md.Size())
|
baseMD := newHash()
|
||||||
|
limit := uint64(keyLen+baseMD.Size()-1) / uint64(baseMD.Size())
|
||||||
if limit >= uint64(1<<32)-1 {
|
if limit >= uint64(1<<32)-1 {
|
||||||
panic("kdf: key length too long")
|
panic("kdf: key length too long")
|
||||||
}
|
}
|
||||||
var countBytes [4]byte
|
var countBytes [4]byte
|
||||||
var ct uint32 = 1
|
var ct uint32 = 1
|
||||||
var k []byte
|
var k []byte
|
||||||
|
|
||||||
|
marshaler, ok := baseMD.(encoding.BinaryMarshaler)
|
||||||
|
if limit == 1 || len(z) < baseMD.BlockSize() || !ok {
|
||||||
for i := 0; i < int(limit); i++ {
|
for i := 0; i < int(limit); i++ {
|
||||||
binary.BigEndian.PutUint32(countBytes[:], ct)
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
md.Write(z)
|
baseMD.Write(z)
|
||||||
|
baseMD.Write(countBytes[:])
|
||||||
|
k = baseMD.Sum(k)
|
||||||
|
ct++
|
||||||
|
baseMD.Reset()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
baseMD.Write(z)
|
||||||
|
zstate, _ := marshaler.MarshalBinary()
|
||||||
|
for i := 0; i < int(limit); i++ {
|
||||||
|
md := newHash()
|
||||||
|
err := md.(encoding.BinaryUnmarshaler).UnmarshalBinary(zstate)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
md.Write(countBytes[:])
|
md.Write(countBytes[:])
|
||||||
k = md.Sum(k)
|
k = md.Sum(k)
|
||||||
ct++
|
ct++
|
||||||
md.Reset()
|
|
||||||
}
|
}
|
||||||
return k[:len]
|
}
|
||||||
|
|
||||||
|
return k[:keyLen]
|
||||||
}
|
}
|
||||||
|
@ -11,6 +11,6 @@ import (
|
|||||||
// This case should be failed on 32bits system.
|
// This case should be failed on 32bits system.
|
||||||
func TestKdfPanic(t *testing.T) {
|
func TestKdfPanic(t *testing.T) {
|
||||||
shouldPanic(t, func() {
|
shouldPanic(t, func() {
|
||||||
Kdf(sm3.New(), []byte("123456"), 1<<37)
|
Kdf(sm3.New, []byte("123456"), 1<<37)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -31,7 +31,7 @@ func TestKdf(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
wantBytes, _ := hex.DecodeString(tt.want)
|
wantBytes, _ := hex.DecodeString(tt.want)
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if got := Kdf(tt.args.md, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
|
if got := Kdf(sm3.New, tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
|
||||||
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
|
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
@ -44,7 +44,7 @@ func TestKdfOldCase(t *testing.T) {
|
|||||||
|
|
||||||
expected := "006e30dae231b071dfad8aa379e90264491603"
|
expected := "006e30dae231b071dfad8aa379e90264491603"
|
||||||
|
|
||||||
result := Kdf(sm3.New(), append(x2.Bytes(), y2.Bytes()...), 19)
|
result := Kdf(sm3.New, append(x2.Bytes(), y2.Bytes()...), 19)
|
||||||
|
|
||||||
resultStr := hex.EncodeToString(result)
|
resultStr := hex.EncodeToString(result)
|
||||||
|
|
||||||
@ -71,16 +71,17 @@ func BenchmarkKdf(b *testing.B) {
|
|||||||
{64, 32},
|
{64, 32},
|
||||||
{64, 64},
|
{64, 64},
|
||||||
{64, 128},
|
{64, 128},
|
||||||
{440, 32},
|
{64, 256},
|
||||||
|
{64, 512},
|
||||||
|
{64, 1024},
|
||||||
}
|
}
|
||||||
sm3Hash := sm3.New()
|
|
||||||
z := make([]byte, 512)
|
z := make([]byte, 512)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
|
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
Kdf(sm3Hash, z[:tt.zLen], tt.kLen)
|
Kdf(sm3.New, z[:tt.zLen], tt.kLen)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -25,7 +25,6 @@ import (
|
|||||||
"github.com/emmansun/gmsm/internal/randutil"
|
"github.com/emmansun/gmsm/internal/randutil"
|
||||||
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
|
_sm2ec "github.com/emmansun/gmsm/internal/sm2ec"
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/sm2/sm2ec"
|
"github.com/emmansun/gmsm/sm2/sm2ec"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -251,7 +250,7 @@ func encryptSM2EC(c *sm2Curve, pub *ecdsa.PublicKey, random io.Reader, msg []byt
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
C2Bytes := C2.Bytes()[1:]
|
C2Bytes := C2.Bytes()[1:]
|
||||||
c2 := kdf.Kdf(sm3.New(), C2Bytes, len(msg))
|
c2 := sm3.Kdf(C2Bytes, len(msg))
|
||||||
if subtle.ConstantTimeAllZero(c2) {
|
if subtle.ConstantTimeAllZero(c2) {
|
||||||
retryCount++
|
retryCount++
|
||||||
if retryCount > maxRetryLimit {
|
if retryCount > maxRetryLimit {
|
||||||
@ -424,7 +423,7 @@ func decryptSM2EC(c *sm2Curve, priv *PrivateKey, ciphertext []byte, opts *Decryp
|
|||||||
}
|
}
|
||||||
C2Bytes := C2.Bytes()[1:]
|
C2Bytes := C2.Bytes()[1:]
|
||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
msg := kdf.Kdf(sm3.New(), C2Bytes, msgLen)
|
msg := sm3.Kdf(C2Bytes, msgLen)
|
||||||
if subtle.ConstantTimeAllZero(c2) {
|
if subtle.ConstantTimeAllZero(c2) {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -185,7 +184,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
|||||||
buffer = append(buffer, ke.z...)
|
buffer = append(buffer, ke.z...)
|
||||||
buffer = append(buffer, ke.peerZ...)
|
buffer = append(buffer, ke.peerZ...)
|
||||||
}
|
}
|
||||||
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
|
return sm3.Kdf(buffer, ke.keyLength), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// avf is the associative value function.
|
// avf is the associative value function.
|
||||||
|
@ -11,7 +11,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/sm2/sm2ec"
|
"github.com/emmansun/gmsm/sm2/sm2ec"
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -260,7 +259,7 @@ func encryptLegacy(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Enc
|
|||||||
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||||
|
|
||||||
//A5, calculate t=KDF(x2||y2, klen)
|
//A5, calculate t=KDF(x2||y2, klen)
|
||||||
c2 := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
c2 := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
if subtle.ConstantTimeAllZero(c2) {
|
if subtle.ConstantTimeAllZero(c2) {
|
||||||
retryCount++
|
retryCount++
|
||||||
if retryCount > maxRetryLimit {
|
if retryCount > maxRetryLimit {
|
||||||
@ -408,7 +407,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
|
|||||||
curve := priv.Curve
|
curve := priv.Curve
|
||||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
msg := kdf.Kdf(sm3.New(), append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
msg := sm3.Kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
if subtle.ConstantTimeAllZero(c2) {
|
if subtle.ConstantTimeAllZero(c2) {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
|
@ -851,3 +851,11 @@ func BenchmarkMoreThan32_P256(b *testing.B) {
|
|||||||
func BenchmarkMoreThan32_SM2(b *testing.B) {
|
func BenchmarkMoreThan32_SM2(b *testing.B) {
|
||||||
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
|
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt512_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkEncrypt1024_SM2(b *testing.B) {
|
||||||
|
benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption sencryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption s")
|
||||||
|
}
|
||||||
|
23
sm3/sm3.go
23
sm3/sm3.go
@ -211,3 +211,26 @@ func Sum(data []byte) [Size]byte {
|
|||||||
d.Write(data)
|
d.Write(data)
|
||||||
return d.checkSum()
|
return d.checkSum()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Kdf key derivation function using SM3, compliance with GB/T 32918.4-2016 5.4.3.
|
||||||
|
func Kdf(z []byte, keyLen int) []byte {
|
||||||
|
limit := uint64(keyLen+Size-1) / uint64(Size)
|
||||||
|
if limit >= uint64(1<<32)-1 {
|
||||||
|
panic("sm3: key length too long")
|
||||||
|
}
|
||||||
|
var countBytes [4]byte
|
||||||
|
var ct uint32 = 1
|
||||||
|
var k []byte
|
||||||
|
baseMD := new(digest)
|
||||||
|
baseMD.Reset()
|
||||||
|
baseMD.Write(z)
|
||||||
|
for i := 0; i < int(limit); i++ {
|
||||||
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
|
md := *baseMD
|
||||||
|
md.Write(countBytes[:])
|
||||||
|
h := md.checkSum()
|
||||||
|
k = append(k, h[:]...)
|
||||||
|
ct++
|
||||||
|
}
|
||||||
|
return k[:keyLen]
|
||||||
|
}
|
||||||
|
@ -9,6 +9,8 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/sys/cpu"
|
"golang.org/x/sys/cpu"
|
||||||
@ -403,6 +405,75 @@ func BenchmarkHash8K_SH256(b *testing.B) {
|
|||||||
benchmarkSize(benchSH256, b, 8192)
|
benchmarkSize(benchSH256, b, 8192)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKdf(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
z []byte
|
||||||
|
len int
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{"sm3 case 1", args{[]byte("emmansun"), 16}, "708993ef1388a0ae4245a19bb6c02554"},
|
||||||
|
{"sm3 case 2", args{[]byte("emmansun"), 32}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd4"},
|
||||||
|
{"sm3 case 3", args{[]byte("emmansun"), 48}, "708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"},
|
||||||
|
{"sm3 case 4", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 48}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f"},
|
||||||
|
{"sm3 case 5", args{[]byte("708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493708993ef1388a0ae4245a19bb6c02554c632633e356ddb989beb804fda96cfd47eba4fa460e7b277bc6b4ce4d07ed493"), 128}, "49cf14649f324a07e0d5bb2a00f7f05d5f5bdd6d14dff028e071327ec031104590eddb18f98b763e18bf382ff7c3875f30277f3179baebd795e7853fa643fdf280d8d7b81a2ab7829f615e132ab376d32194cd315908d27090e1180ce442d9be99322523db5bfac40ac5acb03550f5c93e5b01b1d71f2630868909a6a1250edb"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
wantBytes, _ := hex.DecodeString(tt.want)
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := Kdf(tt.args.z, tt.args.len); !reflect.DeepEqual(got, wantBytes) {
|
||||||
|
t.Errorf("Kdf(%v) = %x, want %v", tt.name, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKdfOldCase(t *testing.T) {
|
||||||
|
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
|
||||||
|
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
|
||||||
|
|
||||||
|
expected := "006e30dae231b071dfad8aa379e90264491603"
|
||||||
|
|
||||||
|
result := Kdf(append(x2.Bytes(), y2.Bytes()...), 19)
|
||||||
|
|
||||||
|
resultStr := hex.EncodeToString(result)
|
||||||
|
|
||||||
|
if expected != resultStr {
|
||||||
|
t.Fatalf("expected %s, real value %s", expected, resultStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkKdfWithSM3(b *testing.B) {
|
||||||
|
tests := []struct {
|
||||||
|
zLen int
|
||||||
|
kLen int
|
||||||
|
}{
|
||||||
|
{32, 32},
|
||||||
|
{32, 64},
|
||||||
|
{32, 128},
|
||||||
|
{64, 32},
|
||||||
|
{64, 64},
|
||||||
|
{64, 128},
|
||||||
|
{64, 256},
|
||||||
|
{64, 512},
|
||||||
|
{64, 1024},
|
||||||
|
{64, 1024*8},
|
||||||
|
}
|
||||||
|
z := make([]byte, 512)
|
||||||
|
for _, tt := range tests {
|
||||||
|
b.Run(fmt.Sprintf("zLen=%v-kLen=%v", tt.zLen, tt.kLen), func(b *testing.B) {
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
Kdf(z[:tt.zLen], tt.kLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
func round1(a, b, c, d, e, f, g, h string, i int) {
|
func round1(a, b, c, d, e, f, g, h string, i int) {
|
||||||
fmt.Printf("//Round %d\n", i+1)
|
fmt.Printf("//Round %d\n", i+1)
|
||||||
|
@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/emmansun/gmsm/internal/bigmod"
|
"github.com/emmansun/gmsm/internal/bigmod"
|
||||||
"github.com/emmansun/gmsm/internal/randutil"
|
"github.com/emmansun/gmsm/internal/randutil"
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"github.com/emmansun/gmsm/sm9/bn256"
|
"github.com/emmansun/gmsm/sm9/bn256"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -317,7 +316,7 @@ func WrapKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte,
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key = kdf.Kdf(sm3.New(), buffer, kLen)
|
key = sm3.Kdf(buffer, kLen)
|
||||||
if !subtle.ConstantTimeAllZero(key) {
|
if !subtle.ConstantTimeAllZero(key) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@ -403,7 +402,7 @@ func UnwrapKey(priv *EncryptPrivateKey, uid []byte, cipher *bn256.G1, kLen int)
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key := kdf.Kdf(sm3.New(), buffer, kLen)
|
key := sm3.Kdf(buffer, kLen)
|
||||||
if subtle.ConstantTimeAllZero(key) {
|
if subtle.ConstantTimeAllZero(key) {
|
||||||
return nil, ErrDecryption
|
return nil, ErrDecryption
|
||||||
}
|
}
|
||||||
@ -685,7 +684,7 @@ func (ke *KeyExchange) generateSharedKey(isResponder bool) ([]byte, error) {
|
|||||||
buffer = append(buffer, ke.g2.Marshal()...)
|
buffer = append(buffer, ke.g2.Marshal()...)
|
||||||
buffer = append(buffer, ke.g3.Marshal()...)
|
buffer = append(buffer, ke.g3.Marshal()...)
|
||||||
|
|
||||||
return kdf.Kdf(sm3.New(), buffer, ke.keyLength), nil
|
return sm3.Kdf(buffer, ke.keyLength), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {
|
func respondKeyExchange(ke *KeyExchange, hid byte, r *bigmod.Nat, rA *bn256.G1) (*bn256.G1, []byte, error) {
|
||||||
|
@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/emmansun/gmsm/internal/bigmod"
|
"github.com/emmansun/gmsm/internal/bigmod"
|
||||||
"github.com/emmansun/gmsm/internal/subtle"
|
"github.com/emmansun/gmsm/internal/subtle"
|
||||||
"github.com/emmansun/gmsm/kdf"
|
|
||||||
"github.com/emmansun/gmsm/sm3"
|
"github.com/emmansun/gmsm/sm3"
|
||||||
"github.com/emmansun/gmsm/sm9/bn256"
|
"github.com/emmansun/gmsm/sm9/bn256"
|
||||||
"golang.org/x/crypto/cryptobyte"
|
"golang.org/x/crypto/cryptobyte"
|
||||||
@ -563,7 +562,7 @@ func TestWrapKeySM9Sample(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key := kdf.Kdf(sm3.New(), buffer, 32)
|
key := sm3.Kdf(buffer, 32)
|
||||||
|
|
||||||
if hex.EncodeToString(key) != expectedKey {
|
if hex.EncodeToString(key) != expectedKey {
|
||||||
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
|
t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key))
|
||||||
@ -629,7 +628,7 @@ func TestEncryptSM9Sample(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key := kdf.Kdf(sm3.New(), buffer, len(plaintext)+32)
|
key := sm3.Kdf(buffer, len(plaintext)+32)
|
||||||
|
|
||||||
if hex.EncodeToString(key) != expectedKey {
|
if hex.EncodeToString(key) != expectedKey {
|
||||||
t.Errorf("not expected key")
|
t.Errorf("not expected key")
|
||||||
@ -697,7 +696,7 @@ func TestEncryptSM9SampleBlockMode(t *testing.T) {
|
|||||||
buffer = append(buffer, w.Marshal()...)
|
buffer = append(buffer, w.Marshal()...)
|
||||||
buffer = append(buffer, uid...)
|
buffer = append(buffer, uid...)
|
||||||
|
|
||||||
key := kdf.Kdf(sm3.New(), buffer, 16+32)
|
key := sm3.Kdf(buffer, 16+32)
|
||||||
|
|
||||||
if hex.EncodeToString(key) != expectedKey {
|
if hex.EncodeToString(key) != expectedKey {
|
||||||
t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)
|
t.Errorf("not expected key, expected %v, got %x\n", expectedKey, key)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user