package starnet import ( "bytes" "context" "fmt" "io" "net/http" "net/http/httptrace" "net/url" "strings" ) func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) { if httpReq == nil { return } httpReq.Body = io.NopCloser(bytes.NewReader(data)) httpReq.ContentLength = int64(len(data)) httpReq.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(data)), nil } } func clearSimpleBodyState(body *BodyConfig) { if body == nil { return } body.Bytes = nil body.Reader = nil } func resetFormBodyState(body *BodyConfig) { if body == nil { return } body.FormData = make(map[string][]string) } func resetMultipartBodyState(body *BodyConfig) { if body == nil { return } body.Files = nil } func setBytesBodyConfig(body *BodyConfig, data []byte) { if body == nil { return } body.Mode = bodyModeBytes body.Bytes = cloneBytes(data) body.Reader = nil resetFormBodyState(body) resetMultipartBodyState(body) } func setReaderBodyConfig(body *BodyConfig, reader io.Reader) { if body == nil { return } body.Mode = bodyModeReader body.Reader = reader body.Bytes = nil resetFormBodyState(body) resetMultipartBodyState(body) } func setFormBodyConfig(body *BodyConfig, data map[string][]string) { if body == nil { return } body.Mode = bodyModeForm clearSimpleBodyState(body) resetMultipartBodyState(body) body.FormData = cloneStringMapSlice(data) } func ensureFormMode(body *BodyConfig) { if body == nil { return } if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart { if body.FormData == nil { body.FormData = make(map[string][]string) } return } clearSimpleBodyState(body) resetMultipartBodyState(body) body.FormData = make(map[string][]string) body.Mode = bodyModeForm } func ensureMultipartMode(body *BodyConfig) { if body == nil { return } if body.Mode == bodyModeMultipart { if body.FormData == nil { body.FormData = make(map[string][]string) } return } if body.Mode != bodyModeForm { clearSimpleBodyState(body) body.FormData = make(map[string][]string) } body.Mode = bodyModeMultipart if body.FormData == nil { body.FormData = make(map[string][]string) } } func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) { if reader == nil { return nil, nil } data := make([]byte, reader.Len()) _, err := reader.ReadAt(data, reader.Size()-int64(reader.Len())) if err != nil && err != io.EOF { return nil, err } return data, nil } func snapshotStringReader(reader *strings.Reader) ([]byte, error) { if reader == nil { return nil, nil } data := make([]byte, reader.Len()) _, err := reader.ReadAt(data, reader.Size()-int64(reader.Len())) if err != nil && err != io.EOF { return nil, err } return data, nil } // applyBody 应用请求体 func (r *Request) applyBody(execCtx context.Context) error { r.httpReq.Body = nil r.httpReq.GetBody = nil r.httpReq.ContentLength = 0 switch r.config.Body.Mode { case bodyModeReader: if r.config.Body.Reader == nil { return nil } switch reader := r.config.Body.Reader.(type) { case *bytes.Buffer: setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...)) case *bytes.Reader: data, err := snapshotBytesReader(reader) if err != nil { return wrapError(err, "snapshot bytes reader") } setReplayableRequestBodyBytes(r.httpReq, data) case *strings.Reader: data, err := snapshotStringReader(reader) if err != nil { return wrapError(err, "snapshot strings reader") } setReplayableRequestBodyBytes(r.httpReq, data) default: r.httpReq.Body = io.NopCloser(r.config.Body.Reader) } switch reader := r.config.Body.Reader.(type) { case *bytes.Buffer: r.httpReq.ContentLength = int64(reader.Len()) case *bytes.Reader: r.httpReq.ContentLength = int64(reader.Len()) case *strings.Reader: r.httpReq.ContentLength = int64(reader.Len()) } return nil case bodyModeBytes: setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes) return nil case bodyModeMultipart: return r.applyMultipartBody(execCtx) case bodyModeForm: values := url.Values{} for key, items := range r.config.Body.FormData { for _, value := range items { values.Add(key, value) } } encoded := values.Encode() setReplayableRequestBodyBytes(r.httpReq, []byte(encoded)) return nil } return nil } // prepare 准备请求(应用配置) func (r *Request) prepare() (err error) { if r.applied { return nil } if r.httpReq == nil { return fmt.Errorf("http request is nil") } execCtx := r.ctx if execCtx == nil { execCtx = context.Background() } defaultTLSServerName := "" if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" { defaultTLSServerName = r.httpReq.URL.Hostname() } execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName) var traceState *traceState if r.traceHooks != nil { traceState = newTraceState(r.traceHooks) execCtx = withTraceState(execCtx, traceState) if clientTrace := traceState.clientTrace(); clientTrace != nil { execCtx = httptrace.WithClientTrace(execCtx, clientTrace) } } var cancel context.CancelFunc if r.config.Network.Timeout > 0 { execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout) } defer func() { if err != nil && cancel != nil { cancel() } }() if r.httpClient == nil { r.httpClient, err = r.buildHTTPClient() if err != nil { return err } } if !r.doRaw { if len(r.config.Queries) > 0 { query := r.httpReq.URL.Query() for key, values := range r.config.Queries { for _, value := range values { query.Add(key, value) } } r.httpReq.URL.RawQuery = query.Encode() } for key, values := range r.config.Headers { if isHostHeaderKey(key) { continue } for _, value := range values { r.httpReq.Header.Add(key, value) } } for _, cookie := range r.config.Cookies { r.httpReq.AddCookie(cookie) } if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) } if err := r.applyBody(execCtx); err != nil { return err } if r.config.ContentLength > 0 { r.httpReq.ContentLength = r.config.ContentLength } else if r.config.ContentLength < 0 { r.httpReq.ContentLength = 0 } if r.config.AutoCalcContentLength && r.httpReq.Body != nil { data, err := io.ReadAll(r.httpReq.Body) if err != nil { return wrapError(err, "read body for content length") } setReplayableRequestBodyBytes(r.httpReq, data) } r.syncRequestHost() } r.execCtx = execCtx r.traceState = traceState r.cancel = cancel r.httpReq = r.httpReq.WithContext(r.execCtx) r.applied = true return nil } // buildHTTPClient 构建 HTTP Client func (r *Request) buildHTTPClient() (*http.Client, error) { if r.client != nil { return r.client.HTTPClient(), nil } if r.config.CustomTransport && r.config.Transport != nil { return &http.Client{ Transport: &Transport{base: r.config.Transport}, Timeout: 0, }, nil } return DefaultHTTPClient(), nil }