package notify import ( "context" "errors" "net" "strings" "sync/atomic" "testing" "time" ) type releaseP0TestAddr string func (a releaseP0TestAddr) Network() string { return "tcp" } func (a releaseP0TestAddr) String() string { return string(a) } type closeInspectConn struct { closeFn func() closed atomic.Bool } func (c *closeInspectConn) Read([]byte) (int, error) { return 0, net.ErrClosed } func (c *closeInspectConn) Write(p []byte) (int, error) { return len(p), nil } func (c *closeInspectConn) LocalAddr() net.Addr { return releaseP0TestAddr("local") } func (c *closeInspectConn) RemoteAddr() net.Addr { return releaseP0TestAddr("remote") } func (c *closeInspectConn) SetDeadline(time.Time) error { return nil } func (c *closeInspectConn) SetReadDeadline(time.Time) error { return nil } func (c *closeInspectConn) SetWriteDeadline(time.Time) error { return nil } func (c *closeInspectConn) Close() error { if c == nil { return nil } if c.closed.CompareAndSwap(false, true) && c.closeFn != nil { c.closeFn() } return nil } func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) logical := &LogicalConn{server: server} logical.setID("logical-only") logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28080")) logical.markSessionStarted() logical.markIdentityBound() logical.markStreamTransport() logical.markTransportAttached() logical.setClientConnLastHeartbeatUnix(time.Now().Unix()) logical.markTransportDetached("read error", errors.New("boom")) snapshot, err := GetLogicalConnRuntimeSnapshot(logical) if err != nil { t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err) } if got, want := snapshot.ClientID, "logical-only"; got != want { t.Fatalf("ClientID = %q, want %q", got, want) } if got, want := snapshot.RemoteAddress, "127.0.0.1:28080"; got != want { t.Fatalf("RemoteAddress = %q, want %q", got, want) } if !snapshot.Alive { t.Fatal("Alive should be true") } if !snapshot.IdentityBound { t.Fatal("IdentityBound should be true") } if !snapshot.UsesStreamTransport { t.Fatal("UsesStreamTransport should be true") } if got, want := snapshot.TransportGeneration, uint64(1); got != want { t.Fatalf("TransportGeneration = %d, want %d", got, want) } if got, want := snapshot.TransportDetachReason, "read error"; got != want { t.Fatalf("TransportDetachReason = %q, want %q", got, want) } if got, want := snapshot.TransportDetachError, "boom"; got != want { t.Fatalf("TransportDetachError = %q, want %q", got, want) } if !snapshot.ReattachEligible { t.Fatal("ReattachEligible should be true") } } func TestPendingWaitClosedErrorWithTransportDetail(t *testing.T) { logical := &LogicalConn{} logical.markSessionStarted() logical.markStreamTransport() logical.markTransportAttached() logical.markTransportDetached("read error", errors.New("boom")) err := pendingWaitClosedErrorWith(nil, transportDetachedErrorForLogical(logical)) if !errors.Is(err, errTransportDetached) { t.Fatalf("pendingWaitClosedErrorWith = %v, want transport detached", err) } if !strings.Contains(err.Error(), "read error") || !strings.Contains(err.Error(), "boom") { t.Fatalf("pendingWaitClosedErrorWith detail = %q, want read error and boom", err.Error()) } } func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) { runtime := newBulkRuntime("dedicated-read-error") bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ BulkID: "dedicated-read-error", DataID: 1, Dedicated: true, Range: BulkRange{ Length: 1, }, }, 0, nil, nil, 0, nil, nil, nil, nil, nil) if err := runtime.register(clientFileScope(), bulk); err != nil { t.Fatalf("register bulk failed: %v", err) } handleDedicatedBulkReadError(bulk, errors.New("boom read")) resetErr := bulk.resetErrSnapshot() if !errors.Is(resetErr, errTransportDetached) { t.Fatalf("resetErr = %v, want transport detached", resetErr) } if !strings.Contains(resetErr.Error(), "dedicated bulk read error") || !strings.Contains(resetErr.Error(), "boom read") { t.Fatalf("resetErr detail = %q, want dedicated read detail", resetErr.Error()) } } func TestHandleClientDedicatedSidecarFailureMarksBulkBeforeClosingConn(t *testing.T) { client := NewClient().(*ClientCommon) runtime := client.getBulkRuntime() if runtime == nil { t.Fatal("client bulk runtime should not be nil") } bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ BulkID: "sidecar-order", DataID: 7, Dedicated: true, Range: BulkRange{ Length: 1, }, }, 0, nil, nil, 0, nil, nil, nil, nil, nil) var closeObservedErr error conn := &closeInspectConn{ closeFn: func() { closeObservedErr = bulk.resetErrSnapshot() }, } if err := bulk.attachDedicatedConnShared(conn); err != nil { t.Fatalf("attachDedicatedConnShared failed: %v", err) } if err := runtime.register(clientFileScope(), bulk); err != nil { t.Fatalf("register bulk failed: %v", err) } sidecar := newBulkDedicatedSidecar(conn, 1) client.installClientDedicatedSidecar(1, sidecar) client.handleClientDedicatedSidecarFailure(sidecar, errors.New("boom sidecar")) if !errors.Is(closeObservedErr, errTransportDetached) { t.Fatalf("closeObservedErr = %v, want transport detached", closeObservedErr) } if !strings.Contains(closeObservedErr.Error(), "dedicated bulk read error") || !strings.Contains(closeObservedErr.Error(), "boom sidecar") { t.Fatalf("closeObservedErr detail = %q, want dedicated read error and cause", closeObservedErr.Error()) } } func TestCleanupClientSessionResourcesMarksBulkBeforeClosingSidecar(t *testing.T) { client := NewClient().(*ClientCommon) runtime := client.getBulkRuntime() if runtime == nil { t.Fatal("client bulk runtime should not be nil") } bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ BulkID: "cleanup-order", DataID: 9, Dedicated: true, Range: BulkRange{ Length: 1, }, }, 0, nil, nil, 0, nil, nil, nil, nil, nil) var closeObservedErr error conn := &closeInspectConn{ closeFn: func() { closeObservedErr = bulk.resetErrSnapshot() }, } if err := bulk.attachDedicatedConnShared(conn); err != nil { t.Fatalf("attachDedicatedConnShared failed: %v", err) } if err := runtime.register(clientFileScope(), bulk); err != nil { t.Fatalf("register bulk failed: %v", err) } client.installClientDedicatedSidecar(1, newBulkDedicatedSidecar(conn, 1)) client.cleanupClientSessionResources() if !errors.Is(closeObservedErr, errServiceShutdown) { t.Fatalf("closeObservedErr = %v, want %v", closeObservedErr, errServiceShutdown) } } func TestBestEffortRejectInboundDedicatedDataUsesDedicatedResetRecord(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) logical := server.bootstrapAcceptedLogical("dedicated-reject", nil, nil) if logical == nil { t.Fatal("bootstrapAcceptedLogical should return logical") } conn := newBulkAttachScriptConn(nil) server.bestEffortRejectInboundDedicatedData(logical, conn, 42, "unknown data id") recordConn := newBulkAttachScriptConn(conn.writtenBytes()) payload, err := readBulkDedicatedRecord(recordConn) if err != nil { t.Fatalf("readBulkDedicatedRecord failed: %v", err) } plain, err := server.decryptTransportPayloadLogical(logical, payload) if err != nil { t.Fatalf("decryptTransportPayloadLogical failed: %v", err) } items, err := decodeDedicatedBulkInboundItems(42, plain) if err != nil { t.Fatalf("decodeDedicatedBulkInboundItems failed: %v", err) } if len(items) != 1 { t.Fatalf("decoded items = %d, want 1", len(items)) } if items[0].Type != bulkFastPayloadTypeReset { t.Fatalf("reset item type = %d, want %d", items[0].Type, bulkFastPayloadTypeReset) } if got, want := string(items[0].Payload), "unknown data id"; got != want { t.Fatalf("reset payload = %q, want %q", got, want) } } func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) logical := &LogicalConn{} logical.setID("logical-only") logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28081")) got := server.registerAcceptedLogical(logical) if got != logical { t.Fatalf("registerAcceptedLogical returned %p, want %p", got, logical) } if logical.compatClientConn() != nil { t.Fatal("logical-only peer should not grow a compatibility client") } if logical.Server() != server { t.Fatal("logical-only peer should inherit server owner") } if logical.msgEnSnapshot() == nil || logical.msgDeSnapshot() == nil { t.Fatal("logical-only peer should inherit transport codec profile") } if found := server.GetLogicalConn("logical-only"); found != logical { t.Fatalf("GetLogicalConn returned %p, want %p", found, logical) } if err := server.renameAcceptedLogical(logical, "logical-only-renamed"); err != nil { t.Fatalf("renameAcceptedLogical failed: %v", err) } if found := server.GetLogicalConn("logical-only"); found != nil { t.Fatalf("old logical id should be removed, got %p", found) } if found := server.GetLogicalConn("logical-only-renamed"); found != logical { t.Fatalf("renamed logical lookup returned %p, want %p", found, logical) } server.removeLogical(logical) if found := server.GetLogicalConn("logical-only-renamed"); found != nil { t.Fatalf("removeLogical should delete logical-only peer, got %p", found) } } func TestEncodeDecodeEnvelopeLogicalWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) logical := &LogicalConn{} logical.setID("logical-codec") server.registerAcceptedLogical(logical) env := newSignalAckEnvelope(42) payload, err := server.encodeEnvelopePayloadLogical(logical, env) if err != nil { t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err) } decoded, err := server.decodeEnvelopeLogical(logical, payload) if err != nil { t.Fatalf("decodeEnvelopeLogical failed: %v", err) } if got, want := decoded.Kind, env.Kind; got != want { t.Fatalf("decoded Kind = %v, want %v", got, want) } if got, want := decoded.ID, env.ID; got != want { t.Fatalf("decoded ID = %d, want %d", got, want) } } func TestAttachAcceptedLogicalTransportWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) logical := &LogicalConn{} logical.setID("logical-transport") left, right := net.Pipe() defer left.Close() defer right.Close() if err := server.attachAcceptedLogicalTransport(logical, releaseP0TestAddr("127.0.0.1:28082"), left); err != nil { t.Fatalf("attachAcceptedLogicalTransport failed: %v", err) } if logical.Server() != server { t.Fatal("attachAcceptedLogicalTransport should bind server owner") } transport := logical.CurrentTransportConn() if transport == nil { t.Fatal("CurrentTransportConn should expose attached transport") } if !transport.Attached() || !transport.HasRuntimeConn() { t.Fatalf("transport snapshot mismatch: %+v", transport) } inbound := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration(), true) if inbound == nil { t.Fatal("transportConnSnapshotForInbound should work without compatibility client") } if !inbound.Attached() { t.Fatalf("inbound transport should be attached: %+v", inbound) } if stopFn := logical.stopFuncSnapshot(); stopFn != nil { stopFn() } } func TestResolveInboundSourceValueWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) logical := &LogicalConn{} logical.setID("logical-source") logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28083")) server.registerAcceptedLogical(logical) resolved, transport := server.resolveInboundSourceValue(serverInboundSource{ Source: logical.ID(), Logical: logical, RemoteAddr: logical.RemoteAddr(), TransportGeneration: 1, }) if resolved != logical { t.Fatalf("resolved logical = %p, want %p", resolved, logical) } if transport == nil { t.Fatal("resolveInboundSourceValue should return transport snapshot for logical-only peer") } if transport.LogicalConn() != logical { t.Fatalf("transport logical = %p, want %p", transport.LogicalConn(), logical) } if got, want := transportConnAddrString(transport.RemoteAddr()), transportConnAddrString(logical.RemoteAddr()); got != want { t.Fatalf("transport remote addr = %q, want %q", got, want) } }