diff --git a/client.go b/client.go index 14aed5b..3bc1eb4 100644 --- a/client.go +++ b/client.go @@ -325,3 +325,21 @@ func (c *Client) NewSimpleRequestWithContext(ctx context.Context, url, method st } return req } + +// Trace 发送 TRACE 请求 +func (c *Client) Trace(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodTrace, opts...) + if err != nil { + return nil, err + } + return req.Do() +} + +// Connect 发送 CONNECT 请求 +func (c *Client) Connect(url string, opts ...RequestOpt) (*Response, error) { + req, err := c.NewRequest(url, http.MethodConnect, opts...) + if err != nil { + return nil, err + } + return req.Do() +} diff --git a/dialer.go b/dialer.go index 6fb64ec..0e32b0d 100644 --- a/dialer.go +++ b/dialer.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "fmt" "net" + "strings" "time" ) @@ -119,8 +120,11 @@ func defaultDialTLSFunc(ctx context.Context, network, addr string) (net.Conn, er if tlsConfig.ServerName == "" && !tlsConfig.InsecureSkipVerify { host, _, err := net.SplitHostPort(addr) if err != nil { - // addr 可能没有端口,直接用 addr - host = addr + if idx := strings.LastIndex(addr, ":"); idx > 0 { + host = addr[:idx] + } else { + host = addr + } } tlsConfig = tlsConfig.Clone() // 避免修改原 config tlsConfig.ServerName = host diff --git a/request.go b/request.go index 8c48e66..486b9ef 100644 --- a/request.go +++ b/request.go @@ -84,12 +84,27 @@ func newRequest(ctx context.Context, urlStr string, method string, opts ...Reque // NewRequest 创建新请求 func NewRequest(url, method string, opts ...RequestOpt) (*Request, error) { - return newRequest(context.Background(), url, method, opts...) + req, err := newRequest(context.Background(), url, method, opts...) + if err != nil { + return nil, err + } + if req.err != nil { + return nil, req.err + } + return req, nil } // NewRequestWithContext 创建新请求(带 context) func NewRequestWithContext(ctx context.Context, url, method string, opts ...RequestOpt) (*Request, error) { - return newRequest(ctx, url, method, opts...) + req, err := newRequest(ctx, url, method, opts...) + if err != nil { + return nil, err + } + // 新增 + if req.err != nil { + return nil, req.err + } + return req, nil } // NewSimpleRequest 创建新请求(忽略错误,支持链式调用) @@ -190,7 +205,6 @@ func (r *Request) SetMethod(method string) *Request { if r.err != nil { return r } - method = strings.ToUpper(method) if !validMethod(method) { r.err = wrapError(ErrInvalidMethod, "method: %s", method) @@ -249,6 +263,10 @@ func (r *Request) SetRawRequest(httpReq *http.Request) *Request { } r.httpReq = httpReq r.doRaw = true + if httpReq == nil { + r.err = fmt.Errorf("httpReq cannot be nil") + return r + } return r } @@ -270,6 +288,29 @@ func (r *Request) SetAutoFetch(auto bool) *Request { return r } +// HTTPClient 获取底层 http.Client(只读) +func (r *Request) HTTPClient() (*http.Client, error) { + if r.err != nil { + return nil, r.err + } + + if r.httpClient != nil { + return r.httpClient, nil + } + + // 如果还没构建,先准备 + if err := r.prepare(); err != nil { + return nil, err + } + + return r.httpClient, nil +} + +// Client 获取关联的 Client(只读) +func (r *Request) Client() *Client { + return r.client +} + // Do 执行请求 func (r *Request) Do() (*Response, error) { // 检查累积的错误 diff --git a/request_body.go b/request_body.go index fdba83d..120d399 100644 --- a/request_body.go +++ b/request_body.go @@ -336,17 +336,18 @@ func (r *Request) prepare() error { if r.applied { return nil } - defer func() { r.applied = true }() + // 即使 raw 模式也要确保有 httpClient if r.httpClient == nil { var err error r.httpClient, err = r.buildHTTPClient() if err != nil { - return err + return err // ← 失败时不设置 applied } } // 原始模式不修改请求内容 if r.doRaw { + r.applied = true return nil } @@ -408,7 +409,8 @@ func (r *Request) prepare() error { // 注入配置到 context r.execCtx = injectRequestConfig(r.ctx, r.config) r.httpReq = r.httpReq.WithContext(r.execCtx) - + + r.applied = true return nil } diff --git a/response.go b/response.go index 1de8db4..d43d1ee 100644 --- a/response.go +++ b/response.go @@ -36,6 +36,9 @@ func (r *Response) Body() *Body { // Close 关闭响应体 func (r *Response) Close() error { + if r == nil { + return nil + } if r.body != nil && r.body.raw != nil { return r.body.raw.Close() } @@ -44,6 +47,9 @@ func (r *Response) Close() error { // CloseWithClient 关闭响应体并关闭空闲连接 func (r *Response) CloseWithClient() error { + if r == nil { + return nil + } if r.httpClient != nil { r.httpClient.CloseIdleConnections() }