154 lines
3.5 KiB
Go
154 lines
3.5 KiB
Go
|
|
package stario
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type chunkedWriter struct {
|
||
|
|
buf bytes.Buffer
|
||
|
|
max int
|
||
|
|
}
|
||
|
|
|
||
|
|
func (writer *chunkedWriter) Write(p []byte) (int, error) {
|
||
|
|
if writer.max <= 0 || len(p) <= writer.max {
|
||
|
|
return writer.buf.Write(p)
|
||
|
|
}
|
||
|
|
return writer.buf.Write(p[:writer.max])
|
||
|
|
}
|
||
|
|
|
||
|
|
type blockingReader struct {
|
||
|
|
started chan struct{}
|
||
|
|
release chan struct{}
|
||
|
|
data []byte
|
||
|
|
once sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func (reader *blockingReader) Read(p []byte) (int, error) {
|
||
|
|
reader.once.Do(func() { close(reader.started) })
|
||
|
|
<-reader.release
|
||
|
|
n := copy(p, reader.data)
|
||
|
|
return n, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type blockingWriter struct {
|
||
|
|
started chan struct{}
|
||
|
|
release chan struct{}
|
||
|
|
once sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func (writer *blockingWriter) Write(p []byte) (int, error) {
|
||
|
|
writer.once.Do(func() { close(writer.started) })
|
||
|
|
<-writer.release
|
||
|
|
return len(p), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestReadFullContext(t *testing.T) {
|
||
|
|
buf := make([]byte, 5)
|
||
|
|
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("ReadFullContext returned error: %v", err)
|
||
|
|
}
|
||
|
|
if n != 5 || string(buf) != "hello" {
|
||
|
|
t.Fatalf("unexpected payload: n=%d data=%q", n, buf)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) {
|
||
|
|
buf := make([]byte, 5)
|
||
|
|
n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf)
|
||
|
|
if !errors.Is(err, io.ErrUnexpectedEOF) {
|
||
|
|
t.Fatalf("expected unexpected EOF, got %v", err)
|
||
|
|
}
|
||
|
|
if n != 3 || string(buf[:n]) != "hey" {
|
||
|
|
t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestReadFullContextCanceledWhileBlocked(t *testing.T) {
|
||
|
|
reader := &blockingReader{
|
||
|
|
started: make(chan struct{}),
|
||
|
|
release: make(chan struct{}),
|
||
|
|
data: []byte("hello"),
|
||
|
|
}
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
done := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
buf := make([]byte, 5)
|
||
|
|
_, err := ReadFullContext(ctx, reader, buf)
|
||
|
|
done <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
<-reader.started
|
||
|
|
cancel()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-done:
|
||
|
|
if !errors.Is(err, context.Canceled) {
|
||
|
|
t.Fatalf("expected context canceled, got %v", err)
|
||
|
|
}
|
||
|
|
case <-time.After(200 * time.Millisecond):
|
||
|
|
t.Fatal("ReadFullContext did not return after cancel")
|
||
|
|
}
|
||
|
|
|
||
|
|
close(reader.release)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWriteFullContext(t *testing.T) {
|
||
|
|
writer := &chunkedWriter{max: 2}
|
||
|
|
n, err := WriteFullContext(context.Background(), writer, []byte("hello"))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("WriteFullContext returned error: %v", err)
|
||
|
|
}
|
||
|
|
if n != 5 || writer.buf.String() != "hello" {
|
||
|
|
t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String())
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestWriteFullContextCanceledWhileBlocked(t *testing.T) {
|
||
|
|
writer := &blockingWriter{
|
||
|
|
started: make(chan struct{}),
|
||
|
|
release: make(chan struct{}),
|
||
|
|
}
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
done := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
_, err := WriteFullContext(ctx, writer, []byte("hello"))
|
||
|
|
done <- err
|
||
|
|
}()
|
||
|
|
|
||
|
|
<-writer.started
|
||
|
|
cancel()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-done:
|
||
|
|
if !errors.Is(err, context.Canceled) {
|
||
|
|
t.Fatalf("expected context canceled, got %v", err)
|
||
|
|
}
|
||
|
|
case <-time.After(200 * time.Millisecond):
|
||
|
|
t.Fatal("WriteFullContext did not return after cancel")
|
||
|
|
}
|
||
|
|
|
||
|
|
close(writer.release)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCopyContext(t *testing.T) {
|
||
|
|
var dst bytes.Buffer
|
||
|
|
written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world"))
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("CopyContext returned error: %v", err)
|
||
|
|
}
|
||
|
|
if written != int64(len("hello world")) || dst.String() != "hello world" {
|
||
|
|
t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String())
|
||
|
|
}
|
||
|
|
}
|