mirror of
https://github.com/emmansun/gmsm.git
synced 2025-04-26 12:16:20 +08:00
MAGIC - sm2, basic implementation
This commit is contained in:
parent
4d7305a6f6
commit
be62e3a042
156
sm2/sm2.go
Normal file
156
sm2/sm2.go
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
package sm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"gmsm/sm3"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
Uncompressed byte = 0x04
|
||||||
|
Compressed_02 byte = 0x02
|
||||||
|
Compressed_03 byte = 0x03
|
||||||
|
Mixed_06 byte = 0x06
|
||||||
|
Mixed_07 byte = 0x07
|
||||||
|
)
|
||||||
|
|
||||||
|
///////////////// below code ship from golan crypto/ecdsa ////////////////////
|
||||||
|
var one = new(big.Int).SetInt64(1)
|
||||||
|
|
||||||
|
// randFieldElement returns a random element of the field underlying the given
|
||||||
|
// curve using the procedure given in [NSA] A.2.1.
|
||||||
|
func randFieldElement(c elliptic.Curve, rand io.Reader) (k *big.Int, err error) {
|
||||||
|
params := c.Params()
|
||||||
|
b := make([]byte, params.BitSize/8+8)
|
||||||
|
_, err = io.ReadFull(rand, b)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
k = new(big.Int).SetBytes(b)
|
||||||
|
n := new(big.Int).Sub(params.N, one)
|
||||||
|
k.Mod(k, n)
|
||||||
|
k.Add(k, one)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////////////////////
|
||||||
|
func kdf(z []byte, len int) ([]byte, bool) {
|
||||||
|
limit := (len + sm3.Size - 1) / sm3.Size
|
||||||
|
sm3Hasher := sm3.New()
|
||||||
|
var countBytes [4]byte
|
||||||
|
var ct uint32 = 1
|
||||||
|
k := make([]byte, len+sm3.Size-1)
|
||||||
|
for i := 0; i < limit; i++ {
|
||||||
|
binary.BigEndian.PutUint32(countBytes[:], ct)
|
||||||
|
sm3Hasher.Write(z)
|
||||||
|
sm3Hasher.Write(countBytes[:])
|
||||||
|
copy(k[i*sm3.Size:], sm3Hasher.Sum(nil))
|
||||||
|
ct++
|
||||||
|
sm3Hasher.Reset()
|
||||||
|
}
|
||||||
|
for i := 0; i < len; i++ {
|
||||||
|
if k[i] != 0 {
|
||||||
|
return k[:len], true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return k, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculateC3(curve elliptic.Curve, x2, y2 *big.Int, msg []byte) []byte {
|
||||||
|
hasher := sm3.New()
|
||||||
|
hasher.Write(toBytes(curve, x2))
|
||||||
|
hasher.Write(msg)
|
||||||
|
hasher.Write(toBytes(curve, y2))
|
||||||
|
return hasher.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encrypt sm2 encrypt implementation
|
||||||
|
func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte) ([]byte, error) {
|
||||||
|
curve := pub.Curve
|
||||||
|
msgLen := len(msg)
|
||||||
|
for {
|
||||||
|
//A1, generate random k
|
||||||
|
k, err := randFieldElement(curve, random)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//A2, calculate C1 = k * G
|
||||||
|
x1, y1 := curve.ScalarBaseMult(k.Bytes())
|
||||||
|
c1 := point2CompressedBytes(curve, x1, y1)
|
||||||
|
|
||||||
|
//A3, skipped
|
||||||
|
//A4, calculate k * P (point of Public Key)
|
||||||
|
x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
|
||||||
|
|
||||||
|
//A5, calculate t=KDF(x2||y2, klen)
|
||||||
|
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
|
if !success {
|
||||||
|
fmt.Println("A5, failed to get valid t")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
//A6, C2 = M + t;
|
||||||
|
c2 := make([]byte, msgLen)
|
||||||
|
for i := 0; i < msgLen; i++ {
|
||||||
|
c2[i] = msg[i] ^ t[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
//A7, C3 = hash(x2||M||y2)
|
||||||
|
c3 := calculateC3(curve, x2, y2, msg)
|
||||||
|
|
||||||
|
return append(append(c1, c2...), c3...), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrypt sm2 decrypt implementation
|
||||||
|
func Decrypt(priv *ecdsa.PrivateKey, ciphertext []byte) ([]byte, error) {
|
||||||
|
ciphertextLen := len(ciphertext)
|
||||||
|
if ciphertextLen <= 1+(priv.Params().BitSize/8)+sm3.Size {
|
||||||
|
return nil, errors.New("invalid ciphertext length")
|
||||||
|
}
|
||||||
|
curve := priv.Curve
|
||||||
|
// B1, get C1, and check C1
|
||||||
|
x1, y1, c2Start, err := bytes2Point(curve, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !curve.IsOnCurve(x1, y1) {
|
||||||
|
return nil, fmt.Errorf("point c1 is not on curve %s", curve.Params().Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
//B2 is ignored
|
||||||
|
//B3, calculate x2, y2
|
||||||
|
x2, y2 := curve.ScalarMult(x1, y1, priv.D.Bytes())
|
||||||
|
|
||||||
|
//B4, calculate t=KDF(x2||y2, klen)
|
||||||
|
c2 := ciphertext[c2Start : ciphertextLen-sm3.Size]
|
||||||
|
msgLen := len(c2)
|
||||||
|
t, success := kdf(append(toBytes(curve, x2), toBytes(curve, y2)...), msgLen)
|
||||||
|
if !success {
|
||||||
|
return nil, errors.New("invalid cipher text")
|
||||||
|
}
|
||||||
|
|
||||||
|
//B5, calculate msg = c2 ^ t
|
||||||
|
msg := make([]byte, msgLen)
|
||||||
|
for i := 0; i < msgLen; i++ {
|
||||||
|
msg[i] = c2[i] ^ t[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
//B6, calculate hash and compare it
|
||||||
|
c3 := ciphertext[ciphertextLen-sm3.Size:]
|
||||||
|
u := calculateC3(curve, x2, y2, msg)
|
||||||
|
for i := 0; i < sm3.Size; i++ {
|
||||||
|
if c3[i] != u[i] {
|
||||||
|
return nil, errors.New("invalid hash value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return msg, nil
|
||||||
|
}
|
57
sm2/sm2_test.go
Normal file
57
sm2/sm2_test.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package sm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"math/big"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_kdf(t *testing.T) {
|
||||||
|
x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
|
||||||
|
y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
|
||||||
|
|
||||||
|
expected := "006e30dae231b071dfad8aa379e90264491603"
|
||||||
|
|
||||||
|
result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19)
|
||||||
|
if !success {
|
||||||
|
t.Fatalf("failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
resultStr := hex.EncodeToString(result)
|
||||||
|
|
||||||
|
if expected != resultStr {
|
||||||
|
t.Fatalf("expected %s, real value %s", expected, resultStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_encryptDecrypt(t *testing.T) {
|
||||||
|
priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
plainText string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", "encryption standard"},
|
||||||
|
{"equals 32", "encryption standard encryption "},
|
||||||
|
{"long than 32", "encryption standard encryption standard"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encrypt failed %v", err)
|
||||||
|
}
|
||||||
|
plaintext, err := Decrypt(priv, ciphertext)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decrypt failed %v", err)
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(string(plaintext), tt.plainText) {
|
||||||
|
t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
120
sm2/util.go
Normal file
120
sm2/util.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package sm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/elliptic"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var zero = new(big.Int).SetInt64(0)
|
||||||
|
|
||||||
|
func toBytes(curve elliptic.Curve, value *big.Int) []byte {
|
||||||
|
bytes := value.Bytes()
|
||||||
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||||
|
if byteLen == len(bytes) {
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
result := make([]byte, byteLen)
|
||||||
|
copy(result[byteLen-len(bytes):], bytes)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func point2UncompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
|
return elliptic.Marshal(curve, x, y)
|
||||||
|
}
|
||||||
|
|
||||||
|
func point2CompressedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
|
buffer := make([]byte, (curve.Params().BitSize+7)>>3+1)
|
||||||
|
copy(buffer[1:], toBytes(curve, x))
|
||||||
|
if getLastBitOfY(x, y) > 0 {
|
||||||
|
buffer[0] = Compressed_03
|
||||||
|
} else {
|
||||||
|
buffer[0] = Compressed_02
|
||||||
|
}
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func point2MixedBytes(curve elliptic.Curve, x, y *big.Int) []byte {
|
||||||
|
buffer := elliptic.Marshal(curve, x, y)
|
||||||
|
if getLastBitOfY(x, y) > 0 {
|
||||||
|
buffer[0] = Mixed_07
|
||||||
|
} else {
|
||||||
|
buffer[0] = Mixed_06
|
||||||
|
}
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func getLastBitOfY(x, y *big.Int) uint {
|
||||||
|
if x.Cmp(zero) == 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return y.Bit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toPointXY(bytes []byte) *big.Int {
|
||||||
|
return new(big.Int).SetBytes(bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func calculatePrimeCurveY(curve elliptic.Curve, x *big.Int) (*big.Int, error) {
|
||||||
|
x3 := new(big.Int).Mul(x, x)
|
||||||
|
x3.Mul(x3, x)
|
||||||
|
|
||||||
|
threeX := new(big.Int).Lsh(x, 1)
|
||||||
|
threeX.Add(threeX, x)
|
||||||
|
|
||||||
|
x3.Sub(x3, threeX)
|
||||||
|
x3.Add(x3, curve.Params().B)
|
||||||
|
x3.Mod(x3, curve.Params().P)
|
||||||
|
y := x3.ModSqrt(x3, curve.Params().P)
|
||||||
|
|
||||||
|
if y == nil {
|
||||||
|
return nil, errors.New("can't calculate y based on x")
|
||||||
|
}
|
||||||
|
return y, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytes2Point(curve elliptic.Curve, bytes []byte) (*big.Int, *big.Int, int, error) {
|
||||||
|
if len(bytes) < 1+(curve.Params().BitSize/8) {
|
||||||
|
return nil, nil, 0, fmt.Errorf("invalid bytes length %d", len(bytes))
|
||||||
|
}
|
||||||
|
format := bytes[0]
|
||||||
|
byteLen := (curve.Params().BitSize + 7) >> 3
|
||||||
|
switch format {
|
||||||
|
case Uncompressed:
|
||||||
|
if len(bytes) < 1+byteLen*2 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("invalid uncompressed bytes length %d", len(bytes))
|
||||||
|
}
|
||||||
|
x := toPointXY(bytes[1 : 1+byteLen])
|
||||||
|
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
||||||
|
return x, y, 1 + byteLen*2, nil
|
||||||
|
case Compressed_02, Compressed_03:
|
||||||
|
if len(bytes) < 1+byteLen {
|
||||||
|
return nil, nil, 0, fmt.Errorf("invalid compressed bytes length %d", len(bytes))
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(curve.Params().Name, "P-") {
|
||||||
|
// y² = x³ - 3x + b
|
||||||
|
x := toPointXY(bytes[1 : 1+byteLen])
|
||||||
|
y, err := calculatePrimeCurveY(curve, x)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if (getLastBitOfY(x, y) > 0 && format == Compressed_02) || (getLastBitOfY(x, y) == 0 && format == Compressed_03) {
|
||||||
|
y.Sub(curve.Params().P, y)
|
||||||
|
}
|
||||||
|
return x, y, 1 + byteLen, nil
|
||||||
|
}
|
||||||
|
return nil, nil, 0, fmt.Errorf("unsupport bytes format %d, curve %s", format, curve.Params().Name)
|
||||||
|
case Mixed_06, Mixed_07:
|
||||||
|
// what's the mixed format purpose?
|
||||||
|
if len(bytes) < 1+byteLen*2 {
|
||||||
|
return nil, nil, 0, fmt.Errorf("invalid mixed bytes length %d", len(bytes))
|
||||||
|
}
|
||||||
|
x := toPointXY(bytes[1 : 1+byteLen])
|
||||||
|
y := toPointXY(bytes[1+byteLen : 1+byteLen*2])
|
||||||
|
return x, y, 1 + byteLen*2, nil
|
||||||
|
}
|
||||||
|
return nil, nil, 0, fmt.Errorf("unknown bytes format %d", format)
|
||||||
|
}
|
79
sm2/util_test.go
Normal file
79
sm2/util_test.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package sm2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/elliptic"
|
||||||
|
"encoding/hex"
|
||||||
|
"math/big"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_toBytes(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
value string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"less than 32", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||||
|
{"equals 32", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
v, _ := new(big.Int).SetString(tt.args.value, 16)
|
||||||
|
if got := toBytes(elliptic.P256(), v); !reflect.DeepEqual(hex.EncodeToString(got), tt.want) {
|
||||||
|
t.Errorf("toBytes() = %v, want %v", hex.EncodeToString(got), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_getLastBitOfY(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
y string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want uint
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"0", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, 0},
|
||||||
|
{"1", args{"d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865ff"}, 1},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
y, _ := new(big.Int).SetString(tt.args.y, 16)
|
||||||
|
if got := getLastBitOfY(y, y); got != tt.want {
|
||||||
|
t.Errorf("getLastBitOfY() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_toPointXY(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
bytes string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
// TODO: Add test cases.
|
||||||
|
{"has zero padding", args{"00d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||||
|
{"no zero padding", args{"58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"}, "58d20d27d0632957f8028c1e024f6b02edf23102a566c932ae8bd613a8e865fe"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
bytes, _ := hex.DecodeString(tt.args.bytes)
|
||||||
|
expectedInt, _ := new(big.Int).SetString(tt.want, 16)
|
||||||
|
if got := toPointXY(bytes); !reflect.DeepEqual(got, expectedInt) {
|
||||||
|
t.Errorf("toPointXY() = %v, want %v", got, expectedInt)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@ -7,10 +7,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Size the size of a SM3 checksum in bytes.
|
// Size the size of a SM3 checksum in bytes.
|
||||||
const Size = 32
|
const Size int = 32
|
||||||
|
|
||||||
// BlockSize the blocksize of SM3 in bytes.
|
// BlockSize the blocksize of SM3 in bytes.
|
||||||
const BlockSize = 64
|
const BlockSize int = 64
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chunk = 64
|
chunk = 64
|
||||||
|
Loading…
x
Reference in New Issue
Block a user