diff --git a/sm2/sm2.go b/sm2/sm2.go index f165588..c0e2fa7 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -457,7 +457,7 @@ func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []b ) input := cryptobyte.String(ciphertext) if !input.ReadASN1(&inner, asn1.SEQUENCE) || - (!input.Empty() && !subtle.ConstantTimeAllZero(input)) || + !input.Empty() || !inner.ReadASN1Integer(x1) || !inner.ReadASN1Integer(y1) || !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index ef76ace..087c87f 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -6,11 +6,11 @@ import ( "crypto/elliptic" "crypto/rand" "encoding/hex" + "errors" "math/big" "reflect" "testing" - "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" @@ -329,16 +329,41 @@ func TestEqual(t *testing.T) { } } +// a sample method to get frist ASN1 SEQUENCE data +func getFirstASN1Sequence(ciphertext []byte) ([]byte, []byte, error) { + input := cryptobyte.String(ciphertext) + var inner cryptobyte.String + if !input.ReadASN1(&inner, asn1.SEQUENCE) { + return nil, nil, errors.New("there are no sequence tag") + } + if len(input) == 0 { + return ciphertext, nil, nil + } + return ciphertext[:len(ciphertext)-len(input)], input, nil +} + func TestCipherASN1WithInvalidBytes(t *testing.T) { + ciphertext, _ := hex.DecodeString("3081980220298ED52AE2A0EBA8B7567D54DF41C5F9B310EDFA4A8E15ECCB44EDA94F9F1FC20220116BE33B0833C95D8E5FF9483CD2D7EFF7033C92FE5DEAB6197D809FF1EEE05F042097A90979A6FCEBDE883C2E07E9C286818E694EDE37C3CDAA70E4CD481BE883E00430D62160BB179CB20CE3B5ECA0F5A535BEB6E221566C78FEA92105F71BD37F3F850AD2F86F2D1E35F15E9356557DAC026A") + _, rest, err := getFirstASN1Sequence(ciphertext) + if err != nil || len(rest) != 0 { + t.FailNow() + } + + ciphertext, _ = hex.DecodeString("3081980220298ED52AE2A0EBA8B7567D54DF41C5F9B310EDFA4A8E15ECCB44EDA94F9F1FC20220116BE33B0833C95D8E5FF9483CD2D7EFF7033C92FE5DEAB6197D809FF1EEE05F042097A90979A6FCEBDE883C2E07E9C286818E694EDE37C3CDAA70E4CD481BE883E00430D62160BB179CB20CE3B5ECA0F5A535BEB6E221566C78FEA92105F71BD37F3F850AD2F86F2D1E35F15E9356557DAC026A0000") + seq, rest, err := getFirstASN1Sequence(ciphertext) + if err != nil || len(rest) != 2 { + t.FailNow() + } + var ( x1, y1 = &big.Int{}, &big.Int{} c2, c3 []byte inner cryptobyte.String ) - ciphertext, _ := hex.DecodeString("3081980220298ED52AE2A0EBA8B7567D54DF41C5F9B310EDFA4A8E15ECCB44EDA94F9F1FC20220116BE33B0833C95D8E5FF9483CD2D7EFF7033C92FE5DEAB6197D809FF1EEE05F042097A90979A6FCEBDE883C2E07E9C286818E694EDE37C3CDAA70E4CD481BE883E00430D62160BB179CB20CE3B5ECA0F5A535BEB6E221566C78FEA92105F71BD37F3F850AD2F86F2D1E35F15E9356557DAC026A0000") - input := cryptobyte.String(ciphertext) + + input := cryptobyte.String(seq) if !input.ReadASN1(&inner, asn1.SEQUENCE) || - (!input.Empty() && !subtle.ConstantTimeAllZero(input)) || + !input.Empty() || !inner.ReadASN1Integer(x1) || !inner.ReadASN1Integer(y1) || !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || diff --git a/sm9/sm9.go b/sm9/sm9.go index ef52d7b..ccba988 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -408,7 +408,7 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error ) input := cryptobyte.String(ciphertext) if !input.ReadASN1(&inner, asn1.SEQUENCE) || - (!input.Empty() && !subtle.ConstantTimeAllZero(input)) || + !input.Empty() || !inner.ReadASN1Integer(&encType) || !inner.ReadASN1BitStringAsBytes(&c1Bytes) || !inner.ReadASN1Bytes(&c3Bytes, asn1.OCTET_STRING) ||