stario/io_context.go

178 lines
4.0 KiB
Go
Raw Permalink Normal View History

package stario
import (
"context"
"errors"
"io"
)
const defaultCopyContextBufferSize = 32 * 1024
type readContextResult struct {
data []byte
err error
}
type writeContextResult struct {
n int
err error
}
// ReadFullContext reads exactly len(buf) bytes unless the context is canceled
// or the underlying reader returns an error.
//
// If ctx is canceled while the underlying Read call is blocked, ReadFullContext
// returns ctx.Err() without waiting for that call to finish. The underlying
// reader may still complete asynchronously afterwards.
func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) {
if reader == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
total := 0
for total < len(buf) {
chunk, err := readContext(ctx, reader, len(buf)-total)
if len(chunk) > 0 {
total += copy(buf[total:], chunk)
}
if err != nil {
if errors.Is(err, io.EOF) {
if total > 0 {
return total, io.ErrUnexpectedEOF
}
return total, io.EOF
}
return total, err
}
if len(chunk) == 0 {
return total, io.ErrNoProgress
}
}
return total, nil
}
// WriteFullContext writes the full payload unless the context is canceled or
// the underlying writer returns an error.
//
// If ctx is canceled while the underlying Write call is blocked,
// WriteFullContext returns ctx.Err() without waiting for that call to finish.
// The underlying writer may still complete asynchronously afterwards.
func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
if writer == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
total := 0
for total < len(data) {
written, err := writeContext(ctx, writer, data[total:])
if written > 0 {
total += written
}
if err != nil {
return total, err
}
if written == 0 {
return total, io.ErrNoProgress
}
}
return total, nil
}
// CopyContext copies from src to dst until EOF, context cancellation, or a
// non-EOF error occurs.
//
// If ctx is canceled while the current read or write is blocked, CopyContext
// returns ctx.Err() without waiting for that operation to finish.
func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
if dst == nil || src == nil {
return 0, io.ErrClosedPipe
}
if ctx == nil {
ctx = context.Background()
}
if err := ctx.Err(); err != nil {
return 0, err
}
var copied int64
for {
chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize)
if len(chunk) > 0 {
written, writeErr := WriteFullContext(ctx, dst, chunk)
copied += int64(written)
if writeErr != nil {
return copied, writeErr
}
}
if readErr != nil {
if errors.Is(readErr, io.EOF) {
return copied, nil
}
return copied, readErr
}
if len(chunk) == 0 {
return copied, io.ErrNoProgress
}
}
}
func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) {
if size <= 0 {
return nil, nil
}
resultCh := make(chan readContextResult, 1)
go func() {
tmp := make([]byte, size)
n, err := reader.Read(tmp)
if n > 0 {
tmp = tmp[:n]
} else {
tmp = nil
}
resultCh <- readContextResult{data: tmp, err: err}
}()
select {
case result := <-resultCh:
return result.data, result.err
case <-ctx.Done():
select {
case result := <-resultCh:
return result.data, result.err
default:
return nil, ctx.Err()
}
}
}
func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) {
if len(data) == 0 {
return 0, nil
}
payload := append([]byte(nil), data...)
resultCh := make(chan writeContextResult, 1)
go func() {
n, err := writer.Write(payload)
resultCh <- writeContextResult{n: n, err: err}
}()
select {
case result := <-resultCh:
return result.n, result.err
case <-ctx.Done():
select {
case result := <-resultCh:
return result.n, result.err
default:
return 0, ctx.Err()
}
}
}