diff --git a/cipher/ccm.go b/cipher/ccm.go index 44ad53e..e6d2b90 100644 --- a/cipher/ccm.go +++ b/cipher/ccm.go @@ -8,8 +8,8 @@ import ( "errors" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) const ( @@ -112,14 +112,14 @@ func (c *ccm) deriveCounter(counter *[ccmBlockSize]byte, nonce []byte) { func (c *ccm) cmac(out, data []byte) { for len(data) >= ccmBlockSize { - xor.XorBytes(out, out, data) + subtle.XORBytes(out, out, data) c.cipher.Encrypt(out, out) data = data[ccmBlockSize:] } if len(data) > 0 { var block [ccmBlockSize]byte copy(block[:], data) - xor.XorBytes(out, out, data) + subtle.XORBytes(out, out, data) c.cipher.Encrypt(out, out) } } @@ -168,7 +168,7 @@ func (c *ccm) auth(nonce, plaintext, additionalData []byte, tagMask *[ccmBlockSi if len(plaintext) > 0 { c.cmac(out[:], plaintext) } - xor.XorWords(out[:], out[:], tagMask[:]) + subtle.XORBytes(out[:], out[:], tagMask[:]) return out[:c.tagSize] } @@ -179,8 +179,8 @@ func (c *ccm) Seal(dst, nonce, plaintext, data []byte) []byte { if uint64(len(plaintext)) > uint64(c.MaxLength()) { panic("cipher: message too large for CCM") } - ret, out := subtle.SliceForAppend(dst, len(plaintext)+c.tagSize) - if subtle.InexactOverlap(out, plaintext) { + ret, out := alias.SliceForAppend(dst, len(plaintext)+c.tagSize) + if alias.InexactOverlap(out, plaintext) { panic("cipher: invalid buffer overlap") } @@ -225,8 +225,8 @@ func (c *ccm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { c.deriveCounter(&counter, nonce) c.cipher.Encrypt(tagMask[:], counter[:]) - ret, out := subtle.SliceForAppend(dst, len(ciphertext)) - if subtle.InexactOverlap(out, ciphertext) { + ret, out := alias.SliceForAppend(dst, len(ciphertext)) + if alias.InexactOverlap(out, ciphertext) { panic("cipher: invalid buffer overlap") } diff --git a/cipher/xts.go b/cipher/xts.go index 621342f..c4a60e7 100644 --- a/cipher/xts.go +++ b/cipher/xts.go @@ -6,8 +6,8 @@ import ( "errors" "sync" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) const GF128_FDBK byte = 0x87 @@ -89,7 +89,7 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { if len(plaintext) < blockSize { panic("xts: plaintext length is smaller than the block size") } - if subtle.InexactOverlap(ciphertext[:len(plaintext)], plaintext) { + if alias.InexactOverlap(ciphertext[:len(plaintext)], plaintext) { panic("xts: invalid buffer overlap") } @@ -112,18 +112,18 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { copy(tweaks[blockSize*i:], tweak[:]) mul2(tweak) } - xor.XorBytes(ciphertext, plaintext, tweaks) + subtle.XORBytes(ciphertext, plaintext, tweaks) concCipher.EncryptBlocks(ciphertext, ciphertext) - xor.XorBytes(ciphertext, ciphertext, tweaks) + subtle.XORBytes(ciphertext, ciphertext, tweaks) plaintext = plaintext[batchSize:] lastCiphertext = ciphertext[batchSize-blockSize:] ciphertext = ciphertext[batchSize:] } } for len(plaintext) >= blockSize { - xor.XorBytes(ciphertext, plaintext, tweak[:]) + subtle.XORBytes(ciphertext, plaintext, tweak[:]) c.k1.Encrypt(ciphertext, ciphertext) - xor.XorBytes(ciphertext, ciphertext, tweak[:]) + subtle.XORBytes(ciphertext, ciphertext, tweak[:]) plaintext = plaintext[blockSize:] lastCiphertext = ciphertext ciphertext = ciphertext[blockSize:] @@ -139,11 +139,11 @@ func (c *xts) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) { //Steal ciphertext to complete the block copy(x[remain:], lastCiphertext[remain:blockSize]) //Merge the tweak into the input block - xor.XorBytes(x[:], x[:], tweak[:]) + subtle.XORBytes(x[:], x[:], tweak[:]) //Encrypt the final block using K1 c.k1.Encrypt(x[:], x[:]) //Merge the tweak into the output block - xor.XorBytes(lastCiphertext, x[:], tweak[:]) + subtle.XORBytes(lastCiphertext, x[:], tweak[:]) } tweakPool.Put(tweak) } @@ -158,7 +158,7 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { if len(ciphertext) < blockSize { panic("xts: ciphertext length is smaller than the block size") } - if subtle.InexactOverlap(plaintext[:len(ciphertext)], ciphertext) { + if alias.InexactOverlap(plaintext[:len(ciphertext)], ciphertext) { panic("xts: invalid buffer overlap") } @@ -179,18 +179,18 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { copy(tweaks[blockSize*i:], tweak[:]) mul2(tweak) } - xor.XorBytes(plaintext, ciphertext, tweaks) + subtle.XORBytes(plaintext, ciphertext, tweaks) concCipher.DecryptBlocks(plaintext, plaintext) - xor.XorBytes(plaintext, plaintext, tweaks) + subtle.XORBytes(plaintext, plaintext, tweaks) plaintext = plaintext[batchSize:] ciphertext = ciphertext[batchSize:] } } for len(ciphertext) >= 2*blockSize { - xor.XorBytes(plaintext, ciphertext, tweak[:]) + subtle.XORBytes(plaintext, ciphertext, tweak[:]) c.k1.Decrypt(plaintext, plaintext) - xor.XorBytes(plaintext, plaintext, tweak[:]) + subtle.XORBytes(plaintext, plaintext, tweak[:]) plaintext = plaintext[blockSize:] ciphertext = ciphertext[blockSize:] @@ -203,9 +203,9 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { var tt [blockSize]byte copy(tt[:], tweak[:]) mul2(&tt) - xor.XorBytes(x[:], ciphertext, tt[:]) + subtle.XORBytes(x[:], ciphertext, tt[:]) c.k1.Decrypt(x[:], x[:]) - xor.XorBytes(plaintext, x[:], tt[:]) + subtle.XORBytes(plaintext, x[:], tt[:]) //Retrieve the length of the final block remain -= blockSize @@ -220,9 +220,9 @@ func (c *xts) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) { //The last block contains exactly 128 bits copy(x[:], ciphertext) } - xor.XorBytes(x[:], x[:], tweak[:]) + subtle.XORBytes(x[:], x[:], tweak[:]) c.k1.Decrypt(x[:], x[:]) - xor.XorBytes(plaintext, x[:], tweak[:]) + subtle.XORBytes(plaintext, x[:], tweak[:]) } tweakPool.Put(tweak) diff --git a/internal/alias/alias_test.go b/internal/alias/alias_test.go new file mode 100644 index 0000000..76d95d9 --- /dev/null +++ b/internal/alias/alias_test.go @@ -0,0 +1,42 @@ +package alias + +import "testing" + +var a, b [100]byte + +var aliasingTests = []struct { + x, y []byte + anyOverlap, inexactOverlap bool +}{ + {a[:], b[:], false, false}, + {a[:], b[:0], false, false}, + {a[:], b[:50], false, false}, + {a[40:50], a[50:60], false, false}, + {a[40:50], a[60:70], false, false}, + {a[:51], a[50:], true, true}, + {a[:], a[:], true, false}, + {a[:50], a[:60], true, false}, + {a[:], nil, false, false}, + {nil, nil, false, false}, + {a[:], a[:0], false, false}, + {a[:10], a[:10:20], true, false}, + {a[:10], a[5:10:20], true, true}, +} + +func testAliasing(t *testing.T, i int, x, y []byte, anyOverlap, inexactOverlap bool) { + any := AnyOverlap(x, y) + if any != anyOverlap { + t.Errorf("%d: wrong AnyOverlap result, expected %v, got %v", i, anyOverlap, any) + } + inexact := InexactOverlap(x, y) + if inexact != inexactOverlap { + t.Errorf("%d: wrong InexactOverlap result, expected %v, got %v", i, inexactOverlap, any) + } +} + +func TestAliasing(t *testing.T) { + for i, tt := range aliasingTests { + testAliasing(t, i, tt.x, tt.y, tt.anyOverlap, tt.inexactOverlap) + testAliasing(t, i, tt.y, tt.x, tt.anyOverlap, tt.inexactOverlap) + } +} diff --git a/internal/subtle/aliasing.go b/internal/alias/aliasing.go similarity index 98% rename from internal/subtle/aliasing.go rename to internal/alias/aliasing.go index edd60d6..9bc3ca7 100644 --- a/internal/subtle/aliasing.go +++ b/internal/alias/aliasing.go @@ -1,4 +1,4 @@ -package subtle +package alias import "unsafe" diff --git a/internal/subtle/xor.go b/internal/subtle/xor.go new file mode 100644 index 0000000..5120862 --- /dev/null +++ b/internal/subtle/xor.go @@ -0,0 +1,20 @@ +package subtle + +// XORBytes sets dst[i] = x[i] ^ y[i] for all i < n = min(len(x), len(y)), +// returning n, the number of bytes written to dst. +// If dst does not have length at least n, +// XORBytes panics without writing anything to dst. +func XORBytes(dst, x, y []byte) int { + n := len(x) + if len(y) < n { + n = len(y) + } + if n == 0 { + return 0 + } + if n > len(dst) { + panic("subtle.XORBytes: dst too short") + } + xorBytes(&dst[0], &x[0], &y[0], n) // arch-specific + return n +} diff --git a/internal/subtle/xor_amd64.go b/internal/subtle/xor_amd64.go new file mode 100644 index 0000000..c65a77f --- /dev/null +++ b/internal/subtle/xor_amd64.go @@ -0,0 +1,11 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +//go:build amd64 && !generic +// +build amd64,!generic + +package subtle + +//go:noescape +func xorBytes(dst, a, b *byte, n int) diff --git a/internal/xor/xor_amd64.s b/internal/subtle/xor_amd64.s similarity index 94% rename from internal/xor/xor_amd64.s rename to internal/subtle/xor_amd64.s index 3262818..eaea9e6 100644 --- a/internal/xor/xor_amd64.s +++ b/internal/subtle/xor_amd64.s @@ -1,13 +1,14 @@ // Copyright 2018 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// //go:build amd64 && !generic // +build amd64,!generic #include "textflag.h" -// func xorBytesSSE2(dst, a, b *byte, n int) -TEXT ·xorBytesSSE2(SB), NOSPLIT, $0 +// func xorBytes(dst, a, b *byte, n int) +TEXT ·xorBytes(SB), NOSPLIT, $0 MOVQ dst+0(FP), BX MOVQ a+8(FP), SI MOVQ b+16(FP), CX diff --git a/internal/subtle/xor_arm64.go b/internal/subtle/xor_arm64.go new file mode 100644 index 0000000..68989ca --- /dev/null +++ b/internal/subtle/xor_arm64.go @@ -0,0 +1,11 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +//go:build arm64 && !generic +// +build arm64,!generic + +package subtle + +//go:noescape +func xorBytes(dst, a, b *byte, n int) diff --git a/internal/xor/xor_arm64.s b/internal/subtle/xor_arm64.s similarity index 94% rename from internal/xor/xor_arm64.s rename to internal/subtle/xor_arm64.s index acbd2ea..052721f 100644 --- a/internal/xor/xor_arm64.s +++ b/internal/subtle/xor_arm64.s @@ -1,13 +1,14 @@ // Copyright 2020 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// //go:build arm64 && !generic // +build arm64,!generic #include "textflag.h" -// func xorBytesARM64(dst, a, b *byte, n int) -TEXT ·xorBytesARM64(SB), NOSPLIT|NOFRAME, $0 +// func xorBytes(dst, a, b *byte, n int) +TEXT ·xorBytes(SB), NOSPLIT|NOFRAME, $0 MOVD dst+0(FP), R0 MOVD a+8(FP), R1 MOVD b+16(FP), R2 diff --git a/internal/xor/xor_generic.go b/internal/subtle/xor_generic.go similarity index 54% rename from internal/xor/xor_generic.go rename to internal/subtle/xor_generic.go index 20395e6..3c7bae5 100644 --- a/internal/xor/xor_generic.go +++ b/internal/subtle/xor_generic.go @@ -1,27 +1,25 @@ // Copyright 2013 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. +// //go:build !amd64 && !arm64 || generic // +build !amd64,!arm64 generic -package xor +package subtle import ( "runtime" "unsafe" ) -// xorBytes xors the bytes in a and b. The destination should have enough -// space, otherwise xorBytes will panic. Returns the number of bytes xor'd. -func XorBytes(dst, a, b []byte) int { - n := len(a) - if len(b) < n { - n = len(b) - } - if n == 0 { - return 0 - } +const wordSize = int(unsafe.Sizeof(uintptr(0))) +const supportsUnaligned = runtime.GOARCH == "386" || + runtime.GOARCH == "ppc64" || + runtime.GOARCH == "ppc64le" || + runtime.GOARCH == "s390x" + +func xorBytes(dst, a, b []byte, n int) int { switch { case supportsUnaligned: fastXORBytes(dst, a, b, n) @@ -36,16 +34,10 @@ func XorBytes(dst, a, b []byte) int { return n } -const wordSize = int(unsafe.Sizeof(uintptr(0))) -const supportsUnaligned = runtime.GOARCH == "386" || runtime.GOARCH == "ppc64" || runtime.GOARCH == "ppc64le" || runtime.GOARCH == "s390x" - // fastXORBytes xors in bulk. It only works on architectures that // support unaligned read/writes. // n needs to be smaller or equal than the length of a and b. func fastXORBytes(dst, a, b []byte, n int) { - // Assert dst has enough space - _ = dst[n-1] - w := n / wordSize if w > 0 { dw := *(*[]uintptr)(unsafe.Pointer(&dst)) @@ -67,25 +59,3 @@ func safeXORBytes(dst, a, b []byte, n int) { dst[i] = a[i] ^ b[i] } } - -// fastXORWords XORs multiples of 4 or 8 bytes (depending on architecture.) -// The arguments are assumed to be of equal length. -func fastXORWords(dst, a, b []byte) { - dw := *(*[]uintptr)(unsafe.Pointer(&dst)) - aw := *(*[]uintptr)(unsafe.Pointer(&a)) - bw := *(*[]uintptr)(unsafe.Pointer(&b)) - n := len(b) / wordSize - for i := 0; i < n; i++ { - dw[i] = aw[i] ^ bw[i] - } -} - -// fastXORWords XORs multiples of 4 or 8 bytes (depending on architecture.) -// The slice arguments a and b are assumed to be of equal length. -func XorWords(dst, a, b []byte) { - if supportsUnaligned { - fastXORWords(dst, a, b) - } else { - safeXORBytes(dst, a, b, len(b)) - } -} diff --git a/internal/subtle/xor_test.go b/internal/subtle/xor_test.go new file mode 100644 index 0000000..0a6cf97 --- /dev/null +++ b/internal/subtle/xor_test.go @@ -0,0 +1,95 @@ +package subtle_test + +import ( + "bytes" + "crypto/rand" + "fmt" + "io" + "testing" + + "github.com/emmansun/gmsm/internal/subtle" +) + +func TestXORBytes(t *testing.T) { + for n := 1; n <= 1024; n++ { + if n > 16 && testing.Short() { + n += n >> 3 + } + for alignP := 0; alignP < 8; alignP++ { + for alignQ := 0; alignQ < 8; alignQ++ { + for alignD := 0; alignD < 8; alignD++ { + p := make([]byte, alignP+n, alignP+n+10)[alignP:] + q := make([]byte, alignQ+n, alignQ+n+10)[alignQ:] + if n&1 != 0 { + p = p[:n] + } else { + q = q[:n] + } + if _, err := io.ReadFull(rand.Reader, p); err != nil { + t.Fatal(err) + } + if _, err := io.ReadFull(rand.Reader, q); err != nil { + t.Fatal(err) + } + + d := make([]byte, alignD+n, alignD+n+10) + for i := range d { + d[i] = 0xdd + } + want := make([]byte, len(d), cap(d)) + copy(want[:cap(want)], d[:cap(d)]) + for i := 0; i < n; i++ { + want[alignD+i] = p[i] ^ q[i] + } + + if subtle.XORBytes(d[alignD:], p, q); !bytes.Equal(d, want) { + t.Fatalf("n=%d alignP=%d alignQ=%d alignD=%d:\n\tp = %x\n\tq = %x\n\td = %x\n\twant %x\n", n, alignP, alignQ, alignD, p, q, d, want) + } + } + } + } + } +} + +func TestXorBytesPanic(t *testing.T) { + mustPanic(t, "subtle.XORBytes: dst too short", func() { + subtle.XORBytes(nil, make([]byte, 1), make([]byte, 1)) + }) + mustPanic(t, "subtle.XORBytes: dst too short", func() { + subtle.XORBytes(make([]byte, 1), make([]byte, 2), make([]byte, 3)) + }) +} + +func BenchmarkXORBytes(b *testing.B) { + dst := make([]byte, 1<<15) + data0 := make([]byte, 1<<15) + data1 := make([]byte, 1<<15) + sizes := []int64{1 << 3, 1 << 7, 1 << 11, 1 << 15} + for _, size := range sizes { + b.Run(fmt.Sprintf("%dBytes", size), func(b *testing.B) { + s0 := data0[:size] + s1 := data1[:size] + b.SetBytes(int64(size)) + for i := 0; i < b.N; i++ { + subtle.XORBytes(dst, s0, s1) + } + }) + } +} + +func mustPanic(t *testing.T, expected string, f func()) { + t.Helper() + defer func() { + switch msg := recover().(type) { + case nil: + t.Errorf("expected panic(%q), but did not panic", expected) + case string: + if msg != expected { + t.Errorf("expected panic(%q), but got panic(%q)", expected, msg) + } + default: + t.Errorf("expected panic(%q), but got panic(%T%v)", expected, msg, msg) + } + }() + f() +} diff --git a/internal/xor/xor_amd64.go b/internal/xor/xor_amd64.go deleted file mode 100644 index 9ece77b..0000000 --- a/internal/xor/xor_amd64.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -//go:build amd64 && !generic -// +build amd64,!generic - -package xor - -// XorBytes xors the bytes in a and b. The destination should have enough -// space, otherwise xorBytes will panic. Returns the number of bytes xor'd. -func XorBytes(dst, a, b []byte) int { - n := len(a) - if len(b) < n { - n = len(b) - } - if n == 0 { - return 0 - } - _ = dst[n-1] - xorBytesSSE2(&dst[0], &a[0], &b[0], n) // amd64 must have SSE2 - return n -} - -func XorWords(dst, a, b []byte) { - XorBytes(dst, a, b) -} - -//go:noescape -func xorBytesSSE2(dst, a, b *byte, n int) diff --git a/internal/xor/xor_arm64.go b/internal/xor/xor_arm64.go deleted file mode 100644 index 7453273..0000000 --- a/internal/xor/xor_arm64.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2018 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. -//go:build arm64 && !generic -// +build arm64,!generic - -package xor - -// xorBytes xors the bytes in a and b. The destination should have enough -// space, otherwise xorBytes will panic. Returns the number of bytes xor'd. -func XorBytes(dst, a, b []byte) int { - n := len(a) - if len(b) < n { - n = len(b) - } - if n == 0 { - return 0 - } - // make sure dst has enough space - _ = dst[n-1] - - xorBytesARM64(&dst[0], &a[0], &b[0], n) - return n -} - -func XorWords(dst, a, b []byte) { - XorBytes(dst, a, b) -} - -//go:noescape -func xorBytesARM64(dst, a, b *byte, n int) diff --git a/padding/ansi_x923.go b/padding/ansi_x923.go index 76f23ed..22a174a 100644 --- a/padding/ansi_x923.go +++ b/padding/ansi_x923.go @@ -3,7 +3,7 @@ package padding import ( "errors" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) // https://www.ibm.com/docs/en/linux-on-systems?topic=processes-ansi-x923-cipher-block-chaining @@ -15,7 +15,7 @@ func (pad ansiX923Padding) BlockSize() int { func (pad ansiX923Padding) Pad(src []byte) []byte { overhead := pad.BlockSize() - len(src)%pad.BlockSize() - ret, out := subtle.SliceForAppend(src, overhead) + ret, out := alias.SliceForAppend(src, overhead) out[overhead-1] = byte(overhead) for i := 0; i < overhead-1; i++ { out[i] = 0 diff --git a/padding/pkcs7.go b/padding/pkcs7.go index a9f84f0..1a3ee62 100644 --- a/padding/pkcs7.go +++ b/padding/pkcs7.go @@ -4,7 +4,7 @@ package padding import ( "errors" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) // https://datatracker.ietf.org/doc/html/rfc5652#section-6.3 @@ -16,7 +16,7 @@ func (pad pkcs7Padding) BlockSize() int { func (pad pkcs7Padding) Pad(src []byte) []byte { overhead := pad.BlockSize() - len(src)%pad.BlockSize() - ret, out := subtle.SliceForAppend(src, overhead) + ret, out := alias.SliceForAppend(src, overhead) for i := 0; i < overhead; i++ { out[i] = byte(overhead) } diff --git a/sm2/sm2.go b/sm2/sm2.go index aba13ac..47e5902 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -23,7 +23,7 @@ import ( "strings" "github.com/emmansun/gmsm/internal/randutil" - "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/sm2/sm2ec" "github.com/emmansun/gmsm/sm3" "golang.org/x/crypto/cryptobyte" @@ -345,7 +345,7 @@ func Encrypt(random io.Reader, pub *ecdsa.PublicKey, msg []byte, opts *Encrypter } //A6, C2 = M + t; - xor.XorBytes(c2, msg, c2) + subtle.XORBytes(c2, msg, c2) //A7, C3 = hash(x2||M||y2) c3 := calculateC3(curve, x2, y2, msg) @@ -402,7 +402,7 @@ func rawDecrypt(priv *PrivateKey, x1, y1 *big.Int, c2, c3 []byte) ([]byte, error } //B5, calculate msg = c2 ^ t - xor.XorBytes(msg, c2, msg) + subtle.XORBytes(msg, c2, msg) u := calculateC3(curve, x2, y2, msg) for i := 0; i < sm3.Size; i++ { diff --git a/sm2/sm2_keyexchange.go b/sm2/sm2_keyexchange.go index e2a7739..0170cda 100644 --- a/sm2/sm2_keyexchange.go +++ b/sm2/sm2_keyexchange.go @@ -93,6 +93,10 @@ func (ke *KeyExchange) SetPeerParameters(peerPub *ecdsa.PublicKey, peerUID []byt return errors.New("sm2: 'peerPub' already exists, please do not set it") } + if !IsSM2PublicKey(peerPub) { + return errors.New("sm2: peer public key is not expected/supported") + } + var err error ke.peerPub = peerPub ke.peerZ, err = calculateZA(ke.peerPub, peerUID) diff --git a/sm2/sm2_keyexchange_test.go b/sm2/sm2_keyexchange_test.go index 55bd5cc..047094a 100644 --- a/sm2/sm2_keyexchange_test.go +++ b/sm2/sm2_keyexchange_test.go @@ -1,6 +1,8 @@ package sm2 import ( + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "encoding/hex" "errors" @@ -83,6 +85,7 @@ func TestKeyExchangeSimplest(t *testing.T) { func TestSetPeerParameters(t *testing.T) { priv1, _ := GenerateKey(rand.Reader) priv2, _ := GenerateKey(rand.Reader) + priv3, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) uidA := []byte("Alice") uidB := []byte("Bob") @@ -101,6 +104,11 @@ func TestSetPeerParameters(t *testing.T) { } // 设置对端参数 + err = initiator.SetPeerParameters(&priv3.PublicKey, uidB) + if err == nil { + t.Errorf("should be failed") + } + err = initiator.SetPeerParameters(&priv2.PublicKey, uidB) if err != nil { t.Fatal(err) diff --git a/sm4/cbc_cipher_asm.go b/sm4/cbc_cipher_asm.go index 5a65b02..3b2385a 100644 --- a/sm4/cbc_cipher_asm.go +++ b/sm4/cbc_cipher_asm.go @@ -6,8 +6,8 @@ package sm4 import ( "crypto/cipher" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) // Assert that sm4CipherAsm implements the cbcEncAble and cbcDecAble interfaces. @@ -56,7 +56,7 @@ func (x *cbc) CryptBlocks(dst, src []byte) { if len(dst) < len(src) { panic("cipher: output smaller than input") } - if subtle.InexactOverlap(dst[:len(src)], src) { + if alias.InexactOverlap(dst[:len(src)], src) { panic("cipher: invalid buffer overlap") } if len(src) == 0 { @@ -79,7 +79,7 @@ func (x *cbc) CryptBlocks(dst, src []byte) { for start > 0 { x.b.DecryptBlocks(temp, src[start:end]) copy(batchSrc, src[start-BlockSize:]) - xor.XorBytes(dst[start:], temp, batchSrc) + subtle.XORBytes(dst[start:], temp, batchSrc) end = start start -= x.b.blocksSize } @@ -88,7 +88,7 @@ func (x *cbc) CryptBlocks(dst, src []byte) { copy(batchSrc[BlockSize:], src[:end]) x.b.DecryptBlocks(temp, batchSrc[BlockSize:]) copy(batchSrc, x.iv) - xor.XorBytes(dst, temp[:end], batchSrc) + subtle.XORBytes(dst, temp[:end], batchSrc) // Set the new iv to the first block we copied earlier. x.iv, x.tmp = x.tmp, x.iv diff --git a/sm4/cipher.go b/sm4/cipher.go index be575df..b07dc1d 100644 --- a/sm4/cipher.go +++ b/sm4/cipher.go @@ -5,7 +5,7 @@ import ( "crypto/cipher" "fmt" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) // BlockSize the sm4 block size in bytes. @@ -49,7 +49,7 @@ func (c *sm4Cipher) Encrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } encryptBlockGo(c.enc, dst, src) @@ -62,7 +62,7 @@ func (c *sm4Cipher) Decrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } decryptBlockGo(c.dec, dst, src) diff --git a/sm4/cipher_asm.go b/sm4/cipher_asm.go index 0a0b6a9..18ea99f 100644 --- a/sm4/cipher_asm.go +++ b/sm4/cipher_asm.go @@ -6,7 +6,7 @@ package sm4 import ( "crypto/cipher" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" "golang.org/x/sys/cpu" ) @@ -65,7 +65,7 @@ func (c *sm4CipherAsm) Encrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_AES) @@ -78,7 +78,7 @@ func (c *sm4CipherAsm) EncryptBlocks(dst, src []byte) { if len(dst) < c.blocksSize { panic("sm4: output not full blocks") } - if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { + if alias.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } encryptBlocksAsm(&c.enc[0], dst, src, INST_AES) @@ -91,7 +91,7 @@ func (c *sm4CipherAsm) Decrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_AES) @@ -104,7 +104,7 @@ func (c *sm4CipherAsm) DecryptBlocks(dst, src []byte) { if len(dst) < c.blocksSize { panic("sm4: output not full blocks") } - if subtle.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { + if alias.InexactOverlap(dst[:c.blocksSize], src[:c.blocksSize]) { panic("sm4: invalid buffer overlap") } encryptBlocksAsm(&c.dec[0], dst, src, INST_AES) diff --git a/sm4/cipher_ni.go b/sm4/cipher_ni.go index 541af67..77d764a 100644 --- a/sm4/cipher_ni.go +++ b/sm4/cipher_ni.go @@ -6,7 +6,7 @@ package sm4 import ( "crypto/cipher" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) type sm4CipherNI struct { @@ -29,7 +29,7 @@ func (c *sm4CipherNI) Encrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } encryptBlockAsm(&c.enc[0], &dst[0], &src[0], INST_SM4) @@ -42,7 +42,7 @@ func (c *sm4CipherNI) Decrypt(dst, src []byte) { if len(dst) < BlockSize { panic("sm4: output not full block") } - if subtle.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { + if alias.InexactOverlap(dst[:BlockSize], src[:BlockSize]) { panic("sm4: invalid buffer overlap") } encryptBlockAsm(&c.dec[0], &dst[0], &src[0], INST_SM4) diff --git a/sm4/ctr_cipher_asm.go b/sm4/ctr_cipher_asm.go index 7630eb1..9ff2f26 100644 --- a/sm4/ctr_cipher_asm.go +++ b/sm4/ctr_cipher_asm.go @@ -6,8 +6,8 @@ package sm4 import ( "crypto/cipher" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) // Assert that sm4CipherAsm implements the ctrAble interface. @@ -83,14 +83,14 @@ func (x *ctr) XORKeyStream(dst, src []byte) { if len(dst) < len(src) { panic("cipher: output smaller than input") } - if subtle.InexactOverlap(dst[:len(src)], src) { + if alias.InexactOverlap(dst[:len(src)], src) { panic("cipher: invalid buffer overlap") } for len(src) > 0 { if x.outUsed >= len(x.out)-BlockSize { x.refill() } - n := xor.XorBytes(dst, src, x.out[x.outUsed:]) + n := subtle.XORBytes(dst, src, x.out[x.outUsed:]) dst = dst[n:] src = src[n:] x.outUsed += n diff --git a/sm4/gcm_cipher_asm.go b/sm4/gcm_cipher_asm.go index 168a78f..ed1cc9c 100644 --- a/sm4/gcm_cipher_asm.go +++ b/sm4/gcm_cipher_asm.go @@ -9,8 +9,8 @@ import ( "encoding/binary" "errors" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) // Assert that sm4CipherAsm implements the gcmAble interface. @@ -86,8 +86,8 @@ func (g *gcm) Seal(dst, nonce, plaintext, data []byte) []byte { panic("cipher: message too large for GCM") } - ret, out := subtle.SliceForAppend(dst, len(plaintext)+g.tagSize) - if subtle.InexactOverlap(out, plaintext) { + ret, out := alias.SliceForAppend(dst, len(plaintext)+g.tagSize) + if alias.InexactOverlap(out, plaintext) { panic("cipher: invalid buffer overlap") } @@ -137,8 +137,8 @@ func (g *gcm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { var expectedTag [gcmTagSize]byte g.auth(expectedTag[:], ciphertext, data, &tagMask) - ret, out := subtle.SliceForAppend(dst, len(ciphertext)) - if subtle.InexactOverlap(out, ciphertext) { + ret, out := alias.SliceForAppend(dst, len(ciphertext)) + if alias.InexactOverlap(out, ciphertext) { panic("cipher: invalid buffer overlap") } @@ -274,7 +274,7 @@ func (g *gcm) counterCrypt(out, in []byte, counter *[gcmBlockSize]byte) { gcmInc32(counter) } g.cipher.EncryptBlocks(mask, counters) - xor.XorWords(out, in, mask[:]) + subtle.XORBytes(out, in, mask[:]) out = out[g.cipher.blocksSize:] in = in[g.cipher.blocksSize:] } @@ -286,7 +286,7 @@ func (g *gcm) counterCrypt(out, in []byte, counter *[gcmBlockSize]byte) { gcmInc32(counter) } g.cipher.EncryptBlocks(mask, counters) - xor.XorBytes(out, in, mask[:blocks*gcmBlockSize]) + subtle.XORBytes(out, in, mask[:blocks*gcmBlockSize]) } } @@ -328,5 +328,5 @@ func (g *gcm) auth(out, ciphertext, additionalData []byte, tagMask *[gcmTagSize] binary.BigEndian.PutUint64(out, y.low) binary.BigEndian.PutUint64(out[8:], y.high) - xor.XorWords(out, out, tagMask[:]) + subtle.XORBytes(out, out, tagMask[:]) } diff --git a/sm4/sm4_gcm_asm.go b/sm4/sm4_gcm_asm.go index eecf8a3..a8a8f14 100644 --- a/sm4/sm4_gcm_asm.go +++ b/sm4/sm4_gcm_asm.go @@ -5,9 +5,9 @@ package sm4 import ( "crypto/cipher" - goSubtle "crypto/subtle" + "crypto/subtle" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) // sm4CipherGCM implements crypto/cipher.gcmAble so that crypto/cipher.NewGCM @@ -86,8 +86,8 @@ func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte { var tagOut [gcmTagSize]byte gcmSm4Data(&g.bytesProductTable, data, &tagOut) - ret, out := subtle.SliceForAppend(dst, len(plaintext)+g.tagSize) - if subtle.InexactOverlap(out[:len(plaintext)], plaintext) { + ret, out := alias.SliceForAppend(dst, len(plaintext)+g.tagSize) + if alias.InexactOverlap(out[:len(plaintext)], plaintext) { panic("cipher: invalid buffer overlap") } @@ -140,8 +140,8 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { var expectedTag [gcmTagSize]byte gcmSm4Data(&g.bytesProductTable, data, &expectedTag) - ret, out := subtle.SliceForAppend(dst, len(ciphertext)) - if subtle.InexactOverlap(out, ciphertext) { + ret, out := alias.SliceForAppend(dst, len(ciphertext)) + if alias.InexactOverlap(out, ciphertext) { panic("cipher: invalid buffer overlap") } if len(ciphertext) > 0 { @@ -149,7 +149,7 @@ func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { } gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) - if goSubtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { + if subtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { for i := range out { out[i] = 0 } diff --git a/sm4/sm4ni_gcm_asm.go b/sm4/sm4ni_gcm_asm.go index 3b40f3d..8e66d95 100644 --- a/sm4/sm4ni_gcm_asm.go +++ b/sm4/sm4ni_gcm_asm.go @@ -5,9 +5,9 @@ package sm4 import ( "crypto/cipher" - goSubtle "crypto/subtle" + "crypto/subtle" - "github.com/emmansun/gmsm/internal/subtle" + "github.com/emmansun/gmsm/internal/alias" ) //go:noescape @@ -79,8 +79,8 @@ func (g *gcmNI) Seal(dst, nonce, plaintext, data []byte) []byte { var tagOut [gcmTagSize]byte gcmSm4Data(&g.bytesProductTable, data, &tagOut) - ret, out := subtle.SliceForAppend(dst, len(plaintext)+g.tagSize) - if subtle.InexactOverlap(out[:len(plaintext)], plaintext) { + ret, out := alias.SliceForAppend(dst, len(plaintext)+g.tagSize) + if alias.InexactOverlap(out[:len(plaintext)], plaintext) { panic("cipher: invalid buffer overlap") } @@ -133,8 +133,8 @@ func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { var expectedTag [gcmTagSize]byte gcmSm4Data(&g.bytesProductTable, data, &expectedTag) - ret, out := subtle.SliceForAppend(dst, len(ciphertext)) - if subtle.InexactOverlap(out, ciphertext) { + ret, out := alias.SliceForAppend(dst, len(ciphertext)) + if alias.InexactOverlap(out, ciphertext) { panic("cipher: invalid buffer overlap") } if len(ciphertext) > 0 { @@ -142,7 +142,7 @@ func (g *gcmNI) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { } gcmSm4Finish(&g.bytesProductTable, &tagMask, &expectedTag, uint64(len(ciphertext)), uint64(len(data))) - if goSubtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { + if subtle.ConstantTimeCompare(expectedTag[:g.tagSize], tag) != 1 { for i := range out { out[i] = 0 } diff --git a/sm9/sm9.go b/sm9/sm9.go index d7e13c7..4000011 100644 --- a/sm9/sm9.go +++ b/sm9/sm9.go @@ -3,14 +3,14 @@ package sm9 import ( "crypto" - "crypto/subtle" + goSubtle "crypto/subtle" "encoding/binary" "errors" "fmt" "io" "math/big" - "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" "golang.org/x/crypto/cryptobyte" @@ -325,7 +325,7 @@ func Encrypt(rand io.Reader, pub *EncryptMasterPublicKey, uid []byte, hid byte, if err != nil { return nil, err } - xor.XorBytes(key, key[:len(plaintext)], plaintext) + subtle.XORBytes(key, key[:len(plaintext)], plaintext) hash := sm3.New() hash.Write(key) @@ -349,7 +349,7 @@ func (pub *EncryptMasterPublicKey) Encrypt(rand io.Reader, uid []byte, hid byte, if err != nil { return nil, err } - xor.XorBytes(key, key[:len(plaintext)], plaintext) + subtle.XORBytes(key, key[:len(plaintext)], plaintext) hash := sm3.New() hash.Write(key) @@ -385,11 +385,11 @@ func Decrypt(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { hash.Write(key[len(c2):]) c32 := hash.Sum(nil) - if subtle.ConstantTimeCompare(c3[:sm3.Size], c32) != 1 { + if goSubtle.ConstantTimeCompare(c3[:sm3.Size], c32) != 1 { return nil, errors.New("sm9: invalid mac value") } - xor.XorBytes(key, c2, key[:len(c2)]) + subtle.XORBytes(key, c2, key[:len(c2)]) return key[:len(c2)], nil } @@ -437,10 +437,10 @@ func DecryptASN1(priv *EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error hash.Write(key[len(c2Bytes):]) c32 := hash.Sum(nil) - if subtle.ConstantTimeCompare(c3Bytes, c32) != 1 { + if goSubtle.ConstantTimeCompare(c3Bytes, c32) != 1 { return nil, errors.New("sm9: invalid mac value") } - xor.XorBytes(key, c2Bytes, key[:len(c2Bytes)]) + subtle.XORBytes(key, c2Bytes, key[:len(c2Bytes)]) return key[:len(c2Bytes)], nil } @@ -597,7 +597,7 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) // step 6, verify signature if len(sB) > 0 { signature := ke.sign(false, 0x82) - if subtle.ConstantTimeCompare(signature, sB) != 1 { + if goSubtle.ConstantTimeCompare(signature, sB) != 1 { return nil, errors.New("sm9: invalid responder's signature") } } @@ -611,7 +611,7 @@ func (ke *KeyExchange) ConfirmResponder(rB *bn256.G1, sB []byte) ([]byte, error) // ConfirmInitiator for responder's step B8 func (ke *KeyExchange) ConfirmInitiator(s1 []byte) error { buffer := ke.sign(true, 0x83) - if subtle.ConstantTimeCompare(buffer, s1) != 1 { + if goSubtle.ConstantTimeCompare(buffer, s1) != 1 { return errors.New("sm9: invalid initiator's signature") } return nil diff --git a/sm9/sm9_test.go b/sm9/sm9_test.go index f0e2ab1..9873cfa 100644 --- a/sm9/sm9_test.go +++ b/sm9/sm9_test.go @@ -6,7 +6,7 @@ import ( "math/big" "testing" - "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/internal/subtle" "github.com/emmansun/gmsm/sm3" "github.com/emmansun/gmsm/sm9/bn256" ) @@ -493,7 +493,7 @@ func TestEncryptSM9Sample(t *testing.T) { if hex.EncodeToString(key) != expectedKey { t.Errorf("not expected key") } - xor.XorBytes(key, key[:len(plaintext)], plaintext) + subtle.XORBytes(key, key[:len(plaintext)], plaintext) hash := sm3.New() hash.Write(key) diff --git a/zuc/eea.go b/zuc/eea.go index d4fdaa1..2b5312a 100644 --- a/zuc/eea.go +++ b/zuc/eea.go @@ -4,8 +4,8 @@ import ( "crypto/cipher" "encoding/binary" + "github.com/emmansun/gmsm/internal/alias" "github.com/emmansun/gmsm/internal/subtle" - "github.com/emmansun/gmsm/internal/xor" ) const RoundWords = 32 @@ -35,7 +35,7 @@ func xorKeyStreamGeneric(c *zucState32, dst, src []byte) { for j := 0; j < RoundWords; j++ { binary.BigEndian.PutUint32(keyBytes[j*4:], keyWords[j]) } - xor.XorBytes(dst, src, keyBytes[:]) + subtle.XORBytes(dst, src, keyBytes[:]) dst = dst[RoundWords*4:] src = src[RoundWords*4:] } @@ -44,7 +44,7 @@ func xorKeyStreamGeneric(c *zucState32, dst, src []byte) { for j := 0; j < words-rounds*RoundWords; j++ { binary.BigEndian.PutUint32(keyBytes[j*4:], keyWords[j]) } - xor.XorBytes(dst, src, keyBytes[:]) + subtle.XORBytes(dst, src, keyBytes[:]) } } @@ -52,7 +52,7 @@ func (c *zucState32) XORKeyStream(dst, src []byte) { if len(dst) < len(src) { panic("zuc: output smaller than input") } - if subtle.InexactOverlap(dst[:len(src)], src) { + if alias.InexactOverlap(dst[:len(src)], src) { panic("zuc: invalid buffer overlap") } xorKeyStream(c, dst, src) diff --git a/zuc/eea_asm.go b/zuc/eea_asm.go index 464c1b7..c7e6ec2 100644 --- a/zuc/eea_asm.go +++ b/zuc/eea_asm.go @@ -4,7 +4,7 @@ package zuc import ( - "github.com/emmansun/gmsm/internal/xor" + "github.com/emmansun/gmsm/internal/subtle" ) //go:noescape @@ -17,13 +17,13 @@ func xorKeyStream(c *zucState32, dst, src []byte) { if words > 0 { dstWords := dst[:words*4] genKeyStreamRev32Asm(dstWords, c) - xor.XorBytes(dst, src, dstWords) + subtle.XORBytes(dst, src, dstWords) } // handle remain bytes if words*4 < len(src) { var singleWord [4]byte genKeyStreamRev32Asm(singleWord[:], c) - xor.XorBytes(dst[words*4:], src[words*4:], singleWord[:]) + subtle.XORBytes(dst[words*4:], src[words*4:], singleWord[:]) } } else { xorKeyStreamGeneric(c, dst, src)