notify/release_p0_test.go

368 lines
12 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"errors"
"net"
"strings"
"sync/atomic"
"testing"
"time"
)
type releaseP0TestAddr string
func (a releaseP0TestAddr) Network() string { return "tcp" }
func (a releaseP0TestAddr) String() string { return string(a) }
type closeInspectConn struct {
closeFn func()
closed atomic.Bool
}
func (c *closeInspectConn) Read([]byte) (int, error) { return 0, net.ErrClosed }
func (c *closeInspectConn) Write(p []byte) (int, error) { return len(p), nil }
func (c *closeInspectConn) LocalAddr() net.Addr { return releaseP0TestAddr("local") }
func (c *closeInspectConn) RemoteAddr() net.Addr { return releaseP0TestAddr("remote") }
func (c *closeInspectConn) SetDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetReadDeadline(time.Time) error { return nil }
func (c *closeInspectConn) SetWriteDeadline(time.Time) error { return nil }
func (c *closeInspectConn) Close() error {
if c == nil {
return nil
}
if c.closed.CompareAndSwap(false, true) && c.closeFn != nil {
c.closeFn()
}
return nil
}
func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
logical := &LogicalConn{server: server}
logical.setID("logical-only")
logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28080"))
logical.markSessionStarted()
logical.markIdentityBound()
logical.markStreamTransport()
logical.markTransportAttached()
logical.setClientConnLastHeartbeatUnix(time.Now().Unix())
logical.markTransportDetached("read error", errors.New("boom"))
snapshot, err := GetLogicalConnRuntimeSnapshot(logical)
if err != nil {
t.Fatalf("GetLogicalConnRuntimeSnapshot failed: %v", err)
}
if got, want := snapshot.ClientID, "logical-only"; got != want {
t.Fatalf("ClientID = %q, want %q", got, want)
}
if got, want := snapshot.RemoteAddress, "127.0.0.1:28080"; got != want {
t.Fatalf("RemoteAddress = %q, want %q", got, want)
}
if !snapshot.Alive {
t.Fatal("Alive should be true")
}
if !snapshot.IdentityBound {
t.Fatal("IdentityBound should be true")
}
if !snapshot.UsesStreamTransport {
t.Fatal("UsesStreamTransport should be true")
}
if got, want := snapshot.TransportGeneration, uint64(1); got != want {
t.Fatalf("TransportGeneration = %d, want %d", got, want)
}
if got, want := snapshot.TransportDetachReason, "read error"; got != want {
t.Fatalf("TransportDetachReason = %q, want %q", got, want)
}
if got, want := snapshot.TransportDetachError, "boom"; got != want {
t.Fatalf("TransportDetachError = %q, want %q", got, want)
}
if !snapshot.ReattachEligible {
t.Fatal("ReattachEligible should be true")
}
}
func TestPendingWaitClosedErrorWithTransportDetail(t *testing.T) {
logical := &LogicalConn{}
logical.markSessionStarted()
logical.markStreamTransport()
logical.markTransportAttached()
logical.markTransportDetached("read error", errors.New("boom"))
err := pendingWaitClosedErrorWith(nil, transportDetachedErrorForLogical(logical))
if !errors.Is(err, errTransportDetached) {
t.Fatalf("pendingWaitClosedErrorWith = %v, want transport detached", err)
}
if !strings.Contains(err.Error(), "read error") || !strings.Contains(err.Error(), "boom") {
t.Fatalf("pendingWaitClosedErrorWith detail = %q, want read error and boom", err.Error())
}
}
func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) {
runtime := newBulkRuntime("dedicated-read-error")
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "dedicated-read-error",
DataID: 1,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
handleDedicatedBulkReadError(bulk, errors.New("boom read"))
resetErr := bulk.resetErrSnapshot()
if !errors.Is(resetErr, errTransportDetached) {
t.Fatalf("resetErr = %v, want transport detached", resetErr)
}
if !strings.Contains(resetErr.Error(), "dedicated bulk read error") || !strings.Contains(resetErr.Error(), "boom read") {
t.Fatalf("resetErr detail = %q, want dedicated read detail", resetErr.Error())
}
}
func TestHandleClientDedicatedSidecarFailureMarksBulkBeforeClosingConn(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "sidecar-order",
DataID: 7,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
sidecar := newBulkDedicatedSidecar(conn, 1)
client.installClientDedicatedSidecar(1, sidecar)
client.handleClientDedicatedSidecarFailure(sidecar, errors.New("boom sidecar"))
if !errors.Is(closeObservedErr, errTransportDetached) {
t.Fatalf("closeObservedErr = %v, want transport detached", closeObservedErr)
}
if !strings.Contains(closeObservedErr.Error(), "dedicated bulk read error") || !strings.Contains(closeObservedErr.Error(), "boom sidecar") {
t.Fatalf("closeObservedErr detail = %q, want dedicated read error and cause", closeObservedErr.Error())
}
}
func TestCleanupClientSessionResourcesMarksBulkBeforeClosingSidecar(t *testing.T) {
client := NewClient().(*ClientCommon)
runtime := client.getBulkRuntime()
if runtime == nil {
t.Fatal("client bulk runtime should not be nil")
}
bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{
BulkID: "cleanup-order",
DataID: 9,
Dedicated: true,
Range: BulkRange{
Length: 1,
},
}, 0, nil, nil, 0, nil, nil, nil, nil, nil)
var closeObservedErr error
conn := &closeInspectConn{
closeFn: func() {
closeObservedErr = bulk.resetErrSnapshot()
},
}
if err := bulk.attachDedicatedConnShared(conn); err != nil {
t.Fatalf("attachDedicatedConnShared failed: %v", err)
}
if err := runtime.register(clientFileScope(), bulk); err != nil {
t.Fatalf("register bulk failed: %v", err)
}
client.installClientDedicatedSidecar(1, newBulkDedicatedSidecar(conn, 1))
client.cleanupClientSessionResources()
if !errors.Is(closeObservedErr, errServiceShutdown) {
t.Fatalf("closeObservedErr = %v, want %v", closeObservedErr, errServiceShutdown)
}
}
func TestBestEffortRejectInboundDedicatedDataUsesDedicatedResetRecord(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := server.bootstrapAcceptedLogical("dedicated-reject", nil, nil)
if logical == nil {
t.Fatal("bootstrapAcceptedLogical should return logical")
}
conn := newBulkAttachScriptConn(nil)
server.bestEffortRejectInboundDedicatedData(logical, conn, 42, "unknown data id")
recordConn := newBulkAttachScriptConn(conn.writtenBytes())
payload, err := readBulkDedicatedRecord(recordConn)
if err != nil {
t.Fatalf("readBulkDedicatedRecord failed: %v", err)
}
plain, err := server.decryptTransportPayloadLogical(logical, payload)
if err != nil {
t.Fatalf("decryptTransportPayloadLogical failed: %v", err)
}
items, err := decodeDedicatedBulkInboundItems(42, plain)
if err != nil {
t.Fatalf("decodeDedicatedBulkInboundItems failed: %v", err)
}
if len(items) != 1 {
t.Fatalf("decoded items = %d, want 1", len(items))
}
if items[0].Type != bulkFastPayloadTypeReset {
t.Fatalf("reset item type = %d, want %d", items[0].Type, bulkFastPayloadTypeReset)
}
if got, want := string(items[0].Payload), "unknown data id"; got != want {
t.Fatalf("reset payload = %q, want %q", got, want)
}
}
func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := &LogicalConn{}
logical.setID("logical-only")
logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28081"))
got := server.registerAcceptedLogical(logical)
if got != logical {
t.Fatalf("registerAcceptedLogical returned %p, want %p", got, logical)
}
if logical.compatClientConn() != nil {
t.Fatal("logical-only peer should not grow a compatibility client")
}
if logical.Server() != server {
t.Fatal("logical-only peer should inherit server owner")
}
if logical.msgEnSnapshot() == nil || logical.msgDeSnapshot() == nil {
t.Fatal("logical-only peer should inherit transport codec profile")
}
if found := server.GetLogicalConn("logical-only"); found != logical {
t.Fatalf("GetLogicalConn returned %p, want %p", found, logical)
}
if err := server.renameAcceptedLogical(logical, "logical-only-renamed"); err != nil {
t.Fatalf("renameAcceptedLogical failed: %v", err)
}
if found := server.GetLogicalConn("logical-only"); found != nil {
t.Fatalf("old logical id should be removed, got %p", found)
}
if found := server.GetLogicalConn("logical-only-renamed"); found != logical {
t.Fatalf("renamed logical lookup returned %p, want %p", found, logical)
}
server.removeLogical(logical)
if found := server.GetLogicalConn("logical-only-renamed"); found != nil {
t.Fatalf("removeLogical should delete logical-only peer, got %p", found)
}
}
func TestEncodeDecodeEnvelopeLogicalWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := &LogicalConn{}
logical.setID("logical-codec")
server.registerAcceptedLogical(logical)
env := newSignalAckEnvelope(42)
payload, err := server.encodeEnvelopePayloadLogical(logical, env)
if err != nil {
t.Fatalf("encodeEnvelopePayloadLogical failed: %v", err)
}
decoded, err := server.decodeEnvelopeLogical(logical, payload)
if err != nil {
t.Fatalf("decodeEnvelopeLogical failed: %v", err)
}
if got, want := decoded.Kind, env.Kind; got != want {
t.Fatalf("decoded Kind = %v, want %v", got, want)
}
if got, want := decoded.ID, env.ID; got != want {
t.Fatalf("decoded ID = %d, want %d", got, want)
}
}
func TestAttachAcceptedLogicalTransportWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := &LogicalConn{}
logical.setID("logical-transport")
left, right := net.Pipe()
defer left.Close()
defer right.Close()
if err := server.attachAcceptedLogicalTransport(logical, releaseP0TestAddr("127.0.0.1:28082"), left); err != nil {
t.Fatalf("attachAcceptedLogicalTransport failed: %v", err)
}
if logical.Server() != server {
t.Fatal("attachAcceptedLogicalTransport should bind server owner")
}
transport := logical.CurrentTransportConn()
if transport == nil {
t.Fatal("CurrentTransportConn should expose attached transport")
}
if !transport.Attached() || !transport.HasRuntimeConn() {
t.Fatalf("transport snapshot mismatch: %+v", transport)
}
inbound := logical.transportConnSnapshotForInbound(left, nil, transport.TransportGeneration(), true)
if inbound == nil {
t.Fatal("transportConnSnapshotForInbound should work without compatibility client")
}
if !inbound.Attached() {
t.Fatalf("inbound transport should be attached: %+v", inbound)
}
if stopFn := logical.stopFuncSnapshot(); stopFn != nil {
stopFn()
}
}
func TestResolveInboundSourceValueWithoutCompatClient(t *testing.T) {
server := NewServer().(*ServerCommon)
UseLegacySecurityServer(server)
logical := &LogicalConn{}
logical.setID("logical-source")
logical.setRemoteAddr(releaseP0TestAddr("127.0.0.1:28083"))
server.registerAcceptedLogical(logical)
resolved, transport := server.resolveInboundSourceValue(serverInboundSource{
Source: logical.ID(),
Logical: logical,
RemoteAddr: logical.RemoteAddr(),
TransportGeneration: 1,
})
if resolved != logical {
t.Fatalf("resolved logical = %p, want %p", resolved, logical)
}
if transport == nil {
t.Fatal("resolveInboundSourceValue should return transport snapshot for logical-only peer")
}
if transport.LogicalConn() != logical {
t.Fatalf("transport logical = %p, want %p", transport.LogicalConn(), logical)
}
if got, want := transportConnAddrString(transport.RemoteAddr()), transportConnAddrString(logical.RemoteAddr()); got != want {
t.Fatalf("transport remote addr = %q, want %q", got, want)
}
}