135 lines
3.7 KiB
Go
135 lines
3.7 KiB
Go
package starnet
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net/http"
|
|
"reflect"
|
|
)
|
|
|
|
type Client struct {
|
|
*http.Client
|
|
}
|
|
|
|
// NewHttpClient creates a new http.Client with the specified options.
|
|
func NewHttpClient(opts ...RequestOpt) (Client, error) {
|
|
req, err := newRequest(context.Background(), "", "", opts...)
|
|
if err != nil {
|
|
return Client{}, err
|
|
}
|
|
defer func() {
|
|
req = nil
|
|
}()
|
|
cl, err := req.HttpClient()
|
|
return Client{
|
|
Client: cl,
|
|
}, err
|
|
}
|
|
|
|
func NewClientFromHttpClient(httpClient *http.Client) (Client, error) {
|
|
if httpClient == nil {
|
|
return Client{}, fmt.Errorf("httpClient cannot be nil")
|
|
}
|
|
|
|
if httpClient.Transport == nil {
|
|
httpClient.Transport = &Transport{
|
|
base: &http.Transport{},
|
|
}
|
|
} else {
|
|
switch t := httpClient.Transport.(type) {
|
|
case *Transport:
|
|
if t.base == nil {
|
|
t.base = &http.Transport{}
|
|
}
|
|
case *http.Transport:
|
|
httpClient.Transport = &Transport{
|
|
base: t,
|
|
}
|
|
default:
|
|
return Client{}, fmt.Errorf("unsupported transport type: %T", t)
|
|
}
|
|
}
|
|
return Client{
|
|
Client: httpClient,
|
|
}, nil
|
|
}
|
|
|
|
func NewClientFromHttpClientNoError(httpClient *http.Client) Client {
|
|
return Client{Client: httpClient}
|
|
}
|
|
|
|
// DisableRedirect returns whether the request will disable HTTP redirects.
|
|
// if true, the request will not follow redirects automatically.
|
|
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
|
// you will get the original response with the redirect status code and Location header.
|
|
func (c Client) DisableRedirect() bool {
|
|
return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer()
|
|
}
|
|
|
|
// SetDisableRedirect sets whether the request will disable HTTP redirects.
|
|
// if true, the request will not follow redirects automatically.
|
|
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
|
// you will get the original response with the redirect status code and Location header.
|
|
func (c Client) SetDisableRedirect(disableRedirect bool) {
|
|
if disableRedirect {
|
|
c.Client.CheckRedirect = DefaultCheckRedirectFunc
|
|
}
|
|
}
|
|
|
|
func (c Client) SetDefaultSkipTLSVerify(skip bool) {
|
|
if c.Client.Transport == nil {
|
|
c.Client.Transport = &Transport{
|
|
base: &http.Transport{},
|
|
}
|
|
}
|
|
if transport, ok := c.Client.Transport.(*Transport); ok {
|
|
if transport.base.TLSClientConfig == nil {
|
|
transport.base.TLSClientConfig = &tls.Config{}
|
|
}
|
|
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
|
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
|
if transport.TLSClientConfig == nil {
|
|
transport.TLSClientConfig = &tls.Config{}
|
|
}
|
|
transport.TLSClientConfig.InsecureSkipVerify = skip
|
|
}
|
|
}
|
|
|
|
func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) {
|
|
if c.Client.Transport == nil {
|
|
c.Client.Transport = &Transport{
|
|
base: &http.Transport{},
|
|
}
|
|
}
|
|
if transport, ok := c.Client.Transport.(*Transport); ok {
|
|
transport.base.TLSClientConfig = tlsConfig
|
|
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
|
transport.TLSClientConfig = tlsConfig
|
|
}
|
|
}
|
|
|
|
type Transport struct {
|
|
base *http.Transport
|
|
}
|
|
|
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
if t.base == nil {
|
|
t.base = &http.Transport{}
|
|
}
|
|
transport, ok := req.Context().Value("transport").(*http.Transport)
|
|
if ok && transport != nil {
|
|
return transport.RoundTrip(req)
|
|
}
|
|
proxy, ok := req.Context().Value("proxy").(string)
|
|
if ok && proxy != "" {
|
|
tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config)
|
|
if ok && tlsConfig != nil {
|
|
tmpTransport := t.base.Clone()
|
|
tmpTransport.TLSClientConfig = tlsConfig
|
|
return tmpTransport.RoundTrip(req)
|
|
}
|
|
}
|
|
return t.base.RoundTrip(req)
|
|
}
|