notify/stream_helper_test.go

394 lines
11 KiB
Go
Raw Permalink Normal View History

package notify
import (
"bytes"
"context"
"errors"
"io"
"net"
"sync"
"testing"
"time"
)
type streamHelperMock struct {
readBuf *bytes.Reader
writeBuf bytes.Buffer
closeWriteCalls int
closeCalls int
}
func newStreamHelperMock(readData []byte) *streamHelperMock {
return &streamHelperMock{readBuf: bytes.NewReader(readData)}
}
func (s *streamHelperMock) Read(p []byte) (int, error) {
if s == nil || s.readBuf == nil {
return 0, io.EOF
}
return s.readBuf.Read(p)
}
func (s *streamHelperMock) Write(p []byte) (int, error) {
if s == nil {
return 0, io.ErrClosedPipe
}
return s.writeBuf.Write(p)
}
func (s *streamHelperMock) Close() error {
if s != nil {
s.closeCalls++
}
return nil
}
func (s *streamHelperMock) ID() string { return "helper-stream" }
func (s *streamHelperMock) Channel() StreamChannel { return StreamDataChannel }
func (s *streamHelperMock) Metadata() StreamMetadata { return nil }
func (s *streamHelperMock) Context() context.Context { return context.Background() }
func (s *streamHelperMock) LogicalConn() *LogicalConn { return nil }
func (s *streamHelperMock) TransportConn() *TransportConn { return nil }
func (s *streamHelperMock) TransportGeneration() uint64 { return 0 }
func (s *streamHelperMock) LocalAddr() net.Addr { return nil }
func (s *streamHelperMock) RemoteAddr() net.Addr { return nil }
func (s *streamHelperMock) Reset(error) error { return nil }
func (s *streamHelperMock) SetDeadline(time.Time) error { return nil }
func (s *streamHelperMock) SetReadDeadline(time.Time) error {
return nil
}
func (s *streamHelperMock) SetWriteDeadline(time.Time) error {
return nil
}
func (s *streamHelperMock) CloseWrite() error {
if s != nil {
s.closeWriteCalls++
}
return nil
}
func TestCopyToStreamClosesWriteAndCopiesPayload(t *testing.T) {
stream := newStreamHelperMock(nil)
payload := bytes.Repeat([]byte("helper-copy-to-stream-"), 32)
n, err := CopyToStream(context.Background(), stream, bytes.NewReader(payload), StreamCopyOptions{
BufferSize: 17,
CloseWrite: true,
CloseStream: true,
})
if err != nil {
t.Fatalf("CopyToStream failed: %v", err)
}
if got, want := n, int64(len(payload)); got != want {
t.Fatalf("copied bytes = %d, want %d", got, want)
}
if got := stream.writeBuf.Bytes(); !bytes.Equal(got, payload) {
t.Fatalf("stream write payload mismatch: got %d want %d", len(got), len(payload))
}
if got, want := stream.closeWriteCalls, 1; got != want {
t.Fatalf("CloseWrite calls = %d, want %d", got, want)
}
if got := stream.closeCalls; got != 0 {
t.Fatalf("Close calls = %d, want 0", got)
}
}
func TestCopyFromStreamCopiesPayload(t *testing.T) {
payload := bytes.Repeat([]byte("helper-copy-from-stream-"), 24)
stream := newStreamHelperMock(payload)
var dst bytes.Buffer
n, err := CopyFromStream(context.Background(), &dst, stream, StreamCopyOptions{
BufferSize: 19,
})
if err != nil {
t.Fatalf("CopyFromStream failed: %v", err)
}
if got, want := n, int64(len(payload)); got != want {
t.Fatalf("copied bytes = %d, want %d", got, want)
}
if got := dst.Bytes(); !bytes.Equal(got, payload) {
t.Fatalf("copied payload mismatch: got %d want %d", len(got), len(payload))
}
}
func TestBridgeStreamCopiesBothDirections(t *testing.T) {
stream := newStreamHelperMock([]byte("from-stream"))
peer := newStreamHelperMock([]byte("from-peer"))
if err := BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
BufferSize: 4,
}); err != nil {
t.Fatalf("BridgeStream failed: %v", err)
}
if got, want := stream.writeBuf.String(), "from-peer"; got != want {
t.Fatalf("stream received payload = %q, want %q", got, want)
}
if got, want := peer.writeBuf.String(), "from-stream"; got != want {
t.Fatalf("peer received payload = %q, want %q", got, want)
}
if got, want := stream.closeWriteCalls, 1; got != want {
t.Fatalf("stream CloseWrite calls = %d, want %d", got, want)
}
if got, want := peer.closeWriteCalls, 1; got != want {
t.Fatalf("peer CloseWrite calls = %d, want %d", got, want)
}
if got := peer.closeCalls; got != 0 {
t.Fatalf("peer Close calls = %d, want 0", got)
}
}
type blockingStreamHelperMock struct {
readBuf *bytes.Reader
writeBuf bytes.Buffer
closeWriteCalls int
closeCalls int
resetCalls int
resetErr error
closedCh chan struct{}
closeOnce sync.Once
readStarted chan struct{}
readStartOnce sync.Once
}
func newBlockingStreamHelperMock(readData []byte) *blockingStreamHelperMock {
var reader *bytes.Reader
if len(readData) > 0 {
reader = bytes.NewReader(readData)
}
return &blockingStreamHelperMock{
readBuf: reader,
closedCh: make(chan struct{}),
readStarted: make(chan struct{}),
}
}
func (s *blockingStreamHelperMock) Read(p []byte) (int, error) {
s.readStartOnce.Do(func() {
close(s.readStarted)
})
if s == nil {
return 0, io.EOF
}
if s.readBuf != nil && s.readBuf.Len() > 0 {
return s.readBuf.Read(p)
}
<-s.closedCh
if s.resetErr != nil {
return 0, s.resetErr
}
return 0, io.EOF
}
func (s *blockingStreamHelperMock) Write(p []byte) (int, error) {
if s == nil {
return 0, io.ErrClosedPipe
}
return s.writeBuf.Write(p)
}
func (s *blockingStreamHelperMock) Close() error {
if s != nil {
s.closeCalls++
s.closeOnce.Do(func() {
close(s.closedCh)
})
}
return nil
}
func (s *blockingStreamHelperMock) CloseWrite() error {
if s != nil {
s.closeWriteCalls++
}
return nil
}
func (s *blockingStreamHelperMock) Reset(err error) error {
if s != nil {
s.resetCalls++
s.resetErr = err
s.closeOnce.Do(func() {
close(s.closedCh)
})
}
return nil
}
func (s *blockingStreamHelperMock) ID() string { return "blocking-helper-stream" }
func (s *blockingStreamHelperMock) Channel() StreamChannel { return StreamDataChannel }
func (s *blockingStreamHelperMock) Metadata() StreamMetadata { return nil }
func (s *blockingStreamHelperMock) Context() context.Context { return context.Background() }
func (s *blockingStreamHelperMock) LogicalConn() *LogicalConn { return nil }
func (s *blockingStreamHelperMock) TransportConn() *TransportConn { return nil }
func (s *blockingStreamHelperMock) TransportGeneration() uint64 { return 0 }
func (s *blockingStreamHelperMock) LocalAddr() net.Addr { return nil }
func (s *blockingStreamHelperMock) RemoteAddr() net.Addr { return nil }
func (s *blockingStreamHelperMock) SetDeadline(time.Time) error { return nil }
func (s *blockingStreamHelperMock) SetReadDeadline(time.Time) error {
return nil
}
func (s *blockingStreamHelperMock) SetWriteDeadline(time.Time) error {
return nil
}
type blockingPeerHelperMock struct {
writeErr error
closeCalls int
closedCh chan struct{}
closeOnce sync.Once
readStarted chan struct{}
readStartOnce sync.Once
}
func newBlockingPeerHelperMock(writeErr error) *blockingPeerHelperMock {
return &blockingPeerHelperMock{
writeErr: writeErr,
closedCh: make(chan struct{}),
readStarted: make(chan struct{}),
}
}
func (p *blockingPeerHelperMock) Read(buf []byte) (int, error) {
p.readStartOnce.Do(func() {
close(p.readStarted)
})
<-p.closedCh
return 0, io.EOF
}
func (p *blockingPeerHelperMock) Write(buf []byte) (int, error) {
if p.writeErr != nil {
return 0, p.writeErr
}
return len(buf), nil
}
func (p *blockingPeerHelperMock) Close() error {
if p != nil {
p.closeCalls++
p.closeOnce.Do(func() {
close(p.closedCh)
})
}
return nil
}
func TestBridgeStreamResetOnCopyError(t *testing.T) {
writeErr := errors.New("bridge-peer-write-failed")
stream := newBlockingStreamHelperMock([]byte("from-stream"))
peer := newBlockingPeerHelperMock(writeErr)
errCh := make(chan error, 1)
go func() {
errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
BufferSize: 4,
ResetOnCopyError: true,
})
}()
select {
case err := <-errCh:
if !errors.Is(err, writeErr) {
t.Fatalf("BridgeStream error = %v, want %v", err, writeErr)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for BridgeStream write error")
}
if got, want := stream.resetCalls, 1; got != want {
t.Fatalf("stream Reset calls = %d, want %d", got, want)
}
if !errors.Is(stream.resetErr, writeErr) {
t.Fatalf("stream reset error = %v, want %v", stream.resetErr, writeErr)
}
if got := stream.closeCalls; got != 0 {
t.Fatalf("stream Close calls = %d, want 0", got)
}
if got, want := peer.closeCalls, 1; got != want {
t.Fatalf("peer Close calls = %d, want %d", got, want)
}
}
func TestBridgeStreamCopyErrorClosesStreamWithoutReset(t *testing.T) {
writeErr := errors.New("bridge-peer-write-failed")
stream := newBlockingStreamHelperMock([]byte("from-stream"))
peer := newBlockingPeerHelperMock(writeErr)
errCh := make(chan error, 1)
go func() {
errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
BufferSize: 4,
})
}()
select {
case err := <-errCh:
if !errors.Is(err, writeErr) {
t.Fatalf("BridgeStream error = %v, want %v", err, writeErr)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for BridgeStream write error")
}
if got := stream.resetCalls; got != 0 {
t.Fatalf("stream Reset calls = %d, want 0", got)
}
if got, want := stream.closeCalls, 1; got != want {
t.Fatalf("stream Close calls = %d, want %d", got, want)
}
if got, want := peer.closeCalls, 1; got != want {
t.Fatalf("peer Close calls = %d, want %d", got, want)
}
}
func TestBridgeStreamContextCancelUnblocksBlockedCopies(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
stream := newBlockingStreamHelperMock(nil)
peer := newBlockingPeerHelperMock(nil)
errCh := make(chan error, 1)
go func() {
errCh <- BridgeStream(ctx, stream, peer, StreamBridgeOptions{
BufferSize: 4,
})
}()
waitHelperReadStarted(t, stream.readStarted, time.Second)
waitHelperReadStarted(t, peer.readStarted, time.Second)
cancel()
select {
case err := <-errCh:
if !errors.Is(err, context.Canceled) {
t.Fatalf("BridgeStream error = %v, want %v", err, context.Canceled)
}
case <-time.After(time.Second):
t.Fatal("timed out waiting for BridgeStream cancel")
}
if got := stream.resetCalls; got != 0 {
t.Fatalf("stream Reset calls = %d, want 0", got)
}
if got, want := stream.closeCalls, 1; got != want {
t.Fatalf("stream Close calls = %d, want %d", got, want)
}
if got, want := peer.closeCalls, 1; got != want {
t.Fatalf("peer Close calls = %d, want %d", got, want)
}
}
func waitHelperReadStarted(t *testing.T, started <-chan struct{}, timeout time.Duration) {
t.Helper()
select {
case <-started:
case <-time.After(timeout):
t.Fatal("timed out waiting for helper read to start")
}
}