notify/server_inbound_source.go

179 lines
4.8 KiB
Go
Raw Permalink Normal View History

package notify
import (
"b612.me/stario"
"fmt"
"net"
"time"
)
type serverInboundSource struct {
Source string
Logical *LogicalConn
Conn net.Conn
RemoteAddr net.Addr
TransportGeneration uint64
HasRuntimeConn bool
}
func newServerInboundSource(logical *LogicalConn, conn net.Conn, remoteAddr net.Addr, generation uint64) serverInboundSource {
if remoteAddr == nil && conn != nil {
remoteAddr = conn.RemoteAddr()
}
source := ""
if conn != nil && conn.RemoteAddr() != nil {
source = conn.RemoteAddr().String()
}
if source == "" && logical != nil && logical.ID() != "" {
source = logical.ID()
}
if source == "" && remoteAddr != nil {
source = remoteAddr.String()
}
if source == "" && logical != nil && logical.RemoteAddr() != nil {
source = logical.RemoteAddr().String()
}
return serverInboundSource{
Source: source,
Logical: logical,
Conn: conn,
RemoteAddr: remoteAddr,
TransportGeneration: generation,
HasRuntimeConn: conn != nil,
}
}
func (s *ServerCommon) pushMessageSource(data []byte, source interface{}) {
queue := s.serverQueueSnapshot()
if queue == nil || len(data) == 0 {
return
}
if s.pushMessageSourceFast(queue, data, source) {
return
}
_ = queue.ParseMessage(data, source)
}
func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byte, source interface{}) bool {
dispatcher := s.serverInboundDispatcherSnapshot()
if queue == nil || dispatcher == nil || len(data) == 0 {
return false
}
if err := queue.ParseMessageView(data, source, func(frame stario.FrameView) error {
s.pushTransportPayloadSourceFast(frame.Payload, nil, frame.Conn)
return nil
}); err != nil && (s.showError || s.debugMode) {
fmt.Println("server parse inbound frame error", err)
}
return true
}
func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release func(), source interface{}) bool {
dispatcher := s.serverInboundDispatcherSnapshot()
if len(payload) == 0 {
if release != nil {
release()
}
return false
}
if dispatcher == nil {
queue := s.serverQueueSnapshot()
if queue == nil {
if release != nil {
release()
}
return false
}
frame := queue.BuildMessage(payload)
if release != nil {
release()
}
if err := queue.ParseMessage(frame, source); err != nil && (s.showError || s.debugMode) {
fmt.Println("server enqueue inbound frame error", err)
}
return true
}
logical, transport := s.resolveInboundSource(source)
if logical == nil {
if release != nil {
release()
}
return true
}
plain, plainRelease, err := s.decryptTransportPayloadLogicalPooled(logical, payload, release)
if err != nil {
if s.showError || s.debugMode {
fmt.Println("server decode transport payload error", err)
}
return true
}
inboundConn := serverInboundConn(source)
if s.tryDispatchBorrowedTransportPlain(logical, transport, inboundConn, plain, plainRelease) {
return true
}
owned := plain
if plainRelease != nil {
owned = append([]byte(nil), plain...)
plainRelease()
}
s.wg.Add(1)
if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() {
defer s.wg.Done()
now := time.Now()
if err := s.dispatchInboundTransportPlain(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) {
fmt.Println("server decode envelope error", err)
}
}) {
s.wg.Done()
}
return true
}
func serverInboundConn(source interface{}) net.Conn {
switch data := source.(type) {
case net.Conn:
return data
case serverInboundSource:
return data.Conn
case *serverInboundSource:
if data != nil {
return data.Conn
}
}
return nil
}
func (s *ServerCommon) resolveInboundSource(source interface{}) (*LogicalConn, *TransportConn) {
switch data := source.(type) {
case serverInboundSource:
return s.resolveInboundSourceValue(data)
case *serverInboundSource:
if data == nil {
return nil, nil
}
return s.resolveInboundSourceValue(*data)
case string:
return s.resolveLogicalBySource(data), nil
default:
return nil, nil
}
}
func (s *ServerCommon) resolveInboundSourceValue(source serverInboundSource) (*LogicalConn, *TransportConn) {
logical := source.Logical
if logical == nil {
logical = s.resolveLogicalBySource(source.Source)
} else if source.HasRuntimeConn {
transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn)
if transport == nil || !transport.Attached() {
if rebound := s.resolveLogicalBySource(source.Source); rebound != nil {
logical = rebound
} else if !logical.Status().Alive {
return nil, nil
}
}
}
transport := logical.transportConnSnapshotForInbound(source.Conn, source.RemoteAddr, source.TransportGeneration, source.HasRuntimeConn)
return logical, transport
}