diff --git a/curl.go b/curl.go index 65d64f7..70761d7 100644 --- a/curl.go +++ b/curl.go @@ -143,6 +143,8 @@ func (r *Request) Clone() *Request { } if r.doRawClient { clonedRequest.rawClient = r.rawClient + } else { + clonedRequest.rawClient = new(http.Client) } if r.doRawRequest { clonedRequest.rawRequest = r.rawRequest @@ -726,6 +728,7 @@ func (r *RequestOpts) SkipTLSVerify() bool { return r.skipTLSVerify } +// SetSkipTLSVerify This function will Not Work when use rawClient,use SetClientSkipVerify instead func (r *Request) SetSkipTLSVerify(skipTLSVerify bool) *Request { r.skipTLSVerify = skipTLSVerify return r @@ -1793,3 +1796,21 @@ func NewRequestWithContextWithClient(ctx context.Context, client *http.Client, u req.SetDoRawTransport(true) return req, err } + +func SetClientSkipVerify(c *http.Client, val bool) error { + switch tp := c.Transport.(type) { + case *http.Transport: + if tp.TLSClientConfig == nil { + tp.TLSClientConfig = &tls.Config{} + } + tp.TLSClientConfig.InsecureSkipVerify = val + case nil: + transport := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: val}, + } + c.Transport = transport + default: + return fmt.Errorf("unsupported transport type: %T", tp) + } + return nil +} diff --git a/curl_test.go b/curl_test.go index feec396..c06266b 100644 --- a/curl_test.go +++ b/curl_test.go @@ -552,3 +552,44 @@ func TestUploadFile(t *testing.T) { } resp.CloseAll() } + +func TestTlsConfig(t *testing.T) { + server := httptest.NewTLSServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("hello") != "world" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("hello world failed")) + return + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + client, err := NewHttpClient(WithSkipTLSVerify(false)) + if err != nil { + t.Error(err) + } + req := NewSimpleRequestWithClient(client, server.URL, "GET", WithHeader("hello", "world")) + //SetClientSkipVerify(client, true) + req.SetDoRawClient(false) + //req.SetDoRawTransport(false) + req.SetSkipTLSVerify(true) + resp, err := req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() + req = req.Clone() + req.AddHeader("ok", "good") + resp, err = req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() +}