notify/logical_conn.go

1125 lines
28 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"net"
"sync/atomic"
"time"
)
type LogicalConn struct {
client *ClientConn
server Server
ClientID string
ClientAddr net.Addr
state atomic.Pointer[logicalConnState]
runtime atomic.Pointer[logicalConnRuntimeState]
transportState atomic.Pointer[clientConnTransportState]
attachment atomic.Pointer[clientConnAttachmentState]
}
var errLogicalConnClientNil = errors.New("logical conn is nil")
func logicalConnFromClient(client *ClientConn) *LogicalConn {
if client == nil {
return nil
}
if logical := client.logicalView.Load(); logical != nil {
return logical.bindLegacyClient(client)
}
logical := (&LogicalConn{}).attachLegacyClient(client)
if client.logicalView.CompareAndSwap(nil, logical) {
return logical
}
logical = client.logicalView.Load()
return logical.bindLegacyClient(client)
}
func newServerLogicalConn(server Server, id string, addr net.Addr) *LogicalConn {
client := &ClientConn{
server: server,
}
logical := (&LogicalConn{
client: client,
server: server,
}).attachLegacyClient(client)
client.logicalView.Store(logical)
if id != "" {
logical.setID(id)
}
if addr != nil {
logical.setRemoteAddr(addr)
}
return logical
}
func (c *LogicalConn) attachLegacyClient(client *ClientConn) *LogicalConn {
c = c.bindLegacyClient(client)
if c == nil {
return nil
}
if state := c.state.Load(); state != nil {
c.syncCompatibilityFieldsFromState(state)
} else {
c.syncCompatibilityFieldsFromClient(client)
}
return c
}
func (c *LogicalConn) bindLegacyClient(client *ClientConn) *LogicalConn {
if c == nil || client == nil {
return c
}
if c.client == nil {
c.client = client
}
if c.server == nil {
c.server = client.server
}
if state := client.logicalState.Load(); state != nil {
c.state.CompareAndSwap(nil, state)
}
if runtime := client.runtimeState.Load(); runtime != nil {
c.runtime.CompareAndSwap(nil, runtime)
}
if transportState := client.transportState.Load(); transportState != nil {
c.transportState.CompareAndSwap(nil, transportState)
}
if attachment := client.attachment.Load(); attachment != nil {
c.attachment.CompareAndSwap(nil, attachment)
}
if state := c.state.Load(); state != nil {
client.logicalState.Store(state)
}
if runtime := c.runtime.Load(); runtime != nil {
client.runtimeState.Store(runtime)
}
if transportState := c.transportState.Load(); transportState != nil {
client.transportState.Store(transportState)
}
if attachment := c.attachment.Load(); attachment != nil {
client.attachment.Store(attachment)
}
client.logicalView.CompareAndSwap(nil, c)
return c
}
func clientConnFromLogical(logical *LogicalConn) *ClientConn {
if logical == nil {
return nil
}
return logical.client
}
func logicalConnFromPeer(peer any) *LogicalConn {
switch data := peer.(type) {
case nil:
return nil
case *LogicalConn:
return data
case *ClientConn:
return logicalConnFromClient(data)
default:
return nil
}
}
func (c *ClientConn) LogicalConn() *LogicalConn {
return logicalConnFromClient(c)
}
func (c *LogicalConn) compatClientConn() *ClientConn {
if c == nil {
return nil
}
return c.client
}
func (c *LogicalConn) logicalStateSnapshot() *logicalConnState {
if c == nil {
return nil
}
if state := c.state.Load(); state != nil {
return state
}
return c.ensureState()
}
func (c *LogicalConn) logicalRuntimeStateSnapshot() *logicalConnRuntimeState {
if c == nil {
return nil
}
if state := c.runtime.Load(); state != nil {
return state
}
return c.ensureRuntimeState()
}
func (c *LogicalConn) ID() string {
return c.clientIDSnapshot()
}
func (c *LogicalConn) RemoteAddr() net.Addr {
return c.clientRemoteAddrSnapshot()
}
func (c *LogicalConn) GetRemoteAddr() net.Addr {
return c.RemoteAddr()
}
func (c *LogicalConn) Status() Status {
state := c.logicalStateSnapshot()
if state != nil {
return state.statusSnapshot()
}
return Status{}
}
func (c *LogicalConn) Server() Server {
if c == nil {
return nil
}
if c.server != nil {
return c.server
}
client := c.compatClientConn()
if client == nil {
return nil
}
return client.server
}
func (c *LogicalConn) setServer(server Server) {
if c == nil || server == nil {
return
}
c.server = server
if client := c.compatClientConn(); client != nil {
client.server = server
}
}
func (c *LogicalConn) syncCompatibilityFieldsFromClient(client *ClientConn) {
if c == nil || client == nil {
return
}
c.ClientID = client.ClientID
c.ClientAddr = client.ClientAddr
if c.server == nil {
c.server = client.server
}
}
func (c *LogicalConn) syncCompatibilityFieldsFromState(state *logicalConnState) {
if c == nil {
return
}
if state == nil {
c.syncCompatibilityFieldsFromClient(c.compatClientConn())
return
}
peer := state.peerSnapshot()
c.ClientID = peer.clientID
c.ClientAddr = peer.clientAddr
}
func (c *LogicalConn) markSessionStarted() {
state := c.logicalStateSnapshot()
if state == nil {
return
}
state.markStarted()
if client := c.compatClientConn(); client != nil {
client.syncLegacyLogicalFieldsFromState(state)
}
}
func (c *LogicalConn) markSessionStopped(reason string, err error) {
state := c.logicalStateSnapshot()
if state == nil {
return
}
state.markStopped(reason, err, c.stopFuncSnapshot())
if client := c.compatClientConn(); client != nil {
client.syncLegacyLogicalFieldsFromState(state)
}
}
func (c *LogicalConn) rsaDecode(message Message) {
if client := c.compatClientConn(); client != nil {
client.rsaDecode(message)
}
}
func (c *LogicalConn) sayGoodByeForTU() error {
if client := c.compatClientConn(); client != nil {
return client.sayGoodByeForTU()
}
return errTransportDetached
}
func (c *LogicalConn) setID(id string) {
if c == nil {
return
}
state := c.ensureState()
if state == nil {
c.ClientID = id
if client := c.compatClientConn(); client != nil {
client.ClientID = id
}
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.clientID = id
})
c.syncCompatibilityFieldsFromState(state)
if client := c.compatClientConn(); client != nil {
client.syncLegacyLogicalFieldsFromState(state)
}
}
func (c *LogicalConn) clientIDSnapshot() string {
state := c.logicalStateSnapshot()
if state == nil {
return c.ClientID
}
peer := state.peerSnapshot()
return peer.clientID
}
func (c *LogicalConn) clientRemoteAddrSnapshot() net.Addr {
state := c.logicalStateSnapshot()
if state == nil {
return c.ClientAddr
}
peer := state.peerSnapshot()
return peer.clientAddr
}
func (c *LogicalConn) setRemoteAddr(addr net.Addr) {
if c == nil || addr == nil {
return
}
state := c.ensureState()
if state == nil {
c.ClientAddr = addr
if client := c.compatClientConn(); client != nil {
client.ClientAddr = addr
}
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.clientAddr = addr
})
c.syncCompatibilityFieldsFromState(state)
if client := c.compatClientConn(); client != nil {
client.syncLegacyLogicalFieldsFromState(state)
}
}
func (c *LogicalConn) transportGenerationSnapshot() uint64 {
state := c.ensureTransportState()
if state == nil {
return 0
}
return state.transportGen.Load()
}
func (c *LogicalConn) lastHeartbeatUnixSnapshot() int64 {
return c.attachmentStateSnapshot().lastHeartBeat
}
func (c *LogicalConn) transportAttachedSnapshot() bool {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return false
}
return rt.transportAttached
}
func (c *LogicalConn) usesStreamTransportSnapshot() bool {
state := c.ensureTransportState()
if state == nil {
return false
}
return state.streamTransport.Load()
}
func (c *LogicalConn) logicalTransportDetachedSnapshot() bool {
if c == nil {
return false
}
if !c.clientConnIdentityBoundSnapshot() || !c.usesStreamTransportSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
return !c.transportAttachedSnapshot()
}
func (c *LogicalConn) shouldPreserveLogicalPeerOnTransportLoss() bool {
return c.clientConnIdentityBoundSnapshot() && c.usesStreamTransportSnapshot()
}
func (c *LogicalConn) markIdentityBound() {
state := c.logicalStateSnapshot()
if state == nil {
return
}
state.updatePeer(func(peer *logicalConnPeerState) {
peer.identityBound = true
})
if client := c.compatClientConn(); client != nil {
client.syncLegacyLogicalFieldsFromState(state)
}
}
func (c *LogicalConn) markHeartbeatNow() {
c.setClientConnLastHeartbeatUnix(time.Now().Unix())
}
func (c *LogicalConn) markStreamTransport() {
state := c.ensureTransportState()
if state == nil {
return
}
state.streamTransport.Store(true)
}
func (c *LogicalConn) markTransportAttached() uint64 {
state := c.ensureTransportState()
if state == nil {
return 0
}
gen := state.transportGen.Add(1)
state.attachCount.Add(1)
state.lastAttachAt.Store(time.Now().UnixNano())
return gen
}
func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, fastStreamEncode transportFastStreamEncoder, fastBulkEncode transportFastBulkEncoder, fastPlainEncode transportFastPlainEncoder, handshakeRsaKey []byte, secretKey []byte) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn
state.msgDe = msgDe
state.fastStreamEncode = fastStreamEncode
state.fastBulkEncode = fastBulkEncode
state.fastPlainEncode = fastPlainEncode
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})
}
func (c *LogicalConn) msgEnSnapshot() func([]byte, []byte) []byte {
return c.attachmentStateSnapshot().msgEn
}
func (c *LogicalConn) msgDeSnapshot() func([]byte, []byte) []byte {
return c.attachmentStateSnapshot().msgDe
}
func (c *LogicalConn) secretKeySnapshot() []byte {
return c.attachmentStateSnapshot().secretKey
}
func (c *LogicalConn) fastStreamEncodeSnapshot() transportFastStreamEncoder {
return c.attachmentStateSnapshot().fastStreamEncode
}
func (c *LogicalConn) fastBulkEncodeSnapshot() transportFastBulkEncoder {
return c.attachmentStateSnapshot().fastBulkEncode
}
func (c *LogicalConn) fastPlainEncodeSnapshot() transportFastPlainEncoder {
return c.attachmentStateSnapshot().fastPlainEncode
}
func (c *LogicalConn) inheritAttachmentProfile(src *LogicalConn) {
if c == nil || src == nil {
return
}
c.setAttachmentState(src.attachmentStateSnapshot())
}
func (c *LogicalConn) transportBindingSnapshot() *transportBinding {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport
}
if rt.tuConn == nil {
return nil
}
return newTransportBinding(rt.tuConn, nil)
}
func (c *LogicalConn) maxWriteTimeoutSnapshot() time.Duration {
return c.attachmentStateSnapshot().maxWriteTimeout
}
func (c *LogicalConn) transportDetachSnapshot() *clientConnTransportDetachState {
state := c.ensureTransportState()
if state == nil {
return nil
}
return cloneClientConnTransportDetachState(state.transportDetach.Load())
}
func (c *LogicalConn) markTransportDetached(reason string, err error) {
state := c.ensureTransportState()
if state == nil {
return
}
detachState := &clientConnTransportDetachState{
Generation: c.transportGenerationSnapshot(),
Reason: reason,
At: time.Now(),
}
if err != nil {
detachState.Err = err.Error()
}
state.detachCount.Add(1)
c.setTransportDetachState(detachState)
}
func (c *LogicalConn) clearTransportDetachState() {
c.setTransportDetachState(nil)
}
func (c *LogicalConn) transportDetachExpiredSnapshot(now time.Time) bool {
if !c.logicalTransportDetachedSnapshot() {
return false
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok {
return false
}
return !now.Before(expiry)
}
func (c *LogicalConn) reattachEligibleSnapshot(now time.Time) bool {
if !c.logicalTransportDetachedSnapshot() {
return false
}
if !c.clientConnAliveSnapshot() {
return false
}
if c.transportAttachedSnapshot() {
return false
}
if c.transportDetachExpiredSnapshot(now) {
return false
}
return true
}
func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot {
if c == nil {
return ClientConnRuntimeSnapshot{}
}
status := c.Status()
now := time.Now()
snapshot := ClientConnRuntimeSnapshot{
ClientID: c.clientIDSnapshot(),
Alive: status.Alive,
Reason: status.Reason,
IdentityBound: c.clientConnIdentityBoundSnapshot(),
UsesStreamTransport: c.usesStreamTransportSnapshot(),
TransportGeneration: c.transportGenerationSnapshot(),
TransportAttachCount: c.clientConnTransportAttachCountSnapshot(),
TransportDetachCount: c.clientConnTransportDetachCountSnapshot(),
LastTransportAttachAt: c.clientConnLastTransportAttachedAtSnapshot(),
}
if status.Err != nil {
snapshot.Error = status.Err.Error()
}
if addr := c.RemoteAddr(); addr != nil {
snapshot.RemoteAddress = addr.String()
}
if lastHeartbeat := c.lastHeartbeatUnixSnapshot(); lastHeartbeat != 0 {
snapshot.LastHeartbeatAt = time.Unix(lastHeartbeat, 0)
}
if server := c.Server(); server != nil {
snapshot.DetachedClientKeepSec = server.DetachedClientKeepSec()
}
if rt := c.sessionRuntimeSnapshot(); rt != nil {
snapshot.TransportAttached = c.transportAttachedSnapshot()
snapshot.HasRuntimeConn = c.transportSnapshot() != nil
snapshot.HasRuntimeStopCtx = rt.stopCtx != nil
}
if detach := c.transportDetachSnapshot(); detach != nil {
snapshot.TransportDetachReason = detach.Reason
snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason)
snapshot.TransportDetachGeneration = c.clientConnTransportDetachGenerationSnapshot()
snapshot.TransportDetachError = detach.Err
snapshot.TransportDetachedAt = detach.At
snapshot.TransportDetachExpiry, snapshot.TransportDetachHasExpiry = c.clientConnTransportDetachExpirySnapshot()
snapshot.TransportDetachRemaining = c.clientConnTransportDetachRemainingSnapshot(now)
snapshot.TransportDetachExpired = c.clientConnTransportDetachExpiredSnapshot(now)
}
snapshot.ReattachEligible = c.clientConnReattachEligibleSnapshot(now)
return snapshot
}
func (c *LogicalConn) clientConnLogicalPeerStateSnapshot() logicalConnPeerState {
state := c.logicalStateSnapshot()
if state == nil {
return logicalConnPeerState{
clientID: c.ClientID,
clientAddr: c.ClientAddr,
}
}
return state.peerSnapshot()
}
func (c *LogicalConn) clientConnIDSnapshot() string {
return c.clientIDSnapshot()
}
func (c *LogicalConn) clientConnRemoteAddrSnapshot() net.Addr {
return c.clientRemoteAddrSnapshot()
}
func (c *LogicalConn) clientConnAliveSnapshot() bool {
state := c.logicalStateSnapshot()
if state == nil {
return false
}
return state.aliveSnapshot()
}
func (c *LogicalConn) clientConnStatusSnapshot() Status {
return c.Status()
}
func (c *LogicalConn) clientConnIdentityBoundSnapshot() bool {
return c.clientConnLogicalPeerStateSnapshot().identityBound
}
func (c *LogicalConn) clientConnUsesStreamTransportSnapshot() bool {
return c.usesStreamTransportSnapshot()
}
func (c *LogicalConn) clientConnTransportGenerationSnapshot() uint64 {
return c.transportGenerationSnapshot()
}
func (c *LogicalConn) clientConnTransportAttachCountSnapshot() uint64 {
state := c.ensureTransportState()
if state == nil {
return 0
}
return state.attachCount.Load()
}
func (c *LogicalConn) clientConnTransportDetachCountSnapshot() uint64 {
state := c.ensureTransportState()
if state == nil {
return 0
}
return state.detachCount.Load()
}
func (c *LogicalConn) clientConnTransportSnapshot() net.Conn {
return c.transportSnapshot()
}
func (c *LogicalConn) clientConnTransportBindingSnapshot() *transportBinding {
return c.transportBindingSnapshot()
}
func (c *LogicalConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime {
return c.sessionRuntimeSnapshot()
}
func (c *LogicalConn) clientConnStopContextSnapshot() context.Context {
return c.stopContextSnapshot()
}
func (c *LogicalConn) clientConnStopFuncSnapshot() context.CancelFunc {
return c.stopFuncSnapshot()
}
func (c *LogicalConn) clientConnTransportStopContextSnapshot() context.Context {
return c.transportStopContextSnapshot()
}
func (c *LogicalConn) clientConnTransportAttachedSnapshot() bool {
return c.transportAttachedSnapshot()
}
func (c *LogicalConn) clientConnLogicalTransportDetachedSnapshot() bool {
return c.logicalTransportDetachedSnapshot()
}
func (c *LogicalConn) clientConnLastTransportAttachedAtSnapshot() time.Time {
state := c.ensureTransportState()
if state == nil {
return time.Time{}
}
unixNano := state.lastAttachAt.Load()
if unixNano == 0 {
return time.Time{}
}
return time.Unix(0, unixNano)
}
func (c *LogicalConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState {
return c.transportDetachSnapshot()
}
func (c *LogicalConn) clientConnTransportDetachKindSnapshot() string {
detach := c.transportDetachSnapshot()
if detach == nil {
return ""
}
return classifyClientConnTransportDetachReason(detach.Reason)
}
func (c *LogicalConn) clientConnTransportDetachGenerationSnapshot() uint64 {
detach := c.transportDetachSnapshot()
if detach == nil {
return 0
}
if detach.Generation == 0 {
return c.transportGenerationSnapshot()
}
return detach.Generation
}
func (c *LogicalConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) {
detach := c.transportDetachSnapshot()
if detach == nil || detach.At.IsZero() {
return time.Time{}, false
}
server := c.Server()
if server == nil {
return time.Time{}, false
}
keepSec := server.DetachedClientKeepSec()
if keepSec <= 0 {
return time.Time{}, false
}
return detach.At.Add(time.Duration(keepSec) * time.Second), true
}
func (c *LogicalConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool {
return c.transportDetachExpiredSnapshot(now)
}
func (c *LogicalConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration {
if !c.clientConnLogicalTransportDetachedSnapshot() {
return 0
}
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
if !ok || !now.Before(expiry) {
return 0
}
return expiry.Sub(now)
}
func (c *LogicalConn) clientConnReattachEligibleSnapshot(now time.Time) bool {
return c.reattachEligibleSnapshot(now)
}
func (c *LogicalConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot {
return c.runtimeSnapshot()
}
func (c *LogicalConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState {
return c.attachmentStateSnapshot()
}
func (c *LogicalConn) clientConnMaxReadTimeoutSnapshot() time.Duration {
return c.attachmentStateSnapshot().maxReadTimeout
}
func (c *LogicalConn) clientConnMaxWriteTimeoutSnapshot() time.Duration {
return c.maxWriteTimeoutSnapshot()
}
func (c *LogicalConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
return c.msgEnSnapshot()
}
func (c *LogicalConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
return c.msgDeSnapshot()
}
func (c *LogicalConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder {
return c.fastStreamEncodeSnapshot()
}
func (c *LogicalConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder {
return c.fastBulkEncodeSnapshot()
}
func (c *LogicalConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder {
return c.fastPlainEncodeSnapshot()
}
func (c *LogicalConn) clientConnHandshakeRsaKeySnapshot() []byte {
return c.attachmentStateSnapshot().handshakeRsaKey
}
func (c *LogicalConn) clientConnSecretKeySnapshot() []byte {
return c.secretKeySnapshot()
}
func (c *LogicalConn) clientConnLastHeartbeatUnixSnapshot() int64 {
return c.lastHeartbeatUnixSnapshot()
}
func (c *LogicalConn) setClientConnID(id string) {
c.setID(id)
}
func (c *LogicalConn) setClientConnRemoteAddr(addr net.Addr) {
c.setRemoteAddr(addr)
}
func (c *LogicalConn) setClientConnLastHeartbeatUnix(unix int64) {
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.lastHeartBeat = unix
})
}
func (c *LogicalConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
return c.startSession(tuConn, stopCtx, stopFn)
}
func (c *LogicalConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
return c.startSessionTransport(tuConn, stopCtx, stopFn)
}
func (c *LogicalConn) attachClientConnSessionTransport(tuConn net.Conn) error {
return c.attachSessionTransport(tuConn)
}
func (c *LogicalConn) detachClientConnTransportForTransfer() (net.Conn, error) {
return c.detachTransportForTransfer()
}
func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) {
if client := c.compatClientConn(); client != nil {
client.applyClientConnAttachmentProfile(maxReadTimeout, maxWriteTimeout, msgEn, msgDe, handshakeRsaKey, secretKey)
return
}
c.updateAttachmentState(func(state *clientConnAttachmentState) {
state.maxReadTimeout = maxReadTimeout
state.maxWriteTimeout = maxWriteTimeout
state.msgEn = msgEn
state.msgDe = msgDe
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
})
}
func (c *LogicalConn) inheritClientConnAttachmentProfile(src *ClientConn) {
if src == nil {
return
}
if client := c.compatClientConn(); client != nil {
client.inheritClientConnAttachmentProfile(src)
return
}
c.setAttachmentState(src.clientConnAttachmentStateSnapshot())
}
func (c *LogicalConn) sessionRuntimeSnapshot() *clientConnSessionRuntime {
state := c.logicalRuntimeStateSnapshot()
if state == nil {
return nil
}
return state.sessionRuntimeSnapshot()
}
func (c *LogicalConn) setSessionRuntime(rt *clientConnSessionRuntime) {
if c == nil || rt == nil {
return
}
var oldBinding *transportBinding
if prev := c.sessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport {
oldBinding = prev.transport
}
if rt.transport == nil && rt.tuConn != nil {
rt.transport = newTransportBinding(rt.tuConn, nil)
}
normalizeClientConnSessionRuntimeTransportState(rt)
ensureClientConnSessionRuntimeTransportLifecycle(rt)
ensureClientConnSessionRuntimeTransportDone(rt)
state := c.logicalRuntimeStateSnapshot()
if state == nil {
client := c.compatClientConn()
if client != nil {
client.sessionRuntime.Store(rt)
}
return
}
state.setSessionRuntime(rt)
client := c.compatClientConn()
if client != nil {
client.syncLegacySessionRuntimeFromState(state)
}
if oldBinding != nil {
oldBinding.stopBackgroundWorkers()
}
}
func (c *LogicalConn) clearSessionRuntimeTransport() {
if c == nil {
return
}
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return
}
if rt.transportStopFn != nil {
rt.transportStopFn()
}
next := *rt
next.transport = nil
next.transportAttached = false
next.transportGeneration = 0
next.tuConn = nil
next.transportStopCtx = nil
next.transportStopFn = nil
next.transportDone = nil
c.setSessionRuntime(&next)
}
func (c *LogicalConn) transportSnapshot() net.Conn {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transport != nil {
return rt.transport.connSnapshot()
}
return rt.tuConn
}
func (c *LogicalConn) stopContextSnapshot() context.Context {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopCtx
}
func (c *LogicalConn) stopFuncSnapshot() context.CancelFunc {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil
}
return rt.stopFn
}
func (c *LogicalConn) transportStopContextSnapshot() context.Context {
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil
}
if rt.transportStopCtx != nil {
return rt.transportStopCtx
}
return rt.stopCtx
}
func (c *LogicalConn) closeTransport() {
rt := c.sessionRuntimeSnapshot()
var binding *transportBinding
if rt != nil {
binding = rt.transport
}
conn := c.transportSnapshot()
if conn == nil {
if binding != nil {
binding.stopBackgroundWorkers()
}
return
}
_ = conn.Close()
if binding != nil {
binding.stopBackgroundWorkers()
}
}
func (c *LogicalConn) startSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
if c == nil {
return stopCtx, stopFn
}
if stopCtx == nil || stopFn == nil {
stopCtx, stopFn = context.WithCancel(context.Background())
}
if c.RemoteAddr() == nil && tuConn != nil {
c.setRemoteAddr(tuConn.RemoteAddr())
}
transportGeneration := uint64(0)
if tuConn != nil {
c.markStreamTransport()
transportGeneration = c.markTransportAttached()
c.clearTransportDetachState()
}
c.setSessionRuntime(&clientConnSessionRuntime{
transport: newTransportBinding(tuConn, nil),
transportAttached: tuConn != nil,
transportGeneration: transportGeneration,
tuConn: tuConn,
stopCtx: stopCtx,
stopFn: stopFn,
})
c.markSessionStarted()
return stopCtx, stopFn
}
func (c *LogicalConn) startSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
if c == nil {
return stopCtx, stopFn
}
stopCtx, stopFn = c.startSession(tuConn, stopCtx, stopFn)
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return stopCtx, stopFn
}
go c.readTUMessageLoop(rt)
return stopCtx, stopFn
}
func (c *LogicalConn) attachSessionTransport(tuConn net.Conn) error {
if c == nil {
return errLogicalConnClientNil
}
if tuConn == nil {
return errors.New("conn is nil")
}
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return errors.New("client conn session runtime is nil")
}
oldBinding := rt.transport
if rt.transportStopFn != nil {
rt.transportStopFn()
}
next := *rt
next.transport = newTransportBinding(tuConn, nil)
next.transportAttached = true
next.transportGeneration = c.markTransportAttached()
next.tuConn = tuConn
next.transportStopCtx = nil
next.transportStopFn = nil
next.transportDone = nil
c.setSessionRuntime(&next)
if tuConn.RemoteAddr() != nil {
c.setRemoteAddr(tuConn.RemoteAddr())
}
c.markStreamTransport()
c.clearTransportDetachState()
if oldBinding != nil {
if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != tuConn {
_ = oldConn.Close()
}
}
attached := c.sessionRuntimeSnapshot()
if attached == nil {
return nil
}
go c.readTUMessageLoop(attached)
return nil
}
func (c *LogicalConn) detachTransportForTransfer() (net.Conn, error) {
if c == nil {
return nil, errLogicalConnClientNil
}
rt := c.sessionRuntimeSnapshot()
if rt == nil {
return nil, errors.New("client conn session runtime is nil")
}
conn := rt.tuConn
if rt.transport != nil && rt.transport.connSnapshot() != nil {
conn = rt.transport.connSnapshot()
}
next := *rt
next.transport = nil
next.transportAttached = false
next.transportGeneration = 0
next.tuConn = nil
next.transportStopCtx = nil
next.transportStopFn = nil
next.transportDone = nil
c.setSessionRuntime(&next)
if rt.transportStopFn != nil {
rt.transportStopFn()
}
if conn != nil {
_ = conn.SetReadDeadline(time.Now())
}
if rt.transportDone != nil {
select {
case <-rt.transportDone:
case <-time.After(time.Second):
if conn != nil {
_ = conn.Close()
}
return nil, errors.New("timed out waiting for transport handoff")
}
}
if conn != nil {
_ = conn.SetReadDeadline(time.Time{})
}
return conn, nil
}
func (c *LogicalConn) CurrentTransportConn() *TransportConn {
return c.currentTransportConnSnapshot()
}
func (c *LogicalConn) transportConnSnapshotForInbound(conn net.Conn, remoteAddr net.Addr, generation uint64, hasRuntimeConn bool) *TransportConn {
if c == nil {
return nil
}
if remoteAddr == nil {
if conn != nil {
remoteAddr = conn.RemoteAddr()
}
if remoteAddr == nil {
remoteAddr = c.RemoteAddr()
}
}
if remoteAddr == nil && !hasRuntimeConn {
return nil
}
attached := false
currentGeneration := c.transportGenerationSnapshot()
if conn != nil {
binding := c.transportBindingSnapshot()
if binding != nil && binding.connSnapshot() == conn && c.transportAttachedSnapshot() && currentGeneration == generation {
attached = true
}
} else {
current := c.CurrentTransportConn()
if current != nil && currentGeneration == generation && transportConnAddrString(current.RemoteAddr()) == transportConnAddrString(remoteAddr) {
attached = current.Attached()
if !hasRuntimeConn {
hasRuntimeConn = current.HasRuntimeConn()
}
}
}
return &TransportConn{
logical: c,
generation: generation,
remoteAddr: remoteAddr,
attached: attached,
hasRuntimeConn: hasRuntimeConn,
}
}