重构http Client部分
This commit is contained in:
parent
d260181adf
commit
260ceb90ed
197
curl_default.go
Normal file
197
curl_default.go
Normal file
@ -0,0 +1,197 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
HEADER_FORM_URLENCODE = `application/x-www-form-urlencoded`
|
||||
HEADER_FORM_DATA = `multipart/form-data`
|
||||
HEADER_JSON = `application/json`
|
||||
HEADER_PLAIN = `text/plain`
|
||||
)
|
||||
|
||||
var (
|
||||
DefaultDialTimeout = 5 * time.Second
|
||||
DefaultTimeout = 10 * time.Second
|
||||
DefaultFetchRespBody = false
|
||||
)
|
||||
|
||||
func UrlEncodeRaw(str string) string {
|
||||
strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1)
|
||||
return strs
|
||||
}
|
||||
|
||||
func UrlEncode(str string) string {
|
||||
return url.QueryEscape(str)
|
||||
}
|
||||
|
||||
func UrlDecode(str string) (string, error) {
|
||||
return url.QueryUnescape(str)
|
||||
}
|
||||
|
||||
func BuildQuery(queryData map[string]string) string {
|
||||
query := url.Values{}
|
||||
for k, v := range queryData {
|
||||
query.Add(k, v)
|
||||
}
|
||||
return query.Encode()
|
||||
}
|
||||
|
||||
// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string,
|
||||
// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests,
|
||||
// where the server expects the request body to be URL-encoded form data.
|
||||
//
|
||||
// Parameters:
|
||||
// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request.
|
||||
//
|
||||
// Returns:
|
||||
// A byte slice representing the URL-encoded form data.
|
||||
func BuildPostForm(queryMap map[string]string) []byte {
|
||||
return []byte(BuildQuery(queryMap))
|
||||
}
|
||||
|
||||
func Get(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "GET", opts...).Do()
|
||||
}
|
||||
|
||||
func Post(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "POST", opts...).Do()
|
||||
}
|
||||
|
||||
func Options(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "OPTIONS", opts...).Do()
|
||||
}
|
||||
|
||||
func Put(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "PUT", opts...).Do()
|
||||
}
|
||||
|
||||
func Delete(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "DELETE", opts...).Do()
|
||||
}
|
||||
|
||||
func Head(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "HEAD", opts...).Do()
|
||||
}
|
||||
|
||||
func Patch(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "PATCH", opts...).Do()
|
||||
}
|
||||
|
||||
func Trace(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "TRACE", opts...).Do()
|
||||
}
|
||||
|
||||
func Connect(uri string, opts ...RequestOpt) (*Response, error) {
|
||||
return NewSimpleRequest(uri, "CONNECT", opts...).Do()
|
||||
}
|
||||
|
||||
func DefaultCheckRedirectFunc(req *http.Request, via []*http.Request) error {
|
||||
return http.ErrUseLastResponse
|
||||
}
|
||||
|
||||
func DefaultDialFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||
var lastErr error
|
||||
var addrs []string
|
||||
if dialFn, ok := ctx.Value("dialFunc").(func(context.Context, string, string) (net.Conn, error)); ok {
|
||||
if dialFn != nil {
|
||||
return dialFn(ctx, netType, addr)
|
||||
}
|
||||
}
|
||||
customIP, ok := ctx.Value("customIP").([]string)
|
||||
if !ok {
|
||||
customIP = nil
|
||||
}
|
||||
dialTimeout, ok := ctx.Value("dialTimeout").(time.Duration)
|
||||
if !ok {
|
||||
dialTimeout = DefaultDialTimeout
|
||||
}
|
||||
timeout, ok := ctx.Value("timeout").(time.Duration)
|
||||
if !ok {
|
||||
timeout = DefaultTimeout
|
||||
}
|
||||
lookUpIPfn, ok := ctx.Value("lookUpIP").(func(context.Context, string) ([]net.IPAddr, error))
|
||||
if !ok {
|
||||
lookUpIPfn = net.DefaultResolver.LookupIPAddr
|
||||
}
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
proxy, ok := ctx.Value("proxy").(string)
|
||||
if !ok {
|
||||
proxy = ""
|
||||
}
|
||||
if proxy == "" && len(customIP) > 0 {
|
||||
for _, v := range customIP {
|
||||
ipAddr := net.ParseIP(v)
|
||||
if ipAddr == nil {
|
||||
return nil, fmt.Errorf("invalid custom ip: %s", customIP)
|
||||
}
|
||||
tmpAddr := net.JoinHostPort(v, port)
|
||||
addrs = append(addrs, tmpAddr)
|
||||
}
|
||||
} else {
|
||||
ipLists, err := lookUpIPfn(ctx, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, v := range ipLists {
|
||||
tmpAddr := net.JoinHostPort(v.String(), port)
|
||||
addrs = append(addrs, tmpAddr)
|
||||
}
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
c, err := net.DialTimeout(netType, addr, dialTimeout)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
if timeout != 0 {
|
||||
err = c.SetDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func DefaultDialTlsFunc(ctx context.Context, netType, addr string) (net.Conn, error) {
|
||||
conn, err := DefaultDialFunc(ctx, netType, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tlsConfig, ok := ctx.Value("tlsConfig").(*tls.Config)
|
||||
if !ok || tlsConfig == nil {
|
||||
return nil, fmt.Errorf("tlsConfig is not set in context")
|
||||
}
|
||||
tlsConn := tls.Client(conn, tlsConfig)
|
||||
if err := tlsConn.Handshake(); err != nil {
|
||||
return nil, fmt.Errorf("tls handshake failed: %w", err)
|
||||
}
|
||||
return tlsConn, nil
|
||||
}
|
||||
|
||||
func DefaultProxyURL() func(*http.Request) (*url.URL, error) {
|
||||
return func(req *http.Request) (*url.URL, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("request is nil")
|
||||
}
|
||||
proxyURL, ok := req.Context().Value("proxy").(string)
|
||||
if !ok || proxyURL == "" {
|
||||
return nil, nil
|
||||
}
|
||||
parsedURL, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse proxy URL: %w", err)
|
||||
}
|
||||
return parsedURL, nil
|
||||
}
|
||||
}
|
18
curl_test.go
18
curl_test.go
@ -345,11 +345,11 @@ func TestGet(t *testing.T) {
|
||||
var reply postmanReply
|
||||
resp, err := NewReq("https://postman-echo.com/get").
|
||||
AddHeader("hello", "nononmo").
|
||||
SetAutoCalcContentLength(true).
|
||||
Do()
|
||||
SetAutoCalcContentLengthNoError(true).Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Proto)
|
||||
err = resp.Body().Unmarshal(&reply)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@ -474,7 +474,7 @@ func TestReqClone(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequestWithClient(http.DefaultClient, server.URL, "GET", WithHeader("hello", "world"))
|
||||
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
@ -526,7 +526,7 @@ func TestUploadFile(t *testing.T) {
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
req := NewSimpleRequestWithClient(http.DefaultClient, server.URL, "GET", WithHeader("hello", "world"))
|
||||
req := NewSimpleRequestWithClient(NewClientFromHttpClientNoError(http.DefaultClient), server.URL, "GET", WithHeader("hello", "world"))
|
||||
req.AddFileWithName("666", "./curl.go", "curl.go")
|
||||
req.AddFile("777", "./go.mod")
|
||||
req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html")
|
||||
@ -569,13 +569,15 @@ func TestTlsConfig(t *testing.T) {
|
||||
}
|
||||
req := NewSimpleRequestWithClient(client, server.URL, "GET", WithHeader("hello", "world"))
|
||||
//SetClientSkipVerify(client, true)
|
||||
req.SetDoRawClient(false)
|
||||
//req.SetDoRawClient(false)
|
||||
//req.SetDoRawTransport(false)
|
||||
req.SetSkipTLSVerify(true)
|
||||
req.SetProxy("http://127.0.0.1:29992")
|
||||
resp, err := req.Do()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
fmt.Println(resp.Proto)
|
||||
if resp.StatusCode != 200 {
|
||||
resp.CloseAll()
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
@ -592,4 +594,10 @@ func TestTlsConfig(t *testing.T) {
|
||||
t.Errorf("status code is %d", resp.StatusCode)
|
||||
}
|
||||
resp.CloseAll()
|
||||
req = req.Clone()
|
||||
req.SetSkipTLSVerify(false)
|
||||
resp, err = req.Do()
|
||||
if err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
134
curl_transport.go
Normal file
134
curl_transport.go
Normal file
@ -0,0 +1,134 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
*http.Client
|
||||
}
|
||||
|
||||
// NewHttpClient creates a new http.Client with the specified options.
|
||||
func NewHttpClient(opts ...RequestOpt) (Client, error) {
|
||||
req, err := newRequest(context.Background(), "", "", opts...)
|
||||
if err != nil {
|
||||
return Client{}, err
|
||||
}
|
||||
defer func() {
|
||||
req = nil
|
||||
}()
|
||||
cl, err := req.HttpClient()
|
||||
return Client{
|
||||
Client: cl,
|
||||
}, err
|
||||
}
|
||||
|
||||
func NewClientFromHttpClient(httpClient *http.Client) (Client, error) {
|
||||
if httpClient == nil {
|
||||
return Client{}, fmt.Errorf("httpClient cannot be nil")
|
||||
}
|
||||
|
||||
if httpClient.Transport == nil {
|
||||
httpClient.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
} else {
|
||||
switch t := httpClient.Transport.(type) {
|
||||
case *Transport:
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{}
|
||||
}
|
||||
case *http.Transport:
|
||||
httpClient.Transport = &Transport{
|
||||
base: t,
|
||||
}
|
||||
default:
|
||||
return Client{}, fmt.Errorf("unsupported transport type: %T", t)
|
||||
}
|
||||
}
|
||||
return Client{
|
||||
Client: httpClient,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClientFromHttpClientNoError(httpClient *http.Client) Client {
|
||||
return Client{Client: httpClient}
|
||||
}
|
||||
|
||||
// DisableRedirect returns whether the request will disable HTTP redirects.
|
||||
// if true, the request will not follow redirects automatically.
|
||||
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||
// you will get the original response with the redirect status code and Location header.
|
||||
func (c Client) DisableRedirect() bool {
|
||||
return reflect.ValueOf(c.Client.CheckRedirect).Pointer() == reflect.ValueOf(DefaultCheckRedirectFunc).Pointer()
|
||||
}
|
||||
|
||||
// SetDisableRedirect sets whether the request will disable HTTP redirects.
|
||||
// if true, the request will not follow redirects automatically.
|
||||
// for example, if the server responds with a 301 or 302 status code, the request will not automatically follow the redirect.
|
||||
// you will get the original response with the redirect status code and Location header.
|
||||
func (c Client) SetDisableRedirect(disableRedirect bool) {
|
||||
if disableRedirect {
|
||||
c.Client.CheckRedirect = DefaultCheckRedirectFunc
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) SetDefaultSkipTLSVerify(skip bool) {
|
||||
if c.Client.Transport == nil {
|
||||
c.Client.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
}
|
||||
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||
if transport.base.TLSClientConfig == nil {
|
||||
transport.base.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.base.TLSClientConfig.InsecureSkipVerify = skip
|
||||
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||
if transport.TLSClientConfig == nil {
|
||||
transport.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
transport.TLSClientConfig.InsecureSkipVerify = skip
|
||||
}
|
||||
}
|
||||
|
||||
func (c Client) SetDefaultTLSConfig(tlsConfig *tls.Config) {
|
||||
if c.Client.Transport == nil {
|
||||
c.Client.Transport = &Transport{
|
||||
base: &http.Transport{},
|
||||
}
|
||||
}
|
||||
if transport, ok := c.Client.Transport.(*Transport); ok {
|
||||
transport.base.TLSClientConfig = tlsConfig
|
||||
} else if transport, ok := c.Client.Transport.(*http.Transport); ok {
|
||||
transport.TLSClientConfig = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
type Transport struct {
|
||||
base *http.Transport
|
||||
}
|
||||
|
||||
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if t.base == nil {
|
||||
t.base = &http.Transport{}
|
||||
}
|
||||
transport, ok := req.Context().Value("transport").(*http.Transport)
|
||||
if ok && transport != nil {
|
||||
return transport.RoundTrip(req)
|
||||
}
|
||||
proxy, ok := req.Context().Value("proxy").(string)
|
||||
if ok && proxy != "" {
|
||||
tlsConfig, ok := req.Context().Value("tlsConfig").(*tls.Config)
|
||||
if ok && tlsConfig != nil {
|
||||
tmpTransport := t.base.Clone()
|
||||
tmpTransport.TLSClientConfig = tlsConfig
|
||||
return tmpTransport.RoundTrip(req)
|
||||
}
|
||||
}
|
||||
return t.base.RoundTrip(req)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user