package notify import ( "bytes" "context" "errors" "io" "net" "sync" "testing" "time" ) type streamHelperMock struct { readBuf *bytes.Reader writeBuf bytes.Buffer closeWriteCalls int closeCalls int } func newStreamHelperMock(readData []byte) *streamHelperMock { return &streamHelperMock{readBuf: bytes.NewReader(readData)} } func (s *streamHelperMock) Read(p []byte) (int, error) { if s == nil || s.readBuf == nil { return 0, io.EOF } return s.readBuf.Read(p) } func (s *streamHelperMock) Write(p []byte) (int, error) { if s == nil { return 0, io.ErrClosedPipe } return s.writeBuf.Write(p) } func (s *streamHelperMock) Close() error { if s != nil { s.closeCalls++ } return nil } func (s *streamHelperMock) ID() string { return "helper-stream" } func (s *streamHelperMock) Channel() StreamChannel { return StreamDataChannel } func (s *streamHelperMock) Metadata() StreamMetadata { return nil } func (s *streamHelperMock) Context() context.Context { return context.Background() } func (s *streamHelperMock) LogicalConn() *LogicalConn { return nil } func (s *streamHelperMock) TransportConn() *TransportConn { return nil } func (s *streamHelperMock) TransportGeneration() uint64 { return 0 } func (s *streamHelperMock) LocalAddr() net.Addr { return nil } func (s *streamHelperMock) RemoteAddr() net.Addr { return nil } func (s *streamHelperMock) Reset(error) error { return nil } func (s *streamHelperMock) SetDeadline(time.Time) error { return nil } func (s *streamHelperMock) SetReadDeadline(time.Time) error { return nil } func (s *streamHelperMock) SetWriteDeadline(time.Time) error { return nil } func (s *streamHelperMock) CloseWrite() error { if s != nil { s.closeWriteCalls++ } return nil } func TestCopyToStreamClosesWriteAndCopiesPayload(t *testing.T) { stream := newStreamHelperMock(nil) payload := bytes.Repeat([]byte("helper-copy-to-stream-"), 32) n, err := CopyToStream(context.Background(), stream, bytes.NewReader(payload), StreamCopyOptions{ BufferSize: 17, CloseWrite: true, CloseStream: true, }) if err != nil { t.Fatalf("CopyToStream failed: %v", err) } if got, want := n, int64(len(payload)); got != want { t.Fatalf("copied bytes = %d, want %d", got, want) } if got := stream.writeBuf.Bytes(); !bytes.Equal(got, payload) { t.Fatalf("stream write payload mismatch: got %d want %d", len(got), len(payload)) } if got, want := stream.closeWriteCalls, 1; got != want { t.Fatalf("CloseWrite calls = %d, want %d", got, want) } if got := stream.closeCalls; got != 0 { t.Fatalf("Close calls = %d, want 0", got) } } func TestCopyFromStreamCopiesPayload(t *testing.T) { payload := bytes.Repeat([]byte("helper-copy-from-stream-"), 24) stream := newStreamHelperMock(payload) var dst bytes.Buffer n, err := CopyFromStream(context.Background(), &dst, stream, StreamCopyOptions{ BufferSize: 19, }) if err != nil { t.Fatalf("CopyFromStream failed: %v", err) } if got, want := n, int64(len(payload)); got != want { t.Fatalf("copied bytes = %d, want %d", got, want) } if got := dst.Bytes(); !bytes.Equal(got, payload) { t.Fatalf("copied payload mismatch: got %d want %d", len(got), len(payload)) } } func TestBridgeStreamCopiesBothDirections(t *testing.T) { stream := newStreamHelperMock([]byte("from-stream")) peer := newStreamHelperMock([]byte("from-peer")) if err := BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ BufferSize: 4, }); err != nil { t.Fatalf("BridgeStream failed: %v", err) } if got, want := stream.writeBuf.String(), "from-peer"; got != want { t.Fatalf("stream received payload = %q, want %q", got, want) } if got, want := peer.writeBuf.String(), "from-stream"; got != want { t.Fatalf("peer received payload = %q, want %q", got, want) } if got, want := stream.closeWriteCalls, 1; got != want { t.Fatalf("stream CloseWrite calls = %d, want %d", got, want) } if got, want := peer.closeWriteCalls, 1; got != want { t.Fatalf("peer CloseWrite calls = %d, want %d", got, want) } if got := peer.closeCalls; got != 0 { t.Fatalf("peer Close calls = %d, want 0", got) } } type blockingStreamHelperMock struct { readBuf *bytes.Reader writeBuf bytes.Buffer closeWriteCalls int closeCalls int resetCalls int resetErr error closedCh chan struct{} closeOnce sync.Once readStarted chan struct{} readStartOnce sync.Once } func newBlockingStreamHelperMock(readData []byte) *blockingStreamHelperMock { var reader *bytes.Reader if len(readData) > 0 { reader = bytes.NewReader(readData) } return &blockingStreamHelperMock{ readBuf: reader, closedCh: make(chan struct{}), readStarted: make(chan struct{}), } } func (s *blockingStreamHelperMock) Read(p []byte) (int, error) { s.readStartOnce.Do(func() { close(s.readStarted) }) if s == nil { return 0, io.EOF } if s.readBuf != nil && s.readBuf.Len() > 0 { return s.readBuf.Read(p) } <-s.closedCh if s.resetErr != nil { return 0, s.resetErr } return 0, io.EOF } func (s *blockingStreamHelperMock) Write(p []byte) (int, error) { if s == nil { return 0, io.ErrClosedPipe } return s.writeBuf.Write(p) } func (s *blockingStreamHelperMock) Close() error { if s != nil { s.closeCalls++ s.closeOnce.Do(func() { close(s.closedCh) }) } return nil } func (s *blockingStreamHelperMock) CloseWrite() error { if s != nil { s.closeWriteCalls++ } return nil } func (s *blockingStreamHelperMock) Reset(err error) error { if s != nil { s.resetCalls++ s.resetErr = err s.closeOnce.Do(func() { close(s.closedCh) }) } return nil } func (s *blockingStreamHelperMock) ID() string { return "blocking-helper-stream" } func (s *blockingStreamHelperMock) Channel() StreamChannel { return StreamDataChannel } func (s *blockingStreamHelperMock) Metadata() StreamMetadata { return nil } func (s *blockingStreamHelperMock) Context() context.Context { return context.Background() } func (s *blockingStreamHelperMock) LogicalConn() *LogicalConn { return nil } func (s *blockingStreamHelperMock) TransportConn() *TransportConn { return nil } func (s *blockingStreamHelperMock) TransportGeneration() uint64 { return 0 } func (s *blockingStreamHelperMock) LocalAddr() net.Addr { return nil } func (s *blockingStreamHelperMock) RemoteAddr() net.Addr { return nil } func (s *blockingStreamHelperMock) SetDeadline(time.Time) error { return nil } func (s *blockingStreamHelperMock) SetReadDeadline(time.Time) error { return nil } func (s *blockingStreamHelperMock) SetWriteDeadline(time.Time) error { return nil } type blockingPeerHelperMock struct { writeErr error closeCalls int closedCh chan struct{} closeOnce sync.Once readStarted chan struct{} readStartOnce sync.Once } func newBlockingPeerHelperMock(writeErr error) *blockingPeerHelperMock { return &blockingPeerHelperMock{ writeErr: writeErr, closedCh: make(chan struct{}), readStarted: make(chan struct{}), } } func (p *blockingPeerHelperMock) Read(buf []byte) (int, error) { p.readStartOnce.Do(func() { close(p.readStarted) }) <-p.closedCh return 0, io.EOF } func (p *blockingPeerHelperMock) Write(buf []byte) (int, error) { if p.writeErr != nil { return 0, p.writeErr } return len(buf), nil } func (p *blockingPeerHelperMock) Close() error { if p != nil { p.closeCalls++ p.closeOnce.Do(func() { close(p.closedCh) }) } return nil } func TestBridgeStreamResetOnCopyError(t *testing.T) { writeErr := errors.New("bridge-peer-write-failed") stream := newBlockingStreamHelperMock([]byte("from-stream")) peer := newBlockingPeerHelperMock(writeErr) errCh := make(chan error, 1) go func() { errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ BufferSize: 4, ResetOnCopyError: true, }) }() select { case err := <-errCh: if !errors.Is(err, writeErr) { t.Fatalf("BridgeStream error = %v, want %v", err, writeErr) } case <-time.After(time.Second): t.Fatal("timed out waiting for BridgeStream write error") } if got, want := stream.resetCalls, 1; got != want { t.Fatalf("stream Reset calls = %d, want %d", got, want) } if !errors.Is(stream.resetErr, writeErr) { t.Fatalf("stream reset error = %v, want %v", stream.resetErr, writeErr) } if got := stream.closeCalls; got != 0 { t.Fatalf("stream Close calls = %d, want 0", got) } if got, want := peer.closeCalls, 1; got != want { t.Fatalf("peer Close calls = %d, want %d", got, want) } } func TestBridgeStreamCopyErrorClosesStreamWithoutReset(t *testing.T) { writeErr := errors.New("bridge-peer-write-failed") stream := newBlockingStreamHelperMock([]byte("from-stream")) peer := newBlockingPeerHelperMock(writeErr) errCh := make(chan error, 1) go func() { errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{ BufferSize: 4, }) }() select { case err := <-errCh: if !errors.Is(err, writeErr) { t.Fatalf("BridgeStream error = %v, want %v", err, writeErr) } case <-time.After(time.Second): t.Fatal("timed out waiting for BridgeStream write error") } if got := stream.resetCalls; got != 0 { t.Fatalf("stream Reset calls = %d, want 0", got) } if got, want := stream.closeCalls, 1; got != want { t.Fatalf("stream Close calls = %d, want %d", got, want) } if got, want := peer.closeCalls, 1; got != want { t.Fatalf("peer Close calls = %d, want %d", got, want) } } func TestBridgeStreamContextCancelUnblocksBlockedCopies(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() stream := newBlockingStreamHelperMock(nil) peer := newBlockingPeerHelperMock(nil) errCh := make(chan error, 1) go func() { errCh <- BridgeStream(ctx, stream, peer, StreamBridgeOptions{ BufferSize: 4, }) }() waitHelperReadStarted(t, stream.readStarted, time.Second) waitHelperReadStarted(t, peer.readStarted, time.Second) cancel() select { case err := <-errCh: if !errors.Is(err, context.Canceled) { t.Fatalf("BridgeStream error = %v, want %v", err, context.Canceled) } case <-time.After(time.Second): t.Fatal("timed out waiting for BridgeStream cancel") } if got := stream.resetCalls; got != 0 { t.Fatalf("stream Reset calls = %d, want 0", got) } if got, want := stream.closeCalls, 1; got != want { t.Fatalf("stream Close calls = %d, want %d", got, want) } if got, want := peer.closeCalls, 1; got != want { t.Fatalf("peer Close calls = %d, want %d", got, want) } } func waitHelperReadStarted(t *testing.T, started <-chan struct{}, timeout time.Duration) { t.Helper() select { case <-started: case <-time.After(timeout): t.Fatal("timed out waiting for helper read to start") } }