357 lines
10 KiB
Go
357 lines
10 KiB
Go
|
|
package notify
|
||
|
|
|
||
|
|
import (
|
||
|
|
"errors"
|
||
|
|
"fmt"
|
||
|
|
"net"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
const defaultSignalAckTimeout = time.Second
|
||
|
|
|
||
|
|
const defaultSignalSendRetry = 3
|
||
|
|
|
||
|
|
type signalReliabilityConfig struct {
|
||
|
|
Enabled bool
|
||
|
|
EnableConfigured bool
|
||
|
|
AckTimeout time.Duration
|
||
|
|
SendRetry int
|
||
|
|
ReceiveCacheLimit int
|
||
|
|
}
|
||
|
|
|
||
|
|
type SignalReliabilityOptions struct {
|
||
|
|
Enabled bool
|
||
|
|
AckTimeout time.Duration
|
||
|
|
SendRetry int
|
||
|
|
ReceiveCacheLimit int
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
errSignalReliabilityClientNil = errors.New("signal reliability client is nil")
|
||
|
|
errSignalReliabilityServerNil = errors.New("signal reliability server is nil")
|
||
|
|
errSignalReliabilityUnsupportedClient = errors.New("signal reliability client type is unsupported")
|
||
|
|
errSignalReliabilityUnsupportedServer = errors.New("signal reliability server type is unsupported")
|
||
|
|
)
|
||
|
|
|
||
|
|
func defaultSignalReliabilityConfig() signalReliabilityConfig {
|
||
|
|
return signalReliabilityConfig{
|
||
|
|
Enabled: false,
|
||
|
|
EnableConfigured: false,
|
||
|
|
AckTimeout: defaultSignalAckTimeout,
|
||
|
|
SendRetry: defaultSignalSendRetry,
|
||
|
|
ReceiveCacheLimit: defaultReceivedSignalCacheLimit,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func DefaultSignalReliabilityOptions() SignalReliabilityOptions {
|
||
|
|
cfg := defaultSignalReliabilityConfig()
|
||
|
|
return SignalReliabilityOptions{
|
||
|
|
Enabled: true,
|
||
|
|
AckTimeout: cfg.AckTimeout,
|
||
|
|
SendRetry: cfg.SendRetry,
|
||
|
|
ReceiveCacheLimit: cfg.ReceiveCacheLimit,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func normalizeSignalReliabilityConfig(cfg signalReliabilityConfig) signalReliabilityConfig {
|
||
|
|
defaults := defaultSignalReliabilityConfig()
|
||
|
|
if cfg.AckTimeout <= 0 {
|
||
|
|
cfg.AckTimeout = defaults.AckTimeout
|
||
|
|
}
|
||
|
|
if cfg.SendRetry <= 0 {
|
||
|
|
cfg.SendRetry = defaults.SendRetry
|
||
|
|
}
|
||
|
|
if cfg.ReceiveCacheLimit <= 0 {
|
||
|
|
cfg.ReceiveCacheLimit = defaults.ReceiveCacheLimit
|
||
|
|
}
|
||
|
|
return cfg
|
||
|
|
}
|
||
|
|
|
||
|
|
func signalReliabilityConfigFromOptions(opts *SignalReliabilityOptions) signalReliabilityConfig {
|
||
|
|
cfg := defaultSignalReliabilityConfig()
|
||
|
|
if opts == nil {
|
||
|
|
cfg.Enabled = true
|
||
|
|
cfg.EnableConfigured = true
|
||
|
|
return cfg
|
||
|
|
}
|
||
|
|
return signalReliabilityConfig{
|
||
|
|
Enabled: opts.Enabled,
|
||
|
|
EnableConfigured: true,
|
||
|
|
AckTimeout: opts.AckTimeout,
|
||
|
|
SendRetry: opts.SendRetry,
|
||
|
|
ReceiveCacheLimit: opts.ReceiveCacheLimit,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
type signalReliabilityConfigurer interface {
|
||
|
|
setSignalReliabilityConfig(signalReliabilityConfig)
|
||
|
|
}
|
||
|
|
|
||
|
|
func UseSignalReliabilityClient(c Client, opts *SignalReliabilityOptions) error {
|
||
|
|
if c == nil {
|
||
|
|
return errSignalReliabilityClientNil
|
||
|
|
}
|
||
|
|
configurer, ok := any(c).(signalReliabilityConfigurer)
|
||
|
|
if !ok {
|
||
|
|
return errSignalReliabilityUnsupportedClient
|
||
|
|
}
|
||
|
|
configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts))
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func UseSignalReliabilityServer(s Server, opts *SignalReliabilityOptions) error {
|
||
|
|
if s == nil {
|
||
|
|
return errSignalReliabilityServerNil
|
||
|
|
}
|
||
|
|
configurer, ok := any(s).(signalReliabilityConfigurer)
|
||
|
|
if !ok {
|
||
|
|
return errSignalReliabilityUnsupportedServer
|
||
|
|
}
|
||
|
|
configurer.setSignalReliabilityConfig(signalReliabilityConfigFromOptions(opts))
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) getSignalReliabilityConfig() signalReliabilityConfig {
|
||
|
|
c.mu.Lock()
|
||
|
|
defer c.mu.Unlock()
|
||
|
|
c.signalReliableCfg = normalizeSignalReliabilityConfig(c.signalReliableCfg)
|
||
|
|
return c.signalReliableCfg
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) getSignalReliabilityConfig() signalReliabilityConfig {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
s.signalReliableCfg = normalizeSignalReliabilityConfig(s.signalReliableCfg)
|
||
|
|
return s.signalReliableCfg
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) {
|
||
|
|
cfg = normalizeSignalReliabilityConfig(cfg)
|
||
|
|
c.mu.Lock()
|
||
|
|
cfg.EnableConfigured = true
|
||
|
|
c.signalReliableCfg = cfg
|
||
|
|
state := c.logicalSession
|
||
|
|
c.mu.Unlock()
|
||
|
|
if state != nil {
|
||
|
|
state.applySignalReliabilityConfig(cfg)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) setSignalReliabilityConfig(cfg signalReliabilityConfig) {
|
||
|
|
cfg = normalizeSignalReliabilityConfig(cfg)
|
||
|
|
s.mu.Lock()
|
||
|
|
cfg.EnableConfigured = true
|
||
|
|
s.signalReliableCfg = cfg
|
||
|
|
state := s.logicalSession
|
||
|
|
s.mu.Unlock()
|
||
|
|
if state != nil {
|
||
|
|
state.applySignalReliabilityConfig(cfg)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) applySignalReliabilityTransportDefault(enabled bool) {
|
||
|
|
cfg := c.getSignalReliabilityConfig()
|
||
|
|
if cfg.EnableConfigured {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
cfg.Enabled = enabled
|
||
|
|
c.mu.Lock()
|
||
|
|
c.signalReliableCfg = cfg
|
||
|
|
c.mu.Unlock()
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) applySignalReliabilityTransportDefault(enabled bool) {
|
||
|
|
cfg := s.getSignalReliabilityConfig()
|
||
|
|
if cfg.EnableConfigured {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
cfg.Enabled = enabled
|
||
|
|
s.mu.Lock()
|
||
|
|
s.signalReliableCfg = cfg
|
||
|
|
s.mu.Unlock()
|
||
|
|
}
|
||
|
|
|
||
|
|
func requiresSignalReplyWait(msg TransferMsg) bool {
|
||
|
|
return msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT
|
||
|
|
}
|
||
|
|
|
||
|
|
func signalCanUseTransportAck(msg TransferMsg) bool {
|
||
|
|
return msg.ID != 0
|
||
|
|
}
|
||
|
|
|
||
|
|
func retryReliableSignalSend(cfg signalReliabilityConfig, send func(signalReliabilityConfig) error) error {
|
||
|
|
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, _ int) error {
|
||
|
|
return send(cfg)
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func retryReliableSignalSendWithAttempt(cfg signalReliabilityConfig, send func(signalReliabilityConfig, int) error) error {
|
||
|
|
cfg = normalizeSignalReliabilityConfig(cfg)
|
||
|
|
var lastErr error
|
||
|
|
for attempt := 1; attempt <= cfg.SendRetry; attempt++ {
|
||
|
|
lastErr = send(cfg, attempt)
|
||
|
|
if lastErr == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return lastErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func sendSignalWithAck(scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error {
|
||
|
|
return sendSignalWithAckTracked(nil, scope, signalID, timeout, pool, send)
|
||
|
|
}
|
||
|
|
|
||
|
|
func sendSignalWithAckTracked(state *signalReliabilityState, scope string, signalID uint64, timeout time.Duration, pool *signalAckPool, send func() error) error {
|
||
|
|
if state != nil {
|
||
|
|
state.incAckWait()
|
||
|
|
}
|
||
|
|
wait := pool.prepare(scope, signalID)
|
||
|
|
if err := send(); err != nil {
|
||
|
|
wait.cancel()
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
err := pool.waitPrepared(wait, timeout)
|
||
|
|
if state == nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if err == nil {
|
||
|
|
state.incAckDeliver()
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSignalAckTimeout) {
|
||
|
|
state.incAckTimeout()
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
if errors.Is(err, errSignalAckCanceled) {
|
||
|
|
state.incAckCanceled()
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) sendSignalEnvelopeMaybeReliable(env Envelope, msg TransferMsg) error {
|
||
|
|
state := c.getSignalReliabilityState()
|
||
|
|
state.incSignalSend()
|
||
|
|
cfg := c.getSignalReliabilityConfig()
|
||
|
|
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
|
||
|
|
return c.sendEnvelope(env)
|
||
|
|
}
|
||
|
|
state.incReliableSend()
|
||
|
|
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error {
|
||
|
|
if attempt > 1 {
|
||
|
|
state.incRetry()
|
||
|
|
}
|
||
|
|
return sendSignalWithAckTracked(state, clientFileScope(), env.ID, cfg.AckTimeout, c.getSignalAckPool(), func() error {
|
||
|
|
return c.sendEnvelope(env)
|
||
|
|
})
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) sendSignalEnvelopeMaybeReliable(logical *LogicalConn, env Envelope, msg TransferMsg) error {
|
||
|
|
if logical == nil {
|
||
|
|
return s.sendSignalEnvelopeMaybeReliableTransport(nil, env, msg)
|
||
|
|
}
|
||
|
|
return s.sendSignalEnvelopeMaybeReliableTransport(s.resolveOutboundTransport(logical), env, msg)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) sendSignalEnvelopeMaybeReliableTransport(transport *TransportConn, env Envelope, msg TransferMsg) error {
|
||
|
|
state := s.getSignalReliabilityState()
|
||
|
|
state.incSignalSend()
|
||
|
|
cfg := s.getSignalReliabilityConfig()
|
||
|
|
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
|
||
|
|
return s.sendEnvelopeTransport(transport, env)
|
||
|
|
}
|
||
|
|
state.incReliableSend()
|
||
|
|
scope := serverTransportScopeForTransport(transport)
|
||
|
|
return retryReliableSignalSendWithAttempt(cfg, func(cfg signalReliabilityConfig, attempt int) error {
|
||
|
|
if attempt > 1 {
|
||
|
|
state.incRetry()
|
||
|
|
}
|
||
|
|
return sendSignalWithAckTracked(state, scope, env.ID, cfg.AckTimeout, s.getSignalAckPool(), func() error {
|
||
|
|
return s.sendEnvelopeTransport(transport, env)
|
||
|
|
})
|
||
|
|
})
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) sendSignalAck(signalID uint64) error {
|
||
|
|
return c.sendEnvelope(newSignalAckEnvelope(signalID))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) sendSignalAck(logical *LogicalConn, signalID uint64) error {
|
||
|
|
if logical == nil {
|
||
|
|
return s.sendSignalAckTransport(nil, signalID)
|
||
|
|
}
|
||
|
|
return s.sendSignalAckTransport(s.resolveOutboundTransport(logical), signalID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) sendSignalAckTransport(transport *TransportConn, signalID uint64) error {
|
||
|
|
return s.sendEnvelopeTransport(transport, newSignalAckEnvelope(signalID))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) sendSignalAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, signalID uint64) error {
|
||
|
|
if conn == nil {
|
||
|
|
return s.sendSignalAckTransport(transport, signalID)
|
||
|
|
}
|
||
|
|
return s.sendEnvelopeInboundTransport(logical, transport, conn, newSignalAckEnvelope(signalID))
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) handleSignalAckEnvelope(env Envelope) bool {
|
||
|
|
return c.getSignalAckPool().deliver(clientFileScope(), env.ID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) handleSignalAckEnvelope(logical *LogicalConn, env Envelope) bool {
|
||
|
|
if logical == nil {
|
||
|
|
return s.handleSignalAckEnvelopeTransport(nil, env)
|
||
|
|
}
|
||
|
|
return s.handleSignalAckEnvelopeTransport(s.resolveOutboundTransport(logical), env)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) handleSignalAckEnvelopeTransport(transport *TransportConn, env Envelope) bool {
|
||
|
|
return s.getSignalAckPool().deliverAny(serverTransportDeliveryScopesForTransport(transport), env.ID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *ClientCommon) handleReceivedSignalReliability(msg TransferMsg) bool {
|
||
|
|
cfg := c.getSignalReliabilityConfig()
|
||
|
|
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
state := c.getSignalReliabilityState()
|
||
|
|
duplicate := c.getReceivedSignalCache().seenOrRemember(clientFileScope(), msg.ID)
|
||
|
|
if duplicate {
|
||
|
|
state.incDuplicateRecv()
|
||
|
|
}
|
||
|
|
state.incAckSend()
|
||
|
|
if err := c.sendSignalAck(msg.ID); err != nil {
|
||
|
|
state.incAckSendError()
|
||
|
|
if c.showError || c.debugMode {
|
||
|
|
fmt.Println("client send signal ack error", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return duplicate
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) handleReceivedSignalReliability(logical *LogicalConn, msg TransferMsg) bool {
|
||
|
|
return s.handleReceivedSignalReliabilityTransport(logical, s.resolveOutboundTransport(logical), nil, msg)
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *ServerCommon) handleReceivedSignalReliabilityTransport(logical *LogicalConn, transport *TransportConn, conn net.Conn, msg TransferMsg) bool {
|
||
|
|
cfg := s.getSignalReliabilityConfig()
|
||
|
|
if !cfg.Enabled || !signalCanUseTransportAck(msg) {
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
state := s.getSignalReliabilityState()
|
||
|
|
scope := serverFileScope(logical)
|
||
|
|
duplicate := s.getReceivedSignalCache().seenOrRemember(scope, msg.ID)
|
||
|
|
if duplicate {
|
||
|
|
state.incDuplicateRecv()
|
||
|
|
}
|
||
|
|
state.incAckSend()
|
||
|
|
if err := s.sendSignalAckInbound(logical, transport, conn, msg.ID); err != nil {
|
||
|
|
state.incAckSendError()
|
||
|
|
if s.showError || s.debugMode {
|
||
|
|
fmt.Println("server send signal ack error", err)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return duplicate
|
||
|
|
}
|