package notify import ( "context" "errors" "net" "sync/atomic" "time" ) type LogicalConn struct { client *ClientConn server Server ClientID string ClientAddr net.Addr state atomic.Pointer[logicalConnState] runtime atomic.Pointer[logicalConnRuntimeState] transportState atomic.Pointer[clientConnTransportState] attachment atomic.Pointer[clientConnAttachmentState] } var errLogicalConnClientNil = errors.New("logical conn is nil") func logicalConnFromClient(client *ClientConn) *LogicalConn { if client == nil { return nil } if logical := client.logicalView.Load(); logical != nil { return logical.bindLegacyClient(client) } logical := (&LogicalConn{}).attachLegacyClient(client) if client.logicalView.CompareAndSwap(nil, logical) { return logical } logical = client.logicalView.Load() return logical.bindLegacyClient(client) } func newServerLogicalConn(server Server, id string, addr net.Addr) *LogicalConn { client := &ClientConn{ server: server, } logical := (&LogicalConn{ client: client, server: server, }).attachLegacyClient(client) client.logicalView.Store(logical) if id != "" { logical.setID(id) } if addr != nil { logical.setRemoteAddr(addr) } return logical } func (c *LogicalConn) attachLegacyClient(client *ClientConn) *LogicalConn { c = c.bindLegacyClient(client) if c == nil { return nil } if state := c.state.Load(); state != nil { c.syncCompatibilityFieldsFromState(state) } else { c.syncCompatibilityFieldsFromClient(client) } return c } func (c *LogicalConn) bindLegacyClient(client *ClientConn) *LogicalConn { if c == nil || client == nil { return c } if c.client == nil { c.client = client } if c.server == nil { c.server = client.server } if state := client.logicalState.Load(); state != nil { c.state.CompareAndSwap(nil, state) } if runtime := client.runtimeState.Load(); runtime != nil { c.runtime.CompareAndSwap(nil, runtime) } if transportState := client.transportState.Load(); transportState != nil { c.transportState.CompareAndSwap(nil, transportState) } if attachment := client.attachment.Load(); attachment != nil { c.attachment.CompareAndSwap(nil, attachment) } if state := c.state.Load(); state != nil { client.logicalState.Store(state) } if runtime := c.runtime.Load(); runtime != nil { client.runtimeState.Store(runtime) } if transportState := c.transportState.Load(); transportState != nil { client.transportState.Store(transportState) } if attachment := c.attachment.Load(); attachment != nil { client.attachment.Store(attachment) } client.logicalView.CompareAndSwap(nil, c) return c } func clientConnFromLogical(logical *LogicalConn) *ClientConn { if logical == nil { return nil } return logical.client } func logicalConnFromPeer(peer any) *LogicalConn { switch data := peer.(type) { case nil: return nil case *LogicalConn: return data case *ClientConn: return logicalConnFromClient(data) default: return nil } } func (c *ClientConn) LogicalConn() *LogicalConn { return logicalConnFromClient(c) } func (c *LogicalConn) compatClientConn() *ClientConn { if c == nil { return nil } return c.client } func (c *LogicalConn) logicalStateSnapshot() *logicalConnState { if c == nil { return nil } if state := c.state.Load(); state != nil { return state } return c.ensureState() } func (c *LogicalConn) logicalRuntimeStateSnapshot() *logicalConnRuntimeState { if c == nil { return nil } if state := c.runtime.Load(); state != nil { return state } return c.ensureRuntimeState() } func (c *LogicalConn) ID() string { return c.clientIDSnapshot() } func (c *LogicalConn) RemoteAddr() net.Addr { return c.clientRemoteAddrSnapshot() } func (c *LogicalConn) GetRemoteAddr() net.Addr { return c.RemoteAddr() } func (c *LogicalConn) Status() Status { state := c.logicalStateSnapshot() if state != nil { return state.statusSnapshot() } return Status{} } func (c *LogicalConn) Server() Server { if c == nil { return nil } if c.server != nil { return c.server } client := c.compatClientConn() if client == nil { return nil } return client.server } func (c *LogicalConn) setServer(server Server) { if c == nil || server == nil { return } c.server = server if client := c.compatClientConn(); client != nil { client.server = server } } func (c *LogicalConn) syncCompatibilityFieldsFromClient(client *ClientConn) { if c == nil || client == nil { return } c.ClientID = client.ClientID c.ClientAddr = client.ClientAddr if c.server == nil { c.server = client.server } } func (c *LogicalConn) syncCompatibilityFieldsFromState(state *logicalConnState) { if c == nil { return } if state == nil { c.syncCompatibilityFieldsFromClient(c.compatClientConn()) return } peer := state.peerSnapshot() c.ClientID = peer.clientID c.ClientAddr = peer.clientAddr } func (c *LogicalConn) markSessionStarted() { state := c.logicalStateSnapshot() if state == nil { return } state.markStarted() if client := c.compatClientConn(); client != nil { client.syncLegacyLogicalFieldsFromState(state) } } func (c *LogicalConn) markSessionStopped(reason string, err error) { state := c.logicalStateSnapshot() if state == nil { return } state.markStopped(reason, err, c.stopFuncSnapshot()) if client := c.compatClientConn(); client != nil { client.syncLegacyLogicalFieldsFromState(state) } } func (c *LogicalConn) rsaDecode(message Message) { if client := c.compatClientConn(); client != nil { client.rsaDecode(message) } } func (c *LogicalConn) sayGoodByeForTU() error { if client := c.compatClientConn(); client != nil { return client.sayGoodByeForTU() } return errTransportDetached } func (c *LogicalConn) setID(id string) { if c == nil { return } state := c.ensureState() if state == nil { c.ClientID = id if client := c.compatClientConn(); client != nil { client.ClientID = id } return } state.updatePeer(func(peer *logicalConnPeerState) { peer.clientID = id }) c.syncCompatibilityFieldsFromState(state) if client := c.compatClientConn(); client != nil { client.syncLegacyLogicalFieldsFromState(state) } } func (c *LogicalConn) clientIDSnapshot() string { state := c.logicalStateSnapshot() if state == nil { return c.ClientID } peer := state.peerSnapshot() return peer.clientID } func (c *LogicalConn) clientRemoteAddrSnapshot() net.Addr { state := c.logicalStateSnapshot() if state == nil { return c.ClientAddr } peer := state.peerSnapshot() return peer.clientAddr } func (c *LogicalConn) setRemoteAddr(addr net.Addr) { if c == nil || addr == nil { return } state := c.ensureState() if state == nil { c.ClientAddr = addr if client := c.compatClientConn(); client != nil { client.ClientAddr = addr } return } state.updatePeer(func(peer *logicalConnPeerState) { peer.clientAddr = addr }) c.syncCompatibilityFieldsFromState(state) if client := c.compatClientConn(); client != nil { client.syncLegacyLogicalFieldsFromState(state) } } func (c *LogicalConn) transportGenerationSnapshot() uint64 { state := c.ensureTransportState() if state == nil { return 0 } return state.transportGen.Load() } func (c *LogicalConn) lastHeartbeatUnixSnapshot() int64 { return c.attachmentStateSnapshot().lastHeartBeat } func (c *LogicalConn) transportAttachedSnapshot() bool { rt := c.sessionRuntimeSnapshot() if rt == nil { return false } return rt.transportAttached } func (c *LogicalConn) usesStreamTransportSnapshot() bool { state := c.ensureTransportState() if state == nil { return false } return state.streamTransport.Load() } func (c *LogicalConn) logicalTransportDetachedSnapshot() bool { if c == nil { return false } if !c.clientConnIdentityBoundSnapshot() || !c.usesStreamTransportSnapshot() { return false } if !c.clientConnAliveSnapshot() { return false } return !c.transportAttachedSnapshot() } func (c *LogicalConn) shouldPreserveLogicalPeerOnTransportLoss() bool { return c.clientConnIdentityBoundSnapshot() && c.usesStreamTransportSnapshot() } func (c *LogicalConn) markIdentityBound() { state := c.logicalStateSnapshot() if state == nil { return } state.updatePeer(func(peer *logicalConnPeerState) { peer.identityBound = true }) if client := c.compatClientConn(); client != nil { client.syncLegacyLogicalFieldsFromState(state) } } func (c *LogicalConn) markHeartbeatNow() { c.setClientConnLastHeartbeatUnix(time.Now().Unix()) } func (c *LogicalConn) markStreamTransport() { state := c.ensureTransportState() if state == nil { return } state.streamTransport.Store(true) } func (c *LogicalConn) markTransportAttached() uint64 { state := c.ensureTransportState() if state == nil { return 0 } gen := state.transportGen.Add(1) state.attachCount.Add(1) state.lastAttachAt.Store(time.Now().UnixNano()) return gen } func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, fastStreamEncode transportFastStreamEncoder, fastBulkEncode transportFastBulkEncoder, fastPlainEncode transportFastPlainEncoder, handshakeRsaKey []byte, secretKey []byte) { c.updateAttachmentState(func(state *clientConnAttachmentState) { state.maxReadTimeout = maxReadTimeout state.maxWriteTimeout = maxWriteTimeout state.msgEn = msgEn state.msgDe = msgDe state.fastStreamEncode = fastStreamEncode state.fastBulkEncode = fastBulkEncode state.fastPlainEncode = fastPlainEncode state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey) }) } func (c *LogicalConn) msgEnSnapshot() func([]byte, []byte) []byte { return c.attachmentStateSnapshot().msgEn } func (c *LogicalConn) msgDeSnapshot() func([]byte, []byte) []byte { return c.attachmentStateSnapshot().msgDe } func (c *LogicalConn) secretKeySnapshot() []byte { return c.attachmentStateSnapshot().secretKey } func (c *LogicalConn) fastStreamEncodeSnapshot() transportFastStreamEncoder { return c.attachmentStateSnapshot().fastStreamEncode } func (c *LogicalConn) fastBulkEncodeSnapshot() transportFastBulkEncoder { return c.attachmentStateSnapshot().fastBulkEncode } func (c *LogicalConn) fastPlainEncodeSnapshot() transportFastPlainEncoder { return c.attachmentStateSnapshot().fastPlainEncode } func (c *LogicalConn) inheritAttachmentProfile(src *LogicalConn) { if c == nil || src == nil { return } c.setAttachmentState(src.attachmentStateSnapshot()) } func (c *LogicalConn) transportBindingSnapshot() *transportBinding { rt := c.sessionRuntimeSnapshot() if rt == nil { return nil } if rt.transport != nil { return rt.transport } if rt.tuConn == nil { return nil } return newTransportBinding(rt.tuConn, nil) } func (c *LogicalConn) maxWriteTimeoutSnapshot() time.Duration { return c.attachmentStateSnapshot().maxWriteTimeout } func (c *LogicalConn) transportDetachSnapshot() *clientConnTransportDetachState { state := c.ensureTransportState() if state == nil { return nil } return cloneClientConnTransportDetachState(state.transportDetach.Load()) } func (c *LogicalConn) markTransportDetached(reason string, err error) { state := c.ensureTransportState() if state == nil { return } detachState := &clientConnTransportDetachState{ Generation: c.transportGenerationSnapshot(), Reason: reason, At: time.Now(), } if err != nil { detachState.Err = err.Error() } state.detachCount.Add(1) c.setTransportDetachState(detachState) } func (c *LogicalConn) clearTransportDetachState() { c.setTransportDetachState(nil) } func (c *LogicalConn) transportDetachExpiredSnapshot(now time.Time) bool { if !c.logicalTransportDetachedSnapshot() { return false } expiry, ok := c.clientConnTransportDetachExpirySnapshot() if !ok { return false } return !now.Before(expiry) } func (c *LogicalConn) reattachEligibleSnapshot(now time.Time) bool { if !c.logicalTransportDetachedSnapshot() { return false } if !c.clientConnAliveSnapshot() { return false } if c.transportAttachedSnapshot() { return false } if c.transportDetachExpiredSnapshot(now) { return false } return true } func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot { if c == nil { return ClientConnRuntimeSnapshot{} } status := c.Status() now := time.Now() snapshot := ClientConnRuntimeSnapshot{ ClientID: c.clientIDSnapshot(), Alive: status.Alive, Reason: status.Reason, IdentityBound: c.clientConnIdentityBoundSnapshot(), UsesStreamTransport: c.usesStreamTransportSnapshot(), TransportGeneration: c.transportGenerationSnapshot(), TransportAttachCount: c.clientConnTransportAttachCountSnapshot(), TransportDetachCount: c.clientConnTransportDetachCountSnapshot(), LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(), } if status.Err != nil { snapshot.Error = status.Err.Error() } if addr := c.RemoteAddr(); addr != nil { snapshot.RemoteAddress = addr.String() } if lastHeartbeat := c.lastHeartbeatUnixSnapshot(); lastHeartbeat != 0 { snapshot.LastHeartbeatAt = time.Unix(lastHeartbeat, 0) } if server := c.Server(); server != nil { snapshot.DetachedClientKeepSec = server.DetachedClientKeepSec() } if rt := c.sessionRuntimeSnapshot(); rt != nil { snapshot.TransportAttached = c.transportAttachedSnapshot() snapshot.HasRuntimeConn = c.transportSnapshot() != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil } if detach := c.transportDetachSnapshot(); detach != nil { snapshot.TransportDetachReason = detach.Reason snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason) snapshot.TransportDetachGeneration = c.clientConnTransportDetachGenerationSnapshot() snapshot.TransportDetachError = detach.Err snapshot.TransportDetachedAt = detach.At snapshot.TransportDetachExpiry, snapshot.TransportDetachHasExpiry = c.clientConnTransportDetachExpirySnapshot() snapshot.TransportDetachRemaining = c.clientConnTransportDetachRemainingSnapshot(now) snapshot.TransportDetachExpired = c.clientConnTransportDetachExpiredSnapshot(now) } snapshot.ReattachEligible = c.clientConnReattachEligibleSnapshot(now) return snapshot } func (c *LogicalConn) clientConnLogicalPeerStateSnapshot() logicalConnPeerState { state := c.logicalStateSnapshot() if state == nil { return logicalConnPeerState{ clientID: c.ClientID, clientAddr: c.ClientAddr, } } return state.peerSnapshot() } func (c *LogicalConn) clientConnIDSnapshot() string { return c.clientIDSnapshot() } func (c *LogicalConn) clientConnRemoteAddrSnapshot() net.Addr { return c.clientRemoteAddrSnapshot() } func (c *LogicalConn) clientConnAliveSnapshot() bool { state := c.logicalStateSnapshot() if state == nil { return false } return state.aliveSnapshot() } func (c *LogicalConn) clientConnStatusSnapshot() Status { return c.Status() } func (c *LogicalConn) clientConnIdentityBoundSnapshot() bool { return c.clientConnLogicalPeerStateSnapshot().identityBound } func (c *LogicalConn) clientConnUsesStreamTransportSnapshot() bool { return c.usesStreamTransportSnapshot() } func (c *LogicalConn) clientConnTransportGenerationSnapshot() uint64 { return c.transportGenerationSnapshot() } func (c *LogicalConn) clientConnTransportAttachCountSnapshot() uint64 { state := c.ensureTransportState() if state == nil { return 0 } return state.attachCount.Load() } func (c *LogicalConn) clientConnTransportDetachCountSnapshot() uint64 { state := c.ensureTransportState() if state == nil { return 0 } return state.detachCount.Load() } func (c *LogicalConn) clientConnTransportSnapshot() net.Conn { return c.transportSnapshot() } func (c *LogicalConn) clientConnTransportBindingSnapshot() *transportBinding { return c.transportBindingSnapshot() } func (c *LogicalConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime { return c.sessionRuntimeSnapshot() } func (c *LogicalConn) clientConnStopContextSnapshot() context.Context { return c.stopContextSnapshot() } func (c *LogicalConn) clientConnStopFuncSnapshot() context.CancelFunc { return c.stopFuncSnapshot() } func (c *LogicalConn) clientConnTransportStopContextSnapshot() context.Context { return c.transportStopContextSnapshot() } func (c *LogicalConn) clientConnTransportAttachedSnapshot() bool { return c.transportAttachedSnapshot() } func (c *LogicalConn) clientConnLogicalTransportDetachedSnapshot() bool { return c.logicalTransportDetachedSnapshot() } func (c *LogicalConn) clientConnLastTransportAttachedAtSnapshot() time.Time { state := c.ensureTransportState() if state == nil { return time.Time{} } unixNano := state.lastAttachAt.Load() if unixNano == 0 { return time.Time{} } return time.Unix(0, unixNano) } func (c *LogicalConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState { return c.transportDetachSnapshot() } func (c *LogicalConn) clientConnTransportDetachKindSnapshot() string { detach := c.transportDetachSnapshot() if detach == nil { return "" } return classifyClientConnTransportDetachReason(detach.Reason) } func (c *LogicalConn) clientConnTransportDetachGenerationSnapshot() uint64 { detach := c.transportDetachSnapshot() if detach == nil { return 0 } if detach.Generation == 0 { return c.transportGenerationSnapshot() } return detach.Generation } func (c *LogicalConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) { detach := c.transportDetachSnapshot() if detach == nil || detach.At.IsZero() { return time.Time{}, false } server := c.Server() if server == nil { return time.Time{}, false } keepSec := server.DetachedClientKeepSec() if keepSec <= 0 { return time.Time{}, false } return detach.At.Add(time.Duration(keepSec) * time.Second), true } func (c *LogicalConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool { return c.transportDetachExpiredSnapshot(now) } func (c *LogicalConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration { if !c.clientConnLogicalTransportDetachedSnapshot() { return 0 } expiry, ok := c.clientConnTransportDetachExpirySnapshot() if !ok || !now.Before(expiry) { return 0 } return expiry.Sub(now) } func (c *LogicalConn) clientConnReattachEligibleSnapshot(now time.Time) bool { return c.reattachEligibleSnapshot(now) } func (c *LogicalConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { return c.runtimeSnapshot() } func (c *LogicalConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState { return c.attachmentStateSnapshot() } func (c *LogicalConn) clientConnMaxReadTimeoutSnapshot() time.Duration { return c.attachmentStateSnapshot().maxReadTimeout } func (c *LogicalConn) clientConnMaxWriteTimeoutSnapshot() time.Duration { return c.maxWriteTimeoutSnapshot() } func (c *LogicalConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte { return c.msgEnSnapshot() } func (c *LogicalConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte { return c.msgDeSnapshot() } func (c *LogicalConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder { return c.fastStreamEncodeSnapshot() } func (c *LogicalConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder { return c.fastBulkEncodeSnapshot() } func (c *LogicalConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder { return c.fastPlainEncodeSnapshot() } func (c *LogicalConn) clientConnHandshakeRsaKeySnapshot() []byte { return c.attachmentStateSnapshot().handshakeRsaKey } func (c *LogicalConn) clientConnSecretKeySnapshot() []byte { return c.secretKeySnapshot() } func (c *LogicalConn) clientConnLastHeartbeatUnixSnapshot() int64 { return c.lastHeartbeatUnixSnapshot() } func (c *LogicalConn) setClientConnID(id string) { c.setID(id) } func (c *LogicalConn) setClientConnRemoteAddr(addr net.Addr) { c.setRemoteAddr(addr) } func (c *LogicalConn) setClientConnLastHeartbeatUnix(unix int64) { c.updateAttachmentState(func(state *clientConnAttachmentState) { state.lastHeartBeat = unix }) } func (c *LogicalConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { return c.startSession(tuConn, stopCtx, stopFn) } func (c *LogicalConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { return c.startSessionTransport(tuConn, stopCtx, stopFn) } func (c *LogicalConn) attachClientConnSessionTransport(tuConn net.Conn) error { return c.attachSessionTransport(tuConn) } func (c *LogicalConn) detachClientConnTransportForTransfer() (net.Conn, error) { return c.detachTransportForTransfer() } func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) { if client := c.compatClientConn(); client != nil { client.applyClientConnAttachmentProfile(maxReadTimeout, maxWriteTimeout, msgEn, msgDe, handshakeRsaKey, secretKey) return } c.updateAttachmentState(func(state *clientConnAttachmentState) { state.maxReadTimeout = maxReadTimeout state.maxWriteTimeout = maxWriteTimeout state.msgEn = msgEn state.msgDe = msgDe state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey) }) } func (c *LogicalConn) inheritClientConnAttachmentProfile(src *ClientConn) { if src == nil { return } if client := c.compatClientConn(); client != nil { client.inheritClientConnAttachmentProfile(src) return } c.setAttachmentState(src.clientConnAttachmentStateSnapshot()) } func (c *LogicalConn) sessionRuntimeSnapshot() *clientConnSessionRuntime { state := c.logicalRuntimeStateSnapshot() if state == nil { return nil } return state.sessionRuntimeSnapshot() } func (c *LogicalConn) setSessionRuntime(rt *clientConnSessionRuntime) { if c == nil || rt == nil { return } var oldBinding *transportBinding if prev := c.sessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport { oldBinding = prev.transport } if rt.transport == nil && rt.tuConn != nil { rt.transport = newTransportBinding(rt.tuConn, nil) } normalizeClientConnSessionRuntimeTransportState(rt) ensureClientConnSessionRuntimeTransportLifecycle(rt) ensureClientConnSessionRuntimeTransportDone(rt) state := c.logicalRuntimeStateSnapshot() if state == nil { client := c.compatClientConn() if client != nil { client.sessionRuntime.Store(rt) } return } state.setSessionRuntime(rt) client := c.compatClientConn() if client != nil { client.syncLegacySessionRuntimeFromState(state) } if oldBinding != nil { oldBinding.stopBackgroundWorkers() } } func (c *LogicalConn) clearSessionRuntimeTransport() { if c == nil { return } rt := c.sessionRuntimeSnapshot() if rt == nil { return } if rt.transportStopFn != nil { rt.transportStopFn() } next := *rt next.transport = nil next.transportAttached = false next.transportGeneration = 0 next.tuConn = nil next.transportStopCtx = nil next.transportStopFn = nil next.transportDone = nil c.setSessionRuntime(&next) } func (c *LogicalConn) transportSnapshot() net.Conn { rt := c.sessionRuntimeSnapshot() if rt == nil { return nil } if rt.transport != nil { return rt.transport.connSnapshot() } return rt.tuConn } func (c *LogicalConn) stopContextSnapshot() context.Context { rt := c.sessionRuntimeSnapshot() if rt == nil { return nil } return rt.stopCtx } func (c *LogicalConn) stopFuncSnapshot() context.CancelFunc { rt := c.sessionRuntimeSnapshot() if rt == nil { return nil } return rt.stopFn } func (c *LogicalConn) transportStopContextSnapshot() context.Context { rt := c.sessionRuntimeSnapshot() if rt == nil { return nil } if rt.transportStopCtx != nil { return rt.transportStopCtx } return rt.stopCtx } func (c *LogicalConn) closeTransport() { rt := c.sessionRuntimeSnapshot() var binding *transportBinding if rt != nil { binding = rt.transport } conn := c.transportSnapshot() if conn == nil { if binding != nil { binding.stopBackgroundWorkers() } return } _ = conn.Close() if binding != nil { binding.stopBackgroundWorkers() } } func (c *LogicalConn) startSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { if c == nil { return stopCtx, stopFn } if stopCtx == nil || stopFn == nil { stopCtx, stopFn = context.WithCancel(context.Background()) } if c.RemoteAddr() == nil && tuConn != nil { c.setRemoteAddr(tuConn.RemoteAddr()) } transportGeneration := uint64(0) if tuConn != nil { c.markStreamTransport() transportGeneration = c.markTransportAttached() c.clearTransportDetachState() } c.setSessionRuntime(&clientConnSessionRuntime{ transport: newTransportBinding(tuConn, nil), transportAttached: tuConn != nil, transportGeneration: transportGeneration, tuConn: tuConn, stopCtx: stopCtx, stopFn: stopFn, }) c.markSessionStarted() return stopCtx, stopFn } func (c *LogicalConn) startSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) { if c == nil { return stopCtx, stopFn } stopCtx, stopFn = c.startSession(tuConn, stopCtx, stopFn) rt := c.sessionRuntimeSnapshot() if rt == nil { return stopCtx, stopFn } go c.readTUMessageLoop(rt) return stopCtx, stopFn } func (c *LogicalConn) attachSessionTransport(tuConn net.Conn) error { if c == nil { return errLogicalConnClientNil } if tuConn == nil { return errors.New("conn is nil") } rt := c.sessionRuntimeSnapshot() if rt == nil { return errors.New("client conn session runtime is nil") } oldBinding := rt.transport if rt.transportStopFn != nil { rt.transportStopFn() } next := *rt next.transport = newTransportBinding(tuConn, nil) next.transportAttached = true next.transportGeneration = c.markTransportAttached() next.tuConn = tuConn next.transportStopCtx = nil next.transportStopFn = nil next.transportDone = nil c.setSessionRuntime(&next) if tuConn.RemoteAddr() != nil { c.setRemoteAddr(tuConn.RemoteAddr()) } c.markStreamTransport() c.clearTransportDetachState() if oldBinding != nil { if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != tuConn { _ = oldConn.Close() } } attached := c.sessionRuntimeSnapshot() if attached == nil { return nil } go c.readTUMessageLoop(attached) return nil } func (c *LogicalConn) detachTransportForTransfer() (net.Conn, error) { if c == nil { return nil, errLogicalConnClientNil } rt := c.sessionRuntimeSnapshot() if rt == nil { return nil, errors.New("client conn session runtime is nil") } conn := rt.tuConn if rt.transport != nil && rt.transport.connSnapshot() != nil { conn = rt.transport.connSnapshot() } next := *rt next.transport = nil next.transportAttached = false next.transportGeneration = 0 next.tuConn = nil next.transportStopCtx = nil next.transportStopFn = nil next.transportDone = nil c.setSessionRuntime(&next) if rt.transportStopFn != nil { rt.transportStopFn() } if conn != nil { _ = conn.SetReadDeadline(time.Now()) } if rt.transportDone != nil { select { case <-rt.transportDone: case <-time.After(time.Second): if conn != nil { _ = conn.Close() } return nil, errors.New("timed out waiting for transport handoff") } } if conn != nil { _ = conn.SetReadDeadline(time.Time{}) } return conn, nil } func (c *LogicalConn) CurrentTransportConn() *TransportConn { return c.currentTransportConnSnapshot() } func (c *LogicalConn) transportConnSnapshotForInbound(conn net.Conn, remoteAddr net.Addr, generation uint64, hasRuntimeConn bool) *TransportConn { if c == nil { return nil } if remoteAddr == nil { if conn != nil { remoteAddr = conn.RemoteAddr() } if remoteAddr == nil { remoteAddr = c.RemoteAddr() } } if remoteAddr == nil && !hasRuntimeConn { return nil } attached := false currentGeneration := c.transportGenerationSnapshot() if conn != nil { binding := c.transportBindingSnapshot() if binding != nil && binding.connSnapshot() == conn && c.transportAttachedSnapshot() && currentGeneration == generation { attached = true } } else { current := c.CurrentTransportConn() if current != nil && currentGeneration == generation && transportConnAddrString(current.RemoteAddr()) == transportConnAddrString(remoteAddr) { attached = current.Attached() if !hasRuntimeConn { hasRuntimeConn = current.HasRuntimeConn() } } } return &TransportConn{ logical: c, generation: generation, remoteAddr: remoteAddr, attached: attached, hasRuntimeConn: hasRuntimeConn, } }