From e3b7369e12deff9d9eefc57a5ae04745a47f0feb Mon Sep 17 00:00:00 2001 From: starainrt Date: Wed, 13 Aug 2025 10:16:08 +0800 Subject: [PATCH] bug fix:nil pointer error --- curl.go | 102 +++++++++++++++++++++++++++++++++++++-------------- curl_test.go | 90 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 27 deletions(-) diff --git a/curl.go b/curl.go index 34d42c5..65d64f7 100644 --- a/curl.go +++ b/curl.go @@ -141,7 +141,12 @@ func (r *Request) Clone() *Request { autoCalcContentLength: r.autoCalcContentLength, }, } - + if r.doRawClient { + clonedRequest.rawClient = r.rawClient + } + if r.doRawRequest { + clonedRequest.rawRequest = r.rawRequest + } // 手动深拷贝嵌套引用类型 if r.bodyDataReader != nil { clonedRequest.bodyDataReader = r.bodyDataReader @@ -161,6 +166,9 @@ func (r *Request) Clone() *Request { clonedRequest.transport = CloneTransport(r.transport) } + if clonedRequest.rawRequest == nil { + clonedRequest.rawRequest, _ = http.NewRequestWithContext(clonedRequest.ctx, clonedRequest.method, clonedRequest.uri, nil) + } return clonedRequest } @@ -762,76 +770,86 @@ func (r *Request) AddCookie(key, value, path string) *Request { } func (r *Request) AddFile(formName, filepath string) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - stat, err := f.Stat() + stat, err := os.Stat(filepath) if err != nil { return err } r.bodyFileData = append(r.bodyFileData, RequestFile{ FormName: formName, FileName: stat.Name(), - FileData: f, + FileData: nil, FileSize: stat.Size(), FileType: "application/octet-stream", + FilePath: filepath, + }) + return nil +} + +func (r *Request) AddFileStream(formName, filename string, size int64, stream io.Reader) error { + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: stream, + FileSize: size, + FileType: "application/octet-stream", }) return nil } func (r *Request) AddFileWithName(formName, filepath, filename string) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - stat, err := f.Stat() + stat, err := os.Stat(filepath) if err != nil { return err } r.bodyFileData = append(r.bodyFileData, RequestFile{ FormName: formName, FileName: filename, - FileData: f, + FileData: nil, FileSize: stat.Size(), FileType: "application/octet-stream", + FilePath: filepath, }) return nil } func (r *Request) AddFileWithType(formName, filepath, filetype string) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - stat, err := f.Stat() + stat, err := os.Stat(filepath) if err != nil { return err } r.bodyFileData = append(r.bodyFileData, RequestFile{ FormName: formName, FileName: stat.Name(), - FileData: f, + FileData: nil, FileSize: stat.Size(), FileType: filetype, + FilePath: filepath, }) return nil } func (r *Request) AddFileWithNameAndType(formName, filepath, filename, filetype string) error { - f, err := os.Open(filepath) - if err != nil { - return err - } - stat, err := f.Stat() + stat, err := os.Stat(filepath) if err != nil { return err } r.bodyFileData = append(r.bodyFileData, RequestFile{ FormName: formName, FileName: filename, - FileData: f, + FileData: nil, FileSize: stat.Size(), FileType: filetype, + FilePath: filepath, + }) + return nil +} + +func (r *Request) AddFileStreamWithType(formName, filename, filetype string, size int64, stream io.Reader) error { + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: stream, + FileSize: size, + FileType: filetype, }) return nil } @@ -855,6 +873,16 @@ func (r *Request) AddFileWithNameAndTypeNoError(formName, filepath, filename, fi return r } +func (r *Request) AddFileStreamNoError(formName, filename string, size int64, stream io.Reader) *Request { + r.AddFileStream(formName, filename, size, stream) + return r +} + +func (r *Request) AddFileStreamWithTypeNoError(formName, filename, filetype string, size int64, stream io.Reader) *Request { + r.AddFileStreamWithType(formName, filename, filetype, size, stream) + return r +} + func (r *Request) HttpClient() (*http.Client, error) { err := applyOptions(r) if err != nil { @@ -869,6 +897,7 @@ type RequestFile struct { FileData io.Reader FileSize int64 FileType string + FilePath string } type RequestOpt func(opt *RequestOpts) error @@ -1555,15 +1584,32 @@ func applyDataReader(r *Request) error { } } } - for _, v := range r.bodyFileData { + for idx, v := range r.bodyFileData { var fw, err = w.CreateFormFile(v.FormName, v.FileName) if err != nil { r.errInfo = err pw.CloseWithError(err) // close pipe with error return } + if v.FileData == nil { + if v.FilePath != "" { + tmpFile, err := os.Open(v.FilePath) + if err != nil { + r.errInfo = err + pw.CloseWithError(err) // close pipe with error + return + } + defer tmpFile.Close() + v.FileData = tmpFile + } else { + r.errInfo = fmt.Errorf("io reader is nil") + pw.CloseWithError(fmt.Errorf("io reader is nil")) // close pipe with error + return + } + } if _, err := copyWithContext(r.ctx, r.FileUploadRecallFn, v.FileName, v.FileSize, fw, v.FileData); err != nil { r.errInfo = err + r.bodyFileData[idx] = v pw.CloseWithError(err) // close pipe with error return } @@ -1692,7 +1738,9 @@ func copyWithContext(ctx context.Context, recall func(string, int64, int64), fil nr, err := pr.Read(buf) if err != nil { if err == io.EOF { - go recall(filename, count, total) + if recall != nil { + go recall(filename, count, total) + } return written, nil } return written, err diff --git a/curl_test.go b/curl_test.go index 29c21b4..feec396 100644 --- a/curl_test.go +++ b/curl_test.go @@ -462,3 +462,93 @@ func TestCurl(t *testing.T) { }) } } + +func TestReqClone(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("hello") != "world" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("hello world failed")) + return + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + req := NewSimpleRequestWithClient(http.DefaultClient, server.URL, "GET", WithHeader("hello", "world")) + resp, err := req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() + req = req.Clone() + req.AddHeader("ok", "good") + resp, err = req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() +} + +func TestUploadFile(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Header.Get("hello") != "world" { + rw.WriteHeader(http.StatusBadRequest) + rw.Write([]byte("hello world failed")) + return + } + files, header, err := req.FormFile("666") + if err == nil { + fmt.Println(header.Filename) + fmt.Println(header.Size) + fmt.Println(files.Close()) + } + files, header, err = req.FormFile("777") + if err == nil { + fmt.Println(header.Filename) + fmt.Println(header.Size) + fmt.Println(files.Close()) + } + files, header, err = req.FormFile("888") + if err == nil { + fmt.Println(header.Filename) + fmt.Println(header.Size) + fmt.Println(files.Close()) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + req := NewSimpleRequestWithClient(http.DefaultClient, server.URL, "GET", WithHeader("hello", "world")) + req.AddFileWithName("666", "./curl.go", "curl.go") + req.AddFile("777", "./go.mod") + req.AddFileWithNameAndType("888", "./ping.go", "ping.go", "html") + resp, err := req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() + req = req.Clone() + req.AddHeader("ok", "good") + + resp, err = req.Do() + if err != nil { + t.Error(err) + } + if resp.StatusCode != 200 { + resp.CloseAll() + t.Errorf("status code is %d", resp.StatusCode) + } + resp.CloseAll() +}