package notify import ( "context" "net" "os" "time" ) type serverLogicalTransportDetacher interface { detachLogicalSessionTransport(logical *LogicalConn, reason string, err error) } type serverInboundSourcePusher interface { pushMessageSource([]byte, interface{}) } func (c *LogicalConn) readTUMessage() { rt := c.clientConnSessionRuntimeSnapshot() if rt == nil { return } c.readTUMessageLoop(rt) } func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) { if rt == nil { return } stopCtx := rt.transportStopCtx if stopCtx == nil { stopCtx = rt.stopCtx } if stopCtx == nil { return } conn := rt.tuConn generation := rt.transportGeneration defer closeClientConnSessionRuntimeTransportDone(rt) buf := streamReadBuffer() for { select { case <-sessionStopChan(stopCtx): if c.shouldCloseTransportOnStop(conn) { _ = conn.Close() } return default: } num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf) if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) { return } } } func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { if len(data) == 0 { data = streamReadBuffer() } if conn == nil { return 0, nil, net.ErrClosed } if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { _ = conn.SetReadDeadline(time.Now().Add(timeout)) } num, err := conn.Read(data) return num, data, err } func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { if err == os.ErrDeadlineExceeded { if num != 0 { c.pushServerOwnedTransportMessage(data[:num], conn, generation) } return true } if err != nil { select { case <-sessionStopChan(stopCtx): if c.shouldCloseTransportOnStop(conn) { _ = conn.Close() } return false default: } if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { detacher.detachLogicalSessionTransport(c, "read error", err) return false } c.stopServerOwnedSession("read error", err) return false } c.pushServerOwnedTransportMessage(data[:num], conn, generation) return true } func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { if c == nil || len(data) == 0 { return } server := c.Server() if server == nil { return } if pusher, ok := server.(serverInboundSourcePusher); ok { pusher.pushMessageSource(data, newServerInboundSource(c, conn, nil, generation)) return } server.pushMessage(data, c.clientConnIDSnapshot()) } func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool { if c == nil || conn == nil { return false } rt := c.clientConnSessionRuntimeSnapshot() if rt == nil || !rt.transportAttached { return false } current := rt.tuConn if rt.transport != nil && rt.transport.connSnapshot() != nil { current = rt.transport.connSnapshot() } return current == conn } func (c *ClientConn) readFromTUTransport() (int, []byte, error) { binding := c.clientConnTransportBindingSnapshot() if binding == nil { return 0, nil, net.ErrClosed } conn := binding.connSnapshot() return c.readFromTUTransportConn(conn) } func (c *ClientConn) readFromTUTransportConn(conn net.Conn) (int, []byte, error) { return c.readFromTUTransportConnWithBuffer(conn, streamReadBuffer()) } func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { if logical := c.LogicalConn(); logical != nil { return logical.readFromTUTransportConnWithBuffer(conn, data) } if len(data) == 0 { data = streamReadBuffer() } if conn == nil { return 0, nil, net.ErrClosed } if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { _ = conn.SetReadDeadline(time.Now().Add(timeout)) } num, err := conn.Read(data) return num, data, err } func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool { return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err) } func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { if logical := c.LogicalConn(); logical != nil { return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) } if err == os.ErrDeadlineExceeded { if num != 0 { c.pushServerOwnedTransportMessage(data[:num], conn, generation) } return true } if err != nil { select { case <-sessionStopChan(stopCtx): if c.shouldCloseClientConnTransportOnStop(conn) { _ = conn.Close() } return false default: } if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err) return false } c.stopServerOwnedSession("read error", err) return false } c.pushServerOwnedTransportMessage(data[:num], conn, generation) return true } func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) { if logical := c.LogicalConn(); logical != nil { logical.pushServerOwnedTransportMessage(data, conn, generation) return } if c == nil || c.server == nil || len(data) == 0 { return } if pusher, ok := c.server.(serverInboundSourcePusher); ok { pusher.pushMessageSource(data, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation)) return } c.server.pushMessage(data, c.clientConnIDSnapshot()) } func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool { if logical := c.LogicalConn(); logical != nil { return logical.shouldCloseTransportOnStop(conn) } if c == nil || conn == nil { return false } rt := c.clientConnSessionRuntimeSnapshot() if rt == nil || !rt.transportAttached { return false } current := rt.tuConn if rt.transport != nil && rt.transport.connSnapshot() != nil { current = rt.transport.connSnapshot() } return current == conn }