package starnet import ( "context" "crypto/tls" "fmt" "net" "net/http" "net/url" "time" ) type requestMutation func(*Request) error func (r *Request) applyMutation(mutation requestMutation) *Request { if r == nil || r.err != nil { return r } if err := mutation(r); err != nil { r.err = err return r } r.invalidatePreparedState() return r } func requestOptFromMutation(mutation requestMutation) RequestOpt { return func(r *Request) error { if r == nil { return nil } return mutation(r) } } func validateCustomIPs(ips []string) error { for _, ip := range ips { if net.ParseIP(ip) == nil { return wrapError(ErrInvalidIP, "ip: %s", ip) } } return nil } func validateCustomDNS(dnsServers []string) error { for _, dns := range dnsServers { if net.ParseIP(dns) == nil { return wrapError(ErrInvalidDNS, "dns: %s", dns) } } return nil } func parseProxyURL(proxy string) (*url.URL, error) { if proxy == "" { return nil, nil } proxyURL, err := url.Parse(proxy) if err != nil { return nil, wrapError(err, "parse proxy url") } if proxyURL.Scheme == "" { return nil, fmt.Errorf("proxy scheme is required: %s", proxy) } if proxyURL.Host == "" { return nil, fmt.Errorf("proxy host is required: %s", proxy) } return proxyURL, nil } func mutateTimeout(timeout time.Duration) requestMutation { return func(r *Request) error { r.config.Network.Timeout = timeout return nil } } func mutateDialTimeout(timeout time.Duration) requestMutation { return func(r *Request) error { r.config.Network.DialTimeout = timeout return nil } } func mutateProxy(proxy string) requestMutation { return func(r *Request) error { if _, err := parseProxyURL(proxy); err != nil { return err } r.config.Network.Proxy = proxy return nil } } func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation { return func(r *Request) error { r.config.Network.DialFunc = fn return nil } } func mutateTLSConfig(tlsConfig *tls.Config) requestMutation { return func(r *Request) error { r.config.TLS.Config = tlsConfig return nil } } func mutateTLSServerName(serverName string) requestMutation { return func(r *Request) error { r.config.TLS.ServerName = serverName return nil } } func mutateTraceHooks(hooks *TraceHooks) requestMutation { return func(r *Request) error { r.traceHooks = hooks return nil } } func mutateSkipTLSVerify(skip bool) requestMutation { return func(r *Request) error { r.config.TLS.SkipVerify = skip return nil } } func mutateCustomIP(ips []string) requestMutation { return func(r *Request) error { if err := validateCustomIPs(ips); err != nil { return err } r.config.DNS.CustomIP = cloneStringSlice(ips) return nil } } func mutateAddCustomIP(ip string) requestMutation { return func(r *Request) error { if err := validateCustomIPs([]string{ip}); err != nil { return err } r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip) return nil } } func mutateCustomDNS(dnsServers []string) requestMutation { return func(r *Request) error { if err := validateCustomDNS(dnsServers); err != nil { return err } r.config.DNS.CustomDNS = cloneStringSlice(dnsServers) return nil } } func mutateAddCustomDNS(dns string) requestMutation { return func(r *Request) error { if err := validateCustomDNS([]string{dns}); err != nil { return err } r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns) return nil } } func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation { return func(r *Request) error { r.config.DNS.LookupFunc = fn return nil } } func mutateBasicAuth(username, password string) requestMutation { return func(r *Request) error { r.config.BasicAuth = [2]string{username, password} return nil } } func mutateContentLength(length int64) requestMutation { return func(r *Request) error { r.config.ContentLength = length return nil } } func mutateAutoCalcContentLength(auto bool) requestMutation { return func(r *Request) error { if r.doRaw { return fmt.Errorf("cannot set auto calc content length in raw mode") } r.config.AutoCalcContentLength = auto return nil } } func mutateTransport(transport *http.Transport) requestMutation { return func(r *Request) error { r.config.Transport = transport r.config.CustomTransport = true return nil } } func mutateUploadProgress(fn UploadProgressFunc) requestMutation { return func(r *Request) error { r.config.UploadProgress = fn return nil } } func mutateAutoFetch(auto bool) requestMutation { return func(r *Request) error { r.autoFetch = auto return nil } } func mutateMaxRespBodyBytes(maxBytes int64) requestMutation { return func(r *Request) error { if maxBytes < 0 { return fmt.Errorf("max response body bytes must be >= 0") } r.config.MaxRespBodyBytes = maxBytes return nil } } func mutateContext(ctx context.Context) requestMutation { return func(r *Request) error { ctx = normalizeContext(ctx) r.ctx = ctx if r.doRaw && r.rawTemplate != nil { r.rawTemplate = r.rawTemplate.WithContext(ctx) } if r.httpReq != nil { r.httpReq = r.httpReq.WithContext(ctx) } return nil } } func mutateRawRequest(httpReq *http.Request) requestMutation { return func(r *Request) error { if httpReq == nil { return fmt.Errorf("httpReq cannot be nil") } r.httpReq = httpReq r.rawTemplate = httpReq r.ctx = normalizeContext(httpReq.Context()) r.method = httpReq.Method if httpReq.URL != nil { r.url = httpReq.URL.String() } r.doRaw = true r.rawSourceExternal = true return nil } } func mutateAddQuery(key, value string) requestMutation { return func(r *Request) error { r.config.Queries[key] = append(r.config.Queries[key], value) return nil } } func mutateSetQuery(key, value string) requestMutation { return func(r *Request) error { r.config.Queries[key] = []string{value} return nil } } func mutateSetQueries(queries map[string][]string) requestMutation { return func(r *Request) error { r.config.Queries = cloneStringMapSlice(queries) return nil } } func mutateAddQueries(queries map[string]string) requestMutation { return func(r *Request) error { for key, value := range queries { r.config.Queries[key] = append(r.config.Queries[key], value) } return nil } } func mutateDeleteQuery(key string) requestMutation { return func(r *Request) error { delete(r.config.Queries, key) return nil } } func mutateDeleteQueryValue(key, value string) requestMutation { return func(r *Request) error { values, ok := r.config.Queries[key] if !ok { return nil } newValues := make([]string, 0, len(values)) for _, item := range values { if item != value { newValues = append(newValues, item) } } if len(newValues) == 0 { delete(r.config.Queries, key) return nil } r.config.Queries[key] = newValues return nil } }