diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..723ef36 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.idea \ No newline at end of file diff --git a/curl.go b/curl.go index 4015a77..95f820c 100644 --- a/curl.go +++ b/curl.go @@ -14,6 +14,7 @@ import ( "os" "strconv" "strings" + "sync" "time" ) @@ -108,6 +109,185 @@ type Request struct { RequestOpts } +func (r *Request) Clone() *Request { + clonedRequest := &Request{ + ctx: r.ctx, + uri: r.uri, + method: r.method, + errInfo: r.errInfo, + RequestOpts: RequestOpts{ + headers: CloneHeader(r.headers), + cookies: CloneCookies(r.cookies), + bodyFormData: CloneStringMapSlice(r.bodyFormData), + bodyFileData: CloneFiles(r.bodyFileData), + queries: CloneStringMapSlice(r.queries), + bodyDataBytes: CloneByteSlice(r.bodyDataBytes), + proxy: r.proxy, + timeout: r.timeout, + dialTimeout: r.dialTimeout, + alreadyApply: r.alreadyApply, + disableRedirect: r.disableRedirect, + doRawRequest: r.doRawRequest, + doRawClient: r.doRawClient, + doRawTransport: r.doRawTransport, + skipTLSVerify: r.skipTLSVerify, + autoFetchRespBody: r.autoFetchRespBody, + customIP: CloneStringSlice(r.customIP), + alreadySetLookUpIPfn: r.alreadySetLookUpIPfn, + lookUpIPfn: r.lookUpIPfn, + customDNS: CloneStringSlice(r.customDNS), + basicAuth: r.basicAuth, + autoCalcContentLength: r.autoCalcContentLength, + }, + } + + // 手动深拷贝嵌套引用类型 + if r.bodyDataReader != nil { + clonedRequest.bodyDataReader = r.bodyDataReader + } + + if r.FileUploadRecallFn != nil { + clonedRequest.FileUploadRecallFn = r.FileUploadRecallFn + } + + // 对于 tlsConfig 类型,需要手动复制 + if r.tlsConfig != nil { + clonedRequest.tlsConfig = CloneTLSConfig(r.tlsConfig) + } + + // 对于 http.Transport,需要进行手动复制 + if r.transport != nil { + clonedRequest.transport = CloneTransport(r.transport) + } + + return clonedRequest +} + +// CloneHeader 复制 http.Header +func CloneHeader(original http.Header) http.Header { + newHeader := make(http.Header) + for key, values := range original { + copiedValues := make([]string, len(values)) + copy(copiedValues, values) + newHeader[key] = copiedValues + } + return newHeader +} + +// CloneCookies 复制 []*http.Cookie +func CloneCookies(original []*http.Cookie) []*http.Cookie { + cloned := make([]*http.Cookie, len(original)) + for i, cookie := range original { + cloned[i] = &http.Cookie{ + Name: cookie.Name, + Value: cookie.Value, + Path: cookie.Path, + Domain: cookie.Domain, + Expires: cookie.Expires, + RawExpires: cookie.RawExpires, + MaxAge: cookie.MaxAge, + Secure: cookie.Secure, + HttpOnly: cookie.HttpOnly, + SameSite: cookie.SameSite, + Raw: cookie.Raw, + Unparsed: append([]string(nil), cookie.Unparsed...), + } + } + return cloned +} + +// CloneStringMapSlice 复制 map[string][]string +func CloneStringMapSlice(original map[string][]string) map[string][]string { + newMap := make(map[string][]string) + for key, values := range original { + copiedValues := make([]string, len(values)) + copy(copiedValues, values) + newMap[key] = copiedValues + } + return newMap +} + +// CloneFiles 复制 []RequestFile +func CloneFiles(original []RequestFile) []RequestFile { + newFiles := make([]RequestFile, len(original)) + copy(newFiles, original) + return newFiles +} + +// CloneByteSlice 复制 []byte +func CloneByteSlice(original []byte) []byte { + if original == nil { + return nil + } + newSlice := make([]byte, len(original)) + copy(newSlice, original) + return newSlice +} + +// CloneStringSlice 复制 []string +func CloneStringSlice(original []string) []string { + newSlice := make([]string, len(original)) + copy(newSlice, original) + return newSlice +} + +// CloneTLSConfig 复制 tls.Config +func CloneTLSConfig(original *tls.Config) *tls.Config { + newConfig := &tls.Config{ + Rand: original.Rand, + Time: original.Time, + Certificates: append([]tls.Certificate(nil), original.Certificates...), + NameToCertificate: original.NameToCertificate, + GetCertificate: original.GetCertificate, + GetClientCertificate: original.GetClientCertificate, + GetConfigForClient: original.GetConfigForClient, + VerifyPeerCertificate: original.VerifyPeerCertificate, + VerifyConnection: original.VerifyConnection, + RootCAs: original.RootCAs, + NextProtos: append([]string(nil), original.NextProtos...), + ServerName: original.ServerName, + ClientAuth: original.ClientAuth, + ClientCAs: original.ClientCAs, + InsecureSkipVerify: original.InsecureSkipVerify, + CipherSuites: append([]uint16(nil), original.CipherSuites...), + PreferServerCipherSuites: original.PreferServerCipherSuites, + SessionTicketsDisabled: original.SessionTicketsDisabled, + SessionTicketKey: original.SessionTicketKey, + ClientSessionCache: original.ClientSessionCache, + MinVersion: original.MinVersion, + MaxVersion: original.MaxVersion, + CurvePreferences: append([]tls.CurveID(nil), original.CurvePreferences...), + DynamicRecordSizingDisabled: original.DynamicRecordSizingDisabled, + Renegotiation: original.Renegotiation, + KeyLogWriter: original.KeyLogWriter, + } + return newConfig +} + +// CloneTransport 复制 http.Transport +func CloneTransport(original *http.Transport) *http.Transport { + newTransport := &http.Transport{ + Proxy: original.Proxy, + DialContext: original.DialContext, + Dial: original.Dial, + DialTLS: original.DialTLS, + TLSClientConfig: original.TLSClientConfig, + TLSHandshakeTimeout: original.TLSHandshakeTimeout, + DisableKeepAlives: original.DisableKeepAlives, + DisableCompression: original.DisableCompression, + MaxIdleConns: original.MaxIdleConns, + MaxIdleConnsPerHost: original.MaxIdleConnsPerHost, + IdleConnTimeout: original.IdleConnTimeout, + ResponseHeaderTimeout: original.ResponseHeaderTimeout, + ExpectContinueTimeout: original.ExpectContinueTimeout, + TLSNextProto: original.TLSNextProto, + ProxyConnectHeader: original.ProxyConnectHeader, + MaxResponseHeaderBytes: original.MaxResponseHeaderBytes, + WriteBufferSize: original.WriteBufferSize, + ReadBufferSize: original.ReadBufferSize, + } + return newTransport +} func (r *Request) Method() string { return r.method } @@ -1071,10 +1251,17 @@ type Body struct { full []byte raw io.ReadCloser isFull bool + sync.Mutex } func (b *Body) readAll() { + b.Lock() + defer b.Unlock() if !b.isFull { + if b.raw == nil { + b.isFull = true + return + } b.full, _ = io.ReadAll(b.raw) b.isFull = true b.raw.Close() @@ -1099,6 +1286,8 @@ func (b *Body) Unmarshal(u interface{}) error { // Reader returns a reader for the body // if this function is called, other functions like String, Bytes, Unmarshal may not work func (b *Body) Reader() io.ReadCloser { + b.Lock() + defer b.Unlock() if b.isFull { return io.NopCloser(bytes.NewReader(b.full)) }