315 lines
7.0 KiB
Go
315 lines
7.0 KiB
Go
|
|
package starnet
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"net/http/httptrace"
|
||
|
|
"net/url"
|
||
|
|
"strings"
|
||
|
|
)
|
||
|
|
|
||
|
|
func setReplayableRequestBodyBytes(httpReq *http.Request, data []byte) {
|
||
|
|
if httpReq == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
httpReq.Body = io.NopCloser(bytes.NewReader(data))
|
||
|
|
httpReq.ContentLength = int64(len(data))
|
||
|
|
httpReq.GetBody = func() (io.ReadCloser, error) {
|
||
|
|
return io.NopCloser(bytes.NewReader(data)), nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func clearSimpleBodyState(body *BodyConfig) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.Bytes = nil
|
||
|
|
body.Reader = nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func resetFormBodyState(body *BodyConfig) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
}
|
||
|
|
|
||
|
|
func resetMultipartBodyState(body *BodyConfig) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.Files = nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func setBytesBodyConfig(body *BodyConfig, data []byte) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.Mode = bodyModeBytes
|
||
|
|
body.Bytes = cloneBytes(data)
|
||
|
|
body.Reader = nil
|
||
|
|
resetFormBodyState(body)
|
||
|
|
resetMultipartBodyState(body)
|
||
|
|
}
|
||
|
|
|
||
|
|
func setReaderBodyConfig(body *BodyConfig, reader io.Reader) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.Mode = bodyModeReader
|
||
|
|
body.Reader = reader
|
||
|
|
body.Bytes = nil
|
||
|
|
resetFormBodyState(body)
|
||
|
|
resetMultipartBodyState(body)
|
||
|
|
}
|
||
|
|
|
||
|
|
func setFormBodyConfig(body *BodyConfig, data map[string][]string) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
body.Mode = bodyModeForm
|
||
|
|
clearSimpleBodyState(body)
|
||
|
|
resetMultipartBodyState(body)
|
||
|
|
body.FormData = cloneStringMapSlice(data)
|
||
|
|
}
|
||
|
|
|
||
|
|
func ensureFormMode(body *BodyConfig) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if body.Mode == bodyModeForm || body.Mode == bodyModeMultipart {
|
||
|
|
if body.FormData == nil {
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
clearSimpleBodyState(body)
|
||
|
|
resetMultipartBodyState(body)
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
body.Mode = bodyModeForm
|
||
|
|
}
|
||
|
|
|
||
|
|
func ensureMultipartMode(body *BodyConfig) {
|
||
|
|
if body == nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if body.Mode == bodyModeMultipart {
|
||
|
|
if body.FormData == nil {
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
}
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if body.Mode != bodyModeForm {
|
||
|
|
clearSimpleBodyState(body)
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
}
|
||
|
|
body.Mode = bodyModeMultipart
|
||
|
|
if body.FormData == nil {
|
||
|
|
body.FormData = make(map[string][]string)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func snapshotBytesReader(reader *bytes.Reader) ([]byte, error) {
|
||
|
|
if reader == nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
data := make([]byte, reader.Len())
|
||
|
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||
|
|
if err != nil && err != io.EOF {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return data, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func snapshotStringReader(reader *strings.Reader) ([]byte, error) {
|
||
|
|
if reader == nil {
|
||
|
|
return nil, nil
|
||
|
|
}
|
||
|
|
data := make([]byte, reader.Len())
|
||
|
|
_, err := reader.ReadAt(data, reader.Size()-int64(reader.Len()))
|
||
|
|
if err != nil && err != io.EOF {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
return data, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// applyBody 应用请求体
|
||
|
|
func (r *Request) applyBody(execCtx context.Context) error {
|
||
|
|
r.httpReq.Body = nil
|
||
|
|
r.httpReq.GetBody = nil
|
||
|
|
r.httpReq.ContentLength = 0
|
||
|
|
|
||
|
|
switch r.config.Body.Mode {
|
||
|
|
case bodyModeReader:
|
||
|
|
if r.config.Body.Reader == nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
switch reader := r.config.Body.Reader.(type) {
|
||
|
|
case *bytes.Buffer:
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, append([]byte(nil), reader.Bytes()...))
|
||
|
|
case *bytes.Reader:
|
||
|
|
data, err := snapshotBytesReader(reader)
|
||
|
|
if err != nil {
|
||
|
|
return wrapError(err, "snapshot bytes reader")
|
||
|
|
}
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||
|
|
case *strings.Reader:
|
||
|
|
data, err := snapshotStringReader(reader)
|
||
|
|
if err != nil {
|
||
|
|
return wrapError(err, "snapshot strings reader")
|
||
|
|
}
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||
|
|
default:
|
||
|
|
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
||
|
|
}
|
||
|
|
switch reader := r.config.Body.Reader.(type) {
|
||
|
|
case *bytes.Buffer:
|
||
|
|
r.httpReq.ContentLength = int64(reader.Len())
|
||
|
|
case *bytes.Reader:
|
||
|
|
r.httpReq.ContentLength = int64(reader.Len())
|
||
|
|
case *strings.Reader:
|
||
|
|
r.httpReq.ContentLength = int64(reader.Len())
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
case bodyModeBytes:
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, r.config.Body.Bytes)
|
||
|
|
return nil
|
||
|
|
case bodyModeMultipart:
|
||
|
|
return r.applyMultipartBody(execCtx)
|
||
|
|
case bodyModeForm:
|
||
|
|
values := url.Values{}
|
||
|
|
for key, items := range r.config.Body.FormData {
|
||
|
|
for _, value := range items {
|
||
|
|
values.Add(key, value)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
encoded := values.Encode()
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, []byte(encoded))
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// prepare 准备请求(应用配置)
|
||
|
|
func (r *Request) prepare() (err error) {
|
||
|
|
if r.applied {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.httpReq == nil {
|
||
|
|
return fmt.Errorf("http request is nil")
|
||
|
|
}
|
||
|
|
|
||
|
|
execCtx := r.ctx
|
||
|
|
if execCtx == nil {
|
||
|
|
execCtx = context.Background()
|
||
|
|
}
|
||
|
|
defaultTLSServerName := ""
|
||
|
|
if r.httpReq.URL != nil && r.httpReq.URL.Scheme == "https" {
|
||
|
|
defaultTLSServerName = r.httpReq.URL.Hostname()
|
||
|
|
}
|
||
|
|
execCtx = injectRequestConfig(execCtx, r.config, defaultTLSServerName)
|
||
|
|
|
||
|
|
var traceState *traceState
|
||
|
|
if r.traceHooks != nil {
|
||
|
|
traceState = newTraceState(r.traceHooks)
|
||
|
|
execCtx = withTraceState(execCtx, traceState)
|
||
|
|
if clientTrace := traceState.clientTrace(); clientTrace != nil {
|
||
|
|
execCtx = httptrace.WithClientTrace(execCtx, clientTrace)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
var cancel context.CancelFunc
|
||
|
|
if r.config.Network.Timeout > 0 {
|
||
|
|
execCtx, cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
if err != nil && cancel != nil {
|
||
|
|
cancel()
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
if r.httpClient == nil {
|
||
|
|
r.httpClient, err = r.buildHTTPClient()
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
if !r.doRaw {
|
||
|
|
if len(r.config.Queries) > 0 {
|
||
|
|
query := r.httpReq.URL.Query()
|
||
|
|
for key, values := range r.config.Queries {
|
||
|
|
for _, value := range values {
|
||
|
|
query.Add(key, value)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
r.httpReq.URL.RawQuery = query.Encode()
|
||
|
|
}
|
||
|
|
|
||
|
|
for key, values := range r.config.Headers {
|
||
|
|
if isHostHeaderKey(key) {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
for _, value := range values {
|
||
|
|
r.httpReq.Header.Add(key, value)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
for _, cookie := range r.config.Cookies {
|
||
|
|
r.httpReq.AddCookie(cookie)
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
||
|
|
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
||
|
|
}
|
||
|
|
|
||
|
|
if err := r.applyBody(execCtx); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.config.ContentLength > 0 {
|
||
|
|
r.httpReq.ContentLength = r.config.ContentLength
|
||
|
|
} else if r.config.ContentLength < 0 {
|
||
|
|
r.httpReq.ContentLength = 0
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
||
|
|
data, err := io.ReadAll(r.httpReq.Body)
|
||
|
|
if err != nil {
|
||
|
|
return wrapError(err, "read body for content length")
|
||
|
|
}
|
||
|
|
setReplayableRequestBodyBytes(r.httpReq, data)
|
||
|
|
}
|
||
|
|
|
||
|
|
r.syncRequestHost()
|
||
|
|
}
|
||
|
|
|
||
|
|
r.execCtx = execCtx
|
||
|
|
r.traceState = traceState
|
||
|
|
r.cancel = cancel
|
||
|
|
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
||
|
|
r.applied = true
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
|
||
|
|
// buildHTTPClient 构建 HTTP Client
|
||
|
|
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
||
|
|
if r.client != nil {
|
||
|
|
return r.client.HTTPClient(), nil
|
||
|
|
}
|
||
|
|
|
||
|
|
if r.config.CustomTransport && r.config.Transport != nil {
|
||
|
|
return &http.Client{
|
||
|
|
Transport: &Transport{base: r.config.Transport},
|
||
|
|
Timeout: 0,
|
||
|
|
}, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
return DefaultHTTPClient(), nil
|
||
|
|
}
|