From bb0f4f7996c046ab21f21076e669482af674031f Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Fri, 27 May 2022 17:46:14 +0800 Subject: [PATCH] refactoring, align error message pattern --- sm2/sm2.go | 56 ++++++++++++++++++++++++------------------------- sm2/sm2_test.go | 4 ++-- 2 files changed, 30 insertions(+), 30 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index 8d2c2e6..c972426 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -1,4 +1,4 @@ -// Package sm2 handle shangmi sm2 algorithm and its curve implementation +// Package sm2 handle shangmi sm2 digital signature and public key encryption algorithm and its curve implementation package sm2 // Further references: @@ -33,9 +33,9 @@ import ( const ( uncompressed byte = 0x04 compressed02 byte = 0x02 - compressed03 byte = 0x03 - mixed06 byte = 0x06 - mixed07 byte = 0x07 + compressed03 byte = compressed02 | 0x01 + hybrid06 byte = 0x06 + hybrid07 byte = hybrid06 | 0x01 ) // A invertible implements fast inverse in GF(N). @@ -62,8 +62,8 @@ const ( MarshalUncompressed pointMarshalMode = iota //MarshalCompressed compressed mashal mode MarshalCompressed - //MarshalMixed mixed mashal mode - MarshalMixed + //MarshalHybrid hybrid mashal mode + MarshalHybrid ) type ciphertextSplicingOrder byte @@ -105,9 +105,9 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte switch mode { case MarshalCompressed: return elliptic.MarshalCompressed(curve, x, y) - case MarshalMixed: + case MarshalHybrid: buffer := elliptic.Marshal(curve, x, y) - buffer[0] = byte(y.Bit(0)) | mixed06 + buffer[0] = byte(y.Bit(0)) | hybrid06 return buffer default: return elliptic.Marshal(curve, x, y) @@ -123,39 +123,39 @@ func toBytes(curve elliptic.Curve, value *big.Int) []byte { func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { if len(bytes) < 1+(curve.Params().BitSize/8) { - return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) + return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes)) } format := bytes[0] byteLen := (curve.Params().BitSize + 7) >> 3 switch format { - case uncompressed, mixed06, mixed07: // what's the mixed format purpose? + case uncompressed, hybrid06, hybrid07: // what's the hybrid format purpose? if len(bytes) < 1+byteLen*2 { - return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) + return nil, nil, 0, fmt.Errorf("sm2: invalid point uncompressed/hybrid form bytes length %d", len(bytes)) } data := make([]byte, 1+byteLen*2) data[0] = uncompressed copy(data[1:], bytes[1:1+byteLen*2]) x, y := elliptic.Unmarshal(curve, data) if x == nil || y == nil { - return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) + return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen*2, nil case compressed02, compressed03: if len(bytes) < 1+byteLen { - return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) + return nil, nil, 0, fmt.Errorf("sm2: invalid point compressed form bytes length %d", len(bytes)) } // Make sure it's NIST curve or SM2 P-256 curve if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { // y² = x³ - 3x + b, prime curves x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen]) if x == nil || y == nil { - return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) + return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen, nil } - return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) + return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name) } - return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) + return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format) } var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2} @@ -201,7 +201,7 @@ func (*SM2SignerOption) HashFunc() crypto.Hash { // FromECPrivateKey convert an ecdsa private key to SM2 private key. func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { if key.Curve != P256() { - return nil, errors.New("SM2: it's NOT a sm2 curve private key") + return nil, errors.New("sm2: it's NOT a sm2 curve private key") } priv.PrivateKey = *key return priv, nil @@ -346,7 +346,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter } //A3, requirement is to check if h*P is infinite point, h is 1 if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { - return nil, errors.New("SM2: invalid public key") + return nil, errors.New("sm2: invalid public key") } for { //A1, generate random k @@ -368,7 +368,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter if !success { kdfCount++ if kdfCount > maxRetryLimit { - return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount) + return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", kdfCount) } continue } @@ -430,7 +430,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error msgLen := len(c2) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) if !success { - return nil, errors.New("SM2: invalid cipher text") + return nil, errors.New("sm2: invalid cipher text") } //B5, calculate msg = c2 ^ t @@ -441,7 +441,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { if c3[i] != u[i] { - return nil, errors.New("SM2: invalid hash value") + return nil, errors.New("sm2: invalid hash value") } } return msg, nil @@ -460,7 +460,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte, } ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { - return nil, errors.New("SM2: invalid ciphertext length") + return nil, errors.New("sm2: invalid ciphertext length") } curve := priv.Curve // B1, get C1, and check C1 @@ -496,7 +496,7 @@ func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []b !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || !inner.Empty() { - return nil, nil, nil, nil, errors.New("SM2: invalid asn1 format ciphertext") + return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext") } return x1, y1, c2, c3, nil } @@ -523,12 +523,12 @@ func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error // PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { if ciphertext[0] == 0x30 { - return nil, errors.New("SM2: invalid plain encoding ciphertext") + return nil, errors.New("sm2: invalid plain encoding ciphertext") } curve := P256() ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { - return nil, errors.New("SM2: invalid ciphertext length") + return nil, errors.New("sm2: invalid ciphertext length") } // get C1, and check C1 x1, y1, c3Start, err := bytes2Point(curve, ciphertext) @@ -556,7 +556,7 @@ func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicin } ciphertextLen := len(ciphertext) if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { - return nil, errors.New("SM2: invalid ciphertext length") + return nil, errors.New("sm2: invalid ciphertext length") } // get C1, and check C1 @@ -741,7 +741,7 @@ func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { uidLen := len(uid) if uidLen >= 0x2000 { - return nil, errors.New("the uid is too long") + return nil, errors.New("sm2: the uid is too long") } entla := uint16(uidLen) << 3 md := sm3.New() @@ -895,5 +895,5 @@ var zeroReader = &zr{} // IsSM2PublicKey check if given public key is a SM2 public key or not func IsSM2PublicKey(publicKey interface{}) bool { pub, ok := publicKey.(*ecdsa.PublicKey) - return ok && strings.EqualFold(P256().Params().Name, pub.Curve.Params().Name) + return ok && pub.Curve == P256() } diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index e0f0666..3dfa8ed 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -208,8 +208,8 @@ func Test_encryptDecrypt(t *testing.T) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } - // mixed mode - encrypterOpts = NewPlainEncrypterOpts(MarshalMixed, C1C3C2) + // hybrid mode + encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) if err != nil { t.Fatalf("encrypt failed %v", err)