starnet/request_body.go
2026-03-08 20:19:40 +08:00

464 lines
9.7 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"bytes"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)
// SetBody 设置请求体(字节)
func (r *Request) SetBody(body []byte) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return r
}
// SetBodyReader 设置请求体Reader
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return r
}
// SetBodyString 设置请求体(字符串)
func (r *Request) SetBodyString(body string) *Request {
return r.SetBody([]byte(body))
}
// SetJSON 设置 JSON 请求体
func (r *Request) SetJSON(v interface{}) *Request {
if r.err != nil {
return r
}
data, err := json.Marshal(v)
if err != nil {
r.err = wrapError(err, "marshal json")
return r
}
return r.SetContentType(ContentTypeJSON).SetBody(data)
}
// SetFormData 设置表单数据(覆盖)
func (r *Request) SetFormData(data map[string][]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData = data
return r
}
// AddFormData 添加表单数据
func (r *Request) AddFormData(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return r
}
// AddFormDataMap 批量添加表单数据
func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range data {
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
}
return r
}
// AddFile 添加文件(从路径)
func (r *Request) AddFile(formName, filePath string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithName 添加文件(指定文件名)
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithType 添加文件(指定 MIME 类型)
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: fileType,
})
return r
}
// AddFileStream 添加文件流
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: fileType,
})
return r
}
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
defer func() { r.applied = true }()
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err
}
}
// 原始模式不修改请求内容
if r.doRaw {
return nil
}
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
// 注入配置到 context
r.execCtx = injectRequestConfig(r.ctx, r.config)
r.httpReq = r.httpReq.WithContext(r.execCtx)
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
applyTimeoutOverride := func(base *http.Client) *http.Client {
// 没有 base 时兜底
if base == nil {
base = &http.Client{}
}
rt := r.config.Network.Timeout
// 语义:
// rt < 0 : 本次请求禁用超时Timeout = 0
// rt = 0 : 沿用 base.Timeout
// rt > 0 : 本次请求超时覆盖
if rt == 0 {
return base
}
clone := &http.Client{
Transport: base.Transport,
CheckRedirect: base.CheckRedirect,
Jar: base.Jar,
}
if rt < 0 {
clone.Timeout = 0
} else {
clone.Timeout = rt
}
return clone
}
// 优先使用请求关联的 Client
if r.client != nil {
return applyTimeoutOverride(r.client.HTTPClient()), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
base := &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}
return applyTimeoutOverride(base), nil
}
// 默认全局 client
return applyTimeoutOverride(DefaultHTTPClient()), nil
}