394 lines
11 KiB
Go
394 lines
11 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"errors"
|
||
|
|
"io"
|
||
|
|
"net"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
type streamHelperMock struct {
|
||
|
|
readBuf *bytes.Reader
|
||
|
|
writeBuf bytes.Buffer
|
||
|
|
closeWriteCalls int
|
||
|
|
closeCalls int
|
||
|
|
}
|
||
|
|
|
||
|
|
func newStreamHelperMock(readData []byte) *streamHelperMock {
|
||
|
|
return &streamHelperMock{readBuf: bytes.NewReader(readData)}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *streamHelperMock) Read(p []byte) (int, error) {
|
||
|
|
if s == nil || s.readBuf == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
return s.readBuf.Read(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *streamHelperMock) Write(p []byte) (int, error) {
|
||
|
|
if s == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
return s.writeBuf.Write(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *streamHelperMock) Close() error {
|
||
|
|
if s != nil {
|
||
|
|
s.closeCalls++
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *streamHelperMock) ID() string { return "helper-stream" }
|
||
|
|
func (s *streamHelperMock) Channel() StreamChannel { return StreamDataChannel }
|
||
|
|
func (s *streamHelperMock) Metadata() StreamMetadata { return nil }
|
||
|
|
func (s *streamHelperMock) Context() context.Context { return context.Background() }
|
||
|
|
func (s *streamHelperMock) LogicalConn() *LogicalConn { return nil }
|
||
|
|
func (s *streamHelperMock) TransportConn() *TransportConn { return nil }
|
||
|
|
func (s *streamHelperMock) TransportGeneration() uint64 { return 0 }
|
||
|
|
func (s *streamHelperMock) LocalAddr() net.Addr { return nil }
|
||
|
|
func (s *streamHelperMock) RemoteAddr() net.Addr { return nil }
|
||
|
|
func (s *streamHelperMock) Reset(error) error { return nil }
|
||
|
|
func (s *streamHelperMock) SetDeadline(time.Time) error { return nil }
|
||
|
|
func (s *streamHelperMock) SetReadDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (s *streamHelperMock) SetWriteDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *streamHelperMock) CloseWrite() error {
|
||
|
|
if s != nil {
|
||
|
|
s.closeWriteCalls++
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCopyToStreamClosesWriteAndCopiesPayload(t *testing.T) {
|
||
|
|
stream := newStreamHelperMock(nil)
|
||
|
|
payload := bytes.Repeat([]byte("helper-copy-to-stream-"), 32)
|
||
|
|
|
||
|
|
n, err := CopyToStream(context.Background(), stream, bytes.NewReader(payload), StreamCopyOptions{
|
||
|
|
BufferSize: 17,
|
||
|
|
CloseWrite: true,
|
||
|
|
CloseStream: true,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("CopyToStream failed: %v", err)
|
||
|
|
}
|
||
|
|
if got, want := n, int64(len(payload)); got != want {
|
||
|
|
t.Fatalf("copied bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := stream.writeBuf.Bytes(); !bytes.Equal(got, payload) {
|
||
|
|
t.Fatalf("stream write payload mismatch: got %d want %d", len(got), len(payload))
|
||
|
|
}
|
||
|
|
if got, want := stream.closeWriteCalls, 1; got != want {
|
||
|
|
t.Fatalf("CloseWrite calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := stream.closeCalls; got != 0 {
|
||
|
|
t.Fatalf("Close calls = %d, want 0", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestCopyFromStreamCopiesPayload(t *testing.T) {
|
||
|
|
payload := bytes.Repeat([]byte("helper-copy-from-stream-"), 24)
|
||
|
|
stream := newStreamHelperMock(payload)
|
||
|
|
var dst bytes.Buffer
|
||
|
|
|
||
|
|
n, err := CopyFromStream(context.Background(), &dst, stream, StreamCopyOptions{
|
||
|
|
BufferSize: 19,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("CopyFromStream failed: %v", err)
|
||
|
|
}
|
||
|
|
if got, want := n, int64(len(payload)); got != want {
|
||
|
|
t.Fatalf("copied bytes = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := dst.Bytes(); !bytes.Equal(got, payload) {
|
||
|
|
t.Fatalf("copied payload mismatch: got %d want %d", len(got), len(payload))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBridgeStreamCopiesBothDirections(t *testing.T) {
|
||
|
|
stream := newStreamHelperMock([]byte("from-stream"))
|
||
|
|
peer := newStreamHelperMock([]byte("from-peer"))
|
||
|
|
|
||
|
|
if err := BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
|
||
|
|
BufferSize: 4,
|
||
|
|
}); err != nil {
|
||
|
|
t.Fatalf("BridgeStream failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if got, want := stream.writeBuf.String(), "from-peer"; got != want {
|
||
|
|
t.Fatalf("stream received payload = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := peer.writeBuf.String(), "from-stream"; got != want {
|
||
|
|
t.Fatalf("peer received payload = %q, want %q", got, want)
|
||
|
|
}
|
||
|
|
if got, want := stream.closeWriteCalls, 1; got != want {
|
||
|
|
t.Fatalf("stream CloseWrite calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := peer.closeWriteCalls, 1; got != want {
|
||
|
|
t.Fatalf("peer CloseWrite calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got := peer.closeCalls; got != 0 {
|
||
|
|
t.Fatalf("peer Close calls = %d, want 0", got)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type blockingStreamHelperMock struct {
|
||
|
|
readBuf *bytes.Reader
|
||
|
|
writeBuf bytes.Buffer
|
||
|
|
closeWriteCalls int
|
||
|
|
closeCalls int
|
||
|
|
resetCalls int
|
||
|
|
resetErr error
|
||
|
|
closedCh chan struct{}
|
||
|
|
closeOnce sync.Once
|
||
|
|
readStarted chan struct{}
|
||
|
|
readStartOnce sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func newBlockingStreamHelperMock(readData []byte) *blockingStreamHelperMock {
|
||
|
|
var reader *bytes.Reader
|
||
|
|
if len(readData) > 0 {
|
||
|
|
reader = bytes.NewReader(readData)
|
||
|
|
}
|
||
|
|
return &blockingStreamHelperMock{
|
||
|
|
readBuf: reader,
|
||
|
|
closedCh: make(chan struct{}),
|
||
|
|
readStarted: make(chan struct{}),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) Read(p []byte) (int, error) {
|
||
|
|
s.readStartOnce.Do(func() {
|
||
|
|
close(s.readStarted)
|
||
|
|
})
|
||
|
|
if s == nil {
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
if s.readBuf != nil && s.readBuf.Len() > 0 {
|
||
|
|
return s.readBuf.Read(p)
|
||
|
|
}
|
||
|
|
<-s.closedCh
|
||
|
|
if s.resetErr != nil {
|
||
|
|
return 0, s.resetErr
|
||
|
|
}
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) Write(p []byte) (int, error) {
|
||
|
|
if s == nil {
|
||
|
|
return 0, io.ErrClosedPipe
|
||
|
|
}
|
||
|
|
return s.writeBuf.Write(p)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) Close() error {
|
||
|
|
if s != nil {
|
||
|
|
s.closeCalls++
|
||
|
|
s.closeOnce.Do(func() {
|
||
|
|
close(s.closedCh)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) CloseWrite() error {
|
||
|
|
if s != nil {
|
||
|
|
s.closeWriteCalls++
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) Reset(err error) error {
|
||
|
|
if s != nil {
|
||
|
|
s.resetCalls++
|
||
|
|
s.resetErr = err
|
||
|
|
s.closeOnce.Do(func() {
|
||
|
|
close(s.closedCh)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *blockingStreamHelperMock) ID() string { return "blocking-helper-stream" }
|
||
|
|
func (s *blockingStreamHelperMock) Channel() StreamChannel { return StreamDataChannel }
|
||
|
|
func (s *blockingStreamHelperMock) Metadata() StreamMetadata { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) Context() context.Context { return context.Background() }
|
||
|
|
func (s *blockingStreamHelperMock) LogicalConn() *LogicalConn { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) TransportConn() *TransportConn { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) TransportGeneration() uint64 { return 0 }
|
||
|
|
func (s *blockingStreamHelperMock) LocalAddr() net.Addr { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) RemoteAddr() net.Addr { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) SetDeadline(time.Time) error { return nil }
|
||
|
|
func (s *blockingStreamHelperMock) SetReadDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
func (s *blockingStreamHelperMock) SetWriteDeadline(time.Time) error {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
type blockingPeerHelperMock struct {
|
||
|
|
writeErr error
|
||
|
|
closeCalls int
|
||
|
|
closedCh chan struct{}
|
||
|
|
closeOnce sync.Once
|
||
|
|
readStarted chan struct{}
|
||
|
|
readStartOnce sync.Once
|
||
|
|
}
|
||
|
|
|
||
|
|
func newBlockingPeerHelperMock(writeErr error) *blockingPeerHelperMock {
|
||
|
|
return &blockingPeerHelperMock{
|
||
|
|
writeErr: writeErr,
|
||
|
|
closedCh: make(chan struct{}),
|
||
|
|
readStarted: make(chan struct{}),
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (p *blockingPeerHelperMock) Read(buf []byte) (int, error) {
|
||
|
|
p.readStartOnce.Do(func() {
|
||
|
|
close(p.readStarted)
|
||
|
|
})
|
||
|
|
<-p.closedCh
|
||
|
|
return 0, io.EOF
|
||
|
|
}
|
||
|
|
|
||
|
|
func (p *blockingPeerHelperMock) Write(buf []byte) (int, error) {
|
||
|
|
if p.writeErr != nil {
|
||
|
|
return 0, p.writeErr
|
||
|
|
}
|
||
|
|
return len(buf), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (p *blockingPeerHelperMock) Close() error {
|
||
|
|
if p != nil {
|
||
|
|
p.closeCalls++
|
||
|
|
p.closeOnce.Do(func() {
|
||
|
|
close(p.closedCh)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBridgeStreamResetOnCopyError(t *testing.T) {
|
||
|
|
writeErr := errors.New("bridge-peer-write-failed")
|
||
|
|
stream := newBlockingStreamHelperMock([]byte("from-stream"))
|
||
|
|
peer := newBlockingPeerHelperMock(writeErr)
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
|
||
|
|
BufferSize: 4,
|
||
|
|
ResetOnCopyError: true,
|
||
|
|
})
|
||
|
|
}()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, writeErr) {
|
||
|
|
t.Fatalf("BridgeStream error = %v, want %v", err, writeErr)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for BridgeStream write error")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got, want := stream.resetCalls, 1; got != want {
|
||
|
|
t.Fatalf("stream Reset calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if !errors.Is(stream.resetErr, writeErr) {
|
||
|
|
t.Fatalf("stream reset error = %v, want %v", stream.resetErr, writeErr)
|
||
|
|
}
|
||
|
|
if got := stream.closeCalls; got != 0 {
|
||
|
|
t.Fatalf("stream Close calls = %d, want 0", got)
|
||
|
|
}
|
||
|
|
if got, want := peer.closeCalls, 1; got != want {
|
||
|
|
t.Fatalf("peer Close calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBridgeStreamCopyErrorClosesStreamWithoutReset(t *testing.T) {
|
||
|
|
writeErr := errors.New("bridge-peer-write-failed")
|
||
|
|
stream := newBlockingStreamHelperMock([]byte("from-stream"))
|
||
|
|
peer := newBlockingPeerHelperMock(writeErr)
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
errCh <- BridgeStream(context.Background(), stream, peer, StreamBridgeOptions{
|
||
|
|
BufferSize: 4,
|
||
|
|
})
|
||
|
|
}()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, writeErr) {
|
||
|
|
t.Fatalf("BridgeStream error = %v, want %v", err, writeErr)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for BridgeStream write error")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := stream.resetCalls; got != 0 {
|
||
|
|
t.Fatalf("stream Reset calls = %d, want 0", got)
|
||
|
|
}
|
||
|
|
if got, want := stream.closeCalls, 1; got != want {
|
||
|
|
t.Fatalf("stream Close calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := peer.closeCalls, 1; got != want {
|
||
|
|
t.Fatalf("peer Close calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestBridgeStreamContextCancelUnblocksBlockedCopies(t *testing.T) {
|
||
|
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
|
defer cancel()
|
||
|
|
|
||
|
|
stream := newBlockingStreamHelperMock(nil)
|
||
|
|
peer := newBlockingPeerHelperMock(nil)
|
||
|
|
|
||
|
|
errCh := make(chan error, 1)
|
||
|
|
go func() {
|
||
|
|
errCh <- BridgeStream(ctx, stream, peer, StreamBridgeOptions{
|
||
|
|
BufferSize: 4,
|
||
|
|
})
|
||
|
|
}()
|
||
|
|
|
||
|
|
waitHelperReadStarted(t, stream.readStarted, time.Second)
|
||
|
|
waitHelperReadStarted(t, peer.readStarted, time.Second)
|
||
|
|
cancel()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case err := <-errCh:
|
||
|
|
if !errors.Is(err, context.Canceled) {
|
||
|
|
t.Fatalf("BridgeStream error = %v, want %v", err, context.Canceled)
|
||
|
|
}
|
||
|
|
case <-time.After(time.Second):
|
||
|
|
t.Fatal("timed out waiting for BridgeStream cancel")
|
||
|
|
}
|
||
|
|
|
||
|
|
if got := stream.resetCalls; got != 0 {
|
||
|
|
t.Fatalf("stream Reset calls = %d, want 0", got)
|
||
|
|
}
|
||
|
|
if got, want := stream.closeCalls, 1; got != want {
|
||
|
|
t.Fatalf("stream Close calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
if got, want := peer.closeCalls, 1; got != want {
|
||
|
|
t.Fatalf("peer Close calls = %d, want %d", got, want)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func waitHelperReadStarted(t *testing.T, started <-chan struct{}, timeout time.Duration) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
select {
|
||
|
|
case <-started:
|
||
|
|
case <-time.After(timeout):
|
||
|
|
t.Fatal("timed out waiting for helper read to start")
|
||
|
|
}
|
||
|
|
}
|