starnet/request_mutation.go

327 lines
6.8 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"time"
)
type requestMutation func(*Request) error
func (r *Request) applyMutation(mutation requestMutation) *Request {
if r == nil || r.err != nil {
return r
}
if err := mutation(r); err != nil {
r.err = err
return r
}
r.invalidatePreparedState()
return r
}
func requestOptFromMutation(mutation requestMutation) RequestOpt {
return func(r *Request) error {
if r == nil {
return nil
}
return mutation(r)
}
}
func validateCustomIPs(ips []string) error {
for _, ip := range ips {
if net.ParseIP(ip) == nil {
return wrapError(ErrInvalidIP, "ip: %s", ip)
}
}
return nil
}
func validateCustomDNS(dnsServers []string) error {
for _, dns := range dnsServers {
if net.ParseIP(dns) == nil {
return wrapError(ErrInvalidDNS, "dns: %s", dns)
}
}
return nil
}
func parseProxyURL(proxy string) (*url.URL, error) {
if proxy == "" {
return nil, nil
}
proxyURL, err := url.Parse(proxy)
if err != nil {
return nil, wrapError(err, "parse proxy url")
}
if proxyURL.Scheme == "" {
return nil, fmt.Errorf("proxy scheme is required: %s", proxy)
}
if proxyURL.Host == "" {
return nil, fmt.Errorf("proxy host is required: %s", proxy)
}
return proxyURL, nil
}
func mutateTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.Timeout = timeout
return nil
}
}
func mutateDialTimeout(timeout time.Duration) requestMutation {
return func(r *Request) error {
r.config.Network.DialTimeout = timeout
return nil
}
}
func mutateProxy(proxy string) requestMutation {
return func(r *Request) error {
if _, err := parseProxyURL(proxy); err != nil {
return err
}
r.config.Network.Proxy = proxy
return nil
}
}
func mutateDialFunc(fn func(ctx context.Context, network, addr string) (net.Conn, error)) requestMutation {
return func(r *Request) error {
r.config.Network.DialFunc = fn
return nil
}
}
func mutateTLSConfig(tlsConfig *tls.Config) requestMutation {
return func(r *Request) error {
r.config.TLS.Config = tlsConfig
return nil
}
}
func mutateTLSServerName(serverName string) requestMutation {
return func(r *Request) error {
r.config.TLS.ServerName = serverName
return nil
}
}
func mutateTraceHooks(hooks *TraceHooks) requestMutation {
return func(r *Request) error {
r.traceHooks = hooks
return nil
}
}
func mutateSkipTLSVerify(skip bool) requestMutation {
return func(r *Request) error {
r.config.TLS.SkipVerify = skip
return nil
}
}
func mutateCustomIP(ips []string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs(ips); err != nil {
return err
}
r.config.DNS.CustomIP = cloneStringSlice(ips)
return nil
}
}
func mutateAddCustomIP(ip string) requestMutation {
return func(r *Request) error {
if err := validateCustomIPs([]string{ip}); err != nil {
return err
}
r.config.DNS.CustomIP = append(r.config.DNS.CustomIP, ip)
return nil
}
}
func mutateCustomDNS(dnsServers []string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS(dnsServers); err != nil {
return err
}
r.config.DNS.CustomDNS = cloneStringSlice(dnsServers)
return nil
}
}
func mutateAddCustomDNS(dns string) requestMutation {
return func(r *Request) error {
if err := validateCustomDNS([]string{dns}); err != nil {
return err
}
r.config.DNS.CustomDNS = append(r.config.DNS.CustomDNS, dns)
return nil
}
}
func mutateLookupFunc(fn func(ctx context.Context, host string) ([]net.IPAddr, error)) requestMutation {
return func(r *Request) error {
r.config.DNS.LookupFunc = fn
return nil
}
}
func mutateBasicAuth(username, password string) requestMutation {
return func(r *Request) error {
r.config.BasicAuth = [2]string{username, password}
return nil
}
}
func mutateContentLength(length int64) requestMutation {
return func(r *Request) error {
r.config.ContentLength = length
return nil
}
}
func mutateAutoCalcContentLength(auto bool) requestMutation {
return func(r *Request) error {
if r.doRaw {
return fmt.Errorf("cannot set auto calc content length in raw mode")
}
r.config.AutoCalcContentLength = auto
return nil
}
}
func mutateTransport(transport *http.Transport) requestMutation {
return func(r *Request) error {
r.config.Transport = transport
r.config.CustomTransport = true
return nil
}
}
func mutateUploadProgress(fn UploadProgressFunc) requestMutation {
return func(r *Request) error {
r.config.UploadProgress = fn
return nil
}
}
func mutateAutoFetch(auto bool) requestMutation {
return func(r *Request) error {
r.autoFetch = auto
return nil
}
}
func mutateMaxRespBodyBytes(maxBytes int64) requestMutation {
return func(r *Request) error {
if maxBytes < 0 {
return fmt.Errorf("max response body bytes must be >= 0")
}
r.config.MaxRespBodyBytes = maxBytes
return nil
}
}
func mutateContext(ctx context.Context) requestMutation {
return func(r *Request) error {
ctx = normalizeContext(ctx)
r.ctx = ctx
if r.doRaw && r.rawTemplate != nil {
r.rawTemplate = r.rawTemplate.WithContext(ctx)
}
if r.httpReq != nil {
r.httpReq = r.httpReq.WithContext(ctx)
}
return nil
}
}
func mutateRawRequest(httpReq *http.Request) requestMutation {
return func(r *Request) error {
if httpReq == nil {
return fmt.Errorf("httpReq cannot be nil")
}
r.httpReq = httpReq
r.rawTemplate = httpReq
r.ctx = normalizeContext(httpReq.Context())
r.method = httpReq.Method
if httpReq.URL != nil {
r.url = httpReq.URL.String()
}
r.doRaw = true
r.rawSourceExternal = true
return nil
}
}
func mutateAddQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = append(r.config.Queries[key], value)
return nil
}
}
func mutateSetQuery(key, value string) requestMutation {
return func(r *Request) error {
r.config.Queries[key] = []string{value}
return nil
}
}
func mutateSetQueries(queries map[string][]string) requestMutation {
return func(r *Request) error {
r.config.Queries = cloneStringMapSlice(queries)
return nil
}
}
func mutateAddQueries(queries map[string]string) requestMutation {
return func(r *Request) error {
for key, value := range queries {
r.config.Queries[key] = append(r.config.Queries[key], value)
}
return nil
}
}
func mutateDeleteQuery(key string) requestMutation {
return func(r *Request) error {
delete(r.config.Queries, key)
return nil
}
}
func mutateDeleteQueryValue(key, value string) requestMutation {
return func(r *Request) error {
values, ok := r.config.Queries[key]
if !ok {
return nil
}
newValues := make([]string, 0, len(values))
for _, item := range values {
if item != value {
newValues = append(newValues, item)
}
}
if len(newValues) == 0 {
delete(r.config.Queries, key)
return nil
}
r.config.Queries[key] = newValues
return nil
}
}