notify/connection_retry.go

258 lines
6.3 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"net"
"time"
)
const (
defaultConnectRetryAttempts = 3
defaultConnectRetryBase = 200 * time.Millisecond
defaultConnectRetryMax = 2 * time.Second
)
type ConnectRetryOptions struct {
MaxAttempts int
BaseDelay time.Duration
MaxDelay time.Duration
ShouldRetry func(error) bool
OnRetry func(ConnectRetryEvent)
}
type ConnectRetryEvent struct {
Attempt int
MaxAttempts int
Err error
NextDelay time.Duration
}
var (
errConnectRetryClientNil = errors.New("connect retry client is nil")
errConnectRetryServerNil = errors.New("connect retry server is nil")
errConnectRetryFnNil = errors.New("connect retry fn is nil")
errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil")
errClientReconnectNil = errors.New("client reconnect target is nil")
errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported")
errClientReconnectActive = errors.New("client reconnect requires an inactive session")
)
func DefaultConnectRetryOptions() ConnectRetryOptions {
return ConnectRetryOptions{
MaxAttempts: defaultConnectRetryAttempts,
BaseDelay: defaultConnectRetryBase,
MaxDelay: defaultConnectRetryMax,
}
}
func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions {
cfg := DefaultConnectRetryOptions()
if opts == nil {
return cfg
}
if opts.MaxAttempts > 0 {
cfg.MaxAttempts = opts.MaxAttempts
}
if opts.BaseDelay > 0 {
cfg.BaseDelay = opts.BaseDelay
}
if opts.MaxDelay > 0 {
cfg.MaxDelay = opts.MaxDelay
}
cfg.ShouldRetry = opts.ShouldRetry
cfg.OnRetry = opts.OnRetry
if cfg.MaxDelay < cfg.BaseDelay {
cfg.MaxDelay = cfg.BaseDelay
}
return cfg
}
func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error {
if fn == nil {
return errConnectRetryFnNil
}
if ctx == nil {
ctx = context.Background()
}
cfg := normalizeConnectRetryOptions(opts)
var lastErr error
for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
lastErr = fn(ctx)
if lastErr == nil {
return nil
}
if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) {
return lastErr
}
if attempt >= cfg.MaxAttempts {
break
}
delay := connectRetryBackoffDelay(cfg, attempt)
if cfg.OnRetry != nil {
cfg.OnRetry(ConnectRetryEvent{
Attempt: attempt,
MaxAttempts: cfg.MaxAttempts,
Err: lastErr,
NextDelay: delay,
})
}
if err := waitConnectRetryDelay(ctx, delay); err != nil {
return err
}
}
return lastErr
}
func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
return client.Connect(network, addr)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
if dialFn == nil {
return errConnectRetryDialFnNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
return client.ConnectByFactory(ctx, dialFn)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
type clientReconnecter interface {
reconnect(context.Context) error
}
func ReconnectClient(ctx context.Context, client Client) error {
if client == nil {
return errClientReconnectNil
}
reconnecter, ok := any(client).(clientReconnecter)
if !ok {
return errClientReconnectUnsupported
}
return reconnecter.reconnect(ctx)
}
func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error {
if client == nil {
return errConnectRetryClientNil
}
recorder, _ := any(client).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
return ReconnectClient(ctx, client)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error {
if server == nil {
return errConnectRetryServerNil
}
recorder, _ := any(server).(connectionRetryRecorder)
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
return server.Listen(network, addr)
})
if recorder != nil {
recorder.recordConnectionRetryResult(err)
}
return err
}
func (c *ClientCommon) reconnect(ctx context.Context) error {
if c == nil {
return errClientReconnectNil
}
if sessionIsAlive(&c.alive) {
return errClientReconnectActive
}
source := c.clientConnectSourceSnapshot()
if source == nil || !source.canReconnect() {
return errClientReconnectSourceUnavailable
}
finish, err := c.beginClientConnectAttempt()
if err != nil {
return err
}
started := false
defer func() {
finish(started)
}()
if err := c.validateSecurityConfiguration(); err != nil {
return err
}
c.closeClientTransport()
c.applySignalReliabilityTransportDefault(source.isUDP())
conn, err := source.dial(ctx)
if err != nil {
return err
}
if conn == nil {
return errors.New("conn is nil")
}
if err := c.startClientWithConnSource(conn, source); err != nil {
return err
}
started = true
return nil
}
func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration {
delay := cfg.BaseDelay
if delay <= 0 {
return 0
}
for i := 1; i < failedAttempt; i++ {
if delay >= cfg.MaxDelay/2 {
return cfg.MaxDelay
}
delay *= 2
}
if delay > cfg.MaxDelay {
return cfg.MaxDelay
}
return delay
}
func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error {
if delay <= 0 {
return nil
}
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}