package symm import ( "crypto/aes" "crypto/cipher" "errors" "io" "github.com/emmansun/gmsm/sm4" "golang.org/x/crypto/xts" ) const xtsBlockSize = 16 var ( ErrInvalidXTSDataUnitSize = errors.New("xts data unit size must be a positive multiple of 16") ErrInvalidXTSDataLength = errors.New("xts data length must be a multiple of 16") ErrInvalidXTSKeyLength = errors.New("xts key lengths must be non-empty and equal") ErrInvalidXTSMasterKeyLength = errors.New("xts master key length must be non-empty and even") ErrInvalidAESXTSMasterKeyLength = errors.New("aes xts master key length must be 32, 48, or 64 bytes") ErrInvalidSM4XTSMasterKeyLength = errors.New("sm4 xts master key length must be 32 bytes") ) type xtsCipherFactory func(key1, key2 []byte) (*xts.Cipher, error) func combineXTSKeys(key1, key2 []byte) ([]byte, error) { if len(key1) == 0 || len(key1) != len(key2) { return nil, ErrInvalidXTSKeyLength } out := make([]byte, len(key1)+len(key2)) copy(out, key1) copy(out[len(key1):], key2) return out, nil } func splitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { if len(masterKey) == 0 || len(masterKey)%2 != 0 { return nil, nil, ErrInvalidXTSMasterKeyLength } half := len(masterKey) / 2 k1 := make([]byte, half) k2 := make([]byte, half) copy(k1, masterKey[:half]) copy(k2, masterKey[half:]) return k1, k2, nil } // SplitXTSMasterKey splits a master key into two equal XTS keys. func SplitXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { return splitXTSMasterKey(masterKey) } // SplitAesXTSMasterKey splits AES-XTS master key and validates length (32/48/64 bytes). func SplitAesXTSMasterKey(masterKey []byte) ([]byte, []byte, error) { switch len(masterKey) { case 32, 48, 64: return splitXTSMasterKey(masterKey) default: return nil, nil, ErrInvalidAESXTSMasterKeyLength } } // SplitSM4XTSMasterKey splits SM4-XTS master key and validates length (32 bytes). func SplitSM4XTSMasterKey(masterKey []byte) ([]byte, []byte, error) { if len(masterKey) != 32 { return nil, nil, ErrInvalidSM4XTSMasterKeyLength } return splitXTSMasterKey(masterKey) } func validateXTSDataUnitSize(dataUnitSize int) error { if dataUnitSize <= 0 || dataUnitSize%xtsBlockSize != 0 { return ErrInvalidXTSDataUnitSize } return nil } func validateXTSDataLength(data []byte) error { if len(data)%xtsBlockSize != 0 { return ErrInvalidXTSDataLength } return nil } func cryptXTSAt(c *xts.Cipher, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool) ([]byte, error) { if err := validateXTSDataUnitSize(dataUnitSize); err != nil { return nil, err } if err := validateXTSDataLength(in); err != nil { return nil, err } if len(in) == 0 { return []byte{}, nil } out := make([]byte, len(in)) off := 0 unit := dataUnitIndex for off < len(in) { chunkLen := dataUnitSize if remain := len(in) - off; remain < chunkLen { chunkLen = remain } if decrypt { c.Decrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit) } else { c.Encrypt(out[off:off+chunkLen], in[off:off+chunkLen], unit) } off += chunkLen unit++ } return out, nil } func cryptXTSStreamAt(dst io.Writer, src io.Reader, c *xts.Cipher, dataUnitSize int, dataUnitIndex uint64, decrypt bool) error { if err := validateXTSDataUnitSize(dataUnitSize); err != nil { return err } buf := make([]byte, 32*1024) pending := make([]byte, 0, dataUnitSize*2) unit := dataUnitIndex for { n, err := src.Read(buf) if n > 0 { pending = append(pending, buf[:n]...) processLen := len(pending) / dataUnitSize * dataUnitSize if processLen > 0 { out := make([]byte, processLen) for off := 0; off < processLen; off += dataUnitSize { if decrypt { c.Decrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit) } else { c.Encrypt(out[off:off+dataUnitSize], pending[off:off+dataUnitSize], unit) } unit++ } if _, werr := dst.Write(out); werr != nil { return werr } pending = append([]byte(nil), pending[processLen:]...) } } if err != nil { if err == io.EOF { break } return err } } if len(pending) == 0 { return nil } if err := validateXTSDataLength(pending); err != nil { return err } out := make([]byte, len(pending)) if decrypt { c.Decrypt(out, pending, unit) } else { c.Encrypt(out, pending, unit) } _, err := dst.Write(out) return err } func newXTSCipher(newBlock func([]byte) (cipher.Block, error), key1, key2 []byte) (*xts.Cipher, error) { key, err := combineXTSKeys(key1, key2) if err != nil { return nil, err } return xts.NewCipher(newBlock, key) } func newAesXTS(key1, key2 []byte) (*xts.Cipher, error) { return newXTSCipher(aes.NewCipher, key1, key2) } func newSM4XTS(key1, key2 []byte) (*xts.Cipher, error) { return newXTSCipher(sm4.NewCipher, key1, key2) } func cryptXTSAtWithFactory(factory xtsCipherFactory, in []byte, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) ([]byte, error) { c, err := factory(key1, key2) if err != nil { return nil, err } return cryptXTSAt(c, in, dataUnitSize, dataUnitIndex, decrypt) } func cryptXTSStreamAtWithFactory(factory xtsCipherFactory, dst io.Writer, src io.Reader, dataUnitSize int, dataUnitIndex uint64, decrypt bool, key1, key2 []byte) error { c, err := factory(key1, key2) if err != nil { return err } return cryptXTSStreamAt(dst, src, c, dataUnitSize, dataUnitIndex, decrypt) } func EncryptAesXTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { return EncryptAesXTSAt(plain, key1, key2, dataUnitSize, 0) } func DecryptAesXTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { return DecryptAesXTSAt(ciphertext, key1, key2, dataUnitSize, 0) } func EncryptAesXTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { return cryptXTSAtWithFactory(newAesXTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2) } func DecryptAesXTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { return cryptXTSAtWithFactory(newAesXTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2) } func EncryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { return EncryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) } func DecryptAesXTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { return DecryptAesXTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) } func EncryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2) } func DecryptAesXTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { return cryptXTSStreamAtWithFactory(newAesXTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2) } func EncryptSM4XTS(plain, key1, key2 []byte, dataUnitSize int) ([]byte, error) { return EncryptSM4XTSAt(plain, key1, key2, dataUnitSize, 0) } func DecryptSM4XTS(ciphertext, key1, key2 []byte, dataUnitSize int) ([]byte, error) { return DecryptSM4XTSAt(ciphertext, key1, key2, dataUnitSize, 0) } func EncryptSM4XTSAt(plain, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { return cryptXTSAtWithFactory(newSM4XTS, plain, dataUnitSize, dataUnitIndex, false, key1, key2) } func DecryptSM4XTSAt(ciphertext, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) ([]byte, error) { return cryptXTSAtWithFactory(newSM4XTS, ciphertext, dataUnitSize, dataUnitIndex, true, key1, key2) } func EncryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { return EncryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) } func DecryptSM4XTSStream(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int) error { return DecryptSM4XTSStreamAt(dst, src, key1, key2, dataUnitSize, 0) } func EncryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, false, key1, key2) } func DecryptSM4XTSStreamAt(dst io.Writer, src io.Reader, key1, key2 []byte, dataUnitSize int, dataUnitIndex uint64) error { return cryptXTSStreamAtWithFactory(newSM4XTS, dst, src, dataUnitSize, dataUnitIndex, true, key1, key2) }