package starnet import ( "bytes" "encoding/json" "io" "mime/multipart" "net/http" "net/url" "os" "strings" ) // SetBody 设置请求体(字节) func (r *Request) SetBody(body []byte) *Request { if r.err != nil { return r } if r.doRaw { return r } r.config.Body.Bytes = body r.config.Body.Reader = nil return r } // SetBodyReader 设置请求体(Reader) func (r *Request) SetBodyReader(reader io.Reader) *Request { if r.err != nil { return r } if r.doRaw { return r } r.config.Body.Reader = reader r.config.Body.Bytes = nil return r } // SetBodyString 设置请求体(字符串) func (r *Request) SetBodyString(body string) *Request { return r.SetBody([]byte(body)) } // SetJSON 设置 JSON 请求体 func (r *Request) SetJSON(v interface{}) *Request { if r.err != nil { return r } data, err := json.Marshal(v) if err != nil { r.err = wrapError(err, "marshal json") return r } return r.SetContentType(ContentTypeJSON).SetBody(data) } // SetFormData 设置表单数据(覆盖) func (r *Request) SetFormData(data map[string][]string) *Request { if r.err != nil { return r } if r.doRaw { return r } r.config.Body.FormData = data return r } // AddFormData 添加表单数据 func (r *Request) AddFormData(key, value string) *Request { if r.err != nil { return r } if r.doRaw { return r } r.config.Body.FormData[key] = append(r.config.Body.FormData[key], value) return r } // AddFormDataMap 批量添加表单数据 func (r *Request) AddFormDataMap(data map[string]string) *Request { if r.err != nil { return r } if r.doRaw { return r } for k, v := range data { r.config.Body.FormData[k] = append(r.config.Body.FormData[k], v) } return r } // AddFile 添加文件(从路径) func (r *Request) AddFile(formName, filePath string) *Request { if r.err != nil { return r } stat, err := os.Stat(filePath) if err != nil { r.err = wrapError(ErrFileNotFound, "file: %s", filePath) return r } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: stat.Name(), FilePath: filePath, FileSize: stat.Size(), FileType: ContentTypeOctetStream, }) return r } // AddFileWithName 添加文件(指定文件名) func (r *Request) AddFileWithName(formName, filePath, fileName string) *Request { if r.err != nil { return r } stat, err := os.Stat(filePath) if err != nil { r.err = wrapError(ErrFileNotFound, "file: %s", filePath) return r } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, FilePath: filePath, FileSize: stat.Size(), FileType: ContentTypeOctetStream, }) return r } // AddFileWithType 添加文件(指定 MIME 类型) func (r *Request) AddFileWithType(formName, filePath, fileType string) *Request { if r.err != nil { return r } stat, err := os.Stat(filePath) if err != nil { r.err = wrapError(ErrFileNotFound, "file: %s", filePath) return r } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: stat.Name(), FilePath: filePath, FileSize: stat.Size(), FileType: fileType, }) return r } // AddFileStream 添加文件流 func (r *Request) AddFileStream(formName, fileName string, size int64, reader io.Reader) *Request { if r.err != nil { return r } if reader == nil { r.err = ErrNilReader return r } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, FileData: reader, FileSize: size, FileType: ContentTypeOctetStream, }) return r } // AddFileStreamWithType 添加文件流(指定 MIME 类型) func (r *Request) AddFileStreamWithType(formName, fileName, fileType string, size int64, reader io.Reader) *Request { if r.err != nil { return r } if reader == nil { r.err = ErrNilReader return r } r.config.Body.Files = append(r.config.Body.Files, RequestFile{ FormName: formName, FileName: fileName, FileData: reader, FileSize: size, FileType: fileType, }) return r } // applyBody 应用请求体 func (r *Request) applyBody() error { // 优先级:Reader > Bytes > Files > FormData // 1. Reader if r.config.Body.Reader != nil { r.httpReq.Body = io.NopCloser(r.config.Body.Reader) // 尝试获取长度 switch v := r.config.Body.Reader.(type) { case *bytes.Buffer: r.httpReq.ContentLength = int64(v.Len()) case *bytes.Reader: r.httpReq.ContentLength = int64(v.Len()) case *strings.Reader: r.httpReq.ContentLength = int64(v.Len()) } return nil } // 2. Bytes if len(r.config.Body.Bytes) > 0 { r.httpReq.Body = io.NopCloser(bytes.NewReader(r.config.Body.Bytes)) r.httpReq.ContentLength = int64(len(r.config.Body.Bytes)) return nil } // 3. Files(multipart/form-data) if len(r.config.Body.Files) > 0 { return r.applyMultipartBody() } // 4. FormData(application/x-www-form-urlencoded) if len(r.config.Body.FormData) > 0 { values := url.Values{} for k, vs := range r.config.Body.FormData { for _, v := range vs { values.Add(k, v) } } encoded := values.Encode() r.httpReq.Body = io.NopCloser(strings.NewReader(encoded)) r.httpReq.ContentLength = int64(len(encoded)) return nil } return nil } // applyMultipartBody 应用 multipart 请求体 func (r *Request) applyMultipartBody() error { pr, pw := io.Pipe() writer := multipart.NewWriter(pw) // 设置 Content-Type r.httpReq.Header.Set("Content-Type", writer.FormDataContentType()) r.httpReq.Body = pr // 在 goroutine 中写入数据 go func() { defer pw.Close() defer writer.Close() // 写入表单字段 for k, vs := range r.config.Body.FormData { for _, v := range vs { if err := writer.WriteField(k, v); err != nil { pw.CloseWithError(wrapError(err, "write form field")) return } } } // 写入文件 for _, file := range r.config.Body.Files { if err := r.writeFile(writer, file); err != nil { pw.CloseWithError(err) return } } }() return nil } // writeFile 写入文件到 multipart writer func (r *Request) writeFile(writer *multipart.Writer, file RequestFile) error { // 创建文件字段 part, err := writer.CreateFormFile(file.FormName, file.FileName) if err != nil { return wrapError(err, "create form file") } // 获取文件数据源 var reader io.Reader if file.FileData != nil { reader = file.FileData } else if file.FilePath != "" { f, err := os.Open(file.FilePath) if err != nil { return wrapError(err, "open file") } defer f.Close() reader = f } else { return ErrNilReader } // 复制文件数据(带进度) if r.config.UploadProgress != nil { _, err = copyWithProgress(r.ctx, part, reader, file.FileName, file.FileSize, r.config.UploadProgress) } else { _, err = io.Copy(part, reader) } if err != nil { return wrapError(err, "copy file data") } return nil } // prepare 准备请求(应用配置) func (r *Request) prepare() error { if r.applied { return nil } defer func() { r.applied = true }() // 即使 raw 模式也要确保有 httpClient if r.httpClient == nil { var err error r.httpClient, err = r.buildHTTPClient() if err != nil { return err } } // 原始模式不修改请求内容 if r.doRaw { return nil } // 应用查询参数 if len(r.config.Queries) > 0 { q := r.httpReq.URL.Query() for k, values := range r.config.Queries { for _, v := range values { q.Add(k, v) } } r.httpReq.URL.RawQuery = q.Encode() } // 应用 Headers for k, values := range r.config.Headers { for _, v := range values { r.httpReq.Header.Add(k, v) } } // 应用 Cookies for _, cookie := range r.config.Cookies { r.httpReq.AddCookie(cookie) } // 应用 Basic Auth if r.config.BasicAuth[0] != "" || r.config.BasicAuth[1] != "" { r.httpReq.SetBasicAuth(r.config.BasicAuth[0], r.config.BasicAuth[1]) } // 应用请求体 if err := r.applyBody(); err != nil { return err } // 应用 Content-Length if r.config.ContentLength > 0 { r.httpReq.ContentLength = r.config.ContentLength } else if r.config.ContentLength < 0 { r.httpReq.ContentLength = 0 } // 自动计算 Content-Length if r.config.AutoCalcContentLength && r.httpReq.Body != nil { data, err := io.ReadAll(r.httpReq.Body) if err != nil { return wrapError(err, "read body for content length") } r.httpReq.ContentLength = int64(len(data)) r.httpReq.Body = io.NopCloser(bytes.NewBuffer(data)) } // 设置 TLS ServerName(如果有 TLS Config) if r.config.TLS.Config != nil && r.httpReq.URL != nil { r.config.TLS.Config.ServerName = r.httpReq.URL.Hostname() } // 注入配置到 context r.execCtx = injectRequestConfig(r.ctx, r.config) r.httpReq = r.httpReq.WithContext(r.execCtx) return nil } // buildHTTPClient 构建 HTTP Client func (r *Request) buildHTTPClient() (*http.Client, error) { applyTimeoutOverride := func(base *http.Client) *http.Client { // 没有 base 时兜底 if base == nil { base = &http.Client{} } rt := r.config.Network.Timeout // 语义: // rt < 0 : 本次请求禁用超时(Timeout = 0) // rt = 0 : 沿用 base.Timeout // rt > 0 : 本次请求超时覆盖 if rt == 0 { return base } clone := &http.Client{ Transport: base.Transport, CheckRedirect: base.CheckRedirect, Jar: base.Jar, } if rt < 0 { clone.Timeout = 0 } else { clone.Timeout = rt } return clone } // 优先使用请求关联的 Client if r.client != nil { return applyTimeoutOverride(r.client.HTTPClient()), nil } // 自定义 Transport if r.config.CustomTransport && r.config.Transport != nil { base := &http.Client{ Transport: &Transport{base: r.config.Transport}, Timeout: 0, } return applyTimeoutOverride(base), nil } // 默认全局 client return applyTimeoutOverride(DefaultHTTPClient()), nil }