refactoring, align error message pattern

This commit is contained in:
Sun Yimin 2022-05-27 17:46:14 +08:00 committed by GitHub
parent 255b3d3e7e
commit bb0f4f7996
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 30 deletions

View File

@ -1,4 +1,4 @@
// Package sm2 handle shangmi sm2 algorithm and its curve implementation // Package sm2 handle shangmi sm2 digital signature and public key encryption algorithm and its curve implementation
package sm2 package sm2
// Further references: // Further references:
@ -33,9 +33,9 @@ import (
const ( const (
uncompressed byte = 0x04 uncompressed byte = 0x04
compressed02 byte = 0x02 compressed02 byte = 0x02
compressed03 byte = 0x03 compressed03 byte = compressed02 | 0x01
mixed06 byte = 0x06 hybrid06 byte = 0x06
mixed07 byte = 0x07 hybrid07 byte = hybrid06 | 0x01
) )
// A invertible implements fast inverse in GF(N). // A invertible implements fast inverse in GF(N).
@ -62,8 +62,8 @@ const (
MarshalUncompressed pointMarshalMode = iota MarshalUncompressed pointMarshalMode = iota
//MarshalCompressed compressed mashal mode //MarshalCompressed compressed mashal mode
MarshalCompressed MarshalCompressed
//MarshalMixed mixed mashal mode //MarshalHybrid hybrid mashal mode
MarshalMixed MarshalHybrid
) )
type ciphertextSplicingOrder byte type ciphertextSplicingOrder byte
@ -105,9 +105,9 @@ func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte
switch mode { switch mode {
case MarshalCompressed: case MarshalCompressed:
return elliptic.MarshalCompressed(curve, x, y) return elliptic.MarshalCompressed(curve, x, y)
case MarshalMixed: case MarshalHybrid:
buffer := elliptic.Marshal(curve, x, y) buffer := elliptic.Marshal(curve, x, y)
buffer[0] = byte(y.Bit(0)) | mixed06 buffer[0] = byte(y.Bit(0)) | hybrid06
return buffer return buffer
default: default:
return elliptic.Marshal(curve, x, y) return elliptic.Marshal(curve, x, y)
@ -123,39 +123,39 @@ func toBytes(curve elliptic.Curve, value *big.Int) []byte {
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
if len(bytes) < 1+(curve.Params().BitSize/8) { if len(bytes) < 1+(curve.Params().BitSize/8) {
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) return nil, nil, 0, fmt.Errorf("sm2: invalid bytes length %d", len(bytes))
} }
format := bytes[0] format := bytes[0]
byteLen := (curve.Params().BitSize + 7) >> 3 byteLen := (curve.Params().BitSize + 7) >> 3
switch format { switch format {
case uncompressed, mixed06, mixed07: // what's the mixed format purpose? case uncompressed, hybrid06, hybrid07: // what's the hybrid format purpose?
if len(bytes) < 1+byteLen*2 { if len(bytes) < 1+byteLen*2 {
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) return nil, nil, 0, fmt.Errorf("sm2: invalid point uncompressed/hybrid form bytes length %d", len(bytes))
} }
data := make([]byte, 1+byteLen*2) data := make([]byte, 1+byteLen*2)
data[0] = uncompressed data[0] = uncompressed
copy(data[1:], bytes[1:1+byteLen*2]) copy(data[1:], bytes[1:1+byteLen*2])
x, y := elliptic.Unmarshal(curve, data) x, y := elliptic.Unmarshal(curve, data)
if x == nil || y == nil { if x == nil || y == nil {
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name)
} }
return x, y, 1 + byteLen*2, nil return x, y, 1 + byteLen*2, nil
case compressed02, compressed03: case compressed02, compressed03:
if len(bytes) < 1+byteLen { if len(bytes) < 1+byteLen {
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes)) return nil, nil, 0, fmt.Errorf("sm2: invalid point compressed form bytes length %d", len(bytes))
} }
// Make sure it's NIST curve or SM2 P-256 curve // Make sure it's NIST curve or SM2 P-256 curve
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
// y² = x³ - 3x + b, prime curves // y² = x³ - 3x + b, prime curves
x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen]) x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen])
if x == nil || y == nil { if x == nil || y == nil {
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name)
} }
return x, y, 1 + byteLen, nil return x, y, 1 + byteLen, nil
} }
return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name) return nil, nil, 0, fmt.Errorf("sm2: unsupport point form %d, curve %s", format, curve.Params().Name)
} }
return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format) return nil, nil, 0, fmt.Errorf("sm2: unknown point form %d", format)
} }
var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2} var defaultEncrypterOpts = &EncrypterOpts{ENCODING_PLAIN, MarshalUncompressed, C1C3C2}
@ -201,7 +201,7 @@ func (*SM2SignerOption) HashFunc() crypto.Hash {
// FromECPrivateKey convert an ecdsa private key to SM2 private key. // FromECPrivateKey convert an ecdsa private key to SM2 private key.
func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) { func (priv *PrivateKey) FromECPrivateKey(key *ecdsa.PrivateKey) (*PrivateKey, error) {
if key.Curve != P256() { if key.Curve != P256() {
return nil, errors.New("SM2: it's NOT a sm2 curve private key") return nil, errors.New("sm2: it's NOT a sm2 curve private key")
} }
priv.PrivateKey = *key priv.PrivateKey = *key
return priv, nil return priv, nil
@ -346,7 +346,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
} }
//A3, requirement is to check if h*P is infinite point, h is 1 //A3, requirement is to check if h*P is infinite point, h is 1
if pub.X.Sign() == 0 && pub.Y.Sign() == 0 { if pub.X.Sign() == 0 && pub.Y.Sign() == 0 {
return nil, errors.New("SM2: invalid public key") return nil, errors.New("sm2: invalid public key")
} }
for { for {
//A1, generate random k //A1, generate random k
@ -368,7 +368,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter
if !success { if !success {
kdfCount++ kdfCount++
if kdfCount > maxRetryLimit { if kdfCount > maxRetryLimit {
return nil, fmt.Errorf("SM2: A5, failed to calculate valid t, tried %v times", kdfCount) return nil, fmt.Errorf("sm2: A5, failed to calculate valid t, tried %v times", kdfCount)
} }
continue continue
} }
@ -430,7 +430,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
msgLen := len(c2) msgLen := len(c2)
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen) t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
if !success { if !success {
return nil, errors.New("SM2: invalid cipher text") return nil, errors.New("sm2: invalid cipher text")
} }
//B5, calculate msg = c2 ^ t //B5, calculate msg = c2 ^ t
@ -441,7 +441,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error
u := calculateC3(curve, x2, y2, msg) u := calculateC3(curve, x2, y2, msg)
for i := 0; i < sm3.Size; i++ { for i := 0; i < sm3.Size; i++ {
if c3[i] != u[i] { if c3[i] != u[i] {
return nil, errors.New("SM2: invalid hash value") return nil, errors.New("sm2: invalid hash value")
} }
} }
return msg, nil return msg, nil
@ -460,7 +460,7 @@ func decrypt(priv *PrivateKey, ciphertext []byte, opts *DecrypterOpts) ([]byte,
} }
ciphertextLen := len(ciphertext) ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size { if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length") return nil, errors.New("sm2: invalid ciphertext length")
} }
curve := priv.Curve curve := priv.Curve
// B1, get C1, and check C1 // B1, get C1, and check C1
@ -496,7 +496,7 @@ func unmarshalASN1Ciphertext(ciphertext []byte) (*big.Int, *big.Int, []byte, []b
!inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) || !inner.ReadASN1Bytes(&c3, asn1.OCTET_STRING) ||
!inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) || !inner.ReadASN1Bytes(&c2, asn1.OCTET_STRING) ||
!inner.Empty() { !inner.Empty() {
return nil, nil, nil, nil, errors.New("SM2: invalid asn1 format ciphertext") return nil, nil, nil, nil, errors.New("sm2: invalid asn1 format ciphertext")
} }
return x1, y1, c2, c3, nil return x1, y1, c2, c3, nil
} }
@ -523,12 +523,12 @@ func ASN1Ciphertext2Plain(ciphertext []byte, opts *EncrypterOpts) ([]byte, error
// PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format // PlainCiphertext2ASN1 utility method to convert plain encoding ciphertext to ASN.1 encoding format
func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) { func PlainCiphertext2ASN1(ciphertext []byte, from ciphertextSplicingOrder) ([]byte, error) {
if ciphertext[0] == 0x30 { if ciphertext[0] == 0x30 {
return nil, errors.New("SM2: invalid plain encoding ciphertext") return nil, errors.New("sm2: invalid plain encoding ciphertext")
} }
curve := P256() curve := P256()
ciphertextLen := len(ciphertext) ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length") return nil, errors.New("sm2: invalid ciphertext length")
} }
// get C1, and check C1 // get C1, and check C1
x1, y1, c3Start, err := bytes2Point(curve, ciphertext) x1, y1, c3Start, err := bytes2Point(curve, ciphertext)
@ -556,7 +556,7 @@ func AdjustCiphertextSplicingOrder(ciphertext []byte, from, to ciphertextSplicin
} }
ciphertextLen := len(ciphertext) ciphertextLen := len(ciphertext)
if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size { if ciphertextLen <= 1+(curve.Params().BitSize/8)+sm3.Size {
return nil, errors.New("SM2: invalid ciphertext length") return nil, errors.New("sm2: invalid ciphertext length")
} }
// get C1, and check C1 // get C1, and check C1
@ -741,7 +741,7 @@ func CalculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) { func calculateZA(pub *ecdsa.PublicKey, uid []byte) ([]byte, error) {
uidLen := len(uid) uidLen := len(uid)
if uidLen >= 0x2000 { if uidLen >= 0x2000 {
return nil, errors.New("the uid is too long") return nil, errors.New("sm2: the uid is too long")
} }
entla := uint16(uidLen) << 3 entla := uint16(uidLen) << 3
md := sm3.New() md := sm3.New()
@ -895,5 +895,5 @@ var zeroReader = &zr{}
// IsSM2PublicKey check if given public key is a SM2 public key or not // IsSM2PublicKey check if given public key is a SM2 public key or not
func IsSM2PublicKey(publicKey interface{}) bool { func IsSM2PublicKey(publicKey interface{}) bool {
pub, ok := publicKey.(*ecdsa.PublicKey) pub, ok := publicKey.(*ecdsa.PublicKey)
return ok && strings.EqualFold(P256().Params().Name, pub.Curve.Params().Name) return ok && pub.Curve == P256()
} }

View File

@ -208,8 +208,8 @@ func Test_encryptDecrypt(t *testing.T) {
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
} }
// mixed mode // hybrid mode
encrypterOpts = NewPlainEncrypterOpts(MarshalMixed, C1C3C2) encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts) ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), encrypterOpts)
if err != nil { if err != nil {
t.Fatalf("encrypt failed %v", err) t.Fatalf("encrypt failed %v", err)