diff --git a/curl.go b/curl.go index 5270d23..56c96a3 100644 --- a/curl.go +++ b/curl.go @@ -40,6 +40,7 @@ func (r *Request) Clone() *Request { bodyFileData: CloneFiles(r.bodyFileData), queries: CloneStringMapSlice(r.queries), bodyDataBytes: CloneByteSlice(r.bodyDataBytes), + customTransport: r.customTransport, proxy: r.proxy, timeout: r.timeout, dialTimeout: r.dialTimeout, diff --git a/curl_default.go b/curl_default.go index 5c14c6c..55a48ae 100644 --- a/curl_default.go +++ b/curl_default.go @@ -22,6 +22,7 @@ var ( DefaultDialTimeout = 5 * time.Second DefaultTimeout = 10 * time.Second DefaultFetchRespBody = false + DefaultHttpClient = NewHttpClientNoErr() ) func UrlEncodeRaw(str string) string { @@ -59,39 +60,39 @@ func BuildPostForm(queryMap map[string]string) []byte { } func Get(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "GET", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "GET", opts...).Do() } func Post(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "POST", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "POST", opts...).Do() } func Options(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "OPTIONS", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "OPTIONS", opts...).Do() } func Put(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "PUT", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PUT", opts...).Do() } func Delete(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "DELETE", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "DELETE", opts...).Do() } func Head(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "HEAD", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "HEAD", opts...).Do() } func Patch(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "PATCH", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "PATCH", opts...).Do() } func Trace(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "TRACE", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "TRACE", opts...).Do() } func Connect(uri string, opts ...RequestOpt) (*Response, error) { - return NewSimpleRequest(uri, "CONNECT", opts...).Do() + return NewSimpleRequestWithClient(DefaultHttpClient, uri, "CONNECT", opts...).Do() } func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error { diff --git a/curl_test.go b/curl_test.go index 37baec3..e34b79a 100644 --- a/curl_test.go +++ b/curl_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" ) func TestUrlEncodeRaw(t *testing.T) { @@ -567,7 +568,7 @@ func TestTlsConfig(t *testing.T) { if err != nil { t.Error(err) } - req := NewSimpleRequestWithClient(client, server.URL, "GET", WithHeader("hello", "world")) + req := client.NewSimpleRequest(server.URL, "GET", WithHeader("hello", "world")) //SetClientSkipVerify(client, true) //req.SetDoRawClient(false) //req.SetDoRawTransport(false) @@ -601,3 +602,50 @@ func TestTlsConfig(t *testing.T) { t.Error(err) } } + +func TestWithTimeout(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + time.Sleep(time.Second * 30) + rw.Write([]byte(`OK`)) + })) + funcList := []func(string, ...RequestOpt) (*Response, error){ + Get, + Post, + Put, + Delete, + Options, + Patch, + Head, + Trace, + Connect, + } + defer server.Close() + for i := 1; i < 30; i++ { + go func(i int) { + old := time.Now() + fn := funcList[i%len(funcList)] + resp, err := fn(server.URL, WithTimeout(time.Second*time.Duration(i))) + if time.Since(old) > time.Second*time.Duration(i+2) || time.Since(old) < time.Second*time.Duration(i) { + t.Errorf("timeout not work") + } + fmt.Println(time.Since(old)) + if err == nil { + t.Error(err) + resp.CloseAll() + } else { + fmt.Println(err) + } + }(i) + } + resp, err := Get(server.URL, WithTimeout(time.Second*60)) + if err != nil { + t.Error(err) + } else { + fmt.Println(resp.Body().String()) + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() + } +} diff --git a/curl_transport.go b/curl_transport.go index 08ff99b..e14ae6d 100644 --- a/curl_transport.go +++ b/curl_transport.go @@ -27,6 +27,11 @@ func NewHttpClient(opts ...RequestOpt) (Client, error) { }, err } +func NewHttpClientNoErr(opts ...RequestOpt) Client { + c, _ := NewHttpClient(opts...) + return c +} + func NewClientFromHttpClient(httpClient *http.Client) (Client, error) { if httpClient == nil { return Client{}, fmt.Errorf("httpClient cannot be nil") @@ -109,6 +114,32 @@ func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) { } } +func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { + if c.Client == nil { + return nil, fmt.Errorf("http client is nil") + } + req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...) + return req, err +} + +func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { + if c.Client == nil { + return nil, fmt.Errorf("http client is nil") + } + req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...) + return req, err +} + +func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { + req, _ := c.NewRequest(url, method, opts...) + return req +} + +func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { + req, _ := c.NewRequestContext(ctx, url, method, opts...) + return req +} + type Transport struct { base *http.Transport }