feat: 新增XTS/CCM流式与KDF能力,补充安全测试并更新README/CHANGELOG

This commit is contained in:
兔子 2026-03-18 13:43:18 +08:00
parent e89350b56a
commit 4fa79744e8
Signed by: b612
GPG Key ID: 99DD2222B612B612
44 changed files with 4636 additions and 77 deletions

6
.gitignore vendored
View File

@ -5,6 +5,7 @@
# Local Go module/cache sandboxes used in this repo
.gopath/
.gomodcache/
.gocache/
# Build artifacts
*.exe
@ -27,3 +28,8 @@ Thumbs.db
# Temp/test artifacts
sm3/ifile
# Agent workflow artifacts
.sentrux/
agent_readme.md
target.md

View File

@ -1,4 +1,52 @@
# Changelog
# Changelog
## 2026-03-17
---
### Added
- Added AES-XTS and SM4-XTS APIs in `symm` and root wrappers:
- bytes APIs: `Encrypt*/Decrypt*XTS`
- data-unit indexed APIs: `Encrypt*/Decrypt*XTSAt`
- stream APIs: `Encrypt*/Decrypt*XTSStream`
- stream + data-unit indexed APIs: `Encrypt*/Decrypt*XTSStreamAt`
- Added XTS master-key split helpers:
- `SplitXTSMasterKey`
- `SplitAesXTSMasterKey`
- `SplitSM4XTSMasterKey`
- Added XTS parameter validation and explicit no-CTS behavior:
- dual keys must be non-empty and equal length
- `dataUnitSize` must be a positive multiple of 16
- non-stream input must be 16-byte aligned
- stream tail must be 16-byte aligned
- Added key-derivation APIs in `hashx` and root wrappers:
- `DerivePBKDF2SHA256Key` (`crypto/pbkdf2`, stdlib)
- `DerivePBKDF2SHA512Key` (`crypto/pbkdf2`, stdlib)
- `DeriveArgon2idKey` / `DeriveArgon2iKey` (`golang.org/x/crypto/argon2`)
- `Argon2Params` + `DefaultArgon2idParams`
- Added benchmark coverage for symmetric hot paths:
- AES/SM4 `GCM`, `CCM`, `XTS`
- AES/SM4 `XTS` stream path
- Added file random fill APIs:
- `FillWithRandom` (math/rand pseudo-random)
- `FillWithCryptoRandom` (crypto/rand secure random, may be slower)
- Added HMAC verify APIs (bytes + hex wrappers) in `macx` and root package.
- Added XTS standard-vector tests (IEEE P1619 subset) and XTS fuzz tests.
- Added CCM/XTS related test coverage for root wrappers and `symm`.
### Changed
- Refactored XTS internals to reduce duplication via shared factory/path.
- Switched AEAD options behavior to require explicit `Nonce` (no `IV` fallback in GCM/CCM paths).
- Reworked CFB-8 register update to ring-buffer state handling.
- Refined `README.md` structure and added:
- XTS usage/constraints documentation
- AEAD nonce non-reuse requirement
- AEAD `CipherOptions` nonce-only behavior note
- legacy GCM/CCM stream decryption memory warning
- `FillWithRandom` vs `FillWithCryptoRandom` guidance
- XTS `dataUnitIndex` mapping consistency note
## Unreleased
@ -11,7 +59,7 @@
- Added unified symmetric cipher options API:
- `CipherOptions{Mode, Padding, IV, Nonce, AAD}`.
- Added AEAD APIs and wrappers:
- `AES-GCM`, `SM4-GCM` (bytes + stream helper APIs).
- `AES-GCM`, `SM4-GCM`, `AES-CCM`, `SM4-CCM` (bytes + chunk + stream helper APIs).
- Added more symmetric mode coverage for SM4:
- `ECB/CBC/CFB/OFB/CTR` (bytes + stream derived APIs).
- Added comprehensive tests across packages and root wrappers.

View File

