starcrypto/symm/xts.go

264 lines
8.3 KiB
Go

package symm
import (
"crypto/aes"
"crypto/cipher"
"errors"
"io"
"github.com/emmansun/gmsm/sm4"
"golang.org/x/crypto/xts"
)
const xtsBlockSize = 16
var (
ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16")
ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16")
ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal")
ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even")
ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes")
ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes")
)
type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error)
func combineXTSKeys(key1, key2 []byte) ([]byte, error) {
if len(key1) == 0 || len(key1) != len(key2) {
return nil, ErrInvalidXTSKeyLength
}
out := make([]byte, len(key1)+len(key2))
copy(out, key1)
copy(out[len(key1):], key2)
return out, nil
}
func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) == 0 || len(masterKey)%2 != 0 {
return nil, nil, ErrInvalidXTSMasterKeyLength
}
half := len(masterKey) / 2
k1 := make([]byte, half)
k2 := make([]byte, half)
copy(k1, masterKey[:half])
copy(k2, masterKey[half:])
return k1, k2, nil
}
// SplitXTSMasterKey splits a master key into two equal XTS keys.
func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
return splitXTSMasterKey(masterKey)
}
// SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes).
func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
switch len(masterKey) {
case 32, 48, 64:
return splitXTSMasterKey(masterKey)
default:
return nil, nil, ErrInvalidAESXTSMasterKeyLength
}
}
// SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes).
func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) {
if len(masterKey) != 32 {
return nil, nil, ErrInvalidSM4XTSMasterKeyLength
}
return splitXTSMasterKey(masterKey)
}
func validateXTSDataUnitSize(dataUnitSize int) error {
if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 {
return ErrInvalidXTSDataUnitSize
}
return nil
}
func validateXTSDataLength(data []byte) error {
if len(data)%xtsBlockSize != 0 {
return ErrInvalidXTSDataLength
}
return nil
}
func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return nil, err
}
if err := validateXTSDataLength(in); err != nil {
return nil, err
}
if len(in) == 0 {
return []byte{}, nil
}
out := make([]byte, len(in))
off := 0
unit := dataUnitIndex
for off < len(in) {
chunkLen := dataUnitSize
if remain := len(in) - off; remain < chunkLen {
chunkLen = remain
}
if decrypt {
c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
} else {
c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit)
}
off += chunkLen
unit++
}
return out, nil
}
func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error {
if err := validateXTSDataUnitSize(dataUnitSize); err != nil {
return err
}
buf := make([]byte, 32*1024)
pending := make([]byte, 0, dataUnitSize*2)
unit := dataUnitIndex
for {
n, err := src.Read(buf)
if n > 0 {
pending = append(pending, buf[:n]...)
processLen := len(pending) / dataUnitSize * dataUnitSize
if processLen > 0 {
out := make([]byte, processLen)
for off := 0; off < processLen; off += dataUnitSize {
if decrypt {
c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
} else {
c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit)
}
unit++
}
if _, werr := dst.Write(out); werr != nil {
return werr
}
pending = append([]byte(nil), pending[processLen:]...)
}
}
if err != nil {
if err == io.EOF {
break
}
return err
}
}
if len(pending) == 0 {
return nil
}
if err := validateXTSDataLength(pending); err != nil {
return err
}
out := make([]byte, len(pending))
if decrypt {
c.Decrypt(out, pending, unit)
} else {
c.Encrypt(out, pending, unit)
}
_, err := dst.Write(out)
return err
}
func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) {
key, err := combineXTSKeys(key1, key2)
if err != nil {
return nil, err
}
return xts.NewCipher(newBlock, key)
}
func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(aes.NewCipher, key1, key2)
}
func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) {
return newXTSCipher(sm4.NewCipher, key1, key2)
}
func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) {
c, err := factory(key1, key2)
if err != nil {
return nil, err
}
return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt)
}
func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error {
c, err := factory(key1, key2)
if err != nil {
return err
}
return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt)
}
func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) {
return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) {
return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2)
}
func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error {
return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0)
}
func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2)
}
func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error {
return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2)
}