464 lines
9.7 KiB
Go
464 lines
9.7 KiB
Go
|
|
package starnet
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"bytes"
|
|||
|
|
"encoding/json"
|
|||
|
|
"io"
|
|||
|
|
"mime/multipart"
|
|||
|
|
"net/http"
|
|||
|
|
"net/url"
|
|||
|
|
"os"
|
|||
|
|
"strings"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// SetBody 设置请求体(字节)
|
|||
|
|
func (r *Request) SetBody(body []byte) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
if r.doRaw {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
r.config.Body.Bytes = body
|
|||
|
|
r.config.Body.Reader = nil
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SetBodyReader 设置请求体(Reader)
|
|||
|
|
func (r *Request) SetBodyReader(reader io.Reader) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
if r.doRaw {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
r.config.Body.Reader = reader
|
|||
|
|
r.config.Body.Bytes = nil
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SetBodyString 设置请求体(字符串)
|
|||
|
|
func (r *Request) SetBodyString(body string) *Request {
|
|||
|
|
return r.SetBody([]byte(body))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SetJSON 设置 JSON 请求体
|
|||
|
|
func (r *Request) SetJSON(v interface{}) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
data, err := json.Marshal(v)
|
|||
|
|
if err != nil {
|
|||
|
|
r.err = wrapError(err, "marshal json")
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return r.SetContentType(ContentTypeJSON).SetBody(data)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// SetFormData 设置表单数据(覆盖)
|
|||
|
|
func (r *Request) SetFormData(data map[string][]string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
if r.doRaw {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
r.config.Body.FormData = data
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFormData 添加表单数据
|
|||
|
|
func (r *Request) AddFormData(key, value string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
if r.doRaw {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value)
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFormDataMap 批量添加表单数据
|
|||
|
|
func (r *Request) AddFormDataMap(data map[string]string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
if r.doRaw {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
for k, v := range data {
|
|||
|
|
r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v)
|
|||
|
|
}
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFile 添加文件(从路径)
|
|||
|
|
func (r *Request) AddFile(formName, filePath string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
stat, err := os.Stat(filePath)
|
|||
|
|
if err != nil {
|
|||
|
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|||
|
|
FormName: formName,
|
|||
|
|
FileName: stat.Name(),
|
|||
|
|
FilePath: filePath,
|
|||
|
|
FileSize: stat.Size(),
|
|||
|
|
FileType: ContentTypeOctetStream,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFileWithName 添加文件(指定文件名)
|
|||
|
|
func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
stat, err := os.Stat(filePath)
|
|||
|
|
if err != nil {
|
|||
|
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|||
|
|
FormName: formName,
|
|||
|
|
FileName: fileName,
|
|||
|
|
FilePath: filePath,
|
|||
|
|
FileSize: stat.Size(),
|
|||
|
|
FileType: ContentTypeOctetStream,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFileWithType 添加文件(指定 MIME 类型)
|
|||
|
|
func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
stat, err := os.Stat(filePath)
|
|||
|
|
if err != nil {
|
|||
|
|
r.err = wrapError(ErrFileNotFound, "file: %s", filePath)
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|||
|
|
FormName: formName,
|
|||
|
|
FileName: stat.Name(),
|
|||
|
|
FilePath: filePath,
|
|||
|
|
FileSize: stat.Size(),
|
|||
|
|
FileType: fileType,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFileStream 添加文件流
|
|||
|
|
func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if reader == nil {
|
|||
|
|
r.err = ErrNilReader
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|||
|
|
FormName: formName,
|
|||
|
|
FileName: fileName,
|
|||
|
|
FileData: reader,
|
|||
|
|
FileSize: size,
|
|||
|
|
FileType: ContentTypeOctetStream,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// AddFileStreamWithType 添加文件流(指定 MIME 类型)
|
|||
|
|
func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request {
|
|||
|
|
if r.err != nil {
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if reader == nil {
|
|||
|
|
r.err = ErrNilReader
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
r.config.Body.Files = append(r.config.Body.Files, RequestFile{
|
|||
|
|
FormName: formName,
|
|||
|
|
FileName: fileName,
|
|||
|
|
FileData: reader,
|
|||
|
|
FileSize: size,
|
|||
|
|
FileType: fileType,
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return r
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// applyBody 应用请求体
|
|||
|
|
func (r *Request) applyBody() error {
|
|||
|
|
// 优先级:Reader > Bytes > Files > FormData
|
|||
|
|
|
|||
|
|
// 1. Reader
|
|||
|
|
if r.config.Body.Reader != nil {
|
|||
|
|
r.httpReq.Body = io.NopCloser(r.config.Body.Reader)
|
|||
|
|
|
|||
|
|
// 尝试获取长度
|
|||
|
|
switch v := r.config.Body.Reader.(type) {
|
|||
|
|
case *bytes.Buffer:
|
|||
|
|
r.httpReq.ContentLength = int64(v.Len())
|
|||
|
|
case *bytes.Reader:
|
|||
|
|
r.httpReq.ContentLength = int64(v.Len())
|
|||
|
|
case *strings.Reader:
|
|||
|
|
r.httpReq.ContentLength = int64(v.Len())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 2. Bytes
|
|||
|
|
if len(r.config.Body.Bytes) > 0 {
|
|||
|
|
r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes))
|
|||
|
|
r.httpReq.ContentLength = int64(len(r.config.Body.Bytes))
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 3. Files(multipart/form-data)
|
|||
|
|
if len(r.config.Body.Files) > 0 {
|
|||
|
|
return r.applyMultipartBody()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 4. FormData(application/x-www-form-urlencoded)
|
|||
|
|
if len(r.config.Body.FormData) > 0 {
|
|||
|
|
values := url.Values{}
|
|||
|
|
for k, vs := range r.config.Body.FormData {
|
|||
|
|
for _, v := range vs {
|
|||
|
|
values.Add(k, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
encoded := values.Encode()
|
|||
|
|
r.httpReq.Body = io.NopCloser(strings.NewReader(encoded))
|
|||
|
|
r.httpReq.ContentLength = int64(len(encoded))
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// applyMultipartBody 应用 multipart 请求体
|
|||
|
|
func (r *Request) applyMultipartBody() error {
|
|||
|
|
pr, pw := io.Pipe()
|
|||
|
|
writer := multipart.NewWriter(pw)
|
|||
|
|
|
|||
|
|
// 设置 Content-Type
|
|||
|
|
r.httpReq.Header.Set("Content-Type", writer.FormDataContentType())
|
|||
|
|
r.httpReq.Body = pr
|
|||
|
|
|
|||
|
|
// 在 goroutine 中写入数据
|
|||
|
|
go func() {
|
|||
|
|
defer pw.Close()
|
|||
|
|
defer writer.Close()
|
|||
|
|
|
|||
|
|
// 写入表单字段
|
|||
|
|
for k, vs := range r.config.Body.FormData {
|
|||
|
|
for _, v := range vs {
|
|||
|
|
if err := writer.WriteField(k, v); err != nil {
|
|||
|
|
pw.CloseWithError(wrapError(err, "write form field"))
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 写入文件
|
|||
|
|
for _, file := range r.config.Body.Files {
|
|||
|
|
if err := r.writeFile(writer, file); err != nil {
|
|||
|
|
pw.CloseWithError(err)
|
|||
|
|
return
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
}()
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// writeFile 写入文件到 multipart writer
|
|||
|
|
func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error {
|
|||
|
|
// 创建文件字段
|
|||
|
|
part, err := writer.CreateFormFile(file.FormName, file.FileName)
|
|||
|
|
if err != nil {
|
|||
|
|
return wrapError(err, "create form file")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 获取文件数据源
|
|||
|
|
var reader io.Reader
|
|||
|
|
if file.FileData != nil {
|
|||
|
|
reader = file.FileData
|
|||
|
|
} else if file.FilePath != "" {
|
|||
|
|
f, err := os.Open(file.FilePath)
|
|||
|
|
if err != nil {
|
|||
|
|
return wrapError(err, "open file")
|
|||
|
|
}
|
|||
|
|
defer f.Close()
|
|||
|
|
reader = f
|
|||
|
|
} else {
|
|||
|
|
return ErrNilReader
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 复制文件数据(带进度)
|
|||
|
|
if r.config.UploadProgress != nil {
|
|||
|
|
_, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress)
|
|||
|
|
} else {
|
|||
|
|
_, err = io.Copy(part, reader)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if err != nil {
|
|||
|
|
return wrapError(err, "copy file data")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// prepare 准备请求(应用配置)
|
|||
|
|
func (r *Request) prepare() error {
|
|||
|
|
if r.applied {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
defer func() { r.applied = true }()
|
|||
|
|
// 即使 raw 模式也要确保有 httpClient
|
|||
|
|
if r.httpClient == nil {
|
|||
|
|
var err error
|
|||
|
|
r.httpClient, err = r.buildHTTPClient()
|
|||
|
|
if err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
// 原始模式不修改请求内容
|
|||
|
|
if r.doRaw {
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用查询参数
|
|||
|
|
if len(r.config.Queries) > 0 {
|
|||
|
|
q := r.httpReq.URL.Query()
|
|||
|
|
for k, values := range r.config.Queries {
|
|||
|
|
for _, v := range values {
|
|||
|
|
q.Add(k, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
r.httpReq.URL.RawQuery = q.Encode()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用 Headers
|
|||
|
|
for k, values := range r.config.Headers {
|
|||
|
|
for _, v := range values {
|
|||
|
|
r.httpReq.Header.Add(k, v)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用 Cookies
|
|||
|
|
for _, cookie := range r.config.Cookies {
|
|||
|
|
r.httpReq.AddCookie(cookie)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用 Basic Auth
|
|||
|
|
if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" {
|
|||
|
|
r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1])
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用请求体
|
|||
|
|
if err := r.applyBody(); err != nil {
|
|||
|
|
return err
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 应用 Content-Length
|
|||
|
|
if r.config.ContentLength > 0 {
|
|||
|
|
r.httpReq.ContentLength = r.config.ContentLength
|
|||
|
|
} else if r.config.ContentLength < 0 {
|
|||
|
|
r.httpReq.ContentLength = 0
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 自动计算 Content-Length
|
|||
|
|
if r.config.AutoCalcContentLength && r.httpReq.Body != nil {
|
|||
|
|
data, err := io.ReadAll(r.httpReq.Body)
|
|||
|
|
if err != nil {
|
|||
|
|
return wrapError(err, "read body for content length")
|
|||
|
|
}
|
|||
|
|
r.httpReq.ContentLength = int64(len(data))
|
|||
|
|
r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 设置 TLS ServerName(如果有 TLS Config)
|
|||
|
|
if r.config.TLS.Config != nil && r.httpReq.URL != nil {
|
|||
|
|
r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注入配置到 context
|
|||
|
|
r.execCtx = injectRequestConfig(r.ctx, r.config)
|
|||
|
|
r.httpReq = r.httpReq.WithContext(r.execCtx)
|
|||
|
|
|
|||
|
|
return nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// buildHTTPClient 构建 HTTP Client
|
|||
|
|
func (r *Request) buildHTTPClient() (*http.Client, error) {
|
|||
|
|
applyTimeoutOverride := func(base *http.Client) *http.Client {
|
|||
|
|
// 没有 base 时兜底
|
|||
|
|
if base == nil {
|
|||
|
|
base = &http.Client{}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
rt := r.config.Network.Timeout
|
|||
|
|
|
|||
|
|
// 语义:
|
|||
|
|
// rt < 0 : 本次请求禁用超时(Timeout = 0)
|
|||
|
|
// rt = 0 : 沿用 base.Timeout
|
|||
|
|
// rt > 0 : 本次请求超时覆盖
|
|||
|
|
if rt == 0 {
|
|||
|
|
return base
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
clone := &http.Client{
|
|||
|
|
Transport: base.Transport,
|
|||
|
|
CheckRedirect: base.CheckRedirect,
|
|||
|
|
Jar: base.Jar,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if rt < 0 {
|
|||
|
|
clone.Timeout = 0
|
|||
|
|
} else {
|
|||
|
|
clone.Timeout = rt
|
|||
|
|
}
|
|||
|
|
return clone
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 优先使用请求关联的 Client
|
|||
|
|
if r.client != nil {
|
|||
|
|
return applyTimeoutOverride(r.client.HTTPClient()), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 自定义 Transport
|
|||
|
|
if r.config.CustomTransport && r.config.Transport != nil {
|
|||
|
|
base := &http.Client{
|
|||
|
|
Transport: &Transport{base: r.config.Transport},
|
|||
|
|
Timeout: 0,
|
|||
|
|
}
|
|||
|
|
return applyTimeoutOverride(base), nil
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 默认全局 client
|
|||
|
|
return applyTimeoutOverride(DefaultHTTPClient()), nil
|
|||
|
|
}
|