notify/signal_reliable.go

357 lines
10 KiB
Go
Raw Permalink Normal View History

package notify
import (
"errors"
"fmt"
"net"
"time"
)
const defaultSignalAckTimeout = time.Second
const defaultSignalSendRetry = 3
type signalReliabilityConfig struct {
Enabled bool
EnableConfigured bool
AckTimeout time.Duration
SendRetry int
ReceiveCacheLimit int
}
type SignalReliabilityOptions struct {
Enabled bool
AckTimeout time.Duration
SendRetry int
ReceiveCacheLimit int
}
var (
errSignalReliabilityClientNil = errors.New("signal reliability client is nil")
errSignalReliabilityServerNil = errors.New("signal reliability server is nil")
errSignalReliabilityUnsupportedClient = errors.New("signal reliability client type is unsupported")
errSignalReliabilityUnsupportedServer = errors.New("signal reliability server type is unsupported")
)
func defaultSignalReliabilityConfig() signalReliabilityConfig {
return signalReliabilityConfig{
Enabled: false,
EnableConfigured: false,
AckTimeout: defaultSignalAckTimeout,
SendRetry: defaultSignalSendRetry,
ReceiveCacheLimit: defaultReceivedSignalCacheLimit,
}
}
func DefaultSignalReliabilityOptions() SignalReliabilityOptions {
cfg := defaultSignalReliabilityConfig()
return SignalReliabilityOptions{
Enabled: true,
AckTimeout: cfg.AckTimeout,
SendRetry: cfg.SendRetry,
ReceiveCacheLimit: cfg.ReceiveCacheLimit,
}
}
func normalizeSignalReliabilityConfig(cfg signalReliabilityConfig) signalReliabilityConfig {
defaults := defaultSignalReliabilityConfig()
if cfg.AckTimeout <= 0 {
cfg.AckTimeout = defaults.AckTimeout
}
if cfg.SendRetry <= 0 {
cfg.SendRetry = defaults.SendRetry
}
if cfg.ReceiveCacheLimit <= 0 {
cfg.ReceiveCacheLimit = defaults.ReceiveCacheLimit
}
return cfg
}
func signalReliabilityConfigFromOptions(opts *SignalReliabilityOptions) signalReliabilityConfig {
cfg := defaultSignalReliabilityConfig()
if opts == nil {
cfg.Enabled = true
cfg.EnableConfigured = true
return cfg
}
return signalReliabilityConfig{
Enabled: opts.Enabled,
EnableConfigured: true,
AckTimeout: opts.AckTimeout,
SendRetry: opts.SendRetry,
ReceiveCacheLimit: opts.ReceiveCacheLimit,
}
}
type signalReliabilityConfigurer interface {
setSignalReliabilityConfig(signalReliabilityConfig)
}
func UseSignalReliabilityClient(c Client, opts *SignalReliabilityOptions) error {
if c == nil {
return errSignalReliabilityClientNil
}
configurer, ok := any(c).(signalReliabilityConfigurer)
if !ok {
return errSignalReliabilityUnsupportedClient
}
configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts))
return nil
}
func UseSignalReliabilityServer(s Server, opts *SignalReliabilityOptions) error {
if s == nil {
return errSignalReliabilityServerNil
}
configurer, ok := any(s).(signalReliabilityConfigurer)
if !ok {
return errSignalReliabilityUnsupportedServer
}
configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts))
return nil
}
func (c *ClientCommon) getSignalReliabilityConfig() signalReliabilityConfig {
c.mu.Lock()
defer c.mu.Unlock()
c.signalReliableCfg = normalizeSignalReliabilityConfig(c.signalReliableCfg)
return c.signalReliableCfg
}
func (s *ServerCommon) getSignalReliabilityConfig() signalReliabilityConfig {
s.mu.Lock()
defer s.mu.Unlock()
s.signalReliableCfg = normalizeSignalReliabilityConfig(s.signalReliableCfg)
return s.signalReliableCfg
}
func (c *ClientCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) {
cfg = normalizeSignalReliabilityConfig(cfg)
c.mu.Lock()
cfg.EnableConfigured = true
c.signalReliableCfg = cfg
state := c.logicalSession
c.mu.Unlock()
if state != nil {
state.applySignalReliabilityConfig(cfg)
}
}
func (s *ServerCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) {
cfg = normalizeSignalReliabilityConfig(cfg)
s.mu.Lock()
cfg.EnableConfigured = true
s.signalReliableCfg = cfg
state := s.logicalSession
s.mu.Unlock()
if state != nil {
state.applySignalReliabilityConfig(cfg)
}
}
func (c *ClientCommon) applySignalReliabilityTransportDefault(enabled bool) {
cfg := c.getSignalReliabilityConfig()
if cfg.EnableConfigured {
return
}
cfg.Enabled = enabled
c.mu.Lock()
c.signalReliableCfg = cfg
c.mu.Unlock()
}
func (s *ServerCommon) applySignalReliabilityTransportDefault(enabled bool) {
cfg := s.getSignalReliabilityConfig()
if cfg.EnableConfigured {
return
}
cfg.Enabled = enabled
s.mu.Lock()
s.signalReliableCfg = cfg
s.mu.Unlock()
}
func requiresSignalReplyWait(msg TransferMsg) bool {
return msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT
}
func signalCanUseTransportAck(msg TransferMsg) bool {
return msg.ID != 0
}
func retryReliableSignalSend(cfg signalReliabilityConfig, send func(signalReliabilityConfig) error) error {
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, _ int) error {
return send(cfg)
})
}
func retryReliableSignalSendWithAttempt(cfg signalReliabilityConfig, send func(signalReliabilityConfig, int) error) error {
cfg = normalizeSignalReliabilityConfig(cfg)
var lastErr error
for attempt := 1; attempt <= cfg.SendRetry; attempt++ {
lastErr = send(cfg, attempt)
if lastErr == nil {
return nil
}
}
return lastErr
}
func sendSignalWithAck(scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error {
return sendSignalWithAckTracked(nil, scope, signalID, timeout, pool, send)
}
func sendSignalWithAckTracked(state *signalReliabilityState, scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error {
if state != nil {
state.incAckWait()
}
wait := pool.prepare(scope, signalID)
if err := send(); err != nil {
wait.cancel()
return err
}
err := pool.waitPrepared(wait, timeout)
if state == nil {
return err
}
if err == nil {
state.incAckDeliver()
return nil
}
if errors.Is(err, errSignalAckTimeout) {
state.incAckTimeout()
return err
}
if errors.Is(err, errSignalAckCanceled) {
state.incAckCanceled()
return err
}
return err
}
func (c *ClientCommon) sendSignalEnvelopeMaybeReliable(env Envelope, msg TransferMsg) error {
state := c.getSignalReliabilityState()
state.incSignalSend()
cfg := c.getSignalReliabilityConfig()
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
return c.sendEnvelope(env)
}
state.incReliableSend()
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error {
if attempt > 1 {
state.incRetry()
}
return sendSignalWithAckTracked(state, clientFileScope(), env.ID, cfg.AckTimeout, c.getSignalAckPool(), func() error {
return c.sendEnvelope(env)
})
})
}
func (s *ServerCommon) sendSignalEnvelopeMaybeReliable(logical *LogicalConn, env Envelope, msg TransferMsg) error {
if logical == nil {
return s.sendSignalEnvelopeMaybeReliableTransport(nil, env, msg)
}
return s.sendSignalEnvelopeMaybeReliableTransport(s.resolveOutboundTransport(logical), env, msg)
}
func (s *ServerCommon) sendSignalEnvelopeMaybeReliableTransport(transport *TransportConn, env Envelope, msg TransferMsg) error {
state := s.getSignalReliabilityState()
state.incSignalSend()
cfg := s.getSignalReliabilityConfig()
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
return s.sendEnvelopeTransport(transport, env)
}
state.incReliableSend()
scope := serverTransportScopeForTransport(transport)
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error {
if attempt > 1 {
state.incRetry()
}
return sendSignalWithAckTracked(state, scope, env.ID, cfg.AckTimeout, s.getSignalAckPool(), func() error {
return s.sendEnvelopeTransport(transport, env)
})
})
}
func (c *ClientCommon) sendSignalAck(signalID uint64) error {
return c.sendEnvelope(newSignalAckEnvelope(signalID))
}
func (s *ServerCommon) sendSignalAck(logical *LogicalConn, signalID uint64) error {
if logical == nil {
return s.sendSignalAckTransport(nil, signalID)
}
return s.sendSignalAckTransport(s.resolveOutboundTransport(logical), signalID)
}
func (s *ServerCommon) sendSignalAckTransport(transport *TransportConn, signalID uint64) error {
return s.sendEnvelopeTransport(transport, newSignalAckEnvelope(signalID))
}
func (s *ServerCommon) sendSignalAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, signalID uint64) error {
if conn == nil {
return s.sendSignalAckTransport(transport, signalID)
}
return s.sendEnvelopeInboundTransport(logical, transport, conn, newSignalAckEnvelope(signalID))
}
func (c *ClientCommon) handleSignalAckEnvelope(env Envelope) bool {
return c.getSignalAckPool().deliver(clientFileScope(), env.ID)
}
func (s *ServerCommon) handleSignalAckEnvelope(logical *LogicalConn, env Envelope) bool {
if logical == nil {
return s.handleSignalAckEnvelopeTransport(nil, env)
}
return s.handleSignalAckEnvelopeTransport(s.resolveOutboundTransport(logical), env)
}
func (s *ServerCommon) handleSignalAckEnvelopeTransport(transport *TransportConn, env Envelope) bool {
return s.getSignalAckPool().deliverAny(serverTransportDeliveryScopesForTransport(transport), env.ID)
}
func (c *ClientCommon) handleReceivedSignalReliability(msg TransferMsg) bool {
cfg := c.getSignalReliabilityConfig()
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
return false
}
state := c.getSignalReliabilityState()
duplicate := c.getReceivedSignalCache().seenOrRemember(clientFileScope(), msg.ID)
if duplicate {
state.incDuplicateRecv()
}
state.incAckSend()
if err := c.sendSignalAck(msg.ID); err != nil {
state.incAckSendError()
if c.showError || c.debugMode {
fmt.Println("client send signal ack error", err)
}
}
return duplicate
}
func (s *ServerCommon) handleReceivedSignalReliability(logical *LogicalConn, msg TransferMsg) bool {
return s.handleReceivedSignalReliabilityTransport(logical, s.resolveOutboundTransport(logical), nil, msg)
}
func (s *ServerCommon) handleReceivedSignalReliabilityTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) bool {
cfg := s.getSignalReliabilityConfig()
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
return false
}
state := s.getSignalReliabilityState()
scope := serverFileScope(logical)
duplicate := s.getReceivedSignalCache().seenOrRemember(scope, msg.ID)
if duplicate {
state.incDuplicateRecv()
}
state.incAckSend()
if err := s.sendSignalAckInbound(logical, transport, conn, msg.ID); err != nil {
state.incAckSendError()
if s.showError || s.debugMode {
fmt.Println("server send signal ack error", err)
}
}
return duplicate
}