diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..357532a --- /dev/null +++ b/.gitignore @@ -0,0 +1,29 @@ +# IDE / editor +.idea/ +.vscode/ + +# Local Go module/cache sandboxes used in this repo +.gopath/ +.gomodcache/ + +# Build artifacts +*.exe +*.dll +*.so +*.dylib +*.test +*.out +*.prof +bin/ +dist/ + +# Coverage +coverage.out +*.coverprofile + +# OS files +.DS_Store +Thumbs.db + +# Temp/test artifacts +sm3/ifile diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..ac22290 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,34 @@ +# Changelog + +## Unreleased + +### Added +- Introduced subpackages and root wrappers: + - `asymm`, `symm`, `hashx`, `encodingx`, `paddingx`, `filex`, `legacy`, `macx`. +- Added Chinese `README.md` and Apache-2.0 `LICENSE`. +- Added `SM9` support in asymmetric APIs. +- Added `ChaCha20` and `ChaCha20-Poly1305` APIs (memory + stream wrappers where applicable). +- Added unified symmetric cipher options API: + - `CipherOptions{Mode, Padding, IV, Nonce, AAD}`. +- Added AEAD APIs and wrappers: + - `AES-GCM`, `SM4-GCM` (bytes + stream helper APIs). +- Added more symmetric mode coverage for SM4: + - `ECB/CBC/CFB/OFB/CTR` (bytes + stream derived APIs). +- Added comprehensive tests across packages and root wrappers. +- Added fuzz tests for `paddingx`, `encodingx`, and `symm` round-trip invariants. + +### Changed +- Refactored monolithic implementation to subpackage architecture while preserving root-package convenience APIs. +- AES mode APIs now support generic mode selection and derived mode helpers. +- Stream APIs expanded across AES/SM4/DES/3DES and ChaCha20. +- Updated docs to include a security-first recommendation and algorithm capability matrix. +- Updated dependencies and modules for current code paths (`gmsm`, `x/crypto`). + +### Fixed +- Fixed Base128 encode/decode round-trip bug in `encodingx`. +- Corrected CRC32A test expectations and clarified CRC32A variant comments. +- Corrected default padding behavior for AES-CBC to PKCS7. + +### Notes +- Legacy/insecure algorithms and modes remain available for compatibility. +- Production recommendations now explicitly prefer AEAD schemes. diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..47182e3 --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# starcrypto + +`starcrypto` 是一个 Go 单包聚合风格的密码学工具库:根包可直接调用,同时内部拆分为子包实现,兼顾易用性与可维护性。 + +## 特性 + +- 根包直调 + 子包分层(`asymm` / `symm` / `hashx` / `encodingx` / `paddingx` 等) +- 对称算法:AES、SM4、DES、3DES、ChaCha20、ChaCha20-Poly1305 +- 非对称算法:RSA、ECDSA、SM2、SM9 +- 哈希与消息认证:SHA 系列、MD5/MD4、RIPEMD160、SM3、CRC32/CRC32A、HMAC +- 编解码:Base64、Base85、Base91、Base128 +- 支持内存 `[]byte` 与流式 `io.Reader/io.Writer`(见下方能力矩阵) + +## 安装 + +```bash +go get b612.me/starcrypto +``` + +## 推荐用法(安全优先) + +优先使用带认证的 AEAD: + +- `EncryptAesGCM/DecryptAesGCM` +- `EncryptSM4GCM/DecryptSM4GCM` +- `EncryptChaCha20Poly1305/DecryptChaCha20Poly1305` + +或使用统一选项接口(默认 `GCM`): + +- `EncryptAesWithOptions/DecryptAesWithOptions` +- `EncryptSM4WithOptions/DecryptSM4WithOptions` + +> `CBC/CFB/OFB/CTR/ECB` 仅提供机密性,不提供完整性校验,可能被篡改后无法检测。 + +## 快速示例 + +### 统一 Options(默认 GCM) + +```go +package main + +import ( + "fmt" + "log" + + "b612.me/starcrypto" +) + +func main() { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("hello starcrypto") + + opts := &starcrypto.CipherOptions{Nonce: nonce} + enc, err := starcrypto.EncryptAesWithOptions(plain, key, opts) + if err != nil { + log.Fatal(err) + } + dec, err := starcrypto.DecryptAesWithOptions(enc, key, opts) + if err != nil { + log.Fatal(err) + } + + fmt.Println(string(dec)) +} +``` + +### AES CBC 流式接口(兼容模式) + +```go +package main + +import ( + "bytes" + "log" + + "b612.me/starcrypto" +) + +func main() { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + + src := bytes.NewReader([]byte("stream content")) + encBuf := &bytes.Buffer{} + decBuf := &bytes.Buffer{} + + if err := starcrypto.EncryptAesCBCStream(encBuf, src, key, iv, ""); err != nil { + log.Fatal(err) + } + if err := starcrypto.DecryptAesCBCStream(decBuf, bytes.NewReader(encBuf.Bytes()), key, iv, ""); err != nil { + log.Fatal(err) + } +} +``` + +## 算法能力矩阵 + +| 算法 | 模式 / 方案 | AEAD | Stream | 建议 | +|---|---|---:|---:|---| +| AES | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM | +| SM4 | ECB/CBC/CFB/OFB/CTR/GCM | GCM: 是 | 是 | 生产优先 GCM | +| ChaCha20 | ChaCha20 / ChaCha20-Poly1305 | Poly1305: 是 | ChaCha20: 是 | 生产优先 ChaCha20-Poly1305 | +| DES | CBC | 否 | 是 | 仅兼容历史系统 | +| 3DES | CBC | 否 | 是 | 仅兼容历史系统 | + +## 默认填充策略 + +- AES/SM4 的 CBC/ECB 默认:`PKCS7` +- DES/3DES 的 CBC 默认:`PKCS5` +- CFB/OFB/CTR/GCM/ChaCha20 不使用填充 + +## 兼容性说明 + +库中保留了部分历史/兼容用途算法与接口(例如 `ECB`、`DES/3DES`)。如无兼容要求,建议使用 AEAD 方案并统一通过 `CipherOptions` 管理参数。 + +## 许可证 + +本项目使用 Apache-2.0 许可证,详见 `LICENSE`。 diff --git a/aes.go b/aes.go index 4dab0d0..3464ab3 100644 --- a/aes.go +++ b/aes.go @@ -1,122 +1,183 @@ package starcrypto import ( - "bytes" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "errors" "io" + + "b612.me/starcrypto/symm" ) +type CipherOptions = symm.CipherOptions + const ( - PKCS5PADDING = "PKCS5" + PKCS5PADDING = symm.PKCS5PADDING + PKCS7PADDING = symm.PKCS7PADDING + ZEROPADDING = symm.ZEROPADDING + ANSIX923PADDING = symm.ANSIX923PADDING + + MODEECB = symm.MODEECB + MODECBC = symm.MODECBC + MODECFB = symm.MODECFB + MODEOFB = symm.MODEOFB + MODECTR = symm.MODECTR + MODEGCM = symm.MODEGCM ) -func CustomEncryptAesCFB(origData []byte, key []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - encrypted := make([]byte, aes.BlockSize+len(origData)) - iv := encrypted[:aes.BlockSize] - if _, err := io.ReadFull(rand.Reader, iv); err != nil { - return nil, err - } - stream := cipher.NewCFBEncrypter(block, iv) - stream.XORKeyStream(encrypted[aes.BlockSize:], origData) - return encrypted, nil +func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + return symm.EncryptAes(data, key, iv, mode, paddingType) } -func CustomDecryptAesCFB(encrypted []byte, key []byte) ([]byte, error) { - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - if len(encrypted) < aes.BlockSize { - return nil, errors.New("ciphertext too short") - } - iv := encrypted[:aes.BlockSize] - encrypted = encrypted[aes.BlockSize:] - - stream := cipher.NewCFBDecrypter(block, iv) - stream.XORKeyStream(encrypted, encrypted) - return encrypted, nil +func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) { + return symm.DecryptAes(src, key, iv, mode, paddingType) } -func CustomEncryptAesCFBNoBlock(origData []byte, key []byte, iv []byte) ([]byte, error) { - if len(iv) != 16 { - return nil, errors.New("iv length must be 16") - } - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - encrypted := make([]byte, len(origData)) - stream := cipher.NewCFBEncrypter(block, iv) - stream.XORKeyStream(encrypted, origData) - return encrypted, err +func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + return symm.EncryptAesStream(dst, src, key, iv, mode, paddingType) } -func CustomDecryptAesCFBNoBlock(encrypted []byte, key []byte, iv []byte) ([]byte, error) { - if len(iv) != 16 { - return nil, errors.New("iv length must be 16") - } - block, err := aes.NewCipher(key) - if err != nil { - return nil, err - } - stream := cipher.NewCFBDecrypter(block, iv) - stream.XORKeyStream(encrypted, encrypted) - return encrypted, err +func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + return symm.DecryptAesStream(dst, src, key, iv, mode, paddingType) +} + +func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { + return symm.EncryptAesWithOptions(data, key, opts) +} + +func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { + return symm.DecryptAesWithOptions(src, key, opts) +} + +func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + return symm.EncryptAesStreamWithOptions(dst, src, key, opts) +} + +func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + return symm.DecryptAesStreamWithOptions(dst, src, key, opts) +} + +func EncryptAesGCM(plain, key, nonce, aad []byte) ([]byte, error) { + return symm.EncryptAesGCM(plain, key, nonce, aad) +} + +func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + return symm.DecryptAesGCM(ciphertext, key, nonce, aad) +} + +func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.EncryptAesGCMStream(dst, src, key, nonce, aad) +} + +func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.DecryptAesGCMStream(dst, src, key, nonce, aad) +} + +func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) { + return symm.EncryptAesECB(data, key, paddingType) +} + +func DecryptAesECB(src, key []byte, paddingType string) ([]byte, error) { + return symm.DecryptAesECB(src, key, paddingType) } func EncryptAesCBC(data, key []byte, iv []byte, paddingType string) ([]byte, error) { - var content []byte - aesBlockEncrypter, err := aes.NewCipher(key) - switch paddingType { - case PKCS5PADDING: - content = PKCS5Padding(data, aesBlockEncrypter.BlockSize()) - default: - return nil, errors.New("padding type not supported") - } - encrypted := make([]byte, len(content)) - if err != nil { - return nil, err - } - aesEncrypter := cipher.NewCBCEncrypter(aesBlockEncrypter, iv) - aesEncrypter.CryptBlocks(encrypted, content) - return encrypted, nil -} - -func PKCS5Padding(cipherText []byte, blockSize int) []byte { - padding := blockSize - len(cipherText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - return append(cipherText, padText...) -} - -func PKCS5Trimming(encrypt []byte) []byte { - padding := encrypt[len(encrypt)-1] - if len(encrypt)-int(padding) < 0 { - return nil - } - return encrypt[:len(encrypt)-int(padding)] + return symm.EncryptAesCBC(data, key, iv, paddingType) } func DecryptAesCBC(src, key []byte, iv []byte, paddingType string) (data []byte, err error) { - decrypted := make([]byte, len(src)) - var aesBlockDecrypter cipher.Block - aesBlockDecrypter, err = aes.NewCipher(key) - if err != nil { - println(err.Error()) - return nil, err - } - aesDecrypter := cipher.NewCBCDecrypter(aesBlockDecrypter, iv) - aesDecrypter.CryptBlocks(decrypted, src) - switch paddingType { - case PKCS5PADDING: - return PKCS5Trimming(decrypted), nil - default: - return nil, errors.New("padding type not supported") - } + return symm.DecryptAesCBC(src, key, iv, paddingType) +} + +func EncryptAesCFB(data, key, iv []byte) ([]byte, error) { + return symm.EncryptAesCFB(data, key, iv) +} + +func DecryptAesCFB(src, key, iv []byte) ([]byte, error) { + return symm.DecryptAesCFB(src, key, iv) +} + +func EncryptAesOFB(data, key, iv []byte) ([]byte, error) { + return symm.EncryptAesOFB(data, key, iv) +} + +func DecryptAesOFB(src, key, iv []byte) ([]byte, error) { + return symm.DecryptAesOFB(src, key, iv) +} + +func EncryptAesCTR(data, key, iv []byte) ([]byte, error) { + return symm.EncryptAesCTR(data, key, iv) +} + +func DecryptAesCTR(src, key, iv []byte) ([]byte, error) { + return symm.DecryptAesCTR(src, key, iv) +} + +func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return symm.EncryptAesECBStream(dst, src, key, paddingType) +} + +func DecryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return symm.DecryptAesECBStream(dst, src, key, paddingType) +} + +func EncryptAesCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.EncryptAesCBCStream(dst, src, key, iv, paddingType) +} + +func DecryptAesCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.DecryptAesCBCStream(dst, src, key, iv, paddingType) +} + +func EncryptAesCFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptAesCFBStream(dst, src, key, iv) +} + +func DecryptAesCFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptAesCFBStream(dst, src, key, iv) +} + +func EncryptAesOFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptAesOFBStream(dst, src, key, iv) +} + +func DecryptAesOFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptAesOFBStream(dst, src, key, iv) +} + +func EncryptAesCTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptAesCTRStream(dst, src, key, iv) +} + +func DecryptAesCTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptAesCTRStream(dst, src, key, iv) +} + +func CustomEncryptAesCFB(origData []byte, key []byte) ([]byte, error) { + return symm.CustomEncryptAesCFB(origData, key) +} + +func CustomDecryptAesCFB(encrypted []byte, key []byte) ([]byte, error) { + return symm.CustomDecryptAesCFB(encrypted, key) +} + +func CustomEncryptAesCFBNoBlock(origData []byte, key []byte, iv []byte) ([]byte, error) { + return symm.CustomEncryptAesCFBNoBlock(origData, key, iv) +} + +func CustomDecryptAesCFBNoBlock(encrypted []byte, key []byte, iv []byte) ([]byte, error) { + return symm.CustomDecryptAesCFBNoBlock(encrypted, key, iv) +} + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + return symm.PKCS5Padding(cipherText, blockSize) +} + +func PKCS5Trimming(encrypt []byte) []byte { + return symm.PKCS5Trimming(encrypt) +} + +func PKCS7Padding(cipherText []byte, blockSize int) []byte { + return symm.PKCS7Padding(cipherText, blockSize) +} + +func PKCS7Trimming(encrypt []byte, blockSize int) []byte { + return symm.PKCS7Trimming(encrypt, blockSize) } diff --git a/api_test.go b/api_test.go new file mode 100644 index 0000000..62d11e5 --- /dev/null +++ b/api_test.go @@ -0,0 +1,265 @@ +package starcrypto + +import ( + "bytes" + "testing" +) + +func TestRootSymmetricWrappers(t *testing.T) { + aesKey := []byte("0123456789abcdef") + aesIV := []byte("abcdef9876543210") + plain := []byte("root-wrapper-aes") + + aesEnc, err := EncryptAesCBC(plain, aesKey, aesIV, "") + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + aesDec, err := DecryptAesCBC(aesEnc, aesKey, aesIV, "") + if err != nil { + t.Fatalf("DecryptAesCBC failed: %v", err) + } + if !bytes.Equal(aesDec, plain) { + t.Fatalf("aes wrapper mismatch") + } + + aesCFBEnc, err := EncryptAes(plain, aesKey, aesIV, MODECFB, "") + if err != nil { + t.Fatalf("EncryptAes failed: %v", err) + } + aesCFBDec, err := DecryptAes(aesCFBEnc, aesKey, aesIV, MODECFB, "") + if err != nil { + t.Fatalf("DecryptAes failed: %v", err) + } + if !bytes.Equal(aesCFBDec, plain) { + t.Fatalf("aes generic wrapper mismatch") + } + + aesStreamEnc := &bytes.Buffer{} + if err := EncryptAesCBCStream(aesStreamEnc, bytes.NewReader(plain), aesKey, aesIV, ""); err != nil { + t.Fatalf("EncryptAesCBCStream failed: %v", err) + } + aesStreamDec := &bytes.Buffer{} + if err := DecryptAesCBCStream(aesStreamDec, bytes.NewReader(aesStreamEnc.Bytes()), aesKey, aesIV, ""); err != nil { + t.Fatalf("DecryptAesCBCStream failed: %v", err) + } + if !bytes.Equal(aesStreamDec.Bytes(), plain) { + t.Fatalf("aes stream wrapper mismatch") + } + + sm4Enc, err := EncryptSM4CBC(plain, aesKey, aesIV, "") + if err != nil { + t.Fatalf("EncryptSM4CBC failed: %v", err) + } + sm4Dec, err := DecryptSM4CBC(sm4Enc, aesKey, aesIV, "") + if err != nil { + t.Fatalf("DecryptSM4CBC failed: %v", err) + } + if !bytes.Equal(sm4Dec, plain) { + t.Fatalf("sm4 wrapper mismatch") + } + + sm4StreamEnc := &bytes.Buffer{} + if err := EncryptSM4CBCStream(sm4StreamEnc, bytes.NewReader(plain), aesKey, aesIV, ""); err != nil { + t.Fatalf("EncryptSM4CBCStream failed: %v", err) + } + sm4StreamDec := &bytes.Buffer{} + if err := DecryptSM4CBCStream(sm4StreamDec, bytes.NewReader(sm4StreamEnc.Bytes()), aesKey, aesIV, ""); err != nil { + t.Fatalf("DecryptSM4CBCStream failed: %v", err) + } + if !bytes.Equal(sm4StreamDec.Bytes(), plain) { + t.Fatalf("sm4 stream wrapper mismatch") + } + + desKey := []byte("12345678") + desIV := []byte("abcdefgh") + desEnc, err := EncryptDESCBC(plain, desKey, desIV, "") + if err != nil { + t.Fatalf("EncryptDESCBC failed: %v", err) + } + desDec, err := DecryptDESCBC(desEnc, desKey, desIV, "") + if err != nil { + t.Fatalf("DecryptDESCBC failed: %v", err) + } + if !bytes.Equal(desDec, plain) { + t.Fatalf("des wrapper mismatch") + } + + desStreamEnc := &bytes.Buffer{} + if err := EncryptDESCBCStream(desStreamEnc, bytes.NewReader(plain), desKey, desIV, ""); err != nil { + t.Fatalf("EncryptDESCBCStream failed: %v", err) + } + desStreamDec := &bytes.Buffer{} + if err := DecryptDESCBCStream(desStreamDec, bytes.NewReader(desStreamEnc.Bytes()), desKey, desIV, ""); err != nil { + t.Fatalf("DecryptDESCBCStream failed: %v", err) + } + if !bytes.Equal(desStreamDec.Bytes(), plain) { + t.Fatalf("des stream wrapper mismatch") + } +} + +func TestRootSM2AndSM9Wrappers(t *testing.T) { + sm2Priv, sm2Pub, err := GenerateSM2Key() + if err != nil { + t.Fatalf("GenerateSM2Key failed: %v", err) + } + msg := []byte("root-sm") + sig, err := SM2Sign(sm2Priv, msg, nil) + if err != nil { + t.Fatalf("SM2Sign failed: %v", err) + } + if !SM2Verify(sm2Pub, msg, sig, nil) { + t.Fatalf("SM2Verify failed") + } + + signMaster, signPub, err := GenerateSM9SignMasterKey() + if err != nil { + t.Fatalf("GenerateSM9SignMasterKey failed: %v", err) + } + encryptMaster, encryptPub, err := GenerateSM9EncryptMasterKey() + if err != nil { + t.Fatalf("GenerateSM9EncryptMasterKey failed: %v", err) + } + uid := []byte("root@example.com") + signUser, err := GenerateSM9SignUserKey(signMaster, uid, 0) + if err != nil { + t.Fatalf("GenerateSM9SignUserKey failed: %v", err) + } + encryptUser, err := GenerateSM9EncryptUserKey(encryptMaster, uid, 0) + if err != nil { + t.Fatalf("GenerateSM9EncryptUserKey failed: %v", err) + } + + sm9Sig, err := SM9SignASN1(signUser, msg) + if err != nil { + t.Fatalf("SM9SignASN1 failed: %v", err) + } + if !SM9VerifyASN1(signPub, uid, 0, msg, sm9Sig) { + t.Fatalf("SM9VerifyASN1 failed") + } + + cipher, err := SM9Encrypt(encryptPub, uid, 0, msg) + if err != nil { + t.Fatalf("SM9Encrypt failed: %v", err) + } + dec, err := SM9Decrypt(encryptUser, uid, cipher) + if err != nil { + t.Fatalf("SM9Decrypt failed: %v", err) + } + if !bytes.Equal(dec, msg) { + t.Fatalf("SM9 wrapper decrypt mismatch") + } +} + +func TestRootChaChaAndSM4DerivedWrappers(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("root-chacha-wrapper") + + chachaEnc, err := EncryptChaCha20(plain, key, nonce) + if err != nil { + t.Fatalf("EncryptChaCha20 failed: %v", err) + } + chachaDec, err := DecryptChaCha20(chachaEnc, key, nonce) + if err != nil { + t.Fatalf("DecryptChaCha20 failed: %v", err) + } + if !bytes.Equal(chachaDec, plain) { + t.Fatalf("chacha wrapper mismatch") + } + + polyEnc, err := EncryptChaCha20Poly1305(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptChaCha20Poly1305 failed: %v", err) + } + polyDec, err := DecryptChaCha20Poly1305(polyEnc, key, nonce, aad) + if err != nil { + t.Fatalf("DecryptChaCha20Poly1305 failed: %v", err) + } + if !bytes.Equal(polyDec, plain) { + t.Fatalf("chacha20-poly1305 wrapper mismatch") + } + + sm4Key := []byte("0123456789abcdef") + sm4IV := []byte("abcdef9876543210") + sm4Plain := []byte("root-sm4-derived") + + ecbEnc, err := EncryptSM4ECB(sm4Plain, sm4Key, "") + if err != nil { + t.Fatalf("EncryptSM4ECB failed: %v", err) + } + ecbDec, err := DecryptSM4ECB(ecbEnc, sm4Key, "") + if err != nil { + t.Fatalf("DecryptSM4ECB failed: %v", err) + } + if !bytes.Equal(ecbDec, sm4Plain) { + t.Fatalf("sm4 ecb wrapper mismatch") + } + + ofbEnc, err := EncryptSM4OFB(sm4Plain, sm4Key, sm4IV) + if err != nil { + t.Fatalf("EncryptSM4OFB failed: %v", err) + } + ofbDec, err := DecryptSM4OFB(ofbEnc, sm4Key, sm4IV) + if err != nil { + t.Fatalf("DecryptSM4OFB failed: %v", err) + } + if !bytes.Equal(ofbDec, sm4Plain) { + t.Fatalf("sm4 ofb wrapper mismatch") + } + + ctrEnc, err := EncryptSM4CTR(sm4Plain, sm4Key, sm4IV) + if err != nil { + t.Fatalf("EncryptSM4CTR failed: %v", err) + } + ctrDec, err := DecryptSM4CTR(ctrEnc, sm4Key, sm4IV) + if err != nil { + t.Fatalf("DecryptSM4CTR failed: %v", err) + } + if !bytes.Equal(ctrDec, sm4Plain) { + t.Fatalf("sm4 ctr wrapper mismatch") + } + + ecbStreamEnc := &bytes.Buffer{} + if err := EncryptSM4ECBStream(ecbStreamEnc, bytes.NewReader(sm4Plain), sm4Key, ""); err != nil { + t.Fatalf("EncryptSM4ECBStream failed: %v", err) + } + ecbStreamDec := &bytes.Buffer{} + if err := DecryptSM4ECBStream(ecbStreamDec, bytes.NewReader(ecbStreamEnc.Bytes()), sm4Key, ""); err != nil { + t.Fatalf("DecryptSM4ECBStream failed: %v", err) + } + if !bytes.Equal(ecbStreamDec.Bytes(), sm4Plain) { + t.Fatalf("sm4 ecb stream wrapper mismatch") + } +} + +func TestRootOptionsAndGCMWrappers(t *testing.T) { + aesKey := []byte("0123456789abcdef") + sm4Key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("root-options-gcm") + + aesEnc, err := EncryptAesGCM(plain, aesKey, nonce, []byte("aad")) + if err != nil { + t.Fatalf("EncryptAesGCM failed: %v", err) + } + aesDec, err := DecryptAesGCM(aesEnc, aesKey, nonce, []byte("aad")) + if err != nil { + t.Fatalf("DecryptAesGCM failed: %v", err) + } + if !bytes.Equal(aesDec, plain) { + t.Fatalf("aes gcm wrapper mismatch") + } + + sm4Enc, err := EncryptSM4WithOptions(plain, sm4Key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("EncryptSM4WithOptions failed: %v", err) + } + sm4Dec, err := DecryptSM4WithOptions(sm4Enc, sm4Key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("DecryptSM4WithOptions failed: %v", err) + } + if !bytes.Equal(sm4Dec, plain) { + t.Fatalf("sm4 options wrapper mismatch") + } +} diff --git a/asy.go b/asy.go index af2848c..d5180e4 100644 --- a/asy.go +++ b/asy.go @@ -1,149 +1,34 @@ package starcrypto import ( + "b612.me/starcrypto/asymm" "crypto" - "crypto/ecdsa" - "crypto/ed25519" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "golang.org/x/crypto/ssh" - "reflect" ) func EncodePrivateKey(private crypto.PrivateKey, secret string) ([]byte, error) { - switch private.(type) { - case *rsa.PrivateKey: - return EncodeRsaPrivateKey(private.(*rsa.PrivateKey), secret) - case *ecdsa.PrivateKey: - return EncodeEcdsaPrivateKey(private.(*ecdsa.PrivateKey), secret) - default: - b, err := x509.MarshalPKCS8PrivateKey(private) - if err != nil { - return nil, err - } - if secret == "" { - return pem.EncodeToMemory(&pem.Block{ - Bytes: b, - Type: "PRIVATE KEY", - }), err - } - chiper := x509.PEMCipherAES256 - blk, err := x509.EncryptPEMBlock(rand.Reader, "PRIVATE KEY", b, []byte(secret), chiper) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(blk), err - } + return asymm.EncodePrivateKey(private, secret) } func EncodePublicKey(public crypto.PublicKey) ([]byte, error) { - switch public.(type) { - case *rsa.PublicKey: - return EncodeRsaPublicKey(public.(*rsa.PublicKey)) - case *ecdsa.PublicKey: - return EncodeEcdsaPublicKey(public.(*ecdsa.PublicKey)) - default: - publicBytes, err := x509.MarshalPKIXPublicKey(public) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Bytes: publicBytes, - Type: "PUBLIC KEY", - }), nil - } + return asymm.EncodePublicKey(public) } func DecodePrivateKey(private []byte, password string) (crypto.PrivateKey, error) { - blk, _ := pem.Decode(private) - if blk == nil { - return nil, errors.New("private key error") - } - switch blk.Type { - case "RSA PRIVATE KEY": - return DecodeRsaPrivateKey(private, password) - case "EC PRIVATE KEY": - return DecodeEcdsaPrivateKey(private, password) - case "PRIVATE KEY": - var prikey crypto.PrivateKey - var err error - var bytes []byte - blk, _ := pem.Decode(private) - if blk == nil { - return nil, errors.New("private key error!") - } - if password != "" { - tmp, err := x509.DecryptPEMBlock(blk, []byte(password)) - if err != nil { - return nil, err - } - bytes = tmp - } else { - bytes = blk.Bytes - } - prikey, err = x509.ParsePKCS8PrivateKey(bytes) - if err != nil { - return nil, err - } - return prikey, err - case "OPENSSH PRIVATE KEY": - var err error - var priv crypto.PrivateKey - if password == "" { - priv, err = ssh.ParseRawPrivateKey(private) - if err != nil { - return nil, err - } - } else { - priv, err = ssh.ParseRawPrivateKeyWithPassphrase(private, []byte(password)) - if err != nil { - return nil, err - } - } - return priv, nil - default: - return nil, errors.New("private key type error") - } + return asymm.DecodePrivateKey(private, password) } func EncodeOpenSSHPrivateKey(private crypto.PrivateKey, secret string) ([]byte, error) { - var key interface{} = private - var block *pem.Block - var err error - if reflect.TypeOf(key) == reflect.TypeOf(&ed25519.PrivateKey{}) { - key = *(key.(*ed25519.PrivateKey)) - } - if secret == "" { - block, err = ssh.MarshalPrivateKey(key, "") - } else { - block, err = ssh.MarshalPrivateKeyWithPassphrase(key, "", []byte(secret)) - } - return pem.EncodeToMemory(block), err + return asymm.EncodeOpenSSHPrivateKey(private, secret) } func DecodePublicKey(pubStr []byte) (crypto.PublicKey, error) { - blk, _ := pem.Decode(pubStr) - if blk == nil { - return nil, errors.New("public key error") - } - pub, err := x509.ParsePKIXPublicKey(blk.Bytes) - if err != nil { - return nil, err - } - return pub, nil + return asymm.DecodePublicKey(pubStr) } func EncodeSSHPublicKey(public crypto.PublicKey) ([]byte, error) { - publicKey, err := ssh.NewPublicKey(public) - if err != nil { - return nil, err - } - return ssh.MarshalAuthorizedKey(publicKey), nil + return asymm.EncodeSSHPublicKey(public) } func DecodeSSHPublicKey(pubStr []byte) (crypto.PublicKey, error) { - return ssh.ParsePublicKey(pubStr) + return asymm.DecodeSSHPublicKey(pubStr) } diff --git a/asymm/asymm_sm_test.go b/asymm/asymm_sm_test.go new file mode 100644 index 0000000..daf2a81 --- /dev/null +++ b/asymm/asymm_sm_test.go @@ -0,0 +1,138 @@ +package asymm + +import ( + "bytes" + "testing" +) + +func TestSM2SignVerifyAndEncryptDecrypt(t *testing.T) { + priv, pub, err := GenerateSM2Key() + if err != nil { + t.Fatalf("GenerateSM2Key failed: %v", err) + } + + msg := []byte("sm2-message") + uid := []byte("user123") + sig, err := SM2Sign(priv, msg, uid) + if err != nil { + t.Fatalf("SM2Sign failed: %v", err) + } + if !SM2Verify(pub, msg, sig, uid) { + t.Fatalf("SM2Verify failed") + } + + cipher, err := SM2EncryptASN1(pub, msg) + if err != nil { + t.Fatalf("SM2EncryptASN1 failed: %v", err) + } + plain, err := SM2DecryptASN1(priv, cipher) + if err != nil { + t.Fatalf("SM2DecryptASN1 failed: %v", err) + } + if !bytes.Equal(plain, msg) { + t.Fatalf("SM2 decrypt mismatch") + } +} + +func TestSM2PEMEncodeDecode(t *testing.T) { + priv, pub, err := GenerateSM2Key() + if err != nil { + t.Fatalf("GenerateSM2Key failed: %v", err) + } + + privPEM, err := EncodeSM2PrivateKey(priv, "pwd") + if err != nil { + t.Fatalf("EncodeSM2PrivateKey failed: %v", err) + } + pubPEM, err := EncodeSM2PublicKey(pub) + if err != nil { + t.Fatalf("EncodeSM2PublicKey failed: %v", err) + } + + decodedPriv, err := DecodeSM2PrivateKey(privPEM, "pwd") + if err != nil { + t.Fatalf("DecodeSM2PrivateKey failed: %v", err) + } + decodedPub, err := DecodeSM2PublicKey(pubPEM) + if err != nil { + t.Fatalf("DecodeSM2PublicKey failed: %v", err) + } + + msg := []byte("sm2-pem") + sig, err := SM2Sign(decodedPriv, msg, nil) + if err != nil { + t.Fatalf("SM2Sign failed: %v", err) + } + if !SM2Verify(decodedPub, msg, sig, nil) { + t.Fatalf("SM2 verify with decoded keys failed") + } +} + +func TestSM9SignVerifyAndEncryptDecrypt(t *testing.T) { + signMasterPriv, signMasterPub, err := GenerateSM9SignMasterKey() + if err != nil { + t.Fatalf("GenerateSM9SignMasterKey failed: %v", err) + } + encryptMasterPriv, encryptMasterPub, err := GenerateSM9EncryptMasterKey() + if err != nil { + t.Fatalf("GenerateSM9EncryptMasterKey failed: %v", err) + } + + uid := []byte("alice@example.com") + signUserKey, err := GenerateSM9SignUserKey(signMasterPriv, uid, 0) + if err != nil { + t.Fatalf("GenerateSM9SignUserKey failed: %v", err) + } + encryptUserKey, err := GenerateSM9EncryptUserKey(encryptMasterPriv, uid, 0) + if err != nil { + t.Fatalf("GenerateSM9EncryptUserKey failed: %v", err) + } + + msg := []byte("sm9-message") + sig, err := SM9SignASN1(signUserKey, msg) + if err != nil { + t.Fatalf("SM9SignASN1 failed: %v", err) + } + if !SM9VerifyASN1(signMasterPub, uid, 0, msg, sig) { + t.Fatalf("SM9VerifyASN1 failed") + } + + cipher, err := SM9Encrypt(encryptMasterPub, uid, 0, msg) + if err != nil { + t.Fatalf("SM9Encrypt failed: %v", err) + } + plain, err := SM9Decrypt(encryptUserKey, uid, cipher) + if err != nil { + t.Fatalf("SM9Decrypt failed: %v", err) + } + if !bytes.Equal(plain, msg) { + t.Fatalf("SM9 decrypt mismatch") + } +} + +func TestSM9PEMEncodeDecode(t *testing.T) { + signMasterPriv, _, err := GenerateSM9SignMasterKey() + if err != nil { + t.Fatalf("GenerateSM9SignMasterKey failed: %v", err) + } + encryptMasterPriv, _, err := GenerateSM9EncryptMasterKey() + if err != nil { + t.Fatalf("GenerateSM9EncryptMasterKey failed: %v", err) + } + + signPrivPEM, err := EncodeSM9SignMasterPrivateKey(signMasterPriv) + if err != nil { + t.Fatalf("EncodeSM9SignMasterPrivateKey failed: %v", err) + } + encPrivPEM, err := EncodeSM9EncryptMasterPrivateKey(encryptMasterPriv) + if err != nil { + t.Fatalf("EncodeSM9EncryptMasterPrivateKey failed: %v", err) + } + + if _, err := DecodeSM9SignMasterPrivateKey(signPrivPEM); err != nil { + t.Fatalf("DecodeSM9SignMasterPrivateKey failed: %v", err) + } + if _, err := DecodeSM9EncryptMasterPrivateKey(encPrivPEM); err != nil { + t.Fatalf("DecodeSM9EncryptMasterPrivateKey failed: %v", err) + } +} diff --git a/asymm/ecdsa.go b/asymm/ecdsa.go new file mode 100644 index 0000000..a087081 --- /dev/null +++ b/asymm/ecdsa.go @@ -0,0 +1,113 @@ +package asymm + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/pem" + "errors" + + "golang.org/x/crypto/ssh" +) + +func GenerateEcdsaKey(pubkeyCurve elliptic.Curve) (*ecdsa.PrivateKey, *ecdsa.PublicKey, error) { + priv, err := ecdsa.GenerateKey(pubkeyCurve, rand.Reader) + if err != nil { + return nil, nil, err + } + return priv, &priv.PublicKey, nil +} + +func EncodeEcdsaPrivateKey(private *ecdsa.PrivateKey, secret string) ([]byte, error) { + b, err := x509.MarshalECPrivateKey(private) + if err != nil { + return nil, err + } + if secret == "" { + return pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: b, + }), nil + } + blk, err := x509.EncryptPEMBlock(rand.Reader, "EC PRIVATE KEY", b, []byte(secret), x509.PEMCipherAES256) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(blk), nil +} + +func EncodeEcdsaPublicKey(public *ecdsa.PublicKey) ([]byte, error) { + publicBytes, err := x509.MarshalPKIXPublicKey(public) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicBytes, + }), nil +} + +func DecodeEcdsaPrivateKey(private []byte, password string) (*ecdsa.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + bytes, err := decodePEMBlockBytes(blk, password) + if err != nil { + return nil, err + } + + if key, err := x509.ParseECPrivateKey(bytes); err == nil { + return key, nil + } + pkcs8, err := x509.ParsePKCS8PrivateKey(bytes) + if err != nil { + return nil, err + } + key, ok := pkcs8.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("private key is not ECDSA") + } + return key, nil +} + +func DecodeEcdsaPublicKey(pubStr []byte) (*ecdsa.PublicKey, error) { + blk, _ := pem.Decode(pubStr) + if blk == nil { + return nil, errors.New("public key error") + } + pub, err := x509.ParsePKIXPublicKey(blk.Bytes) + if err != nil { + return nil, err + } + key, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, errors.New("public key is not ECDSA") + } + return key, nil +} + +func EncodeEcdsaSSHPublicKey(public *ecdsa.PublicKey) ([]byte, error) { + publicKey, err := ssh.NewPublicKey(public) + if err != nil { + return nil, err + } + return ssh.MarshalAuthorizedKey(publicKey), nil +} + +func GenerateEcdsaSSHKeyPair(pubkeyCurve elliptic.Curve, secret string) (string, string, error) { + pkey, pubkey, err := GenerateEcdsaKey(pubkeyCurve) + if err != nil { + return "", "", err + } + pub, err := EncodeEcdsaSSHPublicKey(pubkey) + if err != nil { + return "", "", err + } + priv, err := EncodeEcdsaPrivateKey(pkey, secret) + if err != nil { + return "", "", err + } + return string(priv), string(pub), nil +} diff --git a/asymm/key.go b/asymm/key.go new file mode 100644 index 0000000..3401930 --- /dev/null +++ b/asymm/key.go @@ -0,0 +1,161 @@ +package asymm + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + + "github.com/emmansun/gmsm/sm2" + "github.com/emmansun/gmsm/smx509" + "golang.org/x/crypto/ssh" +) + +func EncodePrivateKey(private crypto.PrivateKey, secret string) ([]byte, error) { + switch key := private.(type) { + case *rsa.PrivateKey: + return EncodeRsaPrivateKey(key, secret) + case *ecdsa.PrivateKey: + return EncodeEcdsaPrivateKey(key, secret) + case *sm2.PrivateKey: + return EncodeSM2PrivateKey(key, secret) + default: + b, err := x509.MarshalPKCS8PrivateKey(private) + if err != nil { + return nil, err + } + if secret == "" { + return pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: b, + }), nil + } + blk, err := x509.EncryptPEMBlock(rand.Reader, "PRIVATE KEY", b, []byte(secret), x509.PEMCipherAES256) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(blk), nil + } +} + +func EncodePublicKey(public crypto.PublicKey) ([]byte, error) { + switch key := public.(type) { + case *rsa.PublicKey: + return EncodeRsaPublicKey(key) + case *ecdsa.PublicKey: + if sm2.IsSM2PublicKey(key) { + return EncodeSM2PublicKey(key) + } + return EncodeEcdsaPublicKey(key) + default: + publicBytes, err := x509.MarshalPKIXPublicKey(public) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicBytes, + }), nil + } +} + +func DecodePrivateKey(private []byte, password string) (crypto.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + + switch blk.Type { + case "RSA PRIVATE KEY": + return DecodeRsaPrivateKey(private, password) + case "EC PRIVATE KEY": + return DecodeEcdsaPrivateKey(private, password) + case "SM2 PRIVATE KEY": + return DecodeSM2PrivateKey(private, password) + case "PRIVATE KEY": + bytes, err := decodePEMBlockBytes(blk, password) + if err != nil { + return nil, err + } + key, err := x509.ParsePKCS8PrivateKey(bytes) + if err == nil { + return key, nil + } + return smx509.ParsePKCS8PrivateKey(bytes) + case "OPENSSH PRIVATE KEY": + if password == "" { + return ssh.ParseRawPrivateKey(private) + } + return ssh.ParseRawPrivateKeyWithPassphrase(private, []byte(password)) + default: + return nil, errors.New("private key type error") + } +} + +func EncodeOpenSSHPrivateKey(private crypto.PrivateKey, secret string) ([]byte, error) { + key := interface{}(private) + if k, ok := private.(*ed25519.PrivateKey); ok { + key = *k + } + + var ( + block *pem.Block + err error + ) + if secret == "" { + block, err = ssh.MarshalPrivateKey(key, "") + } else { + block, err = ssh.MarshalPrivateKeyWithPassphrase(key, "", []byte(secret)) + } + if err != nil { + return nil, err + } + return pem.EncodeToMemory(block), nil +} + +func DecodePublicKey(pubStr []byte) (crypto.PublicKey, error) { + blk, _ := pem.Decode(pubStr) + if blk == nil { + return nil, errors.New("public key error") + } + key, err := x509.ParsePKIXPublicKey(blk.Bytes) + if err == nil { + return key, nil + } + return smx509.ParsePKIXPublicKey(blk.Bytes) +} + +func EncodeSSHPublicKey(public crypto.PublicKey) ([]byte, error) { + publicKey, err := ssh.NewPublicKey(public) + if err != nil { + return nil, err + } + return ssh.MarshalAuthorizedKey(publicKey), nil +} + +func DecodeSSHPublicKey(pubStr []byte) (crypto.PublicKey, error) { + return ssh.ParsePublicKey(pubStr) +} + +func decodePEMBlockBytes(blk *pem.Block, password string) ([]byte, error) { + if password == "" { + if x509.IsEncryptedPEMBlock(blk) || smx509.IsEncryptedPEMBlock(blk) { + return nil, errors.New("private key is encrypted but password is empty") + } + return blk.Bytes, nil + } + + if x509.IsEncryptedPEMBlock(blk) { + if b, err := x509.DecryptPEMBlock(blk, []byte(password)); err == nil { + return b, nil + } + } + if smx509.IsEncryptedPEMBlock(blk) { + return smx509.DecryptPEMBlock(blk, []byte(password)) + } + return blk.Bytes, nil +} diff --git a/asymm/rsa.go b/asymm/rsa.go new file mode 100644 index 0000000..9799072 --- /dev/null +++ b/asymm/rsa.go @@ -0,0 +1,204 @@ +package asymm + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "math/big" + + "golang.org/x/crypto/ssh" +) + +func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { + private, err := rsa.GenerateKey(rand.Reader, bits) + if err != nil { + return nil, nil, err + } + return private, &private.PublicKey, nil +} + +func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) { + der := x509.MarshalPKCS1PrivateKey(private) + if secret == "" { + return pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: der, + }), nil + } + blk, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", der, []byte(secret), x509.PEMCipherAES256) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(blk), nil +} + +func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) { + publicBytes, err := x509.MarshalPKIXPublicKey(public) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicBytes, + }), nil +} + +func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + + bytes, err := decodePEMBlockBytes(blk, password) + if err != nil { + return nil, err + } + + if prikey, err := x509.ParsePKCS1PrivateKey(bytes); err == nil { + return prikey, nil + } + pkcs8, err := x509.ParsePKCS8PrivateKey(bytes) + if err != nil { + return nil, err + } + prikey, ok := pkcs8.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("private key is not RSA") + } + return prikey, nil +} + +func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) { + blk, _ := pem.Decode(pubStr) + if blk == nil { + return nil, errors.New("public key error") + } + pub, err := x509.ParsePKIXPublicKey(blk.Bytes) + if err != nil { + return nil, err + } + rsapub, ok := pub.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not RSA") + } + return rsapub, nil +} + +func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) { + publicKey, err := ssh.NewPublicKey(public) + if err != nil { + return nil, err + } + return ssh.MarshalAuthorizedKey(publicKey), nil +} + +func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { + pkey, pubkey, err := GenerateRsaKey(bits) + if err != nil { + return "", "", err + } + pub, err := EncodeRsaSSHPublicKey(pubkey) + if err != nil { + return "", "", err + } + priv, err := EncodeRsaPrivateKey(pkey, secret) + if err != nil { + return "", "", err + } + return string(priv), string(pub), nil +} + +func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) { + return rsa.EncryptPKCS1v15(rand.Reader, pub, data) +} + +func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) { + return rsa.DecryptPKCS1v15(rand.Reader, prikey, data) +} + +func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) { + prikey, err := DecodeRsaPrivateKey(priKey, password) + if err != nil { + return nil, err + } + hashed, err := hashMessage(msg, hashType) + if err != nil { + return nil, err + } + return rsa.SignPKCS1v15(rand.Reader, prikey, hashType, hashed) +} + +func RSAVerify(sig, msg, pubKey []byte, hashType crypto.Hash) error { + pubkey, err := DecodeRsaPublicKey(pubKey) + if err != nil { + return err + } + hashed, err := hashMessage(msg, hashType) + if err != nil { + return err + } + return rsa.VerifyPKCS1v15(pubkey, hashType, hashed, sig) +} + +func RSAEncryptByPrivkey(priv *rsa.PrivateKey, data []byte) ([]byte, error) { + return rsa.SignPKCS1v15(nil, priv, crypto.Hash(0), data) +} + +func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) { + c := new(big.Int).SetBytes(data) + m := new(big.Int).Exp(c, big.NewInt(int64(pub.E)), pub.N) + em := leftPad(m.Bytes(), (pub.N.BitLen()+7)/8) + return unLeftPad(em) +} + +func hashMessage(msg []byte, hashType crypto.Hash) ([]byte, error) { + if hashType == 0 { + return msg, nil + } + if !hashType.Available() { + return nil, errors.New("hash function is not available") + } + h := hashType.New() + _, err := h.Write(msg) + if err != nil { + return nil, err + } + return h.Sum(nil), nil +} + +func leftPad(input []byte, size int) []byte { + n := len(input) + if n > size { + n = size + } + out := make([]byte, size) + copy(out[len(out)-n:], input) + return out +} + +func unLeftPad(input []byte) ([]byte, error) { + // PKCS#1 v1.5 block format: 0x00 || 0x01 || PS(0xff...) || 0x00 || M + if len(input) < 3 { + return nil, errors.New("invalid RSA block") + } + if input[0] != 0x00 || input[1] != 0x01 { + return nil, errors.New("invalid RSA block header") + } + i := 2 + for i < len(input) && input[i] == 0xff { + i++ + } + if i >= len(input) || input[i] != 0x00 { + return nil, errors.New("invalid RSA block padding") + } + i++ + if i > len(input) { + return nil, errors.New("invalid RSA block payload") + } + out := make([]byte, len(input)-i) + copy(out, input[i:]) + return out, nil +} diff --git a/asymm/sm2.go b/asymm/sm2.go new file mode 100644 index 0000000..2a27fb9 --- /dev/null +++ b/asymm/sm2.go @@ -0,0 +1,136 @@ +package asymm + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rand" + "encoding/pem" + "errors" + + "github.com/emmansun/gmsm/sm2" + "github.com/emmansun/gmsm/smx509" +) + +func GenerateSM2Key() (*sm2.PrivateKey, *ecdsa.PublicKey, error) { + priv, err := sm2.GenerateKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return priv, &priv.PublicKey, nil +} + +func EncodeSM2PrivateKey(private *sm2.PrivateKey, secret string) ([]byte, error) { + der, err := smx509.MarshalSM2PrivateKey(private) + if err != nil { + return nil, err + } + if secret == "" { + return pem.EncodeToMemory(&pem.Block{Type: "SM2 PRIVATE KEY", Bytes: der}), nil + } + blk, err := smx509.EncryptPEMBlock(rand.Reader, "SM2 PRIVATE KEY", der, []byte(secret), smx509.PEMCipherAES256) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(blk), nil +} + +func EncodeSM2PublicKey(public *ecdsa.PublicKey) ([]byte, error) { + der, err := smx509.MarshalPKIXPublicKey(public) + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "PUBLIC KEY", Bytes: der}), nil +} + +func DecodeSM2PrivateKey(private []byte, password string) (*sm2.PrivateKey, error) { + blk, _ := pem.Decode(private) + if blk == nil { + return nil, errors.New("private key error") + } + + bytes := blk.Bytes + if smx509.IsEncryptedPEMBlock(blk) { + if password == "" { + return nil, errors.New("private key is encrypted but password is empty") + } + var err error + bytes, err = smx509.DecryptPEMBlock(blk, []byte(password)) + if err != nil { + return nil, err + } + } + + if key, err := smx509.ParseSM2PrivateKey(bytes); err == nil { + return key, nil + } + pkcs8, err := smx509.ParsePKCS8PrivateKey(bytes) + if err != nil { + return nil, err + } + key, ok := pkcs8.(*sm2.PrivateKey) + if !ok { + return nil, errors.New("private key is not SM2") + } + return key, nil +} + +func DecodeSM2PublicKey(pubStr []byte) (*ecdsa.PublicKey, error) { + blk, _ := pem.Decode(pubStr) + if blk == nil { + return nil, errors.New("public key error") + } + pubAny, err := smx509.ParsePKIXPublicKey(blk.Bytes) + if err != nil { + return nil, err + } + pub, ok := pubAny.(*ecdsa.PublicKey) + if !ok { + return nil, errors.New("public key is not ECDSA/SM2") + } + if !sm2.IsSM2PublicKey(pub) { + return nil, errors.New("public key is not SM2") + } + return pub, nil +} + +func SM2EncryptASN1(pub *ecdsa.PublicKey, data []byte) ([]byte, error) { + return sm2.EncryptASN1(rand.Reader, pub, data) +} + +func SM2DecryptASN1(priv *sm2.PrivateKey, data []byte) ([]byte, error) { + return priv.Decrypt(nil, data, sm2.ASN1DecrypterOpts) +} + +func SM2Sign(priv *sm2.PrivateKey, msg, uid []byte) ([]byte, error) { + if len(uid) == 0 { + uid = nil + } + return sm2.SignASN1(rand.Reader, priv, msg, sm2.NewSM2SignerOption(true, uid)) +} + +func SM2Verify(pub *ecdsa.PublicKey, msg, sig, uid []byte) bool { + if len(uid) == 0 { + uid = nil + } + return sm2.VerifyASN1WithSM2(pub, uid, msg, sig) +} + +func SM2SignByPEM(msg, priKey []byte, password string, uid []byte) ([]byte, error) { + priv, err := DecodeSM2PrivateKey(priKey, password) + if err != nil { + return nil, err + } + return SM2Sign(priv, msg, uid) +} + +func SM2VerifyByPEM(sig, msg, pubKey []byte, uid []byte) (bool, error) { + pub, err := DecodeSM2PublicKey(pubKey) + if err != nil { + return false, err + } + return SM2Verify(pub, msg, sig, uid), nil +} + +func IsSM2PublicKey(public crypto.PublicKey) bool { + return sm2.IsSM2PublicKey(public) +} diff --git a/asymm/sm9.go b/asymm/sm9.go new file mode 100644 index 0000000..993ec40 --- /dev/null +++ b/asymm/sm9.go @@ -0,0 +1,236 @@ +package asymm + +import ( + "crypto/rand" + "encoding/pem" + "errors" + + gmsm3 "github.com/emmansun/gmsm/sm3" + gmsm9 "github.com/emmansun/gmsm/sm9" +) + +const ( + SM9SignHID byte = 0x01 + SM9EncryptHID byte = 0x03 +) + +func GenerateSM9SignMasterKey() (*gmsm9.SignMasterPrivateKey, *gmsm9.SignMasterPublicKey, error) { + priv, err := gmsm9.GenerateSignMasterKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return priv, priv.PublicKey(), nil +} + +func GenerateSM9EncryptMasterKey() (*gmsm9.EncryptMasterPrivateKey, *gmsm9.EncryptMasterPublicKey, error) { + priv, err := gmsm9.GenerateEncryptMasterKey(rand.Reader) + if err != nil { + return nil, nil, err + } + return priv, priv.PublicKey(), nil +} + +func GenerateSM9SignUserKey(master *gmsm9.SignMasterPrivateKey, uid []byte, hid byte) (*gmsm9.SignPrivateKey, error) { + if master == nil { + return nil, errors.New("sm9 sign master key is nil") + } + if hid == 0 { + hid = SM9SignHID + } + return master.GenerateUserKey(uid, hid) +} + +func GenerateSM9EncryptUserKey(master *gmsm9.EncryptMasterPrivateKey, uid []byte, hid byte) (*gmsm9.EncryptPrivateKey, error) { + if master == nil { + return nil, errors.New("sm9 encrypt master key is nil") + } + if hid == 0 { + hid = SM9EncryptHID + } + return master.GenerateUserKey(uid, hid) +} + +func EncodeSM9SignMasterPrivateKey(key *gmsm9.SignMasterPrivateKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 sign master private key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 SIGN MASTER PRIVATE KEY", Bytes: der}), nil +} + +func DecodeSM9SignMasterPrivateKey(data []byte) (*gmsm9.SignMasterPrivateKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalSignMasterPrivateKeyASN1(der) +} + +func EncodeSM9SignMasterPublicKey(key *gmsm9.SignMasterPublicKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 sign master public key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 SIGN MASTER PUBLIC KEY", Bytes: der}), nil +} + +func DecodeSM9SignMasterPublicKey(data []byte) (*gmsm9.SignMasterPublicKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalSignMasterPublicKeyASN1(der) +} + +func EncodeSM9SignPrivateKey(key *gmsm9.SignPrivateKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 sign private key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 SIGN PRIVATE KEY", Bytes: der}), nil +} + +func DecodeSM9SignPrivateKey(data []byte) (*gmsm9.SignPrivateKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalSignPrivateKeyASN1(der) +} + +func EncodeSM9EncryptMasterPrivateKey(key *gmsm9.EncryptMasterPrivateKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 encrypt master private key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 ENCRYPT MASTER PRIVATE KEY", Bytes: der}), nil +} + +func DecodeSM9EncryptMasterPrivateKey(data []byte) (*gmsm9.EncryptMasterPrivateKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalEncryptMasterPrivateKeyASN1(der) +} + +func EncodeSM9EncryptMasterPublicKey(key *gmsm9.EncryptMasterPublicKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 encrypt master public key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 ENCRYPT MASTER PUBLIC KEY", Bytes: der}), nil +} + +func DecodeSM9EncryptMasterPublicKey(data []byte) (*gmsm9.EncryptMasterPublicKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalEncryptMasterPublicKeyASN1(der) +} + +func EncodeSM9EncryptPrivateKey(key *gmsm9.EncryptPrivateKey) ([]byte, error) { + if key == nil { + return nil, errors.New("sm9 encrypt private key is nil") + } + der, err := key.MarshalASN1() + if err != nil { + return nil, err + } + return pem.EncodeToMemory(&pem.Block{Type: "SM9 ENCRYPT PRIVATE KEY", Bytes: der}), nil +} + +func DecodeSM9EncryptPrivateKey(data []byte) (*gmsm9.EncryptPrivateKey, error) { + der, err := pemOrDER(data) + if err != nil { + return nil, err + } + return gmsm9.UnmarshalEncryptPrivateKeyASN1(der) +} + +func SM9SignHashASN1(priv *gmsm9.SignPrivateKey, hash []byte) ([]byte, error) { + if priv == nil { + return nil, errors.New("sm9 sign private key is nil") + } + return gmsm9.SignASN1(rand.Reader, priv, hash) +} + +func SM9SignASN1(priv *gmsm9.SignPrivateKey, message []byte) ([]byte, error) { + sum := gmsm3.Sum(message) + return SM9SignHashASN1(priv, sum[:]) +} + +func SM9VerifyHashASN1(pub *gmsm9.SignMasterPublicKey, uid []byte, hid byte, hash, sig []byte) bool { + if pub == nil { + return false + } + if hid == 0 { + hid = SM9SignHID + } + return gmsm9.VerifyASN1(pub, uid, hid, hash, sig) +} + +func SM9VerifyASN1(pub *gmsm9.SignMasterPublicKey, uid []byte, hid byte, message, sig []byte) bool { + sum := gmsm3.Sum(message) + return SM9VerifyHashASN1(pub, uid, hid, sum[:], sig) +} + +func SM9Encrypt(pub *gmsm9.EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + if pub == nil { + return nil, errors.New("sm9 encrypt master public key is nil") + } + if hid == 0 { + hid = SM9EncryptHID + } + return gmsm9.Encrypt(rand.Reader, pub, uid, hid, plaintext, gmsm9.SM4CBCEncrypterOpts) +} + +func SM9Decrypt(priv *gmsm9.EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + if priv == nil { + return nil, errors.New("sm9 encrypt private key is nil") + } + return gmsm9.Decrypt(priv, uid, ciphertext, gmsm9.SM4CBCEncrypterOpts) +} + +func SM9EncryptASN1(pub *gmsm9.EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + if pub == nil { + return nil, errors.New("sm9 encrypt master public key is nil") + } + if hid == 0 { + hid = SM9EncryptHID + } + return gmsm9.EncryptASN1(rand.Reader, pub, uid, hid, plaintext, gmsm9.SM4CBCEncrypterOpts) +} + +func SM9DecryptASN1(priv *gmsm9.EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + if priv == nil { + return nil, errors.New("sm9 encrypt private key is nil") + } + return gmsm9.DecryptASN1(priv, uid, ciphertext) +} + +func pemOrDER(data []byte) ([]byte, error) { + if len(data) == 0 { + return nil, errors.New("empty key data") + } + if blk, _ := pem.Decode(data); blk != nil { + return blk.Bytes, nil + } + return data, nil +} diff --git a/base648591.go b/base648591.go index a10f168..958d8e0 100644 --- a/base648591.go +++ b/base648591.go @@ -1,368 +1,80 @@ package starcrypto -import ( - "encoding/ascii85" - "encoding/base64" - "errors" - "io" - "os" -) +import "b612.me/starcrypto/encodingx" var ( - // ErrLength is returned from the Decode* methods if the input has an - // impossible length. - ErrLength = errors.New("base128: invalid length base128 string") - // ErrBit is returned from the Decode* methods if the input has a byte with - // the high-bit set (e.g. 0x80). This will never be the case for strings - // produced from the Encode* methods in this package. - ErrBit = errors.New("base128: high bit set in base128 string") + ErrLength = encodingx.ErrLength + ErrBit = encodingx.ErrBit ) -// Encoding table holds all the characters for base91 encoding -var enctab = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~'") - -// Decoding table maps all the characters back to their integer values -var dectab = map[byte]byte{ - 'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, - 'I': 8, 'J': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'O': 14, 'P': 15, - 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, - 'Y': 24, 'Z': 25, 'a': 26, 'b': 27, 'c': 28, 'd': 29, 'e': 30, 'f': 31, - 'g': 32, 'h': 33, 'i': 34, 'j': 35, 'k': 36, 'l': 37, 'm': 38, 'n': 39, - 'o': 40, 'p': 41, 'q': 42, 'r': 43, 's': 44, 't': 45, 'u': 46, 'v': 47, - 'w': 48, 'x': 49, 'y': 50, 'z': 51, '0': 52, '1': 53, '2': 54, '3': 55, - '4': 56, '5': 57, '6': 58, '7': 59, '8': 60, '9': 61, '!': 62, '#': 63, - '$': 64, '%': 65, '&': 66, '(': 67, ')': 68, '*': 69, '+': 70, ',': 71, - '.': 72, '/': 73, ':': 74, ';': 75, '<': 76, '=': 77, '>': 78, '?': 79, - '@': 80, '[': 81, ']': 82, '^': 83, '_': 84, '`': 85, '{': 86, '|': 87, - '}': 88, '~': 89, '\'': 90, -} - -// Base91EncodeToString encodes the given byte array and returns a string func Base91EncodeToString(d []byte) string { - return string(Base91Encode(d)) + return encodingx.Base91EncodeToString(d) } -// Base91Encode returns the base91 encoded string func Base91Encode(d []byte) []byte { - var n, b uint - var o []byte - - for i := 0; i < len(d); i++ { - b |= uint(d[i]) << n - n += 8 - if n > 13 { - v := b & 8191 - if v > 88 { - b >>= 13 - n -= 13 - } else { - v = b & 16383 - b >>= 14 - n -= 14 - } - o = append(o, enctab[v%91], enctab[v/91]) - } - } - if n > 0 { - o = append(o, enctab[b%91]) - if n > 7 || b > 90 { - o = append(o, enctab[b/91]) - } - } - return o + return encodingx.Base91Encode(d) } -// Base91DecodeToString decodes a given byte array are returns a string func Base91DecodeString(d string) []byte { - return Base91Decode([]byte(d)) + return encodingx.Base91DecodeString(d) } -// Base91Decode decodes a base91 encoded string and returns the result func Base91Decode(d []byte) []byte { - var b, n uint - var o []byte - v := -1 - - for i := 0; i < len(d); i++ { - c, ok := dectab[d[i]] - if !ok { - continue - } - if v < 0 { - v = int(c) - } else { - v += int(c) * 91 - b |= uint(v) << n - if v&8191 > 88 { - n += 13 - } else { - n += 14 - } - o = append(o, byte(b&255)) - b >>= 8 - n -= 8 - for n > 7 { - o = append(o, byte(b&255)) - b >>= 8 - n -= 8 - } - v = -1 - } - } - if v+1 > 0 { - o = append(o, byte((b|uint(v)<>whichByte)) - bufByte = (v&(1< 1 { - dst = append(dst, bufByte|(v>>(8-whichByte))) - } - bufByte = v << whichByte - if whichByte == 8 { - whichByte = 0 - } - whichByte++ - } - return len(dst), nil + return encodingx.Base128Decode(dst, src) } -// Base128DecodeString returns the bytes represented by the base128 string s. func Base128DecodeString(s string) ([]byte, error) { - src := []byte(s) - dst := make([]byte, Base128DecodedLen(len(src))) - if _, err := Base128Decode(dst, src); err != nil { - return nil, err - } - return dst, nil + return encodingx.Base128DecodeString(s) } -// Base128DecodedLen returns the number of bytes `encLen` encoded bytes decodes to. func Base128DecodedLen(encLen int) int { - return (encLen * 7 / 8) + return encodingx.Base128DecodedLen(encLen) } -// Base128EncodedLen returns the number of bytes that `dataLen` bytes will encode to. func Base128EncodedLen(dataLen int) int { - return (((dataLen * 8) + 6) / 7) + return encodingx.Base128EncodedLen(dataLen) } -// Base128EncodeToString returns the base128 encoding of src. func Base128EncodeToString(src []byte) string { - dst := make([]byte, Base128EncodedLen(len(src))) - Base128Encode(dst, src) - return string(dst) + return encodingx.Base128EncodeToString(src) } -// Base64Encode 输出格式化后的Base64字符串 func Base64Encode(bstr []byte) string { - return base64.StdEncoding.EncodeToString(bstr) + return encodingx.Base64Encode(bstr) } -// Base64Decode 输出解密前的Base64数据 func Base64Decode(str string) ([]byte, error) { - return base64.StdEncoding.DecodeString(str) + return encodingx.Base64Decode(str) } -// Base85Encode 输出格式化后的Base85字符串 func Base85Encode(bstr []byte) string { - var rtn []byte - rtn = make([]byte, ascii85.MaxEncodedLen(len(bstr))) - ascii85.Encode(rtn, bstr) - return string(rtn) + return encodingx.Base85Encode(bstr) } -// Base85Decode 输出解密前的Base85数据 func Base85Decode(str string) ([]byte, error) { - var rtn []byte - rtn = make([]byte, len(str)) - _, _, err := ascii85.Decode(rtn, []byte(str), true) - return rtn, err + return encodingx.Base85Decode(str) } -// Base85EncodeFile 用base85方法编码src文件到dst文件中去,shell传入当前进度 func Base85EncodeFile(src, dst string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - b85 := ascii85.NewEncoder(fpdst) - defer b85.Close() - for { - buf := make([]byte, 1024000) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - go shell(float64(sum) / filebig * 100) - b85.Write(buf[0:n]) - } - return nil + return encodingx.Base85EncodeFile(src, dst, shell) } -// Base85DecodeFile 用base85方法解码src文件到dst文件中去,shell传入当前进度 func Base85DecodeFile(src, dst string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 - defer fpsrc.Close() - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - b85 := ascii85.NewDecoder(fpsrc) - for { - buf := make([]byte, 1280000) - n, err := b85.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - per := float64(sum) / filebig * 100 / 4.0 * 5.0 - if per >= 100 { - per = 100 - } - go shell(per) - fpdst.Write(buf[0:n]) - } - return nil + return encodingx.Base85DecodeFile(src, dst, shell) } -// Base64EncodeFile 用base64方法编码src文件到dst文件中去,shell传入当前进度 func Base64EncodeFile(src, dst string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 = 0 - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - b64 := base64.NewEncoder(base64.StdEncoding, fpdst) - defer b64.Close() - for { - buf := make([]byte, 1048575) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - go shell(float64(sum) / filebig * 100) - b64.Write(buf[0:n]) - } - return nil + return encodingx.Base64EncodeFile(src, dst, shell) } -// Base64DecodeFile 用base64方法解码src文件到dst文件中去,shell传入当前进度 func Base64DecodeFile(src, dst string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 = 0 - defer fpsrc.Close() - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - b64 := base64.NewDecoder(base64.StdEncoding, fpsrc) - for { - buf := make([]byte, 1048576) - n, err := b64.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - per := float64(sum) / filebig * 100 / 3.0 * 4.0 - if per >= 100 { - per = 100 - } - go shell(per) - fpdst.Write(buf[0:n]) - } - return nil + return encodingx.Base64DecodeFile(src, dst, shell) } diff --git a/chacha20.go b/chacha20.go new file mode 100644 index 0000000..4fc6fe4 --- /dev/null +++ b/chacha20.go @@ -0,0 +1,31 @@ +package starcrypto + +import ( + "io" + + "b612.me/starcrypto/symm" +) + +func EncryptChaCha20(data, key, nonce []byte) ([]byte, error) { + return symm.EncryptChaCha20(data, key, nonce) +} + +func DecryptChaCha20(src, key, nonce []byte) ([]byte, error) { + return symm.DecryptChaCha20(src, key, nonce) +} + +func EncryptChaCha20Stream(dst io.Writer, src io.Reader, key, nonce []byte) error { + return symm.EncryptChaCha20Stream(dst, src, key, nonce) +} + +func DecryptChaCha20Stream(dst io.Writer, src io.Reader, key, nonce []byte) error { + return symm.DecryptChaCha20Stream(dst, src, key, nonce) +} + +func EncryptChaCha20Poly1305(plain, key, nonce, aad []byte) ([]byte, error) { + return symm.EncryptChaCha20Poly1305(plain, key, nonce, aad) +} + +func DecryptChaCha20Poly1305(ciphertext, key, nonce, aad []byte) ([]byte, error) { + return symm.DecryptChaCha20Poly1305(ciphertext, key, nonce, aad) +} diff --git a/crc32.go b/crc32.go index 7aac14e..f721043 100644 --- a/crc32.go +++ b/crc32.go @@ -1,117 +1,23 @@ package starcrypto -import ( - "encoding/binary" - "encoding/hex" - "hash/crc32" -) +import "b612.me/starcrypto/hashx" -// CheckCRC32A calculates CRC32A (ITU I.363.5 algorithm, popularized by BZIP2) checksum. -// This function will produce the same results as following PHP code: -// -// hexdec(hash('crc32', $data)) func CheckCRC32A(data []byte) uint32 { - b := digest(data) - - return binary.BigEndian.Uint32(b) + return hashx.CheckCRC32A(data) } func Crc32Str(bstr []byte) string { - return String(Crc32(bstr)) + return hashx.Crc32Str(bstr) } -// CRC32 输出CRC32校验值 func Crc32(bstr []byte) []byte { - crcsum := crc32.NewIEEE() - crcsum.Write(bstr) - return crcsum.Sum(nil) + return hashx.Crc32(bstr) } func Crc32A(data []byte) []byte { - return digest(data) + return hashx.Crc32A(data) } -// Crc32AStr is a convenience function that outputs CRC32A (ITU I.363.5 algorithm, popularized by BZIP2) checksum as a hex string. -// This function will produce the same results as following PHP code: -// -// hash('crc32', $data) func Crc32AStr(data []byte) string { - b := digest(data) - - return hex.EncodeToString(b) -} - -// digest performs checksum calculation for each byte of provided data and returns digest in form of byte array. -func digest(data []byte) []byte { - var crc uint32 - var digest = make([]byte, 4) - - crc = ^crc - for i := 0; i < len(data); i++ { - crc = (crc << 8) ^ table[(crc>>24)^(uint32(data[i])&0xff)] - } - crc = ^crc - - digest[3] = byte((crc >> 24) & 0xff) - digest[2] = byte((crc >> 16) & 0xff) - digest[1] = byte((crc >> 8) & 0xff) - digest[0] = byte(crc & 0xff) - - return digest -} - -// table is the pre-generated 0x04C11DB7 polynominal used for CRC32A. -var table = [256]uint32{ - 0x0, - 0x04c11db7, 0x09823b6e, 0x0d4326d9, 0x130476dc, 0x17c56b6b, - 0x1a864db2, 0x1e475005, 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, - 0x2b4bcb61, 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, - 0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9, 0x5f15adac, - 0x5bd4b01b, 0x569796c2, 0x52568b75, 0x6a1936c8, 0x6ed82b7f, - 0x639b0da6, 0x675a1011, 0x791d4014, 0x7ddc5da3, 0x709f7b7a, - 0x745e66cd, 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, - 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, 0xbe2b5b58, - 0xbaea46ef, 0xb7a96036, 0xb3687d81, 0xad2f2d84, 0xa9ee3033, - 0xa4ad16ea, 0xa06c0b5d, 0xd4326d90, 0xd0f37027, 0xddb056fe, - 0xd9714b49, 0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95, - 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, 0xe13ef6f4, - 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, 0x34867077, 0x30476dc0, - 0x3d044b19, 0x39c556ae, 0x278206ab, 0x23431b1c, 0x2e003dc5, - 0x2ac12072, 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, - 0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca, 0x7897ab07, - 0x7c56b6b0, 0x71159069, 0x75d48dde, 0x6b93dddb, 0x6f52c06c, - 0x6211e6b5, 0x66d0fb02, 0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, - 0x53dc6066, 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, - 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, 0xbfa1b04b, - 0xbb60adfc, 0xb6238b25, 0xb2e29692, 0x8aad2b2f, 0x8e6c3698, - 0x832f1041, 0x87ee0df6, 0x99a95df3, 0x9d684044, 0x902b669d, - 0x94ea7b2a, 0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e, - 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, 0xc6bcf05f, - 0xc27dede8, 0xcf3ecb31, 0xcbffd686, 0xd5b88683, 0xd1799b34, - 0xdc3abded, 0xd8fba05a, 0x690ce0ee, 0x6dcdfd59, 0x608edb80, - 0x644fc637, 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, - 0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f, 0x5c007b8a, - 0x58c1663d, 0x558240e4, 0x51435d53, 0x251d3b9e, 0x21dc2629, - 0x2c9f00f0, 0x285e1d47, 0x36194d42, 0x32d850f5, 0x3f9b762c, - 0x3b5a6b9b, 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, - 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, 0xf12f560e, - 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7, 0xe22b20d2, 0xe6ea3d65, - 0xeba91bbc, 0xef68060b, 0xd727bbb6, 0xd3e6a601, 0xdea580d8, - 0xda649d6f, 0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3, - 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, 0xae3afba2, - 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, 0x9b3660c6, 0x9ff77d71, - 0x92b45ba8, 0x9675461f, 0x8832161a, 0x8cf30bad, 0x81b02d74, - 0x857130c3, 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, - 0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c, 0x7b827d21, - 0x7f436096, 0x7200464f, 0x76c15bf8, 0x68860bfd, 0x6c47164a, - 0x61043093, 0x65c52d24, 0x119b4be9, 0x155a565e, 0x18197087, - 0x1cd86d30, 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, - 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, 0x2497d08d, - 0x2056cd3a, 0x2d15ebe3, 0x29d4f654, 0xc5a92679, 0xc1683bce, - 0xcc2b1d17, 0xc8ea00a0, 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, - 0xdbee767c, 0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18, - 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, 0x89b8fd09, - 0x8d79e0be, 0x803ac667, 0x84fbdbd0, 0x9abc8bd5, 0x9e7d9662, - 0x933eb0bb, 0x97ffad0c, 0xafb010b1, 0xab710d06, 0xa6322bdf, - 0xa2f33668, 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4, + return hashx.Crc32AStr(data) } diff --git a/crypto.go b/crypto.go index 1c8e69e..a7eb4d7 100644 --- a/crypto.go +++ b/crypto.go @@ -1,171 +1,26 @@ package starcrypto import ( - "crypto/rand" - "encoding/binary" - "encoding/hex" - "io" - "os" + "b612.me/starcrypto/hashx" + "b612.me/starcrypto/legacy" ) func String(bstr []byte) string { - return hex.EncodeToString(bstr) + return hashx.HexString(bstr) } func VicqueEncodeV1(srcdata []byte, key string) []byte { - var keys []int - var randCode1, randCode2 uint8 - data := make([]byte, len(srcdata)) - copy(data, srcdata) - binary.Read(rand.Reader, binary.LittleEndian, &randCode1) - binary.Read(rand.Reader, binary.LittleEndian, &randCode2) - keys = append(keys, len(key)+int(randCode1)) - lens := len(data) - for _, v := range key { - keys = append(keys, int(byte(v))+int(randCode1)-int(randCode2)) - } - lenkey := len(keys) - for k, v := range data { - if k == lens/2 { - break - } - nv := int(v) - t := 0 - if k%2 == 0 { - nv += keys[k%lenkey] - if nv > 255 { - nv -= 256 - } - t = int(data[lens-1-k]) - t += keys[k%lenkey] - if t > 255 { - t -= 256 - } - } else { - nv -= keys[k%lenkey] - if nv < 0 { - nv += 256 - } - t = int(data[lens-1-k]) - t -= keys[k%lenkey] - if t > 255 { - t += 256 - } - } - data[k] = byte(t) - data[lens-1-k] = byte(nv) - } - data = append(data, byte(randCode1), byte(randCode2)) - return data + return legacy.VicqueEncodeV1(srcdata, key) } func VicqueDecodeV1(srcdata []byte, key string) []byte { - var keys []int - var randCode1, randCode2 int - data := make([]byte, len(srcdata)) - copy(data, srcdata) - lens := len(data) - randCode1 = int(data[lens-2]) - randCode2 = int(data[lens-1]) - keys = append(keys, len(key)+int(randCode1)) - for _, v := range key { - keys = append(keys, int(byte(v))+int(randCode1)-int(randCode2)) - } - lenkey := len(keys) - lens -= 2 - for k, v := range data { - if k == lens/2 { - break - } - nv := int(v) - t := 0 - if k%2 == 0 { - nv -= keys[k%lenkey] - if nv < 0 { - nv += 256 - } - t = int(data[lens-1-k]) - t -= keys[k%lenkey] - if t > 255 { - t += 256 - } - } else { - nv += keys[k%lenkey] - if nv > 255 { - nv -= 256 - } - t = int(data[lens-1-k]) - t += keys[k%lenkey] - if t > 255 { - t -= 256 - } - } - data[k] = byte(t) - data[lens-1-k] = byte(nv) - } - return data[:lens] + return legacy.VicqueDecodeV1(srcdata, key) } func VicqueEncodeV1File(src, dst, pwd string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 - defer fpsrc.Close() - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - for { - buf := make([]byte, 1048576) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - go shell(float64(sum) / filebig * 100) - data := VicqueEncodeV1(buf[0:n], pwd) - fpdst.Write(data) - } - return nil + return legacy.VicqueEncodeV1File(src, dst, pwd, shell) } func VicqueDecodeV1File(src, dst, pwd string, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - var sum int64 - defer fpsrc.Close() - fpdst, err := os.Create(dst) - if err != nil { - return err - } - defer fpdst.Close() - for { - buf := make([]byte, 1048578) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - go shell(float64(sum) / filebig * 100) - data := VicqueDecodeV1(buf[0:n], pwd) - fpdst.Write(data) - } - return nil + return legacy.VicqueDecodeV1File(src, dst, pwd, shell) } diff --git a/des.go b/des.go new file mode 100644 index 0000000..a94e559 --- /dev/null +++ b/des.go @@ -0,0 +1,39 @@ +package starcrypto + +import ( + "io" + + "b612.me/starcrypto/symm" +) + +func EncryptDESCBC(data, key, iv []byte, paddingType string) ([]byte, error) { + return symm.EncryptDESCBC(data, key, iv, paddingType) +} + +func DecryptDESCBC(src, key, iv []byte, paddingType string) ([]byte, error) { + return symm.DecryptDESCBC(src, key, iv, paddingType) +} + +func EncryptDESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.EncryptDESCBCStream(dst, src, key, iv, paddingType) +} + +func DecryptDESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.DecryptDESCBCStream(dst, src, key, iv, paddingType) +} + +func Encrypt3DESCBC(data, key, iv []byte, paddingType string) ([]byte, error) { + return symm.Encrypt3DESCBC(data, key, iv, paddingType) +} + +func Decrypt3DESCBC(src, key, iv []byte, paddingType string) ([]byte, error) { + return symm.Decrypt3DESCBC(src, key, iv, paddingType) +} + +func Encrypt3DESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.Encrypt3DESCBCStream(dst, src, key, iv, paddingType) +} + +func Decrypt3DESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.Decrypt3DESCBCStream(dst, src, key, iv, paddingType) +} diff --git a/ecdsa.go b/ecdsa.go index 3e59be0..29504a9 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -1,117 +1,35 @@ package starcrypto import ( + "b612.me/starcrypto/asymm" "crypto/ecdsa" "crypto/elliptic" - "crypto/rand" - "crypto/x509" - "encoding/pem" - "errors" - "golang.org/x/crypto/ssh" ) func GenerateEcdsaKey(pubkeyCurve elliptic.Curve) (*ecdsa.PrivateKey, *ecdsa.PublicKey, error) { - // 随机挑选基点,生成私钥 - priv, err := ecdsa.GenerateKey(pubkeyCurve, rand.Reader) - if err != nil { - return nil, nil, err - } - return priv, &priv.PublicKey, nil - + return asymm.GenerateEcdsaKey(pubkeyCurve) } func EncodeEcdsaPrivateKey(private *ecdsa.PrivateKey, secret string) ([]byte, error) { - b, err := x509.MarshalECPrivateKey(private) - if err != nil { - return nil, err - } - if secret == "" { - return pem.EncodeToMemory(&pem.Block{ - Bytes: b, - Type: "EC PRIVATE KEY", - }), err - } - chiper := x509.PEMCipherAES256 - blk, err := x509.EncryptPEMBlock(rand.Reader, "EC PRIVATE KEY", b, []byte(secret), chiper) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(blk), err + return asymm.EncodeEcdsaPrivateKey(private, secret) } func EncodeEcdsaPublicKey(public *ecdsa.PublicKey) ([]byte, error) { - publicBytes, err := x509.MarshalPKIXPublicKey(public) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Bytes: publicBytes, - Type: "PUBLIC KEY", - }), nil + return asymm.EncodeEcdsaPublicKey(public) } func DecodeEcdsaPrivateKey(private []byte, password string) (*ecdsa.PrivateKey, error) { - var prikey *ecdsa.PrivateKey - var err error - var bytes []byte - blk, _ := pem.Decode(private) - if blk == nil { - return nil, errors.New("private key error!") - } - if password != "" { - tmp, err := x509.DecryptPEMBlock(blk, []byte(password)) - if err != nil { - return nil, err - } - bytes = tmp - } else { - bytes = blk.Bytes - } - prikey, err = x509.ParseECPrivateKey(bytes) - if err != nil { - tmp, err := x509.ParsePKCS8PrivateKey(bytes) - if err != nil { - return nil, err - } - prikey = tmp.(*ecdsa.PrivateKey) - } - return prikey, err + return asymm.DecodeEcdsaPrivateKey(private, password) } func DecodeEcdsaPublicKey(pubStr []byte) (*ecdsa.PublicKey, error) { - blk, _ := pem.Decode(pubStr) - if blk == nil { - return nil, errors.New("public key error") - } - pub, err := x509.ParsePKIXPublicKey(blk.Bytes) - if err != nil { - return nil, err - } - return pub.(*ecdsa.PublicKey), nil + return asymm.DecodeEcdsaPublicKey(pubStr) } func EncodeEcdsaSSHPublicKey(public *ecdsa.PublicKey) ([]byte, error) { - publicKey, err := ssh.NewPublicKey(public) - if err != nil { - return nil, err - } - return ssh.MarshalAuthorizedKey(publicKey), nil + return asymm.EncodeEcdsaSSHPublicKey(public) } func GenerateEcdsaSSHKeyPair(pubkeyCurve elliptic.Curve, secret string) (string, string, error) { - pkey, pubkey, err := GenerateEcdsaKey(pubkeyCurve) - if err != nil { - return "", "", err - } - - pub, err := EncodeEcdsaSSHPublicKey(pubkey) - if err != nil { - return "", "", err - } - - priv, err := EncodeEcdsaPrivateKey(pkey, secret) - if err != nil { - return "", "", err - } - return string(priv), string(pub), nil + return asymm.GenerateEcdsaSSHKeyPair(pubkeyCurve, secret) } diff --git a/encodingx/encoding.go b/encodingx/encoding.go new file mode 100644 index 0000000..f743c49 --- /dev/null +++ b/encodingx/encoding.go @@ -0,0 +1,387 @@ +package encodingx + +import ( + "encoding/ascii85" + "encoding/base64" + "errors" + "io" + "os" +) + +var ( + ErrLength = errors.New("base128: invalid length base128 string") + ErrBit = errors.New("base128: high bit set in base128 string") +) + +var enctab = []byte("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789!#$%&()*+,./:;<=>?@[]^_`{|}~'") + +var dectab = map[byte]byte{ + 'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, + 'I': 8, 'J': 9, 'K': 10, 'L': 11, 'M': 12, 'N': 13, 'O': 14, 'P': 15, + 'Q': 16, 'R': 17, 'S': 18, 'T': 19, 'U': 20, 'V': 21, 'W': 22, 'X': 23, + 'Y': 24, 'Z': 25, 'a': 26, 'b': 27, 'c': 28, 'd': 29, 'e': 30, 'f': 31, + 'g': 32, 'h': 33, 'i': 34, 'j': 35, 'k': 36, 'l': 37, 'm': 38, 'n': 39, + 'o': 40, 'p': 41, 'q': 42, 'r': 43, 's': 44, 't': 45, 'u': 46, 'v': 47, + 'w': 48, 'x': 49, 'y': 50, 'z': 51, '0': 52, '1': 53, '2': 54, '3': 55, + '4': 56, '5': 57, '6': 58, '7': 59, '8': 60, '9': 61, '!': 62, '#': 63, + '$': 64, '%': 65, '&': 66, '(': 67, ')': 68, '*': 69, '+': 70, ',': 71, + '.': 72, '/': 73, ':': 74, ';': 75, '<': 76, '=': 77, '>': 78, '?': 79, + '@': 80, '[': 81, ']': 82, '^': 83, '_': 84, '`': 85, '{': 86, '|': 87, + '}': 88, '~': 89, '\'': 90, +} + +func Base91EncodeToString(d []byte) string { + return string(Base91Encode(d)) +} + +func Base91Encode(d []byte) []byte { + var n, b uint + var o []byte + + for i := 0; i < len(d); i++ { + b |= uint(d[i]) << n + n += 8 + if n > 13 { + v := b & 8191 + if v > 88 { + b >>= 13 + n -= 13 + } else { + v = b & 16383 + b >>= 14 + n -= 14 + } + o = append(o, enctab[v%91], enctab[v/91]) + } + } + if n > 0 { + o = append(o, enctab[b%91]) + if n > 7 || b > 90 { + o = append(o, enctab[b/91]) + } + } + return o +} + +func Base91DecodeString(d string) []byte { + return Base91Decode([]byte(d)) +} + +func Base91Decode(d []byte) []byte { + var b, n uint + var o []byte + v := -1 + + for i := 0; i < len(d); i++ { + c, ok := dectab[d[i]] + if !ok { + continue + } + if v < 0 { + v = int(c) + } else { + v += int(c) * 91 + b |= uint(v) << n + if v&8191 > 88 { + n += 13 + } else { + n += 14 + } + o = append(o, byte(b&255)) + b >>= 8 + n -= 8 + for n > 7 { + o = append(o, byte(b&255)) + b >>= 8 + n -= 8 + } + v = -1 + } + } + if v+1 > 0 { + o = append(o, byte((b|uint(v)<>whichByte)) + bufByte = (v & byte((1< 1 { + buf = append(buf, bufByte|(v>>(8-whichByte))) + } + bufByte = v << whichByte + if whichByte == 8 { + whichByte = 0 + } + whichByte++ + } + return len(buf), nil +} + +func Base128DecodeString(s string) ([]byte, error) { + src := []byte(s) + dst := make([]byte, Base128DecodedLen(len(src))) + n, err := Base128Decode(dst, src) + if err != nil { + return nil, err + } + return dst[:n], nil +} + +func Base128DecodedLen(encLen int) int { + return encLen * 7 / 8 +} + +func Base128EncodedLen(dataLen int) int { + return ((dataLen * 8) + 6) / 7 +} + +func Base128EncodeToString(src []byte) string { + dst := make([]byte, Base128EncodedLen(len(src))) + Base128Encode(dst, src) + return string(dst) +} + +func Base64Encode(bstr []byte) string { + return base64.StdEncoding.EncodeToString(bstr) +} + +func Base64Decode(str string) ([]byte, error) { + return base64.StdEncoding.DecodeString(str) +} + +func Base85Encode(bstr []byte) string { + out := make([]byte, ascii85.MaxEncodedLen(len(bstr))) + n := ascii85.Encode(out, bstr) + return string(out[:n]) +} + +func Base85Decode(str string) ([]byte, error) { + out := make([]byte, len(str)) + n, _, err := ascii85.Decode(out, []byte(str), true) + if err != nil { + return nil, err + } + return out[:n], nil +} + +func Base85EncodeFile(src, dst string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + enc := ascii85.NewEncoder(fpdst) + defer enc.Close() + + var sum int64 + buf := make([]byte, 1000*1024) + for { + n, readErr := fpsrc.Read(buf) + if n > 0 { + sum += int64(n) + if _, err := enc.Write(buf[:n]); err != nil { + return err + } + reportProgress(progress, sum, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +func Base85DecodeFile(src, dst string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + counter := &countingReader{r: fpsrc} + dec := ascii85.NewDecoder(counter) + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + buf := make([]byte, 1250*1024) + for { + n, readErr := dec.Read(buf) + if n > 0 { + if _, err := fpdst.Write(buf[:n]); err != nil { + return err + } + reportProgress(progress, counter.n, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +func Base64EncodeFile(src, dst string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + enc := base64.NewEncoder(base64.StdEncoding, fpdst) + defer enc.Close() + + var sum int64 + buf := make([]byte, 1024*1024) + for { + n, readErr := fpsrc.Read(buf) + if n > 0 { + sum += int64(n) + if _, err := enc.Write(buf[:n]); err != nil { + return err + } + reportProgress(progress, sum, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +func Base64DecodeFile(src, dst string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + counter := &countingReader{r: fpsrc} + dec := base64.NewDecoder(base64.StdEncoding, counter) + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + buf := make([]byte, 1024*1024) + for { + n, readErr := dec.Read(buf) + if n > 0 { + if _, err := fpdst.Write(buf[:n]); err != nil { + return err + } + reportProgress(progress, counter.n, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +type countingReader struct { + r io.Reader + n int64 +} + +func (c *countingReader) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + c.n += int64(n) + return n, err +} + +func reportProgress(progress func(float64), current, total int64) { + if progress == nil { + return + } + if total <= 0 { + progress(100) + return + } + progress(float64(current) / float64(total) * 100) +} diff --git a/encodingx/encoding_test.go b/encodingx/encoding_test.go new file mode 100644 index 0000000..c016ca9 --- /dev/null +++ b/encodingx/encoding_test.go @@ -0,0 +1,98 @@ +package encodingx + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestBase64AndBase85RoundTrip(t *testing.T) { + plain := []byte("encoding-roundtrip") + + b64 := Base64Encode(plain) + d64, err := Base64Decode(b64) + if err != nil { + t.Fatalf("Base64Decode failed: %v", err) + } + if !bytes.Equal(d64, plain) { + t.Fatalf("base64 mismatch") + } + + b85 := Base85Encode(plain) + d85, err := Base85Decode(b85) + if err != nil { + t.Fatalf("Base85Decode failed: %v", err) + } + if !bytes.Equal(d85, plain) { + t.Fatalf("base85 mismatch") + } +} + +func TestBase91AndBase128RoundTrip(t *testing.T) { + plain := []byte("base91-base128") + + e91 := Base91Encode(plain) + d91 := Base91Decode(e91) + if !bytes.Equal(d91, plain) { + t.Fatalf("base91 mismatch") + } + + e128 := Base128EncodeToString(plain) + d128, err := Base128DecodeString(e128) + if err != nil { + t.Fatalf("Base128DecodeString failed: %v", err) + } + if !bytes.Equal(d128, plain) { + t.Fatalf("base128 mismatch") + } +} + +func TestBase128DecodeInvalid(t *testing.T) { + _, err := Base128DecodeString(string([]byte{0x80})) + if err == nil { + t.Fatalf("expected base128 decode error") + } +} + +func TestBase64AndBase85FileRoundTrip(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "src.bin") + b64 := filepath.Join(dir, "src.b64") + dst64 := filepath.Join(dir, "src.64.out") + b85 := filepath.Join(dir, "src.b85") + dst85 := filepath.Join(dir, "src.85.out") + + plain := []byte("file-roundtrip-encoding") + if err := os.WriteFile(src, plain, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + if err := Base64EncodeFile(src, b64, nil); err != nil { + t.Fatalf("Base64EncodeFile failed: %v", err) + } + if err := Base64DecodeFile(b64, dst64, nil); err != nil { + t.Fatalf("Base64DecodeFile failed: %v", err) + } + got64, err := os.ReadFile(dst64) + if err != nil { + t.Fatalf("ReadFile dst64 failed: %v", err) + } + if !bytes.Equal(got64, plain) { + t.Fatalf("base64 file roundtrip mismatch") + } + + if err := Base85EncodeFile(src, b85, nil); err != nil { + t.Fatalf("Base85EncodeFile failed: %v", err) + } + if err := Base85DecodeFile(b85, dst85, nil); err != nil { + t.Fatalf("Base85DecodeFile failed: %v", err) + } + got85, err := os.ReadFile(dst85) + if err != nil { + t.Fatalf("ReadFile dst85 failed: %v", err) + } + if !bytes.Equal(got85, plain) { + t.Fatalf("base85 file roundtrip mismatch") + } +} diff --git a/encodingx/fuzz_test.go b/encodingx/fuzz_test.go new file mode 100644 index 0000000..e5f3c39 --- /dev/null +++ b/encodingx/fuzz_test.go @@ -0,0 +1,35 @@ +package encodingx + +import ( + "bytes" + "testing" +) + +func FuzzBase128RoundTrip(f *testing.F) { + f.Add([]byte("base128")) + f.Add([]byte{}) + + f.Fuzz(func(t *testing.T, data []byte) { + e := Base128EncodeToString(data) + d, err := Base128DecodeString(e) + if err != nil { + t.Fatalf("Base128DecodeString failed: %v", err) + } + if !bytes.Equal(d, data) { + t.Fatalf("base128 roundtrip mismatch") + } + }) +} + +func FuzzBase91RoundTrip(f *testing.F) { + f.Add([]byte("base91")) + f.Add([]byte{}) + + f.Fuzz(func(t *testing.T, data []byte) { + e := Base91Encode(data) + d := Base91Decode(e) + if !bytes.Equal(d, data) { + t.Fatalf("base91 roundtrip mismatch") + } + }) +} diff --git a/file.go b/file.go index 1b1c276..f1b46a5 100644 --- a/file.go +++ b/file.go @@ -1,246 +1,23 @@ package starcrypto -import ( - "bufio" - "errors" - "fmt" - "io" - "io/ioutil" - "math/rand" - "os" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" -) +import "b612.me/starcrypto/filex" -// Attach 合并src与dst文件并输出到output中 func Attach(src, dst, output string) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - fpdst, err := os.Open(dst) - if err != nil { - return err - } - defer fpdst.Close() - fpout, err := os.Create(output) - if err != nil { - return err - } - defer fpout.Close() - if _, err := io.Copy(fpout, fpsrc); err != nil { - return err - } - for { - buf := make([]byte, 1048574) - n, err := fpdst.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - fpout.Write(buf[0:n]) - } - return nil + return filex.Attach(src, dst, output) } -// Detach 按bytenum字节大小分割src文件到dst1与dst2两个新文件中去 func Detach(src string, bytenum int, dst1, dst2 string) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - fpdst1, err := os.Create(dst1) - if err != nil { - return err - } - defer fpdst1.Close() - fpdst2, err := os.Create(dst2) - if err != nil { - return err - } - defer fpdst2.Close() - sumall := 0 - var buf []byte - for { - if bytenum-sumall < 1048576 { - buf = make([]byte, bytenum-sumall) - } else { - buf = make([]byte, 1048576) - } - n, err := fpsrc.Read(buf) - if err != nil { - return err - } - sumall += n - fpdst1.Write(buf[0:n]) - if sumall == bytenum { - break - } - } - for { - buf = make([]byte, 1048576) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - fpdst2.Write(buf[0:n]) - } - return nil + return filex.Detach(src, bytenum, dst1, dst2) } -// SplitFile 把src文件按要求分割到dst中去,dst应传入带*号字符串 -// 如果bynum=true 则把文件分割成num份 -// 如果bynum=false 则把文件按num字节分成多份 func SplitFile(src, dst string, num int, bynum bool, shell func(float64)) error { - fpsrc, err := os.Open(src) - if err != nil { - return err - } - defer fpsrc.Close() - stat, _ := os.Stat(src) - filebig := float64(stat.Size()) - if bynum { - if int(filebig) < num { - return errors.New("file is too small to split") - } - } - balance := int(filebig/float64(num)) + 1 - if !bynum { - balance = num - } - nownum := 0 - fpdst, err := os.Create(strings.Replace(dst, "*", fmt.Sprint(nownum), -1)) - if err != nil { - return err - } - defer fpdst.Close() - var sum, tsum int = 0, 0 - var buf []byte - for { - if balance-sum < 1048576 { - buf = make([]byte, balance-sum) - } else { - buf = make([]byte, 1048576) - } - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += n - tsum += n - fpdst.Write(buf[0:n]) - go shell(float64(tsum) / filebig * 100) - if sum == balance { - fpdst.Close() - nownum++ - fpdst, err = os.Create(strings.Replace(dst, "*", fmt.Sprint(nownum), -1)) - if err != nil { - return err - } - sum = 0 - } - } - return nil + return filex.SplitFile(src, dst, num, bynum, shell) } -// MergeFile 合并src文件到dst文件中去,src文件应传入带*号字符串 func MergeFile(src, dst string, shell func(float64)) error { - tmp := strings.Replace(src, "*", "0", -1) - dir, err := ioutil.ReadDir(filepath.Dir(tmp)) - if err != nil { - return err - } - base := filepath.Base(src) - tmp = strings.Replace(base, "*", "(\\d+)", -1) - reg := regexp.MustCompile(tmp) - count := 0 - var filebig float64 - for _, v := range dir { - if reg.MatchString(v.Name()) { - count++ - filebig += float64(v.Size()) - } - } - fpdst, err := os.Create(dst) - defer fpdst.Close() - if err != nil { - return err - } - var sum int64 - for i := 0; i < count; i++ { - fpsrc, err := os.Open(strings.Replace(src, "*", strconv.Itoa(i), -1)) - if err != nil { - return err - } - for { - buf := make([]byte, 1048576) - n, err := fpsrc.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return err - } - sum += int64(n) - go shell(float64(sum) / filebig * 100) - fpdst.Write(buf[0:n]) - } - fpsrc.Close() - } - return nil + return filex.MergeFile(src, dst, shell) } -// FillWithRandom 随机写filesize大小的文件,每次buf大小为bufcap,随机bufnum个字符 func FillWithRandom(filepath string, filesize int, bufcap int, bufnum int, shell func(float64)) error { - var buf [][]byte - var buftmp []byte - rand.Seed(time.Now().Unix()) - if bufnum <= 0 { - bufnum = 1 - } - if bufcap > filesize { - bufcap = filesize - } - myfile, err := os.Create(filepath) - if err != nil { - return err - } - defer myfile.Close() - writer := bufio.NewWriter(myfile) - for i := 0; i < bufnum; i++ { - buftmp = []byte{} - for j := 0; j < bufcap; j++ { - buftmp = append(buftmp, byte(rand.Intn(256))) - } - buf = append(buf, buftmp) - } - sum := 0 - for { - if filesize-sum < bufcap { - writer.Write(buf[rand.Intn(bufnum)][0 : filesize-sum]) - sum += filesize - sum - } else { - writer.Write(buf[rand.Intn(bufnum)]) - sum += bufcap - } - go shell(float64(sum) / float64(filesize) * 100) - if sum >= filesize { - break - } - } - writer.Flush() - return nil + return filex.FillWithRandom(filepath, filesize, bufcap, bufnum, shell) } diff --git a/filex/file.go b/filex/file.go new file mode 100644 index 0000000..c0a7160 --- /dev/null +++ b/filex/file.go @@ -0,0 +1,311 @@ +package filex + +import ( + "bufio" + "errors" + "fmt" + "io" + "math/rand" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "time" +) + +func Attach(src, dst, output string) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + fpdst, err := os.Open(dst) + if err != nil { + return err + } + defer fpdst.Close() + + fpout, err := os.Create(output) + if err != nil { + return err + } + defer fpout.Close() + + if _, err := io.Copy(fpout, fpsrc); err != nil { + return err + } + if _, err := io.Copy(fpout, fpdst); err != nil { + return err + } + return nil +} + +func Detach(src string, bytenum int, dst1, dst2 string) error { + if bytenum < 0 { + return errors.New("bytenum must be non-negative") + } + + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + fpdst1, err := os.Create(dst1) + if err != nil { + return err + } + defer fpdst1.Close() + + fpdst2, err := os.Create(dst2) + if err != nil { + return err + } + defer fpdst2.Close() + + if bytenum > 0 { + if _, err := io.CopyN(fpdst1, fpsrc, int64(bytenum)); err != nil && err != io.EOF { + return err + } + } + if _, err := io.Copy(fpdst2, fpsrc); err != nil { + return err + } + return nil +} + +func SplitFile(src, dst string, num int, bynum bool, progress func(float64)) error { + if num <= 0 { + return errors.New("num must be greater than zero") + } + + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + total := stat.Size() + if total == 0 { + return errors.New("file is empty") + } + + var sizes []int64 + if bynum { + if total < int64(num) { + return errors.New("file is too small to split") + } + base := total / int64(num) + rest := total % int64(num) + sizes = make([]int64, 0, num) + for i := 0; i < num; i++ { + sz := base + if int64(i) < rest { + sz++ + } + sizes = append(sizes, sz) + } + } else { + chunk := int64(num) + for remain := total; remain > 0; { + sz := chunk + if remain < chunk { + sz = remain + } + sizes = append(sizes, sz) + remain -= sz + } + } + + var copied int64 + buf := make([]byte, 1024*1024) + for i, partSize := range sizes { + name := strings.Replace(dst, "*", fmt.Sprint(i), -1) + fpdst, err := os.Create(name) + if err != nil { + return err + } + + remaining := partSize + for remaining > 0 { + readLen := int64(len(buf)) + if remaining < readLen { + readLen = remaining + } + n, readErr := fpsrc.Read(buf[:readLen]) + if n > 0 { + if _, err := fpdst.Write(buf[:n]); err != nil { + fpdst.Close() + return err + } + remaining -= int64(n) + copied += int64(n) + reportProgress(progress, copied, total) + } + if readErr != nil { + if readErr == io.EOF && remaining == 0 { + break + } + fpdst.Close() + return readErr + } + } + + if err := fpdst.Close(); err != nil { + return err + } + } + return nil +} + +func MergeFile(src, dst string, progress func(float64)) error { + tmp := strings.Replace(src, "*", "0", -1) + dirEntries, err := os.ReadDir(filepath.Dir(tmp)) + if err != nil { + return err + } + + base := filepath.Base(src) + pattern := strings.Replace(base, "*", "(\\d+)", -1) + reg := regexp.MustCompile("^" + pattern + "$") + + type indexedFile struct { + index int + name string + size int64 + } + files := make([]indexedFile, 0) + var total int64 + for _, entry := range dirEntries { + m := reg.FindStringSubmatch(entry.Name()) + if len(m) != 2 { + continue + } + idx, err := strconv.Atoi(m[1]) + if err != nil { + continue + } + info, err := entry.Info() + if err != nil { + return err + } + files = append(files, indexedFile{index: idx, name: entry.Name(), size: info.Size()}) + total += info.Size() + } + if len(files) == 0 { + return errors.New("no split files found") + } + + sort.Slice(files, func(i, j int) bool { return files[i].index < files[j].index }) + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + var copied int64 + buf := make([]byte, 1024*1024) + for _, f := range files { + path := filepath.Join(filepath.Dir(tmp), f.name) + fpsrc, err := os.Open(path) + if err != nil { + return err + } + for { + n, readErr := fpsrc.Read(buf) + if n > 0 { + if _, err := fpdst.Write(buf[:n]); err != nil { + fpsrc.Close() + return err + } + copied += int64(n) + reportProgress(progress, copied, total) + } + if readErr != nil { + if readErr == io.EOF { + break + } + fpsrc.Close() + return readErr + } + } + if err := fpsrc.Close(); err != nil { + return err + } + } + return nil +} + +func FillWithRandom(path string, filesize, bufcap, bufnum int, progress func(float64)) error { + if filesize < 0 { + return errors.New("filesize must be non-negative") + } + if bufnum <= 0 { + bufnum = 1 + } + if bufcap <= 0 { + bufcap = 1 + } + if bufcap > filesize && filesize > 0 { + bufcap = filesize + } + + rand.Seed(time.Now().UnixNano()) + + fp, err := os.Create(path) + if err != nil { + return err + } + defer fp.Close() + + writer := bufio.NewWriter(fp) + defer writer.Flush() + + if filesize == 0 { + reportProgress(progress, 0, 0) + return nil + } + + pool := make([][]byte, 0, bufnum) + for i := 0; i < bufnum; i++ { + b := make([]byte, bufcap) + for j := 0; j < bufcap; j++ { + b[j] = byte(rand.Intn(256)) + } + pool = append(pool, b) + } + + written := 0 + for written < filesize { + chunk := bufcap + if filesize-written < chunk { + chunk = filesize - written + } + buf := pool[rand.Intn(len(pool))][:chunk] + if _, err := writer.Write(buf); err != nil { + return err + } + written += chunk + reportProgress(progress, int64(written), int64(filesize)) + } + return nil +} + +func reportProgress(progress func(float64), current, total int64) { + if progress == nil { + return + } + if total <= 0 { + progress(100) + return + } + progress(float64(current) / float64(total) * 100) +} diff --git a/go.mod b/go.mod index e441e45..693cc34 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,10 @@ module b612.me/starcrypto -go 1.16 +go 1.24.0 -require golang.org/x/crypto v0.21.0 +require ( + github.com/emmansun/gmsm v0.41.1 + golang.org/x/crypto v0.48.0 +) + +require golang.org/x/sys v0.41.0 // indirect diff --git a/go.sum b/go.sum index 78c3953..dc1cad3 100644 --- a/go.sum +++ b/go.sum @@ -1,45 +1,8 @@ -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w= +github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= +golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= diff --git a/hash.go b/hash.go index d9c7a86..988c12d 100644 --- a/hash.go +++ b/hash.go @@ -1,204 +1,15 @@ package starcrypto -import ( - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/hex" - "errors" - "hash" - "hash/crc32" - "io" - "os" -) +import "b612.me/starcrypto/hashx" -// SumAll 可以对同一数据进行多种校验 func SumAll(data []byte, method []string) (map[string][]byte, error) { - result := make(map[string][]byte) - methods := make(map[string]hash.Hash) - var iscrc bool - if len(method) == 0 { - method = []string{"sha512", "sha256", "sha384", "sha224", "sha1", "crc32", "md5"} - } - sum512 := sha512.New() - sum384 := sha512.New384() - sum256 := sha256.New() - sum224 := sha256.New224() - sum1 := sha1.New() - crcsum := crc32.NewIEEE() - md5sum := md5.New() - for _, v := range method { - switch v { - case "md5": - methods["md5"] = md5sum - case "crc32": - iscrc = true - case "sha1": - methods["sha1"] = sum1 - case "sha224": - methods["sha224"] = sum224 - case "sha256": - methods["sha256"] = sum256 - case "sha384": - methods["sha384"] = sum384 - case "sha512": - methods["sha512"] = sum512 - } - } - for _, v := range methods { - v.Write(data) - } - if iscrc { - crcsum.Write(data) - } - - for k, v := range methods { - result[k] = v.Sum(nil) - } - if iscrc { - result["crc32"] = crcsum.Sum(nil) - } - return result, nil + return hashx.SumAll(data, method) } -// FileSum 输出文件内容校验值,method为单个校验方法,小写 -// 例:FileSum("./test.txt","md5",shell(pect float64){fmt.Sprintf("已完成 %f\r",pect)}) func FileSum(filepath, method string, shell func(float64)) (string, error) { - var sum hash.Hash - var sum32 hash.Hash32 - var issum32 bool - var result string - fp, err := os.Open(filepath) - if err != nil { - return "", err - } - switch method { - case "sha512": - sum = sha512.New() - case "sha384": - sum = sha512.New384() - case "sha256": - sum = sha256.New() - case "sha224": - sum = sha256.New224() - case "sha1": - sum = sha1.New() - case "crc32": - sum32 = crc32.NewIEEE() - issum32 = true - case "md5": - sum = md5.New() - default: - return "", errors.New("Cannot Recognize The Method:" + method) - } - writer := 0 - stat, _ := os.Stat(filepath) - filebig := float64(stat.Size()) - if !issum32 { - // if _, err := io.Copy(sum, fp); err != nil { - for { - buf := make([]byte, 1048574) - n, err := fp.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return result, err - } - writer += n - pect := (float64(writer) / filebig) * 100 - go shell(pect) - sum.Write(buf[0:n]) - } - result = hex.EncodeToString(sum.Sum(nil)) - } else { - for { - buf := make([]byte, 1048574) - n, err := fp.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return result, err - } - writer += n - pect := (float64(writer) / filebig) * 100 - go shell(pect) - sum32.Write(buf[0:n]) - } - result = hex.EncodeToString(sum32.Sum(nil)) - } - return result, nil + return hashx.FileSum(filepath, method, shell) } -// FileSumAll 可以对同一文件进行多种校验 func FileSumAll(filepath string, method []string, shell func(float64)) (map[string]string, error) { - result := make(map[string]string) - methods := make(map[string]hash.Hash) - var iscrc bool - - if len(method) == 0 { - method = []string{"sha512", "sha256", "sha384", "sha224", "sha1", "crc32", "md5"} - } - fp, err := os.Open(filepath) - defer fp.Close() - if err != nil { - return result, err - } - stat, _ := os.Stat(filepath) - filebig := float64(stat.Size()) - sum512 := sha512.New() - sum384 := sha512.New384() - sum256 := sha256.New() - sum224 := sha256.New224() - sum1 := sha1.New() - crcsum := crc32.NewIEEE() - md5sum := md5.New() - for _, v := range method { - switch v { - case "md5": - methods["md5"] = md5sum - case "crc32": - iscrc = true - case "sha1": - methods["sha1"] = sum1 - case "sha224": - methods["sha224"] = sum224 - case "sha256": - methods["sha256"] = sum256 - case "sha384": - methods["sha384"] = sum384 - case "sha512": - methods["sha512"] = sum512 - } - } - - writer := 0 - for { - buf := make([]byte, 1048574) - n, err := fp.Read(buf) - if err != nil { - if err == io.EOF { - break - } - return result, err - } - writer += n - pect := (float64(writer) / filebig) * 100 - go shell(pect) - for _, v := range methods { - v.Write(buf[0:n]) - } - if iscrc { - crcsum.Write(buf[0:n]) - } - } - for k, v := range methods { - result[k] = hex.EncodeToString(v.Sum(nil)) - } - if iscrc { - result["crc32"] = hex.EncodeToString(crcsum.Sum(nil)) - } - return result, nil + return hashx.FileSumAll(filepath, method, shell) } diff --git a/hashx/hashx.go b/hashx/hashx.go new file mode 100644 index 0000000..c0163a5 --- /dev/null +++ b/hashx/hashx.go @@ -0,0 +1,420 @@ +package hashx + +import ( + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/binary" + "encoding/hex" + "errors" + "hash" + "hash/crc32" + "io" + "os" + + gmsm3 "github.com/emmansun/gmsm/sm3" + "golang.org/x/crypto/md4" + "golang.org/x/crypto/ripemd160" +) + +var ( + // ErrUnsupportedMethod reports an unknown hash method string. + ErrUnsupportedMethod = errors.New("unsupported hash method") +) + +func HexString(b []byte) string { + return hex.EncodeToString(b) +} + +func Md5(b []byte) []byte { + sum := md5.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Md5Str(b []byte) string { + return HexString(Md5(b)) +} + +func Md4(b []byte) []byte { + sum := md4.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Md4Str(b []byte) string { + return HexString(Md4(b)) +} + +func Sha512(b []byte) []byte { + sum := sha512.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Sha512Str(b []byte) string { + return HexString(Sha512(b)) +} + +func Sha384(b []byte) []byte { + sum := sha512.New384() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Sha384Str(b []byte) string { + return HexString(Sha384(b)) +} + +func Sha256(b []byte) []byte { + sum := sha256.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Sha256Str(b []byte) string { + return HexString(Sha256(b)) +} + +func Sha224(b []byte) []byte { + sum := sha256.New224() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Sha224Str(b []byte) string { + return HexString(Sha224(b)) +} + +func Sha1(b []byte) []byte { + sum := sha1.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func Sha1Str(b []byte) string { + return HexString(Sha1(b)) +} + +func RipeMd160(b []byte) []byte { + sum := ripemd160.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func RipeMd160Str(b []byte) string { + return HexString(RipeMd160(b)) +} + +func SM3(b []byte) []byte { + sum := gmsm3.New() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +func SM3Str(b []byte) string { + return HexString(SM3(b)) +} + +// CheckCRC32A returns CRC32A as uint32 in big-endian view. +// CRC32A here is an MSB-first, non-reflected variant and is different +// from Go standard library CRC-32/IEEE. +func CheckCRC32A(data []byte) uint32 { + b := crc32aDigest(data) + return binary.BigEndian.Uint32(b) +} + +func Crc32Str(b []byte) string { + return HexString(Crc32(b)) +} + +func Crc32(b []byte) []byte { + sum := crc32.NewIEEE() + _, _ = sum.Write(b) + return sum.Sum(nil) +} + +// Crc32A computes CRC32A (MSB-first, non-reflected), which is not the +// same route as Go standard library CRC-32/IEEE. +func Crc32A(data []byte) []byte { + return crc32aDigest(data) +} + +// Crc32AStr returns hex string of Crc32A digest. +func Crc32AStr(data []byte) string { + return hex.EncodeToString(crc32aDigest(data)) +} + +func SumAll(data []byte, methods []string) (map[string][]byte, error) { + if len(methods) == 0 { + methods = []string{"sha512", "sha256", "sha384", "sha224", "sha1", "crc32", "md5"} + } + + result := make(map[string][]byte, len(methods)) + hashers := make(map[string]hash.Hash, len(methods)) + var crc hash.Hash32 + + for _, method := range methods { + switch method { + case "md5": + hashers[method] = md5.New() + case "sha1": + hashers[method] = sha1.New() + case "sha224": + hashers[method] = sha256.New224() + case "sha256": + hashers[method] = sha256.New() + case "sha384": + hashers[method] = sha512.New384() + case "sha512": + hashers[method] = sha512.New() + case "crc32": + if crc == nil { + crc = crc32.NewIEEE() + } + default: + // Keep compatibility with previous behavior: unknown methods are ignored. + } + } + + for _, h := range hashers { + _, _ = h.Write(data) + } + if crc != nil { + _, _ = crc.Write(data) + } + + for method, h := range hashers { + result[method] = h.Sum(nil) + } + if crc != nil { + result["crc32"] = crc.Sum(nil) + } + return result, nil +} + +func FileSum(filePath, method string, progress func(float64)) (string, error) { + fp, err := os.Open(filePath) + if err != nil { + return "", err + } + defer fp.Close() + + stat, err := fp.Stat() + if err != nil { + return "", err + } + + var ( + h hash.Hash + h32 hash.Hash32 + is32 bool + total int64 + size = stat.Size() + ) + + switch method { + case "sha512": + h = sha512.New() + case "sha384": + h = sha512.New384() + case "sha256": + h = sha256.New() + case "sha224": + h = sha256.New224() + case "sha1": + h = sha1.New() + case "md5": + h = md5.New() + case "crc32": + h32 = crc32.NewIEEE() + is32 = true + default: + return "", errors.New(ErrUnsupportedMethod.Error() + ": " + method) + } + + buf := make([]byte, 1024*1024) + for { + n, readErr := fp.Read(buf) + if n > 0 { + total += int64(n) + if is32 { + _, _ = h32.Write(buf[:n]) + } else { + _, _ = h.Write(buf[:n]) + } + reportProgress(progress, total, size) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return "", readErr + } + } + + if is32 { + return hex.EncodeToString(h32.Sum(nil)), nil + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +func FileSumAll(filePath string, methods []string, progress func(float64)) (map[string]string, error) { + if len(methods) == 0 { + methods = []string{"sha512", "sha256", "sha384", "sha224", "sha1", "crc32", "md5"} + } + + fp, err := os.Open(filePath) + if err != nil { + return nil, err + } + defer fp.Close() + + stat, err := fp.Stat() + if err != nil { + return nil, err + } + + hashers := make(map[string]hash.Hash, len(methods)) + var crc hash.Hash32 + for _, method := range methods { + switch method { + case "md5": + hashers[method] = md5.New() + case "sha1": + hashers[method] = sha1.New() + case "sha224": + hashers[method] = sha256.New224() + case "sha256": + hashers[method] = sha256.New() + case "sha384": + hashers[method] = sha512.New384() + case "sha512": + hashers[method] = sha512.New() + case "crc32": + if crc == nil { + crc = crc32.NewIEEE() + } + default: + // Keep compatibility with previous behavior: unknown methods are ignored. + } + } + + var total int64 + size := stat.Size() + buf := make([]byte, 1024*1024) + for { + n, readErr := fp.Read(buf) + if n > 0 { + total += int64(n) + chunk := buf[:n] + for _, h := range hashers { + _, _ = h.Write(chunk) + } + if crc != nil { + _, _ = crc.Write(chunk) + } + reportProgress(progress, total, size) + } + + if readErr != nil { + if readErr == io.EOF { + break + } + return nil, readErr + } + } + + result := make(map[string]string, len(hashers)+1) + for method, h := range hashers { + result[method] = hex.EncodeToString(h.Sum(nil)) + } + if crc != nil { + result["crc32"] = hex.EncodeToString(crc.Sum(nil)) + } + return result, nil +} + +func reportProgress(progress func(float64), current, total int64) { + if progress == nil { + return + } + if total <= 0 { + progress(100) + return + } + progress(float64(current) / float64(total) * 100) +} + +func crc32aDigest(data []byte) []byte { + var crc uint32 + digest := make([]byte, 4) + + crc = ^crc + for i := 0; i < len(data); i++ { + crc = (crc << 8) ^ crc32aTable[(crc>>24)^(uint32(data[i])&0xff)] + } + crc = ^crc + + digest[3] = byte((crc >> 24) & 0xff) + digest[2] = byte((crc >> 16) & 0xff) + digest[1] = byte((crc >> 8) & 0xff) + digest[0] = byte(crc & 0xff) + return digest +} + +var crc32aTable = [256]uint32{ + 0x0, + 0x04c11db7, 0x09823b6e, 0x0d4326d9, 0x130476dc, 0x17c56b6b, + 0x1a864db2, 0x1e475005, 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, + 0x2b4bcb61, 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, + 0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9, 0x5f15adac, + 0x5bd4b01b, 0x569796c2, 0x52568b75, 0x6a1936c8, 0x6ed82b7f, + 0x639b0da6, 0x675a1011, 0x791d4014, 0x7ddc5da3, 0x709f7b7a, + 0x745e66cd, 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, + 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, 0xbe2b5b58, + 0xbaea46ef, 0xb7a96036, 0xb3687d81, 0xad2f2d84, 0xa9ee3033, + 0xa4ad16ea, 0xa06c0b5d, 0xd4326d90, 0xd0f37027, 0xddb056fe, + 0xd9714b49, 0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95, + 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, 0xe13ef6f4, + 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, 0x34867077, 0x30476dc0, + 0x3d044b19, 0x39c556ae, 0x278206ab, 0x23431b1c, 0x2e003dc5, + 0x2ac12072, 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, + 0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca, 0x7897ab07, + 0x7c56b6b0, 0x71159069, 0x75d48dde, 0x6b93dddb, 0x6f52c06c, + 0x6211e6b5, 0x66d0fb02, 0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, + 0x53dc6066, 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, + 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, 0xbfa1b04b, + 0xbb60adfc, 0xb6238b25, 0xb2e29692, 0x8aad2b2f, 0x8e6c3698, + 0x832f1041, 0x87ee0df6, 0x99a95df3, 0x9d684044, 0x902b669d, + 0x94ea7b2a, 0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e, + 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, 0xc6bcf05f, + 0xc27dede8, 0xcf3ecb31, 0xcbffd686, 0xd5b88683, 0xd1799b34, + 0xdc3abded, 0xd8fba05a, 0x690ce0ee, 0x6dcdfd59, 0x608edb80, + 0x644fc637, 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, + 0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f, 0x5c007b8a, + 0x58c1663d, 0x558240e4, 0x51435d53, 0x251d3b9e, 0x21dc2629, + 0x2c9f00f0, 0x285e1d47, 0x36194d42, 0x32d850f5, 0x3f9b762c, + 0x3b5a6b9b, 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, + 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, 0xf12f560e, + 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7, 0xe22b20d2, 0xe6ea3d65, + 0xeba91bbc, 0xef68060b, 0xd727bbb6, 0xd3e6a601, 0xdea580d8, + 0xda649d6f, 0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3, + 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, 0xae3afba2, + 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, 0x9b3660c6, 0x9ff77d71, + 0x92b45ba8, 0x9675461f, 0x8832161a, 0x8cf30bad, 0x81b02d74, + 0x857130c3, 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, + 0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c, 0x7b827d21, + 0x7f436096, 0x7200464f, 0x76c15bf8, 0x68860bfd, 0x6c47164a, + 0x61043093, 0x65c52d24, 0x119b4be9, 0x155a565e, 0x18197087, + 0x1cd86d30, 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, + 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, 0x2497d08d, + 0x2056cd3a, 0x2d15ebe3, 0x29d4f654, 0xc5a92679, 0xc1683bce, + 0xcc2b1d17, 0xc8ea00a0, 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, + 0xdbee767c, 0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18, + 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, 0x89b8fd09, + 0x8d79e0be, 0x803ac667, 0x84fbdbd0, 0x9abc8bd5, 0x9e7d9662, + 0x933eb0bb, 0x97ffad0c, 0xafb010b1, 0xab710d06, 0xa6322bdf, + 0xa2f33668, 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4, +} diff --git a/hashx/hashx_test.go b/hashx/hashx_test.go new file mode 100644 index 0000000..a4ccda3 --- /dev/null +++ b/hashx/hashx_test.go @@ -0,0 +1,89 @@ +package hashx + +import ( + "crypto/md5" + "encoding/hex" + "os" + "path/filepath" + "testing" +) + +func TestSha1StrMatchesSha1(t *testing.T) { + got := Sha1Str([]byte("abc")) + const want = "a9993e364706816aba3e25717850c26c9cd0d89d" + if got != want { + t.Fatalf("Sha1Str mismatch, got %s want %s", got, want) + } +} + +func TestSM3AndCRC32A(t *testing.T) { + if len(SM3([]byte("abc"))) != 32 { + t.Fatalf("SM3 digest size must be 32") + } + if Crc32AStr([]byte("123456789")) != "181989fc" { + t.Fatalf("Crc32AStr mismatch") + } + if CheckCRC32A([]byte("123456789")) != 0x181989fc { + t.Fatalf("CheckCRC32A mismatch") + } +} + +func TestSumAllUnknownMethodIgnored(t *testing.T) { + res, err := SumAll([]byte("abc"), []string{"sha1", "unknown"}) + if err != nil { + t.Fatalf("SumAll returned error: %v", err) + } + if _, ok := res["sha1"]; !ok { + t.Fatalf("expected sha1 in result") + } + if _, ok := res["unknown"]; ok { + t.Fatalf("unknown method should be ignored") + } +} + +func TestFileSumAndFileSumAll(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sum.txt") + data := []byte("hash-file-content") + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + + calls := 0 + h, err := FileSum(path, "md5", func(float64) { calls++ }) + if err != nil { + t.Fatalf("FileSum failed: %v", err) + } + expected := md5.Sum(data) + if h != hex.EncodeToString(expected[:]) { + t.Fatalf("md5 mismatch, got %s want %s", h, hex.EncodeToString(expected[:])) + } + if calls == 0 { + t.Fatalf("progress callback should be called") + } + + all, err := FileSumAll(path, []string{"sha1", "crc32", "unknown"}, nil) + if err != nil { + t.Fatalf("FileSumAll failed: %v", err) + } + if _, ok := all["sha1"]; !ok { + t.Fatalf("expected sha1 in FileSumAll") + } + if _, ok := all["crc32"]; !ok { + t.Fatalf("expected crc32 in FileSumAll") + } + if _, ok := all["unknown"]; ok { + t.Fatalf("unknown method should not appear") + } +} + +func TestFileSumUnsupportedMethod(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "sum.txt") + if err := os.WriteFile(path, []byte("x"), 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if _, err := FileSum(path, "not-support", nil); err == nil { + t.Fatalf("expected unsupported method error") + } +} diff --git a/hmac.go b/hmac.go index bf77fa4..61eb237 100644 --- a/hmac.go +++ b/hmac.go @@ -1,87 +1,67 @@ package starcrypto -import ( - "crypto/hmac" - "crypto/md5" - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" - "encoding/hex" - "golang.org/x/crypto/md4" - "golang.org/x/crypto/ripemd160" - "hash" -) - -func chmac(message, key []byte, f func() hash.Hash) []byte { - h := hmac.New(f, []byte(key)) - h.Write([]byte(message)) - return h.Sum(nil) -} - -func chmacStr(message, key []byte, f func() hash.Hash) string { - return hex.EncodeToString(chmac(message, key, f)) -} +import "b612.me/starcrypto/macx" func HmacMd4(message, key []byte) []byte { - return chmac(message, key, md4.New) + return macx.HmacMd4(message, key) } func HmacMd4Str(message, key []byte) string { - return chmacStr(message, key, md4.New) + return macx.HmacMd4Str(message, key) } func HmacMd5(message, key []byte) []byte { - return chmac(message, key, md5.New) + return macx.HmacMd5(message, key) } func HmacMd5Str(message, key []byte) string { - return chmacStr(message, key, md5.New) + return macx.HmacMd5Str(message, key) } func HmacSHA1(message, key []byte) []byte { - return chmac(message, key, sha1.New) + return macx.HmacSHA1(message, key) } func HmacSHA1Str(message, key []byte) string { - return chmacStr(message, key, sha1.New) + return macx.HmacSHA1Str(message, key) } func HmacSHA256(message, key []byte) []byte { - return chmac(message, key, sha256.New) + return macx.HmacSHA256(message, key) } func HmacSHA256Str(message, key []byte) string { - return chmacStr(message, key, sha256.New) + return macx.HmacSHA256Str(message, key) } func HmacSHA384(message, key []byte) []byte { - return chmac(message, key, sha512.New384) + return macx.HmacSHA384(message, key) } func HmacSHA384Str(message, key []byte) string { - return chmacStr(message, key, sha512.New384) + return macx.HmacSHA384Str(message, key) } func HmacSHA512(message, key []byte) []byte { - return chmac(message, key, sha512.New) + return macx.HmacSHA512(message, key) } func HmacSHA512Str(message, key []byte) string { - return chmacStr(message, key, sha512.New) + return macx.HmacSHA512Str(message, key) } func HmacSHA224(message, key []byte) []byte { - return chmac(message, key, sha256.New224) + return macx.HmacSHA224(message, key) } func HmacSHA224Str(message, key []byte) string { - return chmacStr(message, key, sha256.New224) + return macx.HmacSHA224Str(message, key) } func HmacRipeMd160(message, key []byte) []byte { - return chmac(message, key, ripemd160.New) + return macx.HmacRipeMd160(message, key) } func HmacRipeMd160Str(message, key []byte) string { - return chmacStr(message, key, ripemd160.New) + return macx.HmacRipeMd160Str(message, key) } diff --git a/legacy/vicque.go b/legacy/vicque.go new file mode 100644 index 0000000..bcfc97f --- /dev/null +++ b/legacy/vicque.go @@ -0,0 +1,201 @@ +package legacy + +import ( + "crypto/rand" + "io" + "os" +) + +func VicqueEncodeV1(srcdata []byte, key string) []byte { + var keys []int + var randCode [2]byte + _, _ = io.ReadFull(rand.Reader, randCode[:]) + + data := make([]byte, len(srcdata)) + copy(data, srcdata) + + randCode1 := int(randCode[0]) + randCode2 := int(randCode[1]) + keys = append(keys, len(key)+randCode1) + for _, v := range key { + keys = append(keys, int(byte(v))+randCode1-randCode2) + } + + lens := len(data) + lenkey := len(keys) + for k, v := range data { + if k == lens/2 { + break + } + nv := int(v) + t := 0 + if k%2 == 0 { + nv += keys[k%lenkey] + if nv > 255 { + nv -= 256 + } + t = int(data[lens-1-k]) + t += keys[k%lenkey] + if t > 255 { + t -= 256 + } + } else { + nv -= keys[k%lenkey] + if nv < 0 { + nv += 256 + } + t = int(data[lens-1-k]) + t -= keys[k%lenkey] + if t < 0 { + t += 256 + } + } + data[k] = byte(t) + data[lens-1-k] = byte(nv) + } + data = append(data, randCode[0], randCode[1]) + return data +} + +func VicqueDecodeV1(srcdata []byte, key string) []byte { + if len(srcdata) < 2 { + return nil + } + + data := make([]byte, len(srcdata)) + copy(data, srcdata) + lens := len(data) + + randCode1 := int(data[lens-2]) + randCode2 := int(data[lens-1]) + + keys := []int{len(key) + randCode1} + for _, v := range key { + keys = append(keys, int(byte(v))+randCode1-randCode2) + } + + lenkey := len(keys) + lens -= 2 + for k, v := range data { + if k == lens/2 { + break + } + nv := int(v) + t := 0 + if k%2 == 0 { + nv -= keys[k%lenkey] + if nv < 0 { + nv += 256 + } + t = int(data[lens-1-k]) + t -= keys[k%lenkey] + if t < 0 { + t += 256 + } + } else { + nv += keys[k%lenkey] + if nv > 255 { + nv -= 256 + } + t = int(data[lens-1-k]) + t += keys[k%lenkey] + if t > 255 { + t -= 256 + } + } + data[k] = byte(t) + data[lens-1-k] = byte(nv) + } + return data[:lens] +} + +func VicqueEncodeV1File(src, dst, pwd string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + var sum int64 + buf := make([]byte, 1024*1024) + for { + n, readErr := fpsrc.Read(buf) + if n > 0 { + sum += int64(n) + data := VicqueEncodeV1(buf[:n], pwd) + if _, err := fpdst.Write(data); err != nil { + return err + } + reportProgress(progress, sum, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +func VicqueDecodeV1File(src, dst, pwd string, progress func(float64)) error { + fpsrc, err := os.Open(src) + if err != nil { + return err + } + defer fpsrc.Close() + + stat, err := fpsrc.Stat() + if err != nil { + return err + } + + fpdst, err := os.Create(dst) + if err != nil { + return err + } + defer fpdst.Close() + + var sum int64 + buf := make([]byte, 1024*1024+2) + for { + n, readErr := fpsrc.Read(buf) + if n > 0 { + sum += int64(n) + data := VicqueDecodeV1(buf[:n], pwd) + if _, err := fpdst.Write(data); err != nil { + return err + } + reportProgress(progress, sum, stat.Size()) + } + if readErr != nil { + if readErr == io.EOF { + break + } + return readErr + } + } + return nil +} + +func reportProgress(progress func(float64), current, total int64) { + if progress == nil { + return + } + if total <= 0 { + progress(100) + return + } + progress(float64(current) / float64(total) * 100) +} diff --git a/legacy/vicque_test.go b/legacy/vicque_test.go new file mode 100644 index 0000000..a9a3b79 --- /dev/null +++ b/legacy/vicque_test.go @@ -0,0 +1,41 @@ +package legacy + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +func TestVicqueRoundTrip(t *testing.T) { + plain := []byte("legacy-vicque-roundtrip") + enc := VicqueEncodeV1(plain, "secret") + dec := VicqueDecodeV1(enc, "secret") + if !bytes.Equal(dec, plain) { + t.Fatalf("vicque roundtrip mismatch") + } +} + +func TestVicqueFileRoundTrip(t *testing.T) { + dir := t.TempDir() + src := filepath.Join(dir, "src.bin") + enc := filepath.Join(dir, "src.enc") + out := filepath.Join(dir, "src.out") + plain := []byte("legacy-vicque-file-roundtrip") + if err := os.WriteFile(src, plain, 0o644); err != nil { + t.Fatalf("WriteFile failed: %v", err) + } + if err := VicqueEncodeV1File(src, enc, "pwd", nil); err != nil { + t.Fatalf("VicqueEncodeV1File failed: %v", err) + } + if err := VicqueDecodeV1File(enc, out, "pwd", nil); err != nil { + t.Fatalf("VicqueDecodeV1File failed: %v", err) + } + got, err := os.ReadFile(out) + if err != nil { + t.Fatalf("ReadFile failed: %v", err) + } + if !bytes.Equal(got, plain) { + t.Fatalf("vicque file roundtrip mismatch") + } +} diff --git a/macx/hmac.go b/macx/hmac.go new file mode 100644 index 0000000..a286ea6 --- /dev/null +++ b/macx/hmac.go @@ -0,0 +1,88 @@ +package macx + +import ( + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "encoding/hex" + "hash" + + "golang.org/x/crypto/md4" + "golang.org/x/crypto/ripemd160" +) + +func chmac(message, key []byte, f func() hash.Hash) []byte { + h := hmac.New(f, key) + _, _ = h.Write(message) + return h.Sum(nil) +} + +func chmacStr(message, key []byte, f func() hash.Hash) string { + return hex.EncodeToString(chmac(message, key, f)) +} + +func HmacMd4(message, key []byte) []byte { + return chmac(message, key, md4.New) +} + +func HmacMd4Str(message, key []byte) string { + return chmacStr(message, key, md4.New) +} + +func HmacMd5(message, key []byte) []byte { + return chmac(message, key, md5.New) +} + +func HmacMd5Str(message, key []byte) string { + return chmacStr(message, key, md5.New) +} + +func HmacSHA1(message, key []byte) []byte { + return chmac(message, key, sha1.New) +} + +func HmacSHA1Str(message, key []byte) string { + return chmacStr(message, key, sha1.New) +} + +func HmacSHA256(message, key []byte) []byte { + return chmac(message, key, sha256.New) +} + +func HmacSHA256Str(message, key []byte) string { + return chmacStr(message, key, sha256.New) +} + +func HmacSHA384(message, key []byte) []byte { + return chmac(message, key, sha512.New384) +} + +func HmacSHA384Str(message, key []byte) string { + return chmacStr(message, key, sha512.New384) +} + +func HmacSHA512(message, key []byte) []byte { + return chmac(message, key, sha512.New) +} + +func HmacSHA512Str(message, key []byte) string { + return chmacStr(message, key, sha512.New) +} + +func HmacSHA224(message, key []byte) []byte { + return chmac(message, key, sha256.New224) +} + +func HmacSHA224Str(message, key []byte) string { + return chmacStr(message, key, sha256.New224) +} + +func HmacRipeMd160(message, key []byte) []byte { + return chmac(message, key, ripemd160.New) +} + +func HmacRipeMd160Str(message, key []byte) string { + return chmacStr(message, key, ripemd160.New) +} diff --git a/md5.go b/md5.go index d7da41a..712cedb 100644 --- a/md5.go +++ b/md5.go @@ -1,27 +1,19 @@ package starcrypto -import ( - "crypto/md5" - "golang.org/x/crypto/md4" -) +import "b612.me/starcrypto/hashx" -// MD5 输出MD5校验值 func Md5(bstr []byte) []byte { - md5sum := md5.New() - md5sum.Write(bstr) - return md5sum.Sum(nil) + return hashx.Md5(bstr) } func Md5Str(bstr []byte) string { - return String(Md5(bstr)) + return hashx.Md5Str(bstr) } func Md4(bstr []byte) []byte { - md4sum := md4.New() - md4sum.Write(bstr) - return md4sum.Sum(nil) + return hashx.Md4(bstr) } func Md4Str(bstr []byte) string { - return String(Md4(bstr)) + return hashx.Md4Str(bstr) } diff --git a/paddingx/fuzz_test.go b/paddingx/fuzz_test.go new file mode 100644 index 0000000..1a613c7 --- /dev/null +++ b/paddingx/fuzz_test.go @@ -0,0 +1,25 @@ +package paddingx + +import "testing" + +func FuzzPadUnpadRoundTrip(f *testing.F) { + f.Add([]byte("abc")) + f.Add([]byte{}) + modes := []string{PKCS7, ZERO, ANSIX923} + + f.Fuzz(func(t *testing.T, data []byte) { + for _, mode := range modes { + padded, err := Pad(data, 16, mode) + if err != nil { + t.Fatalf("Pad failed: %v", err) + } + out, err := Unpad(padded, 16, mode) + if err != nil { + t.Fatalf("Unpad failed: %v", err) + } + if mode != ZERO && string(out) != string(data) { + t.Fatalf("roundtrip mismatch for mode %s", mode) + } + } + }) +} diff --git a/paddingx/padding.go b/paddingx/padding.go new file mode 100644 index 0000000..6a0b2b7 --- /dev/null +++ b/paddingx/padding.go @@ -0,0 +1,128 @@ +package paddingx + +import ( + "bytes" + "errors" + "strings" +) + +const ( + PKCS5 = "PKCS5" + PKCS7 = "PKCS7" + ZERO = "ZERO" + ANSIX923 = "ANSIX923" +) + +func Pad(data []byte, blockSize int, mode string) ([]byte, error) { + if blockSize <= 0 { + return nil, errors.New("block size must be greater than zero") + } + switch normalizeMode(mode) { + case "", PKCS7: + return PKCS7Padding(data, blockSize), nil + case PKCS5: + // Compatibility mode: historically PKCS5 was used generically in this project. + return PKCS7Padding(data, blockSize), nil + case ZERO: + return zeroPadding(data, blockSize), nil + case ANSIX923: + return ansiX923Padding(data, blockSize), nil + default: + return nil, errors.New("padding type not supported") + } +} + +func Unpad(data []byte, blockSize int, mode string) ([]byte, error) { + if blockSize <= 0 { + return nil, errors.New("block size must be greater than zero") + } + switch normalizeMode(mode) { + case "", PKCS7: + return PKCS7Unpadding(data, blockSize) + case PKCS5: + // Compatibility mode: historically PKCS5 was used generically in this project. + return PKCS7Unpadding(data, blockSize) + case ZERO: + return zeroUnpadding(data) + case ANSIX923: + return ansiX923Unpadding(data, blockSize) + default: + return nil, errors.New("padding type not supported") + } +} + +func PKCS7Padding(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + return append(data, padText...) +} + +func PKCS7Unpadding(data []byte, blockSize int) ([]byte, error) { + if len(data) == 0 || len(data)%blockSize != 0 { + return nil, errors.New("invalid PKCS7 padding") + } + padding := int(data[len(data)-1]) + if padding <= 0 || padding > blockSize || padding > len(data) { + return nil, errors.New("invalid PKCS7 padding") + } + for i := len(data) - padding; i < len(data); i++ { + if int(data[i]) != padding { + return nil, errors.New("invalid PKCS7 padding") + } + } + return data[:len(data)-padding], nil +} + +func PKCS5Padding(data []byte) []byte { + return PKCS7Padding(data, 8) +} + +func PKCS5Unpadding(data []byte) ([]byte, error) { + return PKCS7Unpadding(data, 8) +} + +func zeroPadding(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + if padding == blockSize { + return data + } + return append(data, bytes.Repeat([]byte{0x00}, padding)...) +} + +func zeroUnpadding(data []byte) ([]byte, error) { + idx := len(data) + for idx > 0 && data[idx-1] == 0x00 { + idx-- + } + return data[:idx], nil +} + +func ansiX923Padding(data []byte, blockSize int) []byte { + padding := blockSize - len(data)%blockSize + if padding == 0 { + padding = blockSize + } + pad := make([]byte, padding) + pad[len(pad)-1] = byte(padding) + return append(data, pad...) +} + +func ansiX923Unpadding(data []byte, blockSize int) ([]byte, error) { + if len(data) == 0 || len(data)%blockSize != 0 { + return nil, errors.New("invalid ANSI X9.23 padding") + } + padding := int(data[len(data)-1]) + if padding <= 0 || padding > blockSize || padding > len(data) { + return nil, errors.New("invalid ANSI X9.23 padding") + } + for i := len(data) - padding; i < len(data)-1; i++ { + if data[i] != 0x00 { + return nil, errors.New("invalid ANSI X9.23 padding") + } + } + return data[:len(data)-padding], nil +} + +func normalizeMode(mode string) string { + return strings.ToUpper(strings.TrimSpace(mode)) +} diff --git a/paddingx/padding_test.go b/paddingx/padding_test.go new file mode 100644 index 0000000..b8b7550 --- /dev/null +++ b/paddingx/padding_test.go @@ -0,0 +1,94 @@ +package paddingx + +import ( + "bytes" + "testing" +) + +func TestPadAndUnpadPKCS7(t *testing.T) { + plain := []byte("hello-world") + padded, err := Pad(plain, 16, PKCS7) + if err != nil { + t.Fatalf("Pad PKCS7 failed: %v", err) + } + if len(padded)%16 != 0 { + t.Fatalf("padded length should be block aligned, got %d", len(padded)) + } + got, err := Unpad(padded, 16, PKCS7) + if err != nil { + t.Fatalf("Unpad PKCS7 failed: %v", err) + } + if !bytes.Equal(got, plain) { + t.Fatalf("roundtrip mismatch, got %x want %x", got, plain) + } +} + +func TestPadAndUnpadPKCS5Compatibility(t *testing.T) { + plain := []byte("DES-plaintext") + padded, err := Pad(plain, 8, PKCS5) + if err != nil { + t.Fatalf("Pad PKCS5 failed: %v", err) + } + got, err := Unpad(padded, 8, PKCS5) + if err != nil { + t.Fatalf("Unpad PKCS5 failed: %v", err) + } + if !bytes.Equal(got, plain) { + t.Fatalf("roundtrip mismatch, got %x want %x", got, plain) + } +} + +func TestPadAndUnpadZero(t *testing.T) { + plain := []byte("abc\x00\x00") + padded, err := Pad(plain, 8, ZERO) + if err != nil { + t.Fatalf("Pad ZERO failed: %v", err) + } + got, err := Unpad(padded, 8, ZERO) + if err != nil { + t.Fatalf("Unpad ZERO failed: %v", err) + } + if !bytes.Equal(got, []byte("abc")) { + t.Fatalf("zero unpadding mismatch, got %q", got) + } +} + +func TestPadAndUnpadANSIX923(t *testing.T) { + plain := []byte("ansi-x923") + padded, err := Pad(plain, 16, ANSIX923) + if err != nil { + t.Fatalf("Pad ANSIX923 failed: %v", err) + } + got, err := Unpad(padded, 16, ANSIX923) + if err != nil { + t.Fatalf("Unpad ANSIX923 failed: %v", err) + } + if !bytes.Equal(got, plain) { + t.Fatalf("roundtrip mismatch, got %x want %x", got, plain) + } +} + +func TestPadUnsupportedMode(t *testing.T) { + if _, err := Pad([]byte("x"), 8, "UNKNOWN"); err == nil { + t.Fatalf("expected error for unsupported mode") + } +} + +func TestUnpadInvalidPKCS7(t *testing.T) { + _, err := Unpad([]byte{1, 2, 3, 4}, 4, PKCS7) + if err == nil { + t.Fatalf("expected invalid PKCS7 padding error") + } +} + +func TestPKCS5Helpers(t *testing.T) { + plain := []byte("1234567") + padded := PKCS5Padding(plain) + got, err := PKCS5Unpadding(padded) + if err != nil { + t.Fatalf("PKCS5Unpadding failed: %v", err) + } + if !bytes.Equal(got, plain) { + t.Fatalf("PKCS5 helper mismatch, got %x want %x", got, plain) + } +} diff --git a/ripe.go b/ripe.go index 3f9d708..daf0bfc 100644 --- a/ripe.go +++ b/ripe.go @@ -1,15 +1,11 @@ package starcrypto -import ( - "golang.org/x/crypto/ripemd160" -) +import "b612.me/starcrypto/hashx" func RipeMd160(bstr []byte) []byte { - ripe := ripemd160.New() - ripe.Write(bstr) - return ripe.Sum(nil) + return hashx.RipeMd160(bstr) } func RipeMd160Str(bstr []byte) string { - return String(RipeMd160(bstr)) + return hashx.RipeMd160Str(bstr) } diff --git a/rsa.go b/rsa.go index 4b6e99e..19ab67e 100644 --- a/rsa.go +++ b/rsa.go @@ -1,280 +1,59 @@ package starcrypto import ( + "b612.me/starcrypto/asymm" "crypto" - "crypto/rand" "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "golang.org/x/crypto/ssh" - "math/big" ) func GenerateRsaKey(bits int) (*rsa.PrivateKey, *rsa.PublicKey, error) { - private, err := rsa.GenerateKey(rand.Reader, bits) - if err != nil { - return nil, nil, err - } - return private, &private.PublicKey, nil - + return asymm.GenerateRsaKey(bits) } func EncodeRsaPrivateKey(private *rsa.PrivateKey, secret string) ([]byte, error) { - if secret == "" { - return pem.EncodeToMemory(&pem.Block{ - Bytes: x509.MarshalPKCS1PrivateKey(private), - Type: "RSA PRIVATE KEY", - }), nil - } - chiper := x509.PEMCipherAES256 - blk, err := x509.EncryptPEMBlock(rand.Reader, "RSA PRIVATE KEY", x509.MarshalPKCS1PrivateKey(private), []byte(secret), chiper) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(blk), err + return asymm.EncodeRsaPrivateKey(private, secret) } func EncodeRsaPublicKey(public *rsa.PublicKey) ([]byte, error) { - publicBytes, err := x509.MarshalPKIXPublicKey(public) - if err != nil { - return nil, err - } - return pem.EncodeToMemory(&pem.Block{ - Bytes: publicBytes, - Type: "PUBLIC KEY", - }), nil + return asymm.EncodeRsaPublicKey(public) } func DecodeRsaPrivateKey(private []byte, password string) (*rsa.PrivateKey, error) { - var prikey *rsa.PrivateKey - var err error - var bytes []byte - blk, _ := pem.Decode(private) - if blk == nil { - return nil, errors.New("private key error!") - } - if password != "" { - tmp, err := x509.DecryptPEMBlock(blk, []byte(password)) - if err != nil { - return nil, err - } - bytes = tmp - } else { - bytes = blk.Bytes - } - prikey, err = x509.ParsePKCS1PrivateKey(bytes) - if err != nil { - tmp, err := x509.ParsePKCS8PrivateKey(bytes) - if err != nil { - return nil, err - } - prikey = tmp.(*rsa.PrivateKey) - } - return prikey, err + return asymm.DecodeRsaPrivateKey(private, password) } func DecodeRsaPublicKey(pubStr []byte) (*rsa.PublicKey, error) { - blk, _ := pem.Decode(pubStr) - if blk == nil { - return nil, errors.New("public key error") - } - pub, err := x509.ParsePKIXPublicKey(blk.Bytes) - if err != nil { - return nil, err - } - return pub.(*rsa.PublicKey), nil + return asymm.DecodeRsaPublicKey(pubStr) } func EncodeRsaSSHPublicKey(public *rsa.PublicKey) ([]byte, error) { - publicKey, err := ssh.NewPublicKey(public) - if err != nil { - return nil, err - } - return ssh.MarshalAuthorizedKey(publicKey), nil + return asymm.EncodeRsaSSHPublicKey(public) } func GenerateRsaSSHKeyPair(bits int, secret string) (string, string, error) { - pkey, pubkey, err := GenerateRsaKey(bits) - if err != nil { - return "", "", err - } - - pub, err := EncodeRsaSSHPublicKey(pubkey) - if err != nil { - return "", "", err - } - - priv, err := EncodeRsaPrivateKey(pkey, secret) - if err != nil { - return "", "", err - } - return string(priv), string(pub), nil + return asymm.GenerateRsaSSHKeyPair(bits, secret) } -// RSAEncrypt RSA公钥加密 func RSAEncrypt(pub *rsa.PublicKey, data []byte) ([]byte, error) { - return rsa.EncryptPKCS1v15(rand.Reader, pub, data) + return asymm.RSAEncrypt(pub, data) } -// RSADecrypt RSA私钥解密 func RSADecrypt(prikey *rsa.PrivateKey, data []byte) ([]byte, error) { - return rsa.DecryptPKCS1v15(rand.Reader, prikey, data) + return asymm.RSADecrypt(prikey, data) } -// RSASign RSA私钥签名加密 func RSASign(msg, priKey []byte, password string, hashType crypto.Hash) ([]byte, error) { - var prikey *rsa.PrivateKey - var err error - var bytes []byte - blk, _ := pem.Decode(priKey) - if blk == nil { - return []byte{}, errors.New("private key error!") - } - if password != "" { - tmp, err := x509.DecryptPEMBlock(blk, []byte(password)) - if err != nil { - return []byte{}, err - } - bytes = tmp - } else { - bytes = blk.Bytes - } - prikey, err = x509.ParsePKCS1PrivateKey(bytes) - if err != nil { - tmp, err := x509.ParsePKCS8PrivateKey(bytes) - if err != nil { - return []byte{}, err - } - prikey = tmp.(*rsa.PrivateKey) - } - hashMethod := hashType.New() - _, err = hashMethod.Write(msg) - if err != nil { - return nil, err - } - return rsa.SignPKCS1v15(rand.Reader, prikey, hashType, hashMethod.Sum(nil)) + return asymm.RSASign(msg, priKey, password, hashType) } -// RSAVerify RSA公钥签名验证 func RSAVerify(data, msg, pubKey []byte, hashType crypto.Hash) error { - blk, _ := pem.Decode(pubKey) - if blk == nil { - return errors.New("public key error!") - } - pubkey, err := x509.ParsePKIXPublicKey(blk.Bytes) - if err != nil { - return err - } - hashMethod := hashType.New() - _, err = hashMethod.Write(msg) - if err != nil { - return err - } - return rsa.VerifyPKCS1v15(pubkey.(*rsa.PublicKey), hashType, hashMethod.Sum(nil), data) -} - -// copy from crypt/rsa/pkcs1v5.go -var hashPrefixes = map[crypto.Hash][]byte{ - crypto.MD5: {0x30, 0x20, 0x30, 0x0c, 0x06, 0x08, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x02, 0x05, 0x05, 0x00, 0x04, 0x10}, - crypto.SHA1: {0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2b, 0x0e, 0x03, 0x02, 0x1a, 0x05, 0x00, 0x04, 0x14}, - crypto.SHA224: {0x30, 0x2d, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04, 0x05, 0x00, 0x04, 0x1c}, - crypto.SHA256: {0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0x04, 0x20}, - crypto.SHA384: {0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0x04, 0x30}, - crypto.SHA512: {0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0x04, 0x40}, - crypto.MD5SHA1: {}, // A special TLS case which doesn't use an ASN1 prefix. - crypto.RIPEMD160: {0x30, 0x20, 0x30, 0x08, 0x06, 0x06, 0x28, 0xcf, 0x06, 0x03, 0x00, 0x31, 0x04, 0x14}, -} - -// copy from crypt/rsa/pkcs1v5.go -func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int { - e := big.NewInt(int64(pub.E)) - c.Exp(m, e, pub.N) - return c -} - -// copy from crypt/rsa/pkcs1v5.go -func pkcs1v15HashInfo(hash crypto.Hash, inLen int) (hashLen int, prefix []byte, err error) { - // Special case: crypto.Hash(0) is used to indicate that the data is - // signed directly. - if hash == 0 { - return inLen, nil, nil - } - - hashLen = hash.Size() - if inLen != hashLen { - return 0, nil, errors.New("crypto/rsa: input must be hashed message") - } - prefix, ok := hashPrefixes[hash] - if !ok { - return 0, nil, errors.New("crypto/rsa: unsupported hash function") - } - return -} - -// copy from crypt/rsa/pkcs1v5.go -func leftPad(input []byte, size int) (out []byte) { - n := len(input) - if n > size { - n = size - } - out = make([]byte, size) - copy(out[len(out)-n:], input) - return -} -func unLeftPad(input []byte) (out []byte) { - n := len(input) - t := 2 - for i := 2; i < n; i++ { - if input[i] == 0xff { - t = t + 1 - } else { - if input[i] == input[0] { - t = t + int(input[1]) - } - break - } - } - out = make([]byte, n-t) - copy(out, input[t:]) - return -} - -// copy&modified from crypt/rsa/pkcs1v5.go -func publicDecrypt(pub *rsa.PublicKey, hash crypto.Hash, hashed []byte, sig []byte) (out []byte, err error) { - hashLen, prefix, err := pkcs1v15HashInfo(hash, len(hashed)) - if err != nil { - return nil, err - } - - tLen := len(prefix) + hashLen - k := (pub.N.BitLen() + 7) / 8 - if k < tLen+11 { - return nil, fmt.Errorf("length illegal") - } - - c := new(big.Int).SetBytes(sig) - m := encrypt(new(big.Int), pub, c) - em := leftPad(m.Bytes(), k) - out = unLeftPad(em) - - err = nil - return + return asymm.RSAVerify(data, msg, pubKey, hashType) } func RSAEncryptByPrivkey(privt *rsa.PrivateKey, data []byte) ([]byte, error) { - signData, err := rsa.SignPKCS1v15(nil, privt, crypto.Hash(0), data) - if err != nil { - return nil, err - } - return signData, nil + return asymm.RSAEncryptByPrivkey(privt, data) } func RSADecryptByPubkey(pub *rsa.PublicKey, data []byte) ([]byte, error) { - decData, err := publicDecrypt(pub, crypto.Hash(0), nil, data) - if err != nil { - return nil, err - } - return decData, nil + return asymm.RSADecryptByPubkey(pub, data) } diff --git a/sha.go b/sha.go index 408ede9..c8edba5 100644 --- a/sha.go +++ b/sha.go @@ -1,62 +1,43 @@ package starcrypto -import ( - "crypto/sha1" - "crypto/sha256" - "crypto/sha512" -) +import "b612.me/starcrypto/hashx" -// SHA512 输出SHA512校验值 func Sha512(bstr []byte) []byte { - shasum := sha512.New() - shasum.Write(bstr) - return shasum.Sum(nil) + return hashx.Sha512(bstr) } func Sha512Str(bstr []byte) string { - return String(Sha512(bstr)) + return hashx.Sha512Str(bstr) } -// SHA384 输出SHA384校验值 func Sha384(bstr []byte) []byte { - shasum := sha512.New384() - shasum.Write(bstr) - return shasum.Sum(nil) + return hashx.Sha384(bstr) } func Sha384Str(bstr []byte) string { - return String(Sha384(bstr)) + return hashx.Sha384Str(bstr) } -// SHA256 输出SHA256校验值 func Sha256(bstr []byte) []byte { - shasum := sha256.New() - shasum.Write(bstr) - return shasum.Sum(nil) + return hashx.Sha256(bstr) } func Sha256Str(bstr []byte) string { - return String(Sha256(bstr)) + return hashx.Sha256Str(bstr) } -// SHA224 输出SHA224校验值 func Sha224(bstr []byte) []byte { - shasum := sha256.New224() - shasum.Write(bstr) - return shasum.Sum(nil) + return hashx.Sha224(bstr) } func Sha224Str(bstr []byte) string { - return String(Sha224(bstr)) + return hashx.Sha224Str(bstr) } -// SHA1 输出SHA1校验值 func Sha1(bstr []byte) []byte { - shasum := sha1.New() - shasum.Write(bstr) - return shasum.Sum(nil) + return hashx.Sha1(bstr) } func Sha1Str(bstr []byte) string { - return String(Sha512(bstr)) + return hashx.Sha1Str(bstr) } diff --git a/sm2.go b/sm2.go new file mode 100644 index 0000000..e138270 --- /dev/null +++ b/sm2.go @@ -0,0 +1,57 @@ +package starcrypto + +import ( + "b612.me/starcrypto/asymm" + "crypto" + "crypto/ecdsa" + + "github.com/emmansun/gmsm/sm2" +) + +func GenerateSM2Key() (*sm2.PrivateKey, *ecdsa.PublicKey, error) { + return asymm.GenerateSM2Key() +} + +func EncodeSM2PrivateKey(private *sm2.PrivateKey, secret string) ([]byte, error) { + return asymm.EncodeSM2PrivateKey(private, secret) +} + +func EncodeSM2PublicKey(public *ecdsa.PublicKey) ([]byte, error) { + return asymm.EncodeSM2PublicKey(public) +} + +func DecodeSM2PrivateKey(private []byte, password string) (*sm2.PrivateKey, error) { + return asymm.DecodeSM2PrivateKey(private, password) +} + +func DecodeSM2PublicKey(pubStr []byte) (*ecdsa.PublicKey, error) { + return asymm.DecodeSM2PublicKey(pubStr) +} + +func SM2EncryptASN1(pub *ecdsa.PublicKey, data []byte) ([]byte, error) { + return asymm.SM2EncryptASN1(pub, data) +} + +func SM2DecryptASN1(priv *sm2.PrivateKey, data []byte) ([]byte, error) { + return asymm.SM2DecryptASN1(priv, data) +} + +func SM2Sign(priv *sm2.PrivateKey, msg, uid []byte) ([]byte, error) { + return asymm.SM2Sign(priv, msg, uid) +} + +func SM2Verify(pub *ecdsa.PublicKey, msg, sig, uid []byte) bool { + return asymm.SM2Verify(pub, msg, sig, uid) +} + +func SM2SignByPEM(msg, priKey []byte, password string, uid []byte) ([]byte, error) { + return asymm.SM2SignByPEM(msg, priKey, password, uid) +} + +func SM2VerifyByPEM(sig, msg, pubKey []byte, uid []byte) (bool, error) { + return asymm.SM2VerifyByPEM(sig, msg, pubKey, uid) +} + +func IsSM2PublicKey(public crypto.PublicKey) bool { + return asymm.IsSM2PublicKey(public) +} diff --git a/sm3.go b/sm3.go index b70841c..4fa63b8 100644 --- a/sm3.go +++ b/sm3.go @@ -1,15 +1,11 @@ package starcrypto -import ( - "b612.me/starcrypto/sm3" -) +import "b612.me/starcrypto/hashx" func SM3(bstr []byte) []byte { - sm3sum := sm3.New() - sm3sum.Write(bstr) - return sm3sum.Sum(nil) + return hashx.SM3(bstr) } func SM3Str(bstr []byte) string { - return String(SM3(bstr)) + return hashx.SM3Str(bstr) } diff --git a/sm3/sm3.go b/sm3/sm3.go index 7245dd6..7e5788c 100644 --- a/sm3/sm3.go +++ b/sm3/sm3.go @@ -1,259 +1,46 @@ -/* -Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package sm3 import ( - "encoding/binary" "hash" + + gmsm3 "github.com/emmansun/gmsm/sm3" ) type SM3 struct { - digest [8]uint32 // digest represents the partial evaluation of V - length uint64 // length of the message - unhandleMsg []byte // uint8 // + h hash.Hash } -func (sm3 *SM3) ff0(x, y, z uint32) uint32 { return x ^ y ^ z } - -func (sm3 *SM3) ff1(x, y, z uint32) uint32 { return (x & y) | (x & z) | (y & z) } - -func (sm3 *SM3) gg0(x, y, z uint32) uint32 { return x ^ y ^ z } - -func (sm3 *SM3) gg1(x, y, z uint32) uint32 { return (x & y) | (^x & z) } - -func (sm3 *SM3) p0(x uint32) uint32 { return x ^ sm3.leftRotate(x, 9) ^ sm3.leftRotate(x, 17) } - -func (sm3 *SM3) p1(x uint32) uint32 { return x ^ sm3.leftRotate(x, 15) ^ sm3.leftRotate(x, 23) } - -func (sm3 *SM3) leftRotate(x uint32, i uint32) uint32 { return x<<(i%32) | x>>(32-i%32) } - -func (sm3 *SM3) pad() []byte { - msg := sm3.unhandleMsg - msg = append(msg, 0x80) // Append '1' - blockSize := 64 // Append until the resulting message length (in bits) is congruent to 448 (mod 512) - for len(msg)%blockSize != 56 { - msg = append(msg, 0x00) - } - // append message length - msg = append(msg, uint8(sm3.length>>56&0xff)) - msg = append(msg, uint8(sm3.length>>48&0xff)) - msg = append(msg, uint8(sm3.length>>40&0xff)) - msg = append(msg, uint8(sm3.length>>32&0xff)) - msg = append(msg, uint8(sm3.length>>24&0xff)) - msg = append(msg, uint8(sm3.length>>16&0xff)) - msg = append(msg, uint8(sm3.length>>8&0xff)) - msg = append(msg, uint8(sm3.length>>0&0xff)) - - if len(msg)%64 != 0 { - panic("------SM3 Pad: error msgLen =") - } - return msg -} - -func (sm3 *SM3) update(msg []byte) { - var w [68]uint32 - var w1 [64]uint32 - - a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7] - for len(msg) >= 64 { - for i := 0; i < 16; i++ { - w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)]) - } - for i := 16; i < 68; i++ { - w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6] - } - for i := 0; i < 64; i++ { - w1[i] = w[i] ^ w[i+4] - } - A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h - for i := 0; i < 16; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) - } - for i := 16; i < 64; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) - } - a ^= A - b ^= B - c ^= C - d ^= D - e ^= E - f ^= F - g ^= G - h ^= H - msg = msg[64:] - } - sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7] = a, b, c, d, e, f, g, h -} -func (sm3 *SM3) update2(msg []byte) [8]uint32 { - var w [68]uint32 - var w1 [64]uint32 - - a, b, c, d, e, f, g, h := sm3.digest[0], sm3.digest[1], sm3.digest[2], sm3.digest[3], sm3.digest[4], sm3.digest[5], sm3.digest[6], sm3.digest[7] - for len(msg) >= 64 { - for i := 0; i < 16; i++ { - w[i] = binary.BigEndian.Uint32(msg[4*i : 4*(i+1)]) - } - for i := 16; i < 68; i++ { - w[i] = sm3.p1(w[i-16]^w[i-9]^sm3.leftRotate(w[i-3], 15)) ^ sm3.leftRotate(w[i-13], 7) ^ w[i-6] - } - for i := 0; i < 64; i++ { - w1[i] = w[i] ^ w[i+4] - } - A, B, C, D, E, F, G, H := a, b, c, d, e, f, g, h - for i := 0; i < 16; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x79cc4519, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff0(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg0(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) - } - for i := 16; i < 64; i++ { - SS1 := sm3.leftRotate(sm3.leftRotate(A, 12)+E+sm3.leftRotate(0x7a879d8a, uint32(i)), 7) - SS2 := SS1 ^ sm3.leftRotate(A, 12) - TT1 := sm3.ff1(A, B, C) + D + SS2 + w1[i] - TT2 := sm3.gg1(E, F, G) + H + SS1 + w[i] - D = C - C = sm3.leftRotate(B, 9) - B = A - A = TT1 - H = G - G = sm3.leftRotate(F, 19) - F = E - E = sm3.p0(TT2) - } - a ^= A - b ^= B - c ^= C - d ^= D - e ^= E - f ^= F - g ^= G - h ^= H - msg = msg[64:] - } - var digest [8]uint32 - digest[0], digest[1], digest[2], digest[3], digest[4], digest[5], digest[6], digest[7] = a, b, c, d, e, f, g, h - return digest -} - -// 创建哈希计算实例 func New() hash.Hash { - var sm3 SM3 - - sm3.Reset() - return &sm3 + s := &SM3{} + s.Reset() + return s } -// BlockSize returns the hash's underlying block size. -// The Write method must be able to accept any amount -// of data, but it may operate more efficiently if all writes -// are a multiple of the block size. -func (sm3 *SM3) BlockSize() int { return 64 } +func (sm3 *SM3) BlockSize() int { return gmsm3.BlockSize } -// Size returns the number of bytes Sum will return. -func (sm3 *SM3) Size() int { return 32 } +func (sm3 *SM3) Size() int { return gmsm3.Size } -// Reset clears the internal state by zeroing bytes in the state buffer. -// This can be skipped for a newly-created hash state; the default zero-allocated state is correct. func (sm3 *SM3) Reset() { - // Reset digest - sm3.digest[0] = 0x7380166f - sm3.digest[1] = 0x4914b2b9 - sm3.digest[2] = 0x172442d7 - sm3.digest[3] = 0xda8a0600 - sm3.digest[4] = 0xa96f30bc - sm3.digest[5] = 0x163138aa - sm3.digest[6] = 0xe38dee4d - sm3.digest[7] = 0xb0fb0e4e - - sm3.length = 0 // Reset numberic states - sm3.unhandleMsg = []byte{} + sm3.h = gmsm3.New() } -// Write (via the embedded io.Writer interface) adds more data to the running hash. -// It never returns an error. func (sm3 *SM3) Write(p []byte) (int, error) { - toWrite := len(p) - sm3.length += uint64(len(p) * 8) - msg := append(sm3.unhandleMsg, p...) - nblocks := len(msg) / sm3.BlockSize() - sm3.update(msg) - // Update unhandleMsg - sm3.unhandleMsg = msg[nblocks*sm3.BlockSize():] - - return toWrite, nil + if sm3.h == nil { + sm3.Reset() + } + return sm3.h.Write(p) } -// 返回SM3哈希算法摘要值 -// Sum appends the current hash to b and returns the resulting slice. -// It does not change the underlying hash state. func (sm3 *SM3) Sum(in []byte) []byte { - _, _ = sm3.Write(in) - msg := sm3.pad() - //Finalize - digest := sm3.update2(msg) - - // save hash to in - needed := sm3.Size() - if cap(in)-len(in) < needed { - newIn := make([]byte, len(in), len(in)+needed) - copy(newIn, in) - in = newIn + if sm3.h == nil { + sm3.Reset() } - out := in[len(in) : len(in)+needed] - for i := 0; i < 8; i++ { - binary.BigEndian.PutUint32(out[i*4:], digest[i]) - } - return out - + return sm3.h.Sum(in) } func Sm3Sum(data []byte) []byte { - var sm3 SM3 - - sm3.Reset() - _, _ = sm3.Write(data) - return sm3.Sum(nil) + sum := gmsm3.Sum(data) + out := make([]byte, len(sum)) + copy(out, sum[:]) + return out } diff --git a/sm3/sm3_test.go b/sm3/sm3_test.go index 16e0ee0..3d7453e 100644 --- a/sm3/sm3_test.go +++ b/sm3/sm3_test.go @@ -1,64 +1,36 @@ -/* -Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - package sm3 import ( - "fmt" - "io/ioutil" - "os" + "encoding/hex" "testing" ) -func byteToString(b []byte) string { - ret := "" - for i := 0; i < len(b); i++ { - ret += fmt.Sprintf("%02x", b[i]) - } - fmt.Println("ret = ", ret) - return ret -} -func TestSm3(t *testing.T) { - msg := []byte("test") - err := ioutil.WriteFile("ifile", msg, os.FileMode(0644)) // 生成测试文件 - if err != nil { - t.Fatal(err) - } - msg, err = ioutil.ReadFile("ifile") - if err != nil { - t.Fatal(err) - } - hw := New() - hw.Write(msg) - hash := hw.Sum(nil) - fmt.Println(hash) - fmt.Printf("hash = %d\n", len(hash)) - fmt.Printf("%s\n", byteToString(hash)) - hash1 := Sm3Sum(msg) - fmt.Println(hash1) - fmt.Printf("%s\n", byteToString(hash1)) - -} - -func BenchmarkSm3(t *testing.B) { - t.ReportAllocs() - msg := []byte("test") - hw := New() - for i := 0; i < t.N; i++ { - - hw.Sum(nil) - Sm3Sum(msg) +func TestSm3KnownVector(t *testing.T) { + got := Sm3Sum([]byte("abc")) + const want = "66c7f0f462eeedd9d1f2d46bdc10e4e24167c4875cf2f7a2297da02b8f4ba8e0" + if hex.EncodeToString(got) != want { + t.Fatalf("Sm3Sum mismatch, got %s want %s", hex.EncodeToString(got), want) + } +} + +func TestHashSumDoesNotMutateState(t *testing.T) { + h := New() + if _, err := h.Write([]byte("ab")); err != nil { + t.Fatalf("Write failed: %v", err) + } + a := h.Sum(nil) + if _, err := h.Write([]byte("c")); err != nil { + t.Fatalf("Write failed: %v", err) + } + b := h.Sum(nil) + if hex.EncodeToString(a) == hex.EncodeToString(b) { + t.Fatalf("hash state should evolve after further writes") + } +} + +func BenchmarkSm3Sum(b *testing.B) { + msg := []byte("benchmark") + for i := 0; i < b.N; i++ { + _ = Sm3Sum(msg) } } diff --git a/sm4.go b/sm4.go new file mode 100644 index 0000000..df7fb04 --- /dev/null +++ b/sm4.go @@ -0,0 +1,143 @@ +package starcrypto + +import ( + "io" + + "b612.me/starcrypto/symm" +) + +func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + return symm.EncryptSM4(data, key, iv, mode, paddingType) +} + +func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) { + return symm.DecryptSM4(src, key, iv, mode, paddingType) +} + +func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + return symm.EncryptSM4Stream(dst, src, key, iv, mode, paddingType) +} + +func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + return symm.DecryptSM4Stream(dst, src, key, iv, mode, paddingType) +} + +func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { + return symm.EncryptSM4WithOptions(data, key, opts) +} + +func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { + return symm.DecryptSM4WithOptions(src, key, opts) +} + +func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + return symm.EncryptSM4StreamWithOptions(dst, src, key, opts) +} + +func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + return symm.DecryptSM4StreamWithOptions(dst, src, key, opts) +} + +func EncryptSM4GCM(plain, key, nonce, aad []byte) ([]byte, error) { + return symm.EncryptSM4GCM(plain, key, nonce, aad) +} + +func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + return symm.DecryptSM4GCM(ciphertext, key, nonce, aad) +} + +func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.EncryptSM4GCMStream(dst, src, key, nonce, aad) +} + +func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + return symm.DecryptSM4GCMStream(dst, src, key, nonce, aad) +} + +func EncryptSM4CFB(origData, key []byte) ([]byte, error) { + return symm.EncryptSM4CFB(origData, key) +} + +func DecryptSM4CFB(encrypted, key []byte) ([]byte, error) { + return symm.DecryptSM4CFB(encrypted, key) +} + +func EncryptSM4CFBNoBlock(origData, key, iv []byte) ([]byte, error) { + return symm.EncryptSM4CFBNoBlock(origData, key, iv) +} + +func DecryptSM4CFBNoBlock(encrypted, key, iv []byte) ([]byte, error) { + return symm.DecryptSM4CFBNoBlock(encrypted, key, iv) +} + +func EncryptSM4ECB(data, key []byte, paddingType string) ([]byte, error) { + return symm.EncryptSM4ECB(data, key, paddingType) +} + +func DecryptSM4ECB(src, key []byte, paddingType string) ([]byte, error) { + return symm.DecryptSM4ECB(src, key, paddingType) +} + +func EncryptSM4CBC(data, key, iv []byte, paddingType string) ([]byte, error) { + return symm.EncryptSM4CBC(data, key, iv, paddingType) +} + +func DecryptSM4CBC(src, key, iv []byte, paddingType string) ([]byte, error) { + return symm.DecryptSM4CBC(src, key, iv, paddingType) +} + +func EncryptSM4OFB(data, key, iv []byte) ([]byte, error) { + return symm.EncryptSM4OFB(data, key, iv) +} + +func DecryptSM4OFB(src, key, iv []byte) ([]byte, error) { + return symm.DecryptSM4OFB(src, key, iv) +} + +func EncryptSM4CTR(data, key, iv []byte) ([]byte, error) { + return symm.EncryptSM4CTR(data, key, iv) +} + +func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) { + return symm.DecryptSM4CTR(src, key, iv) +} + +func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return symm.EncryptSM4ECBStream(dst, src, key, paddingType) +} + +func DecryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return symm.DecryptSM4ECBStream(dst, src, key, paddingType) +} + +func EncryptSM4CFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptSM4CFBStream(dst, src, key, iv) +} + +func DecryptSM4CFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptSM4CFBStream(dst, src, key, iv) +} + +func EncryptSM4CBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.EncryptSM4CBCStream(dst, src, key, iv, paddingType) +} + +func DecryptSM4CBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return symm.DecryptSM4CBCStream(dst, src, key, iv, paddingType) +} + +func EncryptSM4OFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptSM4OFBStream(dst, src, key, iv) +} + +func DecryptSM4OFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptSM4OFBStream(dst, src, key, iv) +} + +func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.EncryptSM4CTRStream(dst, src, key, iv) +} + +func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return symm.DecryptSM4CTRStream(dst, src, key, iv) +} diff --git a/sm9.go b/sm9.go new file mode 100644 index 0000000..df5009f --- /dev/null +++ b/sm9.go @@ -0,0 +1,108 @@ +package starcrypto + +import ( + "b612.me/starcrypto/asymm" + + gmsm9 "github.com/emmansun/gmsm/sm9" +) + +const ( + SM9SignHID = asymm.SM9SignHID + SM9EncryptHID = asymm.SM9EncryptHID +) + +func GenerateSM9SignMasterKey() (*gmsm9.SignMasterPrivateKey, *gmsm9.SignMasterPublicKey, error) { + return asymm.GenerateSM9SignMasterKey() +} + +func GenerateSM9EncryptMasterKey() (*gmsm9.EncryptMasterPrivateKey, *gmsm9.EncryptMasterPublicKey, error) { + return asymm.GenerateSM9EncryptMasterKey() +} + +func GenerateSM9SignUserKey(master *gmsm9.SignMasterPrivateKey, uid []byte, hid byte) (*gmsm9.SignPrivateKey, error) { + return asymm.GenerateSM9SignUserKey(master, uid, hid) +} + +func GenerateSM9EncryptUserKey(master *gmsm9.EncryptMasterPrivateKey, uid []byte, hid byte) (*gmsm9.EncryptPrivateKey, error) { + return asymm.GenerateSM9EncryptUserKey(master, uid, hid) +} + +func EncodeSM9SignMasterPrivateKey(key *gmsm9.SignMasterPrivateKey) ([]byte, error) { + return asymm.EncodeSM9SignMasterPrivateKey(key) +} + +func DecodeSM9SignMasterPrivateKey(data []byte) (*gmsm9.SignMasterPrivateKey, error) { + return asymm.DecodeSM9SignMasterPrivateKey(data) +} + +func EncodeSM9SignMasterPublicKey(key *gmsm9.SignMasterPublicKey) ([]byte, error) { + return asymm.EncodeSM9SignMasterPublicKey(key) +} + +func DecodeSM9SignMasterPublicKey(data []byte) (*gmsm9.SignMasterPublicKey, error) { + return asymm.DecodeSM9SignMasterPublicKey(data) +} + +func EncodeSM9SignPrivateKey(key *gmsm9.SignPrivateKey) ([]byte, error) { + return asymm.EncodeSM9SignPrivateKey(key) +} + +func DecodeSM9SignPrivateKey(data []byte) (*gmsm9.SignPrivateKey, error) { + return asymm.DecodeSM9SignPrivateKey(data) +} + +func EncodeSM9EncryptMasterPrivateKey(key *gmsm9.EncryptMasterPrivateKey) ([]byte, error) { + return asymm.EncodeSM9EncryptMasterPrivateKey(key) +} + +func DecodeSM9EncryptMasterPrivateKey(data []byte) (*gmsm9.EncryptMasterPrivateKey, error) { + return asymm.DecodeSM9EncryptMasterPrivateKey(data) +} + +func EncodeSM9EncryptMasterPublicKey(key *gmsm9.EncryptMasterPublicKey) ([]byte, error) { + return asymm.EncodeSM9EncryptMasterPublicKey(key) +} + +func DecodeSM9EncryptMasterPublicKey(data []byte) (*gmsm9.EncryptMasterPublicKey, error) { + return asymm.DecodeSM9EncryptMasterPublicKey(data) +} + +func EncodeSM9EncryptPrivateKey(key *gmsm9.EncryptPrivateKey) ([]byte, error) { + return asymm.EncodeSM9EncryptPrivateKey(key) +} + +func DecodeSM9EncryptPrivateKey(data []byte) (*gmsm9.EncryptPrivateKey, error) { + return asymm.DecodeSM9EncryptPrivateKey(data) +} + +func SM9SignHashASN1(priv *gmsm9.SignPrivateKey, hash []byte) ([]byte, error) { + return asymm.SM9SignHashASN1(priv, hash) +} + +func SM9SignASN1(priv *gmsm9.SignPrivateKey, message []byte) ([]byte, error) { + return asymm.SM9SignASN1(priv, message) +} + +func SM9VerifyHashASN1(pub *gmsm9.SignMasterPublicKey, uid []byte, hid byte, hash, sig []byte) bool { + return asymm.SM9VerifyHashASN1(pub, uid, hid, hash, sig) +} + +func SM9VerifyASN1(pub *gmsm9.SignMasterPublicKey, uid []byte, hid byte, message, sig []byte) bool { + return asymm.SM9VerifyASN1(pub, uid, hid, message, sig) +} + +func SM9Encrypt(pub *gmsm9.EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + return asymm.SM9Encrypt(pub, uid, hid, plaintext) +} + +func SM9Decrypt(priv *gmsm9.EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + return asymm.SM9Decrypt(priv, uid, ciphertext) +} + +func SM9EncryptASN1(pub *gmsm9.EncryptMasterPublicKey, uid []byte, hid byte, plaintext []byte) ([]byte, error) { + return asymm.SM9EncryptASN1(pub, uid, hid, plaintext) +} + +func SM9DecryptASN1(priv *gmsm9.EncryptPrivateKey, uid, ciphertext []byte) ([]byte, error) { + return asymm.SM9DecryptASN1(priv, uid, ciphertext) +} diff --git a/symm/aes.go b/symm/aes.go new file mode 100644 index 0000000..1c4518b --- /dev/null +++ b/symm/aes.go @@ -0,0 +1,331 @@ +package symm + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "errors" + "io" + + "b612.me/starcrypto/paddingx" +) + +const ( + PKCS5PADDING = paddingx.PKCS5 + PKCS7PADDING = paddingx.PKCS7 + ZEROPADDING = paddingx.ZERO + ANSIX923PADDING = paddingx.ANSIX923 +) + +var ErrInvalidGCMNonceLength = errors.New("gcm nonce length must be 12 bytes") + +func EncryptAes(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return EncryptAesGCM(data, key, iv, nil) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return encryptWithBlockMode(block, data, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func DecryptAes(src, key, iv []byte, mode, paddingType string) ([]byte, error) { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return DecryptAesGCM(src, key, iv, nil) + } + + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + return decryptWithBlockMode(block, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func EncryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return EncryptAesGCMStream(dst, src, key, iv, nil) + } + + block, err := aes.NewCipher(key) + if err != nil { + return err + } + return encryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func DecryptAesStream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return DecryptAesGCMStream(dst, src, key, iv, nil) + } + + block, err := aes.NewCipher(key) + if err != nil { + return err + } + return decryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func EncryptAesWithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return EncryptAesGCM(data, key, nonceFromOptions(cfg), cfg.AAD) + } + return EncryptAes(data, key, cfg.IV, mode, cfg.Padding) +} + +func DecryptAesWithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return DecryptAesGCM(src, key, nonceFromOptions(cfg), cfg.AAD) + } + return DecryptAes(src, key, cfg.IV, mode, cfg.Padding) +} + +func EncryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return EncryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } + return EncryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) +} + +func DecryptAesStreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return DecryptAesGCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } + return DecryptAesStream(dst, src, key, cfg.IV, mode, cfg.Padding) +} + +func EncryptAesGCM(plain, key, nonce, aad []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return gcm.Seal(nil, nonce, plain, aad), nil +} + +func DecryptAesGCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return gcm.Open(nil, nonce, ciphertext, aad) +} + +func EncryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + plain, err := io.ReadAll(src) + if err != nil { + return err + } + out, err := EncryptAesGCM(plain, key, nonce, aad) + if err != nil { + return err + } + _, err = dst.Write(out) + return err +} + +func DecryptAesGCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + enc, err := io.ReadAll(src) + if err != nil { + return err + } + out, err := DecryptAesGCM(enc, key, nonce, aad) + if err != nil { + return err + } + _, err = dst.Write(out) + return err +} + +func EncryptAesECB(data, key []byte, paddingType string) ([]byte, error) { + return EncryptAes(data, key, nil, MODEECB, paddingType) +} + +func DecryptAesECB(src, key []byte, paddingType string) ([]byte, error) { + return DecryptAes(src, key, nil, MODEECB, paddingType) +} + +func EncryptAesCBC(data, key, iv []byte, paddingType string) ([]byte, error) { + return EncryptAes(data, key, iv, MODECBC, paddingType) +} + +func DecryptAesCBC(src, key, iv []byte, paddingType string) ([]byte, error) { + return DecryptAes(src, key, iv, MODECBC, paddingType) +} + +func EncryptAesCFB(data, key, iv []byte) ([]byte, error) { + return EncryptAes(data, key, iv, MODECFB, "") +} + +func DecryptAesCFB(src, key, iv []byte) ([]byte, error) { + return DecryptAes(src, key, iv, MODECFB, "") +} + +func EncryptAesOFB(data, key, iv []byte) ([]byte, error) { + return EncryptAes(data, key, iv, MODEOFB, "") +} + +func DecryptAesOFB(src, key, iv []byte) ([]byte, error) { + return DecryptAes(src, key, iv, MODEOFB, "") +} + +func EncryptAesCTR(data, key, iv []byte) ([]byte, error) { + return EncryptAes(data, key, iv, MODECTR, "") +} + +func DecryptAesCTR(src, key, iv []byte) ([]byte, error) { + return DecryptAes(src, key, iv, MODECTR, "") +} + +func EncryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return EncryptAesStream(dst, src, key, nil, MODEECB, paddingType) +} + +func DecryptAesECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return DecryptAesStream(dst, src, key, nil, MODEECB, paddingType) +} + +func EncryptAesCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return EncryptAesStream(dst, src, key, iv, MODECBC, paddingType) +} + +func DecryptAesCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return DecryptAesStream(dst, src, key, iv, MODECBC, paddingType) +} + +func EncryptAesCFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptAesStream(dst, src, key, iv, MODECFB, "") +} + +func DecryptAesCFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptAesStream(dst, src, key, iv, MODECFB, "") +} + +func EncryptAesOFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptAesStream(dst, src, key, iv, MODEOFB, "") +} + +func DecryptAesOFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptAesStream(dst, src, key, iv, MODEOFB, "") +} + +func EncryptAesCTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptAesStream(dst, src, key, iv, MODECTR, "") +} + +func DecryptAesCTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptAesStream(dst, src, key, iv, MODECTR, "") +} + +func CustomEncryptAesCFB(origData, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + + encrypted := make([]byte, aes.BlockSize+len(origData)) + iv := encrypted[:block.BlockSize()] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + + body, err := EncryptAesCFB(origData, key, iv) + if err != nil { + return nil, err + } + copy(encrypted[block.BlockSize():], body) + return encrypted, nil +} + +func CustomDecryptAesCFB(encrypted, key []byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(encrypted) < block.BlockSize() { + return nil, errors.New("ciphertext too short") + } + + iv := encrypted[:block.BlockSize()] + return DecryptAesCFB(encrypted[block.BlockSize():], key, iv) +} + +func CustomEncryptAesCFBNoBlock(origData, key, iv []byte) ([]byte, error) { + return EncryptAesCFB(origData, key, iv) +} + +func CustomDecryptAesCFBNoBlock(encrypted, key, iv []byte) ([]byte, error) { + return DecryptAesCFB(encrypted, key, iv) +} + +func PKCS5Padding(cipherText []byte, blockSize int) []byte { + out, _ := paddingx.Pad(cipherText, blockSize, PKCS5PADDING) + return out +} + +func PKCS5Trimming(encrypted []byte) []byte { + if len(encrypted) == 0 { + return nil + } + padding := int(encrypted[len(encrypted)-1]) + if padding <= 0 || padding > len(encrypted) { + return nil + } + for i := len(encrypted) - padding; i < len(encrypted); i++ { + if int(encrypted[i]) != padding { + return nil + } + } + return encrypted[:len(encrypted)-padding] +} + +func PKCS7Padding(cipherText []byte, blockSize int) []byte { + out, _ := paddingx.Pad(cipherText, blockSize, PKCS7PADDING) + return out +} + +func PKCS7Trimming(encrypted []byte, blockSize int) []byte { + out, err := paddingx.Unpad(encrypted, blockSize, PKCS7PADDING) + if err != nil { + return nil + } + return out +} diff --git a/symm/chacha20.go b/symm/chacha20.go new file mode 100644 index 0000000..2236900 --- /dev/null +++ b/symm/chacha20.go @@ -0,0 +1,65 @@ +package symm + +import ( + "crypto/cipher" + "errors" + "io" + + "golang.org/x/crypto/chacha20" + "golang.org/x/crypto/chacha20poly1305" +) + +var ErrInvalidChaCha20NonceLength = errors.New("chacha20 nonce length must be 12 or 24 bytes") + +func EncryptChaCha20(data, key, nonce []byte) ([]byte, error) { + stream, err := chacha20.NewUnauthenticatedCipher(key, nonce) + if err != nil { + return nil, err + } + out := make([]byte, len(data)) + stream.XORKeyStream(out, data) + return out, nil +} + +func DecryptChaCha20(src, key, nonce []byte) ([]byte, error) { + return EncryptChaCha20(src, key, nonce) +} + +func EncryptChaCha20Stream(dst io.Writer, src io.Reader, key, nonce []byte) error { + stream, err := chacha20.NewUnauthenticatedCipher(key, nonce) + if err != nil { + return err + } + return xorStreamCopy(dst, src, stream) +} + +func DecryptChaCha20Stream(dst io.Writer, src io.Reader, key, nonce []byte) error { + return EncryptChaCha20Stream(dst, src, key, nonce) +} + +func EncryptChaCha20Poly1305(plain, key, nonce, aad []byte) ([]byte, error) { + aead, err := newChaCha20AEAD(key, nonce) + if err != nil { + return nil, err + } + return aead.Seal(nil, nonce, plain, aad), nil +} + +func DecryptChaCha20Poly1305(ciphertext, key, nonce, aad []byte) ([]byte, error) { + aead, err := newChaCha20AEAD(key, nonce) + if err != nil { + return nil, err + } + return aead.Open(nil, nonce, ciphertext, aad) +} + +func newChaCha20AEAD(key, nonce []byte) (cipher.AEAD, error) { + switch len(nonce) { + case chacha20poly1305.NonceSize: + return chacha20poly1305.New(key) + case chacha20poly1305.NonceSizeX: + return chacha20poly1305.NewX(key) + default: + return nil, ErrInvalidChaCha20NonceLength + } +} diff --git a/symm/des.go b/symm/des.go new file mode 100644 index 0000000..2d6ffb0 --- /dev/null +++ b/symm/des.go @@ -0,0 +1,70 @@ +package symm + +import ( + "crypto/des" + "io" +) + +func EncryptDESCBC(data, key, iv []byte, paddingType string) ([]byte, error) { + block, err := des.NewCipher(key) + if err != nil { + return nil, err + } + return encryptWithBlockMode(block, data, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func DecryptDESCBC(src, key, iv []byte, paddingType string) ([]byte, error) { + block, err := des.NewCipher(key) + if err != nil { + return nil, err + } + return decryptWithBlockMode(block, src, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func EncryptDESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + block, err := des.NewCipher(key) + if err != nil { + return err + } + return encryptWithBlockModeStream(block, dst, src, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func DecryptDESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + block, err := des.NewCipher(key) + if err != nil { + return err + } + return decryptWithBlockModeStream(block, dst, src, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func Encrypt3DESCBC(data, key, iv []byte, paddingType string) ([]byte, error) { + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + return encryptWithBlockMode(block, data, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func Decrypt3DESCBC(src, key, iv []byte, paddingType string) ([]byte, error) { + block, err := des.NewTripleDESCipher(key) + if err != nil { + return nil, err + } + return decryptWithBlockMode(block, src, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func Encrypt3DESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + block, err := des.NewTripleDESCipher(key) + if err != nil { + return err + } + return encryptWithBlockModeStream(block, dst, src, iv, MODECBC, paddingType, PKCS5PADDING) +} + +func Decrypt3DESCBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + block, err := des.NewTripleDESCipher(key) + if err != nil { + return err + } + return decryptWithBlockModeStream(block, dst, src, iv, MODECBC, paddingType, PKCS5PADDING) +} diff --git a/symm/fuzz_test.go b/symm/fuzz_test.go new file mode 100644 index 0000000..80bdf46 --- /dev/null +++ b/symm/fuzz_test.go @@ -0,0 +1,72 @@ +package symm + +import ( + "bytes" + "testing" +) + +func FuzzAesCBCRoundTrip(f *testing.F) { + f.Add([]byte("fuzz-aes-cbc")) + f.Add([]byte{}) + + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + + f.Fuzz(func(t *testing.T, data []byte) { + enc, err := EncryptAesCBC(data, key, iv, "") + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + dec, err := DecryptAesCBC(enc, key, iv, "") + if err != nil { + t.Fatalf("DecryptAesCBC failed: %v", err) + } + if !bytes.Equal(dec, data) { + t.Fatalf("aes cbc fuzz roundtrip mismatch") + } + }) +} + +func FuzzChaCha20RoundTrip(f *testing.F) { + f.Add([]byte("fuzz-chacha20")) + f.Add([]byte{}) + + key := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("123456789012") + + f.Fuzz(func(t *testing.T, data []byte) { + enc, err := EncryptChaCha20(data, key, nonce) + if err != nil { + t.Fatalf("EncryptChaCha20 failed: %v", err) + } + dec, err := DecryptChaCha20(enc, key, nonce) + if err != nil { + t.Fatalf("DecryptChaCha20 failed: %v", err) + } + if !bytes.Equal(dec, data) { + t.Fatalf("chacha20 fuzz roundtrip mismatch") + } + }) +} + +func FuzzAesCBCStreamRoundTrip(f *testing.F) { + f.Add([]byte("fuzz-aes-stream")) + f.Add([]byte{}) + + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + + f.Fuzz(func(t *testing.T, data []byte) { + enc := &bytes.Buffer{} + dec := &bytes.Buffer{} + if err := EncryptAesCBCStream(enc, bytes.NewReader(data), key, iv, ""); err != nil { + t.Fatalf("EncryptAesCBCStream failed: %v", err) + } + if err := DecryptAesCBCStream(dec, bytes.NewReader(enc.Bytes()), key, iv, ""); err != nil { + t.Fatalf("DecryptAesCBCStream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), data) { + t.Fatalf("aes cbc stream fuzz roundtrip mismatch") + } + }) +} diff --git a/symm/mode.go b/symm/mode.go new file mode 100644 index 0000000..561e6cd --- /dev/null +++ b/symm/mode.go @@ -0,0 +1,325 @@ +package symm + +import ( + "crypto/cipher" + "errors" + "io" + "strings" + + "b612.me/starcrypto/paddingx" +) + +const ( + MODEECB = "ECB" + MODECBC = "CBC" + MODECFB = "CFB" + MODEOFB = "OFB" + MODECTR = "CTR" + MODEGCM = "GCM" +) + +var ErrUnsupportedCipherMode = errors.New("cipher mode not supported") + +func normalizeCipherMode(mode string) string { + return strings.ToUpper(strings.TrimSpace(mode)) +} + +func encryptWithBlockMode(block cipher.Block, data, iv []byte, mode, paddingType, defaultPadding string) ([]byte, error) { + mode = normalizeCipherMode(mode) + if mode == "" { + mode = MODECBC + } + + switch mode { + case MODEECB: + if paddingType == "" { + paddingType = defaultPadding + } + content, err := paddingx.Pad(data, block.BlockSize(), paddingType) + if err != nil { + return nil, err + } + out := make([]byte, len(content)) + ecbEncryptBlocks(block, out, content) + return out, nil + case MODECBC: + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + if paddingType == "" { + paddingType = defaultPadding + } + content, err := paddingx.Pad(data, block.BlockSize(), paddingType) + if err != nil { + return nil, err + } + out := make([]byte, len(content)) + cipher.NewCBCEncrypter(block, iv).CryptBlocks(out, content) + return out, nil + case MODECFB, MODEOFB, MODECTR: + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + stream, err := newCipherStream(block, iv, mode, false) + if err != nil { + return nil, err + } + out := make([]byte, len(data)) + stream.XORKeyStream(out, data) + return out, nil + default: + return nil, ErrUnsupportedCipherMode + } +} + +func decryptWithBlockMode(block cipher.Block, src, iv []byte, mode, paddingType, defaultPadding string) ([]byte, error) { + mode = normalizeCipherMode(mode) + if mode == "" { + mode = MODECBC + } + + switch mode { + case MODEECB: + if len(src) == 0 || len(src)%block.BlockSize() != 0 { + return nil, errors.New("ciphertext is not a full block size") + } + if paddingType == "" { + paddingType = defaultPadding + } + decrypted := make([]byte, len(src)) + ecbDecryptBlocks(block, decrypted, src) + return paddingx.Unpad(decrypted, block.BlockSize(), paddingType) + case MODECBC: + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + if len(src) == 0 || len(src)%block.BlockSize() != 0 { + return nil, errors.New("ciphertext is not a full block size") + } + if paddingType == "" { + paddingType = defaultPadding + } + decrypted := make([]byte, len(src)) + cipher.NewCBCDecrypter(block, iv).CryptBlocks(decrypted, src) + return paddingx.Unpad(decrypted, block.BlockSize(), paddingType) + case MODECFB, MODEOFB, MODECTR: + if len(iv) != block.BlockSize() { + return nil, errors.New("iv length must match block size") + } + stream, err := newCipherStream(block, iv, mode, true) + if err != nil { + return nil, err + } + out := make([]byte, len(src)) + stream.XORKeyStream(out, src) + return out, nil + default: + return nil, ErrUnsupportedCipherMode + } +} + +func encryptWithBlockModeStream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, mode, paddingType, defaultPadding string) error { + mode = normalizeCipherMode(mode) + if mode == "" { + mode = MODECBC + } + + switch mode { + case MODEECB: + if paddingType == "" { + paddingType = defaultPadding + } + return encryptPaddedBlockStream(dst, src, block.BlockSize(), paddingType, func(out, in []byte) { + ecbEncryptBlocks(block, out, in) + }) + case MODECBC: + if len(iv) != block.BlockSize() { + return errors.New("iv length must match block size") + } + if paddingType == "" { + paddingType = defaultPadding + } + modeEnc := cipher.NewCBCEncrypter(block, iv) + return encryptPaddedBlockStream(dst, src, block.BlockSize(), paddingType, modeEnc.CryptBlocks) + case MODECFB, MODEOFB, MODECTR: + if len(iv) != block.BlockSize() { + return errors.New("iv length must match block size") + } + stream, err := newCipherStream(block, iv, mode, false) + if err != nil { + return err + } + return xorStreamCopy(dst, src, stream) + default: + return ErrUnsupportedCipherMode + } +} + +func decryptWithBlockModeStream(block cipher.Block, dst io.Writer, src io.Reader, iv []byte, mode, paddingType, defaultPadding string) error { + mode = normalizeCipherMode(mode) + if mode == "" { + mode = MODECBC + } + + switch mode { + case MODEECB: + if paddingType == "" { + paddingType = defaultPadding + } + return decryptPaddedBlockStream(dst, src, block.BlockSize(), paddingType, func(out, in []byte) { + ecbDecryptBlocks(block, out, in) + }) + case MODECBC: + if len(iv) != block.BlockSize() { + return errors.New("iv length must match block size") + } + if paddingType == "" { + paddingType = defaultPadding + } + modeDec := cipher.NewCBCDecrypter(block, iv) + return decryptPaddedBlockStream(dst, src, block.BlockSize(), paddingType, modeDec.CryptBlocks) + case MODECFB, MODEOFB, MODECTR: + if len(iv) != block.BlockSize() { + return errors.New("iv length must match block size") + } + stream, err := newCipherStream(block, iv, mode, true) + if err != nil { + return err + } + return xorStreamCopy(dst, src, stream) + default: + return ErrUnsupportedCipherMode + } +} + +func newCipherStream(block cipher.Block, iv []byte, mode string, decrypt bool) (cipher.Stream, error) { + switch mode { + case MODECFB: + if decrypt { + return cipher.NewCFBDecrypter(block, iv), nil + } + return cipher.NewCFBEncrypter(block, iv), nil + case MODEOFB: + return cipher.NewOFB(block, iv), nil + case MODECTR: + return cipher.NewCTR(block, iv), nil + default: + return nil, ErrUnsupportedCipherMode + } +} + +func xorStreamCopy(dst io.Writer, src io.Reader, stream cipher.Stream) error { + buf := make([]byte, 32*1024) + out := make([]byte, 32*1024) + for { + n, err := src.Read(buf) + if n > 0 { + stream.XORKeyStream(out[:n], buf[:n]) + if _, werr := dst.Write(out[:n]); werr != nil { + return werr + } + } + if err != nil { + if err == io.EOF { + return nil + } + return err + } + } +} + +func encryptPaddedBlockStream(dst io.Writer, src io.Reader, blockSize int, paddingType string, cryptBlocks func(dst, src []byte)) error { + pending := make([]byte, 0, blockSize*2) + buf := make([]byte, 32*1024) + + for { + n, err := src.Read(buf) + if n > 0 { + pending = append(pending, buf[:n]...) + processLen := len(pending) - blockSize + if processLen > 0 { + processLen -= processLen % blockSize + if processLen > 0 { + out := make([]byte, processLen) + cryptBlocks(out, pending[:processLen]) + if _, werr := dst.Write(out); werr != nil { + return werr + } + pending = append([]byte(nil), pending[processLen:]...) + } + } + } + if err != nil { + if err == io.EOF { + break + } + return err + } + } + + content, err := paddingx.Pad(pending, blockSize, paddingType) + if err != nil { + return err + } + out := make([]byte, len(content)) + cryptBlocks(out, content) + _, err = dst.Write(out) + return err +} + +func decryptPaddedBlockStream(dst io.Writer, src io.Reader, blockSize int, paddingType string, cryptBlocks func(dst, src []byte)) error { + pending := make([]byte, 0, blockSize*2) + buf := make([]byte, 32*1024) + + for { + n, err := src.Read(buf) + if n > 0 { + pending = append(pending, buf[:n]...) + processLen := len(pending) - blockSize + if processLen > 0 { + processLen -= processLen % blockSize + if processLen > 0 { + out := make([]byte, processLen) + cryptBlocks(out, pending[:processLen]) + if _, werr := dst.Write(out); werr != nil { + return werr + } + pending = append([]byte(nil), pending[processLen:]...) + } + } + } + if err != nil { + if err == io.EOF { + break + } + return err + } + } + + if len(pending) == 0 || len(pending)%blockSize != 0 { + return errors.New("ciphertext is not a full block size") + } + + decrypted := make([]byte, len(pending)) + cryptBlocks(decrypted, pending) + out, err := paddingx.Unpad(decrypted, blockSize, paddingType) + if err != nil { + return err + } + _, err = dst.Write(out) + return err +} + +func ecbEncryptBlocks(block cipher.Block, dst, src []byte) { + blockSize := block.BlockSize() + for i := 0; i < len(src); i += blockSize { + block.Encrypt(dst[i:i+blockSize], src[i:i+blockSize]) + } +} + +func ecbDecryptBlocks(block cipher.Block, dst, src []byte) { + blockSize := block.BlockSize() + for i := 0; i < len(src); i += blockSize { + block.Decrypt(dst[i:i+blockSize], src[i:i+blockSize]) + } +} diff --git a/symm/options.go b/symm/options.go new file mode 100644 index 0000000..6b03b1f --- /dev/null +++ b/symm/options.go @@ -0,0 +1,25 @@ +package symm + +// CipherOptions provides a unified configuration for symmetric APIs. +// For GCM mode, Nonce is used; if Nonce is empty, IV is used as fallback. +type CipherOptions struct { + Mode string + Padding string + IV []byte + Nonce []byte + AAD []byte +} + +func normalizeCipherOptions(opts *CipherOptions) CipherOptions { + if opts == nil { + return CipherOptions{} + } + return *opts +} + +func nonceFromOptions(opts CipherOptions) []byte { + if len(opts.Nonce) > 0 { + return opts.Nonce + } + return opts.IV +} diff --git a/symm/sm4.go b/symm/sm4.go new file mode 100644 index 0000000..cdfd877 --- /dev/null +++ b/symm/sm4.go @@ -0,0 +1,276 @@ +package symm + +import ( + "crypto/cipher" + "crypto/rand" + "errors" + "io" + + "github.com/emmansun/gmsm/sm4" +) + +func EncryptSM4(data, key, iv []byte, mode, paddingType string) ([]byte, error) { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return EncryptSM4GCM(data, key, iv, nil) + } + + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return encryptWithBlockMode(block, data, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func DecryptSM4(src, key, iv []byte, mode, paddingType string) ([]byte, error) { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return DecryptSM4GCM(src, key, iv, nil) + } + + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + return decryptWithBlockMode(block, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func EncryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return EncryptSM4GCMStream(dst, src, key, iv, nil) + } + + block, err := sm4.NewCipher(key) + if err != nil { + return err + } + return encryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func DecryptSM4Stream(dst io.Writer, src io.Reader, key, iv []byte, mode, paddingType string) error { + normalizedMode := normalizeCipherMode(mode) + if normalizedMode == MODEGCM { + return DecryptSM4GCMStream(dst, src, key, iv, nil) + } + + block, err := sm4.NewCipher(key) + if err != nil { + return err + } + return decryptWithBlockModeStream(block, dst, src, iv, normalizedMode, paddingType, PKCS7PADDING) +} + +func EncryptSM4WithOptions(data, key []byte, opts *CipherOptions) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return EncryptSM4GCM(data, key, nonceFromOptions(cfg), cfg.AAD) + } + return EncryptSM4(data, key, cfg.IV, mode, cfg.Padding) +} + +func DecryptSM4WithOptions(src, key []byte, opts *CipherOptions) ([]byte, error) { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return DecryptSM4GCM(src, key, nonceFromOptions(cfg), cfg.AAD) + } + return DecryptSM4(src, key, cfg.IV, mode, cfg.Padding) +} + +func EncryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return EncryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } + return EncryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) +} + +func DecryptSM4StreamWithOptions(dst io.Writer, src io.Reader, key []byte, opts *CipherOptions) error { + cfg := normalizeCipherOptions(opts) + mode := normalizeCipherMode(cfg.Mode) + if mode == "" { + mode = MODEGCM + } + if mode == MODEGCM { + return DecryptSM4GCMStream(dst, src, key, nonceFromOptions(cfg), cfg.AAD) + } + return DecryptSM4Stream(dst, src, key, cfg.IV, mode, cfg.Padding) +} + +func EncryptSM4GCM(plain, key, nonce, aad []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return gcm.Seal(nil, nonce, plain, aad), nil +} + +func DecryptSM4GCM(ciphertext, key, nonce, aad []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + if len(nonce) != gcm.NonceSize() { + return nil, ErrInvalidGCMNonceLength + } + return gcm.Open(nil, nonce, ciphertext, aad) +} + +func EncryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + plain, err := io.ReadAll(src) + if err != nil { + return err + } + out, err := EncryptSM4GCM(plain, key, nonce, aad) + if err != nil { + return err + } + _, err = dst.Write(out) + return err +} + +func DecryptSM4GCMStream(dst io.Writer, src io.Reader, key, nonce, aad []byte) error { + enc, err := io.ReadAll(src) + if err != nil { + return err + } + out, err := DecryptSM4GCM(enc, key, nonce, aad) + if err != nil { + return err + } + _, err = dst.Write(out) + return err +} + +func EncryptSM4CFB(origData, key []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + out := make([]byte, block.BlockSize()+len(origData)) + iv := out[:block.BlockSize()] + if _, err := io.ReadFull(rand.Reader, iv); err != nil { + return nil, err + } + body, err := EncryptSM4CFBNoBlock(origData, key, iv) + if err != nil { + return nil, err + } + copy(out[block.BlockSize():], body) + return out, nil +} + +func DecryptSM4CFB(encrypted, key []byte) ([]byte, error) { + block, err := sm4.NewCipher(key) + if err != nil { + return nil, err + } + if len(encrypted) < block.BlockSize() { + return nil, errors.New("ciphertext too short") + } + iv := encrypted[:block.BlockSize()] + return DecryptSM4CFBNoBlock(encrypted[block.BlockSize():], key, iv) +} + +func EncryptSM4CFBNoBlock(origData, key, iv []byte) ([]byte, error) { + return EncryptSM4(origData, key, iv, MODECFB, "") +} + +func DecryptSM4CFBNoBlock(encrypted, key, iv []byte) ([]byte, error) { + return DecryptSM4(encrypted, key, iv, MODECFB, "") +} + +func EncryptSM4ECB(data, key []byte, paddingType string) ([]byte, error) { + return EncryptSM4(data, key, nil, MODEECB, paddingType) +} + +func DecryptSM4ECB(src, key []byte, paddingType string) ([]byte, error) { + return DecryptSM4(src, key, nil, MODEECB, paddingType) +} + +func EncryptSM4CBC(data, key, iv []byte, paddingType string) ([]byte, error) { + return EncryptSM4(data, key, iv, MODECBC, paddingType) +} + +func DecryptSM4CBC(src, key, iv []byte, paddingType string) ([]byte, error) { + return DecryptSM4(src, key, iv, MODECBC, paddingType) +} + +func EncryptSM4OFB(data, key, iv []byte) ([]byte, error) { + return EncryptSM4(data, key, iv, MODEOFB, "") +} + +func DecryptSM4OFB(src, key, iv []byte) ([]byte, error) { + return DecryptSM4(src, key, iv, MODEOFB, "") +} + +func EncryptSM4CTR(data, key, iv []byte) ([]byte, error) { + return EncryptSM4(data, key, iv, MODECTR, "") +} + +func DecryptSM4CTR(src, key, iv []byte) ([]byte, error) { + return DecryptSM4(src, key, iv, MODECTR, "") +} + +func EncryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return EncryptSM4Stream(dst, src, key, nil, MODEECB, paddingType) +} + +func DecryptSM4ECBStream(dst io.Writer, src io.Reader, key []byte, paddingType string) error { + return DecryptSM4Stream(dst, src, key, nil, MODEECB, paddingType) +} + +func EncryptSM4CBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return EncryptSM4Stream(dst, src, key, iv, MODECBC, paddingType) +} + +func DecryptSM4CBCStream(dst io.Writer, src io.Reader, key, iv []byte, paddingType string) error { + return DecryptSM4Stream(dst, src, key, iv, MODECBC, paddingType) +} + +func EncryptSM4CFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptSM4Stream(dst, src, key, iv, MODECFB, "") +} + +func DecryptSM4CFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptSM4Stream(dst, src, key, iv, MODECFB, "") +} + +func EncryptSM4OFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptSM4Stream(dst, src, key, iv, MODEOFB, "") +} + +func DecryptSM4OFBStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptSM4Stream(dst, src, key, iv, MODEOFB, "") +} + +func EncryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return EncryptSM4Stream(dst, src, key, iv, MODECTR, "") +} + +func DecryptSM4CTRStream(dst io.Writer, src io.Reader, key, iv []byte) error { + return DecryptSM4Stream(dst, src, key, iv, MODECTR, "") +} diff --git a/symm/symm_test.go b/symm/symm_test.go new file mode 100644 index 0000000..404598b --- /dev/null +++ b/symm/symm_test.go @@ -0,0 +1,608 @@ +package symm + +import ( + "bytes" + "encoding/hex" + "testing" +) + +func TestEncryptAesDefaultModeCBC(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("aes-default-mode-cbc") + + encDefault, err := EncryptAes(plain, key, iv, "", "") + if err != nil { + t.Fatalf("EncryptAes default failed: %v", err) + } + encCBC, err := EncryptAesCBC(plain, key, iv, "") + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + if !bytes.Equal(encDefault, encCBC) { + t.Fatalf("default mode should match CBC mode") + } +} + +func TestAESCBCRoundTripDefaultPKCS7(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("aes-cbc-with-default-padding") + + enc, err := EncryptAesCBC(plain, key, iv, "") + if err != nil { + t.Fatalf("EncryptAesCBC failed: %v", err) + } + dec, err := DecryptAesCBC(enc, key, iv, "") + if err != nil { + t.Fatalf("DecryptAesCBC failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes cbc mismatch, got %q want %q", dec, plain) + } +} + +func TestAESCFBRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + plain := []byte("aes-cfb-roundtrip") + enc, err := CustomEncryptAesCFB(plain, key) + if err != nil { + t.Fatalf("CustomEncryptAesCFB failed: %v", err) + } + dec, err := CustomDecryptAesCFB(enc, key) + if err != nil { + t.Fatalf("CustomDecryptAesCFB failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes cfb mismatch, got %q want %q", dec, plain) + } +} + +func TestAesGenericModesRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("generic-aes-mode-roundtrip") + + modes := []string{MODEECB, MODECBC, MODECFB, MODEOFB, MODECTR} + for _, mode := range modes { + t.Run(mode, func(t *testing.T) { + useIV := iv + if mode == MODEECB { + useIV = nil + } + enc, err := EncryptAes(plain, key, useIV, mode, "") + if err != nil { + t.Fatalf("EncryptAes(%s) failed: %v", mode, err) + } + dec, err := DecryptAes(enc, key, useIV, mode, "") + if err != nil { + t.Fatalf("DecryptAes(%s) failed: %v", mode, err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes %s mismatch", mode) + } + }) + } +} + +func TestAesDerivedFunctionsRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("aes-derived-func-roundtrip") + + ecbEnc, err := EncryptAesECB(plain, key, "") + if err != nil { + t.Fatalf("EncryptAesECB failed: %v", err) + } + ecbDec, err := DecryptAesECB(ecbEnc, key, "") + if err != nil { + t.Fatalf("DecryptAesECB failed: %v", err) + } + if !bytes.Equal(ecbDec, plain) { + t.Fatalf("aes ecb mismatch") + } + + cfbEnc, err := EncryptAesCFB(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCFB failed: %v", err) + } + cfbDec, err := DecryptAesCFB(cfbEnc, key, iv) + if err != nil { + t.Fatalf("DecryptAesCFB failed: %v", err) + } + if !bytes.Equal(cfbDec, plain) { + t.Fatalf("aes cfb mismatch") + } + + ofbEnc, err := EncryptAesOFB(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesOFB failed: %v", err) + } + ofbDec, err := DecryptAesOFB(ofbEnc, key, iv) + if err != nil { + t.Fatalf("DecryptAesOFB failed: %v", err) + } + if !bytes.Equal(ofbDec, plain) { + t.Fatalf("aes ofb mismatch") + } + + ctrEnc, err := EncryptAesCTR(plain, key, iv) + if err != nil { + t.Fatalf("EncryptAesCTR failed: %v", err) + } + ctrDec, err := DecryptAesCTR(ctrEnc, key, iv) + if err != nil { + t.Fatalf("DecryptAesCTR failed: %v", err) + } + if !bytes.Equal(ctrDec, plain) { + t.Fatalf("aes ctr mismatch") + } +} + +func TestAesStreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("streaming-aes-mode-roundtrip-content") + + modes := []string{MODEECB, MODECBC, MODECFB, MODEOFB, MODECTR} + for _, mode := range modes { + t.Run(mode, func(t *testing.T) { + encBuf := &bytes.Buffer{} + decBuf := &bytes.Buffer{} + useIV := iv + if mode == MODEECB { + useIV = nil + } + if err := EncryptAesStream(encBuf, bytes.NewReader(plain), key, useIV, mode, ""); err != nil { + t.Fatalf("EncryptAesStream(%s) failed: %v", mode, err) + } + if err := DecryptAesStream(decBuf, bytes.NewReader(encBuf.Bytes()), key, useIV, mode, ""); err != nil { + t.Fatalf("DecryptAesStream(%s) failed: %v", mode, err) + } + if !bytes.Equal(decBuf.Bytes(), plain) { + t.Fatalf("aes stream %s mismatch", mode) + } + }) + } +} + +func TestAesStreamInvalidMode(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + err := EncryptAesStream(&bytes.Buffer{}, bytes.NewReader([]byte("x")), key, iv, "BAD", "") + if err == nil { + t.Fatalf("expected invalid mode error") + } +} + +func TestSM4CBCRoundTripDefaultPKCS7(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("sm4-cbc-with-default-padding") + + enc, err := EncryptSM4CBC(plain, key, iv, "") + if err != nil { + t.Fatalf("EncryptSM4CBC failed: %v", err) + } + dec, err := DecryptSM4CBC(enc, key, iv, "") + if err != nil { + t.Fatalf("DecryptSM4CBC failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 cbc mismatch, got %q want %q", dec, plain) + } +} + +func TestSM4CFBRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + plain := []byte("sm4-cfb-roundtrip") + enc, err := EncryptSM4CFB(plain, key) + if err != nil { + t.Fatalf("EncryptSM4CFB failed: %v", err) + } + dec, err := DecryptSM4CFB(enc, key) + if err != nil { + t.Fatalf("DecryptSM4CFB failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 cfb mismatch, got %q want %q", dec, plain) + } +} + +func TestSM4StreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("sm4-stream-roundtrip-data") + + encCBC := &bytes.Buffer{} + if err := EncryptSM4CBCStream(encCBC, bytes.NewReader(plain), key, iv, ""); err != nil { + t.Fatalf("EncryptSM4CBCStream failed: %v", err) + } + decCBC := &bytes.Buffer{} + if err := DecryptSM4CBCStream(decCBC, bytes.NewReader(encCBC.Bytes()), key, iv, ""); err != nil { + t.Fatalf("DecryptSM4CBCStream failed: %v", err) + } + if !bytes.Equal(decCBC.Bytes(), plain) { + t.Fatalf("sm4 cbc stream mismatch") + } + + encCFB := &bytes.Buffer{} + if err := EncryptSM4CFBStream(encCFB, bytes.NewReader(plain), key, iv); err != nil { + t.Fatalf("EncryptSM4CFBStream failed: %v", err) + } + decCFB := &bytes.Buffer{} + if err := DecryptSM4CFBStream(decCFB, bytes.NewReader(encCFB.Bytes()), key, iv); err != nil { + t.Fatalf("DecryptSM4CFBStream failed: %v", err) + } + if !bytes.Equal(decCFB.Bytes(), plain) { + t.Fatalf("sm4 cfb stream mismatch") + } +} + +func TestDESCBCRoundTripDefaultPKCS5(t *testing.T) { + key := []byte("12345678") + iv := []byte("abcdefgh") + plain := []byte("des-cbc") + + enc, err := EncryptDESCBC(plain, key, iv, "") + if err != nil { + t.Fatalf("EncryptDESCBC failed: %v", err) + } + dec, err := DecryptDESCBC(enc, key, iv, "") + if err != nil { + t.Fatalf("DecryptDESCBC failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("des cbc mismatch, got %q want %q", dec, plain) + } +} + +func Test3DESCBCRoundTripDefaultPKCS5(t *testing.T) { + key := []byte("12345678abcdefgh87654321") + iv := []byte("12345678") + plain := []byte("3des-cbc-default-padding") + + enc, err := Encrypt3DESCBC(plain, key, iv, "") + if err != nil { + t.Fatalf("Encrypt3DESCBC failed: %v", err) + } + dec, err := Decrypt3DESCBC(enc, key, iv, "") + if err != nil { + t.Fatalf("Decrypt3DESCBC failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("3des cbc mismatch, got %q want %q", dec, plain) + } +} + +func TestDESStreamRoundTrip(t *testing.T) { + desKey := []byte("12345678") + desIV := []byte("abcdefgh") + desPlain := []byte("des-stream-roundtrip") + + desEnc := &bytes.Buffer{} + if err := EncryptDESCBCStream(desEnc, bytes.NewReader(desPlain), desKey, desIV, ""); err != nil { + t.Fatalf("EncryptDESCBCStream failed: %v", err) + } + desDec := &bytes.Buffer{} + if err := DecryptDESCBCStream(desDec, bytes.NewReader(desEnc.Bytes()), desKey, desIV, ""); err != nil { + t.Fatalf("DecryptDESCBCStream failed: %v", err) + } + if !bytes.Equal(desDec.Bytes(), desPlain) { + t.Fatalf("des cbc stream mismatch") + } + + key3des := []byte("12345678abcdefgh87654321") + iv3des := []byte("12345678") + plain3des := []byte("3des-stream-roundtrip") + + enc3des := &bytes.Buffer{} + if err := Encrypt3DESCBCStream(enc3des, bytes.NewReader(plain3des), key3des, iv3des, ""); err != nil { + t.Fatalf("Encrypt3DESCBCStream failed: %v", err) + } + dec3des := &bytes.Buffer{} + if err := Decrypt3DESCBCStream(dec3des, bytes.NewReader(enc3des.Bytes()), key3des, iv3des, ""); err != nil { + t.Fatalf("Decrypt3DESCBCStream failed: %v", err) + } + if !bytes.Equal(dec3des.Bytes(), plain3des) { + t.Fatalf("3des cbc stream mismatch") + } +} + +func TestCBCInvalidIVLength(t *testing.T) { + _, err := EncryptAesCBC([]byte("a"), []byte("0123456789abcdef"), []byte("short"), PKCS7PADDING) + if err == nil { + t.Fatalf("expected invalid IV length error") + } +} + +func TestCBCInvalidCiphertextLength(t *testing.T) { + _, err := DecryptSM4CBC([]byte("short"), []byte("0123456789abcdef"), []byte("abcdef9876543210"), PKCS7PADDING) + if err == nil { + t.Fatalf("expected invalid ciphertext length error") + } +} + +func TestCBCStreamInvalidCiphertextLength(t *testing.T) { + err := DecryptAesCBCStream(&bytes.Buffer{}, bytes.NewReader([]byte("short")), []byte("0123456789abcdef"), []byte("abcdef9876543210"), PKCS7PADDING) + if err == nil { + t.Fatalf("expected invalid ciphertext length error") + } +} + +func TestSM4DerivedModesRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("sm4-derived-mode-roundtrip") + + ecbEnc, err := EncryptSM4ECB(plain, key, "") + if err != nil { + t.Fatalf("EncryptSM4ECB failed: %v", err) + } + ecbDec, err := DecryptSM4ECB(ecbEnc, key, "") + if err != nil { + t.Fatalf("DecryptSM4ECB failed: %v", err) + } + if !bytes.Equal(ecbDec, plain) { + t.Fatalf("sm4 ecb mismatch") + } + + ofbEnc, err := EncryptSM4OFB(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4OFB failed: %v", err) + } + ofbDec, err := DecryptSM4OFB(ofbEnc, key, iv) + if err != nil { + t.Fatalf("DecryptSM4OFB failed: %v", err) + } + if !bytes.Equal(ofbDec, plain) { + t.Fatalf("sm4 ofb mismatch") + } + + ctrEnc, err := EncryptSM4CTR(plain, key, iv) + if err != nil { + t.Fatalf("EncryptSM4CTR failed: %v", err) + } + ctrDec, err := DecryptSM4CTR(ctrEnc, key, iv) + if err != nil { + t.Fatalf("DecryptSM4CTR failed: %v", err) + } + if !bytes.Equal(ctrDec, plain) { + t.Fatalf("sm4 ctr mismatch") + } +} + +func TestSM4DerivedStreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef") + iv := []byte("abcdef9876543210") + plain := []byte("sm4-derived-stream-roundtrip") + + ecbEnc := &bytes.Buffer{} + if err := EncryptSM4ECBStream(ecbEnc, bytes.NewReader(plain), key, ""); err != nil { + t.Fatalf("EncryptSM4ECBStream failed: %v", err) + } + ecbDec := &bytes.Buffer{} + if err := DecryptSM4ECBStream(ecbDec, bytes.NewReader(ecbEnc.Bytes()), key, ""); err != nil { + t.Fatalf("DecryptSM4ECBStream failed: %v", err) + } + if !bytes.Equal(ecbDec.Bytes(), plain) { + t.Fatalf("sm4 ecb stream mismatch") + } + + ofbEnc := &bytes.Buffer{} + if err := EncryptSM4OFBStream(ofbEnc, bytes.NewReader(plain), key, iv); err != nil { + t.Fatalf("EncryptSM4OFBStream failed: %v", err) + } + ofbDec := &bytes.Buffer{} + if err := DecryptSM4OFBStream(ofbDec, bytes.NewReader(ofbEnc.Bytes()), key, iv); err != nil { + t.Fatalf("DecryptSM4OFBStream failed: %v", err) + } + if !bytes.Equal(ofbDec.Bytes(), plain) { + t.Fatalf("sm4 ofb stream mismatch") + } + + ctrEnc := &bytes.Buffer{} + if err := EncryptSM4CTRStream(ctrEnc, bytes.NewReader(plain), key, iv); err != nil { + t.Fatalf("EncryptSM4CTRStream failed: %v", err) + } + ctrDec := &bytes.Buffer{} + if err := DecryptSM4CTRStream(ctrDec, bytes.NewReader(ctrEnc.Bytes()), key, iv); err != nil { + t.Fatalf("DecryptSM4CTRStream failed: %v", err) + } + if !bytes.Equal(ctrDec.Bytes(), plain) { + t.Fatalf("sm4 ctr stream mismatch") + } +} + +func TestChaCha20RoundTrip(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("chacha20-roundtrip") + + enc, err := EncryptChaCha20(plain, key, nonce) + if err != nil { + t.Fatalf("EncryptChaCha20 failed: %v", err) + } + dec, err := DecryptChaCha20(enc, key, nonce) + if err != nil { + t.Fatalf("DecryptChaCha20 failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("chacha20 mismatch") + } +} + +func TestChaCha20StreamRoundTrip(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("chacha20-stream-roundtrip") + + enc := &bytes.Buffer{} + if err := EncryptChaCha20Stream(enc, bytes.NewReader(plain), key, nonce); err != nil { + t.Fatalf("EncryptChaCha20Stream failed: %v", err) + } + dec := &bytes.Buffer{} + if err := DecryptChaCha20Stream(dec, bytes.NewReader(enc.Bytes()), key, nonce); err != nil { + t.Fatalf("DecryptChaCha20Stream failed: %v", err) + } + if !bytes.Equal(dec.Bytes(), plain) { + t.Fatalf("chacha20 stream mismatch") + } +} + +func TestChaCha20Poly1305RoundTrip(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + nonce := []byte("123456789012") + aad := []byte("aad") + plain := []byte("chacha20-poly1305-roundtrip") + + enc, err := EncryptChaCha20Poly1305(plain, key, nonce, aad) + if err != nil { + t.Fatalf("EncryptChaCha20Poly1305 failed: %v", err) + } + dec, err := DecryptChaCha20Poly1305(enc, key, nonce, aad) + if err != nil { + t.Fatalf("DecryptChaCha20Poly1305 failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("chacha20-poly1305 mismatch") + } +} + +func TestChaCha20Poly1305InvalidNonce(t *testing.T) { + key := []byte("0123456789abcdef0123456789abcdef") + _, err := EncryptChaCha20Poly1305([]byte("x"), key, []byte("short"), nil) + if err == nil { + t.Fatalf("expected invalid nonce error") + } +} + +func TestAESGCMNISTVectorEmpty(t *testing.T) { + key := mustHex(t, "00000000000000000000000000000000") + nonce := mustHex(t, "000000000000000000000000") + enc, err := EncryptAesGCM(nil, key, nonce, nil) + if err != nil { + t.Fatalf("EncryptAesGCM failed: %v", err) + } + want := mustHex(t, "58e2fccefa7e3061367f1d57a4e7455a") + if !bytes.Equal(enc, want) { + t.Fatalf("AES-GCM empty vector mismatch: got %x want %x", enc, want) + } +} + +func TestAESGCMNISTVectorOneBlock(t *testing.T) { + key := mustHex(t, "00000000000000000000000000000000") + nonce := mustHex(t, "000000000000000000000000") + plain := mustHex(t, "00000000000000000000000000000000") + enc, err := EncryptAesGCM(plain, key, nonce, nil) + if err != nil { + t.Fatalf("EncryptAesGCM failed: %v", err) + } + want := mustHex(t, "0388dace60b6a392f328c2b971b2fe78ab6e47d42cec13bdf53a67b21257bddf") + if !bytes.Equal(enc, want) { + t.Fatalf("AES-GCM one-block vector mismatch: got %x want %x", enc, want) + } +} + +func TestSM4ECBStandardVector(t *testing.T) { + key := mustHex(t, "0123456789abcdeffedcba9876543210") + plain := mustHex(t, "0123456789abcdeffedcba9876543210") + enc, err := EncryptSM4ECB(plain, key, ZEROPADDING) + if err != nil { + t.Fatalf("EncryptSM4ECB failed: %v", err) + } + want := mustHex(t, "681edf34d206965e86b3e94f536e4246") + if !bytes.Equal(enc, want) { + t.Fatalf("SM4 ECB vector mismatch: got %x want %x", enc, want) + } +} + +func TestChaCha20Poly1305RFCVector(t *testing.T) { + key := mustHex(t, "808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9f") + nonce := mustHex(t, "070000004041424344454647") + enc, err := EncryptChaCha20Poly1305(nil, key, nonce, nil) + if err != nil { + t.Fatalf("EncryptChaCha20Poly1305 failed: %v", err) + } + want := mustHex(t, "a0784d7a4716f3feb4f64e7f4b39bf04") + if !bytes.Equal(enc, want) { + t.Fatalf("ChaCha20-Poly1305 vector mismatch: got %x want %x", enc, want) + } +} + +func TestAesOptionsDefaultToGCM(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("aes-options-default-gcm") + + enc, err := EncryptAesWithOptions(plain, key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("EncryptAesWithOptions failed: %v", err) + } + dec, err := DecryptAesWithOptions(enc, key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("DecryptAesWithOptions failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("aes options default gcm mismatch") + } +} + +func TestSM4OptionsDefaultToGCM(t *testing.T) { + key := []byte("0123456789abcdef") + nonce := []byte("123456789012") + plain := []byte("sm4-options-default-gcm") + + enc, err := EncryptSM4WithOptions(plain, key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("EncryptSM4WithOptions failed: %v", err) + } + dec, err := DecryptSM4WithOptions(enc, key, &CipherOptions{Nonce: nonce}) + if err != nil { + t.Fatalf("DecryptSM4WithOptions failed: %v", err) + } + if !bytes.Equal(dec, plain) { + t.Fatalf("sm4 options default gcm mismatch") + } +} + +func TestLargeStreamRoundTrip(t *testing.T) { + large := bytes.Repeat([]byte("starcrypto-large-stream-data-0123456789"), 180000) + + aesKey := []byte("0123456789abcdef") + aesIV := []byte("abcdef9876543210") + aesEnc := &bytes.Buffer{} + if err := EncryptAesCBCStream(aesEnc, bytes.NewReader(large), aesKey, aesIV, ""); err != nil { + t.Fatalf("EncryptAesCBCStream large failed: %v", err) + } + aesDec := &bytes.Buffer{} + if err := DecryptAesCBCStream(aesDec, bytes.NewReader(aesEnc.Bytes()), aesKey, aesIV, ""); err != nil { + t.Fatalf("DecryptAesCBCStream large failed: %v", err) + } + if !bytes.Equal(aesDec.Bytes(), large) { + t.Fatalf("aes large stream mismatch") + } + + chachaKey := []byte("0123456789abcdef0123456789abcdef") + chachaNonce := []byte("123456789012") + chachaEnc := &bytes.Buffer{} + if err := EncryptChaCha20Stream(chachaEnc, bytes.NewReader(large), chachaKey, chachaNonce); err != nil { + t.Fatalf("EncryptChaCha20Stream large failed: %v", err) + } + chachaDec := &bytes.Buffer{} + if err := DecryptChaCha20Stream(chachaDec, bytes.NewReader(chachaEnc.Bytes()), chachaKey, chachaNonce); err != nil { + t.Fatalf("DecryptChaCha20Stream large failed: %v", err) + } + if !bytes.Equal(chachaDec.Bytes(), large) { + t.Fatalf("chacha20 large stream mismatch") + } +} + +func mustHex(t *testing.T, s string) []byte { + t.Helper() + b, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("DecodeString failed: %v", err) + } + return b +}