package notify import ( "b612.me/stario" "context" "net" "sync" "sync/atomic" "testing" "time" ) type countingConn struct { mu sync.Mutex writeCount atomic.Int32 closed atomic.Bool localAddr net.Addr remoteAddr net.Addr } func (c *countingConn) Read(_ []byte) (int, error) { return 0, net.ErrClosed } func (c *countingConn) Write(p []byte) (int, error) { c.writeCount.Add(1); return len(p), nil } func (c *countingConn) Close() error { c.closed.Store(true); return nil } func (c *countingConn) LocalAddr() net.Addr { if c.localAddr != nil { return c.localAddr } return countingAddr("local") } func (c *countingConn) RemoteAddr() net.Addr { if c.remoteAddr != nil { return c.remoteAddr } return countingAddr("remote") } func (c *countingConn) SetDeadline(time.Time) error { return nil } func (c *countingConn) SetReadDeadline(time.Time) error { return nil } func (c *countingConn) SetWriteDeadline(time.Time) error { return nil } type countingAddr string func (a countingAddr) Network() string { return "counting" } func (a countingAddr) String() string { return string(a) } func TestStartClientWithConnResetsByeFromServer(t *testing.T) { server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } }) client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } client.setByeFromServer(true) left, right := net.Pipe() defer func() { _ = left.Close() _ = right.Close() }() bootstrapPeerAttachLogicalForTest(t, server, right) if err := client.startClientWithConn(left); err != nil { t.Fatalf("startClientWithConn failed: %v", err) } if !client.shouldSayGoodByeOnStop() { t.Fatal("new session should reset bye-from-server state") } if source := client.clientConnectSourceSnapshot(); source == nil || source.kind != clientConnectSourceConn { t.Fatalf("client connect source should record direct conn start, got %+v", source) } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("client Stop failed: %v", err) } } func TestServerCleanupLostHeartbeatClientsStopsExpiredOnly(t *testing.T) { server := NewServer().(*ServerCommon) server.SetHeartbeatTimeoutSec(10) now := time.Now().Unix() staleStopCtx, staleStopFn := context.WithCancel(context.Background()) defer staleStopFn() stale, _, _ := newRegisteredServerLogicalForTest(t, server, "stale-client", nil, staleStopCtx, staleStopFn) stale.setClientConnLastHeartbeatUnix(now - 20) activeStopCtx, activeStopFn := context.WithCancel(context.Background()) defer activeStopFn() active, _, _ := newRegisteredServerLogicalForTest(t, server, "active-client", nil, activeStopCtx, activeStopFn) active.setClientConnLastHeartbeatUnix(now) server.cleanupLostHeartbeatClients(time.Unix(now, 0)) if got := server.GetLogicalConn(stale.ClientID); got != nil { t.Fatalf("stale client should be removed, got %+v", got) } staleStatus := stale.Status() if staleStatus.Alive || staleStatus.Reason != "heartbeat timeout" { t.Fatalf("stale client status mismatch: %+v", staleStatus) } if got := server.GetLogicalConn(active.ClientID); got == nil { t.Fatal("active client should remain in pool") } } func TestRetireClientSessionRuntimeSuppressesGoodByeOnStop(t *testing.T) { client := NewClient().(*ClientCommon) client.markSessionStarted() currentConn := &countingConn{} currentCtx, currentCancel := context.WithCancel(context.Background()) defer currentCancel() client.setClientSessionRuntime(newClientSessionRuntime(currentConn, currentCtx, currentCancel, stario.NewQueueCtx(currentCtx, 4, 16), 2)) oldConn := &countingConn{} oldCtx, oldCancel := context.WithCancel(context.Background()) oldRT := newClientSessionRuntime(oldConn, oldCtx, oldCancel, stario.NewQueueCtx(oldCtx, 4, 16), 1) done := make(chan struct{}) go func() { client.loadMessageLoop(oldRT) close(done) }() client.retireClientSessionRuntime(oldRT, true) select { case <-done: case <-time.After(time.Second): t.Fatal("loadMessageLoop should exit after runtime retire") } if got := currentConn.writeCount.Load(); got != 0 { t.Fatalf("retired runtime should not send goodbye through current runtime, got %d writes", got) } if !oldConn.closed.Load() { t.Fatal("retired runtime transport should be closed") } }