diff --git a/CHANGELOG.md b/CHANGELOG.md index 56f04ca..76ce0dc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -71,6 +71,9 @@ - Stream APIs expanded across AES/SM4/DES/3DES and ChaCha20. - Updated docs to include a security-first recommendation and algorithm capability matrix. - Updated dependencies and modules for current code paths (`gmsm`, `x/crypto`). +- Refactored duplicated AES/SM4 symmetric code paths by extracting shared dispatch/helpers in `symm`. +- Unified hash method validation semantics: `hashx.SumAll` and `hashx.FileSumAll` now return errors for unknown methods (aligned with `FileSum`). +- Updated `hashx` tests for unsupported-method behavior. ### Fixed - Fixed Base128 encode/decode round-trip bug in `encodingx`. diff --git a/README.md b/README.md index c4d3060..d265e72 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,7 @@ func main() { ## 兼容性说明 库中保留了部分历史/兼容用途算法与接口(例如 `ECB`、`DES/3DES`)。如无兼容要求,建议优先使用 AEAD 方案。 +- `hashx` 的 `SumAll` / `FileSumAll` / `FileSum` 对未知算法名会返回错误(不再静默忽略)。 ## 许可证 diff --git a/hashx/hashx.go b/hashx/hashx.go index c0163a5..f24246e 100644 --- a/hashx/hashx.go +++ b/hashx/hashx.go @@ -8,6 +8,7 @@ import ( "encoding/binary" "encoding/hex" "errors" + "fmt" "hash" "hash/crc32" "io" @@ -146,6 +147,27 @@ func Crc32AStr(data []byte) string { return hex.EncodeToString(crc32aDigest(data)) } +func buildHasher(method string) (hash.Hash, hash.Hash32, error) { + switch method { + case "md5": + return md5.New(), nil, nil + case "sha1": + return sha1.New(), nil, nil + case "sha224": + return sha256.New224(), nil, nil + case "sha256": + return sha256.New(), nil, nil + case "sha384": + return sha512.New384(), nil, nil + case "sha512": + return sha512.New(), nil, nil + case "crc32": + return nil, crc32.NewIEEE(), nil + default: + return nil, nil, fmt.Errorf("%w: %s", ErrUnsupportedMethod, method) + } +} + func SumAll(data []byte, methods []string) (map[string][]byte, error) { if len(methods) == 0 { methods = []string{"sha512", "sha256", "sha384", "sha224", "sha1", "crc32", "md5"} @@ -156,25 +178,15 @@ func SumAll(data []byte, methods []string) (map[string][]byte, error) { var crc hash.Hash32 for _, method := range methods { - switch method { - case "md5": - hashers[method] = md5.New() - case "sha1": - hashers[method] = sha1.New() - case "sha224": - hashers[method] = sha256.New224() - case "sha256": - hashers[method] = sha256.New() - case "sha384": - hashers[method] = sha512.New384() - case "sha512": - hashers[method] = sha512.New() - case "crc32": - if crc == nil { - crc = crc32.NewIEEE() - } - default: - // Keep compatibility with previous behavior: unknown methods are ignored. + h, h32, err := buildHasher(method) + if err != nil { + return nil, err + } + if h != nil { + hashers[method] = h + } + if h32 != nil && crc == nil { + crc = h32 } } @@ -206,40 +218,22 @@ func FileSum(filePath, method string, progress func(float64)) (string, error) { return "", err } + h, h32, err := buildHasher(method) + if err != nil { + return "", err + } + var ( - h hash.Hash - h32 hash.Hash32 - is32 bool total int64 size = stat.Size() ) - switch method { - case "sha512": - h = sha512.New() - case "sha384": - h = sha512.New384() - case "sha256": - h = sha256.New() - case "sha224": - h = sha256.New224() - case "sha1": - h = sha1.New() - case "md5": - h = md5.New() - case "crc32": - h32 = crc32.NewIEEE() - is32 = true - default: - return "", errors.New(ErrUnsupportedMethod.Error() + ": " + method) - } - buf := make([]byte, 1024*1024) for { n, readErr := fp.Read(buf) if n > 0 { total += int64(n) - if is32 { + if h32 != nil { _, _ = h32.Write(buf[:n]) } else { _, _ = h.Write(buf[:n]) @@ -254,7 +248,7 @@ func FileSum(filePath, method string, progress func(float64)) (string, error) { } } - if is32 { + if h32 != nil { return hex.EncodeToString(h32.Sum(nil)), nil } return hex.EncodeToString(h.Sum(nil)), nil @@ -279,25 +273,15 @@ func FileSumAll(filePath string, methods []string, progress func(float64)) (map[ hashers := make(map[string]hash.Hash, len(methods)) var crc hash.Hash32 for _, method := range methods { - switch method { - case "md5": - hashers[method] = md5.New() - case "sha1": - hashers[method] = sha1.New() - case "sha224": - hashers[method] = sha256.New224() - case "sha256": - hashers[method] = sha256.New() - case "sha384": - hashers[method] = sha512.New384() - case "sha512": - hashers[method] = sha512.New() - case "crc32": - if crc == nil { - crc = crc32.NewIEEE() - } - default: - // Keep compatibility with previous behavior: unknown methods are ignored. + h, h32, err := buildHasher(method) + if err != nil { + return nil, err + } + if h != nil { + hashers[method] = h + } + if h32 != nil && crc == nil { + crc = h32 } } @@ -335,7 +319,6 @@ func FileSumAll(filePath string, methods []string, progress func(float64)) (map[ } return result, nil } - func reportProgress(progress func(float64), current, total int64) { if progress == nil { return diff --git a/hashx/hashx_test.go b/hashx/hashx_test.go index ae9c727..a32dc98 100644 --- a/hashx/hashx_test.go +++ b/hashx/hashx_test.go @@ -29,16 +29,10 @@ func TestSM3AndCRC32A(t *testing.T) { } } -func TestSumAllUnknownMethodIgnored(t *testing.T) { - res, err := SumAll([]byte("abc"), []string{"sha1", "unknown"}) - if err != nil { - t.Fatalf("SumAll returned error: %v", err) - } - if _, ok := res["sha1"]; !ok { - t.Fatalf("expected sha1 in result") - } - if _, ok := res["unknown"]; ok { - t.Fatalf("unknown method should be ignored") +func TestSumAllUnsupportedMethod(t *testing.T) { + _, err := SumAll([]byte("abc"), []string{"sha1", "unknown"}) + if err == nil { + t.Fatalf("expected unsupported method error") } } @@ -63,7 +57,7 @@ func TestFileSumAndFileSumAll(t *testing.T) { t.Fatalf("progress callback should be called") } - all, err := FileSumAll(path, []string{"sha1", "crc32", "unknown"}, nil) + all, err := FileSumAll(path, []string{"sha1", "crc32"}, nil) if err != nil { t.Fatalf("FileSumAll failed: %v", err) } @@ -73,9 +67,6 @@ func TestFileSumAndFileSumAll(t *testing.T) { if _, ok := all["crc32"]; !ok { t.Fatalf("expected crc32 in FileSumAll") } - if _, ok := all["unknown"]; ok { - t.Fatalf("unknown method should not appear") - } } func TestFileSumUnsupportedMethod(t *testing.T) { @@ -88,6 +79,18 @@ func TestFileSumUnsupportedMethod(t *testing.T) { t.Fatalf("expected unsupported method error") } } + +func TestFileSumAllUnsupportedMethod(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sum.txt") + if err := os.WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if _, err := FileSumAll(path, []string{"sha256", "not-support"}, nil); err == nil { + t.Fatalf("expected unsupported method error") + } +} + func TestPBKDF2SHA256Vector(t *testing.T) { got, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 1, 32) if err != nil { diff --git a/symm/aes.go b/symm/aes.go index 4bcbfa6..87cb5f0 100644 --- a/symm/aes.go +++ b/symm/aes.go @@ -2,12 +2,10 @@ package symm import ( "crypto/aes" - "crypto/cipher" "crypto/rand" "errors" "io" - "b612.me/starcrypto/ccm" "b612.me/starcrypto/paddingx" ) @@ -28,306 +26,90 @@ const ( aeadCCMNonceSize = 12 ) -func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return EncryptAesGCM(data, key, iv, nil) - } - if normalizedMode == MODECCM { - return EncryptAesCCM(data, key, iv, nil) - } +var ( + aesGCMFactory = newGCMFactory(aes.NewCipher) + aesCCMFactory = newCCMFactory(aes.NewCipher) +) - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - return encryptWithBlockMode(block, data, iv, normalizedMode, paddingType, PKCS7PADDING) +func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + return encryptBlockCipher(data, key, iv, mode, paddingType, PKCS7PADDING, aes.NewCipher, EncryptAesGCM, EncryptAesCCM) } func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return DecryptAesGCM(src, key, iv, nil) - } - if normalizedMode == MODECCM { - return DecryptAesCCM(src, key, iv, nil) - } - - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - return decryptWithBlockMode(block, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return decryptBlockCipher(src, key, iv, mode, paddingType, PKCS7PADDING, aes.NewCipher, DecryptAesGCM, DecryptAesCCM) } func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return EncryptAesGCMStream(dst, src, key, iv, nil) - } - if normalizedMode == MODECCM { - return EncryptAesCCMStream(dst, src, key, iv, nil) - } - - block, err := aes.NewCipher(key) - if err != nil { - return err - } - return encryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return encryptBlockCipherStream(dst, src, key, iv, mode, paddingType, PKCS7PADDING, aes.NewCipher, EncryptAesGCMStream, EncryptAesCCMStream) } func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return DecryptAesGCMStream(dst, src, key, iv, nil) - } - if normalizedMode == MODECCM { - return DecryptAesCCMStream(dst, src, key, iv, nil) - } - - block, err := aes.NewCipher(key) - if err != nil { - return err - } - return decryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return decryptBlockCipherStream(dst, src, key, iv, mode, paddingType, PKCS7PADDING, aes.NewCipher, DecryptAesGCMStream, DecryptAesCCMStream) } func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return EncryptAesCCM(data, key, nonceFromOptions(cfg), cfg.AAD) - } - return EncryptAes(data, key, cfg.IV, mode, cfg.Padding) + return encryptBlockWithOptions(data, key, opts, EncryptAesGCM, EncryptAesCCM, EncryptAes) } func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return DecryptAesCCM(src, key, nonceFromOptions(cfg), cfg.AAD) - } - return DecryptAes(src, key, cfg.IV, mode, cfg.Padding) + return decryptBlockWithOptions(src, key, opts, DecryptAesGCM, DecryptAesCCM, DecryptAes) } func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return EncryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) + return encryptBlockStreamWithOptions(dst, src, key, opts, EncryptAesGCMStream, EncryptAesCCMStream, EncryptAesStream) } func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return DecryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) + return decryptBlockStreamWithOptions(dst, src, key, opts, DecryptAesGCMStream, DecryptAesCCMStream, DecryptAesStream) } func EncryptAesGCM(plain, key, nonce, aad []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return gcm.Seal(nil, nonce, plain, aad), nil + return encryptAEAD(aesGCMFactory, plain, key, nonce, aad, ErrInvalidGCMNonceLength) } func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return gcm.Open(nil, nonce, ciphertext, aad) -} - -func newAesCCM(key []byte) (cipher.AEAD, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize) + return decryptAEAD(aesGCMFactory, ciphertext, key, nonce, aad, ErrInvalidGCMNonceLength) } func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) { - aead, err := newAesCCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return aead.Seal(nil, nonce, plain, aad), nil + return encryptAEAD(aesCCMFactory, plain, key, nonce, aad, ErrInvalidCCMNonceLength) } func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { - aead, err := newAesCCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return aead.Open(nil, nonce, ciphertext, aad) + return decryptAEAD(aesCCMFactory, ciphertext, key, nonce, aad, ErrInvalidCCMNonceLength) } func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - aead, err := newAesCCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil + return encryptAEADChunk(aesCCMFactory, plain, key, nonce, aad, chunkIndex, ErrInvalidCCMNonceLength, encryptCCMChunk) } func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - aead, err := newAesCCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex) + return decryptAEADChunk(aesCCMFactory, ciphertext, key, nonce, aad, chunkIndex, ErrInvalidCCMNonceLength, decryptCCMChunk) } func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - aead, err := newAesCCM(key) - if err != nil { - return err - } - if len(nonce) != aead.NonceSize() { - return ErrInvalidCCMNonceLength - } - return encryptCCMChunkedStream(dst, src, aead, nonce, aad) + return encryptAEADStream(aesCCMFactory, dst, src, key, nonce, aad, ErrInvalidCCMNonceLength, encryptCCMChunkedStream) } func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - aead, err := newAesCCM(key) - if err != nil { - return err - } - if len(nonce) != aead.NonceSize() { - return ErrInvalidCCMNonceLength - } - return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad) + return decryptAEADStream(aesCCMFactory, dst, src, key, nonce, aad, ErrInvalidCCMNonceLength, decryptCCMChunkedOrLegacyStream) } func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil + return encryptAEADChunk(aesGCMFactory, plain, key, nonce, aad, chunkIndex, ErrInvalidGCMNonceLength, encryptGCMChunk) } func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex) + return decryptAEADChunk(aesGCMFactory, ciphertext, key, nonce, aad, chunkIndex, ErrInvalidGCMNonceLength, decryptGCMChunk) } func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - block, err := aes.NewCipher(key) - if err != nil { - return err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return err - } - if len(nonce) != gcm.NonceSize() { - return ErrInvalidGCMNonceLength - } - return encryptGCMChunkedStream(dst, src, gcm, nonce, aad) + return encryptAEADStream(aesGCMFactory, dst, src, key, nonce, aad, ErrInvalidGCMNonceLength, encryptGCMChunkedStream) } func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - block, err := aes.NewCipher(key) - if err != nil { - return err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return err - } - if len(nonce) != gcm.NonceSize() { - return ErrInvalidGCMNonceLength - } - return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad) + return decryptAEADStream(aesGCMFactory, dst, src, key, nonce, aad, ErrInvalidGCMNonceLength, decryptGCMChunkedOrLegacyStream) } - func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) { return EncryptAes(data, key, nil, MODEECB, paddingType) } diff --git a/symm/cipher_common.go b/symm/cipher_common.go new file mode 100644 index 0000000..d05fb34 --- /dev/null +++ b/symm/cipher_common.go @@ -0,0 +1,223 @@ +package symm + +import ( + "crypto/cipher" + "io" + + "b612.me/starcrypto/ccm" +) + +type blockCipherFactory func(key []byte) (cipher.Block, error) +type aeadFactory func(key []byte) (cipher.AEAD, error) + +type aeadBytesFunc func(data, key, nonce, aad []byte) ([]byte, error) +type aeadStreamFunc func(dst io.Writer, src io.Reader, key, nonce, aad []byte) error +type blockModeBytesFunc func(data, key, iv []byte, mode, paddingType string) ([]byte, error) +type blockModeStreamFunc func(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error + +type aeadChunkEncryptFunc func(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte +type aeadChunkDecryptFunc func(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) +type aeadStreamCodec func(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error + +func newGCMFactory(newBlock blockCipherFactory) aeadFactory { + return func(key []byte) (cipher.AEAD, error) { + block, err := newBlock(key) + if err != nil { + return nil, err + } + return cipher.NewGCM(block) + } +} + +func newCCMFactory(newBlock blockCipherFactory) aeadFactory { + return func(key []byte) (cipher.AEAD, error) { + block, err := newBlock(key) + if err != nil { + return nil, err + } + return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize) + } +} + +func normalizeModeOrDefaultAEAD(mode string) string { + mode = normalizeCipherMode(mode) + if mode == "" { + mode = MODEGCM + } + return mode +} + +func encryptBlockCipher(data, key, iv []byte, mode, paddingType, defaultPadding string, newBlock blockCipherFactory, encryptGCM, encryptCCM aeadBytesFunc) ([]byte, error) { + mode = normalizeModeOrDefaultAEAD(mode) + switch mode { + case MODEGCM: + return encryptGCM(data, key, iv, nil) + case MODECCM: + return encryptCCM(data, key, iv, nil) + default: + block, err := newBlock(key) + if err != nil { + return nil, err + } + return encryptWithBlockMode(block, data, iv, mode, paddingType, defaultPadding) + } +} + +func decryptBlockCipher(src, key, iv []byte, mode, paddingType, defaultPadding string, newBlock blockCipherFactory, decryptGCM, decryptCCM aeadBytesFunc) ([]byte, error) { + mode = normalizeModeOrDefaultAEAD(mode) + switch mode { + case MODEGCM: + return decryptGCM(src, key, iv, nil) + case MODECCM: + return decryptCCM(src, key, iv, nil) + default: + block, err := newBlock(key) + if err != nil { + return nil, err + } + return decryptWithBlockMode(block, src, iv, mode, paddingType, defaultPadding) + } +} + +func encryptBlockCipherStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType, defaultPadding string, newBlock blockCipherFactory, encryptGCMStream, encryptCCMStream aeadStreamFunc) error { + mode = normalizeModeOrDefaultAEAD(mode) + switch mode { + case MODEGCM: + return encryptGCMStream(dst, src, key, iv, nil) + case MODECCM: + return encryptCCMStream(dst, src, key, iv, nil) + default: + block, err := newBlock(key) + if err != nil { + return err + } + return encryptWithBlockModeStream(block, dst, src, iv, mode, paddingType, defaultPadding) + } +} + +func decryptBlockCipherStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType, defaultPadding string, newBlock blockCipherFactory, decryptGCMStream, decryptCCMStream aeadStreamFunc) error { + mode = normalizeModeOrDefaultAEAD(mode) + switch mode { + case MODEGCM: + return decryptGCMStream(dst, src, key, iv, nil) + case MODECCM: + return decryptCCMStream(dst, src, key, iv, nil) + default: + block, err := newBlock(key) + if err != nil { + return err + } + return decryptWithBlockModeStream(block, dst, src, iv, mode, paddingType, defaultPadding) + } +} + +func encryptBlockWithOptions(data, key []byte, opts *CipherOptions, encryptGCM, encryptCCM aeadBytesFunc, encryptBlock blockModeBytesFunc) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeModeOrDefaultAEAD(cfg.Mode) + switch mode { + case MODEGCM: + return encryptGCM(data, key, nonceFromOptions(cfg), cfg.AAD) + case MODECCM: + return encryptCCM(data, key, nonceFromOptions(cfg), cfg.AAD) + default: + return encryptBlock(data, key, cfg.IV, mode, cfg.Padding) + } +} + +func decryptBlockWithOptions(data, key []byte, opts *CipherOptions, decryptGCM, decryptCCM aeadBytesFunc, decryptBlock blockModeBytesFunc) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeModeOrDefaultAEAD(cfg.Mode) + switch mode { + case MODEGCM: + return decryptGCM(data, key, nonceFromOptions(cfg), cfg.AAD) + case MODECCM: + return decryptCCM(data, key, nonceFromOptions(cfg), cfg.AAD) + default: + return decryptBlock(data, key, cfg.IV, mode, cfg.Padding) + } +} + +func encryptBlockStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions, encryptGCM, encryptCCM aeadStreamFunc, encryptBlockStream blockModeStreamFunc) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeModeOrDefaultAEAD(cfg.Mode) + switch mode { + case MODEGCM: + return encryptGCM(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + case MODECCM: + return encryptCCM(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + default: + return encryptBlockStream(dst, src, key, cfg.IV, mode, cfg.Padding) + } +} + +func decryptBlockStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions, decryptGCM, decryptCCM aeadStreamFunc, decryptBlockStream blockModeStreamFunc) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeModeOrDefaultAEAD(cfg.Mode) + switch mode { + case MODEGCM: + return decryptGCM(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + case MODECCM: + return decryptCCM(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + default: + return decryptBlockStream(dst, src, key, cfg.IV, mode, cfg.Padding) + } +} + +func buildAEAD(factory aeadFactory, key, nonce []byte, errInvalidNonce error) (cipher.AEAD, error) { + aead, err := factory(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, errInvalidNonce + } + return aead, nil +} + +func encryptAEAD(factory aeadFactory, plain, key, nonce, aad []byte, errInvalidNonce error) ([]byte, error) { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return nil, err + } + return aead.Seal(nil, nonce, plain, aad), nil +} + +func decryptAEAD(factory aeadFactory, ciphertext, key, nonce, aad []byte, errInvalidNonce error) ([]byte, error) { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return nil, err + } + return aead.Open(nil, nonce, ciphertext, aad) +} + +func encryptAEADChunk(factory aeadFactory, plain, key, nonce, aad []byte, chunkIndex uint64, errInvalidNonce error, encryptChunk aeadChunkEncryptFunc) ([]byte, error) { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return nil, err + } + return encryptChunk(aead, plain, nonce, aad, chunkIndex), nil +} + +func decryptAEADChunk(factory aeadFactory, ciphertext, key, nonce, aad []byte, chunkIndex uint64, errInvalidNonce error, decryptChunk aeadChunkDecryptFunc) ([]byte, error) { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return nil, err + } + return decryptChunk(aead, ciphertext, nonce, aad, chunkIndex) +} + +func encryptAEADStream(factory aeadFactory, dst io.Writer, src io.Reader, key, nonce, aad []byte, errInvalidNonce error, encryptStream aeadStreamCodec) error { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return err + } + return encryptStream(dst, src, aead, nonce, aad) +} + +func decryptAEADStream(factory aeadFactory, dst io.Writer, src io.Reader, key, nonce, aad []byte, errInvalidNonce error, decryptStream aeadStreamCodec) error { + aead, err := buildAEAD(factory, key, nonce, errInvalidNonce) + if err != nil { + return err + } + return decryptStream(dst, src, aead, nonce, aad) +} diff --git a/symm/sm4.go b/symm/sm4.go index 4a171f0..0edee91 100644 --- a/symm/sm4.go +++ b/symm/sm4.go @@ -1,315 +1,97 @@ package symm import ( - "crypto/cipher" "crypto/rand" "errors" "io" - "b612.me/starcrypto/ccm" "github.com/emmansun/gmsm/sm4" ) -func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return EncryptSM4GCM(data, key, iv, nil) - } - if normalizedMode == MODECCM { - return EncryptSM4CCM(data, key, iv, nil) - } +var ( + sm4GCMFactory = newGCMFactory(sm4.NewCipher) + sm4CCMFactory = newCCMFactory(sm4.NewCipher) +) - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - return encryptWithBlockMode(block, data, iv, normalizedMode, paddingType, PKCS7PADDING) +func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + return encryptBlockCipher(data, key, iv, mode, paddingType, PKCS7PADDING, sm4.NewCipher, EncryptSM4GCM, EncryptSM4CCM) } func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return DecryptSM4GCM(src, key, iv, nil) - } - if normalizedMode == MODECCM { - return DecryptSM4CCM(src, key, iv, nil) - } - - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - return decryptWithBlockMode(block, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return decryptBlockCipher(src, key, iv, mode, paddingType, PKCS7PADDING, sm4.NewCipher, DecryptSM4GCM, DecryptSM4CCM) } func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return EncryptSM4GCMStream(dst, src, key, iv, nil) - } - if normalizedMode == MODECCM { - return EncryptSM4CCMStream(dst, src, key, iv, nil) - } - - block, err := sm4.NewCipher(key) - if err != nil { - return err - } - return encryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return encryptBlockCipherStream(dst, src, key, iv, mode, paddingType, PKCS7PADDING, sm4.NewCipher, EncryptSM4GCMStream, EncryptSM4CCMStream) } func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { - normalizedMode := normalizeCipherMode(mode) - if normalizedMode == "" { - normalizedMode = MODEGCM - } - if normalizedMode == MODEGCM { - return DecryptSM4GCMStream(dst, src, key, iv, nil) - } - if normalizedMode == MODECCM { - return DecryptSM4CCMStream(dst, src, key, iv, nil) - } - - block, err := sm4.NewCipher(key) - if err != nil { - return err - } - return decryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) + return decryptBlockCipherStream(dst, src, key, iv, mode, paddingType, PKCS7PADDING, sm4.NewCipher, DecryptSM4GCMStream, DecryptSM4CCMStream) } func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return EncryptSM4CCM(data, key, nonceFromOptions(cfg), cfg.AAD) - } - return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding) + return encryptBlockWithOptions(data, key, opts, EncryptSM4GCM, EncryptSM4CCM, EncryptSM4) } func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return DecryptSM4CCM(src, key, nonceFromOptions(cfg), cfg.AAD) - } - return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding) + return decryptBlockWithOptions(src, key, opts, DecryptSM4GCM, DecryptSM4CCM, DecryptSM4) } func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return EncryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) + return encryptBlockStreamWithOptions(dst, src, key, opts, EncryptSM4GCMStream, EncryptSM4CCMStream, EncryptSM4Stream) } func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { - cfg := normalizeCipherOptions(opts) - mode := normalizeCipherMode(cfg.Mode) - if mode == "" { - mode = MODEGCM - } - if mode == MODEGCM { - return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - if mode == MODECCM { - return DecryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) - } - return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) + return decryptBlockStreamWithOptions(dst, src, key, opts, DecryptSM4GCMStream, DecryptSM4CCMStream, DecryptSM4Stream) } func EncryptSM4GCM(plain, key, nonce, aad []byte) ([]byte, error) { - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return gcm.Seal(nil, nonce, plain, aad), nil + return encryptAEAD(sm4GCMFactory, plain, key, nonce, aad, ErrInvalidGCMNonceLength) } func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return gcm.Open(nil, nonce, ciphertext, aad) + return decryptAEAD(sm4GCMFactory, ciphertext, key, nonce, aad, ErrInvalidGCMNonceLength) } func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil + return encryptAEADChunk(sm4GCMFactory, plain, key, nonce, aad, chunkIndex, ErrInvalidGCMNonceLength, encryptGCMChunk) } func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - if len(nonce) != gcm.NonceSize() { - return nil, ErrInvalidGCMNonceLength - } - return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex) + return decryptAEADChunk(sm4GCMFactory, ciphertext, key, nonce, aad, chunkIndex, ErrInvalidGCMNonceLength, decryptGCMChunk) } func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - block, err := sm4.NewCipher(key) - if err != nil { - return err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return err - } - if len(nonce) != gcm.NonceSize() { - return ErrInvalidGCMNonceLength - } - return encryptGCMChunkedStream(dst, src, gcm, nonce, aad) + return encryptAEADStream(sm4GCMFactory, dst, src, key, nonce, aad, ErrInvalidGCMNonceLength, encryptGCMChunkedStream) } func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - block, err := sm4.NewCipher(key) - if err != nil { - return err - } - gcm, err := cipher.NewGCM(block) - if err != nil { - return err - } - if len(nonce) != gcm.NonceSize() { - return ErrInvalidGCMNonceLength - } - return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad) -} - -func newSM4CCM(key []byte) (cipher.AEAD, error) { - block, err := sm4.NewCipher(key) - if err != nil { - return nil, err - } - return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize) + return decryptAEADStream(sm4GCMFactory, dst, src, key, nonce, aad, ErrInvalidGCMNonceLength, decryptGCMChunkedOrLegacyStream) } func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) { - aead, err := newSM4CCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return aead.Seal(nil, nonce, plain, aad), nil + return encryptAEAD(sm4CCMFactory, plain, key, nonce, aad, ErrInvalidCCMNonceLength) } func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { - aead, err := newSM4CCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return aead.Open(nil, nonce, ciphertext, aad) + return decryptAEAD(sm4CCMFactory, ciphertext, key, nonce, aad, ErrInvalidCCMNonceLength) } func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - aead, err := newSM4CCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil + return encryptAEADChunk(sm4CCMFactory, plain, key, nonce, aad, chunkIndex, ErrInvalidCCMNonceLength, encryptCCMChunk) } func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { - aead, err := newSM4CCM(key) - if err != nil { - return nil, err - } - if len(nonce) != aead.NonceSize() { - return nil, ErrInvalidCCMNonceLength - } - return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex) + return decryptAEADChunk(sm4CCMFactory, ciphertext, key, nonce, aad, chunkIndex, ErrInvalidCCMNonceLength, decryptCCMChunk) } func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - aead, err := newSM4CCM(key) - if err != nil { - return err - } - if len(nonce) != aead.NonceSize() { - return ErrInvalidCCMNonceLength - } - return encryptCCMChunkedStream(dst, src, aead, nonce, aad) + return encryptAEADStream(sm4CCMFactory, dst, src, key, nonce, aad, ErrInvalidCCMNonceLength, encryptCCMChunkedStream) } func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - aead, err := newSM4CCM(key) - if err != nil { - return err - } - if len(nonce) != aead.NonceSize() { - return ErrInvalidCCMNonceLength - } - return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad) + return decryptAEADStream(sm4CCMFactory, dst, src, key, nonce, aad, ErrInvalidCCMNonceLength, decryptCCMChunkedOrLegacyStream) } - func EncryptSM4CFB(origData, key []byte) ([]byte, error) { block, err := sm4.NewCipher(key) if err != nil {