package starnet import ( "context" "crypto/tls" "encoding/json" "io" "net" "net/http" "os" "time" ) // WithTimeout 设置请求总超时时间 // timeout > 0: 使用该超时 // timeout = 0: 使用 Client 默认超时 // timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0) func WithTimeout(timeout time.Duration) RequestOpt { return func(r *Request) error { r.config.Network.Timeout = timeout return nil } } // WithDialTimeout 设置连接超时时间 func WithDialTimeout(timeout time.Duration) RequestOpt { return func(r *Request) error { r.config.Network.DialTimeout = timeout return nil } } // WithProxy 设置代理 func WithProxy(proxy string) RequestOpt { return func(r *Request) error { r.config.Network.Proxy = proxy return nil } } // WithDialFunc 设置自定义 Dial 函数 func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt { return func(r *Request) error { r.config.Network.DialFunc = fn return nil } } // WithTLSConfig 设置 TLS 配置 func WithTLSConfig(tlsConfig *tls.Config) RequestOpt { return func(r *Request) error { r.config.TLS.Config = tlsConfig return nil } } // WithSkipTLSVerify 设置是否跳过 TLS 验证 func WithSkipTLSVerify(skip bool) RequestOpt { return func(r *Request) error { r.config.TLS.SkipVerify = skip return nil } } // WithCustomIP 设置自定义 IP func WithCustomIP(ips []string) RequestOpt { return func(r *Request) error { for _, ip := range ips { if net.ParseIP(ip) == nil { return wrapError(ErrInvalidIP, "ip: %s", ip) } } r.config.DNS.CustomIP = ips return nil } } // WithAddCustomIP 添加自定义 IP func WithAddCustomIP(ip string) RequestOpt { return func(r *Request) error { if net.ParseIP(ip) == nil { return wrapError(ErrInvalidIP, "ip: %s", ip) } r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) return nil } } // WithCustomDNS 设置自定义 DNS 服务器 func WithCustomDNS(dnsServers []string) RequestOpt { return func(r *Request) error { for _, dns := range dnsServers { if net.ParseIP(dns) == nil { return wrapError(ErrInvalidDNS, "dns: %s", dns) } } r.config.DNS.CustomDNS = dnsServers return nil } } // WithAddCustomDNS 添加自定义 DNS 服务器 func WithAddCustomDNS(dns string) RequestOpt { return func(r *Request) error { if net.ParseIP(dns) == nil { return wrapError(ErrInvalidDNS, "dns: %s", dns) } r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) return nil } } // WithLookupFunc 设置自定义 DNS 解析函数 func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { return func(r *Request) error { r.config.DNS.LookupFunc = fn return nil } } // WithHeader 设置 Header func WithHeader(key, value string) RequestOpt { return func(r *Request) error { r.config.Headers.Set(key, value) return nil } } // WithHeaders 批量设置 Headers func WithHeaders(headers map[string]string) RequestOpt { return func(r *Request) error { for k, v := range headers { r.config.Headers.Set(k, v) } return nil } } // WithContentType 设置 Content-Type func WithContentType(contentType string) RequestOpt { return func(r *Request) error { r.config.Headers.Set("Content-Type", contentType) return nil } } // WithUserAgent 设置 User-Agent func WithUserAgent(userAgent string) RequestOpt { return func(r *Request) error { r.config.Headers.Set("User-Agent", userAgent) return nil } } // WithBearerToken 设置 Bearer Token func WithBearerToken(token string) RequestOpt { return func(r *Request) error { r.config.Headers.Set("Authorization", "Bearer "+token) return nil } } // WithBasicAuth 设置 Basic 认证 func WithBasicAuth(username, password string) RequestOpt { return func(r *Request) error { r.config.BasicAuth = [2]string{username, password} return nil } } // WithCookie 添加 Cookie func WithCookie(name, value, path string) RequestOpt { return func(r *Request) error { r.config.Cookies = append(r.config.Cookies, &http.Cookie{ Name: name, Value: value, Path: path, }) return nil } } // WithSimpleCookie 添加简单 Cookie(path 为 /) func WithSimpleCookie(name, value string) RequestOpt { return func(r *Request) error { r.config.Cookies = append(r.config.Cookies, &http.Cookie{ Name: name, Value: value, Path: "/", }) return nil } } // WithCookies 批量添加 Cookies func WithCookies(cookies map[string]string) RequestOpt { return func(r *Request) error { for name, value := range cookies { r.config.Cookies = append(r.config.Cookies, &http.Cookie{ Name: name, Value: value, Path: "/", }) } return nil } } // WithBody 设置请求体(字节) func WithBody(body []byte) RequestOpt { return func(r *Request) error { r.config.Body.Bytes = body r.config.Body.Reader = nil return nil } } // WithBodyString 设置请求体(字符串) func WithBodyString(body string) RequestOpt { return func(r *Request) error { r.config.Body.Bytes = []byte(body) r.config.Body.Reader = nil return nil } } // WithBodyReader 设置请求体(Reader) func WithBodyReader(reader io.Reader) RequestOpt { return func(r *Request) error { r.config.Body.Reader = reader r.config.Body.Bytes = nil return nil } } // WithJSON 设置 JSON 请求体 func WithJSON(v interface{}) RequestOpt { return func(r *Request) error { data, err := json.Marshal(v) if err != nil { return wrapError(err, "marshal json") } r.config.Headers.Set("Content-Type", ContentTypeJSON) r.config.Body.Bytes = data r.config.Body.Reader = nil return nil } } // WithFormData 设置表单数据 func WithFormData(data map[string][]string) RequestOpt { return func(r *Request) error { r.config.Body.FormData = data return nil } } // WithFormDataMap 设置表单数据(简化版) func WithFormDataMap(data map[string]string) RequestOpt { return func(r *Request) error { for k, v := range data { r.config.Body.FormData[k] = []string{v} } return nil } } // WithAddFormData 添加表单数据 func WithAddFormData(key, value string) RequestOpt { return func(r *Request) error { r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) return nil } } // WithFile 添加文件 func WithFile(formName, filePath string) RequestOpt { return func(r *Request) error { stat, err := os.Stat(filePath) if err != nil { return wrapError(ErrFileNotFound, "file: %s", filePath) } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: stat.Name(), FilePath: filePath, FileSize: stat.Size(), FileType: ContentTypeOctetStream, }) return nil } } // WithFileStream 添加文件流 func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt { return func(r *Request) error { if reader == nil { return ErrNilReader } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, FileData: reader, FileSize: size, FileType: ContentTypeOctetStream, }) return nil } } // WithQuery 添加查询参数 func WithQuery(key, value string) RequestOpt { return func(r *Request) error { r.config.Queries[key] = append(r.config.Queries[key], value) return nil } } // WithQueries 批量添加查询参数 func WithQueries(queries map[string]string) RequestOpt { return func(r *Request) error { for k, v := range queries { r.config.Queries[k] = append(r.config.Queries[k], v) } return nil } } // WithContentLength 设置 Content-Length func WithContentLength(length int64) RequestOpt { return func(r *Request) error { r.config.ContentLength = length return nil } } // WithAutoCalcContentLength 设置是否自动计算 Content-Length func WithAutoCalcContentLength(auto bool) RequestOpt { return func(r *Request) error { r.config.AutoCalcContentLength = auto return nil } } // WithUploadProgress 设置文件上传进度回调 func WithUploadProgress(fn UploadProgressFunc) RequestOpt { return func(r *Request) error { r.config.UploadProgress = fn return nil } } // WithTransport 设置自定义 Transport func WithTransport(transport *http.Transport) RequestOpt { return func(r *Request) error { r.config.Transport = transport r.config.CustomTransport = true return nil } } // WithAutoFetch 设置是否自动获取响应体 func WithAutoFetch(auto bool) RequestOpt { return func(r *Request) error { r.autoFetch = auto return nil } } // WithRawRequest 设置原始请求 func WithRawRequest(httpReq *http.Request) RequestOpt { return func(r *Request) error { r.httpReq = httpReq r.doRaw = true return nil } } // WithContext 设置 context func WithContext(ctx context.Context) RequestOpt { return func(r *Request) error { r.ctx = ctx r.httpReq = r.httpReq.WithContext(ctx) return nil } }