You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
533 lines
15 KiB
Go
533 lines
15 KiB
Go
package zstd
|
|
|
|
/*
|
|
#include "zstd.h"
|
|
|
|
typedef struct compressStream2_result_s {
|
|
size_t return_code;
|
|
size_t bytes_consumed;
|
|
size_t bytes_written;
|
|
} compressStream2_result;
|
|
|
|
static void ZSTD_compressStream2_wrapper(compressStream2_result* result, ZSTD_CCtx* ctx,
|
|
void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
|
|
ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
|
|
ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
|
|
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_continue);
|
|
|
|
result->return_code = retCode;
|
|
result->bytes_consumed = inBuffer.pos;
|
|
result->bytes_written = outBuffer.pos;
|
|
}
|
|
|
|
static void ZSTD_compressStream2_flush(compressStream2_result* result, ZSTD_CCtx* ctx,
|
|
void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
|
|
ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
|
|
ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
|
|
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_flush);
|
|
|
|
result->return_code = retCode;
|
|
result->bytes_consumed = inBuffer.pos;
|
|
result->bytes_written = outBuffer.pos;
|
|
}
|
|
|
|
static void ZSTD_compressStream2_finish(compressStream2_result* result, ZSTD_CCtx* ctx,
|
|
void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
|
|
ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
|
|
ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
|
|
size_t retCode = ZSTD_compressStream2(ctx, &outBuffer, &inBuffer, ZSTD_e_end);
|
|
|
|
result->return_code = retCode;
|
|
result->bytes_consumed = inBuffer.pos;
|
|
result->bytes_written = outBuffer.pos;
|
|
}
|
|
|
|
// decompressStream2_result is the same as compressStream2_result, but keep 2 separate struct for easier changes
|
|
typedef struct decompressStream2_result_s {
|
|
size_t return_code;
|
|
size_t bytes_consumed;
|
|
size_t bytes_written;
|
|
} decompressStream2_result;
|
|
|
|
static void ZSTD_decompressStream_wrapper(decompressStream2_result* result, ZSTD_DCtx* ctx,
|
|
void* dst, size_t maxDstSize, const void* src, size_t srcSize) {
|
|
ZSTD_outBuffer outBuffer = { dst, maxDstSize, 0 };
|
|
ZSTD_inBuffer inBuffer = { src, srcSize, 0 };
|
|
size_t retCode = ZSTD_decompressStream(ctx, &outBuffer, &inBuffer);
|
|
|
|
result->return_code = retCode;
|
|
result->bytes_consumed = inBuffer.pos;
|
|
result->bytes_written = outBuffer.pos;
|
|
}
|
|
*/
|
|
import "C"
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"runtime"
|
|
"sync"
|
|
"unsafe"
|
|
)
|
|
|
|
var errShortRead = errors.New("short read")
|
|
var errReaderClosed = errors.New("Reader is closed")
|
|
|
|
// Writer is an io.WriteCloser that zstd-compresses its input.
|
|
type Writer struct {
|
|
CompressionLevel int
|
|
|
|
ctx *C.ZSTD_CCtx
|
|
dict []byte
|
|
srcBuffer []byte
|
|
dstBuffer []byte
|
|
firstError error
|
|
underlyingWriter io.Writer
|
|
resultBuffer *C.compressStream2_result
|
|
}
|
|
|
|
func resize(in []byte, newSize int) []byte {
|
|
if in == nil {
|
|
return make([]byte, newSize)
|
|
}
|
|
if newSize <= cap(in) {
|
|
return in[:newSize]
|
|
}
|
|
toAdd := newSize - len(in)
|
|
return append(in, make([]byte, toAdd)...)
|
|
}
|
|
|
|
// NewWriter creates a new Writer with default compression options. Writes to
|
|
// the writer will be written in compressed form to w.
|
|
func NewWriter(w io.Writer) *Writer {
|
|
return NewWriterLevelDict(w, DefaultCompression, nil)
|
|
}
|
|
|
|
// NewWriterLevel is like NewWriter but specifies the compression level instead
|
|
// of assuming default compression.
|
|
//
|
|
// The level can be DefaultCompression or any integer value between BestSpeed
|
|
// and BestCompression inclusive.
|
|
func NewWriterLevel(w io.Writer, level int) *Writer {
|
|
return NewWriterLevelDict(w, level, nil)
|
|
|
|
}
|
|
|
|
// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
|
|
// compress with. If the dictionary is empty or nil it is ignored. The dictionary
|
|
// should not be modified until the writer is closed.
|
|
func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
|
|
var err error
|
|
ctx := C.ZSTD_createCStream()
|
|
|
|
// Load dictionnary if any
|
|
if dict != nil {
|
|
err = getError(int(C.ZSTD_CCtx_loadDictionary(ctx,
|
|
unsafe.Pointer(&dict[0]),
|
|
C.size_t(len(dict)),
|
|
)))
|
|
}
|
|
|
|
if err == nil {
|
|
// Only set level if the ctx is not in error already
|
|
err = getError(int(C.ZSTD_CCtx_setParameter(ctx, C.ZSTD_c_compressionLevel, C.int(level))))
|
|
}
|
|
|
|
return &Writer{
|
|
CompressionLevel: level,
|
|
ctx: ctx,
|
|
dict: dict,
|
|
srcBuffer: make([]byte, 0),
|
|
dstBuffer: make([]byte, CompressBound(1024)),
|
|
firstError: err,
|
|
underlyingWriter: w,
|
|
resultBuffer: new(C.compressStream2_result),
|
|
}
|
|
}
|
|
|
|
// Write writes a compressed form of p to the underlying io.Writer.
|
|
func (w *Writer) Write(p []byte) (int, error) {
|
|
if w.firstError != nil {
|
|
return 0, w.firstError
|
|
}
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
// Check if dstBuffer is enough
|
|
w.dstBuffer = w.dstBuffer[0:cap(w.dstBuffer)]
|
|
if len(w.dstBuffer) < CompressBound(len(p)) {
|
|
w.dstBuffer = make([]byte, CompressBound(len(p)))
|
|
}
|
|
|
|
// Do not do an extra memcopy if zstd ingest all input data
|
|
srcData := p
|
|
fastPath := len(w.srcBuffer) == 0
|
|
if !fastPath {
|
|
w.srcBuffer = append(w.srcBuffer, p...)
|
|
srcData = w.srcBuffer
|
|
}
|
|
|
|
if len(srcData) == 0 {
|
|
// this is technically unnecessary: srcData is p or w.srcBuffer, and len() > 0 checked above
|
|
// but this ensures the code can change without dereferencing an srcData[0]
|
|
return 0, nil
|
|
}
|
|
C.ZSTD_compressStream2_wrapper(
|
|
w.resultBuffer,
|
|
w.ctx,
|
|
unsafe.Pointer(&w.dstBuffer[0]),
|
|
C.size_t(len(w.dstBuffer)),
|
|
unsafe.Pointer(&srcData[0]),
|
|
C.size_t(len(srcData)),
|
|
)
|
|
ret := int(w.resultBuffer.return_code)
|
|
if err := getError(ret); err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
consumed := int(w.resultBuffer.bytes_consumed)
|
|
if !fastPath {
|
|
w.srcBuffer = w.srcBuffer[consumed:]
|
|
} else {
|
|
remaining := len(p) - consumed
|
|
if remaining > 0 {
|
|
// We still have some non-consumed data, copy remaining data to srcBuffer
|
|
// Try to not reallocate w.srcBuffer if we already have enough space
|
|
if cap(w.srcBuffer) >= remaining {
|
|
w.srcBuffer = w.srcBuffer[0:remaining]
|
|
} else {
|
|
w.srcBuffer = make([]byte, remaining)
|
|
}
|
|
copy(w.srcBuffer, p[consumed:])
|
|
}
|
|
}
|
|
|
|
written := int(w.resultBuffer.bytes_written)
|
|
// Write to underlying buffer
|
|
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
|
|
|
|
// Same behaviour as zlib, we can't know how much data we wrote, only
|
|
// if there was an error
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return len(p), err
|
|
}
|
|
|
|
// Flush writes any unwritten data to the underlying io.Writer.
|
|
func (w *Writer) Flush() error {
|
|
if w.firstError != nil {
|
|
return w.firstError
|
|
}
|
|
|
|
ret := 1 // So we loop at least once
|
|
for ret > 0 {
|
|
var srcPtr *byte // Do not point anywhere, if src is empty
|
|
if len(w.srcBuffer) > 0 {
|
|
srcPtr = &w.srcBuffer[0]
|
|
}
|
|
|
|
C.ZSTD_compressStream2_flush(
|
|
w.resultBuffer,
|
|
w.ctx,
|
|
unsafe.Pointer(&w.dstBuffer[0]),
|
|
C.size_t(len(w.dstBuffer)),
|
|
unsafe.Pointer(srcPtr),
|
|
C.size_t(len(w.srcBuffer)),
|
|
)
|
|
ret = int(w.resultBuffer.return_code)
|
|
if err := getError(ret); err != nil {
|
|
return err
|
|
}
|
|
w.srcBuffer = w.srcBuffer[w.resultBuffer.bytes_consumed:]
|
|
written := int(w.resultBuffer.bytes_written)
|
|
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if ret > 0 { // We have a hint if we need to resize the dstBuffer
|
|
w.dstBuffer = w.dstBuffer[:cap(w.dstBuffer)]
|
|
if len(w.dstBuffer) < ret {
|
|
w.dstBuffer = make([]byte, ret)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close closes the Writer, flushing any unwritten data to the underlying
|
|
// io.Writer and freeing objects, but does not close the underlying io.Writer.
|
|
func (w *Writer) Close() error {
|
|
if w.firstError != nil {
|
|
return w.firstError
|
|
}
|
|
|
|
ret := 1 // So we loop at least once
|
|
for ret > 0 {
|
|
var srcPtr *byte // Do not point anywhere, if src is empty
|
|
if len(w.srcBuffer) > 0 {
|
|
srcPtr = &w.srcBuffer[0]
|
|
}
|
|
|
|
C.ZSTD_compressStream2_finish(
|
|
w.resultBuffer,
|
|
w.ctx,
|
|
unsafe.Pointer(&w.dstBuffer[0]),
|
|
C.size_t(len(w.dstBuffer)),
|
|
unsafe.Pointer(srcPtr),
|
|
C.size_t(len(w.srcBuffer)),
|
|
)
|
|
ret = int(w.resultBuffer.return_code)
|
|
if err := getError(ret); err != nil {
|
|
return err
|
|
}
|
|
w.srcBuffer = w.srcBuffer[w.resultBuffer.bytes_consumed:]
|
|
written := int(w.resultBuffer.bytes_written)
|
|
_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
|
|
if err != nil {
|
|
C.ZSTD_freeCStream(w.ctx)
|
|
return err
|
|
}
|
|
|
|
if ret > 0 { // We have a hint if we need to resize the dstBuffer
|
|
w.dstBuffer = w.dstBuffer[:cap(w.dstBuffer)]
|
|
if len(w.dstBuffer) < ret {
|
|
w.dstBuffer = make([]byte, ret)
|
|
}
|
|
}
|
|
}
|
|
|
|
return getError(int(C.ZSTD_freeCStream(w.ctx)))
|
|
}
|
|
|
|
// cSize is the recommended size of reader.compressionBuffer. This func and
|
|
// invocation allow for a one-time check for validity.
|
|
var cSize = func() int {
|
|
v := int(C.ZSTD_DStreamInSize())
|
|
if v <= 0 {
|
|
panic(fmt.Errorf("ZSTD_DStreamInSize() returned invalid size: %v", v))
|
|
}
|
|
return v
|
|
}()
|
|
|
|
// dSize is the recommended size of reader.decompressionBuffer. This func and
|
|
// invocation allow for a one-time check for validity.
|
|
var dSize = func() int {
|
|
v := int(C.ZSTD_DStreamOutSize())
|
|
if v <= 0 {
|
|
panic(fmt.Errorf("ZSTD_DStreamOutSize() returned invalid size: %v", v))
|
|
}
|
|
return v
|
|
}()
|
|
|
|
// cPool is a pool of buffers for use in reader.compressionBuffer. Buffers are
|
|
// taken from the pool in NewReaderDict, returned in reader.Close(). Returns a
|
|
// pointer to a slice to avoid the extra allocation of returning the slice as a
|
|
// value.
|
|
var cPool = sync.Pool{
|
|
New: func() interface{} {
|
|
buff := make([]byte, cSize)
|
|
return &buff
|
|
},
|
|
}
|
|
|
|
// dPool is a pool of buffers for use in reader.decompressionBuffer. Buffers are
|
|
// taken from the pool in NewReaderDict, returned in reader.Close(). Returns a
|
|
// pointer to a slice to avoid the extra allocation of returning the slice as a
|
|
// value.
|
|
var dPool = sync.Pool{
|
|
New: func() interface{} {
|
|
buff := make([]byte, dSize)
|
|
return &buff
|
|
},
|
|
}
|
|
|
|
// reader is an io.ReadCloser that decompresses when read from.
|
|
type reader struct {
|
|
ctx *C.ZSTD_DCtx
|
|
compressionBuffer []byte
|
|
compressionLeft int
|
|
decompressionBuffer []byte
|
|
decompOff int
|
|
decompSize int
|
|
dict []byte
|
|
firstError error
|
|
recommendedSrcSize int
|
|
resultBuffer *C.decompressStream2_result
|
|
underlyingReader io.Reader
|
|
}
|
|
|
|
// NewReader creates a new io.ReadCloser. Reads from the returned ReadCloser
|
|
// read and decompress data from r. It is the caller's responsibility to call
|
|
// Close on the ReadCloser when done. If this is not done, underlying objects
|
|
// in the zstd library will not be freed.
|
|
func NewReader(r io.Reader) io.ReadCloser {
|
|
return NewReaderDict(r, nil)
|
|
}
|
|
|
|
// NewReaderDict is like NewReader but uses a preset dictionary. NewReaderDict
|
|
// ignores the dictionary if it is nil.
|
|
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
|
|
var err error
|
|
ctx := C.ZSTD_createDStream()
|
|
if len(dict) == 0 {
|
|
err = getError(int(C.ZSTD_initDStream(ctx)))
|
|
} else {
|
|
err = getError(int(C.ZSTD_DCtx_reset(ctx, C.ZSTD_reset_session_only)))
|
|
if err == nil {
|
|
// Only load dictionary if we succesfully inited the context
|
|
err = getError(int(C.ZSTD_DCtx_loadDictionary(
|
|
ctx,
|
|
unsafe.Pointer(&dict[0]),
|
|
C.size_t(len(dict)))))
|
|
}
|
|
}
|
|
compressionBufferP := cPool.Get().(*[]byte)
|
|
decompressionBufferP := dPool.Get().(*[]byte)
|
|
return &reader{
|
|
ctx: ctx,
|
|
dict: dict,
|
|
compressionBuffer: *compressionBufferP,
|
|
decompressionBuffer: *decompressionBufferP,
|
|
firstError: err,
|
|
recommendedSrcSize: cSize,
|
|
resultBuffer: new(C.decompressStream2_result),
|
|
underlyingReader: r,
|
|
}
|
|
}
|
|
|
|
// Close frees the allocated C objects
|
|
func (r *reader) Close() error {
|
|
if r.firstError != nil {
|
|
return r.firstError
|
|
}
|
|
|
|
cb := r.compressionBuffer
|
|
db := r.decompressionBuffer
|
|
// Ensure that we won't resuse buffer
|
|
r.firstError = errReaderClosed
|
|
r.compressionBuffer = nil
|
|
r.decompressionBuffer = nil
|
|
|
|
cPool.Put(&cb)
|
|
dPool.Put(&db)
|
|
return getError(int(C.ZSTD_freeDStream(r.ctx)))
|
|
}
|
|
|
|
func (r *reader) Read(p []byte) (int, error) {
|
|
if r.firstError != nil {
|
|
return 0, r.firstError
|
|
}
|
|
|
|
if len(p) == 0 {
|
|
return 0, nil
|
|
}
|
|
|
|
// If we already have some uncompressed bytes, return without blocking
|
|
if r.decompSize > r.decompOff {
|
|
if r.decompSize-r.decompOff > len(p) {
|
|
copy(p, r.decompressionBuffer[r.decompOff:])
|
|
r.decompOff += len(p)
|
|
return len(p), nil
|
|
}
|
|
// From https://golang.org/pkg/io/#Reader
|
|
// > Read conventionally returns what is available instead of waiting for more.
|
|
copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
|
|
got := r.decompSize - r.decompOff
|
|
r.decompOff = r.decompSize
|
|
return got, nil
|
|
}
|
|
|
|
// Repeatedly read from the underlying reader until we get
|
|
// at least one zstd block, so that we don't block if the
|
|
// other end has flushed a block.
|
|
for {
|
|
// - If the last decompression didn't entirely fill the decompression buffer,
|
|
// zstd flushed all it could, and needs new data. In that case, do 1 Read.
|
|
// - If the last decompression did entirely fill the decompression buffer,
|
|
// it might have needed more room to decompress the input. In that case,
|
|
// don't do any unnecessary Read that might block.
|
|
needsData := r.decompSize < len(r.decompressionBuffer)
|
|
|
|
var src []byte
|
|
if !needsData {
|
|
src = r.compressionBuffer[:r.compressionLeft]
|
|
} else {
|
|
src = r.compressionBuffer
|
|
var n int
|
|
var err error
|
|
// Read until data arrives or an error occurs.
|
|
for n == 0 && err == nil {
|
|
n, err = r.underlyingReader.Read(src[r.compressionLeft:])
|
|
}
|
|
if err != nil && err != io.EOF { // Handle underlying reader errors first
|
|
return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
|
|
}
|
|
if n == 0 {
|
|
// Ideally, we'd return with ErrUnexpectedEOF in all cases where the stream was unexpectedly EOF'd
|
|
// during a block or frame, i.e. when there are incomplete, pending compression data.
|
|
// However, it's hard to detect those cases with zstd. Namely, there is no way to know the size of
|
|
// the current buffered compression data in the zstd stream internal buffers.
|
|
// Best effort: throw ErrUnexpectedEOF if we still have some pending buffered compression data that
|
|
// zstd doesn't want to accept.
|
|
// If we don't have any buffered compression data but zstd still has some in its internal buffers,
|
|
// we will return with EOF instead.
|
|
if r.compressionLeft > 0 {
|
|
return 0, io.ErrUnexpectedEOF
|
|
}
|
|
return 0, io.EOF
|
|
}
|
|
src = src[:r.compressionLeft+n]
|
|
}
|
|
|
|
// C code
|
|
var srcPtr *byte // Do not point anywhere, if src is empty
|
|
if len(src) > 0 {
|
|
srcPtr = &src[0]
|
|
}
|
|
|
|
C.ZSTD_decompressStream_wrapper(
|
|
r.resultBuffer,
|
|
r.ctx,
|
|
unsafe.Pointer(&r.decompressionBuffer[0]),
|
|
C.size_t(len(r.decompressionBuffer)),
|
|
unsafe.Pointer(srcPtr),
|
|
C.size_t(len(src)),
|
|
)
|
|
retCode := int(r.resultBuffer.return_code)
|
|
|
|
// Keep src here even though we reuse later, the code might be deleted at some point
|
|
runtime.KeepAlive(src)
|
|
if err := getError(retCode); err != nil {
|
|
return 0, fmt.Errorf("failed to decompress: %s", err)
|
|
}
|
|
|
|
// Put everything in buffer
|
|
bytesConsumed := int(r.resultBuffer.bytes_consumed)
|
|
if bytesConsumed < len(src) {
|
|
left := src[bytesConsumed:]
|
|
copy(r.compressionBuffer, left)
|
|
}
|
|
r.compressionLeft = len(src) - bytesConsumed
|
|
r.decompSize = int(r.resultBuffer.bytes_written)
|
|
r.decompOff = copy(p, r.decompressionBuffer[:r.decompSize])
|
|
|
|
// Resize buffers
|
|
nsize := retCode // Hint for next src buffer size
|
|
if nsize <= 0 {
|
|
// Reset to recommended size
|
|
nsize = r.recommendedSrcSize
|
|
}
|
|
if nsize < r.compressionLeft {
|
|
nsize = r.compressionLeft
|
|
}
|
|
r.compressionBuffer = resize(r.compressionBuffer, nsize)
|
|
|
|
if r.decompOff > 0 {
|
|
return r.decompOff, nil
|
|
}
|
|
}
|
|
}
|