package notify import ( "fmt" "path/filepath" "runtime" "sync" "testing" "time" ) func TestSignalRoundTripConcurrentTCP(t *testing.T) { server, addr := startSignalRoundTripServer(t, "tcp", "127.0.0.1:0") defer func() { _ = server.Stop() }() runConcurrentSignalRoundTripClients(t, "tcp", addr, 24) } func TestSignalRoundTripConcurrentUnix(t *testing.T) { if runtime.GOOS == "windows" { t.Skip("unix socket is not available on windows") } addr := filepath.Join(t.TempDir(), "notify-signal.sock") server, endpoint := startSignalRoundTripServer(t, "unix", addr) defer func() { _ = server.Stop() }() runConcurrentSignalRoundTripClients(t, "unix", endpoint, 24) } func TestSignalRoundTripConcurrentUDP(t *testing.T) { server, addr := startSignalRoundTripServer(t, "udp", "127.0.0.1:0") defer func() { _ = server.Stop() }() runConcurrentSignalRoundTripClients(t, "udp", addr, 24) } func TestSignalReliabilityStatsRoundTripTCP(t *testing.T) { runSignalReliabilityStatsRoundTrip(t, "tcp", "127.0.0.1:0") } func TestSignalReliabilityStatsRoundTripUDP(t *testing.T) { runSignalReliabilityStatsRoundTrip(t, "udp", "127.0.0.1:0") } func runSignalReliabilityStatsRoundTrip(t *testing.T, network string, addr string) { t.Helper() server, addr := startSignalRoundTripServer(t, network, addr) defer func() { _ = server.Stop() }() reliableOpts := &SignalReliabilityOptions{ Enabled: true, AckTimeout: 2 * time.Second, SendRetry: 4, ReceiveCacheLimit: 64, } if err := UseSignalReliabilityServer(server, reliableOpts); err != nil { t.Fatalf("UseSignalReliabilityServer failed: %v", err) } client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKClient failed: %v", err) } if err := UseSignalReliabilityClient(client, reliableOpts); err != nil { t.Fatalf("UseSignalReliabilityClient failed: %v", err) } if err := client.Connect(network, addr); err != nil { t.Fatalf("client Connect failed: %v", err) } defer func() { _ = client.Stop() }() reply, err := client.SendWait("signal-roundtrip", []byte("stats-check"), 6*time.Second) if err != nil { clientStats, clientStatsErr := GetSignalReliabilityStatsClient(client) serverStats, serverStatsErr := GetSignalReliabilityStatsServer(server) t.Fatalf("SendWait failed: %v (clientStats=%+v clientStatsErr=%v serverStats=%+v serverStatsErr=%v)", err, clientStats, clientStatsErr, serverStats, serverStatsErr) } if got, want := string(reply.Value), "ack:stats-check"; got != want { t.Fatalf("reply mismatch: got %q want %q", got, want) } var clientStats SignalReliabilityStats var serverStats SignalReliabilityStats deadline := time.Now().Add(2 * time.Second) for { clientStats, err = GetSignalReliabilityStatsClient(client) if err != nil { t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err) } serverStats, err = GetSignalReliabilityStatsServer(server) if err != nil { t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err) } if clientStats.AckDeliverTotal >= 1 && serverStats.AckSendTotal >= 1 { break } if time.Now().After(deadline) { break } time.Sleep(10 * time.Millisecond) } if clientStats.SignalSendTotal < 1 || clientStats.ReliableSendTotal < 1 || clientStats.AckWaitTotal < 1 || clientStats.AckDeliverTotal < 1 { t.Fatalf("client signal reliability stats mismatch: %+v", clientStats) } if serverStats.AckSendTotal < 1 { t.Fatalf("server signal reliability stats mismatch: %+v", serverStats) } } func startSignalRoundTripServer(t *testing.T, network string, addr string) (*ServerCommon, string) { t.Helper() server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { t.Fatalf("UseModernPSKServer failed: %v", err) } server.SetLink("signal-roundtrip", func(msg *Message) { _ = msg.Reply([]byte("ack:" + string(msg.Value))) }) if err := server.Listen(network, addr); err != nil { t.Fatalf("server Listen failed: %v", err) } return server, signalRoundTripServerAddr(server, addr) } func signalRoundTripServerAddr(server *ServerCommon, fallback string) string { if server == nil { return fallback } if server.listener != nil && server.listener.Addr() != nil { if value := server.listener.Addr().String(); value != "" { return value } } if server.udpListener != nil && server.udpListener.LocalAddr() != nil { if value := server.udpListener.LocalAddr().String(); value != "" { return value } } return fallback } func runConcurrentSignalRoundTripClients(t *testing.T, network string, addr string, total int) { t.Helper() var wg sync.WaitGroup errCh := make(chan error, total) for i := 0; i < total; i++ { wg.Add(1) go func(i int) { defer wg.Done() client := NewClient() if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { errCh <- fmt.Errorf("client %d configure security: %w", i, err) return } if err := client.Connect(network, addr); err != nil { errCh <- fmt.Errorf("client %d connect: %w", i, err) return } defer func() { _ = client.Stop() }() payload := fmt.Sprintf("hello-%d", i) reply, err := client.SendWait("signal-roundtrip", []byte(payload), 3*time.Second) if err != nil { errCh <- fmt.Errorf("client %d SendWait: %w", i, err) return } want := "ack:" + payload if got := string(reply.Value); got != want { errCh <- fmt.Errorf("client %d reply mismatch: got %q want %q", i, got, want) return } }(i) } wg.Wait() close(errCh) for err := range errCh { if err != nil { t.Fatal(err) } } }