diff --git a/sm2/sm2_test.go b/sm2/sm2_test.go index 891b60d..873f0cb 100644 --- a/sm2/sm2_test.go +++ b/sm2/sm2_test.go @@ -82,6 +82,13 @@ func Test_encryptDecrypt_ASN1(t *testing.T) { if !reflect.DeepEqual(string(plaintext), tt.plainText) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } + plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } }) } } @@ -201,6 +208,13 @@ func Test_encryptDecrypt(t *testing.T) { if !reflect.DeepEqual(string(plaintext), tt.plainText) { t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) } + plaintext, err = Decrypt(priv, ciphertext) + if err != nil { + t.Fatalf("decrypt failed %v", err) + } + if !reflect.DeepEqual(string(plaintext), tt.plainText) { + t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText) + } }) } } diff --git a/sm9/g1.go b/sm9/g1.go index e76c846..61d8a96 100644 --- a/sm9/g1.go +++ b/sm9/g1.go @@ -100,23 +100,41 @@ func (e *G1) Marshal() []byte { // Each value is a 256-bit number. const numBytes = 256 / 8 + ret := make([]byte, numBytes*2) + + e.fillBytes(ret) + return ret +} + +// Marshal converts e to a byte slice with prefix +func (e *G1) MarshalUncompressed() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + + ret := make([]byte, numBytes*2+1) + ret[0] = 4 + + e.fillBytes(ret[1:]) + return ret +} + +func (e *G1) fillBytes(buffer []byte) { + const numBytes = 256 / 8 + if e.p == nil { e.p = &curvePoint{} } e.p.MakeAffine() - ret := make([]byte, numBytes*2) if e.p.IsInfinity() { - return ret + return } temp := &gfP{} montDecode(temp, &e.p.x) - temp.Marshal(ret) + temp.Marshal(buffer) montDecode(temp, &e.p.y) - temp.Marshal(ret[numBytes:]) - - return ret + temp.Marshal(buffer[numBytes:]) } // Unmarshal sets e to the result of converting the output of Marshal back into diff --git a/sm9/g2.go b/sm9/g2.go index 877362c..3c24312 100644 --- a/sm9/g2.go +++ b/sm9/g2.go @@ -12,6 +12,9 @@ type G2 struct { p *twistPoint } +//Gen2 is the generator of G2. +var Gen2 = &G2{twistGen} + // RandomG2 returns x and g₂ˣ where x is a random, non-zero number read from r. func RandomG2(r io.Reader) (*big.Int, *G2, error) { k, err := randomK(r) @@ -76,28 +79,43 @@ func (e *G2) Set(a *G2) *G2 { func (e *G2) Marshal() []byte { // Each value is a 256-bit number. const numBytes = 256 / 8 + ret := make([]byte, numBytes*4) + e.fillBytes(ret) + return ret +} + +// Marshal converts e into a byte slice with prefix +func (e *G2) MarshalUncompressed() []byte { + // Each value is a 256-bit number. + const numBytes = 256 / 8 + ret := make([]byte, numBytes*4+1) + ret[0] = 4 + e.fillBytes(ret[1:]) + return ret +} + +func (e *G2) fillBytes(buffer []byte) { + // Each value is a 256-bit number. + const numBytes = 256 / 8 if e.p == nil { e.p = &twistPoint{} } e.p.MakeAffine() - ret := make([]byte, numBytes*4) if e.p.IsInfinity() { - return ret + return } temp := &gfP{} montDecode(temp, &e.p.x.x) - temp.Marshal(ret) + temp.Marshal(buffer) montDecode(temp, &e.p.x.y) - temp.Marshal(ret[numBytes:]) + temp.Marshal(buffer[numBytes:]) montDecode(temp, &e.p.y.x) - temp.Marshal(ret[2*numBytes:]) + temp.Marshal(buffer[2*numBytes:]) montDecode(temp, &e.p.y.y) - temp.Marshal(ret[3*numBytes:]) - - return ret + temp.Marshal(buffer[3*numBytes:]) } // Unmarshal sets e to the result of converting the output of Marshal back into diff --git a/sm9/sm9.go b/sm9/sm9.go index c8b02c5..a7e2a2e 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -1,4 +1,467 @@ // Package sm9 handle shangmi sm9 algorithm and its curves and pairing implementation package sm9 -// TODO: implement SM9 algorithm based on basic curves, G1/G2/GT and r-ate pairing implementation. +import ( + "crypto" + goSubtle "crypto/subtle" + "encoding/binary" + "errors" + "fmt" + "io" + "math/big" + + "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/sm3" + "golang.org/x/crypto/cryptobyte" + "golang.org/x/crypto/cryptobyte/asn1" +) + +var bigOne = big.NewInt(1) + +type hashMode byte + +const ( + // hashmode used in h1: 0x01 + H1 hashMode = 1 + iota + // hashmode used in h2: 0x02 + H2 +) + +type encryptType byte + +const ( + ENC_TYPE_XOR encryptType = 0 + ENC_TYPE_ECB encryptType = 1 + ENC_TYPE_CBC encryptType = 2 + ENC_TYPE_OFB encryptType = 4 + ENC_TYPE_CFB encryptType = 8 +) + +//hash implements H1(Z,n) or H2(Z,n) in sm9 algorithm. +func hash(z []byte, h hashMode) *big.Int { + md := sm3.New() + var ha [64]byte + var countBytes [4]byte + var ct uint32 = 1 + + for i := 0; i < 2; i++ { + binary.BigEndian.PutUint32(countBytes[:], ct) + md.Write([]byte{byte(h)}) + md.Write(z) + md.Write(countBytes[:]) + copy(ha[i*sm3.Size:], md.Sum(nil)) + ct++ + md.Reset() + } + k := new(big.Int).SetBytes(ha[:40]) + n := new(big.Int).Sub(Order, bigOne) + k.Mod(k, n) + k.Add(k, bigOne) + return k +} + +func hashH1(z []byte) *big.Int { + return hash(z, H1) +} + +func hashH2(z []byte) *big.Int { + return hash(z, H2) +} + +// randFieldElement returns a random element of the order of the given +// curve using the procedure given in FIPS 186-4, Appendix B.5.1. +func randFieldElement(rand io.Reader) (k *big.Int, err error) { + b := make([]byte, 40) // (256 + 64) / 8 + _, err = io.ReadFull(rand, b) + if err != nil { + return + } + + k = new(big.Int).SetBytes(b) + n := new(big.Int).Sub(Order, bigOne) + k.Mod(k, n) + k.Add(k, bigOne) + return +} + +// Sign signs a hash (which should be the result of hashing a larger message) +// using the user dsa key. It returns the signature as a pair of h and s. +func Sign(rand io.Reader, priv *SignPrivateKey, hash []byte) (h *big.Int, s *G1, err error) { + g := Pair(Gen1, priv.SignMasterPublicKey.MasterPublicKey) + var r *big.Int + for { + r, err = randFieldElement(rand) + if err != nil { + return + } + w := new(GT).ScalarMult(g, r) + + var buffer []byte + buffer = append(buffer, hash...) + buffer = append(buffer, w.Marshal()...) + + h = hashH2(buffer) + + l := new(big.Int).Sub(r, h) + l.Mod(l, Order) + + if l.Sign() != 0 { + s = new(G1).ScalarMult(priv.PrivateKey, l) + break + } + } + return +} + +// Sign signs digest with user's DSA key, reading randomness from rand. The opts argument +// is not currently used but, in keeping with the crypto.Signer interface. +func (priv *SignPrivateKey) Sign(rand io.Reader, hash []byte, opts crypto.SignerOpts) ([]byte, error) { + h, s, err := Sign(rand, priv, hash) + if err != nil { + return nil, err + } + + hBytes := make([]byte, 32) + h.FillBytes(hBytes) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1OctetString(hBytes) + b.AddASN1BitString(s.MarshalUncompressed()) + }) + return b.Bytes() +} + +// SignASN1 signs a hash (which should be the result of hashing a larger message) +// using the private key, priv. It returns the ASN.1 encoded signature. +func SignASN1(rand io.Reader, priv *SignPrivateKey, hash []byte) ([]byte, error) { + return priv.Sign(rand, hash, nil) +} + +// GenerateUserPublicKey generate user sign public key +func (pub *SignMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G2 { + var buffer []byte + buffer = append(buffer, uid...) + buffer = append(buffer, hid) + h1 := hashH1(buffer) + p := new(G2).ScalarBaseMult(h1) + p.Add(p, pub.MasterPublicKey) + return p +} + +// Verify verifies the signature in h, s of hash using the master dsa public key and user id, uid and hid. +// Its return value records whether the signature is valid. +func Verify(pub *SignMasterPublicKey, uid []byte, hid byte, hash []byte, h *big.Int, s *G1) bool { + if h.Sign() <= 0 || h.Cmp(Order) >= 0 { + return false + } + if !s.p.IsOnCurve() { + return false + } + g := Pair(Gen1, pub.MasterPublicKey) + t := new(GT).ScalarMult(g, h) + + // user sign public key p generation + p := pub.GenerateUserPublicKey(uid, hid) + + u := Pair(s, p) + w := new(GT).Add(u, t) + + var buffer []byte + buffer = append(buffer, hash...) + buffer = append(buffer, w.Marshal()...) + h2 := hashH2(buffer) + + return h.Cmp(h2) == 0 +} + +// VerifyASN1 verifies the ASN.1 encoded signature, sig, of hash using the +// public key, pub. Its return value records whether the signature is valid. +func VerifyASN1(pub *SignMasterPublicKey, uid []byte, hid byte, hash, sig []byte) bool { + var ( + hBytes []byte + sBytes []byte + inner cryptobyte.String + ) + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Bytes(&hBytes, asn1.OCTET_STRING) || + !inner.ReadASN1BitStringAsBytes(&sBytes) || + !inner.Empty() { + return false + } + h := new(big.Int).SetBytes(hBytes) + if sBytes[0] != 4 { + return false + } + s := new(G1) + _, err := s.Unmarshal(sBytes[1:]) + if err != nil { + return false + } + + return Verify(pub, uid, hid, hash, h, s) +} + +// Verify verifies the ASN.1 encoded signature, sig, of hash using the +// public key, pub. Its return value records whether the signature is valid. +func (pub *SignMasterPublicKey) Verify(uid []byte, hid byte, hash, sig []byte) bool { + return VerifyASN1(pub, uid, hid, hash, sig) +} + +func (pub *EncryptMasterPublicKey) GenerateUserPublicKey(uid []byte, hid byte) *G1 { + var buffer []byte + buffer = append(buffer, uid...) + buffer = append(buffer, hid) + h1 := hashH1(buffer) + p := new(G1).ScalarBaseMult(h1) + p.Add(p, pub.MasterPublicKey) + return p +} + +// WrappKey generate and wrapp key wtih reciever's uid and system hid +func WrappKey(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, kLen int) (key []byte, cipher *G1, err error) { + q := pub.GenerateUserPublicKey(uid, hid) + var r *big.Int + var ok bool + for { + r, err = randFieldElement(rand) + if err != nil { + return + } + + cipher = new(G1).ScalarMult(q, r) + + g := Pair(pub.MasterPublicKey, Gen2) + w := new(GT).ScalarMult(g, r) + + var buffer []byte + buffer = append(buffer, cipher.Marshal()...) + buffer = append(buffer, w.Marshal()...) + buffer = append(buffer, uid...) + + key, ok = sm3.Kdf(buffer, kLen) + if ok { + break + } + } + return +} + +// WrappKey wrapp key and marshal the cipher as ASN1 format. +func (pub *EncryptMasterPublicKey) WrappKey(rand io.Reader, uid []byte, hid byte, kLen int) ([]byte, []byte, error) { + key, cipher, err := WrappKey(rand, pub, uid, hid, kLen) + if err != nil { + return nil, nil, err + } + var b cryptobyte.Builder + b.AddASN1BitString(cipher.MarshalUncompressed()) + cipherASN1, err := b.Bytes() + + return key, cipherASN1, err +} + +// WrappKey wrapp key and marshal the result of SM9KeyPackage as ASN1 format. according +// SM9 cryptographic algorithm application specification +func (pub *EncryptMasterPublicKey) WrappKeyASN1(rand io.Reader, uid []byte, hid byte, kLen int) ([]byte, error) { + key, cipher, err := WrappKey(rand, pub, uid, hid, kLen) + if err != nil { + return nil, err + } + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1OctetString(key) + b.AddASN1BitString(cipher.MarshalUncompressed()) + }) + return b.Bytes() +} + +// UnmarshalSM9KeyPackage is an utility to unmarshal SM9KeyPackage +func UnmarshalSM9KeyPackage(der []byte) ([]byte, *G1, error) { + input := cryptobyte.String(der) + var ( + key []byte + cipherBytes []byte + inner cryptobyte.String + ) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Bytes(&key, asn1.OCTET_STRING) || + !inner.ReadASN1BitStringAsBytes(&cipherBytes) || + !inner.Empty() { + return nil, nil, errors.New("sm9: invalid SM9KeyPackage asn.1 data") + } + g := new(G1) + _, err := g.Unmarshal(cipherBytes[1:]) + if err != nil { + return nil, nil, err + } + return key, g, nil +} + +// UnwrappKey unwrapper key from cipher, user id and aligned key length +func UnwrappKey(priv *EncryptPrivateKey, uid []byte, cipher *G1, kLen int) ([]byte, error) { + if !cipher.p.IsOnCurve() { + return nil, errors.New("sm9: invalid cipher, it's NOT on curve") + } + + w := Pair(cipher, priv.PrivateKey) + + var buffer []byte + buffer = append(buffer, cipher.Marshal()...) + buffer = append(buffer, w.Marshal()...) + buffer = append(buffer, uid...) + + key, ok := sm3.Kdf(buffer, kLen) + if !ok { + return nil, errors.New("sm9: invalid cipher") + } + return key, nil +} + +func (priv *EncryptPrivateKey) UnwrappKey(uid, cipherDer []byte, kLen int) ([]byte, error) { + bytes := make([]byte, 64+1) + input := cryptobyte.String(cipherDer) + if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { + return nil, errors.New("sm9: invalid chipher asn1 data") + } + if bytes[0] != 4 { + return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", bytes[0]) + } + g := new(G1) + _, err := g.Unmarshal(bytes[1:]) + if err != nil { + return nil, err + } + return UnwrappKey(priv, uid, g, kLen) +} + +// Encrypt encrypt plaintext, output ciphertext with format C1||C3||C2 +func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + key, cipher, err := WrappKey(rand, pub, uid, hid, len(plaintext)+sm3.Size) + if err != nil { + return nil, err + } + xor.XorBytes(key, key[:len(plaintext)], plaintext) + + hash := sm3.New() + hash.Write(key) + c3 := hash.Sum(nil) + + ciphertext := append(cipher.Marshal(), c3...) + ciphertext = append(ciphertext, key[:len(plaintext)]...) + return ciphertext, nil +} + +// EncryptASN1 encrypt plaintext and output ciphertext with ASN.1 format according +// SM9 cryptographic algorithm application specification +func EncryptASN1(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + return pub.Encrypt(rand, uid, hid, plaintext) +} + +// Encrypt encrypt plaintext and output ciphertext with ASN.1 format according +// SM9 cryptographic algorithm application specification +func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + key, cipher, err := WrappKey(rand, pub, uid, hid, len(plaintext)+sm3.Size) + if err != nil { + return nil, err + } + xor.XorBytes(key, key[:len(plaintext)], plaintext) + + hash := sm3.New() + hash.Write(key) + c3 := hash.Sum(nil) + + var b cryptobyte.Builder + b.AddASN1(asn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1Int64(int64(ENC_TYPE_XOR)) + b.AddASN1BitString(cipher.MarshalUncompressed()) + b.AddASN1OctetString(c3) + b.AddASN1OctetString(key[:len(plaintext)]) + }) + return b.Bytes() +} + +// Decrypt decrypt chipher, ciphertext should be with format C1||C3||C2 +func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + c := &G1{} + c3, err := c.Unmarshal(ciphertext) + if err != nil { + return nil, err + } + + key, err := UnwrappKey(priv, uid, c, len(c3)) + if err != nil { + return nil, err + } + + c2 := c3[sm3.Size:] + + hash := sm3.New() + hash.Write(c2) + hash.Write(key[len(c2):]) + c32 := hash.Sum(nil) + + if goSubtle.ConstantTimeCompare(c3[:sm3.Size], c32) != 1 { + return nil, errors.New("sm9: invalid mac value") + } + + xor.XorBytes(key, c2, key[:len(c2)]) + return key[:len(c2)], nil +} + +// DecryptASN1 decrypt chipher, ciphertext should be with ASN.1 format according +// SM9 cryptographic algorithm application specification +func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + if len(ciphertext) <= 32+65 { + return nil, errors.New("sm9: invalid ciphertext length") + } + var ( + encType int + c3Bytes []byte + c1Bytes []byte + c2Bytes []byte + inner cryptobyte.String + ) + input := cryptobyte.String(ciphertext) + if !input.ReadASN1(&inner, asn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(&encType) || + !inner.ReadASN1BitStringAsBytes(&c1Bytes) || + !inner.ReadASN1Bytes(&c3Bytes, asn1.OCTET_STRING) || + !inner.ReadASN1Bytes(&c2Bytes, asn1.OCTET_STRING) || + !inner.Empty() { + return nil, errors.New("sm9: invalid ciphertext asn.1 data") + } + if encType != int(ENC_TYPE_XOR) { + return nil, fmt.Errorf("sm9: does not support this kind of encrypt type <%v> yet", encType) + } + if c1Bytes[0] != 4 { + return nil, fmt.Errorf("sm9: unsupport curve point marshal format <%v>", c1Bytes[0]) + } + c := &G1{} + _, err := c.Unmarshal(c1Bytes[1:]) + if err != nil { + return nil, err + } + + key, err := UnwrappKey(priv, uid, c, len(c2Bytes)+len(c3Bytes)) + if err != nil { + return nil, err + } + if err != nil { + return nil, err + } + + hash := sm3.New() + hash.Write(c2Bytes) + hash.Write(key[len(c2Bytes):]) + c32 := hash.Sum(nil) + + if goSubtle.ConstantTimeCompare(c3Bytes, c32) != 1 { + return nil, errors.New("sm9: invalid mac value") + } + xor.XorBytes(key, c2Bytes, key[:len(c2Bytes)]) + return key[:len(c2Bytes)], nil +} diff --git a/sm9/sm9_key.go b/sm9/sm9_key.go new file mode 100644 index 0000000..2758cf0 --- /dev/null +++ b/sm9/sm9_key.go @@ -0,0 +1,290 @@ +package sm9 + +import ( + "errors" + "io" + "math/big" + + "golang.org/x/crypto/cryptobyte" +) + +type SignMasterPrivateKey struct { + SignMasterPublicKey + D *big.Int +} + +type SignMasterPublicKey struct { + MasterPublicKey *G2 +} + +type SignPrivateKey struct { + PrivateKey *G1 + SignMasterPublicKey +} + +type EncryptMasterPrivateKey struct { + EncryptMasterPublicKey + D *big.Int +} + +type EncryptMasterPublicKey struct { + MasterPublicKey *G1 +} + +type EncryptPrivateKey struct { + PrivateKey *G2 + EncryptMasterPublicKey +} + +// GenerateSignMasterKey generates a master public and private key pair for DSA usage. +func GenerateSignMasterKey(rand io.Reader) (*SignMasterPrivateKey, error) { + k, err := randFieldElement(rand) + if err != nil { + return nil, err + } + + priv := new(SignMasterPrivateKey) + priv.D = k + priv.MasterPublicKey = new(G2).ScalarBaseMult(k) + return priv, nil +} + +// MarshalASN1 marshal sign master private key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (master *SignMasterPrivateKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BigInt(master.D) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to sign master private key +func (master *SignMasterPrivateKey) UnmarshalASN1(der []byte) error { + input := cryptobyte.String(der) + d := &big.Int{} + if !input.ReadASN1Integer(d) || !input.Empty() { + return errors.New("sm9: invalid sign master key asn1 data") + } + master.D = d + master.MasterPublicKey = new(G2).ScalarBaseMult(d) + return nil +} + +// GenerateUserKey generate an user dsa key. +func (master *SignMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*SignPrivateKey, error) { + var id []byte + id = append(id, uid...) + id = append(id, hid) + + t1 := hashH1(id) + t1.Add(t1, master.D) + if t1.Sign() == 0 { + return nil, errors.New("sm9: need to re-generate sign master private key") + } + t1.ModInverse(t1, Order) + t2 := new(big.Int).Mul(t1, master.D) + t2.Mod(t2, Order) + priv := new(SignPrivateKey) + priv.SignMasterPublicKey = master.SignMasterPublicKey + priv.PrivateKey = new(G1).ScalarBaseMult(t2) + + return priv, nil +} + +// Public returns the public key corresponding to priv. +func (master *SignMasterPrivateKey) Public() *SignMasterPublicKey { + return &master.SignMasterPublicKey +} + +// MarshalASN1 marshal sign master public key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (pub *SignMasterPublicKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BitString(pub.MasterPublicKey.MarshalUncompressed()) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to sign master public key +func (pub *SignMasterPublicKey) UnmarshalASN1(der []byte) error { + var bytes []byte + input := cryptobyte.String(der) + if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { + return errors.New("sm9: invalid sign master public key asn1 data") + } + if bytes[0] != 4 { + return errors.New("sm9: invalid prefix of sign master public key") + } + g2 := new(G2) + _, err := g2.Unmarshal(bytes[1:]) + if err != nil { + return err + } + pub.MasterPublicKey = g2 + return nil +} + +// MasterPublic returns the master public key corresponding to priv. +func (priv *SignPrivateKey) MasterPublic() *SignMasterPublicKey { + return &priv.SignMasterPublicKey +} + +// SetMasterPublicKey bind the sign master public key to it. +func (priv *SignPrivateKey) SetMasterPublicKey(pub *SignMasterPublicKey) { + if priv.SignMasterPublicKey.MasterPublicKey == nil { + priv.SignMasterPublicKey = *pub + } +} + +// MarshalASN1 marshal sign user private key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (priv *SignPrivateKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BitString(priv.PrivateKey.MarshalUncompressed()) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to sign user private key +// Note, priv's SignMasterPublicKey should be handled separately. +func (priv *SignPrivateKey) UnmarshalASN1(der []byte) error { + var bytes []byte + input := cryptobyte.String(der) + if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { + return errors.New("sm9: invalid sign user private key asn1 data") + } + if bytes[0] != 4 { + return errors.New("sm9: invalid prefix of sign user private key") + } + g := new(G1) + _, err := g.Unmarshal(bytes[1:]) + if err != nil { + return err + } + priv.PrivateKey = g + return nil +} + +// GenerateEncryptMasterKey generates a master public and private key pair for encryption usage. +func GenerateEncryptMasterKey(rand io.Reader) (*EncryptMasterPrivateKey, error) { + k, err := randFieldElement(rand) + if err != nil { + return nil, err + } + + priv := new(EncryptMasterPrivateKey) + priv.D = k + priv.MasterPublicKey = new(G1).ScalarBaseMult(k) + return priv, nil +} + +// GenerateUserKey generate an user key for encryption. +func (master *EncryptMasterPrivateKey) GenerateUserKey(uid []byte, hid byte) (*EncryptPrivateKey, error) { + var id []byte + id = append(id, uid...) + id = append(id, hid) + + t1 := hashH1(id) + t1.Add(t1, master.D) + if t1.Sign() == 0 { + return nil, errors.New("sm9: need to re-generate encrypt master private key") + } + t1.ModInverse(t1, Order) + t2 := new(big.Int).Mul(t1, master.D) + t2.Mod(t2, Order) + + priv := new(EncryptPrivateKey) + priv.EncryptMasterPublicKey = master.EncryptMasterPublicKey + priv.PrivateKey = new(G2).ScalarBaseMult(t2) + + return priv, nil +} + +// Public returns the public key corresponding to priv. +func (master *EncryptMasterPrivateKey) Public() *EncryptMasterPublicKey { + return &master.EncryptMasterPublicKey +} + +// MarshalASN1 marshal encrypt master private key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (master *EncryptMasterPrivateKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BigInt(master.D) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to encrpt master private key +func (master *EncryptMasterPrivateKey) UnmarshalASN1(der []byte) error { + input := cryptobyte.String(der) + d := &big.Int{} + if !input.ReadASN1Integer(d) || !input.Empty() { + return errors.New("sm9: invalid encrpt master key asn1 data") + } + master.D = d + master.MasterPublicKey = new(G1).ScalarBaseMult(d) + return nil +} + +// MarshalASN1 marshal encrypt master public key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (pub *EncryptMasterPublicKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BitString(pub.MasterPublicKey.MarshalUncompressed()) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to encrypt master public key +func (pub *EncryptMasterPublicKey) UnmarshalASN1(der []byte) error { + var bytes []byte + input := cryptobyte.String(der) + if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { + return errors.New("sm9: invalid encrypt master public key asn1 data") + } + if bytes[0] != 4 { + return errors.New("sm9: invalid prefix of encrypt master public key") + } + g := new(G1) + _, err := g.Unmarshal(bytes[1:]) + if err != nil { + return err + } + pub.MasterPublicKey = g + return nil +} + +// MasterPublic returns the master public key corresponding to priv. +func (priv *EncryptPrivateKey) MasterPublic() *EncryptMasterPublicKey { + return &priv.EncryptMasterPublicKey +} + +// SetMasterPublicKey bind the encrypt master public key to it. +func (priv *EncryptPrivateKey) SetMasterPublicKey(pub *EncryptMasterPublicKey) { + if priv.EncryptMasterPublicKey.MasterPublicKey == nil { + priv.EncryptMasterPublicKey = *pub + } +} + +// MarshalASN1 marshal encrypt user private key to asn.1 format data according +// SM9 cryptographic algorithm application specification +func (priv *EncryptPrivateKey) MarshalASN1() ([]byte, error) { + var b cryptobyte.Builder + b.AddASN1BitString(priv.PrivateKey.MarshalUncompressed()) + return b.Bytes() +} + +// UnmarshalASN1 unmarsal der data to encrypt user private key +// Note, priv's EncryptMasterPublicKey should be handled separately. +func (priv *EncryptPrivateKey) UnmarshalASN1(der []byte) error { + var bytes []byte + input := cryptobyte.String(der) + if !input.ReadASN1BitStringAsBytes(&bytes) || !input.Empty() { + return errors.New("sm9: invalid encrypt user private key asn1 data") + } + if bytes[0] != 4 { + return errors.New("sm9: invalid prefix of encrypt user private key") + } + g := new(G2) + _, err := g.Unmarshal(bytes[1:]) + if err != nil { + return err + } + priv.PrivateKey = g + return nil +} diff --git a/sm9/sm9_key_test.go b/sm9/sm9_key_test.go new file mode 100644 index 0000000..10b6ac2 --- /dev/null +++ b/sm9/sm9_key_test.go @@ -0,0 +1,167 @@ +package sm9 + +import ( + "crypto/rand" + "encoding/hex" + "testing" +) + +func TestSignMasterPrivateKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + der, err := masterKey.MarshalASN1() + if err != nil { + t.Fatal(err) + } + masterKey2 := new(SignMasterPrivateKey) + err = masterKey2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if masterKey.D.Cmp(masterKey2.D) != 0 { + t.Errorf("expected %v, got %v", hex.EncodeToString(masterKey.D.Bytes()), hex.EncodeToString(masterKey2.D.Bytes())) + } +} + +func TestSignMasterPublicKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + der, err := masterKey.Public().MarshalASN1() + if err != nil { + t.Fatal(err) + } + pub2 := new(SignMasterPublicKey) + err = pub2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if masterKey.MasterPublicKey.p.x != pub2.MasterPublicKey.p.x || masterKey.MasterPublicKey.p.y != pub2.MasterPublicKey.p.y || masterKey.MasterPublicKey.p.z != pub2.MasterPublicKey.p.z { + t.Errorf("not same") + } +} + +func TestSignUserPrivateKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + der, err := userKey.MarshalASN1() + if err != nil { + t.Fatal(err) + } + userKey2 := new(SignPrivateKey) + err = userKey2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if userKey.PrivateKey.p.x != userKey2.PrivateKey.p.x || userKey.PrivateKey.p.y != userKey2.PrivateKey.p.y { + t.Errorf("not same") + } +} + +func TestEncryptMasterPrivateKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + der, err := masterKey.MarshalASN1() + if err != nil { + t.Fatal(err) + } + masterKey2 := new(SignMasterPrivateKey) + err = masterKey2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if masterKey.D.Cmp(masterKey2.D) != 0 { + t.Errorf("expected %v, got %v", hex.EncodeToString(masterKey.D.Bytes()), hex.EncodeToString(masterKey2.D.Bytes())) + } +} + +func TestEncryptMasterPublicKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + if err != nil { + t.Fatal(err) + } + der, err := masterKey.Public().MarshalASN1() + if err != nil { + t.Fatal(err) + } + pub2 := new(EncryptMasterPublicKey) + err = pub2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if masterKey.MasterPublicKey.p.x != pub2.MasterPublicKey.p.x || masterKey.MasterPublicKey.p.y != pub2.MasterPublicKey.p.y { + t.Errorf("not same") + } +} + +func TestEncryptUserPrivateKeyMarshalASN1(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + der, err := userKey.MarshalASN1() + if err != nil { + t.Fatal(err) + } + userKey2 := new(EncryptPrivateKey) + err = userKey2.UnmarshalASN1(der) + if err != nil { + t.Fatal(err) + } + if userKey.PrivateKey.p.x != userKey2.PrivateKey.p.x || userKey.PrivateKey.p.y != userKey2.PrivateKey.p.y || userKey.PrivateKey.p.z != userKey2.PrivateKey.p.z { + t.Errorf("not same") + } +} + +func BenchmarkGenerateSignPrivKey(b *testing.B) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := masterKey.GenerateUserKey(uid, hid); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkGenerateEncryptPrivKey(b *testing.B) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + if _, err := masterKey.GenerateUserKey(uid, hid); err != nil { + b.Fatal(err) + } + } +} diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go new file mode 100644 index 0000000..660a02f --- /dev/null +++ b/sm9/sm9_test.go @@ -0,0 +1,462 @@ +package sm9 + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "math/big" + "testing" + + "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/sm3" +) + +func TestHashH1(t *testing.T) { + expected := "2acc468c3926b0bdb2767e99ff26e084de9ced8dbc7d5fbf418027b667862fab" + h := hashH1([]byte{0x41, 0x6c, 0x69, 0x63, 0x65, 0x01}) + if hex.EncodeToString(h.Bytes()) != expected { + t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) + } +} + +func TestHashH2(t *testing.T) { + expected := "823c4b21e4bd2dfe1ed92c606653e996668563152fc33f55d7bfbb9bd9705adb" + zStr := "4368696E65736520494253207374616E6461726481377B8FDBC2839B4FA2D0E0F8AA6853BBBE9E9C4099608F8612C6078ACD7563815AEBA217AD502DA0F48704CC73CABB3C06209BD87142E14CBD99E8BCA1680F30DADC5CD9E207AEE32209F6C3CA3EC0D800A1A42D33C73153DED47C70A39D2E8EAF5D179A1836B359A9D1D9BFC19F2EFCDB829328620962BD3FDF15F2567F58A543D25609AE943920679194ED30328BB33FD15660BDE485C6B79A7B32B013983F012DB04BA59FE88DB889321CC2373D4C0C35E84F7AB1FF33679BCA575D67654F8624EB435B838CCA77B2D0347E65D5E46964412A096F4150D8C5EDE5440DDF0656FCB663D24731E80292188A2471B8B68AA993899268499D23C89755A1A89744643CEAD40F0965F28E1CD2895C3D118E4F65C9A0E3E741B6DD52C0EE2D25F5898D60848026B7EFB8FCC1B2442ECF0795F8A81CEE99A6248F294C82C90D26BD6A814AAF475F128AEF43A128E37F80154AE6CB92CAD7D1501BAE30F750B3A9BD1F96B08E97997363911314705BFB9A9DBB97F75553EC90FBB2DDAE53C8F68E42" + z, err := hex.DecodeString(zStr) + if err != nil { + t.Fatal(err) + } + h := hashH2(z) + if hex.EncodeToString(h.Bytes()) != expected { + t.Errorf("got %v, expected %v", hex.EncodeToString(h.Bytes()), expected) + } +} + +func TestSign(t *testing.T) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + hashed := []byte("Chinese IBS standard") + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + h, s, err := Sign(rand.Reader, userKey, hashed) + if err != nil { + t.Fatal(err) + } + if !Verify(masterKey.Public(), uid, hid, hashed, h, s) { + t.Errorf("Verify failed") + } + hashed[0] ^= 0xff + if Verify(masterKey.Public(), uid, hid, hashed, h, s) { + t.Errorf("Verify always works!") + } +} + +func TestSignASN1(t *testing.T) { + masterKey, err := GenerateSignMasterKey(rand.Reader) + hashed := []byte("Chinese IBS standard") + uid := []byte("emmansun") + hid := byte(0x01) + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + sig, err := SignASN1(rand.Reader, userKey, hashed) + if err != nil { + t.Fatal(err) + } + if !VerifyASN1(masterKey.Public(), uid, hid, hashed, sig) { + t.Errorf("Verify failed") + } +} + +func TestWrappKey(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + key, cipher, err := WrappKey(rand.Reader, masterKey.Public(), uid, hid, 16) + if err != nil { + t.Fatal(err) + } + + key2, err := UnwrappKey(userKey, uid, cipher, 16) + if err != nil { + t.Fatal(err) + } + + if hex.EncodeToString(key) != hex.EncodeToString(key2) { + t.Errorf("expected %v, got %v\n", hex.EncodeToString(key), hex.EncodeToString(key2)) + } +} + +func TestWrappKeyASN1(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + key, cipher, err := masterKey.Public().WrappKey(rand.Reader, uid, hid, 16) + if err != nil { + t.Fatal(err) + } + + key2, err := userKey.UnwrappKey(uid, cipher, 16) + if err != nil { + t.Fatal(err) + } + + if hex.EncodeToString(key) != hex.EncodeToString(key2) { + t.Errorf("expected %v, got %v\n", hex.EncodeToString(key), hex.EncodeToString(key2)) + } +} + +func TestUnmarshalSM9KeyPackage(t *testing.T) { + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + p, err := masterKey.Public().WrappKeyASN1(rand.Reader, uid, hid, 16) + if err != nil { + t.Fatal(err) + } + + key, cipher, err := UnmarshalSM9KeyPackage(p) + if err != nil { + t.Fatal(err) + } + + key2, err := UnwrappKey(userKey, uid, cipher, 16) + if err != nil { + t.Fatal(err) + } + + if hex.EncodeToString(key) != hex.EncodeToString(key2) { + t.Errorf("expected %v, got %v\n", hex.EncodeToString(key), hex.EncodeToString(key2)) + } +} + +func TestWrappKeySM9Sample(t *testing.T) { + expectedKey := "4ff5cf86d2ad40c8f4bac98d76abdbde0c0e2f0a829d3f911ef5b2bce0695480" + masterKey := new(EncryptMasterPrivateKey) + masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") + masterKey.MasterPublicKey = new(G1).ScalarBaseMult(masterKey.D) + fmt.Printf("Pub-e=%v\n", hex.EncodeToString(masterKey.MasterPublicKey.Marshal())) + + uid := []byte("Bob") + hid := byte(0x03) + + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + fmt.Printf("UserPrivKey=%v\n", hex.EncodeToString(userKey.PrivateKey.Marshal())) + + q := masterKey.Public().GenerateUserPublicKey(uid, hid) + fmt.Printf("Qb=%v\n", hex.EncodeToString(q.Marshal())) + var r *big.Int = bigFromHex("74015F8489C01EF4270456F9E6475BFB602BDE7F33FD482AB4E3684A6722") + + cipher := new(G1).ScalarMult(q, r) + fmt.Printf("C=%v\n", hex.EncodeToString(cipher.Marshal())) + + g := Pair(masterKey.Public().MasterPublicKey, Gen2) + w := new(GT).ScalarMult(g, r) + + var buffer []byte + buffer = append(buffer, cipher.Marshal()...) + buffer = append(buffer, w.Marshal()...) + buffer = append(buffer, uid...) + + key, ok := sm3.Kdf(buffer, 32) + if !ok { + t.Failed() + } + if hex.EncodeToString(key) != expectedKey { + t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key)) + } + + key2, err := UnwrappKey(userKey, uid, cipher, 32) + if err != nil { + t.Fatal(err) + } + if hex.EncodeToString(key2) != expectedKey { + t.Errorf("expected %v, got %v\n", expectedKey, hex.EncodeToString(key2)) + } +} + +func TestEncryptSM9Sample(t *testing.T) { + plaintext := []byte("Chinese IBE standard") + expectedCiphertext := "2445471164490618e1ee20528ff1d545b0f14c8bcaa44544f03dab5dac07d8ff42ffca97d57cddc05ea405f2e586feb3a6930715532b8000759f13059ed59ac0ba672387bcd6de5016a158a52bb2e7fc429197bcab70b25afee37a2b9db9f3671b5f5b0e951489682f3e64e1378cdd5da9513b1c" + masterKey := new(EncryptMasterPrivateKey) + masterKey.D = bigFromHex("01EDEE3778F441F8DEA3D9FA0ACC4E07EE36C93F9A08618AF4AD85CEDE1C22") + masterKey.MasterPublicKey = new(G1).ScalarBaseMult(masterKey.D) + fmt.Printf("Pub-e=%v\n", hex.EncodeToString(masterKey.MasterPublicKey.Marshal())) + + uid := []byte("Bob") + hid := byte(0x03) + + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + fmt.Printf("UserPrivKey=%v\n", hex.EncodeToString(userKey.PrivateKey.Marshal())) + + q := masterKey.Public().GenerateUserPublicKey(uid, hid) + fmt.Printf("Qb=%v\n", hex.EncodeToString(q.Marshal())) + var r *big.Int = bigFromHex("AAC0541779C8FC45E3E2CB25C12B5D2576B2129AE8BB5EE2CBE5EC9E785C") + + cipher := new(G1).ScalarMult(q, r) + fmt.Printf("C=%v\n", hex.EncodeToString(cipher.Marshal())) + + g := Pair(masterKey.Public().MasterPublicKey, Gen2) + w := new(GT).ScalarMult(g, r) + + var buffer []byte + buffer = append(buffer, cipher.Marshal()...) + buffer = append(buffer, w.Marshal()...) + buffer = append(buffer, uid...) + + key, ok := sm3.Kdf(buffer, len(plaintext)+32) + if !ok { + t.Failed() + } + + fmt.Printf("key=%v\n", hex.EncodeToString(key)) + xor.XorBytes(key, key[:len(plaintext)], plaintext) + + hash := sm3.New() + hash.Write(key) + c3 := hash.Sum(nil) + + ciphertext := append(cipher.Marshal(), c3...) + ciphertext = append(ciphertext, key[:len(plaintext)]...) + if hex.EncodeToString(ciphertext) != expectedCiphertext { + t.Errorf("expected %v, got %v\n", expectedCiphertext, hex.EncodeToString(ciphertext)) + } +} + +func TestEncryptDecrypt(t *testing.T) { + plaintext := []byte("Chinese IBE standard") + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) + if err != nil { + t.Fatal(err) + } + + got, err := Decrypt(userKey, uid, cipher) + if err != nil { + t.Fatal(err) + } + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } + + got, err = Decrypt(userKey, uid, cipher) + if err != nil { + t.Fatal(err) + } + + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } +} + +func TestEncryptDecryptASN1(t *testing.T) { + plaintext := []byte("Chinese IBE standard") + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + t.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + t.Fatal(err) + } + cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext) + if err != nil { + t.Fatal(err) + } + + got, err := DecryptASN1(userKey, uid, cipher) + if err != nil { + t.Fatal(err) + } + + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } + + got, err = DecryptASN1(userKey, uid, cipher) + if err != nil { + t.Fatal(err) + } + + if string(got) != string(plaintext) { + t.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } +} + +func BenchmarkSign(b *testing.B) { + hashed := []byte("Chinese IBS standard") + uid := []byte("emmansun") + hid := byte(0x01) + + masterKey, err := GenerateSignMasterKey(rand.Reader) + if err != nil { + b.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + b.Fatal(err) + } + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sig, err := SignASN1(rand.Reader, userKey, hashed) + if err != nil { + b.Fatal(err) + } + // Prevent the compiler from optimizing out the operation. + hashed[0] = sig[0] + } +} + +func BenchmarkVerify(b *testing.B) { + hashed := []byte("Chinese IBS standard") + uid := []byte("emmansun") + hid := byte(0x01) + + masterKey, err := GenerateSignMasterKey(rand.Reader) + if err != nil { + b.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + b.Fatal(err) + } + sig, err := SignASN1(rand.Reader, userKey, hashed) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + if !VerifyASN1(masterKey.Public(), uid, hid, hashed, sig) { + b.Fatal("verify failed") + } + } +} + +func BenchmarkEncrypt(b *testing.B) { + plaintext := []byte("Chinese IBE standard") + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) + if err != nil { + b.Fatal(err) + } + // Prevent the compiler from optimizing out the operation. + plaintext[0] = cipher[0] + } +} + +func BenchmarkDecrypt(b *testing.B) { + plaintext := []byte("Chinese IBE standard") + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + b.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + b.Fatal(err) + } + cipher, err := Encrypt(rand.Reader, masterKey.Public(), uid, hid, plaintext) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + got, err := Decrypt(userKey, uid, cipher) + if err != nil { + b.Fatal(err) + } + if string(got) != string(plaintext) { + b.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } + } +} + +func BenchmarkDecryptASN1(b *testing.B) { + plaintext := []byte("Chinese IBE standard") + masterKey, err := GenerateEncryptMasterKey(rand.Reader) + hid := byte(0x01) + uid := []byte("emmansun") + if err != nil { + b.Fatal(err) + } + userKey, err := masterKey.GenerateUserKey(uid, hid) + if err != nil { + b.Fatal(err) + } + cipher, err := EncryptASN1(rand.Reader, masterKey.Public(), uid, hid, plaintext) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + got, err := DecryptASN1(userKey, uid, cipher) + if err != nil { + b.Fatal(err) + } + if string(got) != string(plaintext) { + b.Errorf("expected %v, got %v\n", string(plaintext), string(got)) + } + } +}