starnet/request_prepare.go

315 lines
7.0 KiB
Go
Raw Permalink Normal View History

package starnet
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/http/httptrace"
"net/url"
"strings"
)
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
if httpReq == nil {
return
}
httpReq.Body = io.NopCloser(bytes.NewReader(data))
httpReq.ContentLength = int64(len(data))
httpReq.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewReader(data)), nil
}
}
func clearSimpleBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Bytes = nil
body.Reader = nil
}
func resetFormBodyState(body *BodyConfig) {
if body == nil {
return
}
body.FormData = make(map[string][]string)
}
func resetMultipartBodyState(body *BodyConfig) {
if body == nil {
return
}
body.Files = nil
}
func setBytesBodyConfig(body *BodyConfig, data []byte) {
if body == nil {
return
}
body.Mode = bodyModeBytes
body.Bytes = cloneBytes(data)
body.Reader = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
if body == nil {
return
}
body.Mode = bodyModeReader
body.Reader = reader
body.Bytes = nil
resetFormBodyState(body)
resetMultipartBodyState(body)
}
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
if body == nil {
return
}
body.Mode = bodyModeForm
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = cloneStringMapSlice(data)
}
func ensureFormMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
clearSimpleBodyState(body)
resetMultipartBodyState(body)
body.FormData = make(map[string][]string)
body.Mode = bodyModeForm
}
func ensureMultipartMode(body *BodyConfig) {
if body == nil {
return
}
if body.Mode == bodyModeMultipart {
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
return
}
if body.Mode != bodyModeForm {
clearSimpleBodyState(body)
body.FormData = make(map[string][]string)
}
body.Mode = bodyModeMultipart
if body.FormData == nil {
body.FormData = make(map[string][]string)
}
}
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
if reader == nil {
return nil, nil
}
data := make([]byte, reader.Len())
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
if err != nil && err != io.EOF {
return nil, err
}
return data, nil
}
// applyBody 应用请求体
func (r *Request) applyBody(execCtx context.Context) error {
r.httpReq.Body = nil
r.httpReq.GetBody = nil
r.httpReq.ContentLength = 0
switch r.config.Body.Mode {
case bodyModeReader:
if r.config.Body.Reader == nil {
return nil
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
case *bytes.Reader:
data, err := snapshotBytesReader(reader)
if err != nil {
return wrapError(err, "snapshot bytes reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
case *strings.Reader:
data, err := snapshotStringReader(reader)
if err != nil {
return wrapError(err, "snapshot strings reader")
}
setReplayableRequestBodyBytes(r.httpReq, data)
default:
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
}
switch reader := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(reader.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(reader.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(reader.Len())
}
return nil
case bodyModeBytes:
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
return nil
case bodyModeMultipart:
return r.applyMultipartBody(execCtx)
case bodyModeForm:
values := url.Values{}
for key, items := range r.config.Body.FormData {
for _, value := range items {
values.Add(key, value)
}
}
encoded := values.Encode()
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
return nil
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() (err error) {
if r.applied {
return nil
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
execCtx := r.ctx
if execCtx == nil {
execCtx = context.Background()
}
defaultTLSServerName := ""
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
defaultTLSServerName = r.httpReq.URL.Hostname()
}
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
var traceState *traceState
if r.traceHooks != nil {
traceState = newTraceState(r.traceHooks)
execCtx = withTraceState(execCtx, traceState)
if clientTrace := traceState.clientTrace(); clientTrace != nil {
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
}
}
var cancel context.CancelFunc
if r.config.Network.Timeout > 0 {
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
defer func() {
if err != nil && cancel != nil {
cancel()
}
}()
if r.httpClient == nil {
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err
}
}
if !r.doRaw {
if len(r.config.Queries) > 0 {
query := r.httpReq.URL.Query()
for key, values := range r.config.Queries {
for _, value := range values {
query.Add(key, value)
}
}
r.httpReq.URL.RawQuery = query.Encode()
}
for key, values := range r.config.Headers {
if isHostHeaderKey(key) {
continue
}
for _, value := range values {
r.httpReq.Header.Add(key, value)
}
}
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
if err := r.applyBody(execCtx); err != nil {
return err
}
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
setReplayableRequestBodyBytes(r.httpReq, data)
}
r.syncRequestHost()
}
r.execCtx = execCtx
r.traceState = traceState
r.cancel = cancel
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
if r.client != nil {
return r.client.HTTPClient(), nil
}
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
return DefaultHTTPClient(), nil
}