391 lines
8.8 KiB
Go
391 lines
8.8 KiB
Go
package starnet
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/json"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"os"
|
||
"time"
|
||
)
|
||
|
||
// WithTimeout 设置请求总超时时间
|
||
// timeout > 0: 使用该超时
|
||
// timeout = 0: 使用 Client 默认超时
|
||
// timeout < 0: 禁用本次请求超时(覆盖 Client.Timeout=0)
|
||
func WithTimeout(timeout time.Duration) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Network.Timeout = timeout
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithDialTimeout 设置连接超时时间
|
||
func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Network.DialTimeout = timeout
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithProxy 设置代理
|
||
func WithProxy(proxy string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Network.Proxy = proxy
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithDialFunc 设置自定义 Dial 函数
|
||
func WithDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Network.DialFunc = fn
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithTLSConfig 设置 TLS 配置
|
||
func WithTLSConfig(tlsConfig *tls.Config) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.TLS.Config = tlsConfig
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithSkipTLSVerify 设置是否跳过 TLS 验证
|
||
func WithSkipTLSVerify(skip bool) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.TLS.SkipVerify = skip
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithCustomIP 设置自定义 IP
|
||
func WithCustomIP(ips []string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for _, ip := range ips {
|
||
if net.ParseIP(ip) == nil {
|
||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||
}
|
||
}
|
||
r.config.DNS.CustomIP = ips
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithAddCustomIP 添加自定义 IP
|
||
func WithAddCustomIP(ip string) RequestOpt {
|
||
return func(r *Request) error {
|
||
if net.ParseIP(ip) == nil {
|
||
return wrapError(ErrInvalidIP, "ip: %s", ip)
|
||
}
|
||
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithCustomDNS 设置自定义 DNS 服务器
|
||
func WithCustomDNS(dnsServers []string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for _, dns := range dnsServers {
|
||
if net.ParseIP(dns) == nil {
|
||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||
}
|
||
}
|
||
r.config.DNS.CustomDNS = dnsServers
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithAddCustomDNS 添加自定义 DNS 服务器
|
||
func WithAddCustomDNS(dns string) RequestOpt {
|
||
return func(r *Request) error {
|
||
if net.ParseIP(dns) == nil {
|
||
return wrapError(ErrInvalidDNS, "dns: %s", dns)
|
||
}
|
||
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithLookupFunc 设置自定义 DNS 解析函数
|
||
func WithLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.DNS.LookupFunc = fn
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithHeader 设置 Header
|
||
func WithHeader(key, value string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Headers.Set(key, value)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithHeaders 批量设置 Headers
|
||
func WithHeaders(headers map[string]string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for k, v := range headers {
|
||
r.config.Headers.Set(k, v)
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithContentType 设置 Content-Type
|
||
func WithContentType(contentType string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Headers.Set("Content-Type", contentType)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithUserAgent 设置 User-Agent
|
||
func WithUserAgent(userAgent string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Headers.Set("User-Agent", userAgent)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithBearerToken 设置 Bearer Token
|
||
func WithBearerToken(token string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Headers.Set("Authorization", "Bearer "+token)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithBasicAuth 设置 Basic 认证
|
||
func WithBasicAuth(username, password string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.BasicAuth = [2]string{username, password}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithCookie 添加 Cookie
|
||
func WithCookie(name, value, path string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||
Name: name,
|
||
Value: value,
|
||
Path: path,
|
||
})
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithSimpleCookie 添加简单 Cookie(path 为 /)
|
||
func WithSimpleCookie(name, value string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||
Name: name,
|
||
Value: value,
|
||
Path: "/",
|
||
})
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithCookies 批量添加 Cookies
|
||
func WithCookies(cookies map[string]string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for name, value := range cookies {
|
||
r.config.Cookies = append(r.config.Cookies, &http.Cookie{
|
||
Name: name,
|
||
Value: value,
|
||
Path: "/",
|
||
})
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithBody 设置请求体(字节)
|
||
func WithBody(body []byte) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Body.Bytes = body
|
||
r.config.Body.Reader = nil
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithBodyString 设置请求体(字符串)
|
||
func WithBodyString(body string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Body.Bytes = []byte(body)
|
||
r.config.Body.Reader = nil
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithBodyReader 设置请求体(Reader)
|
||
func WithBodyReader(reader io.Reader) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Body.Reader = reader
|
||
r.config.Body.Bytes = nil
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithJSON 设置 JSON 请求体
|
||
func WithJSON(v interface{}) RequestOpt {
|
||
return func(r *Request) error {
|
||
data, err := json.Marshal(v)
|
||
if err != nil {
|
||
return wrapError(err, "marshal json")
|
||
}
|
||
r.config.Headers.Set("Content-Type", ContentTypeJSON)
|
||
r.config.Body.Bytes = data
|
||
r.config.Body.Reader = nil
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithFormData 设置表单数据
|
||
func WithFormData(data map[string][]string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Body.FormData = data
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithFormDataMap 设置表单数据(简化版)
|
||
func WithFormDataMap(data map[string]string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for k, v := range data {
|
||
r.config.Body.FormData[k] = []string{v}
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithAddFormData 添加表单数据
|
||
func WithAddFormData(key, value string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithFile 添加文件
|
||
func WithFile(formName, filePath string) RequestOpt {
|
||
return func(r *Request) error {
|
||
stat, err := os.Stat(filePath)
|
||
if err != nil {
|
||
return wrapError(ErrFileNotFound, "file: %s", filePath)
|
||
}
|
||
|
||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||
FormName: formName,
|
||
FileName: stat.Name(),
|
||
FilePath: filePath,
|
||
FileSize: stat.Size(),
|
||
FileType: ContentTypeOctetStream,
|
||
})
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithFileStream 添加文件流
|
||
func WithFileStream(formName, fileName string, size int64, reader io.Reader) RequestOpt {
|
||
return func(r *Request) error {
|
||
if reader == nil {
|
||
return ErrNilReader
|
||
}
|
||
|
||
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
||
FormName: formName,
|
||
FileName: fileName,
|
||
FileData: reader,
|
||
FileSize: size,
|
||
FileType: ContentTypeOctetStream,
|
||
})
|
||
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithQuery 添加查询参数
|
||
func WithQuery(key, value string) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Queries[key] = append(r.config.Queries[key], value)
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithQueries 批量添加查询参数
|
||
func WithQueries(queries map[string]string) RequestOpt {
|
||
return func(r *Request) error {
|
||
for k, v := range queries {
|
||
r.config.Queries[k] = append(r.config.Queries[k], v)
|
||
}
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithContentLength 设置 Content-Length
|
||
func WithContentLength(length int64) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.ContentLength = length
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithAutoCalcContentLength 设置是否自动计算 Content-Length
|
||
func WithAutoCalcContentLength(auto bool) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.AutoCalcContentLength = auto
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithUploadProgress 设置文件上传进度回调
|
||
func WithUploadProgress(fn UploadProgressFunc) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.UploadProgress = fn
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithTransport 设置自定义 Transport
|
||
func WithTransport(transport *http.Transport) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.config.Transport = transport
|
||
r.config.CustomTransport = true
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithAutoFetch 设置是否自动获取响应体
|
||
func WithAutoFetch(auto bool) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.autoFetch = auto
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithRawRequest 设置原始请求
|
||
func WithRawRequest(httpReq *http.Request) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.httpReq = httpReq
|
||
r.doRaw = true
|
||
return nil
|
||
}
|
||
}
|
||
|
||
// WithContext 设置 context
|
||
func WithContext(ctx context.Context) RequestOpt {
|
||
return func(r *Request) error {
|
||
r.ctx = ctx
|
||
r.httpReq = r.httpReq.WithContext(ctx)
|
||
return nil
|
||
}
|
||
}
|