diff --git a/sm2/sm2.go b/sm2/sm2.go index 2594e49..aba13ac 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -136,7 +136,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e data := make([]byte, 1+byteLen*2) data[0] = uncompressed copy(data[1:], bytes[1:1+byteLen*2]) - x, y := elliptic.Unmarshal(curve, data) + x, y := sm2ec.Unmarshal(curve, data) if x == nil || y == nil { return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) } @@ -148,7 +148,7 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e // Make sure it's NIST curve or SM2 P-256 curve if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, sm2ec.P256().Params().Name) { // y² = x³ - 3x + b, prime curves - x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen]) + x, y := sm2ec.UnmarshalCompressed(curve, bytes[:1+byteLen]) if x == nil || y == nil { return nil, nil, 0, fmt.Errorf("sm2: point is not on curve %s", curve.Params().Name) } diff --git a/sm2/sm2ec/elliptic.go b/sm2/sm2ec/elliptic.go index e34e6c5..b0c92a6 100644 --- a/sm2/sm2ec/elliptic.go +++ b/sm2/sm2ec/elliptic.go @@ -34,3 +34,29 @@ func P256() elliptic.Curve { initonce.Do(initAll) return sm2p256 } + +// Since golang 1.19 +// unmarshaler is implemented by curves with their own constant-time Unmarshal. +// There isn't an equivalent interface for Marshal/MarshalCompressed because +// that doesn't involve any mathematical operations, only FillBytes and Bit. +type unmarshaler interface { + Unmarshal([]byte) (x, y *big.Int) + UnmarshalCompressed([]byte) (x, y *big.Int) +} + +func Unmarshal(curve elliptic.Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.Unmarshal(data) + } + return elliptic.Unmarshal(curve, data) +} + +// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into +// an x, y pair. It is an error if the point is not in compressed form, is not +// on the curve, or is the point at infinity. On error, x = nil. +func UnmarshalCompressed(curve elliptic.Curve, data []byte) (x, y *big.Int) { + if c, ok := curve.(unmarshaler); ok { + return c.UnmarshalCompressed(data) + } + return elliptic.UnmarshalCompressed(curve, data) +} diff --git a/sm2/sm2ec/elliptic_test.go b/sm2/sm2ec/elliptic_test.go index 760ec2f..a084c25 100644 --- a/sm2/sm2ec/elliptic_test.go +++ b/sm2/sm2ec/elliptic_test.go @@ -11,32 +11,6 @@ import ( var _ = elliptic.P256() // force NIST P curves init, avoid panic when we invoke generic implementation's method -// unmarshaler is implemented by curves with their own constant-time Unmarshal. -// Since golang 1.19 -// There isn't an equivalent interface for Marshal/MarshalCompressed because -// that doesn't involve any mathematical operations, only FillBytes and Bit. -type unmarshaler interface { - Unmarshal([]byte) (x, y *big.Int) - UnmarshalCompressed([]byte) (x, y *big.Int) -} - -func unmarshal(curve elliptic.Curve, data []byte) (x, y *big.Int) { - if c, ok := curve.(unmarshaler); ok { - return c.Unmarshal(data) - } - return elliptic.Unmarshal(curve, data) -} - -// UnmarshalCompressed converts a point, serialized by MarshalCompressed, into -// an x, y pair. It is an error if the point is not in compressed form, is not -// on the curve, or is the point at infinity. On error, x = nil. -func unmarshalCompressed(curve elliptic.Curve, data []byte) (x, y *big.Int) { - if c, ok := curve.(unmarshaler); ok { - return c.UnmarshalCompressed(data) - } - return elliptic.UnmarshalCompressed(curve, data) -} - // genericParamsForCurve returns the dereferenced CurveParams for // the specified curve. This is used to avoid the logic for // upgrading a curve to its specific implementation, forcing @@ -87,7 +61,7 @@ func TestOffCurve(t *testing.T) { x.FillBytes(b[1 : 1+byteLen]) y.FillBytes(b[1+byteLen : 1+2*byteLen]) - x1, y1 := unmarshal(curve, b) + x1, y1 := Unmarshal(curve, b) if x1 != nil || y1 != nil { t.Errorf("unmarshaling a point not on the curve succeeded") } @@ -152,18 +126,18 @@ func testInfinity(t *testing.T, curve elliptic.Curve) { t.Errorf("IsOnCurve(∞) == true") } - if xx, yy := unmarshal(curve, elliptic.Marshal(curve, x0, y0)); xx != nil || yy != nil { + if xx, yy := Unmarshal(curve, elliptic.Marshal(curve, x0, y0)); xx != nil || yy != nil { t.Errorf("Unmarshal(Marshal(∞)) did not return an error") } // We don't test UnmarshalCompressed(MarshalCompressed(∞)) because there are // two valid points with x = 0. - if xx, yy := unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { + if xx, yy := Unmarshal(curve, []byte{0x00}); xx != nil || yy != nil { t.Errorf("Unmarshal(∞) did not return an error") } byteLen := (curve.Params().BitSize + 7) / 8 buf := make([]byte, byteLen*2+1) buf[0] = 4 // Uncompressed format. - if xx, yy := unmarshal(curve, buf); xx != nil || yy != nil { + if xx, yy := Unmarshal(curve, buf); xx != nil || yy != nil { t.Errorf("Unmarshal((0,0)) did not return an error") } } @@ -175,7 +149,7 @@ func TestMarshal(t *testing.T) { t.Fatal(err) } serialized := elliptic.Marshal(curve, x, y) - xx, yy := unmarshal(curve, serialized) + xx, yy := Unmarshal(curve, serialized) if xx == nil { t.Fatal("failed to unmarshal") } @@ -256,7 +230,7 @@ func TestMarshalCompressed(t *testing.T) { t.Run("Invalid", func(t *testing.T) { data, _ := hex.DecodeString("02fd4bf61763b46581fd9174d623516cf3c81edd40e29ffa2777fb6cb0ae3ce535") - X, Y := unmarshalCompressed(P256(), data) + X, Y := UnmarshalCompressed(P256(), data) if X != nil || Y != nil { t.Error("expected an error for invalid encoding") } @@ -284,7 +258,7 @@ func testMarshalCompressed(t *testing.T, curve elliptic.Curve, x, y *big.Int, wa t.Errorf("got unexpected MarshalCompressed result: got %x, want %x", got, want) } - X, Y := unmarshalCompressed(curve, got) + X, Y := UnmarshalCompressed(curve, got) if X == nil || Y == nil { t.Fatalf("UnmarshalCompressed failed unexpectedly") } @@ -354,7 +328,7 @@ func BenchmarkMarshalUnmarshal(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { buf := elliptic.Marshal(curve, x, y) - xx, yy := unmarshal(curve, buf) + xx, yy := Unmarshal(curve, buf) if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { b.Error("Unmarshal output different from Marshal input") } @@ -364,7 +338,7 @@ func BenchmarkMarshalUnmarshal(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { buf := elliptic.MarshalCompressed(curve, x, y) - xx, yy := unmarshalCompressed(curve, buf) + xx, yy := UnmarshalCompressed(curve, buf) if xx.Cmp(x) != 0 || yy.Cmp(y) != 0 { b.Error("Unmarshal output different from Marshal input") }