package notify import ( "context" "errors" "net" "testing" "time" ) func TestRetryConnectSucceedsAfterRetries(t *testing.T) { var attempts int wantErr := errors.New("dial failed") err := RetryConnect(context.Background(), &ConnectRetryOptions{ MaxAttempts: 4, BaseDelay: time.Millisecond, MaxDelay: 2 * time.Millisecond, }, func(context.Context) error { attempts++ if attempts < 3 { return wantErr } return nil }) if err != nil { t.Fatalf("RetryConnect failed: %v", err) } if got, want := attempts, 3; got != want { t.Fatalf("attempts mismatch: got %d want %d", got, want) } } func TestRetryConnectReturnsLastError(t *testing.T) { var attempts int wantErr := errors.New("connect failed") err := RetryConnect(context.Background(), &ConnectRetryOptions{ MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Millisecond, }, func(context.Context) error { attempts++ return wantErr }) if !errors.Is(err, wantErr) { t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) } if got, want := attempts, 3; got != want { t.Fatalf("attempts mismatch: got %d want %d", got, want) } } func TestRetryConnectContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() var attempts int err := RetryConnect(ctx, &ConnectRetryOptions{ MaxAttempts: 3, BaseDelay: 100 * time.Millisecond, MaxDelay: 100 * time.Millisecond, }, func(context.Context) error { attempts++ cancel() return errors.New("fail") }) if !errors.Is(err, context.Canceled) { t.Fatalf("RetryConnect error = %v, want context canceled", err) } if got, want := attempts, 1; got != want { t.Fatalf("attempts mismatch: got %d want %d", got, want) } } func TestConnectRetryRejectsNilInputs(t *testing.T) { if err := RetryConnect(context.Background(), nil, nil); !errors.Is(err, errConnectRetryFnNil) { t.Fatalf("RetryConnect nil fn error = %v, want %v", err, errConnectRetryFnNil) } if err := ConnectClientWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryClientNil) { t.Fatalf("ConnectClientWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil) } if err := ConnectClientFactoryWithRetry(context.Background(), nil, nil, nil); !errors.Is(err, errConnectRetryClientNil) { t.Fatalf("ConnectClientFactoryWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil) } if err := ConnectClientFactoryWithRetry(context.Background(), NewClient(), nil, nil); !errors.Is(err, errConnectRetryDialFnNil) { t.Fatalf("ConnectClientFactoryWithRetry nil dialFn error = %v, want %v", err, errConnectRetryDialFnNil) } if err := ListenServerWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryServerNil) { t.Fatalf("ListenServerWithRetry nil server error = %v, want %v", err, errConnectRetryServerNil) } } func TestConnectRetryBackoffDelayCapped(t *testing.T) { cfg := normalizeConnectRetryOptions(&ConnectRetryOptions{ MaxAttempts: 5, BaseDelay: 10 * time.Millisecond, MaxDelay: 30 * time.Millisecond, }) if got, want := connectRetryBackoffDelay(cfg, 1), 10*time.Millisecond; got != want { t.Fatalf("delay attempt1 mismatch: got %v want %v", got, want) } if got, want := connectRetryBackoffDelay(cfg, 2), 20*time.Millisecond; got != want { t.Fatalf("delay attempt2 mismatch: got %v want %v", got, want) } if got, want := connectRetryBackoffDelay(cfg, 3), 30*time.Millisecond; got != want { t.Fatalf("delay attempt3 mismatch: got %v want %v", got, want) } if got, want := connectRetryBackoffDelay(cfg, 4), 30*time.Millisecond; got != want { t.Fatalf("delay attempt4 mismatch: got %v want %v", got, want) } } func TestRetryConnectShouldRetryCanStopEarly(t *testing.T) { var attempts int wantErr := errors.New("not retriable") err := RetryConnect(context.Background(), &ConnectRetryOptions{ MaxAttempts: 5, BaseDelay: time.Millisecond, MaxDelay: 2 * time.Millisecond, ShouldRetry: func(err error) bool { return !errors.Is(err, wantErr) }, }, func(context.Context) error { attempts++ return wantErr }) if !errors.Is(err, wantErr) { t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) } if got, want := attempts, 1; got != want { t.Fatalf("attempts mismatch: got %d want %d", got, want) } } func TestRetryConnectOnRetryHook(t *testing.T) { var events []ConnectRetryEvent wantErr := errors.New("dial failed") err := RetryConnect(context.Background(), &ConnectRetryOptions{ MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: 2 * time.Millisecond, OnRetry: func(event ConnectRetryEvent) { events = append(events, event) }, }, func(context.Context) error { return wantErr }) if !errors.Is(err, wantErr) { t.Fatalf("RetryConnect error = %v, want %v", err, wantErr) } if got, want := len(events), 2; got != want { t.Fatalf("retry events mismatch: got %d want %d", got, want) } if got, want := events[0].Attempt, 1; got != want { t.Fatalf("event[0] attempt mismatch: got %d want %d", got, want) } if got, want := events[0].MaxAttempts, 3; got != want { t.Fatalf("event[0] max attempts mismatch: got %d want %d", got, want) } if !errors.Is(events[0].Err, wantErr) { t.Fatalf("event[0] err mismatch: got %v want %v", events[0].Err, wantErr) } if got, want := events[0].NextDelay, time.Millisecond; got != want { t.Fatalf("event[0] next delay mismatch: got %v want %v", got, want) } if got, want := events[1].Attempt, 2; got != want { t.Fatalf("event[1] attempt mismatch: got %d want %d", got, want) } if got, want := events[1].NextDelay, 2*time.Millisecond; got != want { t.Fatalf("event[1] next delay mismatch: got %v want %v", got, want) } } func TestConnectClientFactoryWithRetryRecoversFromFailedStart(t *testing.T) { client := NewClient().(*ClientCommon) UseLegacySecurityClient(client) server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) { UseLegacySecurityServer(server) }) wantErr := errors.New("key exchange failed on first attempt") keyExchangeAttempts := 0 client.keyExchangeFn = func(Client) error { keyExchangeAttempts++ if keyExchangeAttempts == 1 { return wantErr } return nil } dialAttempts := 0 var peerConns []net.Conn dialFn := func(context.Context) (net.Conn, error) { dialAttempts++ left, right := net.Pipe() peerConns = append(peerConns, right) bootstrapPeerAttachConnForTest(t, server, right) return left, nil } err := ConnectClientFactoryWithRetry(context.Background(), client, dialFn, &ConnectRetryOptions{ MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Millisecond, }) if err != nil { t.Fatalf("ConnectClientFactoryWithRetry failed: %v", err) } if got, want := dialAttempts, 2; got != want { t.Fatalf("dial attempts mismatch: got %d want %d", got, want) } if got, want := keyExchangeAttempts, 2; got != want { t.Fatalf("key exchange attempts mismatch: got %d want %d", got, want) } if status := client.Status(); !status.Alive { t.Fatalf("client should be alive after retry success: %+v", status) } runtimeSnapshot, err := GetClientRuntimeSnapshot(client) if err != nil { t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) } if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want { t.Fatalf("client retry events mismatch: got %d want %d", got, want) } if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want { t.Fatalf("client last retry attempt mismatch: got %d want %d", got, want) } if got, want := runtimeSnapshot.Retry.LastRetryError, wantErr.Error(); got != want { t.Fatalf("client last retry error mismatch: got %q want %q", got, want) } if runtimeSnapshot.Retry.LastRetryAt.IsZero() { t.Fatal("client last retry time should be recorded") } if runtimeSnapshot.Retry.LastResultError != "" { t.Fatalf("client last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError) } if runtimeSnapshot.Retry.LastResultAt.IsZero() { t.Fatal("client last result time should be recorded") } client.setByeFromServer(true) if err := client.Stop(); err != nil { t.Fatalf("client Stop failed: %v", err) } for _, conn := range peerConns { _ = conn.Close() } } func TestListenServerWithRetryRecoversFromFailedStart(t *testing.T) { server := NewServer().(*ServerCommon) var retryEvents []ConnectRetryEvent err := ListenServerWithRetry(context.Background(), server, "tcp", "127.0.0.1:0", &ConnectRetryOptions{ MaxAttempts: 3, BaseDelay: time.Millisecond, MaxDelay: time.Millisecond, OnRetry: func(event ConnectRetryEvent) { retryEvents = append(retryEvents, event) if event.Attempt == 1 { UseLegacySecurityServer(server) } }, }) if err != nil { t.Fatalf("ListenServerWithRetry failed: %v", err) } if status := server.Status(); !status.Alive { t.Fatalf("server should be alive after retry success: %+v", status) } if got := len(retryEvents); got < 1 { t.Fatal("OnRetry should be called at least once") } if got, want := retryEvents[0].Attempt, 1; got != want { t.Fatalf("retry event attempt mismatch: got %d want %d", got, want) } if !errors.Is(retryEvents[0].Err, errModernPSKRequired) { t.Fatalf("retry event err mismatch: got %v want %v", retryEvents[0].Err, errModernPSKRequired) } runtimeSnapshot, err := GetServerRuntimeSnapshot(server) if err != nil { t.Fatalf("GetServerRuntimeSnapshot failed: %v", err) } if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want { t.Fatalf("server retry events mismatch: got %d want %d", got, want) } if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want { t.Fatalf("server last retry attempt mismatch: got %d want %d", got, want) } if got, want := runtimeSnapshot.Retry.LastRetryError, errModernPSKRequired.Error(); got != want { t.Fatalf("server last retry error mismatch: got %q want %q", got, want) } if runtimeSnapshot.Retry.LastRetryAt.IsZero() { t.Fatal("server last retry time should be recorded") } if runtimeSnapshot.Retry.LastResultError != "" { t.Fatalf("server last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError) } if runtimeSnapshot.Retry.LastResultAt.IsZero() { t.Fatal("server last result time should be recorded") } if err := server.Stop(); err != nil { t.Fatalf("server Stop failed: %v", err) } }