stario/io_context_test.go

154 lines
3.5 KiB
Go
Raw Permalink Normal View History

package stario
import (
"bytes"
"context"
"errors"
"io"
"sync"
"testing"
"time"
)
type chunkedWriter struct {
buf bytes.Buffer
max int
}
func (writer *chunkedWriter) Write(p []byte) (int, error) {
if writer.max <= 0 || len(p) <= writer.max {
return writer.buf.Write(p)
}
return writer.buf.Write(p[:writer.max])
}
type blockingReader struct {
started chan struct{}
release chan struct{}
data []byte
once sync.Once
}
func (reader *blockingReader) Read(p []byte) (int, error) {
reader.once.Do(func() { close(reader.started) })
<-reader.release
n := copy(p, reader.data)
return n, nil
}
type blockingWriter struct {
started chan struct{}
release chan struct{}
once sync.Once
}
func (writer *blockingWriter) Write(p []byte) (int, error) {
writer.once.Do(func() { close(writer.started) })
<-writer.release
return len(p), nil
}
func TestReadFullContext(t *testing.T) {
buf := make([]byte, 5)
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf)
if err != nil {
t.Fatalf("ReadFullContext returned error: %v", err)
}
if n != 5 || string(buf) != "hello" {
t.Fatalf("unexpected payload: n=%d data=%q", n, buf)
}
}
func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) {
buf := make([]byte, 5)
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf)
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Fatalf("expected unexpected EOF, got %v", err)
}
if n != 3 || string(buf[:n]) != "hey" {
t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n])
}
}
func TestReadFullContextCanceledWhileBlocked(t *testing.T) {
reader := &blockingReader{
started: make(chan struct{}),
release: make(chan struct{}),
data: []byte("hello"),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
buf := make([]byte, 5)
_, err := ReadFullContext(ctx, reader, buf)
done <- err
}()
<-reader.started
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("ReadFullContext did not return after cancel")
}
close(reader.release)
}
func TestWriteFullContext(t *testing.T) {
writer := &chunkedWriter{max: 2}
n, err := WriteFullContext(context.Background(), writer, []byte("hello"))
if err != nil {
t.Fatalf("WriteFullContext returned error: %v", err)
}
if n != 5 || writer.buf.String() != "hello" {
t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String())
}
}
func TestWriteFullContextCanceledWhileBlocked(t *testing.T) {
writer := &blockingWriter{
started: make(chan struct{}),
release: make(chan struct{}),
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan error, 1)
go func() {
_, err := WriteFullContext(ctx, writer, []byte("hello"))
done <- err
}()
<-writer.started
cancel()
select {
case err := <-done:
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context canceled, got %v", err)
}
case <-time.After(200 * time.Millisecond):
t.Fatal("WriteFullContext did not return after cancel")
}
close(writer.release)
}
func TestCopyContext(t *testing.T) {
var dst bytes.Buffer
written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world"))
if err != nil {
t.Fatalf("CopyContext returned error: %v", err)
}
if written != int64(len("hello world")) || dst.String() != "hello world" {
t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String())
}
}