264 lines
8.3 KiB
Go
264 lines
8.3 KiB
Go
|
|
package symm
|
||
|
|
|
||
|
|
import (
|
||
|
|
"crypto/aes"
|
||
|
|
"crypto/cipher"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
|
||
|
|
"github.com/emmansun/gmsm/sm4"
|
||
|
|
"golang.org/x/crypto/xts"
|
||
|
|
)
|
||
|
|
|
||
|
|
const xtsBlockSize = 16
|
||
|
|
|
||
|
|
var (
|
||
|
|
ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16")
|
||
|
|
ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16")
|
||
|
|
ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal")
|
||
|
|
ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even")
|
||
|
|
ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes")
|
||
|
|
ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes")
|
||
|
|
)
|
||
|
|
|
||
|
|
type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error)
|
||
|
|
|
||
|
|
func combineXTSKeys(key1, key2 []byte) ([]byte, error) {
|
||
|
|
if len(key1) == 0 || len(key1) != len(key2) {
|
||
|
|
return nil, ErrInvalidXTSKeyLength
|
||
|
|
}
|
||
|
|
out := make([]byte, len(key1)+len(key2))
|
||
|
|
copy(out, key1)
|
||
|
|
copy(out[len(key1):], key2)
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||
|
|
if len(masterKey) == 0 || len(masterKey)%2 != 0 {
|
||
|
|
return nil, nil, ErrInvalidXTSMasterKeyLength
|
||
|
|
}
|
||
|
|
half := len(masterKey) / 2
|
||
|
|
k1 := make([]byte, half)
|
||
|
|
k2 := make([]byte, half)
|
||
|
|
copy(k1, masterKey[:half])
|
||
|
|
copy(k2, masterKey[half:])
|
||
|
|
return k1, k2, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// SplitXTSMasterKey splits a master key into two equal XTS keys.
|
||
|
|
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||
|
|
return splitXTSMasterKey(masterKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes).
|
||
|
|
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||
|
|
switch len(masterKey) {
|
||
|
|
case 32, 48, 64:
|
||
|
|
return splitXTSMasterKey(masterKey)
|
||
|
|
default:
|
||
|
|
return nil, nil, ErrInvalidAESXTSMasterKeyLength
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes).
|
||
|
|
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
|
||
|
|
if len(masterKey) != 32 {
|
||
|
|
return nil, nil, ErrInvalidSM4XTSMasterKeyLength
|
||
|
|
}
|
||
|
|
return splitXTSMasterKey(masterKey)
|
||
|
|
}
|
||
|
|
|
||
|
|
func validateXTSDataUnitSize(dataUnitSize int) error {
|
||
|
|
if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 {
|
||
|
|
return ErrInvalidXTSDataUnitSize
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func validateXTSDataLength(data []byte) error {
|
||
|
|
if len(data)%xtsBlockSize != 0 {
|
||
|
|
return ErrInvalidXTSDataLength
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) {
|
||
|
|
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if err := validateXTSDataLength(in); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if len(in) == 0 {
|
||
|
|
return []byte{}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
out := make([]byte, len(in))
|
||
|
|
off := 0
|
||
|
|
unit := dataUnitIndex
|
||
|
|
for off < len(in) {
|
||
|
|
chunkLen := dataUnitSize
|
||
|
|
if remain := len(in) - off; remain < chunkLen {
|
||
|
|
chunkLen = remain
|
||
|
|
}
|
||
|
|
if decrypt {
|
||
|
|
c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
|
||
|
|
} else {
|
||
|
|
c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
|
||
|
|
}
|
||
|
|
off += chunkLen
|
||
|
|
unit++
|
||
|
|
}
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error {
|
||
|
|
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
buf := make([]byte, 32*1024)
|
||
|
|
pending := make([]byte, 0, dataUnitSize*2)
|
||
|
|
unit := dataUnitIndex
|
||
|
|
|
||
|
|
for {
|
||
|
|
n, err := src.Read(buf)
|
||
|
|
if n > 0 {
|
||
|
|
pending = append(pending, buf[:n]...)
|
||
|
|
processLen := len(pending) / dataUnitSize * dataUnitSize
|
||
|
|
if processLen > 0 {
|
||
|
|
out := make([]byte, processLen)
|
||
|
|
for off := 0; off < processLen; off += dataUnitSize {
|
||
|
|
if decrypt {
|
||
|
|
c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
|
||
|
|
} else {
|
||
|
|
c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
|
||
|
|
}
|
||
|
|
unit++
|
||
|
|
}
|
||
|
|
if _, werr := dst.Write(out); werr != nil {
|
||
|
|
return werr
|
||
|
|
}
|
||
|
|
pending = append([]byte(nil), pending[processLen:]...)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
if err == io.EOF {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if len(pending) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if err := validateXTSDataLength(pending); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
out := make([]byte, len(pending))
|
||
|
|
if decrypt {
|
||
|
|
c.Decrypt(out, pending, unit)
|
||
|
|
} else {
|
||
|
|
c.Encrypt(out, pending, unit)
|
||
|
|
}
|
||
|
|
_, err := dst.Write(out)
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) {
|
||
|
|
key, err := combineXTSKeys(key1, key2)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return xts.NewCipher(newBlock, key)
|
||
|
|
}
|
||
|
|
|
||
|
|
func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) {
|
||
|
|
return newXTSCipher(aes.NewCipher, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) {
|
||
|
|
return newXTSCipher(sm4.NewCipher, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) {
|
||
|
|
c, err := factory(key1, key2)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt)
|
||
|
|
}
|
||
|
|
|
||
|
|
func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error {
|
||
|
|
c, err := factory(key1, key2)
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||
|
|
return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||
|
|
return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||
|
|
return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||
|
|
return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||
|
|
return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||
|
|
return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||
|
|
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||
|
|
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||
|
|
return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
|
||
|
|
return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||
|
|
return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
|
||
|
|
return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||
|
|
return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
|
||
|
|
return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
|
||
|
|
}
|
||
|
|
|
||
|
|
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||
|
|
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
|
||
|
|
}
|
||
|
|
|
||
|
|
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
|
||
|
|
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
|
||
|
|
}
|