package notify import ( "errors" "fmt" "net" "time" ) const defaultSignalAckTimeout = time.Second const defaultSignalSendRetry = 3 type signalReliabilityConfig struct { Enabled bool EnableConfigured bool AckTimeout time.Duration SendRetry int ReceiveCacheLimit int } type SignalReliabilityOptions struct { Enabled bool AckTimeout time.Duration SendRetry int ReceiveCacheLimit int } var ( errSignalReliabilityClientNil = errors.New("signal reliability client is nil") errSignalReliabilityServerNil = errors.New("signal reliability server is nil") errSignalReliabilityUnsupportedClient = errors.New("signal reliability client type is unsupported") errSignalReliabilityUnsupportedServer = errors.New("signal reliability server type is unsupported") ) func defaultSignalReliabilityConfig() signalReliabilityConfig { return signalReliabilityConfig{ Enabled: false, EnableConfigured: false, AckTimeout: defaultSignalAckTimeout, SendRetry: defaultSignalSendRetry, ReceiveCacheLimit: defaultReceivedSignalCacheLimit, } } func DefaultSignalReliabilityOptions() SignalReliabilityOptions { cfg := defaultSignalReliabilityConfig() return SignalReliabilityOptions{ Enabled: true, AckTimeout: cfg.AckTimeout, SendRetry: cfg.SendRetry, ReceiveCacheLimit: cfg.ReceiveCacheLimit, } } func normalizeSignalReliabilityConfig(cfg signalReliabilityConfig) signalReliabilityConfig { defaults := defaultSignalReliabilityConfig() if cfg.AckTimeout <= 0 { cfg.AckTimeout = defaults.AckTimeout } if cfg.SendRetry <= 0 { cfg.SendRetry = defaults.SendRetry } if cfg.ReceiveCacheLimit <= 0 { cfg.ReceiveCacheLimit = defaults.ReceiveCacheLimit } return cfg } func signalReliabilityConfigFromOptions(opts *SignalReliabilityOptions) signalReliabilityConfig { cfg := defaultSignalReliabilityConfig() if opts == nil { cfg.Enabled = true cfg.EnableConfigured = true return cfg } return signalReliabilityConfig{ Enabled: opts.Enabled, EnableConfigured: true, AckTimeout: opts.AckTimeout, SendRetry: opts.SendRetry, ReceiveCacheLimit: opts.ReceiveCacheLimit, } } type signalReliabilityConfigurer interface { setSignalReliabilityConfig(signalReliabilityConfig) } func UseSignalReliabilityClient(c Client, opts *SignalReliabilityOptions) error { if c == nil { return errSignalReliabilityClientNil } configurer, ok := any(c).(signalReliabilityConfigurer) if !ok { return errSignalReliabilityUnsupportedClient } configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts)) return nil } func UseSignalReliabilityServer(s Server, opts *SignalReliabilityOptions) error { if s == nil { return errSignalReliabilityServerNil } configurer, ok := any(s).(signalReliabilityConfigurer) if !ok { return errSignalReliabilityUnsupportedServer } configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts)) return nil } func (c *ClientCommon) getSignalReliabilityConfig() signalReliabilityConfig { c.mu.Lock() defer c.mu.Unlock() c.signalReliableCfg = normalizeSignalReliabilityConfig(c.signalReliableCfg) return c.signalReliableCfg } func (s *ServerCommon) getSignalReliabilityConfig() signalReliabilityConfig { s.mu.Lock() defer s.mu.Unlock() s.signalReliableCfg = normalizeSignalReliabilityConfig(s.signalReliableCfg) return s.signalReliableCfg } func (c *ClientCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) { cfg = normalizeSignalReliabilityConfig(cfg) c.mu.Lock() cfg.EnableConfigured = true c.signalReliableCfg = cfg state := c.logicalSession c.mu.Unlock() if state != nil { state.applySignalReliabilityConfig(cfg) } } func (s *ServerCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) { cfg = normalizeSignalReliabilityConfig(cfg) s.mu.Lock() cfg.EnableConfigured = true s.signalReliableCfg = cfg state := s.logicalSession s.mu.Unlock() if state != nil { state.applySignalReliabilityConfig(cfg) } } func (c *ClientCommon) applySignalReliabilityTransportDefault(enabled bool) { cfg := c.getSignalReliabilityConfig() if cfg.EnableConfigured { return } cfg.Enabled = enabled c.mu.Lock() c.signalReliableCfg = cfg c.mu.Unlock() } func (s *ServerCommon) applySignalReliabilityTransportDefault(enabled bool) { cfg := s.getSignalReliabilityConfig() if cfg.EnableConfigured { return } cfg.Enabled = enabled s.mu.Lock() s.signalReliableCfg = cfg s.mu.Unlock() } func requiresSignalReplyWait(msg TransferMsg) bool { return msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT } func signalCanUseTransportAck(msg TransferMsg) bool { return msg.ID != 0 } func retryReliableSignalSend(cfg signalReliabilityConfig, send func(signalReliabilityConfig) error) error { return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, _ int) error { return send(cfg) }) } func retryReliableSignalSendWithAttempt(cfg signalReliabilityConfig, send func(signalReliabilityConfig, int) error) error { cfg = normalizeSignalReliabilityConfig(cfg) var lastErr error for attempt := 1; attempt <= cfg.SendRetry; attempt++ { lastErr = send(cfg, attempt) if lastErr == nil { return nil } } return lastErr } func sendSignalWithAck(scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error { return sendSignalWithAckTracked(nil, scope, signalID, timeout, pool, send) } func sendSignalWithAckTracked(state *signalReliabilityState, scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error { if state != nil { state.incAckWait() } wait := pool.prepare(scope, signalID) if err := send(); err != nil { wait.cancel() return err } err := pool.waitPrepared(wait, timeout) if state == nil { return err } if err == nil { state.incAckDeliver() return nil } if errors.Is(err, errSignalAckTimeout) { state.incAckTimeout() return err } if errors.Is(err, errSignalAckCanceled) { state.incAckCanceled() return err } return err } func (c *ClientCommon) sendSignalEnvelopeMaybeReliable(env Envelope, msg TransferMsg) error { state := c.getSignalReliabilityState() state.incSignalSend() cfg := c.getSignalReliabilityConfig() if !cfg.Enabled || !signalCanUseTransportAck(msg) { return c.sendEnvelope(env) } state.incReliableSend() return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error { if attempt > 1 { state.incRetry() } return sendSignalWithAckTracked(state, clientFileScope(), env.ID, cfg.AckTimeout, c.getSignalAckPool(), func() error { return c.sendEnvelope(env) }) }) } func (s *ServerCommon) sendSignalEnvelopeMaybeReliable(logical *LogicalConn, env Envelope, msg TransferMsg) error { if logical == nil { return s.sendSignalEnvelopeMaybeReliableTransport(nil, env, msg) } return s.sendSignalEnvelopeMaybeReliableTransport(s.resolveOutboundTransport(logical), env, msg) } func (s *ServerCommon) sendSignalEnvelopeMaybeReliableTransport(transport *TransportConn, env Envelope, msg TransferMsg) error { state := s.getSignalReliabilityState() state.incSignalSend() cfg := s.getSignalReliabilityConfig() if !cfg.Enabled || !signalCanUseTransportAck(msg) { return s.sendEnvelopeTransport(transport, env) } state.incReliableSend() scope := serverTransportScopeForTransport(transport) return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error { if attempt > 1 { state.incRetry() } return sendSignalWithAckTracked(state, scope, env.ID, cfg.AckTimeout, s.getSignalAckPool(), func() error { return s.sendEnvelopeTransport(transport, env) }) }) } func (c *ClientCommon) sendSignalAck(signalID uint64) error { return c.sendEnvelope(newSignalAckEnvelope(signalID)) } func (s *ServerCommon) sendSignalAck(logical *LogicalConn, signalID uint64) error { if logical == nil { return s.sendSignalAckTransport(nil, signalID) } return s.sendSignalAckTransport(s.resolveOutboundTransport(logical), signalID) } func (s *ServerCommon) sendSignalAckTransport(transport *TransportConn, signalID uint64) error { return s.sendEnvelopeTransport(transport, newSignalAckEnvelope(signalID)) } func (s *ServerCommon) sendSignalAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, signalID uint64) error { if conn == nil { return s.sendSignalAckTransport(transport, signalID) } return s.sendEnvelopeInboundTransport(logical, transport, conn, newSignalAckEnvelope(signalID)) } func (c *ClientCommon) handleSignalAckEnvelope(env Envelope) bool { return c.getSignalAckPool().deliver(clientFileScope(), env.ID) } func (s *ServerCommon) handleSignalAckEnvelope(logical *LogicalConn, env Envelope) bool { if logical == nil { return s.handleSignalAckEnvelopeTransport(nil, env) } return s.handleSignalAckEnvelopeTransport(s.resolveOutboundTransport(logical), env) } func (s *ServerCommon) handleSignalAckEnvelopeTransport(transport *TransportConn, env Envelope) bool { return s.getSignalAckPool().deliverAny(serverTransportDeliveryScopesForTransport(transport), env.ID) } func (c *ClientCommon) handleReceivedSignalReliability(msg TransferMsg) bool { cfg := c.getSignalReliabilityConfig() if !cfg.Enabled || !signalCanUseTransportAck(msg) { return false } state := c.getSignalReliabilityState() duplicate := c.getReceivedSignalCache().seenOrRemember(clientFileScope(), msg.ID) if duplicate { state.incDuplicateRecv() } state.incAckSend() if err := c.sendSignalAck(msg.ID); err != nil { state.incAckSendError() if c.showError || c.debugMode { fmt.Println("client send signal ack error", err) } } return duplicate } func (s *ServerCommon) handleReceivedSignalReliability(logical *LogicalConn, msg TransferMsg) bool { return s.handleReceivedSignalReliabilityTransport(logical, s.resolveOutboundTransport(logical), nil, msg) } func (s *ServerCommon) handleReceivedSignalReliabilityTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) bool { cfg := s.getSignalReliabilityConfig() if !cfg.Enabled || !signalCanUseTransportAck(msg) { return false } state := s.getSignalReliabilityState() scope := serverFileScope(logical) duplicate := s.getReceivedSignalCache().seenOrRemember(scope, msg.ID) if duplicate { state.incDuplicateRecv() } state.incAckSend() if err := s.sendSignalAckInbound(logical, transport, conn, msg.ID); err != nil { state.incAckSendError() if s.showError || s.debugMode { fmt.Println("server send signal ack error", err) } } return duplicate }