178 lines
4.0 KiB
Go
178 lines
4.0 KiB
Go
|
|
package stario
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
)
|
||
|
|
|
||
|
|
const defaultCopyContextBufferSize = 32 * 1024
|
||
|
|
|
||
|
|
type readContextResult struct {
|
||
|
|
data []byte
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
|
||
|
|
type writeContextResult struct {
|
||
|
|
n int
|
||
|
|
err error
|
||
|
|
}
|
||
|
|
|
||
|
|
// ReadFullContext reads exactly len(buf) bytes unless the context is canceled
|
||
|
|
// or the underlying reader returns an error.
|
||
|
|
//
|
||
|
|
// If ctx is canceled while the underlying Read call is blocked, ReadFullContext
|
||
|
|
// returns ctx.Err() without waiting for that call to finish. The underlying
|
||
|
|
// reader may still complete asynchronously afterwards.
|
||
|
|
func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) {
|
||
|
|
if reader == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
if err := ctx.Err(); err != nil {
|
||
|
|
return 0, err
|
||
|
|
}
|
||
|
|
total := 0
|
||
|
|
for total < len(buf) {
|
||
|
|
chunk, err := readContext(ctx, reader, len(buf)-total)
|
||
|
|
if len(chunk) > 0 {
|
||
|
|
total += copy(buf[total:], chunk)
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
if errors.Is(err, io.EOF) {
|
||
|
|
if total > 0 {
|
||
|
|
return total, io.ErrUnexpectedEOF
|
||
|
|
}
|
||
|
|
return total, io.EOF
|
||
|
|
}
|
||
|
|
return total, err
|
||
|
|
}
|
||
|
|
if len(chunk) == 0 {
|
||
|
|
return total, io.ErrNoProgress
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return total, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// WriteFullContext writes the full payload unless the context is canceled or
|
||
|
|
// the underlying writer returns an error.
|
||
|
|
//
|
||
|
|
// If ctx is canceled while the underlying Write call is blocked,
|
||
|
|
// WriteFullContext returns ctx.Err() without waiting for that call to finish.
|
||
|
|
// The underlying writer may still complete asynchronously afterwards.
|
||
|
|
func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||
|
|
if writer == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
if err := ctx.Err(); err != nil {
|
||
|
|
return 0, err
|
||
|
|
}
|
||
|
|
total := 0
|
||
|
|
for total < len(data) {
|
||
|
|
written, err := writeContext(ctx, writer, data[total:])
|
||
|
|
if written > 0 {
|
||
|
|
total += written
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return total, err
|
||
|
|
}
|
||
|
|
if written == 0 {
|
||
|
|
return total, io.ErrNoProgress
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return total, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// CopyContext copies from src to dst until EOF, context cancellation, or a
|
||
|
|
// non-EOF error occurs.
|
||
|
|
//
|
||
|
|
// If ctx is canceled while the current read or write is blocked, CopyContext
|
||
|
|
// returns ctx.Err() without waiting for that operation to finish.
|
||
|
|
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
|
||
|
|
if dst == nil || src == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
if ctx == nil {
|
||
|
|
ctx = context.Background()
|
||
|
|
}
|
||
|
|
if err := ctx.Err(); err != nil {
|
||
|
|
return 0, err
|
||
|
|
}
|
||
|
|
var copied int64
|
||
|
|
for {
|
||
|
|
chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize)
|
||
|
|
if len(chunk) > 0 {
|
||
|
|
written, writeErr := WriteFullContext(ctx, dst, chunk)
|
||
|
|
copied += int64(written)
|
||
|
|
if writeErr != nil {
|
||
|
|
return copied, writeErr
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if readErr != nil {
|
||
|
|
if errors.Is(readErr, io.EOF) {
|
||
|
|
return copied, nil
|
||
|
|
}
|
||
|
|
return copied, readErr
|
||
|
|
}
|
||
|
|
if len(chunk) == 0 {
|
||
|
|
return copied, io.ErrNoProgress
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) {
|
||
|
|
if size <= 0 {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
resultCh := make(chan readContextResult, 1)
|
||
|
|
go func() {
|
||
|
|
tmp := make([]byte, size)
|
||
|
|
n, err := reader.Read(tmp)
|
||
|
|
if n > 0 {
|
||
|
|
tmp = tmp[:n]
|
||
|
|
} else {
|
||
|
|
tmp = nil
|
||
|
|
}
|
||
|
|
resultCh <- readContextResult{data: tmp, err: err}
|
||
|
|
}()
|
||
|
|
select {
|
||
|
|
case result := <-resultCh:
|
||
|
|
return result.data, result.err
|
||
|
|
case <-ctx.Done():
|
||
|
|
select {
|
||
|
|
case result := <-resultCh:
|
||
|
|
return result.data, result.err
|
||
|
|
default:
|
||
|
|
return nil, ctx.Err()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
|
||
|
|
if len(data) == 0 {
|
||
|
|
return 0, nil
|
||
|
|
}
|
||
|
|
payload := append([]byte(nil), data...)
|
||
|
|
resultCh := make(chan writeContextResult, 1)
|
||
|
|
go func() {
|
||
|
|
n, err := writer.Write(payload)
|
||
|
|
resultCh <- writeContextResult{n: n, err: err}
|
||
|
|
}()
|
||
|
|
select {
|
||
|
|
case result := <-resultCh:
|
||
|
|
return result.n, result.err
|
||
|
|
case <-ctx.Done():
|
||
|
|
select {
|
||
|
|
case result := <-resultCh:
|
||
|
|
return result.n, result.err
|
||
|
|
default:
|
||
|
|
return 0, ctx.Err()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|