SM2: add both C1C2C3 & C1C3C2 cipher text splicing order

This commit is contained in:
Emman 2021-12-02 10:19:49 +08:00
parent d5e7461d58
commit 6e3f8e5d1c
3 changed files with 146 additions and 15 deletions

4
go.mod
View File

@ -3,6 +3,6 @@ module github.com/emmansun/gmsm
go 1.14 go 1.14
require ( require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b golang.org/x/sys v0.0.0-20211124211545-fe61309f8881
) )

View File

@ -55,9 +55,30 @@ const (
MarshalMixed MarshalMixed
) )
type cipherTextSplicingOrder byte
const (
C1C3C2 cipherTextSplicingOrder = iota
C1C2C3
)
// EncrypterOpts encryption options // EncrypterOpts encryption options
type EncrypterOpts struct { type EncrypterOpts struct {
PointMarshalMode pointMarshalMode PointMarshalMode pointMarshalMode
CipherTextSplicingOrder cipherTextSplicingOrder
}
// DecrypterOpts decryption options
type DecrypterOpts struct {
CipherTextSplicingOrder cipherTextSplicingOrder
}
func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder cipherTextSplicingOrder) *EncrypterOpts {
return &EncrypterOpts{marhsalMode, splicingOrder}
}
func NewDecrypterOpts(splicingOrder cipherTextSplicingOrder) *DecrypterOpts {
return &DecrypterOpts{splicingOrder}
} }
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
@ -71,7 +92,7 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte
} }
} }
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed} var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2}
// directSigning is a standard Hash value that signals that no pre-hashing // directSigning is a standard Hash value that signals that no pre-hashing
// should be performed. // should be performed.
@ -148,7 +169,9 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
// Decrypt decrypts msg. The opts argument should be appropriate for // Decrypt decrypts msg. The opts argument should be appropriate for
// the primitive used. // the primitive used.
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
return Decrypt(priv, msg) var sm2Opts *DecrypterOpts
sm2Opts, _ = opts.(*DecrypterOpts)
return decrypt(priv, msg, sm2Opts)
} }
var ( var (
@ -222,7 +245,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
return nil, nil return nil, nil
} }
if opts == nil { if opts == nil {
opts = &defaultEncrypterOpts opts = defaultEncrypterOpts
} }
//A3, requirement is to check if h*P is infinite point, h is 1 //A3, requirement is to check if h*P is infinite point, h is 1
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
@ -262,8 +285,12 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
//A7, C3 = hash(x2||M||y2) //A7, C3 = hash(x2||M||y2)
c3 := calculateC3(curve, x2, y2, msg) c3 := calculateC3(curve, x2, y2, msg)
// c1 || c3 || c2 if opts.CipherTextSplicingOrder == C1C3C2 {
return append(append(c1, c3...), c2...), nil // c1 || c3 || c2
return append(append(c1, c3...), c2...), nil
}
// c1 || c2 || c3
return append(append(c1, c2...), c3...), nil
} }
} }
@ -282,8 +309,16 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
return priv, nil return priv, nil
} }
// Decrypt sm2 decrypt implementation // Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return decrypt(priv, ciphertext, nil)
}
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
splicingOrder := C1C3C2
if opts != nil {
splicingOrder = opts.CipherTextSplicingOrder
}
ciphertextLen := len(ciphertext) ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length") return nil, errors.New("SM2: invalid ciphertext length")
@ -300,7 +335,12 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
//B4, calculate t=KDF(x2||y2, klen) //B4, calculate t=KDF(x2||y2, klen)
c2 := ciphertext[c3Start+sm3.Size:] var c2, c3 []byte
if splicingOrder == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
}
msgLen := len(c2) msgLen := len(c2)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success { if !success {
@ -314,7 +354,11 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
} }
//B6, calculate hash and compare it //B6, calculate hash and compare it
c3 := ciphertext[c3Start : c3Start+sm3.Size] if splicingOrder == C1C3C2 {
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
u := calculateC3(curve, x2, y2, msg) u := calculateC3(curve, x2, y2, msg)
for i := 0; i < sm3.Size; i++ { for i := 0; i < sm3.Size; i++ {
if c3[i] != u[i] { if c3[i] != u[i] {
@ -325,6 +369,47 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return msg, nil return msg, nil
} }
func AdjustCipherTextSplicingOrder(pub *ecdsa.PublicKey, ciphertext []byte, from, to cipherTextSplicingOrder) ([]byte, error) {
if from == to {
return ciphertext, nil
}
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(pub.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length")
}
curve := pub.Curve
// get C1, and check C1
_, _, c3Start, err := bytes2Point(curve, ciphertext)
if err != nil {
return nil, err
}
var c1, c2, c3 []byte
c1 = ciphertext[:c3Start]
if from == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
result := make([]byte, ciphertextLen)
copy(result, c1)
if to == C1C3C2 {
// c1 || c3 || c2
copy(result[c3Start:], c3)
copy(result[c3Start+sm3.Size:], c2)
} else {
// c1 || c2 || c3
copy(result[c3Start:], c2)
copy(result[ciphertextLen-sm3.Size:], c3)
}
return result, nil
}
// hashToInt converts a hash value to an integer. There is some disagreement // hashToInt converts a hash value to an integer. There is some disagreement
// about how this is done. [NSA] suggests that this is done in the obvious // about how this is done. [NSA] suggests that this is done in the obvious
// manner, but [SECG] truncates the hash to the bit-length of the curve order // manner, but [SECG] truncates the hash to the bit-length of the curve order

View File

@ -30,6 +30,52 @@ func Test_kdf(t *testing.T) {
} }
} }
func Test_SplicingOrder(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
from cipherTextSplicingOrder
to cipherTextSplicingOrder
}{
// TODO: Add test cases.
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from))
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
//Adjust splicing order
ciphertext, err = AdjustCipherTextSplicingOrder(&priv.PublicKey, ciphertext, tt.from, tt.to)
if err != nil {
t.Fatalf("adjust splicing order failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to))
if err != nil {
t.Fatalf("decrypt failed after adjust splicing order %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func Test_encryptDecrypt(t *testing.T) { func Test_encryptDecrypt(t *testing.T) {
priv, _ := GenerateKey(rand.Reader) priv, _ := GenerateKey(rand.Reader)
tests := []struct { tests := []struct {
@ -55,8 +101,8 @@ func Test_encryptDecrypt(t *testing.T) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
} }
// compress mode // compress mode
encrypterOpts := EncrypterOpts{MarshalCompressed} encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil { if err != nil {
t.Fatalf("encrypt failed %v", err) t.Fatalf("encrypt failed %v", err)
} }
@ -69,8 +115,8 @@ func Test_encryptDecrypt(t *testing.T) {
} }
// mixed mode // mixed mode
encrypterOpts = EncrypterOpts{MarshalMixed} encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil { if err != nil {
t.Fatalf("encrypt failed %v", err) t.Fatalf("encrypt failed %v", err)
} }