package notify import ( "context" "errors" "net" "time" ) const ( defaultConnectRetryAttempts = 3 defaultConnectRetryBase = 200 * time.Millisecond defaultConnectRetryMax = 2 * time.Second ) type ConnectRetryOptions struct { MaxAttempts int BaseDelay time.Duration MaxDelay time.Duration ShouldRetry func(error) bool OnRetry func(ConnectRetryEvent) } type ConnectRetryEvent struct { Attempt int MaxAttempts int Err error NextDelay time.Duration } var ( errConnectRetryClientNil = errors.New("connect retry client is nil") errConnectRetryServerNil = errors.New("connect retry server is nil") errConnectRetryFnNil = errors.New("connect retry fn is nil") errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil") errClientReconnectNil = errors.New("client reconnect target is nil") errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported") errClientReconnectActive = errors.New("client reconnect requires an inactive session") ) func DefaultConnectRetryOptions() ConnectRetryOptions { return ConnectRetryOptions{ MaxAttempts: defaultConnectRetryAttempts, BaseDelay: defaultConnectRetryBase, MaxDelay: defaultConnectRetryMax, } } func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions { cfg := DefaultConnectRetryOptions() if opts == nil { return cfg } if opts.MaxAttempts > 0 { cfg.MaxAttempts = opts.MaxAttempts } if opts.BaseDelay > 0 { cfg.BaseDelay = opts.BaseDelay } if opts.MaxDelay > 0 { cfg.MaxDelay = opts.MaxDelay } cfg.ShouldRetry = opts.ShouldRetry cfg.OnRetry = opts.OnRetry if cfg.MaxDelay < cfg.BaseDelay { cfg.MaxDelay = cfg.BaseDelay } return cfg } func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error { if fn == nil { return errConnectRetryFnNil } if ctx == nil { ctx = context.Background() } cfg := normalizeConnectRetryOptions(opts) var lastErr error for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ { select { case <-ctx.Done(): return ctx.Err() default: } lastErr = fn(ctx) if lastErr == nil { return nil } if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) { return lastErr } if attempt >= cfg.MaxAttempts { break } delay := connectRetryBackoffDelay(cfg, attempt) if cfg.OnRetry != nil { cfg.OnRetry(ConnectRetryEvent{ Attempt: attempt, MaxAttempts: cfg.MaxAttempts, Err: lastErr, NextDelay: delay, }) } if err := waitConnectRetryDelay(ctx, delay); err != nil { return err } } return lastErr } func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error { if client == nil { return errConnectRetryClientNil } recorder, _ := any(client).(connectionRetryRecorder) retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) err := RetryConnect(ctx, retryOpts, func(context.Context) error { return client.Connect(network, addr) }) if recorder != nil { recorder.recordConnectionRetryResult(err) } return err } func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error { if client == nil { return errConnectRetryClientNil } if dialFn == nil { return errConnectRetryDialFnNil } recorder, _ := any(client).(connectionRetryRecorder) retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error { return client.ConnectByFactory(ctx, dialFn) }) if recorder != nil { recorder.recordConnectionRetryResult(err) } return err } type clientReconnecter interface { reconnect(context.Context) error } func ReconnectClient(ctx context.Context, client Client) error { if client == nil { return errClientReconnectNil } reconnecter, ok := any(client).(clientReconnecter) if !ok { return errClientReconnectUnsupported } return reconnecter.reconnect(ctx) } func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error { if client == nil { return errConnectRetryClientNil } recorder, _ := any(client).(connectionRetryRecorder) retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error { return ReconnectClient(ctx, client) }) if recorder != nil { recorder.recordConnectionRetryResult(err) } return err } func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error { if server == nil { return errConnectRetryServerNil } recorder, _ := any(server).(connectionRetryRecorder) retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder) err := RetryConnect(ctx, retryOpts, func(context.Context) error { return server.Listen(network, addr) }) if recorder != nil { recorder.recordConnectionRetryResult(err) } return err } func (c *ClientCommon) reconnect(ctx context.Context) error { if c == nil { return errClientReconnectNil } if sessionIsAlive(&c.alive) { return errClientReconnectActive } source := c.clientConnectSourceSnapshot() if source == nil || !source.canReconnect() { return errClientReconnectSourceUnavailable } finish, err := c.beginClientConnectAttempt() if err != nil { return err } started := false defer func() { finish(started) }() if err := c.validateSecurityConfiguration(); err != nil { return err } c.closeClientTransport() c.applySignalReliabilityTransportDefault(source.isUDP()) conn, err := source.dial(ctx) if err != nil { return err } if conn == nil { return errors.New("conn is nil") } if err := c.startClientWithConnSource(conn, source); err != nil { return err } started = true return nil } func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration { delay := cfg.BaseDelay if delay <= 0 { return 0 } for i := 1; i < failedAttempt; i++ { if delay >= cfg.MaxDelay/2 { return cfg.MaxDelay } delay *= 2 } if delay > cfg.MaxDelay { return cfg.MaxDelay } return delay } func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error { if delay <= 0 { return nil } timer := time.NewTimer(delay) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } }