200 lines
5.7 KiB
Go
200 lines
5.7 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"fmt"
|
||
|
|
"path/filepath"
|
||
|
|
"runtime"
|
||
|
|
"sync"
|
||
|
|
"testing"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
func TestSignalRoundTripConcurrentTCP(t *testing.T) {
|
||
|
|
server, addr := startSignalRoundTripServer(t, "tcp", "127.0.0.1:0")
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
runConcurrentSignalRoundTripClients(t, "tcp", addr, 24)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSignalRoundTripConcurrentUnix(t *testing.T) {
|
||
|
|
if runtime.GOOS == "windows" {
|
||
|
|
t.Skip("unix socket is not available on windows")
|
||
|
|
}
|
||
|
|
addr := filepath.Join(t.TempDir(), "notify-signal.sock")
|
||
|
|
server, endpoint := startSignalRoundTripServer(t, "unix", addr)
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
runConcurrentSignalRoundTripClients(t, "unix", endpoint, 24)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSignalRoundTripConcurrentUDP(t *testing.T) {
|
||
|
|
server, addr := startSignalRoundTripServer(t, "udp", "127.0.0.1:0")
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
runConcurrentSignalRoundTripClients(t, "udp", addr, 24)
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSignalReliabilityStatsRoundTripTCP(t *testing.T) {
|
||
|
|
runSignalReliabilityStatsRoundTrip(t, "tcp", "127.0.0.1:0")
|
||
|
|
}
|
||
|
|
|
||
|
|
func TestSignalReliabilityStatsRoundTripUDP(t *testing.T) {
|
||
|
|
runSignalReliabilityStatsRoundTrip(t, "udp", "127.0.0.1:0")
|
||
|
|
}
|
||
|
|
|
||
|
|
func runSignalReliabilityStatsRoundTrip(t *testing.T, network string, addr string) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
server, addr := startSignalRoundTripServer(t, network, addr)
|
||
|
|
defer func() {
|
||
|
|
_ = server.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
reliableOpts := &SignalReliabilityOptions{
|
||
|
|
Enabled: true,
|
||
|
|
AckTimeout: 2 * time.Second,
|
||
|
|
SendRetry: 4,
|
||
|
|
ReceiveCacheLimit: 64,
|
||
|
|
}
|
||
|
|
if err := UseSignalReliabilityServer(server, reliableOpts); err != nil {
|
||
|
|
t.Fatalf("UseSignalReliabilityServer failed: %v", err)
|
||
|
|
}
|
||
|
|
|
||
|
|
client := NewClient().(*ClientCommon)
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := UseSignalReliabilityClient(client, reliableOpts); err != nil {
|
||
|
|
t.Fatalf("UseSignalReliabilityClient failed: %v", err)
|
||
|
|
}
|
||
|
|
if err := client.Connect(network, addr); err != nil {
|
||
|
|
t.Fatalf("client Connect failed: %v", err)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
reply, err := client.SendWait("signal-roundtrip", []byte("stats-check"), 6*time.Second)
|
||
|
|
if err != nil {
|
||
|
|
clientStats, clientStatsErr := GetSignalReliabilityStatsClient(client)
|
||
|
|
serverStats, serverStatsErr := GetSignalReliabilityStatsServer(server)
|
||
|
|
t.Fatalf("SendWait failed: %v (clientStats=%+v clientStatsErr=%v serverStats=%+v serverStatsErr=%v)", err, clientStats, clientStatsErr, serverStats, serverStatsErr)
|
||
|
|
}
|
||
|
|
if got, want := string(reply.Value), "ack:stats-check"; got != want {
|
||
|
|
t.Fatalf("reply mismatch: got %q want %q", got, want)
|
||
|
|
}
|
||
|
|
|
||
|
|
var clientStats SignalReliabilityStats
|
||
|
|
var serverStats SignalReliabilityStats
|
||
|
|
deadline := time.Now().Add(2 * time.Second)
|
||
|
|
for {
|
||
|
|
clientStats, err = GetSignalReliabilityStatsClient(client)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetSignalReliabilityStatsClient failed: %v", err)
|
||
|
|
}
|
||
|
|
serverStats, err = GetSignalReliabilityStatsServer(server)
|
||
|
|
if err != nil {
|
||
|
|
t.Fatalf("GetSignalReliabilityStatsServer failed: %v", err)
|
||
|
|
}
|
||
|
|
if clientStats.AckDeliverTotal >= 1 && serverStats.AckSendTotal >= 1 {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
if time.Now().After(deadline) {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
time.Sleep(10 * time.Millisecond)
|
||
|
|
}
|
||
|
|
|
||
|
|
if clientStats.SignalSendTotal < 1 || clientStats.ReliableSendTotal < 1 || clientStats.AckWaitTotal < 1 || clientStats.AckDeliverTotal < 1 {
|
||
|
|
t.Fatalf("client signal reliability stats mismatch: %+v", clientStats)
|
||
|
|
}
|
||
|
|
if serverStats.AckSendTotal < 1 {
|
||
|
|
t.Fatalf("server signal reliability stats mismatch: %+v", serverStats)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func startSignalRoundTripServer(t *testing.T, network string, addr string) (*ServerCommon, string) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
server := NewServer().(*ServerCommon)
|
||
|
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||
|
|
}
|
||
|
|
server.SetLink("signal-roundtrip", func(msg *Message) {
|
||
|
|
_ = msg.Reply([]byte("ack:" + string(msg.Value)))
|
||
|
|
})
|
||
|
|
if err := server.Listen(network, addr); err != nil {
|
||
|
|
t.Fatalf("server Listen failed: %v", err)
|
||
|
|
}
|
||
|
|
return server, signalRoundTripServerAddr(server, addr)
|
||
|
|
}
|
||
|
|
|
||
|
|
func signalRoundTripServerAddr(server *ServerCommon, fallback string) string {
|
||
|
|
if server == nil {
|
||
|
|
return fallback
|
||
|
|
}
|
||
|
|
if server.listener != nil && server.listener.Addr() != nil {
|
||
|
|
if value := server.listener.Addr().String(); value != "" {
|
||
|
|
return value
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if server.udpListener != nil && server.udpListener.LocalAddr() != nil {
|
||
|
|
if value := server.udpListener.LocalAddr().String(); value != "" {
|
||
|
|
return value
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return fallback
|
||
|
|
}
|
||
|
|
|
||
|
|
func runConcurrentSignalRoundTripClients(t *testing.T, network string, addr string, total int) {
|
||
|
|
t.Helper()
|
||
|
|
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
errCh := make(chan error, total)
|
||
|
|
|
||
|
|
for i := 0; i < total; i++ {
|
||
|
|
wg.Add(1)
|
||
|
|
go func(i int) {
|
||
|
|
defer wg.Done()
|
||
|
|
client := NewClient()
|
||
|
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||
|
|
errCh <- fmt.Errorf("client %d configure security: %w", i, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if err := client.Connect(network, addr); err != nil {
|
||
|
|
errCh <- fmt.Errorf("client %d connect: %w", i, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
_ = client.Stop()
|
||
|
|
}()
|
||
|
|
|
||
|
|
payload := fmt.Sprintf("hello-%d", i)
|
||
|
|
reply, err := client.SendWait("signal-roundtrip", []byte(payload), 3*time.Second)
|
||
|
|
if err != nil {
|
||
|
|
errCh <- fmt.Errorf("client %d SendWait: %w", i, err)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
want := "ack:" + payload
|
||
|
|
if got := string(reply.Value); got != want {
|
||
|
|
errCh <- fmt.Errorf("client %d reply mismatch: got %q want %q", i, got, want)
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}(i)
|
||
|
|
}
|
||
|
|
|
||
|
|
wg.Wait()
|
||
|
|
close(errCh)
|
||
|
|
for err := range errCh {
|
||
|
|
if err != nil {
|
||
|
|
t.Fatal(err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|