notify/signal_roundtrip_test.go

200 lines
5.7 KiB
Go
Raw Permalink Normal View History

package notify
import (
"fmt"
"path/filepath"
"runtime"
"sync"
"testing"
"time"
)
func TestSignalRoundTripConcurrentTCP(t *testing.T) {
server, addr := startSignalRoundTripServer(t, "tcp", "127.0.0.1:0")
defer func() {
_ = server.Stop()
}()
runConcurrentSignalRoundTripClients(t, "tcp", addr, 24)
}
func TestSignalRoundTripConcurrentUnix(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("unix socket is not available on windows")
}
addr := filepath.Join(t.TempDir(), "notify-signal.sock")
server, endpoint := startSignalRoundTripServer(t, "unix", addr)
defer func() {
_ = server.Stop()
}()
runConcurrentSignalRoundTripClients(t, "unix", endpoint, 24)
}
func TestSignalRoundTripConcurrentUDP(t *testing.T) {
server, addr := startSignalRoundTripServer(t, "udp", "127.0.0.1:0")
defer func() {
_ = server.Stop()
}()
runConcurrentSignalRoundTripClients(t, "udp", addr, 24)
}
func TestSignalReliabilityStatsRoundTripTCP(t *testing.T) {
runSignalReliabilityStatsRoundTrip(t, "tcp", "127.0.0.1:0")
}
func TestSignalReliabilityStatsRoundTripUDP(t *testing.T) {
runSignalReliabilityStatsRoundTrip(t, "udp", "127.0.0.1:0")
}
func runSignalReliabilityStatsRoundTrip(t *testing.T, network string, addr string) {
t.Helper()
server, addr := startSignalRoundTripServer(t, network, addr)
defer func() {
_ = server.Stop()
}()
reliableOpts := &SignalReliabilityOptions{
Enabled: true,
AckTimeout: 2 * time.Second,
SendRetry: 4,
ReceiveCacheLimit: 64,
}
if err := UseSignalReliabilityServer(server, reliableOpts); err != nil {
t.Fatalf("UseSignalReliabilityServer failed: %v", err)
}
client := NewClient().(*ClientCommon)
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKClient failed: %v", err)
}
if err := UseSignalReliabilityClient(client, reliableOpts); err != nil {
t.Fatalf("UseSignalReliabilityClient failed: %v", err)
}
if err := client.Connect(network, addr); err != nil {
t.Fatalf("client Connect failed: %v", err)
}
defer func() {
_ = client.Stop()
}()
reply, err := client.SendWait("signal-roundtrip", []byte("stats-check"), 6*time.Second)
if err != nil {
clientStats, clientStatsErr := GetSignalReliabilityStatsClient(client)
serverStats, serverStatsErr := GetSignalReliabilityStatsServer(server)
t.Fatalf("SendWait failed: %v (clientStats=%+v clientStatsErr=%v serverStats=%+v serverStatsErr=%v)", err, clientStats, clientStatsErr, serverStats, serverStatsErr)
}
if got, want := string(reply.Value), "ack:stats-check"; got != want {
t.Fatalf("reply mismatch: got %q want %q", got, want)
}
var clientStats SignalReliabilityStats
var serverStats SignalReliabilityStats
deadline := time.Now().Add(2 * time.Second)
for {
clientStats, err = GetSignalReliabilityStatsClient(client)
if err != nil {
t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err)
}
serverStats, err = GetSignalReliabilityStatsServer(server)
if err != nil {
t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err)
}
if clientStats.AckDeliverTotal >= 1 && serverStats.AckSendTotal >= 1 {
break
}
if time.Now().After(deadline) {
break
}
time.Sleep(10 * time.Millisecond)
}
if clientStats.SignalSendTotal < 1 || clientStats.ReliableSendTotal < 1 || clientStats.AckWaitTotal < 1 || clientStats.AckDeliverTotal < 1 {
t.Fatalf("client signal reliability stats mismatch: %+v", clientStats)
}
if serverStats.AckSendTotal < 1 {
t.Fatalf("server signal reliability stats mismatch: %+v", serverStats)
}
}
func startSignalRoundTripServer(t *testing.T, network string, addr string) (*ServerCommon, string) {
t.Helper()
server := NewServer().(*ServerCommon)
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
t.Fatalf("UseModernPSKServer failed: %v", err)
}
server.SetLink("signal-roundtrip", func(msg *Message) {
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
})
if err := server.Listen(network, addr); err != nil {
t.Fatalf("server Listen failed: %v", err)
}
return server, signalRoundTripServerAddr(server, addr)
}
func signalRoundTripServerAddr(server *ServerCommon, fallback string) string {
if server == nil {
return fallback
}
if server.listener != nil && server.listener.Addr() != nil {
if value := server.listener.Addr().String(); value != "" {
return value
}
}
if server.udpListener != nil && server.udpListener.LocalAddr() != nil {
if value := server.udpListener.LocalAddr().String(); value != "" {
return value
}
}
return fallback
}
func runConcurrentSignalRoundTripClients(t *testing.T, network string, addr string, total int) {
t.Helper()
var wg sync.WaitGroup
errCh := make(chan error, total)
for i := 0; i < total; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
client := NewClient()
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
errCh <- fmt.Errorf("client %d configure security: %w", i, err)
return
}
if err := client.Connect(network, addr); err != nil {
errCh <- fmt.Errorf("client %d connect: %w", i, err)
return
}
defer func() {
_ = client.Stop()
}()
payload := fmt.Sprintf("hello-%d", i)
reply, err := client.SendWait("signal-roundtrip", []byte(payload), 3*time.Second)
if err != nil {
errCh <- fmt.Errorf("client %d SendWait: %w", i, err)
return
}
want := "ack:" + payload
if got := string(reply.Value); got != want {
errCh <- fmt.Errorf("client %d reply mismatch: got %q want %q", i, got, want)
return
}
}(i)
}
wg.Wait()
close(errCh)
for err := range errCh {
if err != nil {
t.Fatal(err)
}
}
}