package starnet import ( "context" "errors" "io" "net" "net/http" "net/http/httptest" "strings" "sync/atomic" "testing" "time" ) func TestWithRetrySmokeGet(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&hits, 1) if n <= 2 { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("ok")) })) defer s.Close() resp, err := Get(s.URL, WithRetry(2, WithRetryBackoff(0, 0, 1), WithRetryJitter(0), ), ) if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) } if atomic.LoadInt32(&hits) != 3 { t.Fatalf("hits=%d want=3", hits) } } func TestWithRetryResponseRequestPointerStable(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&hits, 1) if n == 1 { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodGet). SetRetry(1, WithRetryBackoff(0, 0, 1), WithRetryJitter(0)) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.Request() != req { t.Fatal("response request pointer should point to original request") } } func TestWithRetryNoRetryForNonReplayableBodyReader(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hits, 1) w.WriteHeader(http.StatusServiceUnavailable) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodPost). SetBodyReader(strings.NewReader("payload")). SetRetry(3, WithRetryIdempotentOnly(false), WithRetryBackoff(0, 0, 1), WithRetryJitter(0), ) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusServiceUnavailable { t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusServiceUnavailable) } if atomic.LoadInt32(&hits) != 1 { t.Fatalf("hits=%d want=1", hits) } } func TestWithRetryPostWhenIdempotentDisabled(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&hits, 1) _, _ = io.Copy(io.Discard, r.Body) if n == 1 { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodPost). SetBodyString("hello"). SetRetry(1, WithRetryIdempotentOnly(false), WithRetryBackoff(0, 0, 1), WithRetryJitter(0), ) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) } if atomic.LoadInt32(&hits) != 2 { t.Fatalf("hits=%d want=2", hits) } } func TestWithRetryRawWithoutGetBodyNoRetry(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hits, 1) w.WriteHeader(http.StatusServiceUnavailable) })) defer s.Close() rawReq, _ := http.NewRequest(http.MethodPost, s.URL, io.MultiReader(strings.NewReader("raw"))) if rawReq.GetBody != nil { t.Fatal("raw request GetBody should be nil in this test") } req := NewSimpleRequest("", http.MethodPost, WithRawRequest(rawReq)). SetRetry(2, WithRetryIdempotentOnly(false), WithRetryBackoff(0, 0, 1), WithRetryJitter(0), ) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if atomic.LoadInt32(&hits) != 1 { t.Fatalf("hits=%d want=1", hits) } } func TestWithRetryRespectsTotalTimeoutBudget(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt32(&hits, 1) time.Sleep(80 * time.Millisecond) w.WriteHeader(http.StatusServiceUnavailable) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodGet). SetTimeout(120*time.Millisecond). SetRetry(3, WithRetryBackoff(0, 0, 1), WithRetryJitter(0), ) _, err := req.Do() if err == nil { t.Fatal("expected timeout error") } if !errors.Is(err, context.DeadlineExceeded) { t.Fatalf("expected context deadline exceeded, got: %v", err) } if h := atomic.LoadInt32(&hits); h > 2 { t.Fatalf("hits=%d want<=2 under tight timeout budget", h) } } func TestSetRetryInvalidMax(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetry(-1) if req.Err() == nil { t.Fatal("expected error for negative retry max") } } func TestSetRetrySeriesSmokeGet(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&hits, 1) if n <= 2 { w.WriteHeader(http.StatusTooManyRequests) return } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodGet). SetRetry(2). SetRetryBackoff(0, 0, 1). SetRetryJitter(0). SetRetryStatuses(http.StatusTooManyRequests) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) } if h := atomic.LoadInt32(&hits); h != 3 { t.Fatalf("hits=%d want=3", h) } } func TestSetRetryIdempotentOnlyWithPost(t *testing.T) { var hits int32 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { n := atomic.AddInt32(&hits, 1) _, _ = io.Copy(io.Discard, r.Body) if n == 1 { w.WriteHeader(http.StatusServiceUnavailable) return } w.WriteHeader(http.StatusOK) })) defer s.Close() req := NewSimpleRequest(s.URL, http.MethodPost). SetBodyString("hello"). SetRetry(1). SetRetryIdempotentOnly(false). SetRetryBackoff(0, 0, 1). SetRetryJitter(0) resp, err := req.Do() if err != nil { t.Fatalf("unexpected error: %v", err) } defer resp.Close() if resp.StatusCode != http.StatusOK { t.Fatalf("status=%d want=%d", resp.StatusCode, http.StatusOK) } if h := atomic.LoadInt32(&hits); h != 2 { t.Fatalf("hits=%d want=2", h) } } func TestSetRetryOnErrorOverridesDefault(t *testing.T) { var dials int32 dialErr := errors.New("dial failed") req := NewSimpleRequest("http://example.com", http.MethodGet). SetDialFunc(func(ctx context.Context, network, addr string) (net.Conn, error) { atomic.AddInt32(&dials, 1) return nil, dialErr }). SetRetry(1). SetRetryBackoff(0, 0, 1). SetRetryJitter(0). SetRetryOnError(func(err error) bool { return true }) _, err := req.Do() if err == nil { t.Fatal("expected error") } if h := atomic.LoadInt32(&dials); h != 2 { t.Fatalf("dial attempts=%d want=2", h) } } func TestSetRetryOptionRequireEnableRetry(t *testing.T) { req := NewSimpleRequest("http://example.com", http.MethodGet).SetRetryBackoff(10*time.Millisecond, 100*time.Millisecond, 2) if req.Err() == nil { t.Fatal("expected error when setting retry options before SetRetry") } if !strings.Contains(req.Err().Error(), "call SetRetry first") { t.Fatalf("unexpected error: %v", req.Err()) } }