starnet/request_body.go
starainrt b5bd7595a1
1. 优化ping功能
2. 新增重试机制
3. 优化错误处理逻辑
2026-03-19 16:42:45 +08:00

449 lines
9.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package starnet
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/url"
"os"
"strings"
)
// SetBody 设置请求体(字节)
func (r *Request) SetBody(body []byte) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Bytes = body
r.config.Body.Reader = nil
return r
}
// SetBodyReader 设置请求体Reader
func (r *Request) SetBodyReader(reader io.Reader) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.Reader = reader
r.config.Body.Bytes = nil
return r
}
// SetBodyString 设置请求体(字符串)
func (r *Request) SetBodyString(body string) *Request {
return r.SetBody([]byte(body))
}
// SetJSON 设置 JSON 请求体
func (r *Request) SetJSON(v interface{}) *Request {
if r.err != nil {
return r
}
data, err := json.Marshal(v)
if err != nil {
r.err = wrapError(err, "marshal json")
return r
}
return r.SetContentType(ContentTypeJSON).SetBody(data)
}
// SetFormData 设置表单数据(覆盖)
func (r *Request) SetFormData(data map[string][]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData = cloneStringMapSlice(data)
return r
}
// AddFormData 添加表单数据
func (r *Request) AddFormData(key, value string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
return r
}
// AddFormDataMap 批量添加表单数据
func (r *Request) AddFormDataMap(data map[string]string) *Request {
if r.err != nil {
return r
}
if r.doRaw {
return r
}
for k, v := range data {
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
}
return r
}
// AddFile 添加文件(从路径)
func (r *Request) AddFile(formName, filePath string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithName 添加文件(指定文件名)
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FilePath: filePath,
FileSize: stat.Size(),
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileWithType 添加文件(指定 MIME 类型)
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
if r.err != nil {
return r
}
stat, err := os.Stat(filePath)
if err != nil {
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: stat.Name(),
FilePath: filePath,
FileSize: stat.Size(),
FileType: fileType,
})
return r
}
// AddFileStream 添加文件流
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: ContentTypeOctetStream,
})
return r
}
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
if r.err != nil {
return r
}
if reader == nil {
r.err = ErrNilReader
return r
}
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
FormName: formName,
FileName: fileName,
FileData: reader,
FileSize: size,
FileType: fileType,
})
return r
}
// applyBody 应用请求体
func (r *Request) applyBody() error {
// 优先级Reader > Bytes > Files > FormData
// 1. Reader
if r.config.Body.Reader != nil {
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
// 尝试获取长度
switch v := r.config.Body.Reader.(type) {
case *bytes.Buffer:
r.httpReq.ContentLength = int64(v.Len())
case *bytes.Reader:
r.httpReq.ContentLength = int64(v.Len())
case *strings.Reader:
r.httpReq.ContentLength = int64(v.Len())
}
return nil
}
// 2. Bytes
if len(r.config.Body.Bytes) > 0 {
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
return nil
}
// 3. Filesmultipart/form-data
if len(r.config.Body.Files) > 0 {
return r.applyMultipartBody()
}
// 4. FormDataapplication/x-www-form-urlencoded
if len(r.config.Body.FormData) > 0 {
values := url.Values{}
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
values.Add(k, v)
}
}
encoded := values.Encode()
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
r.httpReq.ContentLength = int64(len(encoded))
return nil
}
return nil
}
// applyMultipartBody 应用 multipart 请求体
func (r *Request) applyMultipartBody() error {
pr, pw := io.Pipe()
writer := multipart.NewWriter(pw)
// 设置 Content-Type
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
r.httpReq.Body = pr
// 在 goroutine 中写入数据
go func() {
defer pw.Close()
defer writer.Close()
// 写入表单字段
for k, vs := range r.config.Body.FormData {
for _, v := range vs {
if err := writer.WriteField(k, v); err != nil {
pw.CloseWithError(wrapError(err, "write form field"))
return
}
}
}
// 写入文件
for _, file := range r.config.Body.Files {
if err := r.writeFile(writer, file); err != nil {
pw.CloseWithError(err)
return
}
}
}()
return nil
}
// writeFile 写入文件到 multipart writer
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
// 创建文件字段
part, err := writer.CreateFormFile(file.FormName, file.FileName)
if err != nil {
return wrapError(err, "create form file")
}
// 获取文件数据源
var reader io.Reader
if file.FileData != nil {
reader = file.FileData
} else if file.FilePath != "" {
f, err := os.Open(file.FilePath)
if err != nil {
return wrapError(err, "open file")
}
defer f.Close()
reader = f
} else {
return ErrNilReader
}
// 复制文件数据(带进度)
if r.config.UploadProgress != nil {
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
} else {
_, err = io.Copy(part, reader)
}
if err != nil {
return wrapError(err, "copy file data")
}
return nil
}
// prepare 准备请求(应用配置)
func (r *Request) prepare() error {
if r.applied {
return nil
}
// 即使 raw 模式也要确保有 httpClient
if r.httpClient == nil {
var err error
r.httpClient, err = r.buildHTTPClient()
if err != nil {
return err // ← 失败时不设置 applied
}
}
if r.httpReq == nil {
return fmt.Errorf("http request is nil")
}
// 原始模式不修改请求内容
if !r.doRaw {
// 应用查询参数
if len(r.config.Queries) > 0 {
q := r.httpReq.URL.Query()
for k, values := range r.config.Queries {
for _, v := range values {
q.Add(k, v)
}
}
r.httpReq.URL.RawQuery = q.Encode()
}
// 应用 Headers
for k, values := range r.config.Headers {
for _, v := range values {
r.httpReq.Header.Add(k, v)
}
}
// 应用 Cookies
for _, cookie := range r.config.Cookies {
r.httpReq.AddCookie(cookie)
}
// 应用 Basic Auth
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
}
// 应用请求体
if err := r.applyBody(); err != nil {
return err
}
// 应用 Content-Length
if r.config.ContentLength > 0 {
r.httpReq.ContentLength = r.config.ContentLength
} else if r.config.ContentLength < 0 {
r.httpReq.ContentLength = 0
}
// 自动计算 Content-Length
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
data, err := io.ReadAll(r.httpReq.Body)
if err != nil {
return wrapError(err, "read body for content length")
}
r.httpReq.ContentLength = int64(len(data))
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
}
// 设置 TLS ServerName如果有 TLS Config
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
}
}
execCtx := r.ctx
if !r.doRaw {
// raw 模式下不注入请求级网络配置,只应用 context/超时。
execCtx = injectRequestConfig(execCtx, r.config)
}
// 请求级总超时通过 context 控制,避免污染共享 http.Client。
if r.config.Network.Timeout > 0 {
execCtx, r.cancel = context.WithTimeout(execCtx, r.config.Network.Timeout)
}
r.execCtx = execCtx
r.httpReq = r.httpReq.WithContext(r.execCtx)
r.applied = true
return nil
}
// buildHTTPClient 构建 HTTP Client
func (r *Request) buildHTTPClient() (*http.Client, error) {
// 优先使用请求关联的 Client
if r.client != nil {
return r.client.HTTPClient(), nil
}
// 自定义 Transport
if r.config.CustomTransport && r.config.Transport != nil {
return &http.Client{
Transport: &Transport{base: r.config.Transport},
Timeout: 0,
}, nil
}
// 默认全局 client
return DefaultHTTPClient(), nil
}