mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-22 02:06:18 +08:00
144 lines
3.7 KiB
Go
144 lines
3.7 KiB
Go
package sm2
|
|
|
|
import (
|
|
"crypto/elliptic"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
var zero = big.NewInt(0)
|
|
|
|
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
|
bytes := value.Bytes()
|
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
|
if byteLen == len(bytes) {
|
|
return bytes
|
|
}
|
|
result := make([]byte, byteLen)
|
|
copy(result[byteLen-len(bytes):], bytes)
|
|
return result
|
|
}
|
|
|
|
func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
return elliptic.Marshal(curve, x, y)
|
|
}
|
|
|
|
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
|
|
copy(buffer[1:], toBytes(curve, x))
|
|
if getLastBitOfY(x, y) > 0 {
|
|
buffer[0] = compressed03
|
|
} else {
|
|
buffer[0] = compressed02
|
|
}
|
|
return buffer
|
|
}
|
|
|
|
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
|
buffer := elliptic.Marshal(curve, x, y)
|
|
if getLastBitOfY(x, y) > 0 {
|
|
buffer[0] = mixed07
|
|
} else {
|
|
buffer[0] = mixed06
|
|
}
|
|
return buffer
|
|
}
|
|
|
|
func getLastBitOfY(x, y *big.Int) uint {
|
|
if x.Cmp(zero) == 0 {
|
|
return 0
|
|
}
|
|
return y.Bit(0)
|
|
}
|
|
|
|
func toPointXY(bytes []byte) *big.Int {
|
|
return new(big.Int).SetBytes(bytes)
|
|
}
|
|
|
|
func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
|
|
x3 := new(big.Int).Mul(x, x)
|
|
x3.Mul(x3, x)
|
|
|
|
threeX := new(big.Int).Lsh(x, 1)
|
|
threeX.Add(threeX, x)
|
|
|
|
x3.Sub(x3, threeX)
|
|
x3.Add(x3, curve.Params().B)
|
|
x3.Mod(x3, curve.Params().P)
|
|
y := x3.ModSqrt(x3, curve.Params().P)
|
|
|
|
if y == nil {
|
|
return nil, errors.New("can't calculate y based on x")
|
|
}
|
|
return y, nil
|
|
}
|
|
|
|
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
|
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
|
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
|
|
}
|
|
format := bytes[0]
|
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
|
switch format {
|
|
case uncompressed, mixed06, mixed07: // what's the mixed format purpose?
|
|
if len(bytes) < 1+byteLen*2 {
|
|
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
|
|
}
|
|
x := toPointXY(bytes[1 : 1+byteLen])
|
|
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
|
if !curve.IsOnCurve(x, y) {
|
|
return nil, nil, 0, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
|
|
}
|
|
return x, y, 1 + byteLen*2, nil
|
|
case compressed02, compressed03:
|
|
if len(bytes) < 1+byteLen {
|
|
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
|
|
}
|
|
if strings.HasPrefix(curve.Params().Name, "P-") || strings.EqualFold(curve.Params().Name, p256.CurveParams.Name) {
|
|
// y² = x³ - 3x + b, prime curves
|
|
x := toPointXY(bytes[1 : 1+byteLen])
|
|
y, err := calculatePrimeCurveY(curve, x)
|
|
if err != nil {
|
|
return nil, nil, 0, err
|
|
}
|
|
|
|
if (getLastBitOfY(x, y) > 0 && format == compressed02) || (getLastBitOfY(x, y) == 0 && format == compressed03) {
|
|
y.Sub(curve.Params().P, y)
|
|
}
|
|
return x, y, 1 + byteLen, nil
|
|
}
|
|
return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
|
|
}
|
|
return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
|
|
}
|
|
|
|
var (
|
|
closedChanOnce sync.Once
|
|
closedChan chan struct{}
|
|
)
|
|
|
|
// maybeReadByte reads a single byte from r with ~50% probability. This is used
|
|
// to ensure that callers do not depend on non-guaranteed behaviour, e.g.
|
|
// assuming that rsa.GenerateKey is deterministic w.r.t. a given random stream.
|
|
//
|
|
// This does not affect tests that pass a stream of fixed bytes as the random
|
|
// source (e.g. a zeroReader).
|
|
func maybeReadByte(r io.Reader) {
|
|
closedChanOnce.Do(func() {
|
|
closedChan = make(chan struct{})
|
|
close(closedChan)
|
|
})
|
|
|
|
select {
|
|
case <-closedChan:
|
|
return
|
|
case <-closedChan:
|
|
var buf [1]byte
|
|
r.Read(buf[:])
|
|
}
|
|
}
|