package notify import ( "b612.me/stario" "bytes" "context" "errors" "io" "net" "sync" "sync/atomic" "testing" "time" ) type serializedWriteTestConn struct { activeWrites int32 concurrent int32 writeCount int32 } func (c *serializedWriteTestConn) Read([]byte) (int, error) { return 0, net.ErrClosed } func (c *serializedWriteTestConn) Close() error { return nil } func (c *serializedWriteTestConn) LocalAddr() net.Addr { return nil } func (c *serializedWriteTestConn) RemoteAddr() net.Addr { return nil } func (c *serializedWriteTestConn) SetDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) SetReadDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) SetWriteDeadline(time.Time) error { return nil } func (c *serializedWriteTestConn) Write(p []byte) (int, error) { if !atomic.CompareAndSwapInt32(&c.activeWrites, 0, 1) { atomic.StoreInt32(&c.concurrent, 1) return len(p), nil } time.Sleep(10 * time.Millisecond) atomic.AddInt32(&c.writeCount, 1) atomic.StoreInt32(&c.activeWrites, 0) return len(p), nil } func TestWriteFullToConnSerializesConcurrentWriters(t *testing.T) { conn := &serializedWriteTestConn{} payload := []byte("payload") var wg sync.WaitGroup for index := 0; index < 4; index++ { wg.Add(1) go func() { defer wg.Done() if err := writeFullToConn(conn, payload); err != nil { t.Errorf("writeFullToConn failed: %v", err) } }() } wg.Wait() if atomic.LoadInt32(&conn.concurrent) != 0 { t.Fatal("detected concurrent conn.Write execution") } if got, want := atomic.LoadInt32(&conn.writeCount), int32(4); got != want { t.Fatalf("write count = %d, want %d", got, want) } } func newTestBulkBatchSender(binding *transportBinding) *bulkBatchSender { return newTestBulkBatchSenderWithWriteTimeout(binding, nil) } func newTestBulkBatchSenderWithWriteTimeout(binding *transportBinding, writeTimeout func() time.Duration) *bulkBatchSender { return newBulkBatchSender(binding, bulkBatchCodec{ encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { return append([]byte(nil), frame.Payload...), nil, nil }, encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { payload, err := encodeBulkFastBatchPlain(frames) return payload, nil, err }, }, writeTimeout) } func newTestStreamBatchSender(binding *transportBinding, writeTimeout func() time.Duration) *streamBatchSender { return newStreamBatchSender(binding, streamBatchCodec{ encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { return encodeStreamFastFramePayload(frame) }, encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { return encodeStreamFastBatchPlain(frames) }, }, writeTimeout) } func TestBulkBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() sender := newTestBulkBatchSenderWithWriteTimeout(newTransportBinding(left, stario.NewQueue()), func() time.Duration { return 50 * time.Millisecond }) errCh := make(chan error, 1) go func() { errCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload")) }() select { case err := <-errCh: if err == nil { t.Fatal("sender.submit should fail when receiver stalls") } if !isTimeoutLikeError(err) { t.Fatalf("sender.submit error = %v, want timeout-like error", err) } case <-time.After(time.Second): t.Fatal("sender.submit should not hang when receiver stalls") } } func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() sender := newTestStreamBatchSender(newTransportBinding(left, stario.NewQueue()), func() time.Duration { return 50 * time.Millisecond }) errCh := make(chan error, 1) go func() { errCh <- sender.submitData(context.Background(), 1, 1, streamFastPathVersionV2, []byte("payload")) }() select { case err := <-errCh: if err == nil { t.Fatal("sender.submit should fail when receiver stalls") } if !isTimeoutLikeError(err) { t.Fatalf("sender.submit error = %v, want timeout-like error", err) } case <-time.After(time.Second): t.Fatal("sender.submit should not hang when receiver stalls") } } func TestControlBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() sender := newControlBatchSender(newTransportBinding(left, stario.NewQueue())) deadline := time.Now().Add(50 * time.Millisecond) errCh := make(chan error, 1) go func() { errCh <- sender.submit([]byte("payload"), deadline) }() select { case err := <-errCh: if err == nil { t.Fatal("sender.submit should fail when receiver stalls") } if !isTimeoutLikeError(err) { t.Fatalf("sender.submit error = %v, want timeout-like error", err) } case <-time.After(time.Second): t.Fatal("sender.submit should not hang when receiver stalls") } } type blockingPacketWriteConn struct { startCh chan struct{} unblockCh chan struct{} writeCount atomic.Int32 } func newBlockingPacketWriteConn() *blockingPacketWriteConn { return &blockingPacketWriteConn{ startCh: make(chan struct{}), unblockCh: make(chan struct{}), } } func (c *blockingPacketWriteConn) Read([]byte) (int, error) { return 0, net.ErrClosed } func (c *blockingPacketWriteConn) Close() error { return nil } func (c *blockingPacketWriteConn) LocalAddr() net.Addr { return &net.UDPAddr{IP: net.IPv4zero, Port: 1} } func (c *blockingPacketWriteConn) RemoteAddr() net.Addr { return &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 2} } func (c *blockingPacketWriteConn) SetDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) SetReadDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) SetWriteDeadline(time.Time) error { return nil } func (c *blockingPacketWriteConn) Write(p []byte) (int, error) { if c.writeCount.Add(1) == 1 { close(c.startCh) <-c.unblockCh } return len(p), nil } type vectoredShortWriteConn struct { steps []int64 idx int buf bytes.Buffer writes int writev int } func (c *vectoredShortWriteConn) Read([]byte) (int, error) { return 0, io.EOF } func (c *vectoredShortWriteConn) Write(p []byte) (int, error) { c.writes++ return c.buf.Write(p) } func (c *vectoredShortWriteConn) Close() error { return nil } func (c *vectoredShortWriteConn) LocalAddr() net.Addr { return nil } func (c *vectoredShortWriteConn) RemoteAddr() net.Addr { return nil } func (c *vectoredShortWriteConn) SetDeadline(time.Time) error { return nil } func (c *vectoredShortWriteConn) SetReadDeadline(time.Time) error { return nil } func (c *vectoredShortWriteConn) SetWriteDeadline(time.Time) error { return nil } func (c *vectoredShortWriteConn) WriteBuffers(bufs *net.Buffers) (int64, error) { c.writev++ if c.idx >= len(c.steps) { return 0, io.ErrNoProgress } remaining := c.steps[c.idx] c.idx++ written := int64(0) for len(*bufs) > 0 && remaining > 0 { part := (*bufs)[0] if len(part) == 0 { (*bufs)[0] = nil *bufs = (*bufs)[1:] continue } n := int64(len(part)) if n > remaining { n = remaining } _, _ = c.buf.Write(part[:n]) written += n remaining -= n if n == int64(len(part)) { (*bufs)[0] = nil *bufs = (*bufs)[1:] continue } (*bufs)[0] = part[n:] break } return written, nil } type unwrapVectoredConn struct { inner net.Conn } func (c *unwrapVectoredConn) Read(p []byte) (int, error) { return c.inner.Read(p) } func (c *unwrapVectoredConn) Write(p []byte) (int, error) { return c.inner.Write(p) } func (c *unwrapVectoredConn) Close() error { return c.inner.Close() } func (c *unwrapVectoredConn) LocalAddr() net.Addr { return c.inner.LocalAddr() } func (c *unwrapVectoredConn) RemoteAddr() net.Addr { return c.inner.RemoteAddr() } func (c *unwrapVectoredConn) SetDeadline(t time.Time) error { return c.inner.SetDeadline(t) } func (c *unwrapVectoredConn) SetReadDeadline(t time.Time) error { return c.inner.SetReadDeadline(t) } func (c *unwrapVectoredConn) SetWriteDeadline(t time.Time) error { return c.inner.SetWriteDeadline(t) } func (c *unwrapVectoredConn) UnwrapConn() net.Conn { return c.inner } func TestWriteNetBuffersFullUnlockedFallsBackToDirectWritesAfterFirstPartialVectoredWrite(t *testing.T) { conn := &vectoredShortWriteConn{steps: []int64{3}} header := []byte("head") payload := []byte("payload") if err := writeNetBuffersFullUnlocked(conn, net.Buffers{header, payload}); err != nil { t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err) } if got, want := conn.writev, 1; got != want { t.Fatalf("vectored write calls = %d, want %d", got, want) } if got, want := conn.writes, 2; got != want { t.Fatalf("fallback direct writes = %d, want %d", got, want) } if got, want := conn.buf.String(), "headpayload"; got != want { t.Fatalf("written bytes = %q, want %q", got, want) } } func TestWriteNetBuffersFullUnlockedReturnsNoProgressWhenVectoredWriteDoesNotAdvance(t *testing.T) { conn := &vectoredShortWriteConn{steps: []int64{0}} err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}) if !errors.Is(err, io.ErrNoProgress) { t.Fatalf("writeNetBuffersFullUnlocked error = %v, want %v", err, io.ErrNoProgress) } } func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) { inner := &vectoredShortWriteConn{steps: []int64{100}} conn := &unwrapVectoredConn{inner: inner} if err := writeNetBuffersFullUnlocked(conn, net.Buffers{[]byte("head"), []byte("payload")}); err != nil { t.Fatalf("writeNetBuffersFullUnlocked failed: %v", err) } if got, want := inner.writev, 1; got != want { t.Fatalf("unwrapped vectored write calls = %d, want %d", got, want) } if got := inner.writes; got != 0 { t.Fatalf("unexpected fallback direct writes = %d, want 0", got) } } func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) sender := newTestBulkBatchSender(binding) defer sender.stop() firstErrCh := make(chan error, 1) go func() { firstErrCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("first")) }() select { case <-conn.startCh: case <-time.After(time.Second): t.Fatal("first shared bulk write did not start") } ctx, cancel := context.WithCancel(context.Background()) secondErrCh := make(chan error, 1) go func() { secondErrCh <- sender.submitData(ctx, 1, 2, bulkFastPathVersionV1, []byte("second")) }() time.Sleep(20 * time.Millisecond) cancel() select { case err := <-secondErrCh: if !errors.Is(err, context.Canceled) { t.Fatalf("second shared bulk submit error = %v, want %v", err, context.Canceled) } case <-time.After(time.Second): t.Fatal("second shared bulk submit did not return after cancel") } close(conn.unblockCh) select { case err := <-firstErrCh: if err != nil { t.Fatalf("first shared bulk submit failed: %v", err) } case <-time.After(time.Second): t.Fatal("first shared bulk submit did not finish") } time.Sleep(50 * time.Millisecond) if got, want := conn.writeCount.Load(), int32(1); got != want { t.Fatalf("shared bulk write count = %d, want %d", got, want) } } func TestBulkBatchSenderDoesNotDirectSubmitShareableV2Data(t *testing.T) { sender := &bulkBatchSender{} req := bulkBatchRequest{ frames: []bulkFastFrame{{ Type: bulkFastPayloadTypeData, DataID: 1, Seq: 1, Payload: make([]byte, 256*1024), }}, fastPathVersion: bulkFastPathVersionV2, } if sender.shouldDirectSubmit(req) { t.Fatal("shareable v2 shared bulk data should queue for super-batch instead of direct submit") } } func TestBulkBatchSenderDirectSubmitsUnbatchableRequest(t *testing.T) { sender := &bulkBatchSender{} req := bulkBatchRequest{ frames: []bulkFastFrame{{ Type: bulkFastPayloadTypeRelease, DataID: 1, Seq: 0, Payload: []byte("rel"), }}, fastPathVersion: bulkFastPathVersionV2, } if !sender.shouldDirectSubmit(req) { t.Fatal("unbatchable shared bulk control request should still direct submit") } } func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) sender := newTestBulkBatchSender(binding) defer sender.stop() ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) go func() { errCh <- sender.submitData(ctx, 1, 1, bulkFastPathVersionV1, []byte("payload")) }() select { case <-conn.startCh: case <-time.After(time.Second): t.Fatal("shared bulk write did not start") } cancel() select { case err := <-errCh: t.Fatalf("sender.submit returned before flush completed: %v", err) case <-time.After(50 * time.Millisecond): } close(conn.unblockCh) select { case err := <-errCh: if err != nil { t.Fatalf("sender.submit failed after started flush: %v", err) } case <-time.After(time.Second): t.Fatal("sender.submit did not return after started flush completed") } } func TestTransportBindingStopBackgroundWorkersStopsSharedSender(t *testing.T) { binding := newTransportBinding(newBlockingPacketWriteConn(), stario.NewQueue()) sender := binding.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { return append([]byte(nil), frame.Payload...), nil, nil }, encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { payload, err := encodeBulkFastBatchPlain(frames) return payload, nil, err }, }, nil) binding.stopBackgroundWorkers() err := sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload")) if !errors.Is(err, errTransportDetached) { t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) } } func TestTransportBindingStopBackgroundWorkersStopsControlSender(t *testing.T) { binding := newTransportBinding(&serializedWriteTestConn{}, stario.NewQueue()) sender := binding.controlBatchSenderSnapshot() binding.stopBackgroundWorkers() err := sender.submit([]byte("payload"), time.Time{}) if !errors.Is(err, errTransportDetached) { t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) } }