package notify import ( "b612.me/stario" "context" "errors" "net" "sync" "sync/atomic" "testing" "time" ) type serializedWriteTestConn struct { activeWrites int32 concurrent int32 writeCount int32 } func (c *serializedWriteTestConn) Read([]byte) (int, error) { return 0, net.ErrClosed } func (c *serializedWriteTestConn) Close() error { return nil } func (c *serializedWriteTestConn) LocalAddr() net.Addr { return nil } func (c *serializedWriteTestConn) RemoteAddr() net.Addr { return nil } func (c *serializedWriteTestConn) SetDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) SetReadDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) SetWriteDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) Write(p []byte) (int, error) { if !atomic.CompareAndSwapInt32(&c.activeWrites, 0, 1) { atomic.StoreInt32(&c.concurrent, 1) return len(p), nil } time.Sleep(10 * time.Millisecond) atomic.AddInt32(&c.writeCount, 1) atomic.StoreInt32(&c.activeWrites, 0) return len(p), nil } func TestWriteFullToConnSerializesConcurrentWriters(t *testing.T) { conn := &serializedWriteTestConn{} payload := []byte("payload") var wg sync.WaitGroup for index := 0; index < 4; index++ { wg.Add(1) go func() { defer wg.Done() if err := writeFullToConn(conn, payload); err != nil { t.Errorf("writeFullToConn failed: %v", err) } }() } wg.Wait() if atomic.LoadInt32(&conn.concurrent) != 0 { t.Fatal("detected concurrent conn.Write execution") } if got, want := atomic.LoadInt32(&conn.writeCount), int32(4); got != want { t.Fatalf("write count = %d, want %d", got, want) } } func TestBulkBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() sender := newBulkBatchSender(newTransportBinding(left, stario.NewQueue())) ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) defer cancel() errCh := make(chan error, 1) go func() { errCh <- sender.submit(ctx, []byte("payload")) }() select { case err := <-errCh: if err == nil { t.Fatal("sender.submit should fail when receiver stalls") } if !isTimeoutLikeError(err) { t.Fatalf("sender.submit error = %v, want timeout-like error", err) } case <-time.After(time.Second): t.Fatal("sender.submit should not hang when receiver stalls") } } func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() sender := newControlBatchSender(newTransportBinding(left, stario.NewQueue())) deadline := time.Now().Add(50 * time.Millisecond) errCh := make(chan error, 1) go func() { errCh <- sender.submit([]byte("payload"), deadline) }() select { case err := <-errCh: if err == nil { t.Fatal("sender.submit should fail when receiver stalls") } if !isTimeoutLikeError(err) { t.Fatalf("sender.submit error = %v, want timeout-like error", err) } case <-time.After(time.Second): t.Fatal("sender.submit should not hang when receiver stalls") } } type blockingPacketWriteConn struct { startCh chan struct{} unblockCh chan struct{} writeCount atomic.Int32 } func newBlockingPacketWriteConn() *blockingPacketWriteConn { return &blockingPacketWriteConn{ startCh: make(chan struct{}), unblockCh: make(chan struct{}), } } func (c *blockingPacketWriteConn) Read([]byte) (int, error) { return 0, net.ErrClosed } func (c *blockingPacketWriteConn) Close() error { return nil } func (c *blockingPacketWriteConn) LocalAddr() net.Addr { return &net.UDPAddr{IP: net.IPv4zero, Port: 1} } func (c *blockingPacketWriteConn) RemoteAddr() net.Addr { return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 2} } func (c *blockingPacketWriteConn) SetDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) SetReadDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) SetWriteDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) Write(p []byte) (int, error) { if c.writeCount.Add(1) == 1 { close(c.startCh) <-c.unblockCh } return len(p), nil } func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) sender := newBulkBatchSender(binding) defer sender.stop() firstErrCh := make(chan error, 1) go func() { firstErrCh <- sender.submit(context.Background(), []byte("first")) }() select { case <-conn.startCh: case <-time.After(time.Second): t.Fatal("first shared bulk write did not start") } ctx, cancel := context.WithCancel(context.Background()) secondErrCh := make(chan error, 1) go func() { secondErrCh <- sender.submit(ctx, []byte("second")) }() time.Sleep(20 * time.Millisecond) cancel() select { case err := <-secondErrCh: if !errors.Is(err, context.Canceled) { t.Fatalf("second shared bulk submit error = %v, want %v", err, context.Canceled) } case <-time.After(time.Second): t.Fatal("second shared bulk submit did not return after cancel") } close(conn.unblockCh) select { case err := <-firstErrCh: if err != nil { t.Fatalf("first shared bulk submit failed: %v", err) } case <-time.After(time.Second): t.Fatal("first shared bulk submit did not finish") } time.Sleep(50 * time.Millisecond) if got, want := conn.writeCount.Load(), int32(1); got != want { t.Fatalf("shared bulk write count = %d, want %d", got, want) } } func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) sender := newBulkBatchSender(binding) defer sender.stop() ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) go func() { errCh <- sender.submit(ctx, []byte("payload")) }() select { case <-conn.startCh: case <-time.After(time.Second): t.Fatal("shared bulk write did not start") } cancel() select { case err := <-errCh: t.Fatalf("sender.submit returned before flush completed: %v", err) case <-time.After(50 * time.Millisecond): } close(conn.unblockCh) select { case err := <-errCh: if err != nil { t.Fatalf("sender.submit failed after started flush: %v", err) } case <-time.After(time.Second): t.Fatal("sender.submit did not return after started flush completed") } } func TestTransportBindingStopBackgroundWorkersStopsSharedSender(t *testing.T) { binding := newTransportBinding(newBlockingPacketWriteConn(), stario.NewQueue()) sender := binding.bulkBatchSenderSnapshot() binding.stopBackgroundWorkers() err := sender.submit(context.Background(), []byte("payload")) if !errors.Is(err, errTransportDetached) { t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) } } func TestTransportBindingStopBackgroundWorkersStopsControlSender(t *testing.T) { binding := newTransportBinding(&serializedWriteTestConn{}, stario.NewQueue()) sender := binding.controlBatchSenderSnapshot() binding.stopBackgroundWorkers() err := sender.submit([]byte("payload"), time.Time{}) if !errors.Is(err, errTransportDetached) { t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) } }