package notify import ( "context" "fmt" "math/rand" "net" "os" "sync/atomic" "time" ) type serverOutboundRoute struct { logical *LogicalConn transport *TransportConn } func (s *ServerCommon) resolveOutboundRoute(logical *LogicalConn) serverOutboundRoute { if logical == nil { return serverOutboundRoute{} } return serverOutboundRoute{ logical: logical, transport: logical.CurrentTransportConn(), } } func (s *ServerCommon) resolveOutboundTransport(logical *LogicalConn) *TransportConn { return s.resolveOutboundRoute(logical).transport } func (s *ServerCommon) send(c *ClientConn, msg TransferMsg) (WaitMsg, error) { return s.sendLogical(logicalConnFromClient(c), msg) } func (s *ServerCommon) sendLogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { if logical == nil { return s.sendTransport(nil, msg) } return s.sendTransport(s.resolveOutboundTransport(logical), msg) } func (s *ServerCommon) sendTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { if err := s.ensureServerTransportSendReady(transport); err != nil { return WaitMsg{}, err } if s.serverUDPListenerSnapshot() != nil { return s.sendUDPTransport(transport, msg) } return s.sendTUTransport(transport, msg) } func (s *ServerCommon) sendTU(c *ClientConn, msg TransferMsg) (WaitMsg, error) { return s.sendTULogical(logicalConnFromClient(c), msg) } func (s *ServerCommon) sendTULogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { if logical == nil { return s.sendTransport(nil, msg) } return s.sendTUTransport(s.resolveOutboundTransport(logical), msg) } func (s *ServerCommon) sendTUTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { var wait WaitMsg if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { msg.ID = atomic.AddUint64(&s.msgID, 1) } logical := transport.logicalConnSnapshot() if logical == nil { return WaitMsg{}, transportDetachedErrorForTransport(transport) } env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) if err != nil { return WaitMsg{}, err } if requiresSignalReplyWait(msg) { wait = s.getPendingWaitPool().createAndStoreWithScope(msg, serverTransportScopeForTransport(transport)) } err = s.sendSignalEnvelopeMaybeReliableTransport(transport, env, msg) if err != nil { if requiresSignalReplyWait(msg) { s.getPendingWaitPool().removeAndClose(msg.ID) } return WaitMsg{}, err } return wait, err } func (s *ServerCommon) SendLogical(c *LogicalConn, key string, value MsgVal) error { _, err := s.sendLogical(c, TransferMsg{ Key: key, Value: value, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) SendTransport(t *TransportConn, key string, value MsgVal) error { _, err := s.sendTransport(t, TransferMsg{ Key: key, Value: value, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) Send(c *ClientConn, key string, value MsgVal) error { return s.SendLogical(logicalConnFromClient(c), key, value) } func (s *ServerCommon) sendWait(c *ClientConn, msg TransferMsg, timeout time.Duration) (Message, error) { return s.sendWaitLogical(logicalConnFromClient(c), msg, timeout) } func (s *ServerCommon) sendWaitLogical(logical *LogicalConn, msg TransferMsg, timeout time.Duration) (Message, error) { if logical == nil { return s.sendTransportWait(nil, msg, timeout) } return s.sendTransportWait(s.resolveOutboundTransport(logical), msg, timeout) } func (s *ServerCommon) sendTransportWait(transport *TransportConn, msg TransferMsg, timeout time.Duration) (Message, error) { data, err := s.sendTransport(transport, msg) if err != nil { return Message{}, err } stopCh := sessionStopChan(s.serverStopContextSnapshot()) if timeout.Seconds() == 0 { msg, ok := <-data.Reply if !ok { return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(transport)) } return msg, nil } select { case <-time.After(timeout): s.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) return Message{}, os.ErrDeadlineExceeded case <-stopCh: return Message{}, errServiceShutdown case msg, ok := <-data.Reply: if !ok { return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(transport)) } return msg, nil } } func (s *ServerCommon) SendWaitLogical(c *LogicalConn, key string, value MsgVal, timeout time.Duration) (Message, error) { return s.sendWaitLogical(c, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, timeout) } func (s *ServerCommon) SendWaitTransport(t *TransportConn, key string, value MsgVal, timeout time.Duration) (Message, error) { return s.sendTransportWait(t, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, timeout) } func (s *ServerCommon) SendCtxLogical(ctx context.Context, c *LogicalConn, key string, value MsgVal) (Message, error) { return s.sendCtxLogical(c, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, ctx) } func (s *ServerCommon) sendCtx(c *ClientConn, msg TransferMsg, ctx context.Context) (Message, error) { return s.sendCtxLogical(logicalConnFromClient(c), msg, ctx) } func (s *ServerCommon) sendCtxLogical(logical *LogicalConn, msg TransferMsg, ctx context.Context) (Message, error) { if logical == nil { return s.sendCtxTransport(nil, msg, ctx) } return s.sendCtxTransport(s.resolveOutboundTransport(logical), msg, ctx) } func (s *ServerCommon) SendCtxTransport(ctx context.Context, t *TransportConn, key string, value MsgVal) (Message, error) { return s.sendCtxTransport(t, TransferMsg{ Key: key, Value: value, Type: MSG_SYNC_ASK, }, ctx) } func (s *ServerCommon) sendCtxTransport(t *TransportConn, msg TransferMsg, ctx context.Context) (Message, error) { data, err := s.sendTransport(t, msg) if err != nil { return Message{}, err } stopCh := sessionStopChan(s.serverStopContextSnapshot()) if ctx == nil { ctx = context.Background() } select { case <-ctx.Done(): s.getPendingWaitPool().removeAndClose(data.TransferMsg.ID) return Message{}, normalizeStreamDeadlineError(ctx.Err()) case <-stopCh: return Message{}, errServiceShutdown case msg, ok := <-data.Reply: if !ok { return msg, pendingWaitClosedErrorWith(stopCh, transportDetachedErrorForTransport(t)) } return msg, nil } } func (s *ServerCommon) SendCtx(ctx context.Context, c *ClientConn, key string, value MsgVal) (Message, error) { return s.SendCtxLogical(ctx, logicalConnFromClient(c), key, value) } func (s *ServerCommon) SendWait(c *ClientConn, key string, value MsgVal, timeout time.Duration) (Message, error) { return s.SendWaitLogical(logicalConnFromClient(c), key, value, timeout) } func (s *ServerCommon) SendWaitObjLogical(c *LogicalConn, key string, value interface{}, timeout time.Duration) (Message, error) { data, err := s.sequenceEn(value) if err != nil { return Message{}, err } return s.SendWaitLogical(c, key, data, timeout) } func (s *ServerCommon) SendWaitObjTransport(t *TransportConn, key string, value interface{}, timeout time.Duration) (Message, error) { data, err := s.sequenceEn(value) if err != nil { return Message{}, err } return s.SendWaitTransport(t, key, data, timeout) } func (s *ServerCommon) SendWaitObj(c *ClientConn, key string, value interface{}, timeout time.Duration) (Message, error) { return s.SendWaitObjLogical(logicalConnFromClient(c), key, value, timeout) } func (s *ServerCommon) SendObjCtxLogical(ctx context.Context, c *LogicalConn, key string, val interface{}) (Message, error) { data, err := s.sequenceEn(val) if err != nil { return Message{}, err } return s.SendCtxLogical(ctx, c, key, data) } func (s *ServerCommon) SendObjCtxTransport(ctx context.Context, t *TransportConn, key string, val interface{}) (Message, error) { data, err := s.sequenceEn(val) if err != nil { return Message{}, err } return s.SendCtxTransport(ctx, t, key, data) } func (s *ServerCommon) SendObjCtx(ctx context.Context, c *ClientConn, key string, val interface{}) (Message, error) { return s.SendObjCtxLogical(ctx, logicalConnFromClient(c), key, val) } func (s *ServerCommon) SendObjLogical(c *LogicalConn, key string, val interface{}) error { data, err := encode(val) if err != nil { return err } _, err = s.sendLogical(c, TransferMsg{ Key: key, Value: data, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) SendObjTransport(t *TransportConn, key string, val interface{}) error { data, err := encode(val) if err != nil { return err } _, err = s.sendTransport(t, TransferMsg{ Key: key, Value: data, Type: MSG_ASYNC, }) return err } func (s *ServerCommon) SendObj(c *ClientConn, key string, val interface{}) error { return s.SendObjLogical(logicalConnFromClient(c), key, val) } func (s *ServerCommon) Reply(m Message, value MsgVal) error { return m.Reply(value) } func (s *ServerCommon) sendUDP(c *ClientConn, msg TransferMsg) (WaitMsg, error) { return s.sendUDPLogical(logicalConnFromClient(c), msg) } func (s *ServerCommon) sendUDPLogical(logical *LogicalConn, msg TransferMsg) (WaitMsg, error) { if logical == nil { return s.sendTransport(nil, msg) } return s.sendUDPTransport(s.resolveOutboundTransport(logical), msg) } func (s *ServerCommon) sendUDPTransport(transport *TransportConn, msg TransferMsg) (WaitMsg, error) { var wait WaitMsg if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 { msg.ID = uint64(time.Now().UnixNano()) + rand.Uint64() + rand.Uint64() } env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) if err != nil { return WaitMsg{}, err } if requiresSignalReplyWait(msg) { wait = s.getPendingWaitPool().createAndStoreWithScope(msg, serverTransportScopeForTransport(transport)) } err = s.sendSignalEnvelopeMaybeReliableTransport(transport, env, msg) if err != nil { if requiresSignalReplyWait(msg) { s.getPendingWaitPool().removeAndClose(msg.ID) } return WaitMsg{}, err } return wait, err } func (s *ServerCommon) sendEnvelope(c *ClientConn, env Envelope) error { return s.sendEnvelopeLogical(logicalConnFromClient(c), env) } func (s *ServerCommon) sendEnvelopeLogical(logical *LogicalConn, env Envelope) error { if logical == nil { return s.sendEnvelopeTransport(nil, env) } return s.sendEnvelopeTransport(s.resolveOutboundTransport(logical), env) } func (s *ServerCommon) sendEnvelopeTransport(transport *TransportConn, env Envelope) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } logical := transport.logicalConnSnapshot() if logical == nil { return transportDetachedErrorForTransport(transport) } payload, err := s.encodeEnvelopePayloadLogical(logical, env) if err != nil { return err } if batchedControlEnvelope(env) { return s.writeControlEnvelopePayload(logical, transport, nil, payload) } return s.writeEnvelopePayload(logical, transport, nil, payload) } func (s *ServerCommon) sendEnvelopeInboundTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return transportDetachedErrorForPeer(logical, transport) } if logical.msgEnSnapshot() == nil { return transportDetachedErrorForPeer(logical, transport) } payload, err := s.encodeEnvelopePayloadLogical(logical, env) if err != nil { return err } if batchedControlEnvelope(env) { return s.writeControlEnvelopePayload(logical, transport, conn, payload) } return s.writeEnvelopePayload(logical, transport, conn, payload) } func (s *ServerCommon) writeControlEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { if logical == nil { return transportDetachedErrorForPeer(logical, transport) } if s.serverUDPListenerSnapshot() != nil { return s.writeEnvelopePayload(logical, transport, conn, payload) } binding := logical.transportBindingSnapshot() if binding == nil || binding.queueSnapshot() == nil { return s.writeEnvelopePayload(logical, transport, conn, payload) } boundConn := binding.connSnapshot() if boundConn == nil || isPacketTransportConn(boundConn) { return s.writeEnvelopePayload(logical, transport, conn, payload) } if conn != nil && conn != boundConn { return s.writeEnvelopePayload(logical, transport, conn, payload) } sender := binding.controlBatchSenderSnapshot() if sender == nil { return s.writeEnvelopePayload(logical, transport, conn, payload) } return sender.submit(payload, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) } func (s *ServerCommon) sendTransferInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) error { if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return transportDetachedErrorForPeer(logical, transport) } env, err := wrapTransferMsgEnvelope(msg, s.sequenceEn) if err != nil { return err } return s.sendEnvelopeInboundTransport(logical, transport, conn, env) } func (s *ServerCommon) writeEnvelopePayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte) error { udpListener := s.serverUDPListenerSnapshot() queue := s.serverQueueSnapshot() if queue == nil { return errServerSessionQueueUnavailable } if udpListener != nil { if transport == nil || transport.RemoteAddr() == nil { return transportDetachedErrorForTransport(transport) } if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { _ = udpListener.SetWriteDeadline(time.Now().Add(timeout)) } data := queue.BuildMessage(payload) _, err := udpListener.WriteTo(data, transport.RemoteAddr()) return err } var binding *transportBinding if logical != nil { binding = logical.transportBindingSnapshot() } if conn == nil { if binding == nil { return os.ErrClosed } return binding.withConnWriteLock(func(conn net.Conn) error { if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { return err } } return writeFramedPayloadUnlocked(conn, queue, payload) }) } if binding != nil && binding.connSnapshot() == conn { return binding.withConnWriteLock(func(conn net.Conn) error { if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { return err } } return writeFramedPayloadUnlocked(conn, queue, payload) }) } return withRawConnWriteLock(conn, func(conn net.Conn) error { if timeout := logical.maxWriteTimeoutSnapshot(); timeout > 0 { if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { return err } } return writeFramedPayloadUnlocked(conn, queue, payload) }) } func (s *ServerCommon) dispatchEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) { if transport == nil && logical != nil { transport = logical.CurrentTransportConn() } switch env.Kind { case EnvelopeSignalAck: if s.handleSignalAckEnvelopeTransport(transport, env) { return } case EnvelopeStreamData: s.dispatchStreamEnvelope(logical, transport, conn, env) return case EnvelopeSignal: transfer, err := unwrapTransferMsgEnvelope(env, s.sequenceDe) if err != nil { if s.showError || s.debugMode { fmt.Println("server unwrap signal envelope error", err) } return } if s.handleReceivedSignalReliabilityTransport(logical, transport, conn, transfer) { return } message := Message{ LogicalConn: logical, NetType: NET_SERVER, TransportConn: transport, inboundConn: conn, TransferMsg: transfer, Time: now, } s.dispatchMsg(hydrateServerMessagePeerFields(message)) case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck: s.dispatchFileEnvelope(logical, transport, conn, env, now) default: } }