SM2: add both C1C2C3 & C1C3C2 cipher text splicing order

This commit is contained in:
Emman 2021-12-02 10:19:49 +08:00
parent d5e7461d58
commit 6e3f8e5d1c
3 changed files with 146 additions and 15 deletions

4
go.mod
View File

@ -3,6 +3,6 @@ module github.com/emmansun/gmsm
go 1.14
require (
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881
)

View File

@ -55,9 +55,30 @@ const (
MarshalMixed
)
type cipherTextSplicingOrder byte
const (
C1C3C2 cipherTextSplicingOrder = iota
C1C2C3
)
// EncrypterOpts encryption options
type EncrypterOpts struct {
PointMarshalMode pointMarshalMode
PointMarshalMode pointMarshalMode
CipherTextSplicingOrder cipherTextSplicingOrder
}
// DecrypterOpts decryption options
type DecrypterOpts struct {
CipherTextSplicingOrder cipherTextSplicingOrder
}
func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder cipherTextSplicingOrder) *EncrypterOpts {
return &EncrypterOpts{marhsalMode, splicingOrder}
}
func NewDecrypterOpts(splicingOrder cipherTextSplicingOrder) *DecrypterOpts {
return &DecrypterOpts{splicingOrder}
}
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
@ -71,7 +92,7 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte
}
}
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2}
// directSigning is a standard Hash value that signals that no pre-hashing
// should be performed.
@ -148,7 +169,9 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
// Decrypt decrypts msg. The opts argument should be appropriate for
// the primitive used.
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
return Decrypt(priv, msg)
var sm2Opts *DecrypterOpts
sm2Opts, _ = opts.(*DecrypterOpts)
return decrypt(priv, msg, sm2Opts)
}
var (
@ -222,7 +245,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
return nil, nil
}
if opts == nil {
opts = &defaultEncrypterOpts
opts = defaultEncrypterOpts
}
//A3, requirement is to check if h*P is infinite point, h is 1
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
@ -262,8 +285,12 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
//A7, C3 = hash(x2||M||y2)
c3 := calculateC3(curve, x2, y2, msg)
// c1 || c3 || c2
return append(append(c1, c3...), c2...), nil
if opts.CipherTextSplicingOrder == C1C3C2 {
// c1 || c3 || c2
return append(append(c1, c3...), c2...), nil
}
// c1 || c2 || c3
return append(append(c1, c2...), c3...), nil
}
}
@ -282,8 +309,16 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
return priv, nil
}
// Decrypt sm2 decrypt implementation
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return decrypt(priv, ciphertext, nil)
}
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
splicingOrder := C1C3C2
if opts != nil {
splicingOrder = opts.CipherTextSplicingOrder
}
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length")
@ -300,7 +335,12 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
//B4, calculate t=KDF(x2||y2, klen)
c2 := ciphertext[c3Start+sm3.Size:]
var c2, c3 []byte
if splicingOrder == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
}
msgLen := len(c2)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success {
@ -314,7 +354,11 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
}
//B6, calculate hash and compare it
c3 := ciphertext[c3Start : c3Start+sm3.Size]
if splicingOrder == C1C3C2 {
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
u := calculateC3(curve, x2, y2, msg)
for i := 0; i < sm3.Size; i++ {
if c3[i] != u[i] {
@ -325,6 +369,47 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
return msg, nil
}
func AdjustCipherTextSplicingOrder(pub *ecdsa.PublicKey, ciphertext []byte, from, to cipherTextSplicingOrder) ([]byte, error) {
if from == to {
return ciphertext, nil
}
ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(pub.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length")
}
curve := pub.Curve
// get C1, and check C1
_, _, c3Start, err := bytes2Point(curve, ciphertext)
if err != nil {
return nil, err
}
var c1, c2, c3 []byte
c1 = ciphertext[:c3Start]
if from == C1C3C2 {
c2 = ciphertext[c3Start+sm3.Size:]
c3 = ciphertext[c3Start : c3Start+sm3.Size]
} else {
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
c3 = ciphertext[ciphertextLen-sm3.Size:]
}
result := make([]byte, ciphertextLen)
copy(result, c1)
if to == C1C3C2 {
// c1 || c3 || c2
copy(result[c3Start:], c3)
copy(result[c3Start+sm3.Size:], c2)
} else {
// c1 || c2 || c3
copy(result[c3Start:], c2)
copy(result[ciphertextLen-sm3.Size:], c3)
}
return result, nil
}
// hashToInt converts a hash value to an integer. There is some disagreement
// about how this is done. [NSA] suggests that this is done in the obvious
// manner, but [SECG] truncates the hash to the bit-length of the curve order

View File

@ -30,6 +30,52 @@ func Test_kdf(t *testing.T) {
}
}
func Test_SplicingOrder(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
name string
plainText string
from cipherTextSplicingOrder
to cipherTextSplicingOrder
}{
// TODO: Add test cases.
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from))
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from))
if err != nil {
t.Fatalf("decrypt failed %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
//Adjust splicing order
ciphertext, err = AdjustCipherTextSplicingOrder(&priv.PublicKey, ciphertext, tt.from, tt.to)
if err != nil {
t.Fatalf("adjust splicing order failed %v", err)
}
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to))
if err != nil {
t.Fatalf("decrypt failed after adjust splicing order %v", err)
}
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
})
}
}
func Test_encryptDecrypt(t *testing.T) {
priv, _ := GenerateKey(rand.Reader)
tests := []struct {
@ -55,8 +101,8 @@ func Test_encryptDecrypt(t *testing.T) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
}
// compress mode
encrypterOpts := EncrypterOpts{MarshalCompressed}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}
@ -69,8 +115,8 @@ func Test_encryptDecrypt(t *testing.T) {
}
// mixed mode
encrypterOpts = EncrypterOpts{MarshalMixed}
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil {
t.Fatalf("encrypt failed %v", err)
}