starnet/client.go

325 lines
7.7 KiB
Go
Raw Permalink Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net/http"
"sync"
"time"
)
// Client HTTP 客户端封装
type Client struct {
client *http.Client
opts []RequestOpt
mu sync.RWMutex
}
// NewClient 创建新的 Client
func NewClient(opts ...RequestOpt) (*Client, error) {
// 创建基础 Transport
baseTransport := &http.Transport{
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
httpClient := &http.Client{
Transport: &Transport{base: baseTransport},
//Timeout: DefaultTimeout,
}
// 应用选项(如果有)
if len(opts) > 0 {
// 创建临时请求以应用选项
req, err := newRequest(context.Background(), "", http.MethodGet, opts...)
if err != nil {
return nil, wrapError(err, "create client")
}
/*
// 如果选项中有自定义配置,应用到 httpClient
if req.config.Network.Timeout > 0 {
httpClient.Timeout = req.config.Network.Timeout
}
*/
// 如果有自定义 Transport
if req.config.CustomTransport && req.config.Transport != nil {
httpClient.Transport = &Transport{base: req.config.Transport}
}
}
return &Client{
client: httpClient,
opts: opts,
}, nil
}
// NewClientNoErr 创建新的 Client忽略错误
func NewClientNoErr(opts ...RequestOpt) *Client {
client, _ := NewClient(opts...)
if client == nil {
client = &Client{
client: &http.Client{},
opts: opts,
}
}
return client
}
// NewClientFromHTTP 从 http.Client 创建 Client
func NewClientFromHTTP(httpClient *http.Client) (*Client, error) {
if httpClient == nil {
return nil, ErrNilClient
}
// 确保 Transport 是我们的自定义类型
if httpClient.Transport == nil {
httpClient.Transport = &Transport{
base: &http.Transport{},
}
} else {
switch t := httpClient.Transport.(type) {
case *Transport:
// 已经是我们的类型
if t.base == nil {
t.base = &http.Transport{}
}
case *http.Transport:
// 包装标准 Transport
httpClient.Transport = &Transport{
base: t,
}
default:
return nil, fmt.Errorf("unsupported transport type: %T", t)
}
}
return &Client{
client: httpClient,
}, nil
}
// HTTPClient 获取底层 http.Client
func (c *Client) HTTPClient() *http.Client {
return c.client
}
// RequestOptions 获取默认选项(返回副本)
func (c *Client) RequestOptions() []RequestOpt {
c.mu.RLock()
defer c.mu.RUnlock()
opts := make([]RequestOpt, len(c.opts))
copy(opts, c.opts)
return opts
}
// SetOptions 设置默认选项
func (c *Client) SetOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = opts
c.mu.Unlock()
return c
}
// AddOptions 追加默认选项
func (c *Client) AddOptions(opts ...RequestOpt) *Client {
c.mu.Lock()
c.opts = append(c.opts, opts...)
c.mu.Unlock()
return c
}
// Clone 克隆 Client深拷贝
func (c *Client) Clone() *Client {
c.mu.RLock()
defer c.mu.RUnlock()
// 克隆 Transport
var transport http.RoundTripper
if c.client.Transport != nil {
switch t := c.client.Transport.(type) {
case *Transport:
transport = &Transport{
base: t.base.Clone(),
}
case *http.Transport:
transport = t.Clone()
default:
transport = c.client.Transport
}
}
return &Client{
client: &http.Client{
Transport: transport,
CheckRedirect: c.client.CheckRedirect,
Jar: c.client.Jar,
Timeout: c.client.Timeout,
},
opts: append([]RequestOpt(nil), c.opts...),
}
}
// SetDefaultTLSConfig 设置默认 TLS 配置
func (c *Client) SetDefaultTLSConfig(tlsConfig *tls.Config) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
}
transport.base.TLSClientConfig = tlsConfig
transport.mu.Unlock()
}
return c
}
// SetDefaultSkipTLSVerify 设置默认跳过 TLS 验证
func (c *Client) SetDefaultSkipTLSVerify(skip bool) *Client {
if transport, ok := c.client.Transport.(*Transport); ok {
transport.mu.Lock()
if transport.base.TLSClientConfig == nil {
transport.base.TLSClientConfig = &tls.Config{}
}
transport.base.TLSClientConfig.InsecureSkipVerify = skip
transport.mu.Unlock()
}
return c
}
// DisableRedirect 禁用重定向
func (c *Client) DisableRedirect() *Client {
c.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return c
}
// EnableRedirect 启用重定向
func (c *Client) EnableRedirect() *Client {
c.client.CheckRedirect = nil
return c
}
// NewRequest 创建新请求
func (c *Client) NewRequest(url, method string, opts ...RequestOpt) (*Request, error) {
return c.NewRequestWithContext(context.Background(), url, method, opts...)
}
// NewRequestWithContext 创建新请求(带 context
func (c *Client) NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) {
// 合并 Client 级别和请求级别的选项
c.mu.RLock()
allOpts := append(append([]RequestOpt(nil), c.opts...), opts...)
c.mu.RUnlock()
req, err := newRequest(ctx, url, method, allOpts...)
if err != nil {
return nil, err
}
req.client = c
req.httpClient = c.client
return req, nil
}
// Get 发送 GET 请求
func (c *Client) Get(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodGet, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Post 发送 POST 请求
func (c *Client) Post(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPost, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Put 发送 PUT 请求
func (c *Client) Put(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPut, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Delete 发送 DELETE 请求
func (c *Client) Delete(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodDelete, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Head 发送 HEAD 请求
func (c *Client) Head(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodHead, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Patch 发送 PATCH 请求
func (c *Client) Patch(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodPatch, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// Options 发送 OPTIONS 请求
func (c *Client) Options(url string, opts ...RequestOpt) (*Response, error) {
req, err := c.NewRequest(url, http.MethodOptions, opts...)
if err != nil {
return nil, err
}
return req.Do()
}
// NewSimpleRequest 创建新请求(忽略错误,支持链式调用)
func (c *Client) NewSimpleRequest(url, method string, opts ...RequestOpt) *Request {
return c.NewSimpleRequestWithContext(context.Background(), url, method, opts...)
}
// NewSimpleRequestWithContext 创建新请求(带 context忽略错误
func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) *Request {
req, err := c.NewRequestWithContext(ctx, url, method, opts...)
if err != nil {
// 返回一个带错误的请求,保持与全局 NewSimpleRequest 行为一致
return &Request{
ctx: ctx,
url: url,
method: method,
err: err,
config: &RequestConfig{
Headers: make(http.Header),
Queries: make(map[string][]string),
Body: BodyConfig{
FormData: make(map[string][]string),
},
},
client: c,
httpClient: c.client,
autoFetch: DefaultFetchRespBody,
}
}
return req
}