starnet/client.go
2026-03-08 20:19:40 +08:00

325 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
)
// Client HTTP 客户端封装
type Client struct {
client *http.Client
opts []RequestOpt
mu sync.RWMutex
}
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
//Timeout: DefaultTimeout,
}
// 应用选项(如果有)
if len(opts) > 0 {
// 创建临时请求以应用选项
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
if err != nil {
return nil, wrapError(err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
if req.config.Network.Timeout > 0 {
httpClient.Timeout = req.config.Network.Timeout
}
*/
// 如果有自定义 Transport
if req.config.CustomTransport && req.config.Transport != nil {
httpClient.Transport = &Transport{base: req.config.Transport}
}
}
return &Client{
client: httpClient,
opts: opts,
}, nil
}
// NewClientNoErr 创建新的 Client忽略错误
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
client = &Client{
client: &http.Client{},
opts: opts,
}
}
return client
}
// NewClientFromHTTP 从 http.Client 创建 Client
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
if httpClient == nil {
return nil, ErrNilClient
}
// 确保 Transport 是我们的自定义类型
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
// 已经是我们的类型
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
// 包装标准 Transport
httpClient.Transport = &Transport{
base: t,
}
default:
return nil, fmt.Errorf("unsupported transport type: %T", t)
}
}
return &Client{
client: httpClient,
}, nil
}
// HTTPClient 获取底层 http.Client
func (c *Client) HTTPClient() *http.Client {
return c.client
}
// RequestOptions 获取默认选项(返回副本)
func (c *Client) RequestOptions() []RequestOpt {
c.mu.RLock()
defer c.mu.RUnlock()
opts := make([]RequestOpt, len(c.opts))
copy(opts, c.opts)
return opts
}
// SetOptions 设置默认选项
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = opts
c.mu.Unlock()
return c
}
// AddOptions 追加默认选项
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = append(c.opts, opts...)
c.mu.Unlock()
return c
}
// Clone 克隆 Client深拷贝
func (c *Client) Clone() *Client {
c.mu.RLock()
defer c.mu.RUnlock()
// 克隆 Transport
var transport http.RoundTripper
if c.client.Transport != nil {
switch t := c.client.Transport.(type) {
case *Transport:
transport = &Transport{
base: t.base.Clone(),
}
case *http.Transport:
transport = t.Clone()
default:
transport = c.client.Transport
}
}
return &Client{
client: &http.Client{
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
},
opts: append([]RequestOpt(nil), c.opts...),
}
}
// SetDefaultTLSConfig 设置默认 TLS 配置
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
}
transport.base.TLSClientConfig = tlsConfig
transport.mu.Unlock()
}
return c
}
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.mu.Unlock()
}
return c
}
// DisableRedirect 禁用重定向
func (c *Client) DisableRedirect() *Client {
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return c
}
// EnableRedirect 启用重定向
func (c *Client) EnableRedirect() *Client {
c.client.CheckRedirect = nil
return c
}
// NewRequest 创建新请求
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
return c.NewRequestWithContext(context.Background(), url, method, opts...)
}
// NewRequestWithContext 创建新请求(带 context
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
// 合并 Client 级别和请求级别的选项
c.mu.RLock()
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
c.mu.RUnlock()
req, err := newRequest(ctx, url, method, allOpts...)
if err != nil {
return nil, err
}
req.client = c
req.httpClient = c.client
return req, nil
}
// Get 发送 GET 请求
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodGet, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Post 发送 POST 请求
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPost, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Put 发送 PUT 请求
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPut, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Delete 发送 DELETE 请求
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodDelete, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Head 发送 HEAD 请求
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodHead, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Patch 发送 PATCH 请求
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPatch, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Options 发送 OPTIONS 请求
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodOptions, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
if err != nil {
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
client: c,
httpClient: c.client,
autoFetch: DefaultFetchRespBody,
}
}
return req
}