mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 04:06:18 +08:00
SM2: add both C1C2C3 & C1C3C2 cipher text splicing order
This commit is contained in:
parent
d5e7461d58
commit
6e3f8e5d1c
4
go.mod
4
go.mod
@ -3,6 +3,6 @@ module github.com/emmansun/gmsm
|
|||||||
go 1.14
|
go 1.14
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519
|
golang.org/x/crypto v0.0.0-20211117183948-ae814b36b871
|
||||||
golang.org/x/sys v0.0.0-20211103235746-7861aae1554b
|
golang.org/x/sys v0.0.0-20211124211545-fe61309f8881
|
||||||
)
|
)
|
||||||
|
103
sm2/sm2.go
103
sm2/sm2.go
@ -55,9 +55,30 @@ const (
|
|||||||
MarshalMixed
|
MarshalMixed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type cipherTextSplicingOrder byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
C1C3C2 cipherTextSplicingOrder = iota
|
||||||
|
C1C2C3
|
||||||
|
)
|
||||||
|
|
||||||
// EncrypterOpts encryption options
|
// EncrypterOpts encryption options
|
||||||
type EncrypterOpts struct {
|
type EncrypterOpts struct {
|
||||||
PointMarshalMode pointMarshalMode
|
PointMarshalMode pointMarshalMode
|
||||||
|
CipherTextSplicingOrder cipherTextSplicingOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecrypterOpts decryption options
|
||||||
|
type DecrypterOpts struct {
|
||||||
|
CipherTextSplicingOrder cipherTextSplicingOrder
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewEncrypterOpts(marhsalMode pointMarshalMode, splicingOrder cipherTextSplicingOrder) *EncrypterOpts {
|
||||||
|
return &EncrypterOpts{marhsalMode, splicingOrder}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDecrypterOpts(splicingOrder cipherTextSplicingOrder) *DecrypterOpts {
|
||||||
|
return &DecrypterOpts{splicingOrder}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
@ -71,7 +92,7 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var defaultEncrypterOpts = EncrypterOpts{MarshalUncompressed}
|
var defaultEncrypterOpts = &EncrypterOpts{MarshalUncompressed, C1C3C2}
|
||||||
|
|
||||||
// directSigning is a standard Hash value that signals that no pre-hashing
|
// directSigning is a standard Hash value that signals that no pre-hashing
|
||||||
// should be performed.
|
// should be performed.
|
||||||
@ -148,7 +169,9 @@ func (priv *PrivateKey) SignWithSM2(rand io.Reader, uid, msg []byte) ([]byte, er
|
|||||||
// Decrypt decrypts msg. The opts argument should be appropriate for
|
// Decrypt decrypts msg. The opts argument should be appropriate for
|
||||||
// the primitive used.
|
// the primitive used.
|
||||||
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
func (priv *PrivateKey) Decrypt(rand io.Reader, msg []byte, opts crypto.DecrypterOpts) (plaintext []byte, err error) {
|
||||||
return Decrypt(priv, msg)
|
var sm2Opts *DecrypterOpts
|
||||||
|
sm2Opts, _ = opts.(*DecrypterOpts)
|
||||||
|
return decrypt(priv, msg, sm2Opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -222,7 +245,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
|||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if opts == nil {
|
if opts == nil {
|
||||||
opts = &defaultEncrypterOpts
|
opts = defaultEncrypterOpts
|
||||||
}
|
}
|
||||||
//A3, requirement is to check if h*P is infinite point, h is 1
|
//A3, requirement is to check if h*P is infinite point, h is 1
|
||||||
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
|
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
|
||||||
@ -262,8 +285,12 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
|
|||||||
//A7, C3 = hash(x2||M||y2)
|
//A7, C3 = hash(x2||M||y2)
|
||||||
c3 := calculateC3(curve, x2, y2, msg)
|
c3 := calculateC3(curve, x2, y2, msg)
|
||||||
|
|
||||||
// c1 || c3 || c2
|
if opts.CipherTextSplicingOrder == C1C3C2 {
|
||||||
return append(append(c1, c3...), c2...), nil
|
// c1 || c3 || c2
|
||||||
|
return append(append(c1, c3...), c2...), nil
|
||||||
|
}
|
||||||
|
// c1 || c2 || c3
|
||||||
|
return append(append(c1, c2...), c3...), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -282,8 +309,16 @@ func GenerateKey(rand io.Reader) (*PrivateKey, error) {
|
|||||||
return priv, nil
|
return priv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrypt sm2 decrypt implementation
|
// Decrypt sm2 decrypt implementation by default DecrypterOpts{C1C3C2}
|
||||||
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||||
|
return decrypt(priv, ciphertext, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, error) {
|
||||||
|
splicingOrder := C1C3C2
|
||||||
|
if opts != nil {
|
||||||
|
splicingOrder = opts.CipherTextSplicingOrder
|
||||||
|
}
|
||||||
ciphertextLen := len(ciphertext)
|
ciphertextLen := len(ciphertext)
|
||||||
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||||
return nil, errors.New("SM2: invalid ciphertext length")
|
return nil, errors.New("SM2: invalid ciphertext length")
|
||||||
@ -300,7 +335,12 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|||||||
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||||
|
|
||||||
//B4, calculate t=KDF(x2||y2, klen)
|
//B4, calculate t=KDF(x2||y2, klen)
|
||||||
c2 := ciphertext[c3Start+sm3.Size:]
|
var c2, c3 []byte
|
||||||
|
if splicingOrder == C1C3C2 {
|
||||||
|
c2 = ciphertext[c3Start+sm3.Size:]
|
||||||
|
} else {
|
||||||
|
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
||||||
|
}
|
||||||
msgLen := len(c2)
|
msgLen := len(c2)
|
||||||
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
if !success {
|
if !success {
|
||||||
@ -314,7 +354,11 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
//B6, calculate hash and compare it
|
//B6, calculate hash and compare it
|
||||||
c3 := ciphertext[c3Start : c3Start+sm3.Size]
|
if splicingOrder == C1C3C2 {
|
||||||
|
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
||||||
|
} else {
|
||||||
|
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
||||||
|
}
|
||||||
u := calculateC3(curve, x2, y2, msg)
|
u := calculateC3(curve, x2, y2, msg)
|
||||||
for i := 0; i < sm3.Size; i++ {
|
for i := 0; i < sm3.Size; i++ {
|
||||||
if c3[i] != u[i] {
|
if c3[i] != u[i] {
|
||||||
@ -325,6 +369,47 @@ func Decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) {
|
|||||||
return msg, nil
|
return msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AdjustCipherTextSplicingOrder(pub *ecdsa.PublicKey, ciphertext []byte, from, to cipherTextSplicingOrder) ([]byte, error) {
|
||||||
|
if from == to {
|
||||||
|
return ciphertext, nil
|
||||||
|
}
|
||||||
|
ciphertextLen := len(ciphertext)
|
||||||
|
if ciphertextLen <= 1+(pub.Params().BitSize/8)+sm3.Size {
|
||||||
|
return nil, errors.New("SM2: invalid ciphertext length")
|
||||||
|
}
|
||||||
|
curve := pub.Curve
|
||||||
|
|
||||||
|
// get C1, and check C1
|
||||||
|
_, _, c3Start, err := bytes2Point(curve, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var c1, c2, c3 []byte
|
||||||
|
|
||||||
|
c1 = ciphertext[:c3Start]
|
||||||
|
if from == C1C3C2 {
|
||||||
|
c2 = ciphertext[c3Start+sm3.Size:]
|
||||||
|
c3 = ciphertext[c3Start : c3Start+sm3.Size]
|
||||||
|
} else {
|
||||||
|
c2 = ciphertext[c3Start : ciphertextLen-sm3.Size]
|
||||||
|
c3 = ciphertext[ciphertextLen-sm3.Size:]
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]byte, ciphertextLen)
|
||||||
|
copy(result, c1)
|
||||||
|
if to == C1C3C2 {
|
||||||
|
// c1 || c3 || c2
|
||||||
|
copy(result[c3Start:], c3)
|
||||||
|
copy(result[c3Start+sm3.Size:], c2)
|
||||||
|
} else {
|
||||||
|
// c1 || c2 || c3
|
||||||
|
copy(result[c3Start:], c2)
|
||||||
|
copy(result[ciphertextLen-sm3.Size:], c3)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// hashToInt converts a hash value to an integer. There is some disagreement
|
// hashToInt converts a hash value to an integer. There is some disagreement
|
||||||
// about how this is done. [NSA] suggests that this is done in the obvious
|
// about how this is done. [NSA] suggests that this is done in the obvious
|
||||||
// manner, but [SECG] truncates the hash to the bit-length of the curve order
|
// manner, but [SECG] truncates the hash to the bit-length of the curve order
|
||||||
|
@ -30,6 +30,52 @@ func Test_kdf(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_SplicingOrder(t *testing.T) {
|
||||||
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
from cipherTextSplicingOrder
|
||||||
|
to cipherTextSplicingOrder
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
|
||||||
|
{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
|
||||||
|
{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
|
||||||
|
{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
|
||||||
|
{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
|
||||||
|
{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewEncrypterOpts(MarshalUncompressed, tt.from))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.from))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Adjust splicing order
|
||||||
|
ciphertext, err = AdjustCipherTextSplicingOrder(&priv.PublicKey, ciphertext, tt.from, tt.to)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("adjust splicing order failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewDecrypterOpts(tt.to))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed after adjust splicing order %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_encryptDecrypt(t *testing.T) {
|
func Test_encryptDecrypt(t *testing.T) {
|
||||||
priv, _ := GenerateKey(rand.Reader)
|
priv, _ := GenerateKey(rand.Reader)
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@ -55,8 +101,8 @@ func Test_encryptDecrypt(t *testing.T) {
|
|||||||
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
}
|
}
|
||||||
// compress mode
|
// compress mode
|
||||||
encrypterOpts := EncrypterOpts{MarshalCompressed}
|
encrypterOpts := NewEncrypterOpts(MarshalCompressed, C1C3C2)
|
||||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("encrypt failed %v", err)
|
t.Fatalf("encrypt failed %v", err)
|
||||||
}
|
}
|
||||||
@ -69,8 +115,8 @@ func Test_encryptDecrypt(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// mixed mode
|
// mixed mode
|
||||||
encrypterOpts = EncrypterOpts{MarshalMixed}
|
encrypterOpts = NewEncrypterOpts(MarshalMixed, C1C3C2)
|
||||||
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), &encrypterOpts)
|
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("encrypt failed %v", err)
|
t.Fatalf("encrypt failed %v", err)
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user