package stario import ( "context" "errors" "io" ) const defaultCopyContextBufferSize = 32 * 1024 type readContextResult struct { data []byte err error } type writeContextResult struct { n int err error } // ReadFullContext reads exactly len(buf) bytes unless the context is canceled // or the underlying reader returns an error. // // If ctx is canceled while the underlying Read call is blocked, ReadFullContext // returns ctx.Err() without waiting for that call to finish. The underlying // reader may still complete asynchronously afterwards. func ReadFullContext(ctx context.Context, reader io.Reader, buf []byte) (int, error) { if reader == nil { return 0, io.ErrClosedPipe } if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return 0, err } total := 0 for total < len(buf) { chunk, err := readContext(ctx, reader, len(buf)-total) if len(chunk) > 0 { total += copy(buf[total:], chunk) } if err != nil { if errors.Is(err, io.EOF) { if total > 0 { return total, io.ErrUnexpectedEOF } return total, io.EOF } return total, err } if len(chunk) == 0 { return total, io.ErrNoProgress } } return total, nil } // WriteFullContext writes the full payload unless the context is canceled or // the underlying writer returns an error. // // If ctx is canceled while the underlying Write call is blocked, // WriteFullContext returns ctx.Err() without waiting for that call to finish. // The underlying writer may still complete asynchronously afterwards. func WriteFullContext(ctx context.Context, writer io.Writer, data []byte) (int, error) { if writer == nil { return 0, io.ErrClosedPipe } if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return 0, err } total := 0 for total < len(data) { written, err := writeContext(ctx, writer, data[total:]) if written > 0 { total += written } if err != nil { return total, err } if written == 0 { return total, io.ErrNoProgress } } return total, nil } // CopyContext copies from src to dst until EOF, context cancellation, or a // non-EOF error occurs. // // If ctx is canceled while the current read or write is blocked, CopyContext // returns ctx.Err() without waiting for that operation to finish. func CopyContext(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) { if dst == nil || src == nil { return 0, io.ErrClosedPipe } if ctx == nil { ctx = context.Background() } if err := ctx.Err(); err != nil { return 0, err } var copied int64 for { chunk, readErr := readContext(ctx, src, defaultCopyContextBufferSize) if len(chunk) > 0 { written, writeErr := WriteFullContext(ctx, dst, chunk) copied += int64(written) if writeErr != nil { return copied, writeErr } } if readErr != nil { if errors.Is(readErr, io.EOF) { return copied, nil } return copied, readErr } if len(chunk) == 0 { return copied, io.ErrNoProgress } } } func readContext(ctx context.Context, reader io.Reader, size int) ([]byte, error) { if size <= 0 { return nil, nil } resultCh := make(chan readContextResult, 1) go func() { tmp := make([]byte, size) n, err := reader.Read(tmp) if n > 0 { tmp = tmp[:n] } else { tmp = nil } resultCh <- readContextResult{data: tmp, err: err} }() select { case result := <-resultCh: return result.data, result.err case <-ctx.Done(): select { case result := <-resultCh: return result.data, result.err default: return nil, ctx.Err() } } } func writeContext(ctx context.Context, writer io.Writer, data []byte) (int, error) { if len(data) == 0 { return 0, nil } payload := append([]byte(nil), data...) resultCh := make(chan writeContextResult, 1) go func() { n, err := writer.Write(payload) resultCh <- writeContextResult{n: n, err: err} }() select { case result := <-resultCh: return result.n, result.err case <-ctx.Done(): select { case result := <-resultCh: return result.n, result.err default: return 0, ctx.Err() } } }