diff --git a/go.mod b/go.mod index cf2b1c3..a59f2e1 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module github.com/emmansun/gmsm go 1.14 require ( - golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 - golang.org/x/sys v0.0.0-20211103235746-7861aae1554b + golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871 + golang.org/x/sys v0.0.0-20211124211545-fe61309f8881 ) diff --git a/sm2/sm2.go b/sm2/sm2.go index 177116a..7f5fec4 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -55,9 +55,30 @@ const ( MarshalMixed ) +type cipherTextSplicingOrder byte + +const ( + C1C3C2 cipherTextSplicingOrder = iota + C1C2C3 +) + // EncrypterOpts encryption options type EncrypterOpts struct { - PointMarshalMode pointMarshalMode + PointMarshalMode pointMarshalMode + CipherTextSplicingOrder cipherTextSplicingOrder +} + +// DecrypterOpts decryption options +type DecrypterOpts struct { + CipherTextSplicingOrder cipherTextSplicingOrder +} + +func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder cipherTextSplicingOrder) *EncrypterOpts { + return &EncrypterOpts{marhsalMode, splicingOrder} +} + +func NewDecrypterOpts(splicingOrder cipherTextSplicingOrder) *DecrypterOpts { + return &DecrypterOpts{splicingOrder} } func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { @@ -71,7 +92,7 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte } } -var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed} +var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2} // directSigning is a standard Hash value that signals that no pre-hashing // should be performed. @@ -148,7 +169,9 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er // Decrypt decrypts msg. The opts argument should be appropriate for // the primitive used. func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) { - return Decrypt(priv, msg) + var sm2Opts *DecrypterOpts + sm2Opts, _ = opts.(*DecrypterOpts) + return decrypt(priv, msg, sm2Opts) } var ( @@ -222,7 +245,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter return nil, nil } if opts == nil { - opts = &defaultEncrypterOpts + opts = defaultEncrypterOpts } //A3, requirement is to check if h*P is infinite point, h is 1 if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { @@ -262,8 +285,12 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter //A7, C3 = hash(x2||M||y2) c3 := calculateC3(curve, x2, y2, msg) - // c1 || c3 || c2 - return append(append(c1, c3...), c2...), nil + if opts.CipherTextSplicingOrder == C1C3C2 { + // c1 || c3 || c2 + return append(append(c1, c3...), c2...), nil + } + // c1 || c2 || c3 + return append(append(c1, c2...), c3...), nil } } @@ -282,8 +309,16 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) { return priv, nil } -// Decrypt sm2 decrypt implementation +// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2} func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { + return decrypt(priv, ciphertext, nil) +} + +func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) { + splicingOrder := C1C3C2 + if opts != nil { + splicingOrder = opts.CipherTextSplicingOrder + } ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { return nil, errors.New("SM2: invalid ciphertext length") @@ -300,7 +335,12 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes()) //B4, calculate t=KDF(x2||y2, klen) - c2 := ciphertext[c3Start+sm3.Size:] + var c2, c3 []byte + if splicingOrder == C1C3C2 { + c2 = ciphertext[c3Start+sm3.Size:] + } else { + c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] + } msgLen := len(c2) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { @@ -314,7 +354,11 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { } //B6, calculate hash and compare it - c3 := ciphertext[c3Start : c3Start+sm3.Size] + if splicingOrder == C1C3C2 { + c3 = ciphertext[c3Start : c3Start+sm3.Size] + } else { + c3 = ciphertext[ciphertextLen-sm3.Size:] + } u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { if c3[i] != u[i] { @@ -325,6 +369,47 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { return msg, nil } +func AdjustCipherTextSplicingOrder(pub *ecdsa.PublicKey, ciphertext []byte, from, to cipherTextSplicingOrder) ([]byte, error) { + if from == to { + return ciphertext, nil + } + ciphertextLen := len(ciphertext) + if ciphertextLen <= 1+(pub.Params().BitSize/8)+sm3.Size { + return nil, errors.New("SM2: invalid ciphertext length") + } + curve := pub.Curve + + // get C1, and check C1 + _, _, c3Start, err := bytes2Point(curve, ciphertext) + if err != nil { + return nil, err + } + + var c1, c2, c3 []byte + + c1 = ciphertext[:c3Start] + if from == C1C3C2 { + c2 = ciphertext[c3Start+sm3.Size:] + c3 = ciphertext[c3Start : c3Start+sm3.Size] + } else { + c2 = ciphertext[c3Start : ciphertextLen-sm3.Size] + c3 = ciphertext[ciphertextLen-sm3.Size:] + } + + result := make([]byte, ciphertextLen) + copy(result, c1) + if to == C1C3C2 { + // c1 || c3 || c2 + copy(result[c3Start:], c3) + copy(result[c3Start+sm3.Size:], c2) + } else { + // c1 || c2 || c3 + copy(result[c3Start:], c2) + copy(result[ciphertextLen-sm3.Size:], c3) + } + return result, nil +} + // hashToInt converts a hash value to an integer. There is some disagreement // about how this is done. [NSA] suggests that this is done in the obvious // manner, but [SECG] truncates the hash to the bit-length of the curve order diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index b5a5659..01a9549 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -30,6 +30,52 @@ func Test_kdf(t *testing.T) { } } +func Test_SplicingOrder(t *testing.T) { + priv, _ := GenerateKey(rand.Reader) + tests := []struct { + name string + plainText string + from cipherTextSplicingOrder + to cipherTextSplicingOrder + }{ + // TODO: Add test cases. + {"less than 32 1", "encryption standard", C1C2C3, C1C3C2}, + {"less than 32 2", "encryption standard", C1C3C2, C1C2C3}, + {"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2}, + {"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3}, + {"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2}, + {"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from)) + if err != nil { + t.Fatalf("encrypt failed %v", err) + } + plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from)) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + + //Adjust splicing order + ciphertext, err = AdjustCipherTextSplicingOrder(&priv.PublicKey, ciphertext, tt.from, tt.to) + if err != nil { + t.Fatalf("adjust splicing order failed %v", err) + } + plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to)) + if err != nil { + t.Fatalf("decrypt failed after adjust splicing order %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } + }) + } +} + func Test_encryptDecrypt(t *testing.T) { priv, _ := GenerateKey(rand.Reader) tests := []struct { @@ -55,8 +101,8 @@ func Test_encryptDecrypt(t *testing.T) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } // compress mode - encrypterOpts := EncrypterOpts{MarshalCompressed} - ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) + encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2) + ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) if err != nil { t.Fatalf("encrypt failed %v", err) } @@ -69,8 +115,8 @@ func Test_encryptDecrypt(t *testing.T) { } // mixed mode - encrypterOpts = EncrypterOpts{MarshalMixed} - ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts) + encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2) + ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) if err != nil { t.Fatalf("encrypt failed %v", err) }