@ -1,37 +1,56 @@
# starcrypto
# starcrypto
`starcrypto` 是一个 Go 单包聚合风格的密码学工具库:根包可直接调用,同时内部拆分为子包实现,兼顾易用性与可维护性。
## 特性
- 根包直调 + 子包分层(`asymm` / `symm` / `hashx` / `encodingx` / `paddingx` 等)
- 对称算法AES、SM4、DES、3DES、ChaCha20、ChaCha20-Poly1305
- 非对称算法RSA、ECDSA、SM2、SM9
- 哈希与消息认证SHA 系列、MD5/MD4、RIPEMD160、SM3、CRC32/CRC32A、HMAC
- 编解码Base64、Base85、Base91、Base128
- 支持内存 `[]byte` 与流式 `io.Reader/io.Writer`(见下方能力矩阵)
## 安装
```bash
go get b612.me/starcrypto
```
## 特性
- 根包直调 + 子包分层(`asymm` / `symm` / `hashx` / `encodingx` / `paddingx` 等)
- 对称算法AES、SM4、DES、3DES、ChaCha20、ChaCha20-Poly1305
- AEADAES-GCM/AES-CCM、SM4-GCM/SM4-CCM、ChaCha20-Poly1305
- 存储加密AES-XTS、SM4-XTS双 key支持 data unit 索引与流式)
- 非对称算法RSA、ECDSA、SM2、SM9
- 哈希与消息认证SHA 系列、MD5/MD4、RIPEMD160、SM3、CRC32/CRC32A、HMAC
- 支持内存 `[]byte` 与流式 `io.Reader/io.Writer`
## 推荐用法(安全优先)
优先使用带认证的 AEAD
- 通用数据传输优先使用 AEAD
- `EncryptAesGCM/DecryptAesGCM`
- `EncryptAesCCM/DecryptAesCCM`
- `EncryptSM4GCM/DecryptSM4GCM`
- `EncryptSM4CCM/DecryptSM4CCM`
- `EncryptChaCha20Poly1305/DecryptChaCha20Poly1305`
- 磁盘/扇区类场景可用 XTS
- `EncryptAesXTS/DecryptAesXTS`
- `EncryptSM4XTS/DecryptSM4XTS`
- 统一选项接口(默认 `GCM`
- `EncryptAesWithOptions/DecryptAesWithOptions`
- `EncryptSM4WithOptions/DecryptSM4WithOptions`
- `EncryptAesGCM/DecryptAesGCM`
- `EncryptSM4GCM/DecryptSM4GCM`
- `EncryptChaCha20Poly1305/DecryptChaCha20Poly1305`
> `CBC/CFB/OFB/CTR/ECB/XTS` 不提供完整性校验,不能替代 AEAD 的篡改检测能力。
> AEAD (`GCM/CCM/ChaCha20-Poly1305`) 下,同一把 key 绝不能重复使用同一个 nonce。
> 使用 `CipherOptions`AEAD 模式仅读取 `Nonce` 字段,不再回退 `IV`
> GCM/CCM 流解密在 legacy 兼容分支会缓冲完整密文到内存;大文件建议使用分块流格式(带 `SCG1/SCC1` 头)。
或使用统一选项接口(默认 `GCM`
## XTS 约束(重要)
- `EncryptAesWithOptions/DecryptAesWithOptions`
- `EncryptSM4WithOptions/DecryptSM4WithOptions`
- 两个密钥分开传入:`key1`, `key2`,且长度必须相同且非空。
- `dataUnitSize`(数据单元大小)由调用方自定义,但必须是正整数且为 `16` 的倍数。
- `dataUnitIndex` 必须与存储层的逻辑块映射保持稳定一致;同一数据单元在加密和解密时索引必须一致。
- 当前实现不做 CTSciphertext stealing
- 非流式 API输入长度必须满足 `len(data) % 16 == 0`,否则直接返回错误。
- 流式 API最终尾块若不是 `16` 字节对齐会返回错误。
> `CBC/CFB/OFB/CTR/ECB` 仅提供机密性,不提供完整性校验,可能被篡改后无法检测。
## 文件随机填充
- `FillWithRandom` 使用 `math/rand` 伪随机,速度更高,但不适合安全用途。
- `FillWithCryptoRandom` 使用 `crypto/rand`,安全性更高,但速度可能受限。
## 快速示例
### 统一 Options默认 GCM
@ -65,7 +84,7 @@ func main() {
}
```
### AES CBC 流式接口(兼容模式
### AES-XTS按 data unit
```go
package main
@ -78,19 +97,20 @@ import (
)
func main() {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8) // 128 bytes, 16-byte aligned
src := bytes.NewReader([]byte("stream content"))
encBuf := &bytes.Buffer{}
decBuf := &bytes.Buffer{}
if err := starcrypto.EncryptAesCBCStream(encBuf, src, key, iv, ""); err != nil {
// 每个 data unit 64 字节
enc, err := starcrypto.EncryptAesXTS(plain, k1, k2, 64)
if err != nil {
log.Fatal(err)
}
if err := starcrypto.DecryptAesCBCStream(decBuf, bytes.NewReader(encBuf.Bytes()), key, iv, ""); err != nil {
dec, err := starcrypto.DecryptAesXTS(enc, k1, k2, 64)
if err != nil {
log.Fatal(err)
}
_ = dec
}
```
@ -98,8 +118,8 @@ func main() {
| 算法 | 模式 / 方案 | AEAD | Stream | 建议 |
|---|---|---:|---:|---|
| AES | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM |
| SM4 | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM |
| AES | ECB/CBC/CFB/OFB/CTR/GCM/CCM/XTS | GCM/CCM: 是XTS: 否 | 是 | 生产优先 GCM/CCM存储场景可用 XTS |
| SM4 | ECB/CBC/CFB/OFB/CTR/GCM/CCM/XTS | GCM/CCM: 是XTS: 否 | 是 | 生产优先 GCM/CCM存储场景可用 XTS |
| ChaCha20 | ChaCha20 / ChaCha20-Poly1305 | Poly1305: 是 | ChaCha20: 是 | 生产优先 ChaCha20-Poly1305 |
| DES | CBC | 否 | 是 | 仅兼容历史系统 |
| 3DES | CBC | 否 | 是 | 仅兼容历史系统 |
@ -108,11 +128,11 @@ func main() {
- AES/SM4 的 CBC/ECB 默认:`PKCS7`
- DES/3DES 的 CBC 默认:`PKCS5`
- CFB/OFB/CTR/GCM/ChaCha20 不使用填充
- CFB/OFB/CTR/GCM/CCM/XTS/ChaCha20 不使用填充
## 兼容性说明
库中保留了部分历史/兼容用途算法与接口(例如 `ECB``DES/3DES`)。如无兼容要求,建议使用 AEAD 方案并统一通过 `CipherOptions` 管理参数
库中保留了部分历史/兼容用途算法与接口(例如 `ECB``DES/3DES`)。如无兼容要求,建议优先使用 AEAD 方案。
## 许可证

34
THIRD_PARTY_NOTICES.md Normal file
View File

@ -0,0 +1,34 @@
# Third-Party Notices
This repository is primarily licensed under Apache-2.0. Some files are copied from third-party projects and remain under their original licenses.
## Pion DTLS CCM implementation
- Upstream project: https://github.com/pion/dtls
- Upstream path: `pkg/crypto/ccm`
- Local files:
- `ccm/ccm.go`
- `ccm/ccm_test.go`
- License: MIT
### MIT License (Pion DTLS)
Copyright (c) 2026 The Pion community <https://pion.ly>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

100
aes.go
View File

@ -20,6 +20,7 @@ const (
MODEOFB = symm.MODEOFB
MODECTR = symm.MODECTR
MODEGCM = symm.MODEGCM
MODECCM = symm.MODECCM
)
func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
@ -62,6 +63,14 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return symm.DecryptAesGCM(ciphertext, key, nonce, aad)
}
func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex)
}
func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.DecryptAesGCMChunk(ciphertext, key, nonce, aad, chunkIndex)
}
func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.EncryptAesGCMStream(dst, src, key, nonce, aad)
}
@ -70,6 +79,30 @@ func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) e
return symm.DecryptAesGCMStream(dst, src, key, nonce, aad)
}
func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) {
return symm.EncryptAesCCM(plain, key, nonce, aad)
}
func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return symm.DecryptAesCCM(ciphertext, key, nonce, aad)
}
func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex)
}
func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.DecryptAesCCMChunk(ciphertext, key, nonce, aad, chunkIndex)
}
func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.EncryptAesCCMStream(dst, src, key, nonce, aad)
}
func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.DecryptAesCCMStream(dst, src, key, nonce, aad)
}
func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) {
return symm.EncryptAesECB(data, key, paddingType)
}
@ -110,6 +143,14 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) {
return symm.DecryptAesCTR(src, key, iv)
}
func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
return symm.EncryptAesCTRAt(data, key, iv, offset)
}
func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return symm.DecryptAesCTRAt(src, key, iv, offset)
}
func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return symm.EncryptAesECBStream(dst, src, key, paddingType)
}
@ -181,3 +222,62 @@ func PKCS7Padding(cipherText []byte, blockSize int) []byte {
func PKCS7Trimming(encrypt []byte, blockSize int) []byte {
return symm.PKCS7Trimming(encrypt, blockSize)
}
func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) {
return symm.EncryptAesCFB8(data, key, iv)
}
func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) {
return symm.DecryptAesCFB8(src, key, iv)
}
func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
return symm.EncryptAesCFB8Stream(dst, src, key, iv)
}
func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
return symm.DecryptAesCFB8Stream(dst, src, key, iv)
}
func DecryptAesECBBlocks(src, key []byte) ([]byte, error) {
return symm.DecryptAesECBBlocks(src, key)
}
func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
return symm.DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock)
}
func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
return symm.DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock)
}
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return symm.EncryptAesXTS(plain, key1, key2, dataUnitSize)
}
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return symm.DecryptAesXTS(ciphertext, key1, key2, dataUnitSize)
}
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return symm.EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
}
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return symm.DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, dataUnitIndex)
}
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return symm.EncryptAesXTSStream(dst, src, key1, key2, dataUnitSize)
}
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return symm.DecryptAesXTSStream(dst, src, key1, key2, dataUnitSize)
}
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return symm.EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex)
}
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return symm.DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex)
}

48
api_security_test.go Normal file
View File

@ -0,0 +1,48 @@
package starcrypto
import (
"os"
"path/filepath"
"testing"
)
func TestRootHmacVerifyWrappers(t *testing.T) {
msg := []byte("hmac-verify-message")
key := []byte("hmac-verify-key")
sum := HmacSHA256(msg, key)
if !VerifyHmacSHA256(msg, key, sum) {
t.Fatalf("VerifyHmacSHA256 should pass for correct digest")
}
bad := make([]byte, len(sum))
copy(bad, sum)
bad[0] ^= 0xff
if VerifyHmacSHA256(msg, key, bad) {
t.Fatalf("VerifyHmacSHA256 should fail for wrong digest")
}
hexSum := HmacSHA256Str(msg, key)
if !VerifyHmacSHA256Str(msg, key, hexSum) {
t.Fatalf("VerifyHmacSHA256Str should pass for correct digest")
}
if VerifyHmacSHA256Str(msg, key, "not-hex") {
t.Fatalf("VerifyHmacSHA256Str should fail for invalid hex")
}
}
func TestRootFillWithCryptoRandomWrapper(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "secure.bin")
if err := FillWithCryptoRandom(path, 1024, 128, nil); err != nil {
t.Fatalf("FillWithCryptoRandom failed: %v", err)
}
info, err := os.Stat(path)
if err != nil {
t.Fatalf("stat file failed: %v", err)
}
if info.Size() != 1024 {
t.Fatalf("unexpected file size: %d", info.Size())
}
}

View File

@ -263,3 +263,353 @@ func TestRootOptionsAndGCMWrappers(t *testing.T) {
t.Fatalf("sm4 options wrapper mismatch")
}
}
func TestRootCTRAtAndGCMChunkWrappers(t *testing.T) {
aesKey := []byte("0123456789abcdef")
aesIV := []byte("abcdef9876543210")
aesNonce := []byte("123456789012")
plain := bytes.Repeat([]byte("root-ctr-offset-"), 64)
aesFull, err := EncryptAesCTR(plain, aesKey, aesIV)
if err != nil {
t.Fatalf("EncryptAesCTR failed: %v", err)
}
offset := 33
seg := aesFull[offset : offset+100]
decSeg, err := DecryptAesCTRAt(seg, aesKey, aesIV, int64(offset))
if err != nil {
t.Fatalf("DecryptAesCTRAt failed: %v", err)
}
if !bytes.Equal(decSeg, plain[offset:offset+100]) {
t.Fatalf("root aes ctr at mismatch")
}
chunkCipher, err := EncryptAesGCMChunk([]byte("root-aes-gcm-chunk"), aesKey, aesNonce, []byte("aad"), 2)
if err != nil {
t.Fatalf("EncryptAesGCMChunk failed: %v", err)
}
chunkPlain, err := DecryptAesGCMChunk(chunkCipher, aesKey, aesNonce, []byte("aad"), 2)
if err != nil {
t.Fatalf("DecryptAesGCMChunk failed: %v", err)
}
if !bytes.Equal(chunkPlain, []byte("root-aes-gcm-chunk")) {
t.Fatalf("root aes gcm chunk mismatch")
}
sm4Key := []byte("0123456789abcdef")
sm4IV := []byte("abcdef9876543210")
sm4Nonce := []byte("123456789012")
sm4Full, err := EncryptSM4CTR(plain, sm4Key, sm4IV)
if err != nil {
t.Fatalf("EncryptSM4CTR failed: %v", err)
}
sm4Seg := sm4Full[offset : offset+100]
sm4DecSeg, err := DecryptSM4CTRAt(sm4Seg, sm4Key, sm4IV, int64(offset))
if err != nil {
t.Fatalf("DecryptSM4CTRAt failed: %v", err)
}
if !bytes.Equal(sm4DecSeg, plain[offset:offset+100]) {
t.Fatalf("root sm4 ctr at mismatch")
}
sm4ChunkCipher, err := EncryptSM4GCMChunk([]byte("root-sm4-gcm-chunk"), sm4Key, sm4Nonce, []byte("aad"), 3)
if err != nil {
t.Fatalf("EncryptSM4GCMChunk failed: %v", err)
}
sm4ChunkPlain, err := DecryptSM4GCMChunk(sm4ChunkCipher, sm4Key, sm4Nonce, []byte("aad"), 3)
if err != nil {
t.Fatalf("DecryptSM4GCMChunk failed: %v", err)
}
if !bytes.Equal(sm4ChunkPlain, []byte("root-sm4-gcm-chunk")) {
t.Fatalf("root sm4 gcm chunk mismatch")
}
}
func TestRootCFB8AndSegmentDecryptWrappers(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
aesCFB8, err := EncryptAesCFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB8 failed: %v", err)
}
aesCFB8Dec, err := DecryptAesCFB8(aesCFB8, key, iv)
if err != nil {
t.Fatalf("DecryptAesCFB8 failed: %v", err)
}
if !bytes.Equal(aesCFB8Dec, plain) {
t.Fatalf("root aes cfb8 mismatch")
}
aesCBC, err := EncryptAesCBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
}
aesCBCSegDec, err := DecryptAesCBCFromSecondBlock(aesCBC[len(iv):], key, aesCBC[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(aesCBCSegDec, plain[len(iv):]) {
t.Fatalf("root aes cbc from-second-block mismatch")
}
aesCFB, err := EncryptAesCFB(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB failed: %v", err)
}
aesCFBSegDec, err := DecryptAesCFBFromSecondBlock(aesCFB[len(iv):], key, aesCFB[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(aesCFBSegDec, plain[len(iv):]) {
t.Fatalf("root aes cfb from-second-block mismatch")
}
sm4CFB8, err := EncryptSM4CFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB8 failed: %v", err)
}
sm4CFB8Dec, err := DecryptSM4CFB8(sm4CFB8, key, iv)
if err != nil {
t.Fatalf("DecryptSM4CFB8 failed: %v", err)
}
if !bytes.Equal(sm4CFB8Dec, plain) {
t.Fatalf("root sm4 cfb8 mismatch")
}
sm4CBC, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptSM4CBC failed: %v", err)
}
sm4CBCSegDec, err := DecryptSM4CBCFromSecondBlock(sm4CBC[len(iv):], key, sm4CBC[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(sm4CBCSegDec, plain[len(iv):]) {
t.Fatalf("root sm4 cbc from-second-block mismatch")
}
sm4CFB, err := EncryptSM4CFBNoBlock(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB failed: %v", err)
}
sm4CFBSegDec, err := DecryptSM4CFBFromSecondBlock(sm4CFB[len(iv):], key, sm4CFB[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(sm4CFBSegDec, plain[len(iv):]) {
t.Fatalf("root sm4 cfb from-second-block mismatch")
}
}
func TestRootOptionsAndCCMWrappers(t *testing.T) {
aesKey := []byte("0123456789abcdef")
sm4Key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("root-options-ccm")
aesEnc, err := EncryptAesCCM(plain, aesKey, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesCCM failed: %v", err)
}
aesDec, err := DecryptAesCCM(aesEnc, aesKey, nonce, aad)
if err != nil {
t.Fatalf("DecryptAesCCM failed: %v", err)
}
if !bytes.Equal(aesDec, plain) {
t.Fatalf("aes ccm wrapper mismatch")
}
aesOpts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
aesOptEnc, err := EncryptAesWithOptions(plain, aesKey, aesOpts)
if err != nil {
t.Fatalf("EncryptAesWithOptions CCM failed: %v", err)
}
aesOptDec, err := DecryptAesWithOptions(aesOptEnc, aesKey, aesOpts)
if err != nil {
t.Fatalf("DecryptAesWithOptions CCM failed: %v", err)
}
if !bytes.Equal(aesOptDec, plain) {
t.Fatalf("aes options ccm wrapper mismatch")
}
aesChunkCipher, err := EncryptAesCCMChunk([]byte("root-aes-ccm-chunk"), aesKey, nonce, aad, 4)
if err != nil {
t.Fatalf("EncryptAesCCMChunk failed: %v", err)
}
aesChunkPlain, err := DecryptAesCCMChunk(aesChunkCipher, aesKey, nonce, aad, 4)
if err != nil {
t.Fatalf("DecryptAesCCMChunk failed: %v", err)
}
if !bytes.Equal(aesChunkPlain, []byte("root-aes-ccm-chunk")) {
t.Fatalf("root aes ccm chunk mismatch")
}
aesStreamEnc := &bytes.Buffer{}
if err := EncryptAesCCMStream(aesStreamEnc, bytes.NewReader(plain), aesKey, nonce, aad); err != nil {
t.Fatalf("EncryptAesCCMStream failed: %v", err)
}
aesStreamDec := &bytes.Buffer{}
if err := DecryptAesCCMStream(aesStreamDec, bytes.NewReader(aesStreamEnc.Bytes()), aesKey, nonce, aad); err != nil {
t.Fatalf("DecryptAesCCMStream failed: %v", err)
}
if !bytes.Equal(aesStreamDec.Bytes(), plain) {
t.Fatalf("aes ccm stream wrapper mismatch")
}
sm4Enc, err := EncryptSM4CCM(plain, sm4Key, nonce, aad)
if err != nil {
t.Fatalf("EncryptSM4CCM failed: %v", err)
}
sm4Dec, err := DecryptSM4CCM(sm4Enc, sm4Key, nonce, aad)
if err != nil {
t.Fatalf("DecryptSM4CCM failed: %v", err)
}
if !bytes.Equal(sm4Dec, plain) {
t.Fatalf("sm4 ccm wrapper mismatch")
}
sm4Opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
sm4OptEnc, err := EncryptSM4WithOptions(plain, sm4Key, sm4Opts)
if err != nil {
t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err)
}
sm4OptDec, err := DecryptSM4WithOptions(sm4OptEnc, sm4Key, sm4Opts)
if err != nil {
t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err)
}
if !bytes.Equal(sm4OptDec, plain) {
t.Fatalf("sm4 options ccm wrapper mismatch")
}
sm4ChunkCipher, err := EncryptSM4CCMChunk([]byte("root-sm4-ccm-chunk"), sm4Key, nonce, aad, 6)
if err != nil {
t.Fatalf("EncryptSM4CCMChunk failed: %v", err)
}
sm4ChunkPlain, err := DecryptSM4CCMChunk(sm4ChunkCipher, sm4Key, nonce, aad, 6)
if err != nil {
t.Fatalf("DecryptSM4CCMChunk failed: %v", err)
}
if !bytes.Equal(sm4ChunkPlain, []byte("root-sm4-ccm-chunk")) {
t.Fatalf("root sm4 ccm chunk mismatch")
}
sm4StreamEnc := &bytes.Buffer{}
if err := EncryptSM4CCMStream(sm4StreamEnc, bytes.NewReader(plain), sm4Key, nonce, aad); err != nil {
t.Fatalf("EncryptSM4CCMStream failed: %v", err)
}
sm4StreamDec := &bytes.Buffer{}
if err := DecryptSM4CCMStream(sm4StreamDec, bytes.NewReader(sm4StreamEnc.Bytes()), sm4Key, nonce, aad); err != nil {
t.Fatalf("DecryptSM4CCMStream failed: %v", err)
}
if !bytes.Equal(sm4StreamDec.Bytes(), plain) {
t.Fatalf("sm4 ccm stream wrapper mismatch")
}
}
func TestRootXTSWrappers(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 16)
aesEnc, err := EncryptAesXTS(plain, k1, k2, 64)
if err != nil {
t.Fatalf("EncryptAesXTS failed: %v", err)
}
aesDec, err := DecryptAesXTS(aesEnc, k1, k2, 64)
if err != nil {
t.Fatalf("DecryptAesXTS failed: %v", err)
}
if !bytes.Equal(aesDec, plain) {
t.Fatalf("root aes xts mismatch")
}
aesSegEnc, err := EncryptAesXTSAt(plain[:64], k1, k2, 64, 1)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
aesSegDec, err := DecryptAesXTSAt(aesSegEnc, k1, k2, 64, 1)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(aesSegDec, plain[:64]) {
t.Fatalf("root aes xts at mismatch")
}
aesStreamEnc := &bytes.Buffer{}
if err := EncryptAesXTSStream(aesStreamEnc, bytes.NewReader(plain), k1, k2, 64); err != nil {
t.Fatalf("EncryptAesXTSStream failed: %v", err)
}
aesStreamDec := &bytes.Buffer{}
if err := DecryptAesXTSStream(aesStreamDec, bytes.NewReader(aesStreamEnc.Bytes()), k1, k2, 64); err != nil {
t.Fatalf("DecryptAesXTSStream failed: %v", err)
}
if !bytes.Equal(aesStreamDec.Bytes(), plain) {
t.Fatalf("root aes xts stream mismatch")
}
sm4Enc, err := EncryptSM4XTS(plain, k1, k2, 64)
if err != nil {
t.Fatalf("EncryptSM4XTS failed: %v", err)
}
sm4Dec, err := DecryptSM4XTS(sm4Enc, k1, k2, 64)
if err != nil {
t.Fatalf("DecryptSM4XTS failed: %v", err)
}
if !bytes.Equal(sm4Dec, plain) {
t.Fatalf("root sm4 xts mismatch")
}
sm4SegEnc, err := EncryptSM4XTSAt(plain[:64], k1, k2, 64, 2)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
sm4SegDec, err := DecryptSM4XTSAt(sm4SegEnc, k1, k2, 64, 2)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(sm4SegDec, plain[:64]) {
t.Fatalf("root sm4 xts at mismatch")
}
sm4StreamEnc := &bytes.Buffer{}
if err := EncryptSM4XTSStream(sm4StreamEnc, bytes.NewReader(plain), k1, k2, 64); err != nil {
t.Fatalf("EncryptSM4XTSStream failed: %v", err)
}
sm4StreamDec := &bytes.Buffer{}
if err := DecryptSM4XTSStream(sm4StreamDec, bytes.NewReader(sm4StreamEnc.Bytes()), k1, k2, 64); err != nil {
t.Fatalf("DecryptSM4XTSStream failed: %v", err)
}
if !bytes.Equal(sm4StreamDec.Bytes(), plain) {
t.Fatalf("root sm4 xts stream mismatch")
}
}
func TestRootKDFAndXTSKeySplitWrappers(t *testing.T) {
pbk, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 1, 32)
if err != nil {
t.Fatalf("DerivePBKDF2SHA256Key failed: %v", err)
}
if len(pbk) != 32 {
t.Fatalf("pbkdf2 key length mismatch")
}
argonParams := DefaultArgon2idParams()
argonParams.Memory = 32 * 1024
argonParams.Threads = 1
argon, err := DeriveArgon2idKey("password", []byte("salt-salt"), argonParams)
if err != nil {
t.Fatalf("DeriveArgon2idKey failed: %v", err)
}
if len(argon) != int(argonParams.KeyLen) {
t.Fatalf("argon2 key length mismatch")
}
master := []byte("0123456789abcdef0123456789abcdef")
k1, k2, err := SplitXTSMasterKey(master)
if err != nil {
t.Fatalf("SplitXTSMasterKey failed: %v", err)
}
if len(k1) != 16 || len(k2) != 16 {
t.Fatalf("split xts wrapper key lengths mismatch")
}
}

View File

@ -10,6 +10,7 @@ import (
"encoding/pem"
"errors"
"github.com/emmansun/gmsm/pkcs8"
"github.com/emmansun/gmsm/sm2"
"github.com/emmansun/gmsm/smx509"
"golang.org/x/crypto/ssh"
@ -86,6 +87,11 @@ func DecodePrivateKey(private []byte, password string) (crypto.PrivateKey, error
return key, nil
}
return smx509.ParsePKCS8PrivateKey(bytes)
case "ENCRYPTED PRIVATE KEY":
if password == "" {
return nil, errors.New("private key is encrypted but password is empty")
}
return pkcs8.ParsePKCS8PrivateKey(blk.Bytes, []byte(password))
case "OPENSSH PRIVATE KEY":
if password == "" {
return ssh.ParseRawPrivateKey(private)

View File

@ -9,6 +9,7 @@ import (
"errors"
"math/big"
"github.com/emmansun/gmsm/pkcs8"
"golang.org/x/crypto/ssh"
)
@ -21,6 +22,37 @@ func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) {
}
func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) {
return EncodeRsaPrivateKeyWithLegacy(private, secret, true)
}
func EncodeRsaPrivateKeyWithLegacy(private *rsa.PrivateKey, secret string, legacy bool) ([]byte, error) {
if legacy {
return encodeRsaPrivateKeyLegacy(private, secret)
}
return EncodeRsaPrivateKeyPKCS8(private, secret)
}
func EncodeRsaPrivateKeyPKCS8(private *rsa.PrivateKey, secret string) ([]byte, error) {
password := []byte(secret)
var (
der []byte
blockType = "PRIVATE KEY"
err error
)
if secret == "" {
der, err = pkcs8.MarshalPrivateKey(private, nil, nil)
} else {
der, err = pkcs8.MarshalPrivateKey(private, password, pkcs8.DefaultOpts)
blockType = "ENCRYPTED PRIVATE KEY"
}
if err != nil {
return nil, err
}
return pem.EncodeToMemory(&pem.Block{Type: blockType, Bytes: der}), nil
}
func encodeRsaPrivateKeyLegacy(private *rsa.PrivateKey, secret string) ([]byte, error) {
der := x509.MarshalPKCS1PrivateKey(private)
if secret == "" {
return pem.EncodeToMemory(&pem.Block{
@ -52,6 +84,46 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro
return nil, errors.New("private key error")
}
switch blk.Type {
case "PRIVATE KEY", "ENCRYPTED PRIVATE KEY":
return DecodeRsaPrivateKeyPKCS8(private, password)
default:
return decodeRsaPrivateKeyLegacy(private, password)
}
}
func DecodeRsaPrivateKeyWithLegacy(private []byte, password string, legacy bool) (*rsa.PrivateKey, error) {
if legacy {
return decodeRsaPrivateKeyLegacy(private, password)
}
return DecodeRsaPrivateKeyPKCS8(private, password)
}
func DecodeRsaPrivateKeyPKCS8(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
switch blk.Type {
case "PRIVATE KEY":
return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes)
case "ENCRYPTED PRIVATE KEY":
if password == "" {
return nil, errors.New("private key is encrypted but password is empty")
}
return pkcs8.ParsePKCS8PrivateKeyRSA(blk.Bytes, []byte(password))
default:
return nil, errors.New("private key is not PKCS#8")
}
}
func decodeRsaPrivateKeyLegacy(private []byte, password string) (*rsa.PrivateKey, error) {
blk, _ := pem.Decode(private)
if blk == nil {
return nil, errors.New("private key error")
}
bytes, err := decodePEMBlockBytes(blk, password)
if err != nil {
return nil, err
@ -60,11 +132,11 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro
if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil {
return prikey, nil
}
pkcs8, err := x509.ParsePKCS8PrivateKey(bytes)
pkcs8key, err := x509.ParsePKCS8PrivateKey(bytes)
if err != nil {
return nil, err
}
prikey, ok := pkcs8.(*rsa.PrivateKey)
prikey, ok := pkcs8key.(*rsa.PrivateKey)
if !ok {
return nil, errors.New("private key is not RSA")
}
@ -96,6 +168,10 @@ func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) {
}
func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) {
return GenerateRsaSSHKeyPairWithLegacy(bits, secret, true)
}
func GenerateRsaSSHKeyPairWithLegacy(bits int, secret string, legacy bool) (string, string, error) {
pkey, pubkey, err := GenerateRsaKey(bits)
if err != nil {
return "", "", err
@ -104,7 +180,7 @@ func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) {
if err != nil {
return "", "", err
}
priv, err := EncodeRsaPrivateKey(pkey, secret)
priv, err := EncodeRsaPrivateKeyWithLegacy(pkey, secret, legacy)
if err != nil {
return "", "", err
}
@ -119,6 +195,22 @@ func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.DecryptPKCS1v15(rand.Reader, prikey, data)
}
func RSAEncryptOAEP(pub *rsa.PublicKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
hashType, err := normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
return rsa.EncryptOAEP(hashType.New(), rand.Reader, pub, data, label)
}
func RSADecryptOAEP(prikey *rsa.PrivateKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
hashType, err := normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
return rsa.DecryptOAEP(hashType.New(), rand.Reader, prikey, data, label)
}
func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) {
prikey, err := DecodeRsaPrivateKey(priKey, password)
if err != nil {
@ -143,6 +235,38 @@ func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error {
return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig)
}
func RSASignPSS(msg, priKey []byte, password string, hashType crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) {
prikey, err := DecodeRsaPrivateKey(priKey, password)
if err != nil {
return nil, err
}
hashType, err = normalizeModernRSAHash(hashType)
if err != nil {
return nil, err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return nil, err
}
return rsa.SignPSS(rand.Reader, prikey, hashType, hashed, opts)
}
func RSAVerifyPSS(sig, msg, pubKey []byte, hashType crypto.Hash, opts *rsa.PSSOptions) error {
pubkey, err := DecodeRsaPublicKey(pubKey)
if err != nil {
return err
}
hashType, err = normalizeModernRSAHash(hashType)
if err != nil {
return err
}
hashed, err := hashMessage(msg, hashType)
if err != nil {
return err
}
return rsa.VerifyPSS(pubkey, hashType, hashed, sig, opts)
}
func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) {
return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data)
}
@ -154,6 +278,16 @@ func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) {
return unLeftPad(em)
}
func normalizeModernRSAHash(hashType crypto.Hash) (crypto.Hash, error) {
if hashType == 0 {
hashType = crypto.SHA256
}
if !hashType.Available() {
return 0, errors.New("hash function is not available")
}
return hashType, nil
}
func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) {
if hashType == 0 {
return msg, nil

134
asymm/rsa_test.go Normal file
View File

@ -0,0 +1,134 @@
package asymm
import (
"bytes"
"crypto/rsa"
"testing"
)
func TestRsaPrivateKeyEncodeDecodeWithLegacyFlag(t *testing.T) {
priv, pub, err := GenerateRsaKey(1024)
if err != nil {
t.Fatalf("GenerateRsaKey failed: %v", err)
}
modernPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", false)
if err != nil {
t.Fatalf("EncodeRsaPrivateKeyWithLegacy modern failed: %v", err)
}
if !bytes.Contains(modernPEM, []byte("ENCRYPTED PRIVATE KEY")) {
t.Fatalf("modern PEM should be ENCRYPTED PRIVATE KEY")
}
modernPriv, err := DecodeRsaPrivateKeyWithLegacy(modernPEM, "pwd", false)
if err != nil {
t.Fatalf("DecodeRsaPrivateKeyWithLegacy modern failed: %v", err)
}
if modernPriv.D.Cmp(priv.D) != 0 {
t.Fatalf("modern decoded private key mismatch")
}
autoPriv, err := DecodeRsaPrivateKey(modernPEM, "pwd")
if err != nil {
t.Fatalf("DecodeRsaPrivateKey auto failed: %v", err)
}
if autoPriv.D.Cmp(priv.D) != 0 {
t.Fatalf("auto decoded private key mismatch")
}
if _, err := DecodePrivateKey(modernPEM, "pwd"); err != nil {
t.Fatalf("DecodePrivateKey for encrypted PKCS8 failed: %v", err)
}
legacyPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", true)
if err != nil {
t.Fatalf("EncodeRsaPrivateKeyWithLegacy legacy failed: %v", err)
}
if !bytes.Contains(legacyPEM, []byte("RSA PRIVATE KEY")) {
t.Fatalf("legacy PEM should be RSA PRIVATE KEY")
}
legacyPriv, err := DecodeRsaPrivateKeyWithLegacy(legacyPEM, "pwd", true)
if err != nil {
t.Fatalf("DecodeRsaPrivateKeyWithLegacy legacy failed: %v", err)
}
if legacyPriv.D.Cmp(priv.D) != 0 {
t.Fatalf("legacy decoded private key mismatch")
}
pubPEM, err := EncodeRsaPublicKey(pub)
if err != nil {
t.Fatalf("EncodeRsaPublicKey failed: %v", err)
}
decodedPub, err := DecodeRsaPublicKey(pubPEM)
if err != nil {
t.Fatalf("DecodeRsaPublicKey failed: %v", err)
}
if decodedPub.N.Cmp(pub.N) != 0 || decodedPub.E != pub.E {
t.Fatalf("decoded public key mismatch")
}
}
func TestRSAOAEPEncryptDecrypt(t *testing.T) {
priv, pub, err := GenerateRsaKey(1024)
if err != nil {
t.Fatalf("GenerateRsaKey failed: %v", err)
}
msg := []byte("rsa-oaep-message")
label := []byte("label")
enc, err := RSAEncryptOAEP(pub, msg, label, 0)
if err != nil {
t.Fatalf("RSAEncryptOAEP failed: %v", err)
}
dec, err := RSADecryptOAEP(priv, enc, label, 0)
if err != nil {
t.Fatalf("RSADecryptOAEP failed: %v", err)
}
if !bytes.Equal(dec, msg) {
t.Fatalf("oaep decrypt mismatch")
}
}
func TestRSAPSSSignVerify(t *testing.T) {
priv, pub, err := GenerateRsaKey(1024)
if err != nil {
t.Fatalf("GenerateRsaKey failed: %v", err)
}
privPEM, err := EncodeRsaPrivateKeyWithLegacy(priv, "pwd", false)
if err != nil {
t.Fatalf("EncodeRsaPrivateKeyWithLegacy failed: %v", err)
}
pubPEM, err := EncodeRsaPublicKey(pub)
if err != nil {
t.Fatalf("EncodeRsaPublicKey failed: %v", err)
}
msg := []byte("rsa-pss-message")
sig, err := RSASignPSS(msg, privPEM, "pwd", 0, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash})
if err != nil {
t.Fatalf("RSASignPSS failed: %v", err)
}
if err := RSAVerifyPSS(sig, msg, pubPEM, 0, &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}); err != nil {
t.Fatalf("RSAVerifyPSS failed: %v", err)
}
}
func TestRSAPKCS1v15EncryptDecryptCompatibility(t *testing.T) {
priv, pub, err := GenerateRsaKey(1024)
if err != nil {
t.Fatalf("GenerateRsaKey failed: %v", err)
}
msg := []byte("rsa-pkcs1v15-message")
enc, err := RSAEncrypt(pub, msg)
if err != nil {
t.Fatalf("RSAEncrypt failed: %v", err)
}
dec, err := RSADecrypt(priv, enc)
if err != nil {
t.Fatalf("RSADecrypt failed: %v", err)
}
if !bytes.Equal(dec, msg) {
t.Fatalf("pkcs1v15 decrypt mismatch")
}
}

259
ccm/ccm.go Normal file
View File

@ -0,0 +1,259 @@
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
// Package ccm implements a CCM, Counter with CBC-MAC
// as per RFC 3610.
//
// See https://tools.ietf.org/html/rfc3610
//
// This code is derived from https://github.com/pion/dtls (pkg/crypto/ccm).
// The original upstream license is MIT and is preserved in this repository.
//
// A request for including CCM into the Go standard library
// can be found as issue #27484 on the https://github.com/golang/go/
// repository.
package ccm
import (
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
"errors"
"math"
)
// ccm represents a Counter with CBC-MAC with a specific key.
type ccm struct {
b cipher.Block
M uint8
L uint8
}
const ccmBlockSize = 16
// CCM is a block cipher in Counter with CBC-MAC mode.
// Providing authenticated encryption with associated data via the cipher.AEAD interface.
type CCM interface {
cipher.AEAD
// MaxLength returns the maxium length of plaintext in calls to Seal.
// The maximum length of ciphertext in calls to Open is MaxLength()+Overhead().
// The maximum length is related to CCM's `L` parameter (15-noncesize) and
// is 1<<(8*L) - 1 (but also limited by the maxium size of an int).
MaxLength() int
}
var (
errInvalidBlockSize = errors.New("ccm: NewCCM requires 128-bit block cipher")
errInvalidTagSize = errors.New("ccm: tagsize must be 4, 6, 8, 10, 12, 14, or 16")
errInvalidNonceSize = errors.New("ccm: invalid nonce size")
)
// NewCCM returns the given 128-bit block cipher wrapped in CCM.
// The tagsize must be an even integer between 4 and 16 inclusive
// and is used as CCM's `M` parameter.
// The noncesize must be an integer between 7 and 13 inclusive,
// 15-noncesize is used as CCM's `L` parameter.
func NewCCM(b cipher.Block, tagsize, noncesize int) (CCM, error) {
if b.BlockSize() != ccmBlockSize {
return nil, errInvalidBlockSize
}
if tagsize < 4 || tagsize > 16 || tagsize&1 != 0 {
return nil, errInvalidTagSize
}
lensize := 15 - noncesize
if lensize < 2 || lensize > 8 {
return nil, errInvalidNonceSize
}
c := &ccm{b: b, M: uint8(tagsize), L: uint8(lensize)} //nolint:gosec // G114
return c, nil
}
func (c *ccm) NonceSize() int { return 15 - int(c.L) }
func (c *ccm) Overhead() int { return int(c.M) }
func (c *ccm) MaxLength() int { return maxlen(c.L, c.Overhead()) }
func maxlen(l uint8, tagsize int) int {
mLen := (uint64(1) << (8 * l)) - 1
if m64 := uint64(math.MaxInt64) - uint64(tagsize); l > 8 || mLen > m64 { //nolint:gosec // G114
mLen = m64 // The maximum lentgh on a 64bit arch
}
if mLen != uint64(int(mLen)) { //nolint:gosec // G114
return math.MaxInt32 - tagsize // We have only 32bit int's
}
return int(mLen) //nolint:gosec // G114
}
// MaxNonceLength returns the maximum nonce length for a given plaintext length.
// A return value <= 0 indicates that plaintext length is too large for
// any nonce length.
func MaxNonceLength(pdatalen int) int {
const tagsize = 16
for L := 2; L <= 8; L++ {
if maxlen(uint8(L), tagsize) >= pdatalen { //nolint:gosec // G115
return 15 - L
}
}
return 0
}
func (c *ccm) cbcRound(mac, data []byte) {
for i := range ccmBlockSize {
mac[i] ^= data[i]
}
c.b.Encrypt(mac, mac)
}
func (c *ccm) cbcData(mac, data []byte) {
for len(data) >= ccmBlockSize {
c.cbcRound(mac, data[:ccmBlockSize])
data = data[ccmBlockSize:]
}
if len(data) > 0 {
var block [ccmBlockSize]byte
copy(block[:], data)
c.cbcRound(mac, block[:])
}
}
var errPlaintextTooLong = errors.New("ccm: plaintext too large")
func (c *ccm) tag(nonce, plaintext, adata []byte) ([]byte, error) {
var mac [ccmBlockSize]byte
if len(adata) > 0 {
mac[0] |= 1 << 6
}
mac[0] |= (c.M - 2) << 2
mac[0] |= c.L - 1
if len(nonce) != c.NonceSize() {
return nil, errInvalidNonceSize
}
if len(plaintext) > c.MaxLength() {
return nil, errPlaintextTooLong
}
binary.BigEndian.PutUint64(mac[ccmBlockSize-8:], uint64(len(plaintext)))
copy(mac[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(mac[:], mac[:])
var block [ccmBlockSize]byte
if adataLength := uint64(len(adata)); adataLength > 0 { //nolint:nestif
// First adata block includes adata length
i := 2
if adataLength <= 0xfeff {
binary.BigEndian.PutUint16(block[:i], uint16(adataLength))
} else {
binary.BigEndian.PutUint16(block[0:2], 0xfeff)
if adataLength < uint64(1<<32) {
i = 2 + 4
binary.BigEndian.PutUint32(block[2:i], uint32(adataLength)) //nolint:gosec // G115
} else {
i = 2 + 8
binary.BigEndian.PutUint64(block[2:i], adataLength)
}
}
i = copy(block[i:], adata)
c.cbcRound(mac[:], block[:])
c.cbcData(mac[:], adata[i:])
}
if len(plaintext) > 0 {
c.cbcData(mac[:], plaintext)
}
return mac[:c.M], nil
}
// sliceForAppend takes a slice and a requested number of bytes. It returns a
// slice with the contents of the given slice followed by that many bytes and a
// second slice that aliases into it and contains only the extra bytes. If the
// original slice has sufficient capacity then no allocation is performed.
// From crypto/cipher/gcm.go
// .
func sliceForAppend(in []byte, n int) (head, tail []byte) {
if total := len(in) + n; cap(in) >= total {
head = in[:total]
} else {
head = make([]byte, total)
copy(head, in)
}
tail = head[len(in):]
return
}
// Seal encrypts and authenticates plaintext, authenticates the
// additional data and appends the result to dst, returning the updated
// slice. The nonce must be NonceSize() bytes long and unique for all
// time, for a given key.
// The plaintext must be no longer than MaxLength() bytes long.
//
// The plaintext and dst may alias exactly or not at all.
func (c *ccm) Seal(dst, nonce, plaintext, adata []byte) []byte {
tag, err := c.tag(nonce, plaintext, adata)
if err != nil {
// The cipher.AEAD interface doesn't allow for an error return.
panic(err) // nolint
}
var iv, s0 [ccmBlockSize]byte
iv[0] = c.L - 1
copy(iv[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(s0[:], iv[:])
for i := 0; i < int(c.M); i++ {
tag[i] ^= s0[i]
}
iv[len(iv)-1] |= 1
stream := cipher.NewCTR(c.b, iv[:])
ret, out := sliceForAppend(dst, len(plaintext)+int(c.M))
stream.XORKeyStream(out, plaintext)
copy(out[len(plaintext):], tag)
return ret
}
var (
errOpen = errors.New("ccm: message authentication failed")
errCiphertextTooShort = errors.New("ccm: ciphertext too short")
errCiphertextTooLong = errors.New("ccm: ciphertext too long")
)
func (c *ccm) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
if len(ciphertext) < int(c.M) {
return nil, errCiphertextTooShort
}
if len(ciphertext) > c.MaxLength()+c.Overhead() {
return nil, errCiphertextTooLong
}
tag := make([]byte, int(c.M))
copy(tag, ciphertext[len(ciphertext)-int(c.M):])
ciphertextWithoutTag := ciphertext[:len(ciphertext)-int(c.M)]
var iv, s0 [ccmBlockSize]byte
iv[0] = c.L - 1
copy(iv[1:ccmBlockSize-c.L], nonce)
c.b.Encrypt(s0[:], iv[:])
for i := 0; i < int(c.M); i++ {
tag[i] ^= s0[i]
}
iv[len(iv)-1] |= 1
stream := cipher.NewCTR(c.b, iv[:])
// Cannot decrypt directly to dst since we're not supposed to
// reveal the plaintext to the caller if authentication fails.
plaintext := make([]byte, len(ciphertextWithoutTag))
stream.XORKeyStream(plaintext, ciphertextWithoutTag)
expectedTag, err := c.tag(nonce, plaintext, adata)
if err != nil {
return nil, err
}
if subtle.ConstantTimeCompare(tag, expectedTag) != 1 {
return nil, errOpen
}
return append(dst, plaintext...), nil
}

451
ccm/ccm_test.go Normal file
View File

@ -0,0 +1,451 @@
// SPDX-FileCopyrightText: 2026 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package ccm
// Refer to RFC 3610 section 8 for the vectors.
import (
"crypto/aes"
"encoding/hex"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func mustHexDecode(t *testing.T, s string) []byte {
t.Helper()
r, err := hex.DecodeString(s)
assert.NoError(t, err)
return r
}
func aesKey1to12(t *testing.T) []byte {
t.Helper()
return mustHexDecode(t, "c0c1c2c3c4c5c6c7c8c9cacbcccdcecf")
}
func aesKey13to24(t *testing.T) []byte {
t.Helper()
return mustHexDecode(t, "d7828d13b2b0bdc325a76236df93cc6b")
}
// AESKey: AES Key
// CipherText: Authenticated and encrypted output
// ClearHeaderOctets: Input with X cleartext header octets
// Data: Input with X cleartext header octets
// M: length(CBC-MAC)
// Nonce: Nonce.
type vector struct {
AESKey []byte
CipherText []byte
ClearHeaderOctets int
Data []byte
M int
Nonce []byte
}
func TestRFC3610Vectors(t *testing.T) { //nolint:maintidx
cases := []vector{
// Vectors 1-12
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"0001020304050607588c979a61c663d2f066d0c2c0f989806d5f6b61dac38417e8d12cfdf926e0"),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"),
M: 8,
Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060772c91a36e135f8cf291ca894085c87e3cc15c439c9e43a3ba091d56e10400916"),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090A0B0C0D0E0F101112131415161718191A1B1C1D1E1F"),
M: 8,
Nonce: mustHexDecode(t, "00000004030201a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060751b1e5f44a197d1da46b0f8e2d282ae871e838bb64da8596574adaa76fbd9fb0c5",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"),
M: 8,
Nonce: mustHexDecode(t, "00000005040302a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0ba28c6865939a9a79faaa5c4c2a9d4a91cdac8c96c861b9c9e61ef1"),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"),
M: 8,
Nonce: mustHexDecode(t, "00000006050403a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0bdcf1fb7b5d9e23fb9d4e131253658ad86ebdca3e51e83f077d9c2d93"),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"),
M: 8,
Nonce: mustHexDecode(t, "00000007060504a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0b6fc1b011f006568b5171a42d953d469b2570a4bd87405a0443ac91cb94",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"),
M: 8,
Nonce: mustHexDecode(t, "00000008070605a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"00010203040506070135d1b2c95f41d5d1d4fec185d166b8094e999dfed96c048c56602c97acbb7490",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"),
M: 10,
Nonce: mustHexDecode(t, "00000009080706a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"00010203040506077b75399ac0831dd2f0bbd75879a2fd8f6cae6b6cd9b7db24c17b4433f434963f34b4",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"),
M: 10,
Nonce: mustHexDecode(t, "0000000a090807a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060782531a60cc24945a4b8279181ab5c84df21ce7f9b73f42e197ea9c07e56b5eb17e5f4e",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"),
M: 10,
Nonce: mustHexDecode(t, "0000000b0a0908a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0b07342594157785152b074098330abb141b947b566aa9406b4d999988dd",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"),
M: 10,
Nonce: mustHexDecode(t, "0000000c0b0a09a0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0b676bb20380b0e301e8ab79590a396da78b834934f53aa2e9107a8b6c022c",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f"),
M: 10,
Nonce: mustHexDecode(t, "0000000d0c0b0aa0a1a2a3a4a5"),
},
{
AESKey: aesKey1to12(t),
CipherText: mustHexDecode(t,
"000102030405060708090a0bc0ffa0d6f05bdb67f24d43a4338d2aa4bed7b20e43cd1aa31662e7ad65d6db",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f20"),
M: 10,
Nonce: mustHexDecode(t, "0000000e0d0c0ba0a1a2a3a4a5"),
},
// Vectors 13-24
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"0be1a88bace018b14cb97f86a2a4689a877947ab8091ef5386a6ffbdd080f8e78cf7cb0cddd7b3"),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "0be1a88bace018b108e8cf97d820ea258460e96ad9cf5289054d895ceac47c"),
M: 8,
Nonce: mustHexDecode(t, "00412b4ea9cdbe3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"63018f76dc8a1bcb4ccb1e7ca981befaa0726c55d378061298c85c92814abc33c52ee81d7d77c08a"),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "63018f76dc8a1bcb9020ea6f91bdd85afa0039ba4baff9bfb79c7028949cd0ec"),
M: 8,
Nonce: mustHexDecode(t, "0033568ef7b2633c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"aa6cfa36cae86b40b1d23a2220ddc0ac900d9aa03c61fcf4a559a4417767089708a776796edb723506",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "aa6cfa36cae86b40b916e0eacc1c00d7dcec68ec0b3bbb1a02de8a2d1aa346132e"),
M: 8,
Nonce: mustHexDecode(t, "00103fe41336713c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"d0d0735c531e1becf049c24414d253c3967b70609b7cbb7c499160283245269a6f49975bcadeaf"),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "d0d0735c531e1becf049c24412daac5630efa5396f770ce1a66b21f7b2101c"),
M: 8,
Nonce: mustHexDecode(t, "00764c63b8058e3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"77b60f011c03e1525899bcae5545ff1a085ee2efbf52b2e04bee1e2336c73e3f762c0c7744fe7e3c"),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "77b60f011c03e1525899bcaee88b6a46c78d63e52eb8c546efb5de6f75e9cc0d"),
M: 8,
Nonce: mustHexDecode(t, "00f8b678094e3b3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"cd9044d2b71fdb8120ea60c0009769ecabdf48625594c59251e6035722675e04c847099e5ae0704551",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "cd9044d2b71fdb8120ea60c06435acbafb11a82e2f071d7ca4a5ebd93a803ba87f"),
M: 8,
Nonce: mustHexDecode(t, "00d560912d3f703c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"d85bc7e69f944fb8bc218daa947427b6db386a99ac1aef23ade0b52939cb6a637cf9bec2408897c6ba",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "d85bc7e69f944fb88a19b950bcf71a018e5e6701c91787659809d67dbedd18"),
M: 10,
Nonce: mustHexDecode(t, "0042fff8f1951c3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"74a0ebc9069f5b375810e6fd25874022e80361a478e3e9cf484ab04f447efff6f0a477cc2fc9bf548944",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "74a0ebc9069f5b371761433c37c5a35fc1f39f406302eb907c6163be38c98437"),
M: 10,
Nonce: mustHexDecode(t, "00920f40e56cdc3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"44a3aa3aae6475caf2beed7bc5098e83feb5b31608f8e29c38819a89c8e776f1544d4151a4ed3a8b87b9ce",
),
ClearHeaderOctets: 8,
Data: mustHexDecode(t, "44a3aa3aae6475caa434a8e58500c6e41530538862d686ea9e81301b5ae4226bfa"),
M: 10,
Nonce: mustHexDecode(t, "0027ca0c7120bc3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"ec46bb63b02520c33c49fd7031d750a09da3ed7fddd49a2032aabf17ec8ebf7d22c8088c666be5c197",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "ec46bb63b02520c33c49fd70b96b49e21d621741632875db7f6c9243d2d7c2"),
M: 10,
Nonce: mustHexDecode(t, "005b8ccbcd9af83c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"47a65ac78b3d594227e85e71e882f1dbd38ce3eda7c23f04dd65071eb41342acdf7e00dccec7ae52987d",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "47a65ac78b3d594227e85e71e2fcfbb880442c731bf95167c8ffd7895e337076"),
M: 10,
Nonce: mustHexDecode(t, "003ebe94044b9a3c9696766cfa"),
},
{
AESKey: aesKey13to24(t),
CipherText: mustHexDecode(t,
"6e37a6ef546d955d34ab6059f32905b88a641b04b9c9ffb58cc390900f3da12ab16dce9e82efa16da62059",
),
ClearHeaderOctets: 12,
Data: mustHexDecode(t, "6e37a6ef546d955d34ab6059abf21c0b02feb88f856df4a37381bce3cc128517d4"),
M: 10,
Nonce: mustHexDecode(t, "008d493b30ae8b3c9696766cfa"),
},
}
assert.Equal(t, 24, len(cases))
for idx, testCase := range cases {
t.Run(fmt.Sprintf("packet vector #%d", idx+1), func(t *testing.T) {
blk, err := aes.NewCipher(testCase.AESKey)
assert.NoError(t, err, "could not initialize AES block cipher from key")
lccm, err := NewCCM(blk, testCase.M, len(testCase.Nonce))
assert.NoError(t, err, "could not create CCM")
t.Run("seal", func(t *testing.T) {
var dst []byte
dst = lccm.Seal(
dst,
testCase.Nonce,
testCase.Data[testCase.ClearHeaderOctets:],
testCase.Data[:testCase.ClearHeaderOctets],
)
assert.Equal(t, testCase.CipherText[testCase.ClearHeaderOctets:], dst)
})
t.Run("open", func(t *testing.T) {
var dst []byte
dst, err = lccm.Open(
dst,
testCase.Nonce,
testCase.CipherText[testCase.ClearHeaderOctets:],
testCase.CipherText[:testCase.ClearHeaderOctets],
)
assert.NoError(t, err)
assert.Equal(t, testCase.Data[testCase.ClearHeaderOctets:], dst)
})
})
}
}
func TestNewCCMError(t *testing.T) {
cases := map[string]struct {
vector
err error
}{
"ShortNonceLength": {
vector{
AESKey: aesKey1to12(t),
M: 8,
Nonce: mustHexDecode(t, "a0a1a2a3a4a5"),
}, errInvalidNonceSize,
},
"LongNonceLength": {
vector{
AESKey: aesKey1to12(t),
M: 8,
Nonce: mustHexDecode(t, "0001020304050607080910111213"),
}, errInvalidNonceSize,
},
"ShortTag": {
vector{
AESKey: aesKey1to12(t),
M: 3,
Nonce: mustHexDecode(t, "00010203040506070809101112"),
}, errInvalidTagSize,
},
"LongTag": {
vector{
AESKey: aesKey1to12(t),
M: 17,
Nonce: mustHexDecode(t, "00010203040506070809101112"),
}, errInvalidTagSize,
},
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
blk, err := aes.NewCipher(c.AESKey)
assert.NoError(t, err, "could not initialize AES block cipher from key")
_, err = NewCCM(blk, c.M, len(c.Nonce))
assert.ErrorIs(t, err, c.err)
})
}
}
func TestSealError(t *testing.T) {
cases := map[string]struct {
vector
err error
}{
"InvalidNonceLength": {
vector{
Data: mustHexDecode(t, "000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e"),
M: 8,
Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4"), // short
}, errInvalidNonceSize,
},
"PlaintextTooLong": {
vector{
Data: make([]byte, 100000),
M: 8,
Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"),
}, errPlaintextTooLong,
},
}
blk, err := aes.NewCipher(aesKey1to12(t))
assert.NoError(t, err)
lccm, err := NewCCM(blk, 8, 13)
assert.NoError(t, err)
for name, testCase := range cases {
t.Run(name, func(t *testing.T) {
defer func() {
err, ok := recover().(error)
assert.True(t, ok)
assert.ErrorIs(t, err, testCase.err)
}()
var dst []byte
_ = lccm.Seal(
dst,
testCase.Nonce,
testCase.Data[testCase.ClearHeaderOctets:],
testCase.Data[:testCase.ClearHeaderOctets],
)
})
}
}
func TestOpenError(t *testing.T) {
cases := map[string]struct {
vector
err error
}{
"CiphertextTooShort": {
vector{
CipherText: make([]byte, 10),
ClearHeaderOctets: 8,
Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"),
}, errCiphertextTooShort,
},
"CiphertextTooLong": {
vector{
CipherText: make([]byte, 100000),
ClearHeaderOctets: 8,
Nonce: mustHexDecode(t, "00000003020100a0a1a2a3a4a5"),
}, errCiphertextTooLong,
},
}
blk, err := aes.NewCipher(aesKey1to12(t))
assert.NoError(t, err, "could not initialize AES block cipher from key")
lccm, err := NewCCM(blk, 8, 13)
assert.NoError(t, err, "could not create CCM")
for name, c := range cases {
t.Run(name, func(t *testing.T) {
var dst []byte
_, err = lccm.Open(dst, c.Nonce, c.CipherText[c.ClearHeaderOctets:], c.CipherText[:c.ClearHeaderOctets])
assert.ErrorIs(t, err, c.err)
})
}
}

View File

@ -6,6 +6,7 @@ import (
"errors"
"io"
"os"
"strings"
)
var (
@ -192,12 +193,12 @@ func Base85Encode(bstr []byte) string {
}
func Base85Decode(str string) ([]byte, error) {
out := make([]byte, len(str))
n, _, err := ascii85.Decode(out, []byte(str), true)
dec := ascii85.NewDecoder(strings.NewReader(str))
out, err := io.ReadAll(dec)
if err != nil {
return nil, err
}
return out[:n], nil
return out, nil
}
func Base85EncodeFile(src, dst string, progress func(float64)) error {

View File

@ -96,3 +96,36 @@ func TestBase64AndBase85FileRoundTrip(t *testing.T) {
t.Fatalf("base85 file roundtrip mismatch")
}
}
func TestBase85RoundTripEdgeLengths(t *testing.T) {
for n := 0; n <= 128; n++ {
plain := make([]byte, n)
for i := range plain {
plain[i] = byte((i*37 + n) % 256)
}
e := Base85Encode(plain)
d, err := Base85Decode(e)
if err != nil {
t.Fatalf("Base85Decode failed at len=%d: %v", n, err)
}
if !bytes.Equal(d, plain) {
t.Fatalf("base85 mismatch at len=%d", n)
}
}
}
func TestBase91RoundTripEdgeLengths(t *testing.T) {
for n := 0; n <= 256; n++ {
plain := make([]byte, n)
for i := range plain {
plain[i] = byte((i*53 + n) % 256)
}
e := Base91Encode(plain)
d := Base91Decode(e)
if !bytes.Equal(d, plain) {
t.Fatalf("base91 mismatch at len=%d", n)
}
}
}

View File

@ -18,6 +18,13 @@ func MergeFile(src, dst string, shell func(float64)) error {
return filex.MergeFile(src, dst, shell)
}
// FillWithRandom uses math/rand pseudo-random bytes and is not cryptographically secure.
func FillWithRandom(filepath string, filesize int, bufcap int, bufnum int, shell func(float64)) error {
return filex.FillWithRandom(filepath, filesize, bufcap, bufnum, shell)
}
// FillWithCryptoRandom uses crypto/rand secure random bytes.
// Throughput may be lower than FillWithRandom.
func FillWithCryptoRandom(filepath string, filesize int, bufcap int, shell func(float64)) error {
return filex.FillWithCryptoRandom(filepath, filesize, bufcap, shell)
}

View File

@ -2,10 +2,11 @@ package filex
import (
"bufio"
crand "crypto/rand"
"errors"
"fmt"
"io"
"math/rand"
mrand "math/rand"
"os"
"path/filepath"
"regexp"
@ -15,6 +16,8 @@ import (
"time"
)
var ErrInvalidSplitPattern = errors.New("split dst pattern must contain exactly one '*'")
func Attach(src, dst, output string) error {
fpsrc, err := os.Open(src)
if err != nil {
@ -81,6 +84,9 @@ func SplitFile(src, dst string, num int, bynum bool, progress func(float64)) err
if num <= 0 {
return errors.New("num must be greater than zero")
}
if strings.Count(dst, "*") != 1 {
return ErrInvalidSplitPattern
}
fpsrc, err := os.Open(src)
if err != nil {
@ -244,6 +250,8 @@ func MergeFile(src, dst string, progress func(float64)) error {
return nil
}
// FillWithRandom fills file with pseudo-random bytes generated by math/rand.
// It is fast but not cryptographically secure.
func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(float64)) error {
if filesize < 0 {
return errors.New("filesize must be non-negative")
@ -258,7 +266,7 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo
bufcap = filesize
}
rand.Seed(time.Now().UnixNano())
r := mrand.New(mrand.NewSource(time.Now().UnixNano()))
fp, err := os.Create(path)
if err != nil {
@ -267,18 +275,17 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo
defer fp.Close()
writer := bufio.NewWriter(fp)
defer writer.Flush()
if filesize == 0 {
reportProgress(progress, 0, 0)
return nil
return writer.Flush()
}
pool := make([][]byte, 0, bufnum)
for i := 0; i < bufnum; i++ {
b := make([]byte, bufcap)
for j := 0; j < bufcap; j++ {
b[j] = byte(rand.Intn(256))
b[j] = byte(r.Intn(256))
}
pool = append(pool, b)
}
@ -289,14 +296,59 @@ func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(flo
if filesize-written < chunk {
chunk = filesize - written
}
buf := pool[rand.Intn(len(pool))][:chunk]
buf := pool[r.Intn(len(pool))][:chunk]
if _, err := writer.Write(buf); err != nil {
return err
}
written += chunk
reportProgress(progress, int64(written), int64(filesize))
}
return nil
return writer.Flush()
}
// FillWithCryptoRandom fills file with cryptographically secure random bytes from crypto/rand.
// Security is stronger than FillWithRandom, but throughput may be lower.
func FillWithCryptoRandom(path string, filesize, bufcap int, progress func(float64)) error {
if filesize < 0 {
return errors.New("filesize must be non-negative")
}
if bufcap <= 0 {
bufcap = 1
}
if bufcap > filesize && filesize > 0 {
bufcap = filesize
}
fp, err := os.Create(path)
if err != nil {
return err
}
defer fp.Close()
writer := bufio.NewWriter(fp)
if filesize == 0 {
reportProgress(progress, 0, 0)
return writer.Flush()
}
buf := make([]byte, bufcap)
written := 0
for written < filesize {
chunk := bufcap
if filesize-written < chunk {
chunk = filesize - written
}
if _, err := io.ReadFull(crand.Reader, buf[:chunk]); err != nil {
return err
}
if _, err := writer.Write(buf[:chunk]); err != nil {
return err
}
written += chunk
reportProgress(progress, int64(written), int64(filesize))
}
return writer.Flush()
}
func reportProgress(progress func(float64), current, total int64) {

65
filex/file_random_test.go Normal file
View File

@ -0,0 +1,65 @@
package filex
import (
"bytes"
"os"
"path/filepath"
"testing"
)
func TestFillWithRandomAndCryptoRandom(t *testing.T) {
dir := t.TempDir()
pseudoPath := filepath.Join(dir, "pseudo.bin")
securePath := filepath.Join(dir, "secure.bin")
if err := FillWithRandom(pseudoPath, 2048, 128, 4, nil); err != nil {
t.Fatalf("FillWithRandom failed: %v", err)
}
if err := FillWithCryptoRandom(securePath, 2048, 128, nil); err != nil {
t.Fatalf("FillWithCryptoRandom failed: %v", err)
}
pseudoInfo, err := os.Stat(pseudoPath)
if err != nil {
t.Fatalf("stat pseudo file failed: %v", err)
}
if pseudoInfo.Size() != 2048 {
t.Fatalf("unexpected pseudo size: %d", pseudoInfo.Size())
}
secureInfo, err := os.Stat(securePath)
if err != nil {
t.Fatalf("stat secure file failed: %v", err)
}
if secureInfo.Size() != 2048 {
t.Fatalf("unexpected secure size: %d", secureInfo.Size())
}
pseudo, err := os.ReadFile(pseudoPath)
if err != nil {
t.Fatalf("read pseudo file failed: %v", err)
}
secure, err := os.ReadFile(securePath)
if err != nil {
t.Fatalf("read secure file failed: %v", err)
}
if bytes.Equal(secure, make([]byte, len(secure))) {
t.Fatalf("secure random output should not be all zero")
}
if bytes.Equal(pseudo, secure) {
t.Fatalf("pseudo and secure random outputs unexpectedly identical")
}
}
func TestFillWithRandomInvalidArgs(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "bad.bin")
if err := FillWithRandom(path, -1, 16, 1, nil); err == nil {
t.Fatalf("expected FillWithRandom negative filesize error")
}
if err := FillWithCryptoRandom(path, -1, 16, nil); err == nil {
t.Fatalf("expected FillWithCryptoRandom negative filesize error")
}
}

36
filex/file_split_test.go Normal file
View File

@ -0,0 +1,36 @@
package filex
import (
"bytes"
"errors"
"os"
"path/filepath"
"testing"
)
func TestSplitFilePatternValidation(t *testing.T) {
dir := t.TempDir()
src := filepath.Join(dir, "src.bin")
if err := os.WriteFile(src, bytes.Repeat([]byte{0x7f}, 64), 0o600); err != nil {
t.Fatalf("WriteFile failed: %v", err)
}
if err := SplitFile(src, filepath.Join(dir, "part.bin"), 2, true, nil); !errors.Is(err, ErrInvalidSplitPattern) {
t.Fatalf("expected ErrInvalidSplitPattern for missing '*', got: %v", err)
}
if err := SplitFile(src, filepath.Join(dir, "part_*_*.bin"), 2, true, nil); !errors.Is(err, ErrInvalidSplitPattern) {
t.Fatalf("expected ErrInvalidSplitPattern for multiple '*', got: %v", err)
}
pattern := filepath.Join(dir, "part_*.bin")
if err := SplitFile(src, pattern, 2, true, nil); err != nil {
t.Fatalf("SplitFile valid pattern failed: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "part_0.bin")); err != nil {
t.Fatalf("part_0.bin not found: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "part_1.bin")); err != nil {
t.Fatalf("part_1.bin not found: %v", err)
}
}

8
go.mod
View File

@ -4,7 +4,13 @@ go 1.24.0
require (
github.com/emmansun/gmsm v0.41.1
github.com/stretchr/testify v1.11.1
golang.org/x/crypto v0.48.0
)
require golang.org/x/sys v0.41.0 // indirect
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/sys v0.41.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

10
go.sum
View File

@ -1,8 +1,18 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w=
github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

View File

@ -1,6 +1,7 @@
package hashx
import (
"bytes"
"crypto/md5"
"encoding/hex"
"os"
@ -87,3 +88,52 @@ func TestFileSumUnsupportedMethod(t *testing.T) {
t.Fatalf("expected unsupported method error")
}
}
func TestPBKDF2SHA256Vector(t *testing.T) {
got, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 1, 32)
if err != nil {
t.Fatalf("DerivePBKDF2SHA256Key failed: %v", err)
}
const want = "120fb6cffcf8b32c43e7225256c4f837a86548c92ccc35480805987cb70be17b"
if hex.EncodeToString(got) != want {
t.Fatalf("pbkdf2-sha256 vector mismatch: got %x want %s", got, want)
}
}
func TestPBKDF2AndArgon2Deterministic(t *testing.T) {
a, err := DerivePBKDF2SHA512Key("password", []byte("salt"), 1000, 32)
if err != nil {
t.Fatalf("DerivePBKDF2SHA512Key failed: %v", err)
}
b, err := DerivePBKDF2SHA512Key("password", []byte("salt"), 1000, 32)
if err != nil {
t.Fatalf("DerivePBKDF2SHA512Key failed: %v", err)
}
if !bytes.Equal(a, b) {
t.Fatalf("pbkdf2-sha512 must be deterministic")
}
params := Argon2Params{Time: 1, Memory: 32 * 1024, Threads: 1, KeyLen: 32}
argA, err := DeriveArgon2idKey("password", []byte("salt-salt"), params)
if err != nil {
t.Fatalf("DeriveArgon2idKey failed: %v", err)
}
argB, err := DeriveArgon2idKey("password", []byte("salt-salt"), params)
if err != nil {
t.Fatalf("DeriveArgon2idKey failed: %v", err)
}
if !bytes.Equal(argA, argB) {
t.Fatalf("argon2id must be deterministic")
}
}
func TestKDFInvalidParams(t *testing.T) {
if _, err := DerivePBKDF2SHA256Key("password", nil, 1, 32); err == nil {
t.Fatalf("expected pbkdf2 salt error")
}
if _, err := DerivePBKDF2SHA256Key("password", []byte("salt"), 0, 32); err == nil {
t.Fatalf("expected pbkdf2 iterations error")
}
if _, err := DeriveArgon2idKey("password", []byte("salt"), Argon2Params{}); err == nil {
t.Fatalf("expected argon2 params error")
}
}

90
hashx/kdf.go Normal file
View File

@ -0,0 +1,90 @@
package hashx
import (
"crypto/pbkdf2"
"crypto/sha256"
"crypto/sha512"
"errors"
"golang.org/x/crypto/argon2"
)
var (
ErrInvalidKDFSalt = errors.New("kdf salt must be non-empty")
ErrInvalidKDFIterations = errors.New("kdf iterations must be > 0")
ErrInvalidKDFKeyLength = errors.New("kdf key length must be > 0")
ErrInvalidArgon2Params = errors.New("argon2 params must have time, memory, threads, and key length > 0")
)
// Argon2Params configures Argon2 key derivation.
type Argon2Params struct {
Time uint32
Memory uint32
Threads uint8
KeyLen uint32
}
// DefaultArgon2idParams returns a conservative default suitable for general online usage.
func DefaultArgon2idParams() Argon2Params {
return Argon2Params{
Time: 1,
Memory: 64 * 1024, // 64 MiB in KiB
Threads: 4,
KeyLen: 32,
}
}
func validatePBKDF2Params(salt []byte, iterations, keyLen int) error {
if len(salt) == 0 {
return ErrInvalidKDFSalt
}
if iterations <= 0 {
return ErrInvalidKDFIterations
}
if keyLen <= 0 {
return ErrInvalidKDFKeyLength
}
return nil
}
func validateArgon2Params(salt []byte, params Argon2Params) error {
if len(salt) == 0 {
return ErrInvalidKDFSalt
}
if params.Time == 0 || params.Memory == 0 || params.Threads == 0 || params.KeyLen == 0 {
return ErrInvalidArgon2Params
}
return nil
}
// DerivePBKDF2SHA256Key derives a key with PBKDF2-HMAC-SHA256.
func DerivePBKDF2SHA256Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) {
if err := validatePBKDF2Params(salt, iterations, keyLen); err != nil {
return nil, err
}
return pbkdf2.Key(sha256.New, password, salt, iterations, keyLen)
}
// DerivePBKDF2SHA512Key derives a key with PBKDF2-HMAC-SHA512.
func DerivePBKDF2SHA512Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) {
if err := validatePBKDF2Params(salt, iterations, keyLen); err != nil {
return nil, err
}
return pbkdf2.Key(sha512.New, password, salt, iterations, keyLen)
}
// DeriveArgon2idKey derives a key with Argon2id.
func DeriveArgon2idKey(password string, salt []byte, params Argon2Params) ([]byte, error) {
if err := validateArgon2Params(salt, params); err != nil {
return nil, err
}
return argon2.IDKey([]byte(password), salt, params.Time, params.Memory, params.Threads, params.KeyLen), nil
}
// DeriveArgon2iKey derives a key with Argon2i.
func DeriveArgon2iKey(password string, salt []byte, params Argon2Params) ([]byte, error) {
if err := validateArgon2Params(salt, params); err != nil {
return nil, err
}
return argon2.Key([]byte(password), salt, params.Time, params.Memory, params.Threads, params.KeyLen), nil
}

64
hmac.go
View File

@ -10,6 +10,14 @@ func HmacMd4Str(message, key []byte) string {
return macx.HmacMd4Str(message, key)
}
func VerifyHmacMd4(message, key, sum []byte) bool {
return macx.VerifyHmacMd4(message, key, sum)
}
func VerifyHmacMd4Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacMd4Str(message, key, hexSum)
}
func HmacMd5(message, key []byte) []byte {
return macx.HmacMd5(message, key)
}
@ -18,6 +26,14 @@ func HmacMd5Str(message, key []byte) string {
return macx.HmacMd5Str(message, key)
}
func VerifyHmacMd5(message, key, sum []byte) bool {
return macx.VerifyHmacMd5(message, key, sum)
}
func VerifyHmacMd5Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacMd5Str(message, key, hexSum)
}
func HmacSHA1(message, key []byte) []byte {
return macx.HmacSHA1(message, key)
}
@ -26,6 +42,14 @@ func HmacSHA1Str(message, key []byte) string {
return macx.HmacSHA1Str(message, key)
}
func VerifyHmacSHA1(message, key, sum []byte) bool {
return macx.VerifyHmacSHA1(message, key, sum)
}
func VerifyHmacSHA1Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacSHA1Str(message, key, hexSum)
}
func HmacSHA256(message, key []byte) []byte {
return macx.HmacSHA256(message, key)
}
@ -34,6 +58,14 @@ func HmacSHA256Str(message, key []byte) string {
return macx.HmacSHA256Str(message, key)
}
func VerifyHmacSHA256(message, key, sum []byte) bool {
return macx.VerifyHmacSHA256(message, key, sum)
}
func VerifyHmacSHA256Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacSHA256Str(message, key, hexSum)
}
func HmacSHA384(message, key []byte) []byte {
return macx.HmacSHA384(message, key)
}
@ -42,6 +74,14 @@ func HmacSHA384Str(message, key []byte) string {
return macx.HmacSHA384Str(message, key)
}
func VerifyHmacSHA384(message, key, sum []byte) bool {
return macx.VerifyHmacSHA384(message, key, sum)
}
func VerifyHmacSHA384Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacSHA384Str(message, key, hexSum)
}
func HmacSHA512(message, key []byte) []byte {
return macx.HmacSHA512(message, key)
}
@ -50,6 +90,14 @@ func HmacSHA512Str(message, key []byte) string {
return macx.HmacSHA512Str(message, key)
}
func VerifyHmacSHA512(message, key, sum []byte) bool {
return macx.VerifyHmacSHA512(message, key, sum)
}
func VerifyHmacSHA512Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacSHA512Str(message, key, hexSum)
}
func HmacSHA224(message, key []byte) []byte {
return macx.HmacSHA224(message, key)
}
@ -58,6 +106,14 @@ func HmacSHA224Str(message, key []byte) string {
return macx.HmacSHA224Str(message, key)
}
func VerifyHmacSHA224(message, key, sum []byte) bool {
return macx.VerifyHmacSHA224(message, key, sum)
}
func VerifyHmacSHA224Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacSHA224Str(message, key, hexSum)
}
func HmacRipeMd160(message, key []byte) []byte {
return macx.HmacRipeMd160(message, key)
}
@ -65,3 +121,11 @@ func HmacRipeMd160(message, key []byte) []byte {
func HmacRipeMd160Str(message, key []byte) string {
return macx.HmacRipeMd160Str(message, key)
}
func VerifyHmacRipeMd160(message, key, sum []byte) bool {
return macx.VerifyHmacRipeMd160(message, key, sum)
}
func VerifyHmacRipeMd160Str(message, key []byte, hexSum string) bool {
return macx.VerifyHmacRipeMd160Str(message, key, hexSum)
}

25
kdf.go Normal file
View File

@ -0,0 +1,25 @@
package starcrypto
import "b612.me/starcrypto/hashx"
type Argon2Params = hashx.Argon2Params
func DefaultArgon2idParams() Argon2Params {
return hashx.DefaultArgon2idParams()
}
func DerivePBKDF2SHA256Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) {
return hashx.DerivePBKDF2SHA256Key(password, salt, iterations, keyLen)
}
func DerivePBKDF2SHA512Key(password string, salt []byte, iterations, keyLen int) ([]byte, error) {
return hashx.DerivePBKDF2SHA512Key(password, salt, iterations, keyLen)
}
func DeriveArgon2idKey(password string, salt []byte, params Argon2Params) ([]byte, error) {
return hashx.DeriveArgon2idKey(password, salt, params)
}
func DeriveArgon2iKey(password string, salt []byte, params Argon2Params) ([]byte, error) {
return hashx.DeriveArgon2iKey(password, salt, params)
}

View File

@ -8,6 +8,7 @@ import (
"crypto/sha512"
"encoding/hex"
"hash"
"strings"
"golang.org/x/crypto/md4"
"golang.org/x/crypto/ripemd160"
@ -23,6 +24,19 @@ func chmacStr(message, key []byte, f func() hash.Hash) string {
return hex.EncodeToString(chmac(message, key, f))
}
func verifyHMAC(message, key, sum []byte, f func() hash.Hash) bool {
expected := chmac(message, key, f)
return hmac.Equal(expected, sum)
}
func verifyHMACStr(message, key []byte, hexSum string, f func() hash.Hash) bool {
sum, err := hex.DecodeString(strings.TrimSpace(hexSum))
if err != nil {
return false
}
return verifyHMAC(message, key, sum, f)
}
func HmacMd4(message, key []byte) []byte {
return chmac(message, key, md4.New)
}
@ -31,6 +45,14 @@ func HmacMd4Str(message, key []byte) string {
return chmacStr(message, key, md4.New)
}
func VerifyHmacMd4(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, md4.New)
}
func VerifyHmacMd4Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, md4.New)
}
func HmacMd5(message, key []byte) []byte {
return chmac(message, key, md5.New)
}
@ -39,6 +61,14 @@ func HmacMd5Str(message, key []byte) string {
return chmacStr(message, key, md5.New)
}
func VerifyHmacMd5(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, md5.New)
}
func VerifyHmacMd5Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, md5.New)
}
func HmacSHA1(message, key []byte) []byte {
return chmac(message, key, sha1.New)
}
@ -47,6 +77,14 @@ func HmacSHA1Str(message, key []byte) string {
return chmacStr(message, key, sha1.New)
}
func VerifyHmacSHA1(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, sha1.New)
}
func VerifyHmacSHA1Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, sha1.New)
}
func HmacSHA256(message, key []byte) []byte {
return chmac(message, key, sha256.New)
}
@ -55,6 +93,14 @@ func HmacSHA256Str(message, key []byte) string {
return chmacStr(message, key, sha256.New)
}
func VerifyHmacSHA256(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, sha256.New)
}
func VerifyHmacSHA256Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, sha256.New)
}
func HmacSHA384(message, key []byte) []byte {
return chmac(message, key, sha512.New384)
}
@ -63,6 +109,14 @@ func HmacSHA384Str(message, key []byte) string {
return chmacStr(message, key, sha512.New384)
}
func VerifyHmacSHA384(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, sha512.New384)
}
func VerifyHmacSHA384Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, sha512.New384)
}
func HmacSHA512(message, key []byte) []byte {
return chmac(message, key, sha512.New)
}
@ -71,6 +125,14 @@ func HmacSHA512Str(message, key []byte) string {
return chmacStr(message, key, sha512.New)
}
func VerifyHmacSHA512(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, sha512.New)
}
func VerifyHmacSHA512Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, sha512.New)
}
func HmacSHA224(message, key []byte) []byte {
return chmac(message, key, sha256.New224)
}
@ -79,6 +141,14 @@ func HmacSHA224Str(message, key []byte) string {
return chmacStr(message, key, sha256.New224)
}
func VerifyHmacSHA224(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, sha256.New224)
}
func VerifyHmacSHA224Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, sha256.New224)
}
func HmacRipeMd160(message, key []byte) []byte {
return chmac(message, key, ripemd160.New)
}
@ -86,3 +156,11 @@ func HmacRipeMd160(message, key []byte) []byte {
func HmacRipeMd160Str(message, key []byte) string {
return chmacStr(message, key, ripemd160.New)
}
func VerifyHmacRipeMd160(message, key, sum []byte) bool {
return verifyHMAC(message, key, sum, ripemd160.New)
}
func VerifyHmacRipeMd160Str(message, key []byte, hexSum string) bool {
return verifyHMACStr(message, key, hexSum, ripemd160.New)
}

58
macx/hmac_test.go Normal file
View File

@ -0,0 +1,58 @@
package macx
import "testing"
type hmacCase struct {
name string
sum func([]byte, []byte) []byte
sumStr func([]byte, []byte) string
verify func([]byte, []byte, []byte) bool
verifyStr func([]byte, []byte, string) bool
}
func TestHMACVerifyCases(t *testing.T) {
msg := []byte("macx-verify-message")
key := []byte("macx-verify-key")
cases := []hmacCase{
{"md4", HmacMd4, HmacMd4Str, VerifyHmacMd4, VerifyHmacMd4Str},
{"md5", HmacMd5, HmacMd5Str, VerifyHmacMd5, VerifyHmacMd5Str},
{"sha1", HmacSHA1, HmacSHA1Str, VerifyHmacSHA1, VerifyHmacSHA1Str},
{"sha224", HmacSHA224, HmacSHA224Str, VerifyHmacSHA224, VerifyHmacSHA224Str},
{"sha256", HmacSHA256, HmacSHA256Str, VerifyHmacSHA256, VerifyHmacSHA256Str},
{"sha384", HmacSHA384, HmacSHA384Str, VerifyHmacSHA384, VerifyHmacSHA384Str},
{"sha512", HmacSHA512, HmacSHA512Str, VerifyHmacSHA512, VerifyHmacSHA512Str},
{"ripemd160", HmacRipeMd160, HmacRipeMd160Str, VerifyHmacRipeMd160, VerifyHmacRipeMd160Str},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
sum := tc.sum(msg, key)
if !tc.verify(msg, key, sum) {
t.Fatalf("verify bytes should pass")
}
hexSum := tc.sumStr(msg, key)
if !tc.verifyStr(msg, key, hexSum) {
t.Fatalf("verify hex should pass")
}
if !tc.verifyStr(msg, key, " \t"+hexSum+"\n") {
t.Fatalf("verify hex with spaces should pass")
}
bad := make([]byte, len(sum))
copy(bad, sum)
bad[0] ^= 0xff
if tc.verify(msg, key, bad) {
t.Fatalf("verify bytes should fail for tampered sum")
}
if tc.verifyStr(msg, key, "not-hex") {
t.Fatalf("verify hex should fail for invalid hex")
}
if tc.verify([]byte("wrong-msg"), key, sum) {
t.Fatalf("verify bytes should fail for wrong message")
}
})
}
}

36
rsa.go
View File

@ -14,6 +14,14 @@ func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error)
return asymm.EncodeRsaPrivateKey(private, secret)
}
func EncodeRsaPrivateKeyWithLegacy(private *rsa.PrivateKey, secret string, legacy bool) ([]byte, error) {
return asymm.EncodeRsaPrivateKeyWithLegacy(private, secret, legacy)
}
func EncodeRsaPrivateKeyPKCS8(private *rsa.PrivateKey, secret string) ([]byte, error) {
return asymm.EncodeRsaPrivateKeyPKCS8(private, secret)
}
func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) {
return asymm.EncodeRsaPublicKey(public)
}
@ -22,6 +30,14 @@ func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, erro
return asymm.DecodeRsaPrivateKey(private, password)
}
func DecodeRsaPrivateKeyWithLegacy(private []byte, password string, legacy bool) (*rsa.PrivateKey, error) {
return asymm.DecodeRsaPrivateKeyWithLegacy(private, password, legacy)
}
func DecodeRsaPrivateKeyPKCS8(private []byte, password string) (*rsa.PrivateKey, error) {
return asymm.DecodeRsaPrivateKeyPKCS8(private, password)
}
func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) {
return asymm.DecodeRsaPublicKey(pubStr)
}
@ -34,6 +50,10 @@ func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) {
return asymm.GenerateRsaSSHKeyPair(bits, secret)
}
func GenerateRsaSSHKeyPairWithLegacy(bits int, secret string, legacy bool) (string, string, error) {
return asymm.GenerateRsaSSHKeyPairWithLegacy(bits, secret, legacy)
}
func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) {
return asymm.RSAEncrypt(pub, data)
}
@ -42,6 +62,14 @@ func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) {
return asymm.RSADecrypt(prikey, data)
}
func RSAEncryptOAEP(pub *rsa.PublicKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
return asymm.RSAEncryptOAEP(pub, data, label, hashType)
}
func RSADecryptOAEP(prikey *rsa.PrivateKey, data, label []byte, hashType crypto.Hash) ([]byte, error) {
return asymm.RSADecryptOAEP(prikey, data, label, hashType)
}
func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) {
return asymm.RSASign(msg, priKey, password, hashType)
}
@ -50,6 +78,14 @@ func RSAVerify(data, msg, pubKey []byte, hashType crypto.Hash) error {
return asymm.RSAVerify(data, msg, pubKey, hashType)
}
func RSASignPSS(msg, priKey []byte, password string, hashType crypto.Hash, opts *rsa.PSSOptions) ([]byte, error) {
return asymm.RSASignPSS(msg, priKey, password, hashType, opts)
}
func RSAVerifyPSS(sig, msg, pubKey []byte, hashType crypto.Hash, opts *rsa.PSSOptions) error {
return asymm.RSAVerifyPSS(sig, msg, pubKey, hashType, opts)
}
func RSAEncryptByPrivkey(privt *rsa.PrivateKey, data []byte) ([]byte, error) {
return asymm.RSAEncryptByPrivkey(privt, data)
}

99
sm4.go
View File

@ -46,6 +46,14 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return symm.DecryptSM4GCM(ciphertext, key, nonce, aad)
}
func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex)
}
func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.DecryptSM4GCMChunk(ciphertext, key, nonce, aad, chunkIndex)
}
func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.EncryptSM4GCMStream(dst, src, key, nonce, aad)
}
@ -54,6 +62,30 @@ func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) e
return symm.DecryptSM4GCMStream(dst, src, key, nonce, aad)
}
func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) {
return symm.EncryptSM4CCM(plain, key, nonce, aad)
}
func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return symm.DecryptSM4CCM(ciphertext, key, nonce, aad)
}
func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex)
}
func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
return symm.DecryptSM4CCMChunk(ciphertext, key, nonce, aad, chunkIndex)
}
func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.EncryptSM4CCMStream(dst, src, key, nonce, aad)
}
func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
return symm.DecryptSM4CCMStream(dst, src, key, nonce, aad)
}
func EncryptSM4CFB(origData, key []byte) ([]byte, error) {
return symm.EncryptSM4CFB(origData, key)
}
@ -102,6 +134,14 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) {
return symm.DecryptSM4CTR(src, key, iv)
}
func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
return symm.EncryptSM4CTRAt(data, key, iv, offset)
}
func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return symm.DecryptSM4CTRAt(src, key, iv, offset)
}
func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return symm.EncryptSM4ECBStream(dst, src, key, paddingType)
}
@ -141,3 +181,62 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
return symm.DecryptSM4CTRStream(dst, src, key, iv)
}
func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) {
return symm.EncryptSM4CFB8(data, key, iv)
}
func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) {
return symm.DecryptSM4CFB8(src, key, iv)
}
func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
return symm.EncryptSM4CFB8Stream(dst, src, key, iv)
}
func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
return symm.DecryptSM4CFB8Stream(dst, src, key, iv)
}
func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) {
return symm.DecryptSM4ECBBlocks(src, key)
}
func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
return symm.DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock)
}
func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
return symm.DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock)
}
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return symm.EncryptSM4XTS(plain, key1, key2, dataUnitSize)
}
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return symm.DecryptSM4XTS(ciphertext, key1, key2, dataUnitSize)
}
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return symm.EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
}
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return symm.DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, dataUnitIndex)
}
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return symm.EncryptSM4XTSStream(dst, src, key1, key2, dataUnitSize)
}
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return symm.DecryptSM4XTSStream(dst, src, key1, key2, dataUnitSize)
}
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return symm.EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex)
}
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return symm.DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, dataUnitIndex)
}

View File

@ -7,6 +7,7 @@ import (
"errors"
"io"
"b612.me/starcrypto/ccm"
"b612.me/starcrypto/paddingx"
)
@ -17,13 +18,27 @@ const (
ANSIX923PADDING = paddingx.ANSIX923
)
var ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
var (
ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes")
ErrInvalidCCMNonceLength = errors.New("ccm nonce length must be 12 bytes")
)
const (
aeadCCMTagSize = 16
aeadCCMNonceSize = 12
)
func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptAesGCM(data, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptAesCCM(data, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@ -34,9 +49,15 @@ func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error)
func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptAesGCM(src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptAesCCM(src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@ -47,9 +68,15 @@ func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptAesGCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptAesCCMStream(dst, src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@ -60,9 +87,15 @@ func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptAesGCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptAesCCMStream(dst, src, key, iv, nil)
}
block, err := aes.NewCipher(key)
if err != nil {
@ -80,6 +113,9 @@ func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
if mode == MODEGCM {
return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptAesCCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptAes(data, key, cfg.IV, mode, cfg.Padding)
}
@ -92,6 +128,9 @@ func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
if mode == MODEGCM {
return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptAesCCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptAes(src, key, cfg.IV, mode, cfg.Padding)
}
@ -104,6 +143,9 @@ func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@ -116,6 +158,9 @@ func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptAesCCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@ -149,30 +194,138 @@ func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return gcm.Open(nil, nonce, ciphertext, aad)
}
func newAesCCM(key []byte) (cipher.AEAD, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
}
func EncryptAesCCM(plain, key, nonce, aad []byte) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Seal(nil, nonce, plain, aad), nil
}
func DecryptAesCCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Open(nil, nonce, ciphertext, aad)
}
func EncryptAesCCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
}
func DecryptAesCCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newAesCCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
}
func EncryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newAesCCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func DecryptAesCCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newAesCCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
}
func EncryptAesGCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
}
func DecryptAesGCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
}
func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
plain, err := io.ReadAll(src)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
out, err := EncryptAesGCM(plain, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
}
func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
out, err := DecryptAesGCM(enc, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
}
func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) {
@ -215,6 +368,18 @@ func DecryptAesCTR(src, key, iv []byte) ([]byte, error) {
return DecryptAes(src, key, iv, MODECTR, "")
}
func EncryptAesCTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return xorCTRAtOffset(block, data, iv, offset)
}
func DecryptAesCTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return EncryptAesCTRAt(src, key, iv, offset)
}
func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return EncryptAesStream(dst, src, key, nil, MODEECB, paddingType)
}
@ -329,3 +494,63 @@ func PKCS7Trimming(encrypted []byte, blockSize int) []byte {
}
return out
}
func EncryptAesCFB8(data, key, iv []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return encryptCFB8(block, data, iv)
}
func DecryptAesCFB8(src, key, iv []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFB8(block, src, iv)
}
func EncryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := aes.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, false)
}
func DecryptAesCFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := aes.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, true)
}
func DecryptAesECBBlocks(src, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptECBBlocks(block, src)
}
// DecryptAesCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptAesCBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
}
// DecryptAesCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptAesCFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
}

98
symm/bench_test.go Normal file
View File

@ -0,0 +1,98 @@
package symm
import (
"bytes"
"testing"
)
var (
benchPlain4K = bytes.Repeat([]byte("0123456789abcdef"), 256) // 4 KiB
benchPlain256K = bytes.Repeat([]byte("0123456789abcdef"), 16384) // 256 KiB
benchAAD = []byte("benchmark-aad")
benchAESKey = []byte("0123456789abcdef")
benchSM4Key = []byte("0123456789abcdef")
benchXTSKey2 = []byte("fedcba9876543210")
benchNonce12Byte = []byte("123456789012")
)
func BenchmarkAesGCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesGCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesCCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesCCM(benchPlain4K, benchAESKey, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesXTSEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptAesXTS(benchPlain4K, benchAESKey, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4GCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4GCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4CCMEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4CCM(benchPlain4K, benchSM4Key, benchNonce12Byte, benchAAD); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4XTSEncrypt4K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain4K)))
for i := 0; i < b.N; i++ {
if _, err := EncryptSM4XTS(benchPlain4K, benchSM4Key, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkAesXTSStreamEncrypt256K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain256K)))
for i := 0; i < b.N; i++ {
var dst bytes.Buffer
if err := EncryptAesXTSStream(&dst, bytes.NewReader(benchPlain256K), benchAESKey, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}
func BenchmarkSM4XTSStreamEncrypt256K(b *testing.B) {
b.ReportAllocs()
b.SetBytes(int64(len(benchPlain256K)))
for i := 0; i < b.N; i++ {
var dst bytes.Buffer
if err := EncryptSM4XTSStream(&dst, bytes.NewReader(benchPlain256K), benchSM4Key, benchXTSKey2, 512); err != nil {
b.Fatal(err)
}
}
}

127
symm/ccm_stream.go Normal file
View File

@ -0,0 +1,127 @@
package symm
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
)
const ccmStreamMagic = "SCC1"
var ErrInvalidCCMStreamChunk = errors.New("invalid ccm stream chunk")
func encryptCCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Seal(nil, chunkNonce, plain, aad)
}
func decryptCCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Open(nil, chunkNonce, ciphertext, aad)
}
func encryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
if _, err := dst.Write([]byte(ccmStreamMagic)); err != nil {
return err
}
buf := make([]byte, gcmStreamChunkSize)
lenBuf := make([]byte, 4)
var chunkIndex uint64
for {
n, err := src.Read(buf)
if n > 0 {
sealed := encryptCCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
if _, werr := dst.Write(lenBuf); werr != nil {
return werr
}
if _, werr := dst.Write(sealed); werr != nil {
return werr
}
chunkIndex++
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func decryptCCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
header := make([]byte, len(ccmStreamMagic))
n, err := io.ReadFull(src, header)
if err != nil {
if err == io.EOF {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
}
if string(header) != ccmStreamMagic {
return decryptCCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
}
return decryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func decryptCCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
lenBuf := make([]byte, 4)
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
var chunkIndex uint64
for {
_, err := io.ReadFull(src, lenBuf)
if err != nil {
if err == io.EOF {
return nil
}
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
chunkLen := binary.BigEndian.Uint32(lenBuf)
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
return ErrInvalidCCMStreamChunk
}
chunk := make([]byte, chunkLen)
if _, err := io.ReadFull(src, chunk); err != nil {
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
plain, err := decryptCCMChunk(aead, chunk, nonce, aad, chunkIndex)
if err != nil {
return err
}
if _, err := dst.Write(plain); err != nil {
return err
}
chunkIndex++
}
}
func decryptCCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
if err != nil {
return err
}
plain, err := aead.Open(nil, nonce, enc, aad)
if err != nil {
return err
}
_, err = dst.Write(plain)
return err
}

103
symm/cfb8.go Normal file
View File

@ -0,0 +1,103 @@
package symm
import (
"crypto/cipher"
"errors"
"io"
)
func encryptCFB8(block cipher.Block, data, iv []byte) ([]byte, error) {
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
out := make([]byte, len(data))
head := 0
for i := 0; i < len(data); i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
c := data[i] ^ streamBlock[0]
out[i] = c
advanceCFB8Register(reg, &head, c)
}
return out, nil
}
func decryptCFB8(block cipher.Block, src, iv []byte) ([]byte, error) {
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
out := make([]byte, len(src))
head := 0
for i := 0; i < len(src); i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
p := src[i] ^ streamBlock[0]
out[i] = p
advanceCFB8Register(reg, &head, src[i])
}
return out, nil
}
func encryptCFB8Stream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, decrypt bool) error {
if len(iv) != block.BlockSize() {
return errors.New("iv length must match block size")
}
reg := make([]byte, len(iv))
copy(reg, iv)
regView := make([]byte, block.BlockSize())
streamBlock := make([]byte, block.BlockSize())
buf := make([]byte, 32*1024)
out := make([]byte, 32*1024)
head := 0
for {
n, err := src.Read(buf)
if n > 0 {
for i := 0; i < n; i++ {
buildCFB8Register(regView, reg, head)
block.Encrypt(streamBlock, regView)
if decrypt {
out[i] = buf[i] ^ streamBlock[0]
advanceCFB8Register(reg, &head, buf[i])
} else {
c := buf[i] ^ streamBlock[0]
out[i] = c
advanceCFB8Register(reg, &head, c)
}
}
if _, werr := dst.Write(out[:n]); werr != nil {
return werr
}
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func buildCFB8Register(dst, reg []byte, head int) {
first := len(reg) - head
copy(dst, reg[head:])
copy(dst[first:], reg[:head])
}
func advanceCFB8Register(reg []byte, head *int, feedback byte) {
reg[*head] = feedback
*head = *head + 1
if *head == len(reg) {
*head = 0
}
}

56
symm/ctr_seek.go Normal file
View File

@ -0,0 +1,56 @@
package symm
import (
"crypto/cipher"
"errors"
)
var (
ErrInvalidCTROffset = errors.New("ctr offset must be non-negative")
ErrCTRCounterOverflow = errors.New("ctr counter overflow")
)
func xorCTRAtOffset(block cipher.Block, data, iv []byte, offset int64) ([]byte, error) {
if offset < 0 {
return nil, ErrInvalidCTROffset
}
if len(iv) != block.BlockSize() {
return nil, errors.New("iv length must match block size")
}
counter := make([]byte, len(iv))
copy(counter, iv)
blockSize := int64(block.BlockSize())
blockOffset := uint64(offset / blockSize)
byteOffset := int(offset % blockSize)
if err := addUint64ToCounter(counter, blockOffset); err != nil {
return nil, err
}
stream := cipher.NewCTR(block, counter)
if byteOffset > 0 {
skip := make([]byte, byteOffset)
stream.XORKeyStream(skip, skip)
}
out := make([]byte, len(data))
stream.XORKeyStream(out, data)
return out, nil
}
func addUint64ToCounter(counter []byte, inc uint64) error {
if inc == 0 {
return nil
}
for i := len(counter) - 1; i >= 0 && inc > 0; i-- {
sum := uint64(counter[i]) + (inc & 0xff)
counter[i] = byte(sum)
inc = (inc >> 8) + (sum >> 8)
}
if inc > 0 {
return ErrCTRCounterOverflow
}
return nil
}

View File

@ -70,3 +70,122 @@ func FuzzAesCBCStreamRoundTrip(f *testing.F) {
}
})
}
func xtsFuzzDataUnitSize(selector uint8) int {
sizes := [...]int{16, 32, 64, 128, 256, 512}
return sizes[int(selector)%len(sizes)]
}
func xtsFuzzNormalizeData(data []byte, maxLen int) []byte {
if len(data) > maxLen {
data = data[:maxLen]
}
return data[:len(data)/16*16]
}
func FuzzAesXTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-aes-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x42}, 65), uint64(7), uint8(3))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
dec, err := DecryptAesXTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("aes xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptAesXTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("aes xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptAesXTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts bytes error for non-block-aligned input")
}
if err := EncryptAesXTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected aes xts stream error for non-block-aligned input")
}
}
})
}
func FuzzSM4XTSRoundTrip(f *testing.F) {
f.Add([]byte("fuzz-sm4-xts-seed-0000"), uint64(0), uint8(0))
f.Add(bytes.Repeat([]byte{0x5a}, 79), uint64(11), uint8(4))
key1 := []byte("0123456789abcdef")
key2 := []byte("fedcba9876543210")
f.Fuzz(func(t *testing.T, data []byte, dataUnitIndex uint64, selector uint8) {
bounded := data
if len(bounded) > 4097 {
bounded = bounded[:4097]
}
plain := xtsFuzzNormalizeData(bounded, 4096)
dataUnitSize := xtsFuzzDataUnitSize(selector)
enc, err := EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
dec, err := DecryptSM4XTSAt(enc, key1, key2, dataUnitSize, dataUnitIndex)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 xts roundtrip mismatch")
}
encStream := &bytes.Buffer{}
if err := EncryptSM4XTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("EncryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(encStream.Bytes(), enc) {
t.Fatalf("sm4 xts bytes/stream encrypt mismatch")
}
decStream := &bytes.Buffer{}
if err := DecryptSM4XTSStreamAt(decStream, bytes.NewReader(enc), key1, key2, dataUnitSize, dataUnitIndex); err != nil {
t.Fatalf("DecryptSM4XTSStreamAt failed: %v", err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("sm4 xts stream decrypt mismatch")
}
if len(bounded)%16 != 0 {
if _, err := EncryptSM4XTSAt(bounded, key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts bytes error for non-block-aligned input")
}
if err := EncryptSM4XTSStreamAt(&bytes.Buffer{}, bytes.NewReader(bounded), key1, key2, dataUnitSize, dataUnitIndex); err == nil {
t.Fatalf("expected sm4 xts stream error for non-block-aligned input")
}
}
})
}

146
symm/gcm_stream.go Normal file
View File

@ -0,0 +1,146 @@
package symm
import (
"bytes"
"crypto/cipher"
"encoding/binary"
"errors"
"io"
)
const (
gcmStreamMagic = "SCG1"
gcmStreamChunkSize = 32 * 1024
)
var ErrInvalidGCMStreamChunk = errors.New("invalid gcm stream chunk")
func encryptGCMChunk(aead cipher.AEAD, plain, nonce, aad []byte, chunkIndex uint64) []byte {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Seal(nil, chunkNonce, plain, aad)
}
func decryptGCMChunk(aead cipher.AEAD, ciphertext, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
chunkNonce := deriveChunkNonce(nonce, chunkIndex)
return aead.Open(nil, chunkNonce, ciphertext, aad)
}
func encryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
if _, err := dst.Write([]byte(gcmStreamMagic)); err != nil {
return err
}
buf := make([]byte, gcmStreamChunkSize)
lenBuf := make([]byte, 4)
var chunkIndex uint64
for {
n, err := src.Read(buf)
if n > 0 {
sealed := encryptGCMChunk(aead, buf[:n], nonce, aad, chunkIndex)
binary.BigEndian.PutUint32(lenBuf, uint32(len(sealed)))
if _, werr := dst.Write(lenBuf); werr != nil {
return werr
}
if _, werr := dst.Write(sealed); werr != nil {
return werr
}
chunkIndex++
}
if err != nil {
if err == io.EOF {
return nil
}
return err
}
}
}
func decryptGCMChunkedOrLegacyStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
header := make([]byte, len(gcmStreamMagic))
n, err := io.ReadFull(src, header)
if err != nil {
if err == io.EOF {
return nil
}
if err != io.ErrUnexpectedEOF {
return err
}
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header[:n]), src), aead, nonce, aad)
}
if string(header) != gcmStreamMagic {
return decryptGCMLegacyBuffered(dst, io.MultiReader(bytes.NewReader(header), src), aead, nonce, aad)
}
return decryptGCMChunkedStream(dst, src, aead, nonce, aad)
}
func decryptGCMChunkedStream(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
lenBuf := make([]byte, 4)
maxChunkLen := uint32(gcmStreamChunkSize + aead.Overhead())
var chunkIndex uint64
for {
_, err := io.ReadFull(src, lenBuf)
if err != nil {
if err == io.EOF {
return nil
}
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
chunkLen := binary.BigEndian.Uint32(lenBuf)
if chunkLen < uint32(aead.Overhead()) || chunkLen > maxChunkLen {
return ErrInvalidGCMStreamChunk
}
chunk := make([]byte, chunkLen)
if _, err := io.ReadFull(src, chunk); err != nil {
if err == io.ErrUnexpectedEOF {
return io.ErrUnexpectedEOF
}
return err
}
plain, err := decryptGCMChunk(aead, chunk, nonce, aad, chunkIndex)
if err != nil {
return err
}
if _, err := dst.Write(plain); err != nil {
return err
}
chunkIndex++
}
}
func decryptGCMLegacyBuffered(dst io.Writer, src io.Reader, aead cipher.AEAD, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
if err != nil {
return err
}
plain, err := aead.Open(nil, nonce, enc, aad)
if err != nil {
return err
}
_, err = dst.Write(plain)
return err
}
func deriveChunkNonce(baseNonce []byte, chunkIndex uint64) []byte {
nonce := make([]byte, len(baseNonce))
copy(nonce, baseNonce)
if len(nonce) < 8 {
return nonce
}
var indexBytes [8]byte
binary.BigEndian.PutUint64(indexBytes[:], chunkIndex)
off := len(nonce) - 8
for i := 0; i < 8; i++ {
nonce[off+i] ^= indexBytes[i]
}
return nonce
}

View File

@ -16,6 +16,7 @@ const (
MODEOFB = "OFB"
MODECTR = "CTR"
MODEGCM = "GCM"
MODECCM = "CCM"
)
var ErrUnsupportedCipherMode = errors.New("cipher mode not supported")

View File

@ -1,7 +1,7 @@
package symm
// CipherOptions provides a unified configuration for symmetric APIs.
// For GCM mode, Nonce is used; if Nonce is empty, IV is used as fallback.
// For AEAD modes (GCM/CCM), Nonce must be set explicitly.
type CipherOptions struct {
Mode string
Padding string
@ -18,8 +18,5 @@ func normalizeCipherOptions(opts *CipherOptions) CipherOptions {
}
func nonceFromOptions(opts CipherOptions) []byte {
if len(opts.Nonce) > 0 {
return opts.Nonce
}
return opts.IV
return opts.Nonce
}

View File

@ -0,0 +1,28 @@
package symm
import (
"errors"
"testing"
)
func TestAEADOptionsRequireNonce(t *testing.T) {
aesKey := []byte("0123456789abcdef")
sm4Key := []byte("0123456789abcdef")
plain := []byte("nonce-required")
gcmIVOnly := &CipherOptions{Mode: MODEGCM, IV: []byte("123456789012")}
if _, err := EncryptAesWithOptions(plain, aesKey, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
t.Fatalf("expected ErrInvalidGCMNonceLength for AES GCM with IV-only opts, got: %v", err)
}
if _, err := EncryptSM4WithOptions(plain, sm4Key, gcmIVOnly); !errors.Is(err, ErrInvalidGCMNonceLength) {
t.Fatalf("expected ErrInvalidGCMNonceLength for SM4 GCM with IV-only opts, got: %v", err)
}
ccmIVOnly := &CipherOptions{Mode: MODECCM, IV: []byte("123456789012")}
if _, err := EncryptAesWithOptions(plain, aesKey, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
t.Fatalf("expected ErrInvalidCCMNonceLength for AES CCM with IV-only opts, got: %v", err)
}
if _, err := EncryptSM4WithOptions(plain, sm4Key, ccmIVOnly); !errors.Is(err, ErrInvalidCCMNonceLength) {
t.Fatalf("expected ErrInvalidCCMNonceLength for SM4 CCM with IV-only opts, got: %v", err)
}
}

54
symm/segment_decrypt.go Normal file
View File

@ -0,0 +1,54 @@
package symm
import (
"crypto/cipher"
"errors"
)
var (
ErrSegmentNotFullBlocks = errors.New("ciphertext segment is not a full block size")
)
func decryptECBBlocks(block cipher.Block, src []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(src)%block.BlockSize() != 0 {
return nil, ErrSegmentNotFullBlocks
}
out := make([]byte, len(src))
ecbDecryptBlocks(block, out, src)
return out, nil
}
// decryptCBCFromSecondBlock decrypts a CBC ciphertext segment that starts from the second block (or later).
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
func decryptCBCFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(prevCipherBlock) != block.BlockSize() {
return nil, errors.New("prev cipher block length must match block size")
}
if len(src)%block.BlockSize() != 0 {
return nil, ErrSegmentNotFullBlocks
}
out := make([]byte, len(src))
cipher.NewCBCDecrypter(block, prevCipherBlock).CryptBlocks(out, src)
return out, nil
}
// decryptCFBFromSecondBlock decrypts a CFB ciphertext segment that starts from the second block (or later).
// prevCipherBlock must be the previous ciphertext block (C[i-1]); for i=1 this is the original IV.
func decryptCFBFromSecondBlock(block cipher.Block, src, prevCipherBlock []byte) ([]byte, error) {
if len(src) == 0 {
return []byte{}, nil
}
if len(prevCipherBlock) != block.BlockSize() {
return nil, errors.New("prev cipher block length must match block size")
}
stream := cipher.NewCFBDecrypter(block, prevCipherBlock)
out := make([]byte, len(src))
stream.XORKeyStream(out, src)
return out, nil
}

View File

@ -6,14 +6,21 @@ import (
"errors"
"io"
"b612.me/starcrypto/ccm"
"github.com/emmansun/gmsm/sm4"
)
func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptSM4GCM(data, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptSM4CCM(data, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@ -24,9 +31,15 @@ func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error)
func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptSM4GCM(src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptSM4CCM(src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@ -37,9 +50,15 @@ func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) {
func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return EncryptSM4GCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return EncryptSM4CCMStream(dst, src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@ -50,9 +69,15 @@ func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddin
func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error {
normalizedMode := normalizeCipherMode(mode)
if normalizedMode == "" {
normalizedMode = MODEGCM
}
if normalizedMode == MODEGCM {
return DecryptSM4GCMStream(dst, src, key, iv, nil)
}
if normalizedMode == MODECCM {
return DecryptSM4CCMStream(dst, src, key, iv, nil)
}
block, err := sm4.NewCipher(key)
if err != nil {
@ -70,6 +95,9 @@ func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error
if mode == MODEGCM {
return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptSM4CCM(data, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding)
}
@ -82,6 +110,9 @@ func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error)
if mode == MODEGCM {
return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptSM4CCM(src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding)
}
@ -94,6 +125,9 @@ func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return EncryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@ -106,6 +140,9 @@ func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts
if mode == MODEGCM {
return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
if mode == MODECCM {
return DecryptSM4CCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD)
}
return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding)
}
@ -139,30 +176,138 @@ func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
return gcm.Open(nil, nonce, ciphertext, aad)
}
func EncryptSM4GCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return encryptGCMChunk(gcm, plain, nonce, aad, chunkIndex), nil
}
func DecryptSM4GCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
if len(nonce) != gcm.NonceSize() {
return nil, ErrInvalidGCMNonceLength
}
return decryptGCMChunk(gcm, ciphertext, nonce, aad, chunkIndex)
}
func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
plain, err := io.ReadAll(src)
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
out, err := EncryptSM4GCM(plain, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return encryptGCMChunkedStream(dst, src, gcm, nonce, aad)
}
func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
enc, err := io.ReadAll(src)
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
out, err := DecryptSM4GCM(enc, key, nonce, aad)
gcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
_, err = dst.Write(out)
return err
if len(nonce) != gcm.NonceSize() {
return ErrInvalidGCMNonceLength
}
return decryptGCMChunkedOrLegacyStream(dst, src, gcm, nonce, aad)
}
func newSM4CCM(key []byte) (cipher.AEAD, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return ccm.NewCCM(block, aeadCCMTagSize, aeadCCMNonceSize)
}
func EncryptSM4CCM(plain, key, nonce, aad []byte) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Seal(nil, nonce, plain, aad), nil
}
func DecryptSM4CCM(ciphertext, key, nonce, aad []byte) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return aead.Open(nil, nonce, ciphertext, aad)
}
func EncryptSM4CCMChunk(plain, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return encryptCCMChunk(aead, plain, nonce, aad, chunkIndex), nil
}
func DecryptSM4CCMChunk(ciphertext, key, nonce, aad []byte, chunkIndex uint64) ([]byte, error) {
aead, err := newSM4CCM(key)
if err != nil {
return nil, err
}
if len(nonce) != aead.NonceSize() {
return nil, ErrInvalidCCMNonceLength
}
return decryptCCMChunk(aead, ciphertext, nonce, aad, chunkIndex)
}
func EncryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newSM4CCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return encryptCCMChunkedStream(dst, src, aead, nonce, aad)
}
func DecryptSM4CCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error {
aead, err := newSM4CCM(key)
if err != nil {
return err
}
if len(nonce) != aead.NonceSize() {
return ErrInvalidCCMNonceLength
}
return decryptCCMChunkedOrLegacyStream(dst, src, aead, nonce, aad)
}
func EncryptSM4CFB(origData, key []byte) ([]byte, error) {
@ -235,6 +380,18 @@ func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) {
return DecryptSM4(src, key, iv, MODECTR, "")
}
func EncryptSM4CTRAt(data, key, iv []byte, offset int64) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return xorCTRAtOffset(block, data, iv, offset)
}
func DecryptSM4CTRAt(src, key, iv []byte, offset int64) ([]byte, error) {
return EncryptSM4CTRAt(src, key, iv, offset)
}
func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error {
return EncryptSM4Stream(dst, src, key, nil, MODEECB, paddingType)
}
@ -274,3 +431,63 @@ func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error {
return DecryptSM4Stream(dst, src, key, iv, MODECTR, "")
}
func EncryptSM4CFB8(data, key, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return encryptCFB8(block, data, iv)
}
func DecryptSM4CFB8(src, key, iv []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFB8(block, src, iv)
}
func EncryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, false)
}
func DecryptSM4CFB8Stream(dst io.Writer, src io.Reader, key, iv []byte) error {
block, err := sm4.NewCipher(key)
if err != nil {
return err
}
return encryptCFB8Stream(block, dst, src, iv, true)
}
func DecryptSM4ECBBlocks(src, key []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptECBBlocks(block, src)
}
// DecryptSM4CBCFromSecondBlock decrypts a CBC ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptSM4CBCFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCBCFromSecondBlock(block, src, prevCipherBlock)
}
// DecryptSM4CFBFromSecondBlock decrypts a CFB ciphertext segment that starts from block 2 or later.
// prevCipherBlock must be the previous ciphertext block. For data from block 2, pass block 1 as prevCipherBlock.
func DecryptSM4CFBFromSecondBlock(src, key, prevCipherBlock []byte) ([]byte, error) {
block, err := sm4.NewCipher(key)
if err != nil {
return nil, err
}
return decryptCFBFromSecondBlock(block, src, prevCipherBlock)
}

View File

@ -6,21 +6,39 @@ import (
"testing"
)
func TestEncryptAesDefaultModeCBC(t *testing.T) {
func TestEncryptAesDefaultModeGCM(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("aes-default-mode-cbc")
nonce := []byte("123456789012")
plain := []byte("aes-default-mode-gcm")
encDefault, err := EncryptAes(plain, key, iv, "", "")
encDefault, err := EncryptAes(plain, key, nonce, "", "")
if err != nil {
t.Fatalf("EncryptAes default failed: %v", err)
}
encCBC, err := EncryptAesCBC(plain, key, iv, "")
encGCM, err := EncryptAesGCM(plain, key, nonce, nil)
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
t.Fatalf("EncryptAesGCM failed: %v", err)
}
if !bytes.Equal(encDefault, encCBC) {
t.Fatalf("default mode should match CBC mode")
if !bytes.Equal(encDefault, encGCM) {
t.Fatalf("default mode should match GCM mode")
}
}
func TestEncryptSM4DefaultModeGCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
plain := []byte("sm4-default-mode-gcm")
encDefault, err := EncryptSM4(plain, key, nonce, "", "")
if err != nil {
t.Fatalf("EncryptSM4 default failed: %v", err)
}
encGCM, err := EncryptSM4GCM(plain, key, nonce, nil)
if err != nil {
t.Fatalf("EncryptSM4GCM failed: %v", err)
}
if !bytes.Equal(encDefault, encGCM) {
t.Fatalf("default mode should match GCM mode")
}
}
@ -529,6 +547,62 @@ func TestChaCha20Poly1305RFCVector(t *testing.T) {
t.Fatalf("ChaCha20-Poly1305 vector mismatch: got %x want %x", enc, want)
}
}
func TestAesGCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("aes-gcm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptAesGCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptAesGCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesGCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesGCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes gcm stream mismatch")
}
}
func TestAesGCMStreamLegacyCompatDecrypt(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-gcm-legacy-compat")
legacyCipher, err := EncryptAesGCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesGCM failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesGCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesGCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes gcm legacy decrypt mismatch")
}
}
func TestSM4GCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("sm4-gcm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptSM4GCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptSM4GCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4GCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptSM4GCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 gcm stream mismatch")
}
}
func TestAesOptionsDefaultToGCM(t *testing.T) {
key := []byte("0123456789abcdef")
@ -598,6 +672,113 @@ func TestLargeStreamRoundTrip(t *testing.T) {
}
}
func TestAesCTRAtOffsetSegment(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("aes-ctr-offset-"), 256)
full, err := EncryptAesCTR(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCTR failed: %v", err)
}
offset := 137
length := 521
segCipher := full[offset : offset+length]
segPlain, err := DecryptAesCTRAt(segCipher, key, iv, int64(offset))
if err != nil {
t.Fatalf("DecryptAesCTRAt failed: %v", err)
}
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
t.Fatalf("aes ctr offset decrypt mismatch")
}
encSeg, err := EncryptAesCTRAt(plain[offset:offset+length], key, iv, int64(offset))
if err != nil {
t.Fatalf("EncryptAesCTRAt failed: %v", err)
}
if !bytes.Equal(encSeg, segCipher) {
t.Fatalf("aes ctr offset encrypt mismatch")
}
}
func TestSM4CTRAtOffsetSegment(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("sm4-ctr-offset-"), 256)
full, err := EncryptSM4CTR(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CTR failed: %v", err)
}
offset := 193
length := 487
segCipher := full[offset : offset+length]
segPlain, err := DecryptSM4CTRAt(segCipher, key, iv, int64(offset))
if err != nil {
t.Fatalf("DecryptSM4CTRAt failed: %v", err)
}
if !bytes.Equal(segPlain, plain[offset:offset+length]) {
t.Fatalf("sm4 ctr offset decrypt mismatch")
}
encSeg, err := EncryptSM4CTRAt(plain[offset:offset+length], key, iv, int64(offset))
if err != nil {
t.Fatalf("EncryptSM4CTRAt failed: %v", err)
}
if !bytes.Equal(encSeg, segCipher) {
t.Fatalf("sm4 ctr offset encrypt mismatch")
}
}
func TestAesGCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-gcm-chunk")
chunkIndex := uint64(7)
enc, err := EncryptAesGCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptAesGCMChunk failed: %v", err)
}
dec, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptAesGCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes gcm chunk decrypt mismatch")
}
if _, err := DecryptAesGCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestSM4GCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-gcm-chunk")
chunkIndex := uint64(11)
enc, err := EncryptSM4GCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptSM4GCMChunk failed: %v", err)
}
dec, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptSM4GCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 gcm chunk decrypt mismatch")
}
if _, err := DecryptSM4GCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func mustHex(t *testing.T, s string) []byte {
t.Helper()
b, err := hex.DecodeString(s)
@ -606,3 +787,553 @@ func mustHex(t *testing.T, s string) []byte {
}
return b
}
func TestAesCFB8RoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("aes-cfb8-roundtrip-content")
enc, err := EncryptAesCFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB8 failed: %v", err)
}
dec, err := DecryptAesCFB8(enc, key, iv)
if err != nil {
t.Fatalf("DecryptAesCFB8 failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes cfb8 mismatch")
}
}
func TestSM4CFB8RoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := []byte("sm4-cfb8-roundtrip-content")
enc, err := EncryptSM4CFB8(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB8 failed: %v", err)
}
dec, err := DecryptSM4CFB8(enc, key, iv)
if err != nil {
t.Fatalf("DecryptSM4CFB8 failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 cfb8 mismatch")
}
}
func TestAesCFB8StreamRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("aes-cfb8-stream-"), 512)
enc := &bytes.Buffer{}
if err := EncryptAesCFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
t.Fatalf("EncryptAesCFB8Stream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
t.Fatalf("DecryptAesCFB8Stream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes cfb8 stream mismatch")
}
}
func TestSM4CFB8StreamRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("sm4-cfb8-stream-"), 512)
enc := &bytes.Buffer{}
if err := EncryptSM4CFB8Stream(enc, bytes.NewReader(plain), key, iv); err != nil {
t.Fatalf("EncryptSM4CFB8Stream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4CFB8Stream(dec, bytes.NewReader(enc.Bytes()), key, iv); err != nil {
t.Fatalf("DecryptSM4CFB8Stream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 cfb8 stream mismatch")
}
}
func TestAesSegmentDecryptModes(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
ecbEnc, err := EncryptAesECB(plain, key, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptAesECB failed: %v", err)
}
ecbDec, err := DecryptAesECBBlocks(ecbEnc, key)
if err != nil {
t.Fatalf("DecryptAesECBBlocks failed: %v", err)
}
if !bytes.Equal(ecbDec, plain) {
t.Fatalf("aes ecb segment mismatch")
}
cbcEnc, err := EncryptAesCBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptAesCBC failed: %v", err)
}
cbcSegDec, err := DecryptAesCBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
t.Fatalf("aes cbc from-second-block mismatch")
}
cfbEnc, err := EncryptAesCFB(plain, key, iv)
if err != nil {
t.Fatalf("EncryptAesCFB failed: %v", err)
}
cfbSegDec, err := DecryptAesCFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptAesCFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
t.Fatalf("aes cfb from-second-block mismatch")
}
}
func TestSM4SegmentDecryptModes(t *testing.T) {
key := []byte("0123456789abcdef")
iv := []byte("abcdef9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 4)
ecbEnc, err := EncryptSM4ECB(plain, key, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptSM4ECB failed: %v", err)
}
ecbDec, err := DecryptSM4ECBBlocks(ecbEnc, key)
if err != nil {
t.Fatalf("DecryptSM4ECBBlocks failed: %v", err)
}
if !bytes.Equal(ecbDec, plain) {
t.Fatalf("sm4 ecb segment mismatch")
}
cbcEnc, err := EncryptSM4CBC(plain, key, iv, ZEROPADDING)
if err != nil {
t.Fatalf("EncryptSM4CBC failed: %v", err)
}
cbcSegDec, err := DecryptSM4CBCFromSecondBlock(cbcEnc[len(iv):], key, cbcEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CBCFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cbcSegDec, plain[len(iv):]) {
t.Fatalf("sm4 cbc from-second-block mismatch")
}
cfbEnc, err := EncryptSM4CFBNoBlock(plain, key, iv)
if err != nil {
t.Fatalf("EncryptSM4CFB failed: %v", err)
}
cfbSegDec, err := DecryptSM4CFBFromSecondBlock(cfbEnc[len(iv):], key, cfbEnc[:len(iv)])
if err != nil {
t.Fatalf("DecryptSM4CFBFromSecondBlock failed: %v", err)
}
if !bytes.Equal(cfbSegDec, plain[len(iv):]) {
t.Fatalf("sm4 cfb from-second-block mismatch")
}
}
func TestAesCCMRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-roundtrip")
enc, err := EncryptAesCCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesCCM failed: %v", err)
}
dec, err := DecryptAesCCM(enc, key, nonce, aad)
if err != nil {
t.Fatalf("DecryptAesCCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes ccm mismatch")
}
}
func TestSM4CCMRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-ccm-roundtrip")
enc, err := EncryptSM4CCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptSM4CCM failed: %v", err)
}
dec, err := DecryptSM4CCM(enc, key, nonce, aad)
if err != nil {
t.Fatalf("DecryptSM4CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 ccm mismatch")
}
}
func TestAesCCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("aes-ccm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptAesCCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptAesCCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesCCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes ccm stream mismatch")
}
}
func TestAesCCMStreamLegacyCompatDecrypt(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-legacy-compat")
legacyCipher, err := EncryptAesCCM(plain, key, nonce, aad)
if err != nil {
t.Fatalf("EncryptAesCCM failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesCCMStream(dec, bytes.NewReader(legacyCipher), key, nonce, aad); err != nil {
t.Fatalf("DecryptAesCCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes ccm legacy decrypt mismatch")
}
}
func TestSM4CCMStreamRoundTripChunked(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := bytes.Repeat([]byte("sm4-ccm-stream-chunk-"), 10000)
enc := &bytes.Buffer{}
if err := EncryptSM4CCMStream(enc, bytes.NewReader(plain), key, nonce, aad); err != nil {
t.Fatalf("EncryptSM4CCMStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4CCMStream(dec, bytes.NewReader(enc.Bytes()), key, nonce, aad); err != nil {
t.Fatalf("DecryptSM4CCMStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 ccm stream mismatch")
}
}
func TestAesCCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-ccm-chunk")
chunkIndex := uint64(5)
enc, err := EncryptAesCCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptAesCCMChunk failed: %v", err)
}
dec, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptAesCCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes ccm chunk decrypt mismatch")
}
if _, err := DecryptAesCCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestSM4CCMChunkRoundTrip(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-ccm-chunk")
chunkIndex := uint64(9)
enc, err := EncryptSM4CCMChunk(plain, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("EncryptSM4CCMChunk failed: %v", err)
}
dec, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex)
if err != nil {
t.Fatalf("DecryptSM4CCMChunk failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 ccm chunk decrypt mismatch")
}
if _, err := DecryptSM4CCMChunk(enc, key, nonce, aad, chunkIndex+1); err == nil {
t.Fatalf("expected decrypt error for wrong chunk index")
}
}
func TestAesOptionsModeCCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("aes-options-ccm")
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
enc, err := EncryptAesWithOptions(plain, key, opts)
if err != nil {
t.Fatalf("EncryptAesWithOptions CCM failed: %v", err)
}
dec, err := DecryptAesWithOptions(enc, key, opts)
if err != nil {
t.Fatalf("DecryptAesWithOptions CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes options ccm mismatch")
}
}
func TestSM4OptionsModeCCM(t *testing.T) {
key := []byte("0123456789abcdef")
nonce := []byte("123456789012")
aad := []byte("aad")
plain := []byte("sm4-options-ccm")
opts := &CipherOptions{Mode: MODECCM, Nonce: nonce, AAD: aad}
enc, err := EncryptSM4WithOptions(plain, key, opts)
if err != nil {
t.Fatalf("EncryptSM4WithOptions CCM failed: %v", err)
}
dec, err := DecryptSM4WithOptions(enc, key, opts)
if err != nil {
t.Fatalf("DecryptSM4WithOptions CCM failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 options ccm mismatch")
}
}
func TestCCMInvalidNonceLength(t *testing.T) {
key := []byte("0123456789abcdef")
shortNonce := []byte("short")
if _, err := EncryptAesCCM([]byte("x"), key, shortNonce, nil); err == nil {
t.Fatalf("expected aes ccm nonce length error")
}
if _, err := EncryptSM4CCM([]byte("x"), key, shortNonce, nil); err == nil {
t.Fatalf("expected sm4 ccm nonce length error")
}
}
func TestAesXTSRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
enc, err := EncryptAesXTS(plain, k1, k2, 32)
if err != nil {
t.Fatalf("EncryptAesXTS failed: %v", err)
}
dec, err := DecryptAesXTS(enc, k1, k2, 32)
if err != nil {
t.Fatalf("DecryptAesXTS failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("aes xts mismatch")
}
}
func TestSM4XTSRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
enc, err := EncryptSM4XTS(plain, k1, k2, 32)
if err != nil {
t.Fatalf("EncryptSM4XTS failed: %v", err)
}
dec, err := DecryptSM4XTS(enc, k1, k2, 32)
if err != nil {
t.Fatalf("DecryptSM4XTS failed: %v", err)
}
if !bytes.Equal(dec, plain) {
t.Fatalf("sm4 xts mismatch")
}
}
func TestAesXTSAtDataUnit(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
dataUnitSize := 32
full, err := EncryptAesXTS(plain, k1, k2, dataUnitSize)
if err != nil {
t.Fatalf("EncryptAesXTS failed: %v", err)
}
segPlain := plain[64:96]
segEnc, err := EncryptAesXTSAt(segPlain, k1, k2, dataUnitSize, 2)
if err != nil {
t.Fatalf("EncryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(segEnc, full[64:96]) {
t.Fatalf("aes xts at encrypt mismatch")
}
segDec, err := DecryptAesXTSAt(full[64:96], k1, k2, dataUnitSize, 2)
if err != nil {
t.Fatalf("DecryptAesXTSAt failed: %v", err)
}
if !bytes.Equal(segDec, segPlain) {
t.Fatalf("aes xts at decrypt mismatch")
}
}
func TestSM4XTSAtDataUnit(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 8)
dataUnitSize := 32
full, err := EncryptSM4XTS(plain, k1, k2, dataUnitSize)
if err != nil {
t.Fatalf("EncryptSM4XTS failed: %v", err)
}
segPlain := plain[32:64]
segEnc, err := EncryptSM4XTSAt(segPlain, k1, k2, dataUnitSize, 1)
if err != nil {
t.Fatalf("EncryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(segEnc, full[32:64]) {
t.Fatalf("sm4 xts at encrypt mismatch")
}
segDec, err := DecryptSM4XTSAt(full[32:64], k1, k2, dataUnitSize, 1)
if err != nil {
t.Fatalf("DecryptSM4XTSAt failed: %v", err)
}
if !bytes.Equal(segDec, segPlain) {
t.Fatalf("sm4 xts at decrypt mismatch")
}
}
func TestAesXTSStreamRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
enc := &bytes.Buffer{}
if err := EncryptAesXTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
t.Fatalf("EncryptAesXTSStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptAesXTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
t.Fatalf("DecryptAesXTSStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("aes xts stream mismatch")
}
}
func TestSM4XTSStreamRoundTrip(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2048)
enc := &bytes.Buffer{}
if err := EncryptSM4XTSStream(enc, bytes.NewReader(plain), k1, k2, 512); err != nil {
t.Fatalf("EncryptSM4XTSStream failed: %v", err)
}
dec := &bytes.Buffer{}
if err := DecryptSM4XTSStream(dec, bytes.NewReader(enc.Bytes()), k1, k2, 512); err != nil {
t.Fatalf("DecryptSM4XTSStream failed: %v", err)
}
if !bytes.Equal(dec.Bytes(), plain) {
t.Fatalf("sm4 xts stream mismatch")
}
}
func TestXTSRejectsNonBlockMultiple(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
if _, err := EncryptAesXTS([]byte("short"), k1, k2, 32); err == nil {
t.Fatalf("expected aes xts non-block error")
}
if _, err := EncryptSM4XTS([]byte("short"), k1, k2, 32); err == nil {
t.Fatalf("expected sm4 xts non-block error")
}
}
func TestXTSStreamRejectsTailNotFullBlock(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
err := EncryptAesXTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
if err == nil {
t.Fatalf("expected aes xts stream tail error")
}
err = EncryptSM4XTSStream(&bytes.Buffer{}, bytes.NewReader([]byte("tail-not-16")), k1, k2, 32)
if err == nil {
t.Fatalf("expected sm4 xts stream tail error")
}
}
func TestXTSInvalidDataUnitSize(t *testing.T) {
k1 := []byte("0123456789abcdef")
k2 := []byte("fedcba9876543210")
plain := bytes.Repeat([]byte("0123456789abcdef"), 2)
if _, err := EncryptAesXTS(plain, k1, k2, 30); err == nil {
t.Fatalf("expected aes xts invalid data unit size error")
}
if _, err := EncryptSM4XTS(plain, k1, k2, 30); err == nil {
t.Fatalf("expected sm4 xts invalid data unit size error")
}
}
func TestSplitXTSMasterKeyHelpers(t *testing.T) {
master := []byte("0123456789abcdef0123456789abcdef")
k1, k2, err := SplitXTSMasterKey(master)
if err != nil {
t.Fatalf("SplitXTSMasterKey failed: %v", err)
}
if len(k1) != 16 || len(k2) != 16 {
t.Fatalf("split key lengths mismatch")
}
if !bytes.Equal(append(k1, k2...), master) {
t.Fatalf("split key content mismatch")
}
aesK1, aesK2, err := SplitAesXTSMasterKey(master)
if err != nil {
t.Fatalf("SplitAesXTSMasterKey failed: %v", err)
}
if !bytes.Equal(aesK1, k1) || !bytes.Equal(aesK2, k2) {
t.Fatalf("aes split mismatch")
}
sm4K1, sm4K2, err := SplitSM4XTSMasterKey(master)
if err != nil {
t.Fatalf("SplitSM4XTSMasterKey failed: %v", err)
}
if !bytes.Equal(sm4K1, k1) || !bytes.Equal(sm4K2, k2) {
t.Fatalf("sm4 split mismatch")
}
if _, _, err := SplitXTSMasterKey([]byte("abc")); err == nil {
t.Fatalf("expected odd-length split error")
}
if _, _, err := SplitAesXTSMasterKey([]byte("0123456789abcdef0123456789abcdef01")); err == nil {
t.Fatalf("expected aes master length error")
}
if _, _, err := SplitSM4XTSMasterKey([]byte("0123456789abcdef")); err == nil {
t.Fatalf("expected sm4 master length error")
}
}

263
symm/xts.go Normal file
View File

@ -0,0 +1,263 @@
package symm
import (
"crypto/aes"
"crypto/cipher"
"errors"
"io"
"github.com/emmansun/gmsm/sm4"
"golang.org/x/crypto/xts"
)
const xtsBlockSize = 16
var (
ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16")
ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16")
ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal")
ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even")
ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes")
ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes")
)
type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error)
func combineXTSKeys(key1, key2 []byte) ([]byte, error) {
if len(key1) == 0 || len(key1) != len(key2) {
return nil, ErrInvalidXTSKeyLength
}
out := make([]byte, len(key1)+len(key2))
copy(out, key1)
copy(out[len(key1):], key2)
return out, nil
}
func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) == 0 || len(masterKey)%2 != 0 {
return nil, nil, ErrInvalidXTSMasterKeyLength
}
half := len(masterKey) / 2
k1 := make([]byte, half)
k2 := make([]byte, half)
copy(k1, masterKey[:half])
copy(k2, masterKey[half:])
return k1, k2, nil
}
// SplitXTSMasterKey splits a master key into two equal XTS keys.
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return splitXTSMasterKey(masterKey)
}
// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes).
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
switch len(masterKey) {
case 32, 48, 64:
return splitXTSMasterKey(masterKey)
default:
return nil, nil, ErrInvalidAESXTSMasterKeyLength
}
}
// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes).
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) != 32 {
return nil, nil, ErrInvalidSM4XTSMasterKeyLength
}
return splitXTSMasterKey(masterKey)
}
func validateXTSDataUnitSize(dataUnitSize int) error {
if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 {
return ErrInvalidXTSDataUnitSize
}
return nil
}
func validateXTSDataLength(data []byte) error {
if len(data)%xtsBlockSize != 0 {
return ErrInvalidXTSDataLength
}
return nil
}
func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return nil, err
}
if err := validateXTSDataLength(in); err != nil {
return nil, err
}
if len(in) == 0 {
return []byte{}, nil
}
out := make([]byte, len(in))
off := 0
unit := dataUnitIndex
for off < len(in) {
chunkLen := dataUnitSize
if remain := len(in) - off; remain < chunkLen {
chunkLen = remain
}
if decrypt {
c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
} else {
c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
}
off += chunkLen
unit++
}
return out, nil
}
func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return err
}
buf := make([]byte, 32*1024)
pending := make([]byte, 0, dataUnitSize*2)
unit := dataUnitIndex
for {
n, err := src.Read(buf)
if n > 0 {
pending = append(pending, buf[:n]...)
processLen := len(pending) / dataUnitSize * dataUnitSize
if processLen > 0 {
out := make([]byte, processLen)
for off := 0; off < processLen; off += dataUnitSize {
if decrypt {
c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
} else {
c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
}
unit++
}
if _, werr := dst.Write(out); werr != nil {
return werr
}
pending = append([]byte(nil), pending[processLen:]...)
}
}
if err != nil {
if err == io.EOF {
break
}
return err
}
}
if len(pending) == 0 {
return nil
}
if err := validateXTSDataLength(pending); err != nil {
return err
}
out := make([]byte, len(pending))
if decrypt {
c.Decrypt(out, pending, unit)
} else {
c.Encrypt(out, pending, unit)
}
_, err := dst.Write(out)
return err
}
func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) {
key, err := combineXTSKeys(key1, key2)
if err != nil {
return nil, err
}
return xts.NewCipher(newBlock, key)
}
func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(aes.NewCipher, key1, key2)
}
func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(sm4.NewCipher, key1, key2)
}
func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) {
c, err := factory(key1, key2)
if err != nil {
return nil, err
}
return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt)
}
func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error {
c, err := factory(key1, key2)
if err != nil {
return err
}
return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt)
}
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}

79
symm/xts_vector_test.go Normal file
View File

@ -0,0 +1,79 @@
package symm
import (
"bytes"
"testing"
)
// AES-XTS vectors from IEEE P1619/D16 Annex B (same set used by golang.org/x/crypto/xts tests).
var aesXTSStandardVectors = []struct {
key string
dataUnitIndex uint64
plaintext string
ciphertext string
}{
{
key: "0000000000000000000000000000000000000000000000000000000000000000",
dataUnitIndex: 0,
plaintext: "0000000000000000000000000000000000000000000000000000000000000000",
ciphertext: "917cf69ebd68b2ec9b9fe9a3eadda692cd43d2f59598ed858c02c2652fbf922e",
},
{
key: "1111111111111111111111111111111122222222222222222222222222222222",
dataUnitIndex: 0x3333333333,
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
ciphertext: "c454185e6a16936e39334038acef838bfb186fff7480adc4289382ecd6d394f0",
},
{
key: "fffefdfcfbfaf9f8f7f6f5f4f3f2f1f022222222222222222222222222222222",
dataUnitIndex: 0x3333333333,
plaintext: "4444444444444444444444444444444444444444444444444444444444444444",
ciphertext: "af85336b597afc1a900b2eb21ec949d292df4c047e0b21532186a5971a227a89",
},
}
func TestAesXTSStandardVectors(t *testing.T) {
for i, tc := range aesXTSStandardVectors {
master := mustHex(t, tc.key)
key1, key2, err := SplitAesXTSMasterKey(master)
if err != nil {
t.Fatalf("#%d split key failed: %v", i, err)
}
plain := mustHex(t, tc.plaintext)
wantCipher := mustHex(t, tc.ciphertext)
dataUnitSize := len(plain)
gotCipher, err := EncryptAesXTSAt(plain, key1, key2, dataUnitSize, tc.dataUnitIndex)
if err != nil {
t.Fatalf("#%d EncryptAesXTSAt failed: %v", i, err)
}
if !bytes.Equal(gotCipher, wantCipher) {
t.Fatalf("#%d ciphertext mismatch", i)
}
gotPlain, err := DecryptAesXTSAt(wantCipher, key1, key2, dataUnitSize, tc.dataUnitIndex)
if err != nil {
t.Fatalf("#%d DecryptAesXTSAt failed: %v", i, err)
}
if !bytes.Equal(gotPlain, plain) {
t.Fatalf("#%d plaintext mismatch", i)
}
encStream := &bytes.Buffer{}
if err := EncryptAesXTSStreamAt(encStream, bytes.NewReader(plain), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
t.Fatalf("#%d EncryptAesXTSStreamAt failed: %v", i, err)
}
if !bytes.Equal(encStream.Bytes(), wantCipher) {
t.Fatalf("#%d stream ciphertext mismatch", i)
}
decStream := &bytes.Buffer{}
if err := DecryptAesXTSStreamAt(decStream, bytes.NewReader(wantCipher), key1, key2, dataUnitSize, tc.dataUnitIndex); err != nil {
t.Fatalf("#%d DecryptAesXTSStreamAt failed: %v", i, err)
}
if !bytes.Equal(decStream.Bytes(), plain) {
t.Fatalf("#%d stream plaintext mismatch", i)
}
}
}

15
xts_keysplit.go Normal file
View File

@ -0,0 +1,15 @@
package starcrypto
import "b612.me/starcrypto/symm"
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return symm.SplitXTSMasterKey(masterKey)
}
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return symm.SplitAesXTSMasterKey(masterKey)
}
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return symm.SplitSM4XTSMasterKey(masterKey)
}