starnet/request_config.go

270 lines
5.4 KiB
Go
Raw Permalink Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"time"
)
// SetTimeout 设置请求总超时时间
// timeout > 0: 使用该超时
// timeout = 0: 使用 Client 默认超时
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0
func (r *Request) SetTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.Timeout = timeout
return r
}
// SetDialTimeout 设置连接超时时间
func (r *Request) SetDialTimeout(timeout time.Duration) *Request {
if r.err != nil {
return r
}
r.config.Network.DialTimeout = timeout
return r
}
// SetProxy 设置代理
func (r *Request) SetProxy(proxy string) *Request {
if r.err != nil {
return r
}
r.config.Network.Proxy = proxy
return r
}
// SetDialFunc 设置自定义 Dial 函数
func (r *Request) SetDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) *Request {
if r.err != nil {
return r
}
r.config.Network.DialFunc = fn
return r
}
// SetTLSConfig 设置 TLS 配置
func (r *Request) SetTLSConfig(tlsConfig *tls.Config) *Request {
if r.err != nil {
return r
}
r.config.TLS.Config = tlsConfig
return r
}
// SetSkipTLSVerify 设置是否跳过 TLS 验证
func (r *Request) SetSkipTLSVerify(skip bool) *Request {
if r.err != nil {
return r
}
r.config.TLS.SkipVerify = skip
return r
}
// SetCustomIP 设置自定义 IP直接指定 IP跳过 DNS
func (r *Request) SetCustomIP(ips []string) *Request {
if r.err != nil {
return r
}
// 验证 IP 格式
for _, ip := range ips {
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
}
r.config.DNS.CustomIP = ips
return r
}
// AddCustomIP 添加自定义 IP
func (r *Request) AddCustomIP(ip string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(ip) == nil {
r.err = wrapError(ErrInvalidIP, "ip: %s", ip)
return r
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return r
}
// SetCustomDNS 设置自定义 DNS 服务器
func (r *Request) SetCustomDNS(dnsServers []string) *Request {
if r.err != nil {
return r
}
// 验证 DNS 服务器格式
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
}
r.config.DNS.CustomDNS = dnsServers
return r
}
// AddCustomDNS 添加自定义 DNS 服务器
func (r *Request) AddCustomDNS(dns string) *Request {
if r.err != nil {
return r
}
if net.ParseIP(dns) == nil {
r.err = wrapError(ErrInvalidDNS, "dns: %s", dns)
return r
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return r
}
// SetLookupFunc 设置自定义 DNS 解析函数
func (r *Request) SetLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request {
if r.err != nil {
return r
}
r.config.DNS.LookupFunc = fn
return r
}
// SetBasicAuth 设置 Basic 认证
func (r *Request) SetBasicAuth(username, password string) *Request {
if r.err != nil {
return r
}
r.config.BasicAuth = [2]string{username, password}
return r
}
// SetContentLength 设置 Content-Length
func (r *Request) SetContentLength(length int64) *Request {
if r.err != nil {
return r
}
r.config.ContentLength = length
return r
}
// SetAutoCalcContentLength 设置是否自动计算 Content-Length
// 警告:启用后会将整个 body 读入内存
func (r *Request) SetAutoCalcContentLength(auto bool) *Request {
if r.err != nil {
return r
}
if r.doRaw {
r.err = fmt.Errorf("cannot set auto calc content length in raw mode")
return r
}
r.config.AutoCalcContentLength = auto
return r
}
// SetTransport 设置自定义 Transport
func (r *Request) SetTransport(transport *http.Transport) *Request {
if r.err != nil {
return r
}
r.config.Transport = transport
r.config.CustomTransport = true
return r
}
// SetUploadProgress 设置文件上传进度回调
func (r *Request) SetUploadProgress(fn UploadProgressFunc) *Request {
if r.err != nil {
return r
}
r.config.UploadProgress = fn
return r
}
// AddQuery 添加查询参数
func (r *Request) AddQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = append(r.config.Queries[key], value)
return r
}
// SetQuery 设置查询参数(覆盖)
func (r *Request) SetQuery(key, value string) *Request {
if r.err != nil {
return r
}
r.config.Queries[key] = []string{value}
return r
}
// SetQueries 设置所有查询参数(覆盖)
func (r *Request) SetQueries(queries map[string][]string) *Request {
if r.err != nil {
return r
}
r.config.Queries = queries
return r
}
// AddQueries 批量添加查询参数
func (r *Request) AddQueries(queries map[string]string) *Request {
if r.err != nil {
return r
}
for k, v := range queries {
r.config.Queries[k] = append(r.config.Queries[k], v)
}
return r
}
// DeleteQuery 删除查询参数
func (r *Request) DeleteQuery(key string) *Request {
if r.err != nil {
return r
}
delete(r.config.Queries, key)
return r
}
// DeleteQueryValue 删除查询参数的特定值
func (r *Request) DeleteQueryValue(key, value string) *Request {
if r.err != nil {
return r
}
values, ok := r.config.Queries[key]
if !ok {
return r
}
newValues := make([]string, 0, len(values))
for _, v := range values {
if v != value {
newValues = append(newValues, v)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
} else {
r.config.Queries[key] = newValues
}
return r
}