diff --git a/.gitignore b/.gitignore index 357532a..c1473dc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ # Local Go module/cache sandboxes used in this repo .gopath/ .gomodcache/ +.gocache/ # Build artifacts *.exe @@ -27,3 +28,8 @@ Thumbs.db # Temp/test artifacts sm3/ifile + +# Agent workflow artifacts +.sentrux/ +agent_readme.md +target.md diff --git a/CHANGELOG.md b/CHANGELOG.md index ac22290..56f04ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,52 @@ -# Changelog +# Changelog + +## 2026-03-17 + +--- + +### Added + +- Added AES-XTS and SM4-XTS APIs in `symm` and root wrappers: + - bytes APIs: `Encrypt*/Decrypt*XTS` + - data-unit indexed APIs: `Encrypt*/Decrypt*XTSAt` + - stream APIs: `Encrypt*/Decrypt*XTSStream` + - stream + data-unit indexed APIs: `Encrypt*/Decrypt*XTSStreamAt` +- Added XTS master-key split helpers: + - `SplitXTSMasterKey` + - `SplitAesXTSMasterKey` + - `SplitSM4XTSMasterKey` +- Added XTS parameter validation and explicit no-CTS behavior: + - dual keys must be non-empty and equal length + - `dataUnitSize` must be a positive multiple of 16 + - non-stream input must be 16-byte aligned + - stream tail must be 16-byte aligned +- Added key-derivation APIs in `hashx` and root wrappers: + - `DerivePBKDF2SHA256Key` (`crypto/pbkdf2`, stdlib) + - `DerivePBKDF2SHA512Key` (`crypto/pbkdf2`, stdlib) + - `DeriveArgon2idKey` / `DeriveArgon2iKey` (`golang.org/x/crypto/argon2`) + - `Argon2Params` + `DefaultArgon2idParams` +- Added benchmark coverage for symmetric hot paths: + - AES/SM4 `GCM`, `CCM`, `XTS` + - AES/SM4 `XTS` stream path +- Added file random fill APIs: + - `FillWithRandom` (math/rand pseudo-random) + - `FillWithCryptoRandom` (crypto/rand secure random, may be slower) +- Added HMAC verify APIs (bytes + hex wrappers) in `macx` and root package. +- Added XTS standard-vector tests (IEEE P1619 subset) and XTS fuzz tests. +- Added CCM/XTS related test coverage for root wrappers and `symm`. + +### Changed + +- Refactored XTS internals to reduce duplication via shared factory/path. +- Switched AEAD options behavior to require explicit `Nonce` (no `IV` fallback in GCM/CCM paths). +- Reworked CFB-8 register update to ring-buffer state handling. +- Refined `README.md` structure and added: + - XTS usage/constraints documentation + - AEAD nonce non-reuse requirement + - AEAD `CipherOptions` nonce-only behavior note + - legacy GCM/CCM stream decryption memory warning + - `FillWithRandom` vs `FillWithCryptoRandom` guidance + - XTS `dataUnitIndex` mapping consistency note ## Unreleased @@ -11,7 +59,7 @@ - Added unified symmetric cipher options API: - `CipherOptions{Mode, Padding, IV, Nonce, AAD}`. - Added AEAD APIs and wrappers: - - `AES-GCM`, `SM4-GCM` (bytes + stream helper APIs). + - `AES-GCM`, `SM4-GCM`, `AES-CCM`, `SM4-CCM` (bytes + chunk + stream helper APIs). - Added more symmetric mode coverage for SM4: - `ECB/CBC/CFB/OFB/CTR` (bytes + stream derived APIs). - Added comprehensive tests across packages and root wrappers. diff --git a/README.md b/README.md index 47182e3..c4d3060 100644 --- a/README.md +++ b/README.md @@ -1,37 +1,56 @@ -# starcrypto +# starcrypto `starcrypto` 是一个 Go 单包聚合风格的密码学工具库:根包可直接调用,同时内部拆分为子包实现,兼顾易用性与可维护性。 -## 特性 - -- 根包直调 + 子包分层(`asymm` / `symm` / `hashx` / `encodingx` / `paddingx` 等) -- 对称算法:AES、SM4、DES、3DES、ChaCha20、ChaCha20-Poly1305 -- 非对称算法:RSA、ECDSA、SM2、SM9 -- 哈希与消息认证:SHA 系列、MD5/MD4、RIPEMD160、SM3、CRC32/CRC32A、HMAC -- 编解码:Base64、Base85、Base91、Base128 -- 支持内存 `[]byte` 与流式 `io.Reader/io.Writer`(见下方能力矩阵) - ## 安装 ```bash go get b612.me/starcrypto ``` +## 特性 + +- 根包直调 + 子包分层(`asymm` / `symm` / `hashx` / `encodingx` / `paddingx` 等) +- 对称算法:AES、SM4、DES、3DES、ChaCha20、ChaCha20-Poly1305 +- AEAD:AES-GCM/AES-CCM、SM4-GCM/SM4-CCM、ChaCha20-Poly1305 +- 存储加密:AES-XTS、SM4-XTS(双 key,支持 data unit 索引与流式) +- 非对称算法:RSA、ECDSA、SM2、SM9 +- 哈希与消息认证:SHA 系列、MD5/MD4、RIPEMD160、SM3、CRC32/CRC32A、HMAC +- 支持内存 `[]byte` 与流式 `io.Reader/io.Writer` + ## 推荐用法(安全优先) -优先使用带认证的 AEAD: +- 通用数据传输优先使用 AEAD: + - `EncryptAesGCM/DecryptAesGCM` + - `EncryptAesCCM/DecryptAesCCM` + - `EncryptSM4GCM/DecryptSM4GCM` + - `EncryptSM4CCM/DecryptSM4CCM` + - `EncryptChaCha20Poly1305/DecryptChaCha20Poly1305` +- 磁盘/扇区类场景可用 XTS: + - `EncryptAesXTS/DecryptAesXTS` + - `EncryptSM4XTS/DecryptSM4XTS` +- 统一选项接口(默认 `GCM`): + - `EncryptAesWithOptions/DecryptAesWithOptions` + - `EncryptSM4WithOptions/DecryptSM4WithOptions` -- `EncryptAesGCM/DecryptAesGCM` -- `EncryptSM4GCM/DecryptSM4GCM` -- `EncryptChaCha20Poly1305/DecryptChaCha20Poly1305` +> `CBC/CFB/OFB/CTR/ECB/XTS` 不提供完整性校验,不能替代 AEAD 的篡改检测能力。 +> AEAD (`GCM/CCM/ChaCha20-Poly1305`) 下,同一把 key 绝不能重复使用同一个 nonce。 +> 使用 `CipherOptions` 时,AEAD 模式仅读取 `Nonce` 字段,不再回退 `IV`。 +> GCM/CCM 流解密在 legacy 兼容分支会缓冲完整密文到内存;大文件建议使用分块流格式(带 `SCG1/SCC1` 头)。 -或使用统一选项接口(默认 `GCM`): +## XTS 约束(重要) -- `EncryptAesWithOptions/DecryptAesWithOptions` -- `EncryptSM4WithOptions/DecryptSM4WithOptions` +- 两个密钥分开传入:`key1`, `key2`,且长度必须相同且非空。 +- `dataUnitSize`(数据单元大小)由调用方自定义,但必须是正整数且为 `16` 的倍数。 +- `dataUnitIndex` 必须与存储层的逻辑块映射保持稳定一致;同一数据单元在加密和解密时索引必须一致。 +- 当前实现不做 CTS(ciphertext stealing)。 +- 非流式 API:输入长度必须满足 `len(data) % 16 == 0`,否则直接返回错误。 +- 流式 API:最终尾块若不是 `16` 字节对齐会返回错误。 -> `CBC/CFB/OFB/CTR/ECB` 仅提供机密性,不提供完整性校验,可能被篡改后无法检测。 +## 文件随机填充 +- `FillWithRandom` 使用 `math/rand` 伪随机,速度更高,但不适合安全用途。 +- `FillWithCryptoRandom` 使用 `crypto/rand`,安全性更高,但速度可能受限。 ## 快速示例 ### 统一 Options(默认 GCM) @@ -65,7 +84,7 @@ func main() { } ``` -### AES CBC 流式接口(兼容模式) +### AES-XTS(按 data unit) ```go package main @@ -78,19 +97,20 @@ import ( ) func main() { - key := []byte("0123456789abcdef") - iv := []byte("abcdef9876543210") + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 8) // 128 bytes, 16-byte aligned - src := bytes.NewReader([]byte("stream content")) - encBuf := &bytes.Buffer{} - decBuf := &bytes.Buffer{} - - if err := starcrypto.EncryptAesCBCStream(encBuf, src, key, iv, ""); err != nil { + // 每个 data unit 64 字节 + enc, err := starcrypto.EncryptAesXTS(plain, k1, k2, 64) + if err != nil { log.Fatal(err) } - if err := starcrypto.DecryptAesCBCStream(decBuf, bytes.NewReader(encBuf.Bytes()), key, iv, ""); err != nil { + dec, err := starcrypto.DecryptAesXTS(enc, k1, k2, 64) + if err != nil { log.Fatal(err) } + _ = dec } ``` @@ -98,8 +118,8 @@ func main() { | 算法 | 模式 / 方案 | AEAD | Stream | 建议 | |---|---|---:|---:|---| -| AES | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM | -| SM4 | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM | +| AES | ECB/CBC/CFB/OFB/CTR/GCM/CCM/XTS | GCM/CCM: 是;XTS: 否 | 是 | 生产优先 GCM/CCM;存储场景可用 XTS | +| SM4 | ECB/CBC/CFB/OFB/CTR/GCM/CCM/XTS | GCM/CCM: 是;XTS: 否 | 是 | 生产优先 GCM/CCM;存储场景可用 XTS | | ChaCha20 | ChaCha20 / ChaCha20-Poly1305 | Poly1305: 是 | ChaCha20: 是 | 生产优先 ChaCha20-Poly1305 | | DES | CBC | 否 | 是 | 仅兼容历史系统 | | 3DES | CBC | 否 | 是 | 仅兼容历史系统 | @@ -108,11 +128,11 @@ func main() { - AES/SM4 的 CBC/ECB 默认:`PKCS7` - DES/3DES 的 CBC 默认:`PKCS5` -- CFB/OFB/CTR/GCM/ChaCha20 不使用填充 +- CFB/OFB/CTR/GCM/CCM/XTS/ChaCha20 不使用填充 ## 兼容性说明 -库中保留了部分历史/兼容用途算法与接口(例如 `ECB`、`DES/3DES`)。如无兼容要求,建议使用 AEAD 方案并统一通过 `CipherOptions` 管理参数。 +库中保留了部分历史/兼容用途算法与接口(例如 `ECB`、`DES/3DES`)。如无兼容要求,建议优先使用 AEAD 方案。 ## 许可证 diff --git a/THIRD_PARTY_NOTICES.md b/THIRD_PARTY_NOTICES.md new file mode 100644 index 0000000..1fcc949 --- /dev/null +++ b/THIRD_PARTY_NOTICES.md @@ -0,0 +1,34 @@ +# Third-Party Notices + +This repository is primarily licensed under Apache-2.0. Some files are copied from third-party projects and remain under their original licenses. + +## Pion DTLS CCM implementation + +- Upstream project: https://github.com/pion/dtls +- Upstream path: `pkg/crypto/ccm` +- Local files: + - `ccm/ccm.go` + - `ccm/ccm_test.go` +- License: MIT + +### MIT License (Pion DTLS) + +Copyright (c) 2026 The Pion community + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/aes.go b/aes.go index 3464ab3..97c487e 100644 --- a/aes.go +++ b/aes.go @@ -20,6 +20,7 @@ const ( MODEOFB = symm.MODEOFB MODECTR = symm.MODECTR MODEGCM = symm.MODEGCM + MODECCM = symm.MODECCM ) func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { @@ -62,6 +63,14 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { return symm.DecryptAesGCM(ciphertext, key, nonce, aad) } +func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex) +} + +func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.DecryptAesGCMChunk(ciphertext, key, nonce, aad, chunkIndex) +} + func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { return symm.EncryptAesGCMStream(dst, src, key, nonce, aad) } @@ -70,6 +79,30 @@ func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) e return symm.DecryptAesGCMStream(dst, src, key, nonce, aad) } +func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) { + return symm.EncryptAesCCM(plain, key, nonce, aad) +} + +func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + return symm.DecryptAesCCM(ciphertext, key, nonce, aad) +} + +func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex) +} + +func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.DecryptAesCCMChunk(ciphertext, key, nonce, aad, chunkIndex) +} + +func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.EncryptAesCCMStream(dst, src, key, nonce, aad) +} + +func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.DecryptAesCCMStream(dst, src, key, nonce, aad) +} + func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) { return symm.EncryptAesECB(data, key, paddingType) } @@ -110,6 +143,14 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) { return symm.DecryptAesCTR(src, key, iv) } +func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) { + return symm.EncryptAesCTRAt(data, key, iv, offset) +} + +func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) { + return symm.DecryptAesCTRAt(src, key, iv, offset) +} + func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { return symm.EncryptAesECBStream(dst, src, key, paddingType) } @@ -181,3 +222,62 @@ func PKCS7Padding(cipherText []byte, blockSize int) []byte { func PKCS7Trimming(encrypt []byte, blockSize int) []byte { return symm.PKCS7Trimming(encrypt, blockSize) } + +func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) { + return symm.EncryptAesCFB8(data, key, iv) +} + +func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) { + return symm.DecryptAesCFB8(src, key, iv) +} + +func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptAesCFB8Stream(dst, src, key, iv) +} + +func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptAesCFB8Stream(dst, src, key, iv) +} + +func DecryptAesECBBlocks(src, key []byte) ([]byte, error) { + return symm.DecryptAesECBBlocks(src, key) +} + +func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + return symm.DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock) +} + +func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + return symm.DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock) +} +func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return symm.EncryptAesXTS(plain, key1, key2, dataUnitSize) +} + +func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return symm.DecryptAesXTS(ciphertext, key1, key2, dataUnitSize) +} + +func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return symm.EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) +} + +func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return symm.DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, dataUnitIndex) +} + +func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return symm.EncryptAesXTSStream(dst, src, key1, key2, dataUnitSize) +} + +func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return symm.DecryptAesXTSStream(dst, src, key1, key2, dataUnitSize) +} + +func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return symm.EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex) +} + +func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return symm.DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex) +} diff --git a/api_security_test.go b/api_security_test.go new file mode 100644 index 0000000..a7919a6 --- /dev/null +++ b/api_security_test.go @@ -0,0 +1,48 @@ +package starcrypto + +import ( + "os" + "path/filepath" + "testing" +) + +func TestRootHmacVerifyWrappers(t *testing.T) { + msg := []byte("hmac-verify-message") + key := []byte("hmac-verify-key") + + sum := HmacSHA256(msg, key) + if !VerifyHmacSHA256(msg, key, sum) { + t.Fatalf("VerifyHmacSHA256 should pass for correct digest") + } + + bad := make([]byte, len(sum)) + copy(bad, sum) + bad[0] ^= 0xff + if VerifyHmacSHA256(msg, key, bad) { + t.Fatalf("VerifyHmacSHA256 should fail for wrong digest") + } + + hexSum := HmacSHA256Str(msg, key) + if !VerifyHmacSHA256Str(msg, key, hexSum) { + t.Fatalf("VerifyHmacSHA256Str should pass for correct digest") + } + if VerifyHmacSHA256Str(msg, key, "not-hex") { + t.Fatalf("VerifyHmacSHA256Str should fail for invalid hex") + } +} + +func TestRootFillWithCryptoRandomWrapper(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "secure.bin") + + if err := FillWithCryptoRandom(path, 1024, 128, nil); err != nil { + t.Fatalf("FillWithCryptoRandom failed: %v", err) + } + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat file failed: %v", err) + } + if info.Size() != 1024 { + t.Fatalf("unexpected file size: %d", info.Size()) + } +} diff --git a/api_test.go b/api_test.go index 62d11e5..889797c 100644 --- a/api_test.go +++ b/api_test.go @@ -263,3 +263,353 @@ func TestRootOptionsAndGCMWrappers(t *testing.T) { t.Fatalf("sm4 options wrapper mismatch") } } + +func TestRootCTRAtAndGCMChunkWrappers(t *testing.T) { + aesKey := []byte("0123456789abcdef") + aesIV := []byte("abcdef9876543210") + aesNonce := []byte("123456789012") + plain := bytes.Repeat([]byte("root-ctr-offset-"), 64) + + aesFull, err := EncryptAesCTR(plain, aesKey, aesIV) + if err != nil { + t.Fatalf("EncryptAesCTR failed: %v", err) + } + offset := 33 + seg := aesFull[offset : offset+100] + decSeg, err := DecryptAesCTRAt(seg, aesKey, aesIV, int64(offset)) + if err != nil { + t.Fatalf("DecryptAesCTRAt failed: %v", err) + } + if !bytes.Equal(decSeg, plain[offset:offset+100]) { + t.Fatalf("root aes ctr at mismatch") + } + + chunkCipher, err := EncryptAesGCMChunk([]byte("root-aes-gcm-chunk"), aesKey, aesNonce, []byte("aad"), 2) + if err != nil { + t.Fatalf("EncryptAesGCMChunk failed: %v", err) + } + chunkPlain, err := DecryptAesGCMChunk(chunkCipher, aesKey, aesNonce, []byte("aad"), 2) + if err != nil { + t.Fatalf("DecryptAesGCMChunk failed: %v", err) + } + if !bytes.Equal(chunkPlain, []byte("root-aes-gcm-chunk")) { + t.Fatalf("root aes gcm chunk mismatch") + } + + sm4Key := []byte("0123456789abcdef") + sm4IV := []byte("abcdef9876543210") + sm4Nonce := []byte("123456789012") + sm4Full, err := EncryptSM4CTR(plain, sm4Key, sm4IV) + if err != nil { + t.Fatalf("EncryptSM4CTR failed: %v", err) + } + sm4Seg := sm4Full[offset : offset+100] + sm4DecSeg, err := DecryptSM4CTRAt(sm4Seg, sm4Key, sm4IV, int64(offset)) + if err != nil { + t.Fatalf("DecryptSM4CTRAt failed: %v", err) + } + if !bytes.Equal(sm4DecSeg, plain[offset:offset+100]) { + t.Fatalf("root sm4 ctr at mismatch") + } + + sm4ChunkCipher, err := EncryptSM4GCMChunk([]byte("root-sm4-gcm-chunk"), sm4Key, sm4Nonce, []byte("aad"), 3) + if err != nil { + t.Fatalf("EncryptSM4GCMChunk failed: %v", err) + } + sm4ChunkPlain, err := DecryptSM4GCMChunk(sm4ChunkCipher, sm4Key, sm4Nonce, []byte("aad"), 3) + if err != nil { + t.Fatalf("DecryptSM4GCMChunk failed: %v", err) + } + if !bytes.Equal(sm4ChunkPlain, []byte("root-sm4-gcm-chunk")) { + t.Fatalf("root sm4 gcm chunk mismatch") + } +} + +func TestRootCFB8AndSegmentDecryptWrappers(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 4) + + aesCFB8, err := EncryptAesCFB8(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCFB8 failed: %v", err) + } + aesCFB8Dec, err := DecryptAesCFB8(aesCFB8, key, iv) + if err != nil { + t.Fatalf("DecryptAesCFB8 failed: %v", err) + } + if !bytes.Equal(aesCFB8Dec, plain) { + t.Fatalf("root aes cfb8 mismatch") + } + + aesCBC, err := EncryptAesCBC(plain, key, iv, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + aesCBCSegDec, err := DecryptAesCBCFromSecondBlock(aesCBC[len(iv):], key, aesCBC[:len(iv)]) + if err != nil { + t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err) + } + if !bytes.Equal(aesCBCSegDec, plain[len(iv):]) { + t.Fatalf("root aes cbc from-second-block mismatch") + } + + aesCFB, err := EncryptAesCFB(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCFB failed: %v", err) + } + aesCFBSegDec, err := DecryptAesCFBFromSecondBlock(aesCFB[len(iv):], key, aesCFB[:len(iv)]) + if err != nil { + t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err) + } + if !bytes.Equal(aesCFBSegDec, plain[len(iv):]) { + t.Fatalf("root aes cfb from-second-block mismatch") + } + + sm4CFB8, err := EncryptSM4CFB8(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CFB8 failed: %v", err) + } + sm4CFB8Dec, err := DecryptSM4CFB8(sm4CFB8, key, iv) + if err != nil { + t.Fatalf("DecryptSM4CFB8 failed: %v", err) + } + if !bytes.Equal(sm4CFB8Dec, plain) { + t.Fatalf("root sm4 cfb8 mismatch") + } + + sm4CBC, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptSM4CBC failed: %v", err) + } + sm4CBCSegDec, err := DecryptSM4CBCFromSecondBlock(sm4CBC[len(iv):], key, sm4CBC[:len(iv)]) + if err != nil { + t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err) + } + if !bytes.Equal(sm4CBCSegDec, plain[len(iv):]) { + t.Fatalf("root sm4 cbc from-second-block mismatch") + } + + sm4CFB, err := EncryptSM4CFBNoBlock(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CFB failed: %v", err) + } + sm4CFBSegDec, err := DecryptSM4CFBFromSecondBlock(sm4CFB[len(iv):], key, sm4CFB[:len(iv)]) + if err != nil { + t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err) + } + if !bytes.Equal(sm4CFBSegDec, plain[len(iv):]) { + t.Fatalf("root sm4 cfb from-second-block mismatch") + } +} +func TestRootOptionsAndCCMWrappers(t *testing.T) { + aesKey := []byte("0123456789abcdef") + sm4Key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("root-options-ccm") + + aesEnc, err := EncryptAesCCM(plain, aesKey, nonce, aad) + if err != nil { + t.Fatalf("EncryptAesCCM failed: %v", err) + } + aesDec, err := DecryptAesCCM(aesEnc, aesKey, nonce, aad) + if err != nil { + t.Fatalf("DecryptAesCCM failed: %v", err) + } + if !bytes.Equal(aesDec, plain) { + t.Fatalf("aes ccm wrapper mismatch") + } + + aesOpts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad} + aesOptEnc, err := EncryptAesWithOptions(plain, aesKey, aesOpts) + if err != nil { + t.Fatalf("EncryptAesWithOptions CCM failed: %v", err) + } + aesOptDec, err := DecryptAesWithOptions(aesOptEnc, aesKey, aesOpts) + if err != nil { + t.Fatalf("DecryptAesWithOptions CCM failed: %v", err) + } + if !bytes.Equal(aesOptDec, plain) { + t.Fatalf("aes options ccm wrapper mismatch") + } + + aesChunkCipher, err := EncryptAesCCMChunk([]byte("root-aes-ccm-chunk"), aesKey, nonce, aad, 4) + if err != nil { + t.Fatalf("EncryptAesCCMChunk failed: %v", err) + } + aesChunkPlain, err := DecryptAesCCMChunk(aesChunkCipher, aesKey, nonce, aad, 4) + if err != nil { + t.Fatalf("DecryptAesCCMChunk failed: %v", err) + } + if !bytes.Equal(aesChunkPlain, []byte("root-aes-ccm-chunk")) { + t.Fatalf("root aes ccm chunk mismatch") + } + + aesStreamEnc := &bytes.Buffer{} + if err := EncryptAesCCMStream(aesStreamEnc, bytes.NewReader(plain), aesKey, nonce, aad); err != nil { + t.Fatalf("EncryptAesCCMStream failed: %v", err) + } + aesStreamDec := &bytes.Buffer{} + if err := DecryptAesCCMStream(aesStreamDec, bytes.NewReader(aesStreamEnc.Bytes()), aesKey, nonce, aad); err != nil { + t.Fatalf("DecryptAesCCMStream failed: %v", err) + } + if !bytes.Equal(aesStreamDec.Bytes(), plain) { + t.Fatalf("aes ccm stream wrapper mismatch") + } + + sm4Enc, err := EncryptSM4CCM(plain, sm4Key, nonce, aad) + if err != nil { + t.Fatalf("EncryptSM4CCM failed: %v", err) + } + sm4Dec, err := DecryptSM4CCM(sm4Enc, sm4Key, nonce, aad) + if err != nil { + t.Fatalf("DecryptSM4CCM failed: %v", err) + } + if !bytes.Equal(sm4Dec, plain) { + t.Fatalf("sm4 ccm wrapper mismatch") + } + + sm4Opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad} + sm4OptEnc, err := EncryptSM4WithOptions(plain, sm4Key, sm4Opts) + if err != nil { + t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err) + } + sm4OptDec, err := DecryptSM4WithOptions(sm4OptEnc, sm4Key, sm4Opts) + if err != nil { + t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err) + } + if !bytes.Equal(sm4OptDec, plain) { + t.Fatalf("sm4 options ccm wrapper mismatch") + } + + sm4ChunkCipher, err := EncryptSM4CCMChunk([]byte("root-sm4-ccm-chunk"), sm4Key, nonce, aad, 6) + if err != nil { + t.Fatalf("EncryptSM4CCMChunk failed: %v", err) + } + sm4ChunkPlain, err := DecryptSM4CCMChunk(sm4ChunkCipher, sm4Key, nonce, aad, 6) + if err != nil { + t.Fatalf("DecryptSM4CCMChunk failed: %v", err) + } + if !bytes.Equal(sm4ChunkPlain, []byte("root-sm4-ccm-chunk")) { + t.Fatalf("root sm4 ccm chunk mismatch") + } + + sm4StreamEnc := &bytes.Buffer{} + if err := EncryptSM4CCMStream(sm4StreamEnc, bytes.NewReader(plain), sm4Key, nonce, aad); err != nil { + t.Fatalf("EncryptSM4CCMStream failed: %v", err) + } + sm4StreamDec := &bytes.Buffer{} + if err := DecryptSM4CCMStream(sm4StreamDec, bytes.NewReader(sm4StreamEnc.Bytes()), sm4Key, nonce, aad); err != nil { + t.Fatalf("DecryptSM4CCMStream failed: %v", err) + } + if !bytes.Equal(sm4StreamDec.Bytes(), plain) { + t.Fatalf("sm4 ccm stream wrapper mismatch") + } +} +func TestRootXTSWrappers(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 16) + + aesEnc, err := EncryptAesXTS(plain, k1, k2, 64) + if err != nil { + t.Fatalf("EncryptAesXTS failed: %v", err) + } + aesDec, err := DecryptAesXTS(aesEnc, k1, k2, 64) + if err != nil { + t.Fatalf("DecryptAesXTS failed: %v", err) + } + if !bytes.Equal(aesDec, plain) { + t.Fatalf("root aes xts mismatch") + } + + aesSegEnc, err := EncryptAesXTSAt(plain[:64], k1, k2, 64, 1) + if err != nil { + t.Fatalf("EncryptAesXTSAt failed: %v", err) + } + aesSegDec, err := DecryptAesXTSAt(aesSegEnc, k1, k2, 64, 1) + if err != nil { + t.Fatalf("DecryptAesXTSAt failed: %v", err) + } + if !bytes.Equal(aesSegDec, plain[:64]) { + t.Fatalf("root aes xts at mismatch") + } + + aesStreamEnc := &bytes.Buffer{} + if err := EncryptAesXTSStream(aesStreamEnc, bytes.NewReader(plain), k1, k2, 64); err != nil { + t.Fatalf("EncryptAesXTSStream failed: %v", err) + } + aesStreamDec := &bytes.Buffer{} + if err := DecryptAesXTSStream(aesStreamDec, bytes.NewReader(aesStreamEnc.Bytes()), k1, k2, 64); err != nil { + t.Fatalf("DecryptAesXTSStream failed: %v", err) + } + if !bytes.Equal(aesStreamDec.Bytes(), plain) { + t.Fatalf("root aes xts stream mismatch") + } + + sm4Enc, err := EncryptSM4XTS(plain, k1, k2, 64) + if err != nil { + t.Fatalf("EncryptSM4XTS failed: %v", err) + } + sm4Dec, err := DecryptSM4XTS(sm4Enc, k1, k2, 64) + if err != nil { + t.Fatalf("DecryptSM4XTS failed: %v", err) + } + if !bytes.Equal(sm4Dec, plain) { + t.Fatalf("root sm4 xts mismatch") + } + + sm4SegEnc, err := EncryptSM4XTSAt(plain[:64], k1, k2, 64, 2) + if err != nil { + t.Fatalf("EncryptSM4XTSAt failed: %v", err) + } + sm4SegDec, err := DecryptSM4XTSAt(sm4SegEnc, k1, k2, 64, 2) + if err != nil { + t.Fatalf("DecryptSM4XTSAt failed: %v", err) + } + if !bytes.Equal(sm4SegDec, plain[:64]) { + t.Fatalf("root sm4 xts at mismatch") + } + + sm4StreamEnc := &bytes.Buffer{} + if err := EncryptSM4XTSStream(sm4StreamEnc, bytes.NewReader(plain), k1, k2, 64); err != nil { + t.Fatalf("EncryptSM4XTSStream failed: %v", err) + } + sm4StreamDec := &bytes.Buffer{} + if err := DecryptSM4XTSStream(sm4StreamDec, bytes.NewReader(sm4StreamEnc.Bytes()), k1, k2, 64); err != nil { + t.Fatalf("DecryptSM4XTSStream failed: %v", err) + } + if !bytes.Equal(sm4StreamDec.Bytes(), plain) { + t.Fatalf("root sm4 xts stream mismatch") + } +} +func TestRootKDFAndXTSKeySplitWrappers(t *testing.T) { + pbk, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 1, 32) + if err != nil { + t.Fatalf("DerivePBKDF2SHA256Key failed: %v", err) + } + if len(pbk) != 32 { + t.Fatalf("pbkdf2 key length mismatch") + } + + argonParams := DefaultArgon2idParams() + argonParams.Memory = 32 * 1024 + argonParams.Threads = 1 + argon, err := DeriveArgon2idKey("password", []byte("salt-salt"), argonParams) + if err != nil { + t.Fatalf("DeriveArgon2idKey failed: %v", err) + } + if len(argon) != int(argonParams.KeyLen) { + t.Fatalf("argon2 key length mismatch") + } + + master := []byte("0123456789abcdef0123456789abcdef") + k1, k2, err := SplitXTSMasterKey(master) + if err != nil { + t.Fatalf("SplitXTSMasterKey failed: %v", err) + } + if len(k1) != 16 || len(k2) != 16 { + t.Fatalf("split xts wrapper key lengths mismatch") + } +} diff --git a/asymm/key.go b/asymm/key.go index 3401930..c246cf6 100644 --- a/asymm/key.go +++ b/asymm/key.go @@ -10,6 +10,7 @@ import ( "encoding/pem" "errors" + "github.com/emmansun/gmsm/pkcs8" "github.com/emmansun/gmsm/sm2" "github.com/emmansun/gmsm/smx509" "golang.org/x/crypto/ssh" @@ -86,6 +87,11 @@ func DecodePrivateKey(private []byte, password string) (crypto.PrivateKey, error return key, nil } return smx509.ParsePKCS8PrivateKey(bytes) + case "ENCRYPTED PRIVATE KEY": + if password == "" { + return nil, errors.New("private key is encrypted but password is empty") + } + return pkcs8.ParsePKCS8PrivateKey(blk.Bytes, []byte(password)) case "OPENSSH PRIVATE KEY": if password == "" { return ssh.ParseRawPrivateKey(private) diff --git a/asymm/rsa.go b/asymm/rsa.go index 9799072..9b09f21 100644 --- a/asymm/rsa.go +++ b/asymm/rsa.go @@ -9,6 +9,7 @@ import ( "errors" "math/big" + "github.com/emmansun/gmsm/pkcs8" "golang.org/x/crypto/ssh" ) @@ -21,6 +22,37 @@ func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { } func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) { + return EncodeRsaPrivateKeyWithLegacy(private, secret, true) +} + +func EncodeRsaPrivateKeyWithLegacy(private *rsa.PrivateKey, secret string, legacy bool) ([]byte, error) { + if legacy { + return encodeRsaPrivateKeyLegacy(private, secret) + } + return EncodeRsaPrivateKeyPKCS8(private, secret) +} + +func EncodeRsaPrivateKeyPKCS8(private *rsa.PrivateKey, secret string) ([]byte, error) { + password := []byte(secret) + var ( + der []byte + blockType = "PRIVATE KEY" + err error + ) + + if secret == "" { + der, err = pkcs8.MarshalPrivateKey(private, nil, nil) + } else { + der, err = pkcs8.MarshalPrivateKey(private, password, pkcs8.DefaultOpts) + blockType = "ENCRYPTED PRIVATE KEY" + } + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der}), nil +} + +func encodeRsaPrivateKeyLegacy(private *rsa.PrivateKey, secret string) ([]byte, error) { der := x509.MarshalPKCS1PrivateKey(private) if secret == "" { return pem.EncodeToMemory(&pem.Block{ @@ -52,6 +84,46 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro return nil, errors.New("private key error") } + switch blk.Type { + case "PRIVATE KEY", "ENCRYPTED PRIVATE KEY": + return DecodeRsaPrivateKeyPKCS8(private, password) + default: + return decodeRsaPrivateKeyLegacy(private, password) + } +} + +func DecodeRsaPrivateKeyWithLegacy(private []byte, password string, legacy bool) (*rsa.PrivateKey, error) { + if legacy { + return decodeRsaPrivateKeyLegacy(private, password) + } + return DecodeRsaPrivateKeyPKCS8(private, password) +} + +func DecodeRsaPrivateKeyPKCS8(private []byte, password string) (*rsa.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + + switch blk.Type { + case "PRIVATE KEY": + return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes) + case "ENCRYPTED PRIVATE KEY": + if password == "" { + return nil, errors.New("private key is encrypted but password is empty") + } + return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes, []byte(password)) + default: + return nil, errors.New("private key is not PKCS#8") + } +} + +func decodeRsaPrivateKeyLegacy(private []byte, password string) (*rsa.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + bytes, err := decodePEMBlockBytes(blk, password) if err != nil { return nil, err @@ -60,11 +132,11 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil { return prikey, nil } - pkcs8, err := x509.ParsePKCS8PrivateKey(bytes) + pkcs8key, err := x509.ParsePKCS8PrivateKey(bytes) if err != nil { return nil, err } - prikey, ok := pkcs8.(*rsa.PrivateKey) + prikey, ok := pkcs8key.(*rsa.PrivateKey) if !ok { return nil, errors.New("private key is not RSA") } @@ -96,6 +168,10 @@ func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) { } func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { + return GenerateRsaSSHKeyPairWithLegacy(bits, secret, true) +} + +func GenerateRsaSSHKeyPairWithLegacy(bits int, secret string, legacy bool) (string, string, error) { pkey, pubkey, err := GenerateRsaKey(bits) if err != nil { return "", "", err @@ -104,7 +180,7 @@ func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { if err != nil { return "", "", err } - priv, err := EncodeRsaPrivateKey(pkey, secret) + priv, err := EncodeRsaPrivateKeyWithLegacy(pkey, secret, legacy) if err != nil { return "", "", err } @@ -119,6 +195,22 @@ func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) { return rsa.DecryptPKCS1v15(rand.Reader, prikey, data) } +func RSAEncryptOAEP(pub *rsa.PublicKey, data, label []byte, hashType crypto.Hash) ([]byte, error) { + hashType, err := normalizeModernRSAHash(hashType) + if err != nil { + return nil, err + } + return rsa.EncryptOAEP(hashType.New(), rand.Reader, pub, data, label) +} + +func RSADecryptOAEP(prikey *rsa.PrivateKey, data, label []byte, hashType crypto.Hash) ([]byte, error) { + hashType, err := normalizeModernRSAHash(hashType) + if err != nil { + return nil, err + } + return rsa.DecryptOAEP(hashType.New(), rand.Reader, prikey, data, label) +} + func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) { prikey, err := DecodeRsaPrivateKey(priKey, password) if err != nil { @@ -143,6 +235,38 @@ func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error { return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig) } +func RSASignPSS(msg, priKey []byte, password string, hashType crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) { + prikey, err := DecodeRsaPrivateKey(priKey, password) + if err != nil { + return nil, err + } + hashType, err = normalizeModernRSAHash(hashType) + if err != nil { + return nil, err + } + hashed, err := hashMessage(msg, hashType) + if err != nil { + return nil, err + } + return rsa.SignPSS(rand.Reader, prikey, hashType, hashed, opts) +} + +func RSAVerifyPSS(sig, msg, pubKey []byte, hashType crypto.Hash, opts *rsa.PSSOptions) error { + pubkey, err := DecodeRsaPublicKey(pubKey) + if err != nil { + return err + } + hashType, err = normalizeModernRSAHash(hashType) + if err != nil { + return err + } + hashed, err := hashMessage(msg, hashType) + if err != nil { + return err + } + return rsa.VerifyPSS(pubkey, hashType, hashed, sig, opts) +} + func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) { return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data) } @@ -154,6 +278,16 @@ func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) { return unLeftPad(em) } +func normalizeModernRSAHash(hashType crypto.Hash) (crypto.Hash, error) { + if hashType == 0 { + hashType = crypto.SHA256 + } + if !hashType.Available() { + return 0, errors.New("hash function is not available") + } + return hashType, nil +} + func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) { if hashType == 0 { return msg, nil diff --git a/asymm/rsa_test.go b/asymm/rsa_test.go new file mode 100644 index 0000000..ef7fd1c --- /dev/null +++ b/asymm/rsa_test.go @@ -0,0 +1,134 @@ +package asymm + +import ( + "bytes" + "crypto/rsa" + "testing" +) + +func TestRsaPrivateKeyEncodeDecodeWithLegacyFlag(t *testing.T) { + priv, pub, err := GenerateRsaKey(1024) + if err != nil { + t.Fatalf("GenerateRsaKey failed: %v", err) + } + + modernPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", false) + if err != nil { + t.Fatalf("EncodeRsaPrivateKeyWithLegacy modern failed: %v", err) + } + if !bytes.Contains(modernPEM, []byte("ENCRYPTED PRIVATE KEY")) { + t.Fatalf("modern PEM should be ENCRYPTED PRIVATE KEY") + } + modernPriv, err := DecodeRsaPrivateKeyWithLegacy(modernPEM, "pwd", false) + if err != nil { + t.Fatalf("DecodeRsaPrivateKeyWithLegacy modern failed: %v", err) + } + if modernPriv.D.Cmp(priv.D) != 0 { + t.Fatalf("modern decoded private key mismatch") + } + + autoPriv, err := DecodeRsaPrivateKey(modernPEM, "pwd") + if err != nil { + t.Fatalf("DecodeRsaPrivateKey auto failed: %v", err) + } + if autoPriv.D.Cmp(priv.D) != 0 { + t.Fatalf("auto decoded private key mismatch") + } + + if _, err := DecodePrivateKey(modernPEM, "pwd"); err != nil { + t.Fatalf("DecodePrivateKey for encrypted PKCS8 failed: %v", err) + } + + legacyPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", true) + if err != nil { + t.Fatalf("EncodeRsaPrivateKeyWithLegacy legacy failed: %v", err) + } + if !bytes.Contains(legacyPEM, []byte("RSA PRIVATE KEY")) { + t.Fatalf("legacy PEM should be RSA PRIVATE KEY") + } + legacyPriv, err := DecodeRsaPrivateKeyWithLegacy(legacyPEM, "pwd", true) + if err != nil { + t.Fatalf("DecodeRsaPrivateKeyWithLegacy legacy failed: %v", err) + } + if legacyPriv.D.Cmp(priv.D) != 0 { + t.Fatalf("legacy decoded private key mismatch") + } + + pubPEM, err := EncodeRsaPublicKey(pub) + if err != nil { + t.Fatalf("EncodeRsaPublicKey failed: %v", err) + } + decodedPub, err := DecodeRsaPublicKey(pubPEM) + if err != nil { + t.Fatalf("DecodeRsaPublicKey failed: %v", err) + } + if decodedPub.N.Cmp(pub.N) != 0 || decodedPub.E != pub.E { + t.Fatalf("decoded public key mismatch") + } +} + +func TestRSAOAEPEncryptDecrypt(t *testing.T) { + priv, pub, err := GenerateRsaKey(1024) + if err != nil { + t.Fatalf("GenerateRsaKey failed: %v", err) + } + + msg := []byte("rsa-oaep-message") + label := []byte("label") + enc, err := RSAEncryptOAEP(pub, msg, label, 0) + if err != nil { + t.Fatalf("RSAEncryptOAEP failed: %v", err) + } + dec, err := RSADecryptOAEP(priv, enc, label, 0) + if err != nil { + t.Fatalf("RSADecryptOAEP failed: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("oaep decrypt mismatch") + } +} + +func TestRSAPSSSignVerify(t *testing.T) { + priv, pub, err := GenerateRsaKey(1024) + if err != nil { + t.Fatalf("GenerateRsaKey failed: %v", err) + } + + privPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", false) + if err != nil { + t.Fatalf("EncodeRsaPrivateKeyWithLegacy failed: %v", err) + } + pubPEM, err := EncodeRsaPublicKey(pub) + if err != nil { + t.Fatalf("EncodeRsaPublicKey failed: %v", err) + } + + msg := []byte("rsa-pss-message") + sig, err := RSASignPSS(msg, privPEM, "pwd", 0, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}) + if err != nil { + t.Fatalf("RSASignPSS failed: %v", err) + } + if err := RSAVerifyPSS(sig, msg, pubPEM, 0, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}); err != nil { + t.Fatalf("RSAVerifyPSS failed: %v", err) + } +} + +func TestRSAPKCS1v15EncryptDecryptCompatibility(t *testing.T) { + priv, pub, err := GenerateRsaKey(1024) + if err != nil { + t.Fatalf("GenerateRsaKey failed: %v", err) + } + + msg := []byte("rsa-pkcs1v15-message") + enc, err := RSAEncrypt(pub, msg) + if err != nil { + t.Fatalf("RSAEncrypt failed: %v", err) + } + dec, err := RSADecrypt(priv, enc) + if err != nil { + t.Fatalf("RSADecrypt failed: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("pkcs1v15 decrypt mismatch") + } +} diff --git a/ccm/ccm.go b/ccm/ccm.go new file mode 100644 index 0000000..406ede4 --- /dev/null +++ b/ccm/ccm.go @@ -0,0 +1,259 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +// Package ccm implements a CCM, Counter with CBC-MAC +// as per RFC 3610. +// +// See https://tools.ietf.org/html/rfc3610 +// +// This code is derived from https://github.com/pion/dtls (pkg/crypto/ccm). +// The original upstream license is MIT and is preserved in this repository. +// +// A request for including CCM into the Go standard library +// can be found as issue #27484 on the https://github.com/golang/go/ +// repository. +package ccm + +import ( + "crypto/cipher" + "crypto/subtle" + "encoding/binary" + "errors" + "math" +) + +// ccm represents a Counter with CBC-MAC with a specific key. +type ccm struct { + b cipher.Block + M uint8 + L uint8 +} + +const ccmBlockSize = 16 + +// CCM is a block cipher in Counter with CBC-MAC mode. +// Providing authenticated encryption with associated data via the cipher.AEAD interface. +type CCM interface { + cipher.AEAD + // MaxLength returns the maxium length of plaintext in calls to Seal. + // The maximum length of ciphertext in calls to Open is MaxLength()+Overhead(). + // The maximum length is related to CCM's `L` parameter (15-noncesize) and + // is 1<<(8*L) - 1 (but also limited by the maxium size of an int). + MaxLength() int +} + +var ( + errInvalidBlockSize = errors.New("ccm: NewCCM requires 128-bit block cipher") + errInvalidTagSize = errors.New("ccm: tagsize must be 4, 6, 8, 10, 12, 14, or 16") + errInvalidNonceSize = errors.New("ccm: invalid nonce size") +) + +// NewCCM returns the given 128-bit block cipher wrapped in CCM. +// The tagsize must be an even integer between 4 and 16 inclusive +// and is used as CCM's `M` parameter. +// The noncesize must be an integer between 7 and 13 inclusive, +// 15-noncesize is used as CCM's `L` parameter. +func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) { + if b.BlockSize() != ccmBlockSize { + return nil, errInvalidBlockSize + } + if tagsize < 4 || tagsize > 16 || tagsize&1 != 0 { + return nil, errInvalidTagSize + } + lensize := 15 - noncesize + if lensize < 2 || lensize > 8 { + return nil, errInvalidNonceSize + } + c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} //nolint:gosec // G114 + + return c, nil +} + +func (c *ccm) NonceSize() int { return 15 - int(c.L) } +func (c *ccm) Overhead() int { return int(c.M) } +func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) } + +func maxlen(l uint8, tagsize int) int { + mLen := (uint64(1) << (8 * l)) - 1 + if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { //nolint:gosec // G114 + mLen = m64 // The maximum lentgh on a 64bit arch + } + if mLen != uint64(int(mLen)) { //nolint:gosec // G114 + return math.MaxInt32 - tagsize // We have only 32bit int's + } + + return int(mLen) //nolint:gosec // G114 +} + +// MaxNonceLength returns the maximum nonce length for a given plaintext length. +// A return value <= 0 indicates that plaintext length is too large for +// any nonce length. +func MaxNonceLength(pdatalen int) int { + const tagsize = 16 + for L := 2; L <= 8; L++ { + if maxlen(uint8(L), tagsize) >= pdatalen { //nolint:gosec // G115 + return 15 - L + } + } + + return 0 +} + +func (c *ccm) cbcRound(mac, data []byte) { + for i := range ccmBlockSize { + mac[i] ^= data[i] + } + c.b.Encrypt(mac, mac) +} + +func (c *ccm) cbcData(mac, data []byte) { + for len(data) >= ccmBlockSize { + c.cbcRound(mac, data[:ccmBlockSize]) + data = data[ccmBlockSize:] + } + if len(data) > 0 { + var block [ccmBlockSize]byte + copy(block[:], data) + c.cbcRound(mac, block[:]) + } +} + +var errPlaintextTooLong = errors.New("ccm: plaintext too large") + +func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) { + var mac [ccmBlockSize]byte + + if len(adata) > 0 { + mac[0] |= 1 << 6 + } + mac[0] |= (c.M - 2) << 2 + mac[0] |= c.L - 1 + if len(nonce) != c.NonceSize() { + return nil, errInvalidNonceSize + } + if len(plaintext) > c.MaxLength() { + return nil, errPlaintextTooLong + } + binary.BigEndian.PutUint64(mac[ccmBlockSize-8:], uint64(len(plaintext))) + copy(mac[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(mac[:], mac[:]) + + var block [ccmBlockSize]byte + if adataLength := uint64(len(adata)); adataLength > 0 { //nolint:nestif + // First adata block includes adata length + i := 2 + if adataLength <= 0xfeff { + binary.BigEndian.PutUint16(block[:i], uint16(adataLength)) + } else { + binary.BigEndian.PutUint16(block[0:2], 0xfeff) + if adataLength < uint64(1<<32) { + i = 2 + 4 + binary.BigEndian.PutUint32(block[2:i], uint32(adataLength)) //nolint:gosec // G115 + } else { + i = 2 + 8 + binary.BigEndian.PutUint64(block[2:i], adataLength) + } + } + i = copy(block[i:], adata) + c.cbcRound(mac[:], block[:]) + c.cbcData(mac[:], adata[i:]) + } + + if len(plaintext) > 0 { + c.cbcData(mac[:], plaintext) + } + + return mac[:c.M], nil +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +// From crypto/cipher/gcm.go +// . +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + + return +} + +// Seal encrypts and authenticates plaintext, authenticates the +// additional data and appends the result to dst, returning the updated +// slice. The nonce must be NonceSize() bytes long and unique for all +// time, for a given key. +// The plaintext must be no longer than MaxLength() bytes long. +// +// The plaintext and dst may alias exactly or not at all. +func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte { + tag, err := c.tag(nonce, plaintext, adata) + if err != nil { + // The cipher.AEAD interface doesn't allow for an error return. + panic(err) // nolint + } + + var iv, s0 [ccmBlockSize]byte + iv[0] = c.L - 1 + copy(iv[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(s0[:], iv[:]) + for i := 0; i < int(c.M); i++ { + tag[i] ^= s0[i] + } + iv[len(iv)-1] |= 1 + stream := cipher.NewCTR(c.b, iv[:]) + ret, out := sliceForAppend(dst, len(plaintext)+int(c.M)) + stream.XORKeyStream(out, plaintext) + copy(out[len(plaintext):], tag) + + return ret +} + +var ( + errOpen = errors.New("ccm: message authentication failed") + errCiphertextTooShort = errors.New("ccm: ciphertext too short") + errCiphertextTooLong = errors.New("ccm: ciphertext too long") +) + +func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) { + if len(ciphertext) < int(c.M) { + return nil, errCiphertextTooShort + } + if len(ciphertext) > c.MaxLength()+c.Overhead() { + return nil, errCiphertextTooLong + } + + tag := make([]byte, int(c.M)) + copy(tag, ciphertext[len(ciphertext)-int(c.M):]) + ciphertextWithoutTag := ciphertext[:len(ciphertext)-int(c.M)] + + var iv, s0 [ccmBlockSize]byte + iv[0] = c.L - 1 + copy(iv[1:ccmBlockSize-c.L], nonce) + c.b.Encrypt(s0[:], iv[:]) + for i := 0; i < int(c.M); i++ { + tag[i] ^= s0[i] + } + iv[len(iv)-1] |= 1 + stream := cipher.NewCTR(c.b, iv[:]) + + // Cannot decrypt directly to dst since we're not supposed to + // reveal the plaintext to the caller if authentication fails. + plaintext := make([]byte, len(ciphertextWithoutTag)) + stream.XORKeyStream(plaintext, ciphertextWithoutTag) + expectedTag, err := c.tag(nonce, plaintext, adata) + if err != nil { + return nil, err + } + + if subtle.ConstantTimeCompare(tag, expectedTag) != 1 { + return nil, errOpen + } + + return append(dst, plaintext...), nil +} diff --git a/ccm/ccm_test.go b/ccm/ccm_test.go new file mode 100644 index 0000000..41a98a2 --- /dev/null +++ b/ccm/ccm_test.go @@ -0,0 +1,451 @@ +// SPDX-FileCopyrightText: 2026 The Pion community +// SPDX-License-Identifier: MIT + +package ccm + +// Refer to RFC 3610 section 8 for the vectors. + +import ( + "crypto/aes" + "encoding/hex" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func mustHexDecode(t *testing.T, s string) []byte { + t.Helper() + r, err := hex.DecodeString(s) + assert.NoError(t, err) + + return r +} + +func aesKey1to12(t *testing.T) []byte { + t.Helper() + + return mustHexDecode(t, "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf") +} + +func aesKey13to24(t *testing.T) []byte { + t.Helper() + + return mustHexDecode(t, "d7828d13b2b0bdc325a76236df93cc6b") +} + +// AESKey: AES Key +// CipherText: Authenticated and encrypted output +// ClearHeaderOctets: Input with X cleartext header octets +// Data: Input with X cleartext header octets +// M: length(CBC-MAC) +// Nonce: Nonce. +type vector struct { + AESKey []byte + CipherText []byte + ClearHeaderOctets int + Data []byte + M int + Nonce []byte +} + +func TestRFC3610Vectors(t *testing.T) { //nolint:maintidx + cases := []vector{ + // Vectors 1-12 + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "0001020304050607588c979a61c663d2f066d0c2c0f989806d5f6b61dac38417e8d12cfdf926e0"), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), + M: 8, + Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060772c91a36e135f8cf291ca894085c87e3cc15c439c9e43a3ba091d56e10400916"), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F"), + M: 8, + Nonce: mustHexDecode(t, "00000004030201a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), + M: 8, + Nonce: mustHexDecode(t, "00000005040302a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0ba28c6865939a9a79faaa5c4c2a9d4a91cdac8c96c861b9c9e61ef1"), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), + M: 8, + Nonce: mustHexDecode(t, "00000006050403a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0bdcf1fb7b5d9e23fb9d4e131253658ad86ebdca3e51e83f077d9c2d93"), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + M: 8, + Nonce: mustHexDecode(t, "00000007060504a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), + M: 8, + Nonce: mustHexDecode(t, "00000008070605a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), + M: 10, + Nonce: mustHexDecode(t, "00000009080706a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + M: 10, + Nonce: mustHexDecode(t, "0000000a090807a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), + M: 10, + Nonce: mustHexDecode(t, "0000000b0a0908a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), + M: 10, + Nonce: mustHexDecode(t, "0000000c0b0a09a0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"), + M: 10, + Nonce: mustHexDecode(t, "0000000d0c0b0aa0a1a2a3a4a5"), + }, + { + AESKey: aesKey1to12(t), + CipherText: mustHexDecode(t, + "000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"), + M: 10, + Nonce: mustHexDecode(t, "0000000e0d0c0ba0a1a2a3a4a5"), + }, + // Vectors 13-24 + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "0be1a88bace018b14cb97f86a2a4689a877947ab8091ef5386a6ffbdd080f8e78cf7cb0cddd7b3"), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "0be1a88bace018b108e8cf97d820ea258460e96ad9cf5289054d895ceac47c"), + M: 8, + Nonce: mustHexDecode(t, "00412b4ea9cdbe3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "63018f76dc8a1bcb4ccb1e7ca981befaa0726c55d378061298c85c92814abc33c52ee81d7d77c08a"), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "63018f76dc8a1bcb9020ea6f91bdd85afa0039ba4baff9bfb79c7028949cd0ec"), + M: 8, + Nonce: mustHexDecode(t, "0033568ef7b2633c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "aa6cfa36cae86b40b916e0eacc1c00d7dcec68ec0b3bbb1a02de8a2d1aa346132e"), + M: 8, + Nonce: mustHexDecode(t, "00103fe41336713c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "d0d0735c531e1becf049c24414d253c3967b70609b7cbb7c499160283245269a6f49975bcadeaf"), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "d0d0735c531e1becf049c24412daac5630efa5396f770ce1a66b21f7b2101c"), + M: 8, + Nonce: mustHexDecode(t, "00764c63b8058e3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "77b60f011c03e1525899bcae5545ff1a085ee2efbf52b2e04bee1e2336c73e3f762c0c7744fe7e3c"), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "77b60f011c03e1525899bcaee88b6a46c78d63e52eb8c546efb5de6f75e9cc0d"), + M: 8, + Nonce: mustHexDecode(t, "00f8b678094e3b3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "cd9044d2b71fdb8120ea60c06435acbafb11a82e2f071d7ca4a5ebd93a803ba87f"), + M: 8, + Nonce: mustHexDecode(t, "00d560912d3f703c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "d85bc7e69f944fb88a19b950bcf71a018e5e6701c91787659809d67dbedd18"), + M: 10, + Nonce: mustHexDecode(t, "0042fff8f1951c3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "74a0ebc9069f5b371761433c37c5a35fc1f39f406302eb907c6163be38c98437"), + M: 10, + Nonce: mustHexDecode(t, "00920f40e56cdc3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce", + ), + ClearHeaderOctets: 8, + Data: mustHexDecode(t, "44a3aa3aae6475caa434a8e58500c6e41530538862d686ea9e81301b5ae4226bfa"), + M: 10, + Nonce: mustHexDecode(t, "0027ca0c7120bc3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "ec46bb63b02520c33c49fd70b96b49e21d621741632875db7f6c9243d2d7c2"), + M: 10, + Nonce: mustHexDecode(t, "005b8ccbcd9af83c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "47a65ac78b3d594227e85e71e2fcfbb880442c731bf95167c8ffd7895e337076"), + M: 10, + Nonce: mustHexDecode(t, "003ebe94044b9a3c9696766cfa"), + }, + { + AESKey: aesKey13to24(t), + CipherText: mustHexDecode(t, + "6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059", + ), + ClearHeaderOctets: 12, + Data: mustHexDecode(t, "6e37a6ef546d955d34ab6059abf21c0b02feb88f856df4a37381bce3cc128517d4"), + M: 10, + Nonce: mustHexDecode(t, "008d493b30ae8b3c9696766cfa"), + }, + } + + assert.Equal(t, 24, len(cases)) + + for idx, testCase := range cases { + t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) { + blk, err := aes.NewCipher(testCase.AESKey) + assert.NoError(t, err, "could not initialize AES block cipher from key") + + lccm, err := NewCCM(blk, testCase.M, len(testCase.Nonce)) + assert.NoError(t, err, "could not create CCM") + + t.Run("seal", func(t *testing.T) { + var dst []byte + dst = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) + assert.Equal(t, testCase.CipherText[testCase.ClearHeaderOctets:], dst) + }) + + t.Run("open", func(t *testing.T) { + var dst []byte + dst, err = lccm.Open( + dst, + testCase.Nonce, + testCase.CipherText[testCase.ClearHeaderOctets:], + testCase.CipherText[:testCase.ClearHeaderOctets], + ) + assert.NoError(t, err) + assert.Equal(t, testCase.Data[testCase.ClearHeaderOctets:], dst) + }) + }) + } +} + +func TestNewCCMError(t *testing.T) { + cases := map[string]struct { + vector + err error + }{ + "ShortNonceLength": { + vector{ + AESKey: aesKey1to12(t), + M: 8, + Nonce: mustHexDecode(t, "a0a1a2a3a4a5"), + }, errInvalidNonceSize, + }, + "LongNonceLength": { + vector{ + AESKey: aesKey1to12(t), + M: 8, + Nonce: mustHexDecode(t, "0001020304050607080910111213"), + }, errInvalidNonceSize, + }, + "ShortTag": { + vector{ + AESKey: aesKey1to12(t), + M: 3, + Nonce: mustHexDecode(t, "00010203040506070809101112"), + }, errInvalidTagSize, + }, + "LongTag": { + vector{ + AESKey: aesKey1to12(t), + M: 17, + Nonce: mustHexDecode(t, "00010203040506070809101112"), + }, errInvalidTagSize, + }, + } + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + blk, err := aes.NewCipher(c.AESKey) + assert.NoError(t, err, "could not initialize AES block cipher from key") + + _, err = NewCCM(blk, c.M, len(c.Nonce)) + assert.ErrorIs(t, err, c.err) + }) + } +} + +func TestSealError(t *testing.T) { + cases := map[string]struct { + vector + err error + }{ + "InvalidNonceLength": { + vector{ + Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"), + M: 8, + Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4"), // short + }, errInvalidNonceSize, + }, + "PlaintextTooLong": { + vector{ + Data: make([]byte, 100000), + M: 8, + Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"), + }, errPlaintextTooLong, + }, + } + + blk, err := aes.NewCipher(aesKey1to12(t)) + assert.NoError(t, err) + + lccm, err := NewCCM(blk, 8, 13) + assert.NoError(t, err) + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + defer func() { + err, ok := recover().(error) + assert.True(t, ok) + assert.ErrorIs(t, err, testCase.err) + }() + var dst []byte + _ = lccm.Seal( + dst, + testCase.Nonce, + testCase.Data[testCase.ClearHeaderOctets:], + testCase.Data[:testCase.ClearHeaderOctets], + ) + }) + } +} + +func TestOpenError(t *testing.T) { + cases := map[string]struct { + vector + err error + }{ + "CiphertextTooShort": { + vector{ + CipherText: make([]byte, 10), + ClearHeaderOctets: 8, + Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"), + }, errCiphertextTooShort, + }, + "CiphertextTooLong": { + vector{ + CipherText: make([]byte, 100000), + ClearHeaderOctets: 8, + Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"), + }, errCiphertextTooLong, + }, + } + + blk, err := aes.NewCipher(aesKey1to12(t)) + assert.NoError(t, err, "could not initialize AES block cipher from key") + + lccm, err := NewCCM(blk, 8, 13) + assert.NoError(t, err, "could not create CCM") + + for name, c := range cases { + t.Run(name, func(t *testing.T) { + var dst []byte + _, err = lccm.Open(dst, c.Nonce, c.CipherText[c.ClearHeaderOctets:], c.CipherText[:c.ClearHeaderOctets]) + assert.ErrorIs(t, err, c.err) + }) + } +} diff --git a/encodingx/encoding.go b/encodingx/encoding.go index f743c49..0b29f14 100644 --- a/encodingx/encoding.go +++ b/encodingx/encoding.go @@ -6,6 +6,7 @@ import ( "errors" "io" "os" + "strings" ) var ( @@ -192,12 +193,12 @@ func Base85Encode(bstr []byte) string { } func Base85Decode(str string) ([]byte, error) { - out := make([]byte, len(str)) - n, _, err := ascii85.Decode(out, []byte(str), true) + dec := ascii85.NewDecoder(strings.NewReader(str)) + out, err := io.ReadAll(dec) if err != nil { return nil, err } - return out[:n], nil + return out, nil } func Base85EncodeFile(src, dst string, progress func(float64)) error { diff --git a/encodingx/encoding_test.go b/encodingx/encoding_test.go index c016ca9..bdb3dc0 100644 --- a/encodingx/encoding_test.go +++ b/encodingx/encoding_test.go @@ -96,3 +96,36 @@ func TestBase64AndBase85FileRoundTrip(t *testing.T) { t.Fatalf("base85 file roundtrip mismatch") } } + +func TestBase85RoundTripEdgeLengths(t *testing.T) { + for n := 0; n <= 128; n++ { + plain := make([]byte, n) + for i := range plain { + plain[i] = byte((i*37 + n) % 256) + } + + e := Base85Encode(plain) + d, err := Base85Decode(e) + if err != nil { + t.Fatalf("Base85Decode failed at len=%d: %v", n, err) + } + if !bytes.Equal(d, plain) { + t.Fatalf("base85 mismatch at len=%d", n) + } + } +} + +func TestBase91RoundTripEdgeLengths(t *testing.T) { + for n := 0; n <= 256; n++ { + plain := make([]byte, n) + for i := range plain { + plain[i] = byte((i*53 + n) % 256) + } + + e := Base91Encode(plain) + d := Base91Decode(e) + if !bytes.Equal(d, plain) { + t.Fatalf("base91 mismatch at len=%d", n) + } + } +} diff --git a/file.go b/file.go index f1b46a5..d8a704d 100644 --- a/file.go +++ b/file.go @@ -18,6 +18,13 @@ func MergeFile(src, dst string, shell func(float64)) error { return filex.MergeFile(src, dst, shell) } +// FillWithRandom uses math/rand pseudo-random bytes and is not cryptographically secure. func FillWithRandom(filepath string, filesize int, bufcap int, bufnum int, shell func(float64)) error { return filex.FillWithRandom(filepath, filesize, bufcap, bufnum, shell) } + +// FillWithCryptoRandom uses crypto/rand secure random bytes. +// Throughput may be lower than FillWithRandom. +func FillWithCryptoRandom(filepath string, filesize int, bufcap int, shell func(float64)) error { + return filex.FillWithCryptoRandom(filepath, filesize, bufcap, shell) +} diff --git a/filex/file.go b/filex/file.go index c0a7160..ccbc3a5 100644 --- a/filex/file.go +++ b/filex/file.go @@ -2,10 +2,11 @@ package filex import ( "bufio" + crand "crypto/rand" "errors" "fmt" "io" - "math/rand" + mrand "math/rand" "os" "path/filepath" "regexp" @@ -15,6 +16,8 @@ import ( "time" ) +var ErrInvalidSplitPattern = errors.New("split dst pattern must contain exactly one '*'") + func Attach(src, dst, output string) error { fpsrc, err := os.Open(src) if err != nil { @@ -81,6 +84,9 @@ func SplitFile(src, dst string, num int, bynum bool, progress func(float64)) err if num <= 0 { return errors.New("num must be greater than zero") } + if strings.Count(dst, "*") != 1 { + return ErrInvalidSplitPattern + } fpsrc, err := os.Open(src) if err != nil { @@ -244,6 +250,8 @@ func MergeFile(src, dst string, progress func(float64)) error { return nil } +// FillWithRandom fills file with pseudo-random bytes generated by math/rand. +// It is fast but not cryptographically secure. func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(float64)) error { if filesize < 0 { return errors.New("filesize must be non-negative") @@ -258,7 +266,7 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo bufcap = filesize } - rand.Seed(time.Now().UnixNano()) + r := mrand.New(mrand.NewSource(time.Now().UnixNano())) fp, err := os.Create(path) if err != nil { @@ -267,18 +275,17 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo defer fp.Close() writer := bufio.NewWriter(fp) - defer writer.Flush() if filesize == 0 { reportProgress(progress, 0, 0) - return nil + return writer.Flush() } pool := make([][]byte, 0, bufnum) for i := 0; i < bufnum; i++ { b := make([]byte, bufcap) for j := 0; j < bufcap; j++ { - b[j] = byte(rand.Intn(256)) + b[j] = byte(r.Intn(256)) } pool = append(pool, b) } @@ -289,14 +296,59 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo if filesize-written < chunk { chunk = filesize - written } - buf := pool[rand.Intn(len(pool))][:chunk] + buf := pool[r.Intn(len(pool))][:chunk] if _, err := writer.Write(buf); err != nil { return err } written += chunk reportProgress(progress, int64(written), int64(filesize)) } - return nil + return writer.Flush() +} + +// FillWithCryptoRandom fills file with cryptographically secure random bytes from crypto/rand. +// Security is stronger than FillWithRandom, but throughput may be lower. +func FillWithCryptoRandom(path string, filesize, bufcap int, progress func(float64)) error { + if filesize < 0 { + return errors.New("filesize must be non-negative") + } + if bufcap <= 0 { + bufcap = 1 + } + if bufcap > filesize && filesize > 0 { + bufcap = filesize + } + + fp, err := os.Create(path) + if err != nil { + return err + } + defer fp.Close() + + writer := bufio.NewWriter(fp) + + if filesize == 0 { + reportProgress(progress, 0, 0) + return writer.Flush() + } + + buf := make([]byte, bufcap) + written := 0 + for written < filesize { + chunk := bufcap + if filesize-written < chunk { + chunk = filesize - written + } + if _, err := io.ReadFull(crand.Reader, buf[:chunk]); err != nil { + return err + } + if _, err := writer.Write(buf[:chunk]); err != nil { + return err + } + written += chunk + reportProgress(progress, int64(written), int64(filesize)) + } + return writer.Flush() } func reportProgress(progress func(float64), current, total int64) { diff --git a/filex/file_random_test.go b/filex/file_random_test.go new file mode 100644 index 0000000..b74e857 --- /dev/null +++ b/filex/file_random_test.go @@ -0,0 +1,65 @@ +package filex + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestFillWithRandomAndCryptoRandom(t *testing.T) { + dir := t.TempDir() + pseudoPath := filepath.Join(dir, "pseudo.bin") + securePath := filepath.Join(dir, "secure.bin") + + if err := FillWithRandom(pseudoPath, 2048, 128, 4, nil); err != nil { + t.Fatalf("FillWithRandom failed: %v", err) + } + if err := FillWithCryptoRandom(securePath, 2048, 128, nil); err != nil { + t.Fatalf("FillWithCryptoRandom failed: %v", err) + } + + pseudoInfo, err := os.Stat(pseudoPath) + if err != nil { + t.Fatalf("stat pseudo file failed: %v", err) + } + if pseudoInfo.Size() != 2048 { + t.Fatalf("unexpected pseudo size: %d", pseudoInfo.Size()) + } + + secureInfo, err := os.Stat(securePath) + if err != nil { + t.Fatalf("stat secure file failed: %v", err) + } + if secureInfo.Size() != 2048 { + t.Fatalf("unexpected secure size: %d", secureInfo.Size()) + } + + pseudo, err := os.ReadFile(pseudoPath) + if err != nil { + t.Fatalf("read pseudo file failed: %v", err) + } + secure, err := os.ReadFile(securePath) + if err != nil { + t.Fatalf("read secure file failed: %v", err) + } + + if bytes.Equal(secure, make([]byte, len(secure))) { + t.Fatalf("secure random output should not be all zero") + } + if bytes.Equal(pseudo, secure) { + t.Fatalf("pseudo and secure random outputs unexpectedly identical") + } +} + +func TestFillWithRandomInvalidArgs(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.bin") + + if err := FillWithRandom(path, -1, 16, 1, nil); err == nil { + t.Fatalf("expected FillWithRandom negative filesize error") + } + if err := FillWithCryptoRandom(path, -1, 16, nil); err == nil { + t.Fatalf("expected FillWithCryptoRandom negative filesize error") + } +} diff --git a/filex/file_split_test.go b/filex/file_split_test.go new file mode 100644 index 0000000..097e90d --- /dev/null +++ b/filex/file_split_test.go @@ -0,0 +1,36 @@ +package filex + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "testing" +) + +func TestSplitFilePatternValidation(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "src.bin") + if err := os.WriteFile(src, bytes.Repeat([]byte{0x7f}, 64), 0o600); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + if err := SplitFile(src, filepath.Join(dir, "part.bin"), 2, true, nil); !errors.Is(err, ErrInvalidSplitPattern) { + t.Fatalf("expected ErrInvalidSplitPattern for missing '*', got: %v", err) + } + if err := SplitFile(src, filepath.Join(dir, "part_*_*.bin"), 2, true, nil); !errors.Is(err, ErrInvalidSplitPattern) { + t.Fatalf("expected ErrInvalidSplitPattern for multiple '*', got: %v", err) + } + + pattern := filepath.Join(dir, "part_*.bin") + if err := SplitFile(src, pattern, 2, true, nil); err != nil { + t.Fatalf("SplitFile valid pattern failed: %v", err) + } + + if _, err := os.Stat(filepath.Join(dir, "part_0.bin")); err != nil { + t.Fatalf("part_0.bin not found: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "part_1.bin")); err != nil { + t.Fatalf("part_1.bin not found: %v", err) + } +} diff --git a/go.mod b/go.mod index 693cc34..bdbfd73 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,13 @@ go 1.24.0 require ( github.com/emmansun/gmsm v0.41.1 + github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.48.0 ) -require golang.org/x/sys v0.41.0 // indirect +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.41.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum index dc1cad3..5f3e76a 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,18 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w= github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/hashx/hashx_test.go b/hashx/hashx_test.go index a4ccda3..ae9c727 100644 --- a/hashx/hashx_test.go +++ b/hashx/hashx_test.go @@ -1,6 +1,7 @@ package hashx import ( + "bytes" "crypto/md5" "encoding/hex" "os" @@ -87,3 +88,52 @@ func TestFileSumUnsupportedMethod(t *testing.T) { t.Fatalf("expected unsupported method error") } } +func TestPBKDF2SHA256Vector(t *testing.T) { + got, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 1, 32) + if err != nil { + t.Fatalf("DerivePBKDF2SHA256Key failed: %v", err) + } + const want = "120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b" + if hex.EncodeToString(got) != want { + t.Fatalf("pbkdf2-sha256 vector mismatch: got %x want %s", got, want) + } +} + +func TestPBKDF2AndArgon2Deterministic(t *testing.T) { + a, err := DerivePBKDF2SHA512Key("password", []byte("salt"), 1000, 32) + if err != nil { + t.Fatalf("DerivePBKDF2SHA512Key failed: %v", err) + } + b, err := DerivePBKDF2SHA512Key("password", []byte("salt"), 1000, 32) + if err != nil { + t.Fatalf("DerivePBKDF2SHA512Key failed: %v", err) + } + if !bytes.Equal(a, b) { + t.Fatalf("pbkdf2-sha512 must be deterministic") + } + + params := Argon2Params{Time: 1, Memory: 32 * 1024, Threads: 1, KeyLen: 32} + argA, err := DeriveArgon2idKey("password", []byte("salt-salt"), params) + if err != nil { + t.Fatalf("DeriveArgon2idKey failed: %v", err) + } + argB, err := DeriveArgon2idKey("password", []byte("salt-salt"), params) + if err != nil { + t.Fatalf("DeriveArgon2idKey failed: %v", err) + } + if !bytes.Equal(argA, argB) { + t.Fatalf("argon2id must be deterministic") + } +} + +func TestKDFInvalidParams(t *testing.T) { + if _, err := DerivePBKDF2SHA256Key("password", nil, 1, 32); err == nil { + t.Fatalf("expected pbkdf2 salt error") + } + if _, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 0, 32); err == nil { + t.Fatalf("expected pbkdf2 iterations error") + } + if _, err := DeriveArgon2idKey("password", []byte("salt"), Argon2Params{}); err == nil { + t.Fatalf("expected argon2 params error") + } +} diff --git a/hashx/kdf.go b/hashx/kdf.go new file mode 100644 index 0000000..857578e --- /dev/null +++ b/hashx/kdf.go @@ -0,0 +1,90 @@ +package hashx + +import ( + "crypto/pbkdf2" + "crypto/sha256" + "crypto/sha512" + "errors" + + "golang.org/x/crypto/argon2" +) + +var ( + ErrInvalidKDFSalt = errors.New("kdf salt must be non-empty") + ErrInvalidKDFIterations = errors.New("kdf iterations must be > 0") + ErrInvalidKDFKeyLength = errors.New("kdf key length must be > 0") + ErrInvalidArgon2Params = errors.New("argon2 params must have time, memory, threads, and key length > 0") +) + +// Argon2Params configures Argon2 key derivation. +type Argon2Params struct { + Time uint32 + Memory uint32 + Threads uint8 + KeyLen uint32 +} + +// DefaultArgon2idParams returns a conservative default suitable for general online usage. +func DefaultArgon2idParams() Argon2Params { + return Argon2Params{ + Time: 1, + Memory: 64 * 1024, // 64 MiB in KiB + Threads: 4, + KeyLen: 32, + } +} + +func validatePBKDF2Params(salt []byte, iterations, keyLen int) error { + if len(salt) == 0 { + return ErrInvalidKDFSalt + } + if iterations <= 0 { + return ErrInvalidKDFIterations + } + if keyLen <= 0 { + return ErrInvalidKDFKeyLength + } + return nil +} + +func validateArgon2Params(salt []byte, params Argon2Params) error { + if len(salt) == 0 { + return ErrInvalidKDFSalt + } + if params.Time == 0 || params.Memory == 0 || params.Threads == 0 || params.KeyLen == 0 { + return ErrInvalidArgon2Params + } + return nil +} + +// DerivePBKDF2SHA256Key derives a key with PBKDF2-HMAC-SHA256. +func DerivePBKDF2SHA256Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) { + if err := validatePBKDF2Params(salt, iterations, keyLen); err != nil { + return nil, err + } + return pbkdf2.Key(sha256.New, password, salt, iterations, keyLen) +} + +// DerivePBKDF2SHA512Key derives a key with PBKDF2-HMAC-SHA512. +func DerivePBKDF2SHA512Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) { + if err := validatePBKDF2Params(salt, iterations, keyLen); err != nil { + return nil, err + } + return pbkdf2.Key(sha512.New, password, salt, iterations, keyLen) +} + +// DeriveArgon2idKey derives a key with Argon2id. +func DeriveArgon2idKey(password string, salt []byte, params Argon2Params) ([]byte, error) { + if err := validateArgon2Params(salt, params); err != nil { + return nil, err + } + return argon2.IDKey([]byte(password), salt, params.Time, params.Memory, params.Threads, params.KeyLen), nil +} + +// DeriveArgon2iKey derives a key with Argon2i. +func DeriveArgon2iKey(password string, salt []byte, params Argon2Params) ([]byte, error) { + if err := validateArgon2Params(salt, params); err != nil { + return nil, err + } + return argon2.Key([]byte(password), salt, params.Time, params.Memory, params.Threads, params.KeyLen), nil +} diff --git a/hmac.go b/hmac.go index 61eb237..8185f8d 100644 --- a/hmac.go +++ b/hmac.go @@ -10,6 +10,14 @@ func HmacMd4Str(message, key []byte) string { return macx.HmacMd4Str(message, key) } +func VerifyHmacMd4(message, key, sum []byte) bool { + return macx.VerifyHmacMd4(message, key, sum) +} + +func VerifyHmacMd4Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacMd4Str(message, key, hexSum) +} + func HmacMd5(message, key []byte) []byte { return macx.HmacMd5(message, key) } @@ -18,6 +26,14 @@ func HmacMd5Str(message, key []byte) string { return macx.HmacMd5Str(message, key) } +func VerifyHmacMd5(message, key, sum []byte) bool { + return macx.VerifyHmacMd5(message, key, sum) +} + +func VerifyHmacMd5Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacMd5Str(message, key, hexSum) +} + func HmacSHA1(message, key []byte) []byte { return macx.HmacSHA1(message, key) } @@ -26,6 +42,14 @@ func HmacSHA1Str(message, key []byte) string { return macx.HmacSHA1Str(message, key) } +func VerifyHmacSHA1(message, key, sum []byte) bool { + return macx.VerifyHmacSHA1(message, key, sum) +} + +func VerifyHmacSHA1Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacSHA1Str(message, key, hexSum) +} + func HmacSHA256(message, key []byte) []byte { return macx.HmacSHA256(message, key) } @@ -34,6 +58,14 @@ func HmacSHA256Str(message, key []byte) string { return macx.HmacSHA256Str(message, key) } +func VerifyHmacSHA256(message, key, sum []byte) bool { + return macx.VerifyHmacSHA256(message, key, sum) +} + +func VerifyHmacSHA256Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacSHA256Str(message, key, hexSum) +} + func HmacSHA384(message, key []byte) []byte { return macx.HmacSHA384(message, key) } @@ -42,6 +74,14 @@ func HmacSHA384Str(message, key []byte) string { return macx.HmacSHA384Str(message, key) } +func VerifyHmacSHA384(message, key, sum []byte) bool { + return macx.VerifyHmacSHA384(message, key, sum) +} + +func VerifyHmacSHA384Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacSHA384Str(message, key, hexSum) +} + func HmacSHA512(message, key []byte) []byte { return macx.HmacSHA512(message, key) } @@ -50,6 +90,14 @@ func HmacSHA512Str(message, key []byte) string { return macx.HmacSHA512Str(message, key) } +func VerifyHmacSHA512(message, key, sum []byte) bool { + return macx.VerifyHmacSHA512(message, key, sum) +} + +func VerifyHmacSHA512Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacSHA512Str(message, key, hexSum) +} + func HmacSHA224(message, key []byte) []byte { return macx.HmacSHA224(message, key) } @@ -58,6 +106,14 @@ func HmacSHA224Str(message, key []byte) string { return macx.HmacSHA224Str(message, key) } +func VerifyHmacSHA224(message, key, sum []byte) bool { + return macx.VerifyHmacSHA224(message, key, sum) +} + +func VerifyHmacSHA224Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacSHA224Str(message, key, hexSum) +} + func HmacRipeMd160(message, key []byte) []byte { return macx.HmacRipeMd160(message, key) } @@ -65,3 +121,11 @@ func HmacRipeMd160(message, key []byte) []byte { func HmacRipeMd160Str(message, key []byte) string { return macx.HmacRipeMd160Str(message, key) } + +func VerifyHmacRipeMd160(message, key, sum []byte) bool { + return macx.VerifyHmacRipeMd160(message, key, sum) +} + +func VerifyHmacRipeMd160Str(message, key []byte, hexSum string) bool { + return macx.VerifyHmacRipeMd160Str(message, key, hexSum) +} diff --git a/kdf.go b/kdf.go new file mode 100644 index 0000000..df01930 --- /dev/null +++ b/kdf.go @@ -0,0 +1,25 @@ +package starcrypto + +import "b612.me/starcrypto/hashx" + +type Argon2Params = hashx.Argon2Params + +func DefaultArgon2idParams() Argon2Params { + return hashx.DefaultArgon2idParams() +} + +func DerivePBKDF2SHA256Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) { + return hashx.DerivePBKDF2SHA256Key(password, salt, iterations, keyLen) +} + +func DerivePBKDF2SHA512Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) { + return hashx.DerivePBKDF2SHA512Key(password, salt, iterations, keyLen) +} + +func DeriveArgon2idKey(password string, salt []byte, params Argon2Params) ([]byte, error) { + return hashx.DeriveArgon2idKey(password, salt, params) +} + +func DeriveArgon2iKey(password string, salt []byte, params Argon2Params) ([]byte, error) { + return hashx.DeriveArgon2iKey(password, salt, params) +} diff --git a/macx/hmac.go b/macx/hmac.go index a286ea6..e1aa61b 100644 --- a/macx/hmac.go +++ b/macx/hmac.go @@ -8,6 +8,7 @@ import ( "crypto/sha512" "encoding/hex" "hash" + "strings" "golang.org/x/crypto/md4" "golang.org/x/crypto/ripemd160" @@ -23,6 +24,19 @@ func chmacStr(message, key []byte, f func() hash.Hash) string { return hex.EncodeToString(chmac(message, key, f)) } +func verifyHMAC(message, key, sum []byte, f func() hash.Hash) bool { + expected := chmac(message, key, f) + return hmac.Equal(expected, sum) +} + +func verifyHMACStr(message, key []byte, hexSum string, f func() hash.Hash) bool { + sum, err := hex.DecodeString(strings.TrimSpace(hexSum)) + if err != nil { + return false + } + return verifyHMAC(message, key, sum, f) +} + func HmacMd4(message, key []byte) []byte { return chmac(message, key, md4.New) } @@ -31,6 +45,14 @@ func HmacMd4Str(message, key []byte) string { return chmacStr(message, key, md4.New) } +func VerifyHmacMd4(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, md4.New) +} + +func VerifyHmacMd4Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, md4.New) +} + func HmacMd5(message, key []byte) []byte { return chmac(message, key, md5.New) } @@ -39,6 +61,14 @@ func HmacMd5Str(message, key []byte) string { return chmacStr(message, key, md5.New) } +func VerifyHmacMd5(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, md5.New) +} + +func VerifyHmacMd5Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, md5.New) +} + func HmacSHA1(message, key []byte) []byte { return chmac(message, key, sha1.New) } @@ -47,6 +77,14 @@ func HmacSHA1Str(message, key []byte) string { return chmacStr(message, key, sha1.New) } +func VerifyHmacSHA1(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, sha1.New) +} + +func VerifyHmacSHA1Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, sha1.New) +} + func HmacSHA256(message, key []byte) []byte { return chmac(message, key, sha256.New) } @@ -55,6 +93,14 @@ func HmacSHA256Str(message, key []byte) string { return chmacStr(message, key, sha256.New) } +func VerifyHmacSHA256(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, sha256.New) +} + +func VerifyHmacSHA256Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, sha256.New) +} + func HmacSHA384(message, key []byte) []byte { return chmac(message, key, sha512.New384) } @@ -63,6 +109,14 @@ func HmacSHA384Str(message, key []byte) string { return chmacStr(message, key, sha512.New384) } +func VerifyHmacSHA384(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, sha512.New384) +} + +func VerifyHmacSHA384Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, sha512.New384) +} + func HmacSHA512(message, key []byte) []byte { return chmac(message, key, sha512.New) } @@ -71,6 +125,14 @@ func HmacSHA512Str(message, key []byte) string { return chmacStr(message, key, sha512.New) } +func VerifyHmacSHA512(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, sha512.New) +} + +func VerifyHmacSHA512Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, sha512.New) +} + func HmacSHA224(message, key []byte) []byte { return chmac(message, key, sha256.New224) } @@ -79,6 +141,14 @@ func HmacSHA224Str(message, key []byte) string { return chmacStr(message, key, sha256.New224) } +func VerifyHmacSHA224(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, sha256.New224) +} + +func VerifyHmacSHA224Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, sha256.New224) +} + func HmacRipeMd160(message, key []byte) []byte { return chmac(message, key, ripemd160.New) } @@ -86,3 +156,11 @@ func HmacRipeMd160(message, key []byte) []byte { func HmacRipeMd160Str(message, key []byte) string { return chmacStr(message, key, ripemd160.New) } + +func VerifyHmacRipeMd160(message, key, sum []byte) bool { + return verifyHMAC(message, key, sum, ripemd160.New) +} + +func VerifyHmacRipeMd160Str(message, key []byte, hexSum string) bool { + return verifyHMACStr(message, key, hexSum, ripemd160.New) +} diff --git a/macx/hmac_test.go b/macx/hmac_test.go new file mode 100644 index 0000000..c210d3e --- /dev/null +++ b/macx/hmac_test.go @@ -0,0 +1,58 @@ +package macx + +import "testing" + +type hmacCase struct { + name string + sum func([]byte, []byte) []byte + sumStr func([]byte, []byte) string + verify func([]byte, []byte, []byte) bool + verifyStr func([]byte, []byte, string) bool +} + +func TestHMACVerifyCases(t *testing.T) { + msg := []byte("macx-verify-message") + key := []byte("macx-verify-key") + + cases := []hmacCase{ + {"md4", HmacMd4, HmacMd4Str, VerifyHmacMd4, VerifyHmacMd4Str}, + {"md5", HmacMd5, HmacMd5Str, VerifyHmacMd5, VerifyHmacMd5Str}, + {"sha1", HmacSHA1, HmacSHA1Str, VerifyHmacSHA1, VerifyHmacSHA1Str}, + {"sha224", HmacSHA224, HmacSHA224Str, VerifyHmacSHA224, VerifyHmacSHA224Str}, + {"sha256", HmacSHA256, HmacSHA256Str, VerifyHmacSHA256, VerifyHmacSHA256Str}, + {"sha384", HmacSHA384, HmacSHA384Str, VerifyHmacSHA384, VerifyHmacSHA384Str}, + {"sha512", HmacSHA512, HmacSHA512Str, VerifyHmacSHA512, VerifyHmacSHA512Str}, + {"ripemd160", HmacRipeMd160, HmacRipeMd160Str, VerifyHmacRipeMd160, VerifyHmacRipeMd160Str}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + sum := tc.sum(msg, key) + if !tc.verify(msg, key, sum) { + t.Fatalf("verify bytes should pass") + } + + hexSum := tc.sumStr(msg, key) + if !tc.verifyStr(msg, key, hexSum) { + t.Fatalf("verify hex should pass") + } + if !tc.verifyStr(msg, key, " \t"+hexSum+"\n") { + t.Fatalf("verify hex with spaces should pass") + } + + bad := make([]byte, len(sum)) + copy(bad, sum) + bad[0] ^= 0xff + if tc.verify(msg, key, bad) { + t.Fatalf("verify bytes should fail for tampered sum") + } + + if tc.verifyStr(msg, key, "not-hex") { + t.Fatalf("verify hex should fail for invalid hex") + } + if tc.verify([]byte("wrong-msg"), key, sum) { + t.Fatalf("verify bytes should fail for wrong message") + } + }) + } +} diff --git a/rsa.go b/rsa.go index 19ab67e..5e1037c 100644 --- a/rsa.go +++ b/rsa.go @@ -14,6 +14,14 @@ func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) return asymm.EncodeRsaPrivateKey(private, secret) } +func EncodeRsaPrivateKeyWithLegacy(private *rsa.PrivateKey, secret string, legacy bool) ([]byte, error) { + return asymm.EncodeRsaPrivateKeyWithLegacy(private, secret, legacy) +} + +func EncodeRsaPrivateKeyPKCS8(private *rsa.PrivateKey, secret string) ([]byte, error) { + return asymm.EncodeRsaPrivateKeyPKCS8(private, secret) +} + func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) { return asymm.EncodeRsaPublicKey(public) } @@ -22,6 +30,14 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro return asymm.DecodeRsaPrivateKey(private, password) } +func DecodeRsaPrivateKeyWithLegacy(private []byte, password string, legacy bool) (*rsa.PrivateKey, error) { + return asymm.DecodeRsaPrivateKeyWithLegacy(private, password, legacy) +} + +func DecodeRsaPrivateKeyPKCS8(private []byte, password string) (*rsa.PrivateKey, error) { + return asymm.DecodeRsaPrivateKeyPKCS8(private, password) +} + func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) { return asymm.DecodeRsaPublicKey(pubStr) } @@ -34,6 +50,10 @@ func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { return asymm.GenerateRsaSSHKeyPair(bits, secret) } +func GenerateRsaSSHKeyPairWithLegacy(bits int, secret string, legacy bool) (string, string, error) { + return asymm.GenerateRsaSSHKeyPairWithLegacy(bits, secret, legacy) +} + func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) { return asymm.RSAEncrypt(pub, data) } @@ -42,6 +62,14 @@ func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) { return asymm.RSADecrypt(prikey, data) } +func RSAEncryptOAEP(pub *rsa.PublicKey, data, label []byte, hashType crypto.Hash) ([]byte, error) { + return asymm.RSAEncryptOAEP(pub, data, label, hashType) +} + +func RSADecryptOAEP(prikey *rsa.PrivateKey, data, label []byte, hashType crypto.Hash) ([]byte, error) { + return asymm.RSADecryptOAEP(prikey, data, label, hashType) +} + func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) { return asymm.RSASign(msg, priKey, password, hashType) } @@ -50,6 +78,14 @@ func RSAVerify(data, msg, pubKey []byte, hashType crypto.Hash) error { return asymm.RSAVerify(data, msg, pubKey, hashType) } +func RSASignPSS(msg, priKey []byte, password string, hashType crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) { + return asymm.RSASignPSS(msg, priKey, password, hashType, opts) +} + +func RSAVerifyPSS(sig, msg, pubKey []byte, hashType crypto.Hash, opts *rsa.PSSOptions) error { + return asymm.RSAVerifyPSS(sig, msg, pubKey, hashType, opts) +} + func RSAEncryptByPrivkey(privt *rsa.PrivateKey, data []byte) ([]byte, error) { return asymm.RSAEncryptByPrivkey(privt, data) } diff --git a/sm4.go b/sm4.go index df7fb04..e79dbe4 100644 --- a/sm4.go +++ b/sm4.go @@ -46,6 +46,14 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { return symm.DecryptSM4GCM(ciphertext, key, nonce, aad) } +func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex) +} + +func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.DecryptSM4GCMChunk(ciphertext, key, nonce, aad, chunkIndex) +} + func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { return symm.EncryptSM4GCMStream(dst, src, key, nonce, aad) } @@ -54,6 +62,30 @@ func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) e return symm.DecryptSM4GCMStream(dst, src, key, nonce, aad) } +func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) { + return symm.EncryptSM4CCM(plain, key, nonce, aad) +} + +func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + return symm.DecryptSM4CCM(ciphertext, key, nonce, aad) +} + +func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex) +} + +func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + return symm.DecryptSM4CCMChunk(ciphertext, key, nonce, aad, chunkIndex) +} + +func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.EncryptSM4CCMStream(dst, src, key, nonce, aad) +} + +func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.DecryptSM4CCMStream(dst, src, key, nonce, aad) +} + func EncryptSM4CFB(origData, key []byte) ([]byte, error) { return symm.EncryptSM4CFB(origData, key) } @@ -102,6 +134,14 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) { return symm.DecryptSM4CTR(src, key, iv) } +func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) { + return symm.EncryptSM4CTRAt(data, key, iv, offset) +} + +func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) { + return symm.DecryptSM4CTRAt(src, key, iv, offset) +} + func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { return symm.EncryptSM4ECBStream(dst, src, key, paddingType) } @@ -141,3 +181,62 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { return symm.DecryptSM4CTRStream(dst, src, key, iv) } + +func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) { + return symm.EncryptSM4CFB8(data, key, iv) +} + +func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) { + return symm.DecryptSM4CFB8(src, key, iv) +} + +func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptSM4CFB8Stream(dst, src, key, iv) +} + +func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptSM4CFB8Stream(dst, src, key, iv) +} + +func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) { + return symm.DecryptSM4ECBBlocks(src, key) +} + +func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + return symm.DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock) +} + +func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + return symm.DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock) +} +func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return symm.EncryptSM4XTS(plain, key1, key2, dataUnitSize) +} + +func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return symm.DecryptSM4XTS(ciphertext, key1, key2, dataUnitSize) +} + +func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return symm.EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) +} + +func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return symm.DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, dataUnitIndex) +} + +func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return symm.EncryptSM4XTSStream(dst, src, key1, key2, dataUnitSize) +} + +func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return symm.DecryptSM4XTSStream(dst, src, key1, key2, dataUnitSize) +} + +func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return symm.EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex) +} + +func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return symm.DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex) +} diff --git a/symm/aes.go b/symm/aes.go index 1c4518b..4bcbfa6 100644 --- a/symm/aes.go +++ b/symm/aes.go @@ -7,6 +7,7 @@ import ( "errors" "io" + "b612.me/starcrypto/ccm" "b612.me/starcrypto/paddingx" ) @@ -17,13 +18,27 @@ const ( ANSIX923PADDING = paddingx.ANSIX923 ) -var ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes") +var ( + ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes") + ErrInvalidCCMNonceLength = errors.New("ccm nonce length must be 12 bytes") +) + +const ( + aeadCCMTagSize = 16 + aeadCCMNonceSize = 12 +) func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return EncryptAesGCM(data, key, iv, nil) } + if normalizedMode == MODECCM { + return EncryptAesCCM(data, key, iv, nil) + } block, err := aes.NewCipher(key) if err != nil { @@ -34,9 +49,15 @@ func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return DecryptAesGCM(src, key, iv, nil) } + if normalizedMode == MODECCM { + return DecryptAesCCM(src, key, iv, nil) + } block, err := aes.NewCipher(key) if err != nil { @@ -47,9 +68,15 @@ func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) { func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return EncryptAesGCMStream(dst, src, key, iv, nil) } + if normalizedMode == MODECCM { + return EncryptAesCCMStream(dst, src, key, iv, nil) + } block, err := aes.NewCipher(key) if err != nil { @@ -60,9 +87,15 @@ func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return DecryptAesGCMStream(dst, src, key, iv, nil) } + if normalizedMode == MODECCM { + return DecryptAesCCMStream(dst, src, key, iv, nil) + } block, err := aes.NewCipher(key) if err != nil { @@ -80,6 +113,9 @@ func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error if mode == MODEGCM { return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return EncryptAesCCM(data, key, nonceFromOptions(cfg), cfg.AAD) + } return EncryptAes(data, key, cfg.IV, mode, cfg.Padding) } @@ -92,6 +128,9 @@ func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) if mode == MODEGCM { return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return DecryptAesCCM(src, key, nonceFromOptions(cfg), cfg.AAD) + } return DecryptAes(src, key, cfg.IV, mode, cfg.Padding) } @@ -104,6 +143,9 @@ func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts if mode == MODEGCM { return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return EncryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) } @@ -116,6 +158,9 @@ func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts if mode == MODEGCM { return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return DecryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) } @@ -149,30 +194,138 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { return gcm.Open(nil, nonce, ciphertext, aad) } +func newAesCCM(key []byte) (cipher.AEAD, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize) +} + +func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) { + aead, err := newAesCCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return aead.Seal(nil, nonce, plain, aad), nil +} + +func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + aead, err := newAesCCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return aead.Open(nil, nonce, ciphertext, aad) +} + +func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + aead, err := newAesCCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil +} + +func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + aead, err := newAesCCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex) +} + +func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + aead, err := newAesCCM(key) + if err != nil { + return err + } + if len(nonce) != aead.NonceSize() { + return ErrInvalidCCMNonceLength + } + return encryptCCMChunkedStream(dst, src, aead, nonce, aad) +} + +func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + aead, err := newAesCCM(key) + if err != nil { + return err + } + if len(nonce) != aead.NonceSize() { + return ErrInvalidCCMNonceLength + } + return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad) +} + +func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil +} + +func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex) +} + func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - plain, err := io.ReadAll(src) + block, err := aes.NewCipher(key) if err != nil { return err } - out, err := EncryptAesGCM(plain, key, nonce, aad) + gcm, err := cipher.NewGCM(block) if err != nil { return err } - _, err = dst.Write(out) - return err + if len(nonce) != gcm.NonceSize() { + return ErrInvalidGCMNonceLength + } + return encryptGCMChunkedStream(dst, src, gcm, nonce, aad) } func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - enc, err := io.ReadAll(src) + block, err := aes.NewCipher(key) if err != nil { return err } - out, err := DecryptAesGCM(enc, key, nonce, aad) + gcm, err := cipher.NewGCM(block) if err != nil { return err } - _, err = dst.Write(out) - return err + if len(nonce) != gcm.NonceSize() { + return ErrInvalidGCMNonceLength + } + return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad) } func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) { @@ -215,6 +368,18 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) { return DecryptAes(src, key, iv, MODECTR, "") } +func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return xorCTRAtOffset(block, data, iv, offset) +} + +func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) { + return EncryptAesCTRAt(src, key, iv, offset) +} + func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { return EncryptAesStream(dst, src, key, nil, MODEECB, paddingType) } @@ -329,3 +494,63 @@ func PKCS7Trimming(encrypted []byte, blockSize int) []byte { } return out } + +func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return encryptCFB8(block, data, iv) +} + +func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCFB8(block, src, iv) +} + +func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + block, err := aes.NewCipher(key) + if err != nil { + return err + } + return encryptCFB8Stream(block, dst, src, iv, false) +} + +func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + block, err := aes.NewCipher(key) + if err != nil { + return err + } + return encryptCFB8Stream(block, dst, src, iv, true) +} + +func DecryptAesECBBlocks(src, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return decryptECBBlocks(block, src) +} + +// DecryptAesCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later. +// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock. +func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCBCFromSecondBlock(block, src, prevCipherBlock) +} + +// DecryptAesCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later. +// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock. +func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCFBFromSecondBlock(block, src, prevCipherBlock) +} diff --git a/symm/bench_test.go b/symm/bench_test.go new file mode 100644 index 0000000..f974d3d --- /dev/null +++ b/symm/bench_test.go @@ -0,0 +1,98 @@ +package symm + +import ( + "bytes" + "testing" +) + +var ( + benchPlain4K = bytes.Repeat([]byte("0123456789abcdef"), 256) // 4 KiB + benchPlain256K = bytes.Repeat([]byte("0123456789abcdef"), 16384) // 256 KiB + benchAAD = []byte("benchmark-aad") + benchAESKey = []byte("0123456789abcdef") + benchSM4Key = []byte("0123456789abcdef") + benchXTSKey2 = []byte("fedcba9876543210") + benchNonce12Byte = []byte("123456789012") +) + +func BenchmarkAesGCMEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptAesGCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkAesCCMEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptAesCCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkAesXTSEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptAesXTS(benchPlain4K, benchAESKey, benchXTSKey2, 512); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSM4GCMEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptSM4GCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSM4CCMEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptSM4CCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSM4XTSEncrypt4K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain4K))) + for i := 0; i < b.N; i++ { + if _, err := EncryptSM4XTS(benchPlain4K, benchSM4Key, benchXTSKey2, 512); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkAesXTSStreamEncrypt256K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain256K))) + for i := 0; i < b.N; i++ { + var dst bytes.Buffer + if err := EncryptAesXTSStream(&dst, bytes.NewReader(benchPlain256K), benchAESKey, benchXTSKey2, 512); err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSM4XTSStreamEncrypt256K(b *testing.B) { + b.ReportAllocs() + b.SetBytes(int64(len(benchPlain256K))) + for i := 0; i < b.N; i++ { + var dst bytes.Buffer + if err := EncryptSM4XTSStream(&dst, bytes.NewReader(benchPlain256K), benchSM4Key, benchXTSKey2, 512); err != nil { + b.Fatal(err) + } + } +} diff --git a/symm/ccm_stream.go b/symm/ccm_stream.go new file mode 100644 index 0000000..a99a4c5 --- /dev/null +++ b/symm/ccm_stream.go @@ -0,0 +1,127 @@ +package symm + +import ( + "bytes" + "crypto/cipher" + "encoding/binary" + "errors" + "io" +) + +const ccmStreamMagic = "SCC1" + +var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk") + +func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte { + chunkNonce := deriveChunkNonce(nonce, chunkIndex) + return aead.Seal(nil, chunkNonce, plain, aad) +} + +func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + chunkNonce := deriveChunkNonce(nonce, chunkIndex) + return aead.Open(nil, chunkNonce, ciphertext, aad) +} + +func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil { + return err + } + + buf := make([]byte, gcmStreamChunkSize) + lenBuf := make([]byte, 4) + var chunkIndex uint64 + + for { + n, err := src.Read(buf) + if n > 0 { + sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex) + binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed))) + if _, werr := dst.Write(lenBuf); werr != nil { + return werr + } + if _, werr := dst.Write(sealed); werr != nil { + return werr + } + chunkIndex++ + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + header := make([]byte, len(ccmStreamMagic)) + n, err := io.ReadFull(src, header) + if err != nil { + if err == io.EOF { + return nil + } + if err != io.ErrUnexpectedEOF { + return err + } + return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad) + } + + if string(header) != ccmStreamMagic { + return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad) + } + return decryptCCMChunkedStream(dst, src, aead, nonce, aad) +} + +func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + lenBuf := make([]byte, 4) + maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead()) + var chunkIndex uint64 + + for { + _, err := io.ReadFull(src, lenBuf) + if err != nil { + if err == io.EOF { + return nil + } + if err == io.ErrUnexpectedEOF { + return io.ErrUnexpectedEOF + } + return err + } + + chunkLen := binary.BigEndian.Uint32(lenBuf) + if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen { + return ErrInvalidCCMStreamChunk + } + + chunk := make([]byte, chunkLen) + if _, err := io.ReadFull(src, chunk); err != nil { + if err == io.ErrUnexpectedEOF { + return io.ErrUnexpectedEOF + } + return err + } + + plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex) + if err != nil { + return err + } + if _, err := dst.Write(plain); err != nil { + return err + } + chunkIndex++ + } +} + +func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + enc, err := io.ReadAll(src) + if err != nil { + return err + } + plain, err := aead.Open(nil, nonce, enc, aad) + if err != nil { + return err + } + _, err = dst.Write(plain) + return err +} diff --git a/symm/cfb8.go b/symm/cfb8.go new file mode 100644 index 0000000..c1c8d94 --- /dev/null +++ b/symm/cfb8.go @@ -0,0 +1,103 @@ +package symm + +import ( + "crypto/cipher" + "errors" + "io" +) + +func encryptCFB8(block cipher.Block, data, iv []byte) ([]byte, error) { + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + reg := make([]byte, len(iv)) + copy(reg, iv) + regView := make([]byte, block.BlockSize()) + streamBlock := make([]byte, block.BlockSize()) + out := make([]byte, len(data)) + head := 0 + + for i := 0; i < len(data); i++ { + buildCFB8Register(regView, reg, head) + block.Encrypt(streamBlock, regView) + c := data[i] ^ streamBlock[0] + out[i] = c + advanceCFB8Register(reg, &head, c) + } + return out, nil +} + +func decryptCFB8(block cipher.Block, src, iv []byte) ([]byte, error) { + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + reg := make([]byte, len(iv)) + copy(reg, iv) + regView := make([]byte, block.BlockSize()) + streamBlock := make([]byte, block.BlockSize()) + out := make([]byte, len(src)) + head := 0 + + for i := 0; i < len(src); i++ { + buildCFB8Register(regView, reg, head) + block.Encrypt(streamBlock, regView) + p := src[i] ^ streamBlock[0] + out[i] = p + advanceCFB8Register(reg, &head, src[i]) + } + return out, nil +} + +func encryptCFB8Stream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, decrypt bool) error { + if len(iv) != block.BlockSize() { + return errors.New("iv length must match block size") + } + reg := make([]byte, len(iv)) + copy(reg, iv) + regView := make([]byte, block.BlockSize()) + streamBlock := make([]byte, block.BlockSize()) + buf := make([]byte, 32*1024) + out := make([]byte, 32*1024) + head := 0 + + for { + n, err := src.Read(buf) + if n > 0 { + for i := 0; i < n; i++ { + buildCFB8Register(regView, reg, head) + block.Encrypt(streamBlock, regView) + if decrypt { + out[i] = buf[i] ^ streamBlock[0] + advanceCFB8Register(reg, &head, buf[i]) + } else { + c := buf[i] ^ streamBlock[0] + out[i] = c + advanceCFB8Register(reg, &head, c) + } + } + if _, werr := dst.Write(out[:n]); werr != nil { + return werr + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +func buildCFB8Register(dst, reg []byte, head int) { + first := len(reg) - head + copy(dst, reg[head:]) + copy(dst[first:], reg[:head]) +} + +func advanceCFB8Register(reg []byte, head *int, feedback byte) { + reg[*head] = feedback + *head = *head + 1 + if *head == len(reg) { + *head = 0 + } +} diff --git a/symm/ctr_seek.go b/symm/ctr_seek.go new file mode 100644 index 0000000..3a241ba --- /dev/null +++ b/symm/ctr_seek.go @@ -0,0 +1,56 @@ +package symm + +import ( + "crypto/cipher" + "errors" +) + +var ( + ErrInvalidCTROffset = errors.New("ctr offset must be non-negative") + ErrCTRCounterOverflow = errors.New("ctr counter overflow") +) + +func xorCTRAtOffset(block cipher.Block, data, iv []byte, offset int64) ([]byte, error) { + if offset < 0 { + return nil, ErrInvalidCTROffset + } + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + + counter := make([]byte, len(iv)) + copy(counter, iv) + + blockSize := int64(block.BlockSize()) + blockOffset := uint64(offset / blockSize) + byteOffset := int(offset % blockSize) + if err := addUint64ToCounter(counter, blockOffset); err != nil { + return nil, err + } + + stream := cipher.NewCTR(block, counter) + if byteOffset > 0 { + skip := make([]byte, byteOffset) + stream.XORKeyStream(skip, skip) + } + + out := make([]byte, len(data)) + stream.XORKeyStream(out, data) + return out, nil +} + +func addUint64ToCounter(counter []byte, inc uint64) error { + if inc == 0 { + return nil + } + + for i := len(counter) - 1; i >= 0 && inc > 0; i-- { + sum := uint64(counter[i]) + (inc & 0xff) + counter[i] = byte(sum) + inc = (inc >> 8) + (sum >> 8) + } + if inc > 0 { + return ErrCTRCounterOverflow + } + return nil +} diff --git a/symm/fuzz_test.go b/symm/fuzz_test.go index 80bdf46..6e6f44f 100644 --- a/symm/fuzz_test.go +++ b/symm/fuzz_test.go @@ -70,3 +70,122 @@ func FuzzAesCBCStreamRoundTrip(f *testing.F) { } }) } +func xtsFuzzDataUnitSize(selector uint8) int { + sizes := [...]int{16, 32, 64, 128, 256, 512} + return sizes[int(selector)%len(sizes)] +} + +func xtsFuzzNormalizeData(data []byte, maxLen int) []byte { + if len(data) > maxLen { + data = data[:maxLen] + } + return data[:len(data)/16*16] +} + +func FuzzAesXTSRoundTrip(f *testing.F) { + f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0)) + f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3)) + + key1 := []byte("0123456789abcdef") + key2 := []byte("fedcba9876543210") + + f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) { + bounded := data + if len(bounded) > 4097 { + bounded = bounded[:4097] + } + plain := xtsFuzzNormalizeData(bounded, 4096) + dataUnitSize := xtsFuzzDataUnitSize(selector) + + enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) + if err != nil { + t.Fatalf("EncryptAesXTSAt failed: %v", err) + } + dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex) + if err != nil { + t.Fatalf("DecryptAesXTSAt failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes xts roundtrip mismatch") + } + + encStream := &bytes.Buffer{} + if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil { + t.Fatalf("EncryptAesXTSStreamAt failed: %v", err) + } + if !bytes.Equal(encStream.Bytes(), enc) { + t.Fatalf("aes xts bytes/stream encrypt mismatch") + } + + decStream := &bytes.Buffer{} + if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil { + t.Fatalf("DecryptAesXTSStreamAt failed: %v", err) + } + if !bytes.Equal(decStream.Bytes(), plain) { + t.Fatalf("aes xts stream decrypt mismatch") + } + + if len(bounded)%16 != 0 { + if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil { + t.Fatalf("expected aes xts bytes error for non-block-aligned input") + } + if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil { + t.Fatalf("expected aes xts stream error for non-block-aligned input") + } + } + }) +} + +func FuzzSM4XTSRoundTrip(f *testing.F) { + f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0)) + f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4)) + + key1 := []byte("0123456789abcdef") + key2 := []byte("fedcba9876543210") + + f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) { + bounded := data + if len(bounded) > 4097 { + bounded = bounded[:4097] + } + plain := xtsFuzzNormalizeData(bounded, 4096) + dataUnitSize := xtsFuzzDataUnitSize(selector) + + enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex) + if err != nil { + t.Fatalf("EncryptSM4XTSAt failed: %v", err) + } + dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex) + if err != nil { + t.Fatalf("DecryptSM4XTSAt failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 xts roundtrip mismatch") + } + + encStream := &bytes.Buffer{} + if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil { + t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err) + } + if !bytes.Equal(encStream.Bytes(), enc) { + t.Fatalf("sm4 xts bytes/stream encrypt mismatch") + } + + decStream := &bytes.Buffer{} + if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil { + t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err) + } + if !bytes.Equal(decStream.Bytes(), plain) { + t.Fatalf("sm4 xts stream decrypt mismatch") + } + + if len(bounded)%16 != 0 { + if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil { + t.Fatalf("expected sm4 xts bytes error for non-block-aligned input") + } + if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil { + t.Fatalf("expected sm4 xts stream error for non-block-aligned input") + } + } + }) +} diff --git a/symm/gcm_stream.go b/symm/gcm_stream.go new file mode 100644 index 0000000..3e12fda --- /dev/null +++ b/symm/gcm_stream.go @@ -0,0 +1,146 @@ +package symm + +import ( + "bytes" + "crypto/cipher" + "encoding/binary" + "errors" + "io" +) + +const ( + gcmStreamMagic = "SCG1" + gcmStreamChunkSize = 32 * 1024 +) + +var ErrInvalidGCMStreamChunk = errors.New("invalid gcm stream chunk") + +func encryptGCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte { + chunkNonce := deriveChunkNonce(nonce, chunkIndex) + return aead.Seal(nil, chunkNonce, plain, aad) +} + +func decryptGCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + chunkNonce := deriveChunkNonce(nonce, chunkIndex) + return aead.Open(nil, chunkNonce, ciphertext, aad) +} + +func encryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + if _, err := dst.Write([]byte(gcmStreamMagic)); err != nil { + return err + } + + buf := make([]byte, gcmStreamChunkSize) + lenBuf := make([]byte, 4) + var chunkIndex uint64 + + for { + n, err := src.Read(buf) + if n > 0 { + sealed := encryptGCMChunk(aead, buf[:n], nonce, aad, chunkIndex) + binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed))) + if _, werr := dst.Write(lenBuf); werr != nil { + return werr + } + if _, werr := dst.Write(sealed); werr != nil { + return werr + } + chunkIndex++ + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +func decryptGCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + header := make([]byte, len(gcmStreamMagic)) + n, err := io.ReadFull(src, header) + if err != nil { + if err == io.EOF { + return nil + } + if err != io.ErrUnexpectedEOF { + return err + } + return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad) + } + + if string(header) != gcmStreamMagic { + return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad) + } + return decryptGCMChunkedStream(dst, src, aead, nonce, aad) +} + +func decryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + lenBuf := make([]byte, 4) + maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead()) + var chunkIndex uint64 + + for { + _, err := io.ReadFull(src, lenBuf) + if err != nil { + if err == io.EOF { + return nil + } + if err == io.ErrUnexpectedEOF { + return io.ErrUnexpectedEOF + } + return err + } + + chunkLen := binary.BigEndian.Uint32(lenBuf) + if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen { + return ErrInvalidGCMStreamChunk + } + + chunk := make([]byte, chunkLen) + if _, err := io.ReadFull(src, chunk); err != nil { + if err == io.ErrUnexpectedEOF { + return io.ErrUnexpectedEOF + } + return err + } + + plain, err := decryptGCMChunk(aead, chunk, nonce, aad, chunkIndex) + if err != nil { + return err + } + if _, err := dst.Write(plain); err != nil { + return err + } + chunkIndex++ + } +} + +func decryptGCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error { + enc, err := io.ReadAll(src) + if err != nil { + return err + } + plain, err := aead.Open(nil, nonce, enc, aad) + if err != nil { + return err + } + _, err = dst.Write(plain) + return err +} + +func deriveChunkNonce(baseNonce []byte, chunkIndex uint64) []byte { + nonce := make([]byte, len(baseNonce)) + copy(nonce, baseNonce) + if len(nonce) < 8 { + return nonce + } + + var indexBytes [8]byte + binary.BigEndian.PutUint64(indexBytes[:], chunkIndex) + off := len(nonce) - 8 + for i := 0; i < 8; i++ { + nonce[off+i] ^= indexBytes[i] + } + return nonce +} diff --git a/symm/mode.go b/symm/mode.go index 561e6cd..843e66c 100644 --- a/symm/mode.go +++ b/symm/mode.go @@ -16,6 +16,7 @@ const ( MODEOFB = "OFB" MODECTR = "CTR" MODEGCM = "GCM" + MODECCM = "CCM" ) var ErrUnsupportedCipherMode = errors.New("cipher mode not supported") diff --git a/symm/options.go b/symm/options.go index 6b03b1f..4041e54 100644 --- a/symm/options.go +++ b/symm/options.go @@ -1,7 +1,7 @@ package symm // CipherOptions provides a unified configuration for symmetric APIs. -// For GCM mode, Nonce is used; if Nonce is empty, IV is used as fallback. +// For AEAD modes (GCM/CCM), Nonce must be set explicitly. type CipherOptions struct { Mode string Padding string @@ -18,8 +18,5 @@ func normalizeCipherOptions(opts *CipherOptions) CipherOptions { } func nonceFromOptions(opts CipherOptions) []byte { - if len(opts.Nonce) > 0 { - return opts.Nonce - } - return opts.IV + return opts.Nonce } diff --git a/symm/options_nonce_test.go b/symm/options_nonce_test.go new file mode 100644 index 0000000..3662ef2 --- /dev/null +++ b/symm/options_nonce_test.go @@ -0,0 +1,28 @@ +package symm + +import ( + "errors" + "testing" +) + +func TestAEADOptionsRequireNonce(t *testing.T) { + aesKey := []byte("0123456789abcdef") + sm4Key := []byte("0123456789abcdef") + plain := []byte("nonce-required") + + gcmIVOnly := &CipherOptions{Mode: MODEGCM, IV: []byte("123456789012")} + if _, err := EncryptAesWithOptions(plain, aesKey, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) { + t.Fatalf("expected ErrInvalidGCMNonceLength for AES GCM with IV-only opts, got: %v", err) + } + if _, err := EncryptSM4WithOptions(plain, sm4Key, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) { + t.Fatalf("expected ErrInvalidGCMNonceLength for SM4 GCM with IV-only opts, got: %v", err) + } + + ccmIVOnly := &CipherOptions{Mode: MODECCM, IV: []byte("123456789012")} + if _, err := EncryptAesWithOptions(plain, aesKey, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) { + t.Fatalf("expected ErrInvalidCCMNonceLength for AES CCM with IV-only opts, got: %v", err) + } + if _, err := EncryptSM4WithOptions(plain, sm4Key, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) { + t.Fatalf("expected ErrInvalidCCMNonceLength for SM4 CCM with IV-only opts, got: %v", err) + } +} diff --git a/symm/segment_decrypt.go b/symm/segment_decrypt.go new file mode 100644 index 0000000..e27d1b2 --- /dev/null +++ b/symm/segment_decrypt.go @@ -0,0 +1,54 @@ +package symm + +import ( + "crypto/cipher" + "errors" +) + +var ( + ErrSegmentNotFullBlocks = errors.New("ciphertext segment is not a full block size") +) + +func decryptECBBlocks(block cipher.Block, src []byte) ([]byte, error) { + if len(src) == 0 { + return []byte{}, nil + } + if len(src)%block.BlockSize() != 0 { + return nil, ErrSegmentNotFullBlocks + } + out := make([]byte, len(src)) + ecbDecryptBlocks(block, out, src) + return out, nil +} + +// decryptCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from the second block (or later). +// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV. +func decryptCBCFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) { + if len(src) == 0 { + return []byte{}, nil + } + if len(prevCipherBlock) != block.BlockSize() { + return nil, errors.New("prev cipher block length must match block size") + } + if len(src)%block.BlockSize() != 0 { + return nil, ErrSegmentNotFullBlocks + } + out := make([]byte, len(src)) + cipher.NewCBCDecrypter(block, prevCipherBlock).CryptBlocks(out, src) + return out, nil +} + +// decryptCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from the second block (or later). +// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV. +func decryptCFBFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) { + if len(src) == 0 { + return []byte{}, nil + } + if len(prevCipherBlock) != block.BlockSize() { + return nil, errors.New("prev cipher block length must match block size") + } + stream := cipher.NewCFBDecrypter(block, prevCipherBlock) + out := make([]byte, len(src)) + stream.XORKeyStream(out, src) + return out, nil +} diff --git a/symm/sm4.go b/symm/sm4.go index cdfd877..4a171f0 100644 --- a/symm/sm4.go +++ b/symm/sm4.go @@ -6,14 +6,21 @@ import ( "errors" "io" + "b612.me/starcrypto/ccm" "github.com/emmansun/gmsm/sm4" ) func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return EncryptSM4GCM(data, key, iv, nil) } + if normalizedMode == MODECCM { + return EncryptSM4CCM(data, key, iv, nil) + } block, err := sm4.NewCipher(key) if err != nil { @@ -24,9 +31,15 @@ func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return DecryptSM4GCM(src, key, iv, nil) } + if normalizedMode == MODECCM { + return DecryptSM4CCM(src, key, iv, nil) + } block, err := sm4.NewCipher(key) if err != nil { @@ -37,9 +50,15 @@ func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) { func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return EncryptSM4GCMStream(dst, src, key, iv, nil) } + if normalizedMode == MODECCM { + return EncryptSM4CCMStream(dst, src, key, iv, nil) + } block, err := sm4.NewCipher(key) if err != nil { @@ -50,9 +69,15 @@ func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { normalizedMode := normalizeCipherMode(mode) + if normalizedMode == "" { + normalizedMode = MODEGCM + } if normalizedMode == MODEGCM { return DecryptSM4GCMStream(dst, src, key, iv, nil) } + if normalizedMode == MODECCM { + return DecryptSM4CCMStream(dst, src, key, iv, nil) + } block, err := sm4.NewCipher(key) if err != nil { @@ -70,6 +95,9 @@ func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error if mode == MODEGCM { return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return EncryptSM4CCM(data, key, nonceFromOptions(cfg), cfg.AAD) + } return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding) } @@ -82,6 +110,9 @@ func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) if mode == MODEGCM { return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return DecryptSM4CCM(src, key, nonceFromOptions(cfg), cfg.AAD) + } return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding) } @@ -94,6 +125,9 @@ func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts if mode == MODEGCM { return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return EncryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) } @@ -106,6 +140,9 @@ func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts if mode == MODEGCM { return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) } + if mode == MODECCM { + return DecryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) } @@ -139,30 +176,138 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { return gcm.Open(nil, nonce, ciphertext, aad) } +func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil +} + +func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex) +} + func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - plain, err := io.ReadAll(src) + block, err := sm4.NewCipher(key) if err != nil { return err } - out, err := EncryptSM4GCM(plain, key, nonce, aad) + gcm, err := cipher.NewGCM(block) if err != nil { return err } - _, err = dst.Write(out) - return err + if len(nonce) != gcm.NonceSize() { + return ErrInvalidGCMNonceLength + } + return encryptGCMChunkedStream(dst, src, gcm, nonce, aad) } func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { - enc, err := io.ReadAll(src) + block, err := sm4.NewCipher(key) if err != nil { return err } - out, err := DecryptSM4GCM(enc, key, nonce, aad) + gcm, err := cipher.NewGCM(block) if err != nil { return err } - _, err = dst.Write(out) - return err + if len(nonce) != gcm.NonceSize() { + return ErrInvalidGCMNonceLength + } + return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad) +} + +func newSM4CCM(key []byte) (cipher.AEAD, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize) +} + +func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) { + aead, err := newSM4CCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return aead.Seal(nil, nonce, plain, aad), nil +} + +func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + aead, err := newSM4CCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return aead.Open(nil, nonce, ciphertext, aad) +} + +func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + aead, err := newSM4CCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil +} + +func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) { + aead, err := newSM4CCM(key) + if err != nil { + return nil, err + } + if len(nonce) != aead.NonceSize() { + return nil, ErrInvalidCCMNonceLength + } + return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex) +} + +func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + aead, err := newSM4CCM(key) + if err != nil { + return err + } + if len(nonce) != aead.NonceSize() { + return ErrInvalidCCMNonceLength + } + return encryptCCMChunkedStream(dst, src, aead, nonce, aad) +} + +func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + aead, err := newSM4CCM(key) + if err != nil { + return err + } + if len(nonce) != aead.NonceSize() { + return ErrInvalidCCMNonceLength + } + return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad) } func EncryptSM4CFB(origData, key []byte) ([]byte, error) { @@ -235,6 +380,18 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) { return DecryptSM4(src, key, iv, MODECTR, "") } +func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return xorCTRAtOffset(block, data, iv, offset) +} + +func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) { + return EncryptSM4CTRAt(src, key, iv, offset) +} + func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { return EncryptSM4Stream(dst, src, key, nil, MODEECB, paddingType) } @@ -274,3 +431,63 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { return DecryptSM4Stream(dst, src, key, iv, MODECTR, "") } + +func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return encryptCFB8(block, data, iv) +} + +func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCFB8(block, src, iv) +} + +func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + block, err := sm4.NewCipher(key) + if err != nil { + return err + } + return encryptCFB8Stream(block, dst, src, iv, false) +} + +func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error { + block, err := sm4.NewCipher(key) + if err != nil { + return err + } + return encryptCFB8Stream(block, dst, src, iv, true) +} + +func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return decryptECBBlocks(block, src) +} + +// DecryptSM4CBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later. +// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock. +func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCBCFromSecondBlock(block, src, prevCipherBlock) +} + +// DecryptSM4CFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later. +// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock. +func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return decryptCFBFromSecondBlock(block, src, prevCipherBlock) +} diff --git a/symm/symm_test.go b/symm/symm_test.go index 404598b..92ba072 100644 --- a/symm/symm_test.go +++ b/symm/symm_test.go @@ -6,21 +6,39 @@ import ( "testing" ) -func TestEncryptAesDefaultModeCBC(t *testing.T) { +func TestEncryptAesDefaultModeGCM(t *testing.T) { key := []byte("0123456789abcdef") - iv := []byte("abcdef9876543210") - plain := []byte("aes-default-mode-cbc") + nonce := []byte("123456789012") + plain := []byte("aes-default-mode-gcm") - encDefault, err := EncryptAes(plain, key, iv, "", "") + encDefault, err := EncryptAes(plain, key, nonce, "", "") if err != nil { t.Fatalf("EncryptAes default failed: %v", err) } - encCBC, err := EncryptAesCBC(plain, key, iv, "") + encGCM, err := EncryptAesGCM(plain, key, nonce, nil) if err != nil { - t.Fatalf("EncryptAesCBC failed: %v", err) + t.Fatalf("EncryptAesGCM failed: %v", err) } - if !bytes.Equal(encDefault, encCBC) { - t.Fatalf("default mode should match CBC mode") + if !bytes.Equal(encDefault, encGCM) { + t.Fatalf("default mode should match GCM mode") + } +} + +func TestEncryptSM4DefaultModeGCM(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("sm4-default-mode-gcm") + + encDefault, err := EncryptSM4(plain, key, nonce, "", "") + if err != nil { + t.Fatalf("EncryptSM4 default failed: %v", err) + } + encGCM, err := EncryptSM4GCM(plain, key, nonce, nil) + if err != nil { + t.Fatalf("EncryptSM4GCM failed: %v", err) + } + if !bytes.Equal(encDefault, encGCM) { + t.Fatalf("default mode should match GCM mode") } } @@ -529,6 +547,62 @@ func TestChaCha20Poly1305RFCVector(t *testing.T) { t.Fatalf("ChaCha20-Poly1305 vector mismatch: got %x want %x", enc, want) } } +func TestAesGCMStreamRoundTripChunked(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := bytes.Repeat([]byte("aes-gcm-stream-chunk-"), 10000) + + enc := &bytes.Buffer{} + if err := EncryptAesGCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil { + t.Fatalf("EncryptAesGCMStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesGCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil { + t.Fatalf("DecryptAesGCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes gcm stream mismatch") + } +} + +func TestAesGCMStreamLegacyCompatDecrypt(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-gcm-legacy-compat") + + legacyCipher, err := EncryptAesGCM(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptAesGCM failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesGCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil { + t.Fatalf("DecryptAesGCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes gcm legacy decrypt mismatch") + } +} + +func TestSM4GCMStreamRoundTripChunked(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := bytes.Repeat([]byte("sm4-gcm-stream-chunk-"), 10000) + + enc := &bytes.Buffer{} + if err := EncryptSM4GCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil { + t.Fatalf("EncryptSM4GCMStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptSM4GCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil { + t.Fatalf("DecryptSM4GCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("sm4 gcm stream mismatch") + } +} func TestAesOptionsDefaultToGCM(t *testing.T) { key := []byte("0123456789abcdef") @@ -598,6 +672,113 @@ func TestLargeStreamRoundTrip(t *testing.T) { } } +func TestAesCTRAtOffsetSegment(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("aes-ctr-offset-"), 256) + + full, err := EncryptAesCTR(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCTR failed: %v", err) + } + + offset := 137 + length := 521 + segCipher := full[offset : offset+length] + + segPlain, err := DecryptAesCTRAt(segCipher, key, iv, int64(offset)) + if err != nil { + t.Fatalf("DecryptAesCTRAt failed: %v", err) + } + if !bytes.Equal(segPlain, plain[offset:offset+length]) { + t.Fatalf("aes ctr offset decrypt mismatch") + } + + encSeg, err := EncryptAesCTRAt(plain[offset:offset+length], key, iv, int64(offset)) + if err != nil { + t.Fatalf("EncryptAesCTRAt failed: %v", err) + } + if !bytes.Equal(encSeg, segCipher) { + t.Fatalf("aes ctr offset encrypt mismatch") + } +} + +func TestSM4CTRAtOffsetSegment(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("sm4-ctr-offset-"), 256) + + full, err := EncryptSM4CTR(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CTR failed: %v", err) + } + + offset := 193 + length := 487 + segCipher := full[offset : offset+length] + + segPlain, err := DecryptSM4CTRAt(segCipher, key, iv, int64(offset)) + if err != nil { + t.Fatalf("DecryptSM4CTRAt failed: %v", err) + } + if !bytes.Equal(segPlain, plain[offset:offset+length]) { + t.Fatalf("sm4 ctr offset decrypt mismatch") + } + + encSeg, err := EncryptSM4CTRAt(plain[offset:offset+length], key, iv, int64(offset)) + if err != nil { + t.Fatalf("EncryptSM4CTRAt failed: %v", err) + } + if !bytes.Equal(encSeg, segCipher) { + t.Fatalf("sm4 ctr offset encrypt mismatch") + } +} + +func TestAesGCMChunkRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-gcm-chunk") + chunkIndex := uint64(7) + + enc, err := EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("EncryptAesGCMChunk failed: %v", err) + } + dec, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("DecryptAesGCMChunk failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes gcm chunk decrypt mismatch") + } + if _, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil { + t.Fatalf("expected decrypt error for wrong chunk index") + } +} + +func TestSM4GCMChunkRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("sm4-gcm-chunk") + chunkIndex := uint64(11) + + enc, err := EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("EncryptSM4GCMChunk failed: %v", err) + } + dec, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("DecryptSM4GCMChunk failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 gcm chunk decrypt mismatch") + } + if _, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil { + t.Fatalf("expected decrypt error for wrong chunk index") + } +} func mustHex(t *testing.T, s string) []byte { t.Helper() b, err := hex.DecodeString(s) @@ -606,3 +787,553 @@ func mustHex(t *testing.T, s string) []byte { } return b } + +func TestAesCFB8RoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("aes-cfb8-roundtrip-content") + + enc, err := EncryptAesCFB8(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCFB8 failed: %v", err) + } + dec, err := DecryptAesCFB8(enc, key, iv) + if err != nil { + t.Fatalf("DecryptAesCFB8 failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes cfb8 mismatch") + } +} + +func TestSM4CFB8RoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("sm4-cfb8-roundtrip-content") + + enc, err := EncryptSM4CFB8(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CFB8 failed: %v", err) + } + dec, err := DecryptSM4CFB8(enc, key, iv) + if err != nil { + t.Fatalf("DecryptSM4CFB8 failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 cfb8 mismatch") + } +} + +func TestAesCFB8StreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("aes-cfb8-stream-"), 512) + + enc := &bytes.Buffer{} + if err := EncryptAesCFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil { + t.Fatalf("EncryptAesCFB8Stream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesCFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil { + t.Fatalf("DecryptAesCFB8Stream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes cfb8 stream mismatch") + } +} + +func TestSM4CFB8StreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("sm4-cfb8-stream-"), 512) + + enc := &bytes.Buffer{} + if err := EncryptSM4CFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil { + t.Fatalf("EncryptSM4CFB8Stream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptSM4CFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil { + t.Fatalf("DecryptSM4CFB8Stream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("sm4 cfb8 stream mismatch") + } +} + +func TestAesSegmentDecryptModes(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 4) + + ecbEnc, err := EncryptAesECB(plain, key, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptAesECB failed: %v", err) + } + ecbDec, err := DecryptAesECBBlocks(ecbEnc, key) + if err != nil { + t.Fatalf("DecryptAesECBBlocks failed: %v", err) + } + if !bytes.Equal(ecbDec, plain) { + t.Fatalf("aes ecb segment mismatch") + } + + cbcEnc, err := EncryptAesCBC(plain, key, iv, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + cbcSegDec, err := DecryptAesCBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)]) + if err != nil { + t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err) + } + if !bytes.Equal(cbcSegDec, plain[len(iv):]) { + t.Fatalf("aes cbc from-second-block mismatch") + } + + cfbEnc, err := EncryptAesCFB(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCFB failed: %v", err) + } + cfbSegDec, err := DecryptAesCFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)]) + if err != nil { + t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err) + } + if !bytes.Equal(cfbSegDec, plain[len(iv):]) { + t.Fatalf("aes cfb from-second-block mismatch") + } +} + +func TestSM4SegmentDecryptModes(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 4) + + ecbEnc, err := EncryptSM4ECB(plain, key, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptSM4ECB failed: %v", err) + } + ecbDec, err := DecryptSM4ECBBlocks(ecbEnc, key) + if err != nil { + t.Fatalf("DecryptSM4ECBBlocks failed: %v", err) + } + if !bytes.Equal(ecbDec, plain) { + t.Fatalf("sm4 ecb segment mismatch") + } + + cbcEnc, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptSM4CBC failed: %v", err) + } + cbcSegDec, err := DecryptSM4CBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)]) + if err != nil { + t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err) + } + if !bytes.Equal(cbcSegDec, plain[len(iv):]) { + t.Fatalf("sm4 cbc from-second-block mismatch") + } + + cfbEnc, err := EncryptSM4CFBNoBlock(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CFB failed: %v", err) + } + cfbSegDec, err := DecryptSM4CFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)]) + if err != nil { + t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err) + } + if !bytes.Equal(cfbSegDec, plain[len(iv):]) { + t.Fatalf("sm4 cfb from-second-block mismatch") + } +} +func TestAesCCMRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-ccm-roundtrip") + + enc, err := EncryptAesCCM(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptAesCCM failed: %v", err) + } + dec, err := DecryptAesCCM(enc, key, nonce, aad) + if err != nil { + t.Fatalf("DecryptAesCCM failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes ccm mismatch") + } +} + +func TestSM4CCMRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("sm4-ccm-roundtrip") + + enc, err := EncryptSM4CCM(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptSM4CCM failed: %v", err) + } + dec, err := DecryptSM4CCM(enc, key, nonce, aad) + if err != nil { + t.Fatalf("DecryptSM4CCM failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 ccm mismatch") + } +} + +func TestAesCCMStreamRoundTripChunked(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := bytes.Repeat([]byte("aes-ccm-stream-chunk-"), 10000) + + enc := &bytes.Buffer{} + if err := EncryptAesCCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil { + t.Fatalf("EncryptAesCCMStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesCCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil { + t.Fatalf("DecryptAesCCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes ccm stream mismatch") + } +} + +func TestAesCCMStreamLegacyCompatDecrypt(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-ccm-legacy-compat") + + legacyCipher, err := EncryptAesCCM(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptAesCCM failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesCCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil { + t.Fatalf("DecryptAesCCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes ccm legacy decrypt mismatch") + } +} + +func TestSM4CCMStreamRoundTripChunked(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := bytes.Repeat([]byte("sm4-ccm-stream-chunk-"), 10000) + + enc := &bytes.Buffer{} + if err := EncryptSM4CCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil { + t.Fatalf("EncryptSM4CCMStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptSM4CCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil { + t.Fatalf("DecryptSM4CCMStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("sm4 ccm stream mismatch") + } +} + +func TestAesCCMChunkRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-ccm-chunk") + chunkIndex := uint64(5) + + enc, err := EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("EncryptAesCCMChunk failed: %v", err) + } + dec, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("DecryptAesCCMChunk failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes ccm chunk decrypt mismatch") + } + if _, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil { + t.Fatalf("expected decrypt error for wrong chunk index") + } +} + +func TestSM4CCMChunkRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("sm4-ccm-chunk") + chunkIndex := uint64(9) + + enc, err := EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("EncryptSM4CCMChunk failed: %v", err) + } + dec, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex) + if err != nil { + t.Fatalf("DecryptSM4CCMChunk failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 ccm chunk decrypt mismatch") + } + if _, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil { + t.Fatalf("expected decrypt error for wrong chunk index") + } +} + +func TestAesOptionsModeCCM(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("aes-options-ccm") + + opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad} + enc, err := EncryptAesWithOptions(plain, key, opts) + if err != nil { + t.Fatalf("EncryptAesWithOptions CCM failed: %v", err) + } + dec, err := DecryptAesWithOptions(enc, key, opts) + if err != nil { + t.Fatalf("DecryptAesWithOptions CCM failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes options ccm mismatch") + } +} + +func TestSM4OptionsModeCCM(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("sm4-options-ccm") + + opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad} + enc, err := EncryptSM4WithOptions(plain, key, opts) + if err != nil { + t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err) + } + dec, err := DecryptSM4WithOptions(enc, key, opts) + if err != nil { + t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 options ccm mismatch") + } +} + +func TestCCMInvalidNonceLength(t *testing.T) { + key := []byte("0123456789abcdef") + shortNonce := []byte("short") + if _, err := EncryptAesCCM([]byte("x"), key, shortNonce, nil); err == nil { + t.Fatalf("expected aes ccm nonce length error") + } + if _, err := EncryptSM4CCM([]byte("x"), key, shortNonce, nil); err == nil { + t.Fatalf("expected sm4 ccm nonce length error") + } +} +func TestAesXTSRoundTrip(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 8) + + enc, err := EncryptAesXTS(plain, k1, k2, 32) + if err != nil { + t.Fatalf("EncryptAesXTS failed: %v", err) + } + dec, err := DecryptAesXTS(enc, k1, k2, 32) + if err != nil { + t.Fatalf("DecryptAesXTS failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes xts mismatch") + } +} + +func TestSM4XTSRoundTrip(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 8) + + enc, err := EncryptSM4XTS(plain, k1, k2, 32) + if err != nil { + t.Fatalf("EncryptSM4XTS failed: %v", err) + } + dec, err := DecryptSM4XTS(enc, k1, k2, 32) + if err != nil { + t.Fatalf("DecryptSM4XTS failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 xts mismatch") + } +} + +func TestAesXTSAtDataUnit(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 8) + dataUnitSize := 32 + + full, err := EncryptAesXTS(plain, k1, k2, dataUnitSize) + if err != nil { + t.Fatalf("EncryptAesXTS failed: %v", err) + } + segPlain := plain[64:96] + segEnc, err := EncryptAesXTSAt(segPlain, k1, k2, dataUnitSize, 2) + if err != nil { + t.Fatalf("EncryptAesXTSAt failed: %v", err) + } + if !bytes.Equal(segEnc, full[64:96]) { + t.Fatalf("aes xts at encrypt mismatch") + } + + segDec, err := DecryptAesXTSAt(full[64:96], k1, k2, dataUnitSize, 2) + if err != nil { + t.Fatalf("DecryptAesXTSAt failed: %v", err) + } + if !bytes.Equal(segDec, segPlain) { + t.Fatalf("aes xts at decrypt mismatch") + } +} + +func TestSM4XTSAtDataUnit(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 8) + dataUnitSize := 32 + + full, err := EncryptSM4XTS(plain, k1, k2, dataUnitSize) + if err != nil { + t.Fatalf("EncryptSM4XTS failed: %v", err) + } + segPlain := plain[32:64] + segEnc, err := EncryptSM4XTSAt(segPlain, k1, k2, dataUnitSize, 1) + if err != nil { + t.Fatalf("EncryptSM4XTSAt failed: %v", err) + } + if !bytes.Equal(segEnc, full[32:64]) { + t.Fatalf("sm4 xts at encrypt mismatch") + } + + segDec, err := DecryptSM4XTSAt(full[32:64], k1, k2, dataUnitSize, 1) + if err != nil { + t.Fatalf("DecryptSM4XTSAt failed: %v", err) + } + if !bytes.Equal(segDec, segPlain) { + t.Fatalf("sm4 xts at decrypt mismatch") + } +} + +func TestAesXTSStreamRoundTrip(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 2048) + + enc := &bytes.Buffer{} + if err := EncryptAesXTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil { + t.Fatalf("EncryptAesXTSStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptAesXTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil { + t.Fatalf("DecryptAesXTSStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("aes xts stream mismatch") + } +} + +func TestSM4XTSStreamRoundTrip(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 2048) + + enc := &bytes.Buffer{} + if err := EncryptSM4XTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil { + t.Fatalf("EncryptSM4XTSStream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptSM4XTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil { + t.Fatalf("DecryptSM4XTSStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("sm4 xts stream mismatch") + } +} + +func TestXTSRejectsNonBlockMultiple(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + if _, err := EncryptAesXTS([]byte("short"), k1, k2, 32); err == nil { + t.Fatalf("expected aes xts non-block error") + } + if _, err := EncryptSM4XTS([]byte("short"), k1, k2, 32); err == nil { + t.Fatalf("expected sm4 xts non-block error") + } +} + +func TestXTSStreamRejectsTailNotFullBlock(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + err := EncryptAesXTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32) + if err == nil { + t.Fatalf("expected aes xts stream tail error") + } + err = EncryptSM4XTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32) + if err == nil { + t.Fatalf("expected sm4 xts stream tail error") + } +} + +func TestXTSInvalidDataUnitSize(t *testing.T) { + k1 := []byte("0123456789abcdef") + k2 := []byte("fedcba9876543210") + plain := bytes.Repeat([]byte("0123456789abcdef"), 2) + if _, err := EncryptAesXTS(plain, k1, k2, 30); err == nil { + t.Fatalf("expected aes xts invalid data unit size error") + } + if _, err := EncryptSM4XTS(plain, k1, k2, 30); err == nil { + t.Fatalf("expected sm4 xts invalid data unit size error") + } +} +func TestSplitXTSMasterKeyHelpers(t *testing.T) { + master := []byte("0123456789abcdef0123456789abcdef") + + k1, k2, err := SplitXTSMasterKey(master) + if err != nil { + t.Fatalf("SplitXTSMasterKey failed: %v", err) + } + if len(k1) != 16 || len(k2) != 16 { + t.Fatalf("split key lengths mismatch") + } + if !bytes.Equal(append(k1, k2...), master) { + t.Fatalf("split key content mismatch") + } + + aesK1, aesK2, err := SplitAesXTSMasterKey(master) + if err != nil { + t.Fatalf("SplitAesXTSMasterKey failed: %v", err) + } + if !bytes.Equal(aesK1, k1) || !bytes.Equal(aesK2, k2) { + t.Fatalf("aes split mismatch") + } + + sm4K1, sm4K2, err := SplitSM4XTSMasterKey(master) + if err != nil { + t.Fatalf("SplitSM4XTSMasterKey failed: %v", err) + } + if !bytes.Equal(sm4K1, k1) || !bytes.Equal(sm4K2, k2) { + t.Fatalf("sm4 split mismatch") + } + + if _, _, err := SplitXTSMasterKey([]byte("abc")); err == nil { + t.Fatalf("expected odd-length split error") + } + if _, _, err := SplitAesXTSMasterKey([]byte("0123456789abcdef0123456789abcdef01")); err == nil { + t.Fatalf("expected aes master length error") + } + if _, _, err := SplitSM4XTSMasterKey([]byte("0123456789abcdef")); err == nil { + t.Fatalf("expected sm4 master length error") + } +} diff --git a/symm/xts.go b/symm/xts.go new file mode 100644 index 0000000..22b0604 --- /dev/null +++ b/symm/xts.go @@ -0,0 +1,263 @@ +package symm + +import ( + "crypto/aes" + "crypto/cipher" + "errors" + "io" + + "github.com/emmansun/gmsm/sm4" + "golang.org/x/crypto/xts" +) + +const xtsBlockSize = 16 + +var ( + ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16") + ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16") + ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal") + ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even") + ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes") + ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes") +) + +type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error) + +func combineXTSKeys(key1, key2 []byte) ([]byte, error) { + if len(key1) == 0 || len(key1) != len(key2) { + return nil, ErrInvalidXTSKeyLength + } + out := make([]byte, len(key1)+len(key2)) + copy(out, key1) + copy(out[len(key1):], key2) + return out, nil +} + +func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + if len(masterKey) == 0 || len(masterKey)%2 != 0 { + return nil, nil, ErrInvalidXTSMasterKeyLength + } + half := len(masterKey) / 2 + k1 := make([]byte, half) + k2 := make([]byte, half) + copy(k1, masterKey[:half]) + copy(k2, masterKey[half:]) + return k1, k2, nil +} + +// SplitXTSMasterKey splits a master key into two equal XTS keys. +func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + return splitXTSMasterKey(masterKey) +} + +// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes). +func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + switch len(masterKey) { + case 32, 48, 64: + return splitXTSMasterKey(masterKey) + default: + return nil, nil, ErrInvalidAESXTSMasterKeyLength + } +} + +// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes). +func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + if len(masterKey) != 32 { + return nil, nil, ErrInvalidSM4XTSMasterKeyLength + } + return splitXTSMasterKey(masterKey) +} + +func validateXTSDataUnitSize(dataUnitSize int) error { + if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 { + return ErrInvalidXTSDataUnitSize + } + return nil +} + +func validateXTSDataLength(data []byte) error { + if len(data)%xtsBlockSize != 0 { + return ErrInvalidXTSDataLength + } + return nil +} + +func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) { + if err := validateXTSDataUnitSize(dataUnitSize); err != nil { + return nil, err + } + if err := validateXTSDataLength(in); err != nil { + return nil, err + } + if len(in) == 0 { + return []byte{}, nil + } + + out := make([]byte, len(in)) + off := 0 + unit := dataUnitIndex + for off < len(in) { + chunkLen := dataUnitSize + if remain := len(in) - off; remain < chunkLen { + chunkLen = remain + } + if decrypt { + c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit) + } else { + c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit) + } + off += chunkLen + unit++ + } + return out, nil +} + +func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error { + if err := validateXTSDataUnitSize(dataUnitSize); err != nil { + return err + } + + buf := make([]byte, 32*1024) + pending := make([]byte, 0, dataUnitSize*2) + unit := dataUnitIndex + + for { + n, err := src.Read(buf) + if n > 0 { + pending = append(pending, buf[:n]...) + processLen := len(pending) / dataUnitSize * dataUnitSize + if processLen > 0 { + out := make([]byte, processLen) + for off := 0; off < processLen; off += dataUnitSize { + if decrypt { + c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit) + } else { + c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit) + } + unit++ + } + if _, werr := dst.Write(out); werr != nil { + return werr + } + pending = append([]byte(nil), pending[processLen:]...) + } + } + if err != nil { + if err == io.EOF { + break + } + return err + } + } + + if len(pending) == 0 { + return nil + } + if err := validateXTSDataLength(pending); err != nil { + return err + } + out := make([]byte, len(pending)) + if decrypt { + c.Decrypt(out, pending, unit) + } else { + c.Encrypt(out, pending, unit) + } + _, err := dst.Write(out) + return err +} + +func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) { + key, err := combineXTSKeys(key1, key2) + if err != nil { + return nil, err + } + return xts.NewCipher(newBlock, key) +} + +func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) { + return newXTSCipher(aes.NewCipher, key1, key2) +} + +func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) { + return newXTSCipher(sm4.NewCipher, key1, key2) +} + +func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) { + c, err := factory(key1, key2) + if err != nil { + return nil, err + } + return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt) +} + +func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error { + c, err := factory(key1, key2) + if err != nil { + return err + } + return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt) +} + +func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0) +} + +func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0) +} + +func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2) +} + +func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2) +} + +func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) +} + +func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) +} + +func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2) +} + +func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2) +} + +func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0) +} + +func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { + return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0) +} + +func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2) +} + +func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { + return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2) +} + +func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) +} + +func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { + return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) +} + +func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2) +} + +func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { + return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2) +} diff --git a/symm/xts_vector_test.go b/symm/xts_vector_test.go new file mode 100644 index 0000000..b0bc9ba --- /dev/null +++ b/symm/xts_vector_test.go @@ -0,0 +1,79 @@ +package symm + +import ( + "bytes" + "testing" +) + +// AES-XTS vectors from IEEE P1619/D16 Annex B (same set used by golang.org/x/crypto/xts tests). +var aesXTSStandardVectors = []struct { + key string + dataUnitIndex uint64 + plaintext string + ciphertext string +}{ + { + key: "0000000000000000000000000000000000000000000000000000000000000000", + dataUnitIndex: 0, + plaintext: "0000000000000000000000000000000000000000000000000000000000000000", + ciphertext: "917cf69ebd68b2ec9b9fe9a3eadda692cd43d2f59598ed858c02c2652fbf922e", + }, + { + key: "1111111111111111111111111111111122222222222222222222222222222222", + dataUnitIndex: 0x3333333333, + plaintext: "4444444444444444444444444444444444444444444444444444444444444444", + ciphertext: "c454185e6a16936e39334038acef838bfb186fff7480adc4289382ecd6d394f0", + }, + { + key: "fffefdfcfbfaf9f8f7f6f5f4f3f2f1f022222222222222222222222222222222", + dataUnitIndex: 0x3333333333, + plaintext: "4444444444444444444444444444444444444444444444444444444444444444", + ciphertext: "af85336b597afc1a900b2eb21ec949d292df4c047e0b21532186a5971a227a89", + }, +} + +func TestAesXTSStandardVectors(t *testing.T) { + for i, tc := range aesXTSStandardVectors { + master := mustHex(t, tc.key) + key1, key2, err := SplitAesXTSMasterKey(master) + if err != nil { + t.Fatalf("#%d split key failed: %v", i, err) + } + + plain := mustHex(t, tc.plaintext) + wantCipher := mustHex(t, tc.ciphertext) + dataUnitSize := len(plain) + + gotCipher, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, tc.dataUnitIndex) + if err != nil { + t.Fatalf("#%d EncryptAesXTSAt failed: %v", i, err) + } + if !bytes.Equal(gotCipher, wantCipher) { + t.Fatalf("#%d ciphertext mismatch", i) + } + + gotPlain, err := DecryptAesXTSAt(wantCipher, key1, key2, dataUnitSize, tc.dataUnitIndex) + if err != nil { + t.Fatalf("#%d DecryptAesXTSAt failed: %v", i, err) + } + if !bytes.Equal(gotPlain, plain) { + t.Fatalf("#%d plaintext mismatch", i) + } + + encStream := &bytes.Buffer{} + if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil { + t.Fatalf("#%d EncryptAesXTSStreamAt failed: %v", i, err) + } + if !bytes.Equal(encStream.Bytes(), wantCipher) { + t.Fatalf("#%d stream ciphertext mismatch", i) + } + + decStream := &bytes.Buffer{} + if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(wantCipher), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil { + t.Fatalf("#%d DecryptAesXTSStreamAt failed: %v", i, err) + } + if !bytes.Equal(decStream.Bytes(), plain) { + t.Fatalf("#%d stream plaintext mismatch", i) + } + } +} diff --git a/xts_keysplit.go b/xts_keysplit.go new file mode 100644 index 0000000..3c190d5 --- /dev/null +++ b/xts_keysplit.go @@ -0,0 +1,15 @@ +package starcrypto + +import "b612.me/starcrypto/symm" + +func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + return symm.SplitXTSMasterKey(masterKey) +} + +func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + return symm.SplitAesXTSMasterKey(masterKey) +} + +func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) { + return symm.SplitSM4XTSMasterKey(masterKey) +}