package notify import ( "b612.me/stario" "fmt" "net" "time" ) type serverInboundSource struct { Source string Logical *LogicalConn Conn net.Conn RemoteAddr net.Addr TransportGeneration uint64 HasRuntimeConn bool } func newServerInboundSource(logical *LogicalConn, conn net.Conn, remoteAddr net.Addr, generation uint64) serverInboundSource { if remoteAddr == nil && conn != nil { remoteAddr = conn.RemoteAddr() } source := "" if conn != nil && conn.RemoteAddr() != nil { source = conn.RemoteAddr().String() } if source == "" && logical != nil && logical.ID() != "" { source = logical.ID() } if source == "" && remoteAddr != nil { source = remoteAddr.String() } if source == "" && logical != nil && logical.RemoteAddr() != nil { source = logical.RemoteAddr().String() } return serverInboundSource{ Source: source, Logical: logical, Conn: conn, RemoteAddr: remoteAddr, TransportGeneration: generation, HasRuntimeConn: conn != nil, } } func (s *ServerCommon) pushMessageSource(data []byte, source interface{}) { queue := s.serverQueueSnapshot() if queue == nil || len(data) == 0 { return } if s.pushMessageSourceFast(queue, data, source) { return } _ = queue.ParseMessage(data, source) } func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byte, source interface{}) bool { dispatcher := s.serverInboundDispatcherSnapshot() if queue == nil || dispatcher == nil || len(data) == 0 { return false } if err := queue.ParseMessageView(data, source, func(frame stario.FrameView) error { s.pushTransportPayloadSourceFast(frame.Payload, nil, frame.Conn) return nil }); err != nil && (s.showError || s.debugMode) { fmt.Println("server parse inbound frame error", err) } return true } func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release func(), source interface{}) bool { dispatcher := s.serverInboundDispatcherSnapshot() if len(payload) == 0 { if release != nil { release() } return false } if dispatcher == nil { queue := s.serverQueueSnapshot() if queue == nil { if release != nil { release() } return false } frame := queue.BuildMessage(payload) if release != nil { release() } if err := queue.ParseMessage(frame, source); err != nil && (s.showError || s.debugMode) { fmt.Println("server enqueue inbound frame error", err) } return true } logical, transport := s.resolveInboundSource(source) if logical == nil { if release != nil { release() } return true } plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release) if err != nil { if s.showError || s.debugMode { fmt.Println("server decode transport payload error", err) } return true } inboundConn := serverInboundConn(source) if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) { return true } owned := plain if plainRelease != nil { owned = append([]byte(nil), plain...) plainRelease() } s.wg.Add(1) if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { defer s.wg.Done() now := time.Now() if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) { fmt.Println("server decode envelope error", err) } }) { s.wg.Done() } return true } func serverInboundConn(source interface{}) net.Conn { switch data := source.(type) { case net.Conn: return data case serverInboundSource: return data.Conn case *serverInboundSource: if data != nil { return data.Conn } } return nil } func (s *ServerCommon) resolveInboundSource(source interface{}) (*LogicalConn, *TransportConn) { switch data := source.(type) { case serverInboundSource: return s.resolveInboundSourceValue(data) case *serverInboundSource: if data == nil { return nil, nil } return s.resolveInboundSourceValue(*data) case string: return s.resolveLogicalBySource(data), nil default: return nil, nil } } func (s *ServerCommon) resolveInboundSourceValue(source serverInboundSource) (*LogicalConn, *TransportConn) { logical := source.Logical if logical == nil { logical = s.resolveLogicalBySource(source.Source) } else if source.HasRuntimeConn { transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn) if transport == nil || !transport.Attached() { if rebound := s.resolveLogicalBySource(source.Source); rebound != nil { logical = rebound } else if !logical.Status().Alive { return nil, nil } } } transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn) return logical, transport }