package stario import ( "bytes" "context" "errors" "io" "sync" "testing" "time" ) type chunkedWriter struct { buf bytes.Buffer max int } func (writer *chunkedWriter) Write(p []byte) (int, error) { if writer.max <= 0 || len(p) <= writer.max { return writer.buf.Write(p) } return writer.buf.Write(p[:writer.max]) } type blockingReader struct { started chan struct{} release chan struct{} data []byte once sync.Once } func (reader *blockingReader) Read(p []byte) (int, error) { reader.once.Do(func() { close(reader.started) }) <-reader.release n := copy(p, reader.data) return n, nil } type blockingWriter struct { started chan struct{} release chan struct{} once sync.Once } func (writer *blockingWriter) Write(p []byte) (int, error) { writer.once.Do(func() { close(writer.started) }) <-writer.release return len(p), nil } func TestReadFullContext(t *testing.T) { buf := make([]byte, 5) n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hello"), buf) if err != nil { t.Fatalf("ReadFullContext returned error: %v", err) } if n != 5 || string(buf) != "hello" { t.Fatalf("unexpected payload: n=%d data=%q", n, buf) } } func TestReadFullContextReturnsUnexpectedEOF(t *testing.T) { buf := make([]byte, 5) n, err := ReadFullContext(context.Background(), bytes.NewBufferString("hey"), buf) if !errors.Is(err, io.ErrUnexpectedEOF) { t.Fatalf("expected unexpected EOF, got %v", err) } if n != 3 || string(buf[:n]) != "hey" { t.Fatalf("unexpected payload: n=%d data=%q", n, buf[:n]) } } func TestReadFullContextCanceledWhileBlocked(t *testing.T) { reader := &blockingReader{ started: make(chan struct{}), release: make(chan struct{}), data: []byte("hello"), } ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error, 1) go func() { buf := make([]byte, 5) _, err := ReadFullContext(ctx, reader, buf) done <- err }() <-reader.started cancel() select { case err := <-done: if !errors.Is(err, context.Canceled) { t.Fatalf("expected context canceled, got %v", err) } case <-time.After(200 * time.Millisecond): t.Fatal("ReadFullContext did not return after cancel") } close(reader.release) } func TestWriteFullContext(t *testing.T) { writer := &chunkedWriter{max: 2} n, err := WriteFullContext(context.Background(), writer, []byte("hello")) if err != nil { t.Fatalf("WriteFullContext returned error: %v", err) } if n != 5 || writer.buf.String() != "hello" { t.Fatalf("unexpected write result: n=%d data=%q", n, writer.buf.String()) } } func TestWriteFullContextCanceledWhileBlocked(t *testing.T) { writer := &blockingWriter{ started: make(chan struct{}), release: make(chan struct{}), } ctx, cancel := context.WithCancel(context.Background()) defer cancel() done := make(chan error, 1) go func() { _, err := WriteFullContext(ctx, writer, []byte("hello")) done <- err }() <-writer.started cancel() select { case err := <-done: if !errors.Is(err, context.Canceled) { t.Fatalf("expected context canceled, got %v", err) } case <-time.After(200 * time.Millisecond): t.Fatal("WriteFullContext did not return after cancel") } close(writer.release) } func TestCopyContext(t *testing.T) { var dst bytes.Buffer written, err := CopyContext(context.Background(), &dst, bytes.NewBufferString("hello world")) if err != nil { t.Fatalf("CopyContext returned error: %v", err) } if written != int64(len("hello world")) || dst.String() != "hello world" { t.Fatalf("unexpected copy result: written=%d data=%q", written, dst.String()) } }