package starnotify import ( "context" "errors" "net" "sync" "testing" "time" "b612.me/notify" ) var errSingleConnListenerClosed = errors.New("single conn listener closed") type singleConnListener struct { conn net.Conn used bool mu sync.Mutex closed chan struct{} once sync.Once } func newSingleConnListener(conn net.Conn) *singleConnListener { return &singleConnListener{ conn: conn, closed: make(chan struct{}), } } func (l *singleConnListener) Accept() (net.Conn, error) { l.mu.Lock() if !l.used && l.conn != nil { conn := l.conn l.used = true l.mu.Unlock() return conn, nil } l.mu.Unlock() <-l.closed return nil, errSingleConnListenerClosed } func (l *singleConnListener) Close() error { l.once.Do(func() { close(l.closed) }) return nil } func (l *singleConnListener) Addr() net.Addr { return singleConnAddr("starnotify-single-listener") } type singleConnAddr string func (a singleConnAddr) Network() string { return "single-conn" } func (a singleConnAddr) String() string { return string(a) } func TestGetRuntimeSnapshotByKeyDefaults(t *testing.T) { const clientKey = "runtime-snapshot-client" const serverKey = "runtime-snapshot-server" _ = DeleteClient(clientKey) _ = DeleteServer(serverKey) defer DeleteClient(clientKey) defer DeleteServer(serverKey) NewClient(clientKey) NewServer(serverKey) clientSnapshot, err := GetClientRuntimeSnapshot(clientKey) if err != nil { t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) } if got, want := clientSnapshot.OwnerState, "idle"; got != want { t.Fatalf("client OwnerState mismatch: got %q want %q", got, want) } if clientSnapshot.Alive { t.Fatalf("client Alive mismatch: got %v want false", clientSnapshot.Alive) } if !clientSnapshot.HasRuntimeStopCtx { t.Fatalf("client HasRuntimeStopCtx mismatch: got %v want true", clientSnapshot.HasRuntimeStopCtx) } if clientSnapshot.Retry != (notify.ConnectionRetrySnapshot{}) { t.Fatalf("client Retry snapshot mismatch: %+v", clientSnapshot.Retry) } serverSnapshot, err := GetServerRuntimeSnapshot(serverKey) if err != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) } if got, want := serverSnapshot.OwnerState, "idle"; got != want { t.Fatalf("server OwnerState mismatch: got %q want %q", got, want) } if serverSnapshot.Alive { t.Fatalf("server Alive mismatch: got %v want false", serverSnapshot.Alive) } if !serverSnapshot.HasRuntimeStopCtx { t.Fatalf("server HasRuntimeStopCtx mismatch: got %v want true", serverSnapshot.HasRuntimeStopCtx) } if serverSnapshot.Retry != (notify.ConnectionRetrySnapshot{}) { t.Fatalf("server Retry snapshot mismatch: %+v", serverSnapshot.Retry) } } func TestGetRuntimeSnapshotMissingKey(t *testing.T) { if _, err := GetClientRuntimeSnapshot("missing-client"); err == nil { t.Fatal("GetClientRuntimeSnapshot should fail for missing key") } if _, err := GetServerRuntimeSnapshot("missing-server"); err == nil { t.Fatal("GetServerRuntimeSnapshot should fail for missing key") } if _, err := GetServerClientRuntimeSnapshot("missing-server", "peer"); err == nil { t.Fatal("GetServerClientRuntimeSnapshot should fail for missing server key") } if _, err := GetServerLogicalConn("missing-server", "peer"); err == nil { t.Fatal("GetServerLogicalConn should fail for missing server key") } if _, ok, err := GetServerCurrentTransportConn("missing-server", "peer"); err == nil || ok { t.Fatal("GetServerCurrentTransportConn should fail for missing server key") } if _, ok, err := GetServerClientTransportRuntimeSnapshot("missing-server", "peer"); err == nil || ok { t.Fatal("GetServerClientTransportRuntimeSnapshot should fail for missing server key") } if _, err := GetServerDetachedClientRuntimeSnapshots("missing-server"); err == nil { t.Fatal("GetServerDetachedClientRuntimeSnapshots should fail for missing server key") } } func TestGetRuntimeSnapshotExposesRetryState(t *testing.T) { const clientKey = "runtime-retry-client" const serverKey = "runtime-retry-server" _ = DeleteClient(clientKey) _ = DeleteServer(serverKey) defer DeleteClient(clientKey) defer DeleteServer(serverKey) NewLegacySecurityClient(clientKey) NewServer(serverKey) clientRetryErr := errors.New("dial failed") err := ConnectClientFactoryWithRetry(clientKey, func(context.Context) (net.Conn, error) { return nil, clientRetryErr }, ¬ify.ConnectRetryOptions{ MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: time.Millisecond, }) if !errors.Is(err, clientRetryErr) { t.Fatalf("ConnectClientFactoryWithRetry error = %v, want %v", err, clientRetryErr) } clientSnapshot, err := GetClientRuntimeSnapshot(clientKey) if err != nil { t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) } if got, want := clientSnapshot.Retry.RetryEventTotal, uint64(1); got != want { t.Fatalf("client RetryEventTotal mismatch: got %d want %d", got, want) } if got, want := clientSnapshot.Retry.LastRetryAttempt, 1; got != want { t.Fatalf("client LastRetryAttempt mismatch: got %d want %d", got, want) } if got, want := clientSnapshot.Retry.LastRetryError, clientRetryErr.Error(); got != want { t.Fatalf("client LastRetryError mismatch: got %q want %q", got, want) } if got, want := clientSnapshot.Retry.LastResultError, clientRetryErr.Error(); got != want { t.Fatalf("client LastResultError mismatch: got %q want %q", got, want) } if clientSnapshot.Retry.LastRetryAt.IsZero() { t.Fatal("client LastRetryAt should be recorded") } if clientSnapshot.Retry.LastResultAt.IsZero() { t.Fatal("client LastResultAt should be recorded") } serverErr := ListenServerWithRetry(serverKey, "tcp", "127.0.0.1:0", ¬ify.ConnectRetryOptions{ MaxAttempts: 2, BaseDelay: time.Millisecond, MaxDelay: time.Millisecond, }) if serverErr == nil { t.Fatal("ListenServerWithRetry should fail without security configuration") } serverSnapshot, err := GetServerRuntimeSnapshot(serverKey) if err != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) } if got, want := serverSnapshot.Retry.RetryEventTotal, uint64(1); got != want { t.Fatalf("server RetryEventTotal mismatch: got %d want %d", got, want) } if got, want := serverSnapshot.Retry.LastRetryAttempt, 1; got != want { t.Fatalf("server LastRetryAttempt mismatch: got %d want %d", got, want) } if got, want := serverSnapshot.Retry.LastRetryError, serverErr.Error(); got != want { t.Fatalf("server LastRetryError mismatch: got %q want %q", got, want) } if got, want := serverSnapshot.Retry.LastResultError, serverErr.Error(); got != want { t.Fatalf("server LastResultError mismatch: got %q want %q", got, want) } if serverSnapshot.Retry.LastRetryAt.IsZero() { t.Fatal("server LastRetryAt should be recorded") } if serverSnapshot.Retry.LastResultAt.IsZero() { t.Fatal("server LastResultAt should be recorded") } } func TestGetServerClientRuntimeSnapshotByKey(t *testing.T) { const serverKey = "runtime-peer-server" _ = DeleteServer(serverKey) defer DeleteServer(serverKey) server := NewServer(serverKey) client := notify.NewClient() secret := []byte("0123456789abcdef0123456789abcdef") server.SetSecretKey(secret) client.SetSecretKey(secret) serverConn, clientConn := net.Pipe() defer clientConn.Close() listener := newSingleConnListener(serverConn) defer listener.Close() if err := server.ListenByListener(listener); err != nil { t.Fatalf("ListenByListener failed: %v", err) } defer server.Stop() if err := client.ConnectByConn(clientConn); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer client.Stop() srv, err := Server(serverKey) if err != nil { t.Fatalf("Server lookup failed: %v", err) } var clientID string deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) { list := srv.GetLogicalConnList() if len(list) == 1 && list[0] != nil { clientID = list[0].ClientID break } time.Sleep(10 * time.Millisecond) } if clientID == "" { t.Fatal("server did not expose accepted client in time") } snapshot, err := GetServerClientRuntimeSnapshot(serverKey, clientID) if err != nil { t.Fatalf("GetServerClientRuntimeSnapshot failed: %v", err) } if got, want := snapshot.ClientID, clientID; got != want { t.Fatalf("ClientID mismatch: got %q want %q", got, want) } if !snapshot.Alive { t.Fatalf("Alive mismatch: got %v want true", snapshot.Alive) } if !snapshot.IdentityBound { t.Fatal("IdentityBound mismatch: got false want true") } if !snapshot.TransportAttached { t.Fatalf("TransportAttached mismatch: got %v want true", snapshot.TransportAttached) } logicalSnapshot, err := GetServerLogicalRuntimeSnapshot(serverKey, clientID) if err != nil { t.Fatalf("GetServerLogicalRuntimeSnapshot failed: %v", err) } if logicalSnapshot.ClientID != snapshot.ClientID || logicalSnapshot.TransportGeneration != snapshot.TransportGeneration { t.Fatalf("logical runtime snapshot mismatch: got %+v want %+v", logicalSnapshot, snapshot) } } func TestGetServerClientTransportRuntimeSnapshotByKey(t *testing.T) { const serverKey = "runtime-transport-server" _ = DeleteServer(serverKey) defer DeleteServer(serverKey) server := NewServer(serverKey) client := notify.NewClient() secret := []byte("0123456789abcdef0123456789abcdef") server.SetSecretKey(secret) client.SetSecretKey(secret) serverConn, clientConn := net.Pipe() defer clientConn.Close() listener := newSingleConnListener(serverConn) defer listener.Close() if err := server.ListenByListener(listener); err != nil { t.Fatalf("ListenByListener failed: %v", err) } defer server.Stop() if err := client.ConnectByConn(clientConn); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer client.Stop() srv, err := Server(serverKey) if err != nil { t.Fatalf("Server lookup failed: %v", err) } var clientID string deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) { list := srv.GetLogicalConnList() if len(list) == 1 && list[0] != nil { clientID = list[0].ClientID if clientID != "" { break } } time.Sleep(10 * time.Millisecond) } if clientID == "" { t.Fatal("server did not expose accepted client in time") } snapshot, ok, err := GetServerClientTransportRuntimeSnapshot(serverKey, clientID) if err != nil { t.Fatalf("GetServerClientTransportRuntimeSnapshot failed: %v", err) } if !ok { t.Fatal("GetServerClientTransportRuntimeSnapshot should report current transport") } if got, want := snapshot.ClientID, clientID; got != want { t.Fatalf("ClientID mismatch: got %q want %q", got, want) } if !snapshot.Attached { t.Fatal("Attached mismatch: got false want true") } if !snapshot.HasRuntimeConn { t.Fatal("HasRuntimeConn mismatch: got false want true") } if !snapshot.Current { t.Fatal("Current mismatch: got false want true") } transportSnapshot, ok, err := GetServerTransportRuntimeSnapshot(serverKey, clientID) if err != nil { t.Fatalf("GetServerTransportRuntimeSnapshot failed: %v", err) } if !ok { t.Fatal("GetServerTransportRuntimeSnapshot should report current transport") } if transportSnapshot.ClientID != snapshot.ClientID || transportSnapshot.TransportGeneration != snapshot.TransportGeneration { t.Fatalf("transport runtime snapshot mismatch: got %+v want %+v", transportSnapshot, snapshot) } } func TestGetServerLogicalAndTransportConnByKey(t *testing.T) { const serverKey = "runtime-conn-object-server" _ = DeleteServer(serverKey) defer DeleteServer(serverKey) server := NewServer(serverKey) client := notify.NewClient() secret := []byte("0123456789abcdef0123456789abcdef") server.SetSecretKey(secret) client.SetSecretKey(secret) serverConn, clientConn := net.Pipe() defer clientConn.Close() listener := newSingleConnListener(serverConn) defer listener.Close() if err := server.ListenByListener(listener); err != nil { t.Fatalf("ListenByListener failed: %v", err) } defer server.Stop() if err := client.ConnectByConn(clientConn); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer client.Stop() srv, err := Server(serverKey) if err != nil { t.Fatalf("Server lookup failed: %v", err) } var clientID string deadline := time.Now().Add(time.Second) for time.Now().Before(deadline) { list := srv.GetLogicalConnList() if len(list) == 1 && list[0] != nil && list[0].ClientID != "" { clientID = list[0].ClientID break } time.Sleep(10 * time.Millisecond) } if clientID == "" { t.Fatal("server did not expose accepted logical conn in time") } logical, err := GetServerLogicalConn(serverKey, clientID) if err != nil { t.Fatalf("GetServerLogicalConn failed: %v", err) } if logical == nil || logical.ClientID != clientID { t.Fatalf("logical conn mismatch: %+v", logical) } transport, ok, err := GetServerCurrentTransportConn(serverKey, clientID) if err != nil { t.Fatalf("GetServerCurrentTransportConn failed: %v", err) } if !ok { t.Fatal("GetServerCurrentTransportConn should report current transport") } if transport == nil || transport.ClientID() != clientID || !transport.IsCurrent() { t.Fatalf("transport conn mismatch: %+v", transport) } } func TestGetServerDetachedClientRuntimeSnapshotsByKey(t *testing.T) { const serverKey = "runtime-detached-server" const clientKey = "runtime-detached-client" _ = DeleteClient(clientKey) _ = DeleteServer(serverKey) defer DeleteClient(clientKey) defer DeleteServer(serverKey) server, err := NewModernPSKServer(serverKey, []byte("shared-secret"), testModernPSKOptions()) if err != nil { t.Fatalf("NewModernPSKServer failed: %v", err) } server.SetDetachedClientKeepSec(30) client, err := NewModernPSKClient(clientKey, []byte("shared-secret"), testModernPSKOptions()) if err != nil { t.Fatalf("NewModernPSKClient failed: %v", err) } serverConn, clientConn := net.Pipe() listener := newSingleConnListener(serverConn) defer listener.Close() defer clientConn.Close() if err := server.ListenByListener(listener); err != nil { t.Fatalf("ListenByListener failed: %v", err) } defer server.Stop() if err := client.ConnectByConn(clientConn); err != nil { t.Fatalf("ConnectByConn failed: %v", err) } defer client.Stop() srv, err := Server(serverKey) if err != nil { t.Fatalf("Server lookup failed: %v", err) } deadline := time.Now().Add(time.Second) var boundClientID string for time.Now().Before(deadline) { list := srv.GetClientLists() if len(list) == 1 && list[0] != nil { snapshot, snapErr := notify.GetClientConnRuntimeSnapshot(list[0]) if snapErr == nil && snapshot.IdentityBound { boundClientID = snapshot.ClientID break } } time.Sleep(10 * time.Millisecond) } if boundClientID == "" { t.Fatal("server did not bind accepted client identity in time") } if err := clientConn.Close(); err != nil { t.Fatalf("close client conn failed: %v", err) } var snapshots []notify.ClientConnRuntimeSnapshot deadline = time.Now().Add(time.Second) for time.Now().Before(deadline) { snapshots, err = GetServerDetachedClientRuntimeSnapshots(serverKey) if err != nil { t.Fatalf("GetServerDetachedClientRuntimeSnapshots failed: %v", err) } if len(snapshots) == 1 { break } time.Sleep(10 * time.Millisecond) } if len(snapshots) != 1 { t.Fatalf("detached snapshot count mismatch: got %d want 1", len(snapshots)) } snapshot := snapshots[0] if got, want := snapshot.ClientID, boundClientID; got != want { t.Fatalf("detached snapshot ClientID mismatch: got %q want %q", got, want) } if snapshot.TransportAttached { t.Fatalf("detached snapshot TransportAttached mismatch: got %v want false", snapshot.TransportAttached) } if !snapshot.IdentityBound { t.Fatal("detached snapshot should remain identity bound") } if got, want := snapshot.DetachedClientKeepSec, int64(30); got != want { t.Fatalf("detached snapshot keep seconds mismatch: got %d want %d", got, want) } if snapshot.TransportDetachedAt.IsZero() { t.Fatal("detached snapshot should expose detach time") } } func TestGetTransferSnapshotsByKeyDefaults(t *testing.T) { const clientKey = "transfer-snapshot-client" const serverKey = "transfer-snapshot-server" _ = DeleteClient(clientKey) _ = DeleteServer(serverKey) defer DeleteClient(clientKey) defer DeleteServer(serverKey) NewClient(clientKey) NewServer(serverKey) clientSnapshots, err := GetClientTransferSnapshots(clientKey) if err != nil { t.Fatalf("GetClientTransferSnapshots failed: %v", err) } if len(clientSnapshots) != 0 { t.Fatalf("client transfer snapshots count = %d, want 0", len(clientSnapshots)) } serverSnapshots, err := GetServerTransferSnapshots(serverKey) if err != nil { t.Fatalf("GetServerTransferSnapshots failed: %v", err) } if len(serverSnapshots) != 0 { t.Fatalf("server transfer snapshots count = %d, want 0", len(serverSnapshots)) } if _, ok, err := GetClientTransferSnapshotByID(clientKey, "missing"); err != nil || ok { t.Fatalf("GetClientTransferSnapshotByID = (%v, %v), want (nil, false)", err, ok) } if _, ok, err := GetClientTransferSnapshotByIDScope(clientKey, "missing", "scope-a"); err != nil || ok { t.Fatalf("GetClientTransferSnapshotByIDScope = (%v, %v), want (nil, false)", err, ok) } if _, ok, err := GetClientTransferSnapshotByIDQuery(clientKey, "missing", notify.TransferSnapshotQuery{Scope: "scope-a"}); err != nil || ok { t.Fatalf("GetClientTransferSnapshotByIDQuery = (%v, %v), want (nil, false)", err, ok) } if _, ok, err := GetServerTransferSnapshotByID(serverKey, "missing"); err != nil || ok { t.Fatalf("GetServerTransferSnapshotByID = (%v, %v), want (nil, false)", err, ok) } if _, ok, err := GetServerTransferSnapshotByIDScope(serverKey, "missing", "scope-a"); err != nil || ok { t.Fatalf("GetServerTransferSnapshotByIDScope = (%v, %v), want (nil, false)", err, ok) } if _, ok, err := GetServerTransferSnapshotByIDQuery(serverKey, "missing", notify.TransferSnapshotQuery{Scope: "scope-a"}); err != nil || ok { t.Fatalf("GetServerTransferSnapshotByIDQuery = (%v, %v), want (nil, false)", err, ok) } }