mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 20:26:19 +08:00
use golang sdk as much as possible
This commit is contained in:
parent
fc1411a702
commit
01063b1ff7
@ -103,11 +103,13 @@ func NewPlainDecrypterOpts(splicingOrder ciphertextSplicingOrder) *DecrypterOpts
|
|||||||
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
func (mode pointMarshalMode) mashal(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
switch mode {
|
switch mode {
|
||||||
case MarshalCompressed:
|
case MarshalCompressed:
|
||||||
return point2CompressedBytes(curve, x, y)
|
return elliptic.MarshalCompressed(curve, x, y)
|
||||||
case MarshalMixed:
|
case MarshalMixed:
|
||||||
return point2MixedBytes(curve, x, y)
|
buffer := elliptic.Marshal(curve, x, y)
|
||||||
|
buffer[0] = byte(y.Bit(0)) | mixed06
|
||||||
|
return buffer
|
||||||
default:
|
default:
|
||||||
return point2UncompressedBytes(curve, x, y)
|
return elliptic.Marshal(curve, x, y)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
65
sm2/util.go
65
sm2/util.go
@ -2,7 +2,6 @@ package sm2
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/elliptic"
|
"crypto/elliptic"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/big"
|
"math/big"
|
||||||
@ -11,55 +10,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
||||||
bytes := value.Bytes()
|
|
||||||
byteLen := (curve.Params().BitSize + 7) >> 3
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||||
if byteLen == len(bytes) {
|
|
||||||
return bytes
|
|
||||||
}
|
|
||||||
result := make([]byte, byteLen)
|
result := make([]byte, byteLen)
|
||||||
copy(result[byteLen-len(bytes):], bytes)
|
value.FillBytes(result[:])
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
||||||
return elliptic.Marshal(curve, x, y)
|
|
||||||
}
|
|
||||||
|
|
||||||
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
||||||
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
|
|
||||||
copy(buffer[1:], toBytes(curve, x))
|
|
||||||
buffer[0] = byte(y.Bit(0)) | compressed02
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
||||||
buffer := elliptic.Marshal(curve, x, y)
|
|
||||||
buffer[0] = byte(y.Bit(0)) | mixed06
|
|
||||||
return buffer
|
|
||||||
}
|
|
||||||
|
|
||||||
func toPointXY(bytes []byte) *big.Int {
|
|
||||||
return new(big.Int).SetBytes(bytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
|
|
||||||
x3 := new(big.Int).Mul(x, x)
|
|
||||||
x3.Mul(x3, x)
|
|
||||||
|
|
||||||
threeX := new(big.Int).Lsh(x, 1)
|
|
||||||
threeX.Add(threeX, x)
|
|
||||||
|
|
||||||
x3.Sub(x3, threeX)
|
|
||||||
x3.Add(x3, curve.Params().B)
|
|
||||||
x3.Mod(x3, curve.Params().P)
|
|
||||||
y := x3.ModSqrt(x3, curve.Params().P)
|
|
||||||
|
|
||||||
if y == nil {
|
|
||||||
return nil, errors.New("can't calculate y based on x")
|
|
||||||
}
|
|
||||||
return y, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
||||||
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
||||||
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
|
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
|
||||||
@ -71,9 +27,11 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
|
|||||||
if len(bytes) < 1+byteLen*2 {
|
if len(bytes) < 1+byteLen*2 {
|
||||||
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
|
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
|
||||||
}
|
}
|
||||||
x := toPointXY(bytes[1 : 1+byteLen])
|
data := make([]byte, 1+byteLen*2)
|
||||||
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
data[0] = uncompressed
|
||||||
if !curve.IsOnCurve(x, y) {
|
copy(data[1:], bytes[1:1+byteLen*2])
|
||||||
|
x, y := elliptic.Unmarshal(curve, data)
|
||||||
|
if x == nil || y == nil {
|
||||||
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
||||||
}
|
}
|
||||||
return x, y, 1 + byteLen*2, nil
|
return x, y, 1 + byteLen*2, nil
|
||||||
@ -84,15 +42,8 @@ func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, e
|
|||||||
// Make sure it's NIST curve or SM2 P-256 curve
|
// Make sure it's NIST curve or SM2 P-256 curve
|
||||||
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
|
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
|
||||||
// y² = x³ - 3x + b, prime curves
|
// y² = x³ - 3x + b, prime curves
|
||||||
x := toPointXY(bytes[1 : 1+byteLen])
|
x, y := elliptic.UnmarshalCompressed(curve, bytes[:1+byteLen])
|
||||||
y, err := calculatePrimeCurveY(curve, x)
|
if x == nil || y == nil {
|
||||||
if err != nil {
|
|
||||||
return nil, nil, 0, err
|
|
||||||
}
|
|
||||||
if byte(y.Bit(0)) != bytes[0]&1 {
|
|
||||||
y.Neg(y).Mod(y, curve.Params().P)
|
|
||||||
}
|
|
||||||
if !curve.IsOnCurve(x, y) {
|
|
||||||
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
return nil, nil, 0, fmt.Errorf("point is not on curve %s", curve.Params().Name)
|
||||||
}
|
}
|
||||||
return x, y, 1 + byteLen, nil
|
return x, y, 1 + byteLen, nil
|
||||||
|
@ -30,27 +30,3 @@ func Test_toBytes(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_toPointXY(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
bytes string
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want string
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
|
||||||
{"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
bytes, _ := hex.DecodeString(tt.args.bytes)
|
|
||||||
expectedInt, _ := new(big.Int).SetString(tt.want, 16)
|
|
||||||
if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) {
|
|
||||||
t.Errorf("toPointXY() = %v, want %v", got, expectedInt)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user