mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 12:16:20 +08:00
SM2: add both C1C2C3 & C1C3C2 cipher text splicing order
This commit is contained in:
parent
d5e7461d58
commit
6e3f8e5d1c
4
go.mod
4
go.mod
@ -3,6 +3,6 @@ module github.com/emmansun/gmsm
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
|
||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
|
||||
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871
|
||||
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881
|
||||
)
|
||||
|
103
sm2/sm2.go
103
sm2/sm2.go
@ -55,9 +55,30 @@ const (
|
||||
MarshalMixed
|
||||
)
|
||||
|
||||
type cipherTextSplicingOrder byte
|
||||
|
||||
const (
|
||||
C1C3C2 cipherTextSplicingOrder = iota
|
||||
C1C2C3
|
||||
)
|
||||
|
||||
// EncrypterOpts encryption options
|
||||
type EncrypterOpts struct {
|
||||
PointMarshalMode pointMarshalMode
|
||||
PointMarshalMode pointMarshalMode
|
||||
CipherTextSplicingOrder cipherTextSplicingOrder
|
||||
}
|
||||
|
||||
// DecrypterOpts decryption options
|
||||
type DecrypterOpts struct {
|
||||
CipherTextSplicingOrder cipherTextSplicingOrder
|
||||
}
|
||||
|
||||
func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder cipherTextSplicingOrder) *EncrypterOpts {
|
||||
return &EncrypterOpts{marhsalMode, splicingOrder}
|
||||
}
|
||||
|
||||
func NewDecrypterOpts(splicingOrder cipherTextSplicingOrder) *DecrypterOpts {
|
||||
return &DecrypterOpts{splicingOrder}
|
||||
}
|
||||
|
||||
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||
@ -71,7 +92,7 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte
|
||||
}
|
||||
}
|
||||
|
||||
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
|
||||
var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2}
|
||||
|
||||
// directSigning is a standard Hash value that signals that no pre-hashing
|
||||
// should be performed.
|
||||
@ -148,7 +169,9 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
|
||||
// Decrypt decrypts msg. The opts argument should be appropriate for
|
||||
// the primitive used.
|
||||
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||
return Decrypt(priv, msg)
|
||||
var sm2Opts *DecrypterOpts
|
||||
sm2Opts, _ = opts.(*DecrypterOpts)
|
||||
return decrypt(priv, msg, sm2Opts)
|
||||
}
|
||||
|
||||
var (
|
||||
@ -222,7 +245,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
||||
return nil, nil
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &defaultEncrypterOpts
|
||||
opts = defaultEncrypterOpts
|
||||
}
|
||||
//A3, requirement is to check if h*P is infinite point, h is 1
|
||||
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
|
||||
@ -262,8 +285,12 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
||||
//A7, C3 = hash(x2||M||y2)
|
||||
c3 := calculateC3(curve, x2, y2, msg)
|
||||
|
||||
// c1 || c3 || c2
|
||||
return append(append(c1, c3...), c2...), nil
|
||||
if opts.CipherTextSplicingOrder == C1C3C2 {
|
||||
// c1 || c3 || c2
|
||||
return append(append(c1, c3...), c2...), nil
|
||||
}
|
||||
// c1 || c2 || c3
|
||||
return append(append(c1, c2...), c3...), nil
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,8 +309,16 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
||||
return priv, nil
|
||||
}
|
||||
|
||||
// Decrypt sm2 decrypt implementation
|
||||
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}
|
||||
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
return decrypt(priv, ciphertext, nil)
|
||||
}
|
||||
|
||||
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
||||
splicingOrder := C1C3C2
|
||||
if opts != nil {
|
||||
splicingOrder = opts.CipherTextSplicingOrder
|
||||
}
|
||||
ciphertextLen := len(ciphertext)
|
||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||
return nil, errors.New("SM2: invalid ciphertext length")
|
||||
@ -300,7 +335,12 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||
|
||||
//B4, calculate t=KDF(x2||y2, klen)
|
||||
c2 := ciphertext[c3Start+sm3.Size:]
|
||||
var c2, c3 []byte
|
||||
if splicingOrder == C1C3C2 {
|
||||
c2 = ciphertext[c3Start+sm3.Size:]
|
||||
} else {
|
||||
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
||||
}
|
||||
msgLen := len(c2)
|
||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||
if !success {
|
||||
@ -314,7 +354,11 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
//B6, calculate hash and compare it
|
||||
c3 := ciphertext[c3Start : c3Start+sm3.Size]
|
||||
if splicingOrder == C1C3C2 {
|
||||
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
||||
} else {
|
||||
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
||||
}
|
||||
u := calculateC3(curve, x2, y2, msg)
|
||||
for i := 0; i < sm3.Size; i++ {
|
||||
if c3[i] != u[i] {
|
||||
@ -325,6 +369,47 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func AdjustCipherTextSplicingOrder(pub *ecdsa.PublicKey, ciphertext []byte, from, to cipherTextSplicingOrder) ([]byte, error) {
|
||||
if from == to {
|
||||
return ciphertext, nil
|
||||
}
|
||||
ciphertextLen := len(ciphertext)
|
||||
if ciphertextLen <= 1+(pub.Params().BitSize/8)+sm3.Size {
|
||||
return nil, errors.New("SM2: invalid ciphertext length")
|
||||
}
|
||||
curve := pub.Curve
|
||||
|
||||
// get C1, and check C1
|
||||
_, _, c3Start, err := bytes2Point(curve, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var c1, c2, c3 []byte
|
||||
|
||||
c1 = ciphertext[:c3Start]
|
||||
if from == C1C3C2 {
|
||||
c2 = ciphertext[c3Start+sm3.Size:]
|
||||
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
||||
} else {
|
||||
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
||||
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
||||
}
|
||||
|
||||
result := make([]byte, ciphertextLen)
|
||||
copy(result, c1)
|
||||
if to == C1C3C2 {
|
||||
// c1 || c3 || c2
|
||||
copy(result[c3Start:], c3)
|
||||
copy(result[c3Start+sm3.Size:], c2)
|
||||
} else {
|
||||
// c1 || c2 || c3
|
||||
copy(result[c3Start:], c2)
|
||||
copy(result[ciphertextLen-sm3.Size:], c3)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// hashToInt converts a hash value to an integer. There is some disagreement
|
||||
// about how this is done. [NSA] suggests that this is done in the obvious
|
||||
// manner, but [SECG] truncates the hash to the bit-length of the curve order
|
||||
|
@ -30,6 +30,52 @@ func Test_kdf(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SplicingOrder(t *testing.T) {
|
||||
priv, _ := GenerateKey(rand.Reader)
|
||||
tests := []struct {
|
||||
name string
|
||||
plainText string
|
||||
from cipherTextSplicingOrder
|
||||
to cipherTextSplicingOrder
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
|
||||
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
|
||||
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
|
||||
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
|
||||
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
|
||||
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from))
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from))
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
|
||||
//Adjust splicing order
|
||||
ciphertext, err = AdjustCipherTextSplicingOrder(&priv.PublicKey, ciphertext, tt.from, tt.to)
|
||||
if err != nil {
|
||||
t.Fatalf("adjust splicing order failed %v", err)
|
||||
}
|
||||
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to))
|
||||
if err != nil {
|
||||
t.Fatalf("decrypt failed after adjust splicing order %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_encryptDecrypt(t *testing.T) {
|
||||
priv, _ := GenerateKey(rand.Reader)
|
||||
tests := []struct {
|
||||
@ -55,8 +101,8 @@ func Test_encryptDecrypt(t *testing.T) {
|
||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||
}
|
||||
// compress mode
|
||||
encrypterOpts := EncrypterOpts{MarshalCompressed}
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||
encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2)
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
@ -69,8 +115,8 @@ func Test_encryptDecrypt(t *testing.T) {
|
||||
}
|
||||
|
||||
// mixed mode
|
||||
encrypterOpts = EncrypterOpts{MarshalMixed}
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
||||
encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2)
|
||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||
if err != nil {
|
||||
t.Fatalf("encrypt failed %v", err)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user