From 01063b1ff71dee91a48c4e81122beed101ddfdb6 Mon Sep 17 00:00:00 2001 From: Sun Yimin Date: Wed, 18 May 2022 15:33:33 +0800 Subject: [PATCH] use golang sdk as much as possible --- sm2/sm2.go | 8 +++--- sm2/util.go | 65 ++++++------------------------------------------ sm2/util_test.go | 24 ------------------ 3 files changed, 13 insertions(+), 84 deletions(-) diff --git a/sm2/sm2.go b/sm2/sm2.go index d380e39..fbdd9c7 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -103,11 +103,13 @@ func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte { switch mode { case MarshalCompressed: - return point2CompressedBytes(curve, x, y) + return elliptic.MarshalCompressed(curve, x, y) case MarshalMixed: - return point2MixedBytes(curve, x, y) + buffer := elliptic.Marshal(curve, x, y) + buffer[0] = byte(y.Bit(0)) | mixed06 + return buffer default: - return point2UncompressedBytes(curve, x, y) + return elliptic.Marshal(curve, x, y) } } diff --git a/sm2/util.go b/sm2/util.go index b546ab5..927ec44 100644 --- a/sm2/util.go +++ b/sm2/util.go @@ -2,7 +2,6 @@ package sm2 import ( "crypto/elliptic" - "errors" "fmt" "io" "math/big" @@ -11,55 +10,12 @@ import ( ) func toBytes(curve elliptic.Curve, value *big.Int) []byte { - bytes := value.Bytes() byteLen := (curve.Params().BitSize + 7) >> 3 - if byteLen == len(bytes) { - return bytes - } result := make([]byte, byteLen) - copy(result[byteLen-len(bytes):], bytes) + value.FillBytes(result[:]) return result } -func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { - return elliptic.Marshal(curve, x, y) -} - -func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte { - buffer := make([]byte, (curve.Params().BitSize+7)>>3+1) - copy(buffer[1:], toBytes(curve, x)) - buffer[0] = byte(y.Bit(0)) | compressed02 - return buffer -} - -func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte { - buffer := elliptic.Marshal(curve, x, y) - buffer[0] = byte(y.Bit(0)) | mixed06 - return buffer -} - -func toPointXY(bytes []byte) *big.Int { - return new(big.Int).SetBytes(bytes) -} - -func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) { - x3 := new(big.Int).Mul(x, x) - x3.Mul(x3, x) - - threeX := new(big.Int).Lsh(x, 1) - threeX.Add(threeX, x) - - x3.Sub(x3, threeX) - x3.Add(x3, curve.Params().B) - x3.Mod(x3, curve.Params().P) - y := x3.ModSqrt(x3, curve.Params().P) - - if y == nil { - return nil, errors.New("can't calculate y based on x") - } - return y, nil -} - func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) { if len(bytes) < 1+(curve.Params().BitSize/8) { return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes)) @@ -71,9 +27,11 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e if len(bytes) < 1+byteLen*2 { return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes)) } - x := toPointXY(bytes[1 : 1+byteLen]) - y := toPointXY(bytes[1+byteLen : 1+byteLen*2]) - if !curve.IsOnCurve(x, y) { + data := make([]byte, 1+byteLen*2) + data[0] = uncompressed + copy(data[1:], bytes[1:1+byteLen*2]) + x, y := elliptic.Unmarshal(curve, data) + if x == nil || y == nil { return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen*2, nil @@ -84,15 +42,8 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e // Make sure it's NIST curve or SM2 P-256 curve if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) { // y² = x³ - 3x + b, prime curves - x := toPointXY(bytes[1 : 1+byteLen]) - y, err := calculatePrimeCurveY(curve, x) - if err != nil { - return nil, nil, 0, err - } - if byte(y.Bit(0)) != bytes[0]&1 { - y.Neg(y).Mod(y, curve.Params().P) - } - if !curve.IsOnCurve(x, y) { + x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen]) + if x == nil || y == nil { return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name) } return x, y, 1 + byteLen, nil diff --git a/sm2/util_test.go b/sm2/util_test.go index 6828f24..1376d3b 100644 --- a/sm2/util_test.go +++ b/sm2/util_test.go @@ -30,27 +30,3 @@ func Test_toBytes(t *testing.T) { }) } } - -func Test_toPointXY(t *testing.T) { - type args struct { - bytes string - } - tests := []struct { - name string - args args - want string - }{ - // TODO: Add test cases. - {"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, - {"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - bytes, _ := hex.DecodeString(tt.args.bytes) - expectedInt, _ := new(big.Int).SetString(tt.want, 16) - if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) { - t.Errorf("toPointXY() = %v, want %v", got, expectedInt) - } - }) - } -}