package starnet import ( "context" "crypto/tls" "fmt" "net/http" "reflect" ) type Client struct { *http.Client } // NewHttpClient creates a new http.Client with the specified options. func NewHttpClient(opts ...RequestOpt) (Client, error) { req, err := newRequest(context.Background(), "", "", opts...) if err != nil { return Client{}, err } defer func() { req = nil }() cl, err := req.HttpClient() return Client{ Client: cl, }, err } func NewHttpClientNoErr(opts ...RequestOpt) Client { c, _ := NewHttpClient(opts...) return c } func NewClientFromHttpClient(httpClient *http.Client) (Client, error) { if httpClient == nil { return Client{}, fmt.Errorf("httpClient cannot be nil") } if httpClient.Transport == nil { httpClient.Transport = &Transport{ base: &http.Transport{}, } } else { switch t := httpClient.Transport.(type) { case *Transport: if t.base == nil { t.base = &http.Transport{} } case *http.Transport: httpClient.Transport = &Transport{ base: t, } default: return Client{}, fmt.Errorf("unsupported transport type: %T", t) } } return Client{ Client: httpClient, }, nil } func NewClientFromHttpClientNoError(httpClient *http.Client) Client { return Client{Client: httpClient} } // DisableRedirect returns whether the request will disable HTTP redirects. // if true, the request will not follow redirects automatically. // for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect. // you will get the original response with the redirect status code and Location header. func (c Client) DisableRedirect() bool { return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer() } // SetDisableRedirect sets whether the request will disable HTTP redirects. // if true, the request will not follow redirects automatically. // for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect. // you will get the original response with the redirect status code and Location header. func (c Client) SetDisableRedirect(disableRedirect bool) { if disableRedirect { c.Client.CheckRedirect = DefaultCheckRedirectFunc } } func (c Client) SetDefaultSkipTLSVerify(skip bool) { if c.Client.Transport == nil { c.Client.Transport = &Transport{ base: &http.Transport{}, } } if transport, ok := c.Client.Transport.(*Transport); ok { if transport.base.TLSClientConfig == nil { transport.base.TLSClientConfig = &tls.Config{} } transport.base.TLSClientConfig.InsecureSkipVerify = skip } else if transport, ok := c.Client.Transport.(*http.Transport); ok { if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } transport.TLSClientConfig.InsecureSkipVerify = skip } } func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) { if c.Client.Transport == nil { c.Client.Transport = &Transport{ base: &http.Transport{}, } } if transport, ok := c.Client.Transport.(*Transport); ok { transport.base.TLSClientConfig = tlsConfig } else if transport, ok := c.Client.Transport.(*http.Transport); ok { transport.TLSClientConfig = tlsConfig } } func (c Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { if c.Client == nil { return nil, fmt.Errorf("http client is nil") } req, err := NewRequestWithContextWithClient(context.Background(), c, url, method, opts...) return req, err } func (c Client) NewRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { if c.Client == nil { return nil, fmt.Errorf("http client is nil") } req, err := NewRequestWithContextWithClient(ctx, c, url, method, opts...) return req, err } func (c Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request { req, _ := c.NewRequest(url, method, opts...) return req } func (c Client) NewSimpleRequestContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request { req, _ := c.NewRequestContext(ctx, url, method, opts...) return req } type Transport struct { base *http.Transport } func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { if t.base == nil { t.base = &http.Transport{} } transport, ok := req.Context().Value("transport").(*http.Transport) if ok && transport != nil { return transport.RoundTrip(req) } proxy, ok := req.Context().Value("proxy").(string) if ok && proxy != "" { tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config) if ok && tlsConfig != nil { tmpTransport := t.base.Clone() tmpTransport.TLSClientConfig = tlsConfig return tmpTransport.RoundTrip(req) } } return t.base.RoundTrip(req) }