From 1de78f2f06b8b60b2c4bafbddfd556de8f3b5cd6 Mon Sep 17 00:00:00 2001 From: starainrt Date: Thu, 8 Aug 2024 22:03:10 +0800 Subject: [PATCH] rewrite curl.go --- curl.go | 1597 ++++++++++++++++++++++++++++++++++++++++---------- curl_test.go | 464 +++++++++++++++ go.mod | 2 - go.sum | 47 -- httpguts.go | 120 ++++ 5 files changed, 1876 insertions(+), 354 deletions(-) create mode 100644 curl_test.go create mode 100644 httpguts.go diff --git a/curl.go b/curl.go index 96d7502..0f6a4ea 100644 --- a/curl.go +++ b/curl.go @@ -3,19 +3,18 @@ package starnet import ( "bytes" "context" - "crypto/rand" "crypto/tls" - "errors" + "encoding/json" "fmt" "io" + "mime/multipart" "net" "net/http" "net/url" "os" + "strconv" "strings" "time" - - "b612.me/stario" ) const ( @@ -25,439 +24,1427 @@ const ( HEADER_PLAIN = `text/plain` ) -type RequestFile struct { - UploadFile string - UploadForm map[string]string - UploadName string +var ( + DefaultDialTimeout = 5 * time.Second + DefaultTimeout = 10 * time.Second + DefaultFetchRespBody = false +) + +func UrlEncodeRaw(str string) string { + strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1) + return strs +} + +func UrlEncode(str string) string { + return url.QueryEscape(str) +} + +func UrlDecode(str string) (string, error) { + return url.QueryUnescape(str) +} + +func BuildQuery(queryData map[string]string) string { + query := url.Values{} + for k, v := range queryData { + query.Add(k, v) + } + return query.Encode() +} + +// BuildPostForm takes a map of string keys and values, converts it into a URL-encoded query string, +// and then converts that string into a byte slice. This function is useful for preparing data for HTTP POST requests, +// where the server expects the request body to be URL-encoded form data. +// +// Parameters: +// queryMap: A map where the key-value pairs represent the form data to be sent in the HTTP POST request. +// +// Returns: +// A byte slice representing the URL-encoded form data. +func BuildPostForm(queryMap map[string]string) []byte { + return []byte(BuildQuery(queryMap)) +} + +func Get(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "GET", opts...).Do() +} + +func Post(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "POST", opts...).Do() +} + +func Options(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "OPTIONS", opts...).Do() +} + +func Put(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "PUT", opts...).Do() +} + +func Delete(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "DELETE", opts...).Do() +} + +func Head(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "HEAD", opts...).Do() +} + +func Patch(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "PATCH", opts...).Do() +} + +func Trace(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "TRACE", opts...).Do() +} + +func Connect(uri string, opts ...RequestOpt) (*Response, error) { + return NewSimpleRequest(uri, "CONNECT", opts...).Do() } type Request struct { - Url string - RespURL string - Method string - RecvData []byte - RecvContentLength int64 - RecvIo io.Writer - RespHeader http.Header - RespCookies []*http.Cookie - RespHttpCode int - Location *url.URL - CircleBuffer *stario.StarBuffer - respReader io.ReadCloser - respOrigin *http.Response - reqOrigin *http.Request + ctx context.Context + uri string + method string + errInfo error RequestOpts } +func (r *Request) Method() string { + return r.method +} + +func (r *Request) SetMethod(method string) error { + method = strings.ToUpper(method) + if !validMethod(method) { + return fmt.Errorf("invalid method: %s", method) + } + r.method = method + r.rawRequest.Method = method + return nil +} + +func (r *Request) SetMethodNoError(method string) *Request { + r.SetMethod(method) + return r +} + +func (r *Request) Uri() string { + return r.uri +} + +func (r *Request) SetUri(uri string) error { + u, err := url.Parse(uri) + if err != nil { + return fmt.Errorf("parse uri error: %s", err) + } + r.uri = uri + u.Host = removeEmptyPort(u.Host) + r.rawRequest.Host = u.Host + r.rawRequest.URL = u + return nil +} + +func (r *Request) SetUriNoError(uri string) *Request { + r.SetUri(uri) + return r +} + +func (r *Request) RawRequest() *http.Request { + return r.rawRequest +} + +func (r *Request) SetRawRequest(rawRequest *http.Request) *Request { + r.rawRequest = rawRequest + return r +} + +func (r *Request) RawClient() *http.Client { + return r.rawClient +} + +func (r *Request) SetRawClient(rawClient *http.Client) *Request { + r.rawClient = rawClient + return r +} + +func (r *Request) Do() (*Response, error) { + return Curl(r) +} + +func (r *Request) Get() (*Response, error) { + err := r.SetMethod("GET") + if err != nil { + return nil, err + } + return Curl(r) +} + +func (r *Request) Post(data []byte) (*Response, error) { + err := r.SetMethod("POST") + if err != nil { + return nil, err + } + r.bodyDataBytes = data + r.bodyDataReader = nil + return Curl(r) +} + type RequestOpts struct { - RequestFile - PostBuffer io.Reader - Process func(float64) - Proxy string - Timeout time.Duration - DialTimeout time.Duration - ReqHeader http.Header - ReqCookies []*http.Cookie - WriteRecvData bool - SkipTLSVerify bool - CustomTransport *http.Transport - Queries map[string]string - DisableRedirect bool - TlsConfig *tls.Config -} - -type RequestOpt func(opt *RequestOpts) + rawRequest *http.Request + rawClient *http.Client + + alreadyApply bool + bodyDataBytes []byte + bodyDataReader io.Reader + bodyFormData map[string][]string + bodyFileData []RequestFile + //以上优先度为 bodyDataReader> bodyDataBytes > bodyFormData > bodyFileData + FileUploadRecallFn func(filename string, upPos int64, total int64) + proxy string + timeout time.Duration + dialTimeout time.Duration + headers http.Header + cookies []*http.Cookie + transport *http.Transport + queries map[string][]string + disableRedirect bool + //doRawRequest=true 不对request修改,直接发送 + doRawRequest bool + //doRawClient=true 不对http client修改,直接发送 + doRawClient bool + //doRawTransPort=true 不对http transport修改,直接发送 + doRawTransport bool + skipTLSVerify bool + tlsConfig *tls.Config + autoFetchRespBody bool + customIP []string + alreadySetLookUpIPfn bool + lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error) + customDNS []string + basicAuth [2]string + autoCalcContentLength bool +} + +func (r *Request) AutoCalcContentLength() bool { + return r.autoCalcContentLength +} + +// SetAutoCalcContentLength sets whether to automatically calculate the Content-Length header based on the request body. +// WARN: If set to true, the Content-Length header will be set to the length of the request body, which may cause issues with chunked transfer encoding. +// also the memory usage will be higher +// Note that this function will not work if doRawRequest is true +func (r *Request) SetAutoCalcContentLength(autoCalcContentLength bool) *Request { + r.autoCalcContentLength = autoCalcContentLength + return r +} + +// BasicAuth returns the username and password provided in the request's Authorization header. +func (r *RequestOpts) BasicAuth() (string, string) { + return r.basicAuth[0], r.basicAuth[1] +} + +// SetBasicAuth sets the request's Authorization header to use HTTP Basic Authentication with the provided username and password. +// Note: If doRawRequest is true, this function will nolonger work +func (r *RequestOpts) SetBasicAuth(username, password string) *RequestOpts { + r.basicAuth = [2]string{username, password} + return r +} + +func (r *Request) DoRawTransport() bool { + return r.doRawTransport +} + +func (r *Request) SetDoRawTransport(doRawTransport bool) *Request { + r.doRawTransport = doRawTransport + return r +} + +func (r *Request) CustomDNS() []string { + return r.customDNS +} + +// Note: if LookUpIPfn is set, this function will not be used +func (r *Request) SetCustomDNS(customDNS []string) error { + for _, v := range customDNS { + if net.ParseIP(v) == nil { + return fmt.Errorf("invalid custom dns: %s", v) + } + } + r.customDNS = customDNS + return nil +} + +// Note: if LookUpIPfn is set, this function will not be used +func (r *Request) SetCustomDNSNoError(customDNS []string) *Request { + r.SetCustomDNS(customDNS) + return r +} + +// Note: if LookUpIPfn is set, this function will not be used +func (r *Request) AddCustomDNS(customDNS []string) error { + for _, v := range customDNS { + if net.ParseIP(v) == nil { + return fmt.Errorf("invalid custom dns: %s", v) + } + } + r.customDNS = customDNS + return nil +} + +// Note: if LookUpIPfn is set, this function will not be used +func (r *Request) AddCustomDNSNoError(customDNS []string) *Request { + r.AddCustomDNS(customDNS) + return r +} + +func (r *Request) LookUpIPfn() func(ctx context.Context, host string) ([]net.IPAddr, error) { + return r.lookUpIPfn +} + +func (r *Request) SetLookUpIPfn(lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error)) *Request { + if lookUpIPfn == nil { + r.alreadySetLookUpIPfn = false + r.lookUpIPfn = net.DefaultResolver.LookupIPAddr + return r + } + r.lookUpIPfn = lookUpIPfn + r.alreadySetLookUpIPfn = true + return r +} + +func (r *Request) CustomHostIP() []string { + return r.customIP +} + +func (r *Request) SetCustomHostIP(customIP []string) *Request { + r.customIP = customIP + return r +} + +func (r *Request) AddCustomHostIP(customIP string) *Request { + r.customIP = append(r.customIP, customIP) + return r +} + +func (r *Request) BodyDataBytes() []byte { + return r.bodyDataBytes +} + +func (r *Request) SetBodyDataBytes(bodyDataBytes []byte) *Request { + r.bodyDataBytes = bodyDataBytes + return r +} + +func (r *Request) BodyDataReader() io.Reader { + return r.bodyDataReader +} + +func (r *Request) SetBodyDataReader(bodyDataReader io.Reader) *Request { + r.bodyDataReader = bodyDataReader + return r +} + +func (r *Request) BodyFormData() map[string][]string { + return r.bodyFormData +} + +func (r *Request) SetBodyFormData(bodyFormData map[string][]string) *Request { + r.bodyFormData = bodyFormData + return r +} + +func (r *Request) SetFormData(bodyFormData map[string][]string) *Request { + return r.SetBodyFormData(bodyFormData) +} + +func (r *Request) AddFormMapData(bodyFormData map[string]string) *Request { + for k, v := range bodyFormData { + r.bodyFormData[k] = append(r.bodyFormData[k], v) + } + return r +} + +func (r *Request) AddFormData(k, v string) *Request { + r.bodyFormData[k] = append(r.bodyFormData[k], v) + return r +} + +func (r *Request) BodyFileData() []RequestFile { + return r.bodyFileData +} + +func (r *Request) SetBodyFileData(bodyFileData []RequestFile) *Request { + r.bodyFileData = bodyFileData + return r +} + +func (r *Request) Proxy() string { + return r.proxy +} + +func (r *Request) SetProxy(proxy string) *Request { + r.proxy = proxy + return r +} + +func (r *Request) Timeout() time.Duration { + return r.timeout +} + +func (r *Request) SetTimeout(timeout time.Duration) *Request { + r.timeout = timeout + return r +} + +func (r *Request) DialTimeout() time.Duration { + return r.dialTimeout +} + +func (r *Request) SetDialTimeout(dialTimeout time.Duration) *Request { + r.dialTimeout = dialTimeout + return r +} + +func (r *Request) Headers() http.Header { + return r.headers +} + +func (r *Request) SetHeaders(headers http.Header) *Request { + r.headers = headers + return r +} + +func (r *Request) AddHeader(key, val string) *Request { + r.headers.Add(key, val) + return r +} + +func (r *Request) SetHeader(key, val string) *Request { + r.headers.Set(key, val) + return r +} + +func (r *Request) SetContentType(ct string) *Request { + r.headers.Set("Content-Type", ct) + return r +} + +func (r *Request) SetUserAgent(ua string) *Request { + r.headers.Set("User-Agent", ua) + return r +} + +func (r *Request) DeleteHeader(key string) *Request { + r.headers.Del(key) + return r +} + +func (r *Request) Cookies() []*http.Cookie { + return r.cookies +} + +func (r *Request) SetCookies(cookies []*http.Cookie) *Request { + r.cookies = cookies + return r +} + +func (r *Request) Transport() *http.Transport { + return r.transport +} + +func (r *Request) SetTransport(transport *http.Transport) *Request { + r.transport = transport + return r +} + +func (r *Request) Queries() map[string][]string { + return r.queries +} + +func (r *Request) SetQueries(queries map[string][]string) *Request { + r.queries = queries + return r +} + +func (r *Request) AddQueries(queries map[string]string) *Request { + for k, v := range queries { + r.queries[k] = append(r.queries[k], v) + } + return r +} + +func (r *Request) AddQuery(key, value string) *Request { + r.queries[key] = append(r.queries[key], value) + return r +} + +func (r *Request) DelQueryKv(key, value string) *Request { + if _, ok := r.queries[key]; !ok { + return r + } + for i, v := range r.queries[key] { + if v == value { + r.queries[key] = append(r.queries[key][:i], r.queries[key][i+1:]...) + } + } + return r +} + +func (r *Request) DelQuery(key string) *Request { + if _, ok := r.queries[key]; !ok { + return r + } + delete(r.queries, key) + return r +} + +func (r *Request) DisableRedirect() bool { + return r.disableRedirect +} + +func (r *Request) SetDisableRedirect(disableRedirect bool) *Request { + r.disableRedirect = disableRedirect + return r +} + +func (r *Request) DoRawRequest() bool { + return r.doRawRequest +} + +func (r *Request) SetDoRawRequest(doRawRequest bool) *Request { + r.doRawRequest = doRawRequest + return r +} + +func (r *Request) DoRawClient() bool { + return r.doRawClient +} + +func (r *Request) SetDoRawClient(doRawClient bool) *Request { + r.doRawClient = doRawClient + return r +} + +func (r *RequestOpts) SkipTLSVerify() bool { + return r.skipTLSVerify +} + +func (r *Request) SetSkipTLSVerify(skipTLSVerify bool) *Request { + r.skipTLSVerify = skipTLSVerify + return r +} + +func (r *Request) TlsConfig() *tls.Config { + return r.tlsConfig +} + +func (r *Request) SetTlsConfig(tlsConfig *tls.Config) *Request { + r.tlsConfig = tlsConfig + return r +} + +func (r *Request) AutoFetchRespBody() bool { + return r.autoFetchRespBody +} + +func (r *Request) SetAutoFetchRespBody(autoFetchRespBody bool) *Request { + r.autoFetchRespBody = autoFetchRespBody + return r +} + +func (r *Request) ResetReqHeader() *Request { + r.headers = make(http.Header) + return r +} + +func (r *Request) ResetReqCookies() *Request { + r.cookies = []*http.Cookie{} + return r +} + +func (r *Request) AddSimpleCookie(key, value string) *Request { + r.cookies = append(r.cookies, &http.Cookie{Name: key, Value: value, Path: "/"}) + return r +} + +func (r *Request) AddCookie(key, value, path string) *Request { + r.cookies = append(r.cookies, &http.Cookie{Name: key, Value: value, Path: path}) + return r +} + +func (r *Request) AddFile(formName, filepath string) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FileData: f, + FileSize: stat.Size(), + FileType: "application/octet-stream", + }) + return nil +} + +func (r *Request) AddFileWithName(formName, filepath, filename string) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: f, + FileSize: stat.Size(), + FileType: "application/octet-stream", + }) + return nil +} + +func (r *Request) AddFileWithType(formName, filepath, filetype string) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FileData: f, + FileSize: stat.Size(), + FileType: filetype, + }) + return nil +} +func (r *Request) AddFileWithNameAndType(formName, filepath, filename, filetype string) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + r.bodyFileData = append(r.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: f, + FileSize: stat.Size(), + FileType: filetype, + }) + return nil +} + +func (r *Request) AddFileNoError(formName, filepath string) *Request { + r.AddFile(formName, filepath) + return r +} + +func (r *Request) AddFileWithNameNoError(formName, filepath, filename string) *Request { + r.AddFileWithName(formName, filepath, filename) + return r +} + +func (r *Request) AddFileWithTypeNoError(formName, filepath, filetype string) *Request { + r.AddFileWithType(formName, filepath, filetype) + return r +} +func (r *Request) AddFileWithNameAndTypeNoError(formName, filepath, filename, filetype string) *Request { + r.AddFileWithNameAndType(formName, filepath, filename, filetype) + return r +} + +type RequestFile struct { + FormName string + FileName string + FileData io.Reader + FileSize int64 + FileType string +} + +type RequestOpt func(opt *RequestOpts) error +// if doRawTransport is true, this function will nolonger work func WithDialTimeout(timeout time.Duration) RequestOpt { - return func(opt *RequestOpts) { - opt.DialTimeout = timeout + return func(opt *RequestOpts) error { + opt.dialTimeout = timeout + return nil } } +// if doRawTransport is true, this function will nolonger work func WithTimeout(timeout time.Duration) RequestOpt { - return func(opt *RequestOpts) { - opt.Timeout = timeout + return func(opt *RequestOpts) error { + opt.timeout = timeout + return nil } } -func WithHeader(key, val string) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqHeader.Set(key, val) +// if doRawTransport is true, this function will nolonger work +func WithTlsConfig(tlscfg *tls.Config) RequestOpt { + return func(opt *RequestOpts) error { + opt.tlsConfig = tlscfg + return nil } } -func WithTlsConfig(tlscfg *tls.Config) RequestOpt { - return func(opt *RequestOpts) { - opt.TlsConfig = tlscfg +// if doRawRequest is true, this function will nolonger work +func WithHeader(key, val string) RequestOpt { + return func(opt *RequestOpts) error { + opt.headers.Set(key, val) + return nil } } +// if doRawRequest is true, this function will nolonger work func WithHeaderMap(header map[string]string) RequestOpt { - return func(opt *RequestOpts) { + return func(opt *RequestOpts) error { for key, val := range header { - opt.ReqHeader.Set(key, val) + opt.headers.Set(key, val) } + return nil } } -func WithHeaderAdd(key, val string) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqHeader.Add(key, val) +// if doRawRequest is true, this function will nolonger work +func WithReader(r io.Reader) RequestOpt { + return func(opt *RequestOpts) error { + opt.bodyDataReader = r + return nil } } -func WithReader(r io.Reader) RequestOpt { - return func(opt *RequestOpts) { - opt.PostBuffer = r +// if doRawRequest is true, this function will nolonger work +func WithBytes(r []byte) RequestOpt { + return func(opt *RequestOpts) error { + opt.bodyDataBytes = r + return nil + } +} + +// if doRawRequest is true, this function will nolonger work +func WithFormData(data map[string][]string) RequestOpt { + return func(opt *RequestOpts) error { + opt.bodyFormData = data + return nil + } +} + +// if doRawRequest is true, this function will nolonger work +func WithFileDatas(data []RequestFile) RequestOpt { + return func(opt *RequestOpts) error { + opt.bodyFileData = data + return nil + } +} + +// if doRawRequest is true, this function will nolonger work +func WithFileData(data RequestFile) RequestOpt { + return func(opt *RequestOpts) error { + opt.bodyFileData = append(opt.bodyFileData, data) + return nil + } +} + +// if doRawRequest is true, this function will nolonger work +func WithAddFile(formName, filepath string) RequestOpt { + return func(opt *RequestOpts) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + opt.bodyFileData = append(opt.bodyFileData, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FileData: f, + FileSize: stat.Size(), + FileType: "application/octet-stream", + }) + return nil + } +} + +func WithAddFileWithName(formName, filepath, filename string) RequestOpt { + return func(opt *RequestOpts) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + opt.bodyFileData = append(opt.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: f, + FileSize: stat.Size(), + FileType: "application/octet-stream", + }) + return nil + } +} + +func WithAddFileWithType(formName, filepath, filetype string) RequestOpt { + return func(opt *RequestOpts) error { + f, err := os.Open(filepath) + if err != nil { + return nil + } + stat, err := f.Stat() + if err != nil { + return nil + } + opt.bodyFileData = append(opt.bodyFileData, RequestFile{ + FormName: formName, + FileName: stat.Name(), + FileData: f, + FileSize: stat.Size(), + FileType: filetype, + }) + return nil + } +} + +func WithAddFileWithNameAndType(formName, filepath, filename, filetype string) RequestOpt { + return func(opt *RequestOpts) error { + f, err := os.Open(filepath) + if err != nil { + return err + } + stat, err := f.Stat() + if err != nil { + return err + } + opt.bodyFileData = append(opt.bodyFileData, RequestFile{ + FormName: formName, + FileName: filename, + FileData: f, + FileSize: stat.Size(), + FileType: filetype, + }) + return nil } } func WithFetchRespBody(fetch bool) RequestOpt { - return func(opt *RequestOpts) { - opt.WriteRecvData = fetch + return func(opt *RequestOpts) error { + opt.autoFetchRespBody = fetch + return nil } } func WithCookies(ck []*http.Cookie) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqCookies = ck + return func(opt *RequestOpts) error { + opt.cookies = ck + return nil } } func WithCookie(key, val, path string) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path}) + return func(opt *RequestOpts) error { + opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: path}) + return nil + } +} + +func WithSimpleCookie(key, val string) RequestOpt { + return func(opt *RequestOpts) error { + opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: "/"}) + return nil } } func WithCookieMap(header map[string]string, path string) RequestOpt { - return func(opt *RequestOpts) { + return func(opt *RequestOpts) error { for key, val := range header { - opt.ReqCookies = append(opt.ReqCookies, &http.Cookie{Name: key, Value: val, Path: path}) + opt.cookies = append(opt.cookies, &http.Cookie{Name: key, Value: val, Path: path}) } + return nil } } -func WithQueries(queries map[string]string) RequestOpt { - return func(opt *RequestOpts) { - opt.Queries = queries +func WithQueries(queries map[string][]string) RequestOpt { + return func(opt *RequestOpts) error { + opt.queries = queries + return nil + } +} + +func WithAddQueries(queries map[string]string) RequestOpt { + return func(opt *RequestOpts) error { + for k, v := range queries { + opt.queries[k] = append(opt.queries[k], v) + } + return nil + } +} + +func WithAddQuery(key, val string) RequestOpt { + return func(opt *RequestOpts) error { + opt.queries[key] = append(opt.queries[key], val) + return nil } } func WithProxy(proxy string) RequestOpt { - return func(opt *RequestOpts) { - opt.Proxy = proxy + return func(opt *RequestOpts) error { + opt.proxy = proxy + return nil } } -func WithProcess(fn func(float64)) RequestOpt { - return func(opt *RequestOpts) { - opt.Process = fn +func WithProcess(fn func(string, int64, int64)) RequestOpt { + return func(opt *RequestOpts) error { + opt.FileUploadRecallFn = fn + return nil } } func WithContentType(ct string) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqHeader.Set("Content-Type", ct) + return func(opt *RequestOpts) error { + opt.headers.Set("Content-Type", ct) + return nil } } func WithUserAgent(ua string) RequestOpt { - return func(opt *RequestOpts) { - opt.ReqHeader.Set("User-Agent", ua) + return func(opt *RequestOpts) error { + opt.headers.Set("User-Agent", ua) + return nil } } -func WithCustomTransport(hs *http.Transport) RequestOpt { - return func(opt *RequestOpts) { - opt.CustomTransport = hs +func WithSkipTLSVerify(skip bool) RequestOpt { + return func(opt *RequestOpts) error { + opt.skipTLSVerify = skip + return nil } } -func WithSkipTLSVerify(skip bool) RequestOpt { - return func(opt *RequestOpts) { - opt.SkipTLSVerify = skip +func WithDisableRedirect(disable bool) RequestOpt { + return func(opt *RequestOpts) error { + opt.disableRedirect = disable + return nil } } -func WithDisableRedirect(disable bool) RequestOpt { - return func(opt *RequestOpts) { - opt.DisableRedirect = disable +func WithDoRawRequest(doRawRequest bool) RequestOpt { + return func(opt *RequestOpts) error { + opt.doRawRequest = doRawRequest + return nil } } -func NewRequests(url string, rawdata []byte, method string, opts ...RequestOpt) Request { - req := Request{ - RequestOpts: RequestOpts{ - Timeout: 30 * time.Second, - DialTimeout: 15 * time.Second, - WriteRecvData: true, - }, - Url: url, - Method: method, +func WithDoRawClient(doRawClient bool) RequestOpt { + return func(opt *RequestOpts) error { + opt.doRawClient = doRawClient + return nil } - if rawdata != nil { - req.PostBuffer = bytes.NewBuffer(rawdata) +} + +func WithDoRawTransport(doRawTrans bool) RequestOpt { + return func(opt *RequestOpts) error { + opt.doRawTransport = doRawTrans + return nil } - req.ReqHeader = make(http.Header) - if strings.ToUpper(method) == "POST" { - req.ReqHeader.Set("Content-Type", HEADER_FORM_URLENCODE) +} + +func WithTransport(hs *http.Transport) RequestOpt { + return func(opt *RequestOpts) error { + opt.transport = hs + return nil } - req.ReqHeader.Set("User-Agent", "B612 / 1.1.0") - for _, v := range opts { - v(&req.RequestOpts) +} + +func WithRawRequest(req *http.Request) RequestOpt { + return func(opt *RequestOpts) error { + opt.rawRequest = req + return nil } - if req.CustomTransport == nil { - req.CustomTransport = &http.Transport{} +} + +func WithRawClient(hc *http.Client) RequestOpt { + return func(opt *RequestOpts) error { + opt.rawClient = hc + return nil } - if req.SkipTLSVerify { - if req.CustomTransport.TLSClientConfig == nil { - req.CustomTransport.TLSClientConfig = &tls.Config{} +} + +func WithCustomHostIP(ip []string) RequestOpt { + return func(opt *RequestOpts) error { + if len(ip) == 0 { + return nil + } + for _, v := range ip { + if net.ParseIP(v) == nil { + return fmt.Errorf("invalid custom ip: %s", v) + } } - req.CustomTransport.TLSClientConfig.InsecureSkipVerify = true + opt.customIP = ip + return nil } - if req.TlsConfig != nil { - req.CustomTransport.TLSClientConfig = req.TlsConfig +} + +func WithAddCustomHostIP(ip string) RequestOpt { + return func(opt *RequestOpts) error { + if net.ParseIP(ip) == nil { + return fmt.Errorf("invalid custom ip: %s", ip) + } + opt.customIP = append(opt.customIP, ip) + return nil } - req.CustomTransport.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { - c, err := net.DialTimeout(netw, addr, req.DialTimeout) - if err != nil { - return nil, err +} + +func WithLookUpFn(lookUpIPfn func(ctx context.Context, host string) ([]net.IPAddr, error)) RequestOpt { + return func(opt *RequestOpts) error { + if lookUpIPfn == nil { + opt.alreadySetLookUpIPfn = false + opt.lookUpIPfn = net.DefaultResolver.LookupIPAddr + return nil } - if req.Timeout != 0 { - c.SetDeadline(time.Now().Add(req.Timeout)) + opt.lookUpIPfn = lookUpIPfn + opt.alreadySetLookUpIPfn = true + return nil + } +} + +// WithCustomDNS will use custom dns to resolve the host +// Note: if LookUpIPfn is set, this function will not be used +func WithCustomDNS(customDNS []string) RequestOpt { + return func(opt *RequestOpts) error { + for _, v := range customDNS { + if net.ParseIP(v) == nil { + return fmt.Errorf("invalid custom dns: %s", v) + } } - return c, nil + opt.customDNS = customDNS + return nil } - return req } -func (curl *Request) ResetReqHeader() { - curl.ReqHeader = make(http.Header) +// WithAddCustomDNS will use a custom dns to resolve the host +// Note: if LookUpIPfn is set, this function will not be used +func WithAddCustomDNS(customDNS string) RequestOpt { + return func(opt *RequestOpts) error { + if net.ParseIP(customDNS) == nil { + return fmt.Errorf("invalid custom dns: %s", customDNS) + } + opt.customDNS = append(opt.customDNS, customDNS) + return nil + } } -func (curl *Request) ResetReqCookies() { - curl.ReqCookies = []*http.Cookie{} +// WithAutoCalcContentLength sets whether to automatically calculate the Content-Length header based on the request body. +// WARN: If set to true, the Content-Length header will be set to the length of the request body, which may cause issues with chunked transfer encoding. +// also the memory usage will be higher +// Note that this function will not work if doRawRequest is true +func (r *RequestOpts) WithAutoCalcContentLength(autoCalcContentLength bool) RequestOpt { + return func(opt *RequestOpts) error { + r.autoCalcContentLength = autoCalcContentLength + return nil + } } -func (curl *Request) AddSimpleCookie(key, value string) { - curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: "/"}) +type Response struct { + *http.Response + req Request + data *Body } -func (curl *Request) AddCookie(key, value, path string) { - curl.ReqCookies = append(curl.ReqCookies, &http.Cookie{Name: key, Value: value, Path: path}) + +type Body struct { + full []byte + raw io.ReadCloser + isFull bool } -func randomBoundary() string { - var buf [30]byte - _, err := io.ReadFull(rand.Reader, buf[:]) - if err != nil { - panic(err) +func (b *Body) readAll() { + if !b.isFull { + b.full, _ = io.ReadAll(b.raw) + b.isFull = true + b.raw.Close() } - return fmt.Sprintf("%x", buf[:]) } -func Curl(curl Request) (resps Request, err error) { - var fpsrc *os.File - if curl.RequestFile.UploadFile != "" { - fpsrc, err = os.Open(curl.UploadFile) - if err != nil { - return - } - defer fpsrc.Close() - boundary := randomBoundary() - boundarybytes := []byte("\r\n--" + boundary + "\r\n") - endbytes := []byte("\r\n--" + boundary + "--\r\n") - fpstat, _ := fpsrc.Stat() - filebig := float64(fpstat.Size()) - sum, n := 0, 0 - fpdst := stario.NewStarBuffer(1048576) - if curl.UploadForm != nil { - for k, v := range curl.UploadForm { - header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\";\r\nContent-Type: x-www-form-urlencoded \r\n\r\n", k) - fpdst.Write(boundarybytes) - fpdst.Write([]byte(header)) - fpdst.Write([]byte(v)) - } - } - header := fmt.Sprintf("Content-Disposition: form-data; name=\"%s\"; filename=\"%s\"\r\nContent-Type: application/octet-stream\r\n\r\n", curl.UploadName, fpstat.Name()) - fpdst.Write(boundarybytes) - fpdst.Write([]byte(header)) - go func() { - for { - bufs := make([]byte, 393213) - n, err = fpsrc.Read(bufs) - if err != nil { - if err == io.EOF { - if n != 0 { - fpdst.Write(bufs[0:n]) - if curl.Process != nil { - go curl.Process(float64(sum+n) / filebig * 100) - } - } - break - } - return - } - sum += n - if curl.Process != nil { - go curl.Process(float64(sum+n) / filebig * 100) - } - fpdst.Write(bufs[0:n]) - } - fpdst.Write(endbytes) - fpdst.Write(nil) - }() - curl.CircleBuffer = fpdst - curl.ReqHeader.Set("Content-Type", "multipart/form-data;boundary="+boundary) +func (b *Body) String() string { + b.readAll() + return string(b.full) +} + +func (b *Body) Bytes() []byte { + b.readAll() + return b.full +} + +func (b *Body) Unmarshal(u interface{}) error { + b.readAll() + return json.Unmarshal(b.full, u) +} + +// Reader returns a reader for the body +// if this function is called, other functions like String, Bytes, Unmarshal may not work +func (b *Body) Reader() io.ReadCloser { + if b.isFull { + return io.NopCloser(bytes.NewReader(b.full)) } - req, resp, err := netcurl(curl) + b.isFull = true + return b.raw +} + +func (b *Body) Close() error { + return b.Close() +} + +func (r *Response) GetRequest() Request { + return r.req +} + +func (r *Response) Body() *Body { + return r.data +} + +func Curl(r *Request) (*Response, error) { + r.errInfo = nil + err := applyOptions(r) if err != nil { - return Request{}, err - } - if resp.Request != nil && resp.Request.URL != nil { - curl.RespURL = resp.Request.URL.String() - } - curl.reqOrigin = req - curl.respOrigin = resp - curl.Location, _ = resp.Location() - curl.RespHttpCode = resp.StatusCode - curl.RespHeader = resp.Header - curl.RespCookies = resp.Cookies() - curl.RecvContentLength = resp.ContentLength - readFunc := func(reader io.ReadCloser, writer io.Writer) error { - lengthall := resp.ContentLength - defer reader.Close() - var lengthsum int - buf := make([]byte, 65535) - for { - n, err := reader.Read(buf) - if n != 0 { - _, err := writer.Write(buf[:n]) - lengthsum += n - if curl.Process != nil { - go curl.Process(float64(lengthsum) / float64(lengthall) * 100.00) - } - if err != nil { - return err - } - } - if err != nil && err != io.EOF { - return err - } else if err == io.EOF { - return nil - } - } + return nil, fmt.Errorf("apply options error: %s", err) } - if curl.WriteRecvData { - buf := bytes.NewBuffer([]byte{}) - err = readFunc(resp.Body, buf) - if err != nil { - return - } - curl.RecvData = buf.Bytes() - } else { - curl.respReader = resp.Body + resp, err := r.rawClient.Do(r.rawRequest) + var res = Response{ + Response: resp, + req: *r, + data: new(Body), } - if curl.RecvIo != nil { - if curl.WriteRecvData { - _, err = curl.RecvIo.Write(curl.RecvData) - } else { - err = readFunc(resp.Body, curl.RecvIo) - if err != nil { - return - } - } + if err != nil { + res.Response = &http.Response{} + return &res, fmt.Errorf("do request error: %s", err) + } + res.data.raw = resp.Body + if r.autoFetchRespBody { + res.data.full, _ = io.ReadAll(resp.Body) + res.data.isFull = true + resp.Body.Close() } - return curl, err + return &res, r.errInfo +} + +func NewReq(uri string, opts ...RequestOpt) *Request { + return NewSimpleRequest(uri, "GET", opts...) +} + +func NewReqWithContext(ctx context.Context, uri string, opts ...RequestOpt) *Request { + return NewSimpleRequestWithContext(ctx, uri, "GET", opts...) } -// RespBodyReader Only works when WriteRecvData set to false -func (curl *Request) RespBodyReader() io.ReadCloser { - return curl.respReader +func NewSimpleRequest(uri string, method string, opts ...RequestOpt) *Request { + r, _ := newRequest(context.Background(), uri, method, opts...) + return r } -func netcurl(curl Request) (*http.Request, *http.Response, error) { +func NewRequest(uri string, method string, opts ...RequestOpt) (*Request, error) { + return newRequest(context.Background(), uri, method, opts...) +} + +func NewSimpleRequestWithContext(ctx context.Context, uri string, method string, opts ...RequestOpt) *Request { + r, _ := newRequest(ctx, uri, method, opts...) + return r +} + +func NewRequestWithContext(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) { + return newRequest(ctx, uri, method, opts...) +} + +func newRequest(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) { var req *http.Request var err error - if curl.Method == "" { - return nil, nil, errors.New("Error Method Not Entered") - } - if curl.PostBuffer != nil { - req, err = http.NewRequest(curl.Method, curl.Url, curl.PostBuffer) - } else if curl.CircleBuffer != nil && curl.CircleBuffer.Len() > 0 { - req, err = http.NewRequest(curl.Method, curl.Url, curl.CircleBuffer) - } else { - req, err = http.NewRequest(curl.Method, curl.Url, nil) - } - if curl.Queries != nil { - sid := req.URL.Query() - for k, v := range curl.Queries { - sid.Add(k, v) - } - req.URL.RawQuery = sid.Encode() + if method == "" { + method = "GET" } + method = strings.ToUpper(method) + req, err = http.NewRequestWithContext(ctx, method, uri, nil) if err != nil { - return nil, nil, err + return nil, err } - req.Header = curl.ReqHeader - if len(curl.ReqCookies) != 0 { - for _, v := range curl.ReqCookies { - req.AddCookie(v) - } + var r = &Request{ + ctx: ctx, + uri: uri, + method: method, + RequestOpts: RequestOpts{ + rawRequest: req, + rawClient: new(http.Client), + timeout: DefaultTimeout, + dialTimeout: DefaultDialTimeout, + autoFetchRespBody: DefaultFetchRespBody, + lookUpIPfn: net.DefaultResolver.LookupIPAddr, + bodyFormData: make(map[string][]string), + queries: make(map[string][]string), + }, } - if curl.Proxy != "" { - purl, err := url.Parse(curl.Proxy) - if err != nil { - return nil, nil, err + + r.headers = make(http.Header) + if strings.ToUpper(method) == "POST" { + r.headers.Set("Content-Type", HEADER_FORM_URLENCODE) + } + r.headers.Set("User-Agent", "B612 / 1.2.0") + for _, v := range opts { + if v != nil { + err = v(&r.RequestOpts) + if err != nil { + return nil, err + } } - curl.CustomTransport.Proxy = http.ProxyURL(purl) } - client := &http.Client{ - Transport: curl.CustomTransport, + if r.transport == nil { + r.transport = &http.Transport{} } - if curl.DisableRedirect { - client.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse + if r.doRawTransport { + if r.skipTLSVerify { + if r.transport.TLSClientConfig == nil { + r.transport.TLSClientConfig = &tls.Config{} + } + r.transport.TLSClientConfig.InsecureSkipVerify = true + } + if r.tlsConfig != nil { + r.transport.TLSClientConfig = r.tlsConfig + } + r.transport.DialContext = func(ctx context.Context, netType, addr string) (net.Conn, error) { + var lastErr error + var addrs []string + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if len(r.customIP) > 0 { + for _, v := range r.customIP { + ipAddr := net.ParseIP(v) + if ipAddr == nil { + return nil, fmt.Errorf("invalid custom ip: %s", r.customIP) + } + tmpAddr := net.JoinHostPort(v, port) + addrs = append(addrs, tmpAddr) + } + } else { + ipLists, err := r.lookUpIPfn(ctx, host) + if err != nil { + return nil, err + } + for _, v := range ipLists { + tmpAddr := net.JoinHostPort(v.String(), port) + addrs = append(addrs, tmpAddr) + } + } + for _, addr := range addrs { + c, err := net.DialTimeout(netType, addr, r.dialTimeout) + if err != nil { + lastErr = err + continue + } + if r.timeout != 0 { + err = c.SetDeadline(time.Now().Add(r.timeout)) + } + return c, nil + } + return nil, lastErr } } - resp, err := client.Do(req) - - return req, resp, err + return r, nil } -func UrlEncodeRaw(str string) string { - strs := strings.Replace(url.QueryEscape(str), "+", "%20", -1) - return strs -} +func applyDataReader(r *Request) error { + // 优先度为:bodyDataReader > bodyDataBytes > bodyFormData > bodyFileData + if r.bodyDataReader != nil { + r.rawRequest.Body = io.NopCloser(r.bodyDataReader) + return nil + } + if len(r.bodyDataBytes) != 0 { + r.rawRequest.Body = io.NopCloser(bytes.NewReader(r.bodyDataBytes)) + return nil + } + if len(r.bodyFormData) != 0 && len(r.bodyFileData) == 0 { + var body = url.Values{} + for k, v := range r.bodyFormData { + for _, vv := range v { + body.Add(k, vv) + } + } + r.rawRequest.Body = io.NopCloser(strings.NewReader(body.Encode())) + return nil + } + if len(r.bodyFileData) != 0 { + var pr, pw = io.Pipe() + var w = multipart.NewWriter(pw) + r.rawRequest.Header.Set("Content-Type", w.FormDataContentType()) + go func() { + defer pw.Close() // ensure pipe writer is closed -func UrlEncode(str string) string { - return url.QueryEscape(str) -} + if len(r.bodyFormData) != 0 { + for k, v := range r.bodyFormData { + for _, vv := range v { + if err := w.WriteField(k, vv); err != nil { + r.errInfo = err + pw.CloseWithError(err) // close pipe with error + return + } + } + } + } + for _, v := range r.bodyFileData { + var fw, err = w.CreateFormFile(v.FormName, v.FileName) + if err != nil { + r.errInfo = err + pw.CloseWithError(err) // close pipe with error + return + } + if _, err := copyWithContext(r.ctx, r.FileUploadRecallFn, v.FileName, v.FileSize, fw, v.FileData); err != nil { + r.errInfo = err + pw.CloseWithError(err) // close pipe with error + return + } + } -func UrlDecode(str string) (string, error) { - return url.QueryUnescape(str) -} + if err := w.Close(); err != nil { + pw.CloseWithError(err) // close pipe with error if writer close fails + } + }() -func BuildQuery(queryData map[string]string) string { - query := url.Values{} - for k, v := range queryData { - query.Add(k, v) + r.rawRequest.Body = pr + return nil } - return query.Encode() + return nil } -func BuildPostForm(queryMap map[string]string) []byte { - query := url.Values{} - for k, v := range queryMap { - query.Add(k, v) +func applyOptions(r *Request) error { + defer func() { + r.alreadyApply = true + }() + var req = r.rawRequest + if !r.doRawRequest { + if r.queries != nil { + sid := req.URL.Query() + for k, v := range r.queries { + for _, vv := range v { + sid.Add(k, vv) + } + } + req.URL.RawQuery = sid.Encode() + } + for k, v := range r.headers { + for _, vv := range v { + req.Header.Add(k, vv) + } + } + if len(r.cookies) != 0 { + for _, v := range r.cookies { + req.AddCookie(v) + } + } + if r.basicAuth[0] != "" || r.basicAuth[1] != "" { + req.SetBasicAuth(r.basicAuth[0], r.basicAuth[1]) + } + err := applyDataReader(r) + if err != nil { + return fmt.Errorf("apply data reader error: %s", err) + } + if r.autoCalcContentLength { + if req.Body != nil { + data, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("read data error: %s", err) + } + req.Header.Set("Content-Length", strconv.Itoa(len(data))) + req.Body = io.NopCloser(bytes.NewReader(data)) + } + } } - return []byte(query.Encode()) + if r.proxy != "" { + purl, err := url.Parse(r.proxy) + if err != nil { + return fmt.Errorf("parse proxy url error: %s", err) + } + r.transport.Proxy = http.ProxyURL(purl) + } + if !r.doRawClient { + if !r.doRawTransport { + if !r.alreadySetLookUpIPfn && len(r.customIP) > 0 { + resolver := net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (conn net.Conn, err error) { + for _, addr := range r.customIP { + if conn, err = net.Dial("udp", addr+":53"); err != nil { + continue + } else { + return conn, nil + } + } + return + }, + } + r.lookUpIPfn = resolver.LookupIPAddr + } + } + r.rawClient.Transport = r.transport + if r.disableRedirect { + r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + } + return nil } -func (r Request) Resopnse() *http.Response { - return r.respOrigin -} +func copyWithContext(ctx context.Context, recall func(string, int64, int64), filename string, total int64, dst io.Writer, src io.Reader) (written int64, err error) { + pr, pw := io.Pipe() + defer pr.Close() -func (r Request) Request() *http.Request { - return r.reqOrigin + go func() { + defer pw.Close() + _, err := io.Copy(pw, src) + if err != nil { + pw.CloseWithError(err) + } + }() + var count int64 + buf := make([]byte, 4096) + for { + select { + case <-ctx.Done(): + return written, ctx.Err() + default: + nr, err := pr.Read(buf) + if err != nil { + if err == io.EOF { + go recall(filename, count, total) + return written, nil + } + return written, err + } + count += int64(nr) + if recall != nil { + go recall(filename, count, total) + } + nw, err := dst.Write(buf[:nr]) + if err != nil { + return written, err + } + if nr != nw { + return written, io.ErrShortWrite + } + written += int64(nr) + } + } } diff --git a/curl_test.go b/curl_test.go new file mode 100644 index 0000000..29c21b4 --- /dev/null +++ b/curl_test.go @@ -0,0 +1,464 @@ +package starnet + +import ( + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func TestUrlEncodeRaw(t *testing.T) { + input := "hello world!@#$%^&*()_+-=~`" + expected := "hello%20world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60" + result := UrlEncodeRaw(input) + if result != expected { + t.Errorf("UrlEncodeRaw(%q) = %q; want %q", input, result, expected) + } +} + +func TestUrlEncode(t *testing.T) { + input := "hello world!@#$%^&*()_+-=~`" + expected := `hello+world%21%40%23%24%25%5E%26%2A%28%29_%2B-%3D~%60` + result := UrlEncode(input) + if result != expected { + t.Errorf("UrlEncode(%q) = %q; want %q", input, result, expected) + } +} + +func TestUrlDecode(t *testing.T) { + input := "hello%20world%21%40%23%24%25%5E%26*%28%29_%2B-%3D~%60" + expected := "hello world!@#$%^&*()_+-=~`" + result, err := UrlDecode(input) + if err != nil { + t.Errorf("UrlDecode(%q) returned error: %v", input, err) + } + if result != expected { + t.Errorf("UrlDecode(%q) = %q; want %q", input, result, expected) + } + + // Test for error case + invalidInput := "%zz" + _, err = UrlDecode(invalidInput) + if err == nil { + t.Errorf("UrlDecode(%q) expected error, got nil", invalidInput) + } +} + +func TestBuildPostForm_WithValidInput(t *testing.T) { + input := map[string]string{ + "key1": "value1", + "key2": "value2", + } + + expected := []byte("key1=value1&key2=value2") + + result := BuildPostForm(input) + + if string(result) != string(expected) { + t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) + } +} + +func TestBuildPostForm_WithEmptyInput(t *testing.T) { + input := map[string]string{} + + expected := []byte("") + + result := BuildPostForm(input) + + if string(result) != string(expected) { + t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) + } +} + +func TestBuildPostForm_WithNilInput(t *testing.T) { + var input map[string]string + + expected := []byte("") + + result := BuildPostForm(input) + + if string(result) != string(expected) { + t.Errorf("BuildPostForm(%v) = %v; want %v", input, result, expected) + } +} + +func TestGetRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Get(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestPostRequest(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + t.Errorf("Expected 'POST', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Post(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestOptionsRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodOptions { + t.Errorf("Expected 'OPTIONS', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Options(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestPutRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPut { + t.Errorf("Expected 'PUT', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Put(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestDeleteRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodDelete { + t.Errorf("Expected 'DELETE', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Delete(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestHeadRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodHead { + t.Errorf("Expected 'HEAD', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Head(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body == "OK" { + t.Errorf("Expected , got %v", body) + } +} + +func TestPatchRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPatch { + t.Errorf("Expected 'PATCH', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Patch(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestTraceRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodTrace { + t.Errorf("Expected 'TRACE', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Trace(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} + +func TestConnectRequestWithValidInput(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodConnect { + t.Errorf("Expected 'CONNECT', got %v", req.Method) + } + rw.Write([]byte(`OK`)) + })) + defer server.Close() + + resp, err := Connect(server.URL) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + body := resp.Body().String() + if body != "OK" { + t.Errorf("Expected OK, got %v", body) + } +} +func TestMethodReturnsCorrectValue(t *testing.T) { + req := NewReq("https://example.com") + req.SetMethodNoError("GET") + if req.Method() != "GET" { + t.Errorf("Expected 'GET', got %v", req.Method()) + } +} + +func TestSetMethodHandlesInvalidInput(t *testing.T) { + req := NewReq("https://example.com") + err := req.SetMethod("我是谁") + if err == nil { + t.Errorf("Expected error, got nil") + } +} + +func TestSetMethodNoErrorSetsMethodCorrectly(t *testing.T) { + req := NewReq("https://example.com") + req.SetMethodNoError("POST") + if req.Method() != "POST" { + t.Errorf("Expected 'POST', got %v", req.Method()) + } +} + +func TestSetMethodNoErrorIgnoresInvalidInput(t *testing.T) { + req := NewReq("https://example.com") + req.SetMethodNoError("你是谁") + if req.Method() != "GET" { + t.Errorf("Expected '', got %v", req.Method()) + } +} + +func TestUriReturnsCorrectValue(t *testing.T) { + req := NewReq("https://example.com") + if req.Uri() != "https://example.com" { + t.Errorf("Expected 'https://example.com', got %v", req.Uri()) + } +} + +func TestSetUriHandlesValidInput(t *testing.T) { + req := NewReq("https://example.com") + err := req.SetUri("https://newexample.com") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if req.Uri() != "https://newexample.com" { + t.Errorf("Expected 'https://newexample.com', got %v", req.Uri()) + } +} + +func TestSetUriHandlesInvalidInput(t *testing.T) { + req := NewReq("https://example.com") + err := req.SetUri("://invalidurl") + if err == nil { + t.Errorf("Expected error, got nil") + } +} + +func TestSetUriNoErrorSetsUriCorrectly(t *testing.T) { + req := NewReq("https://example.com") + req.SetUriNoError("https://newexample.com") + if req.Uri() != "https://newexample.com" { + t.Errorf("Expected 'https://newexample.com', got %v", req.Uri()) + } +} + +func TestSetUriNoErrorIgnoresInvalidInput(t *testing.T) { + req := NewReq("https://example.com") + req.SetUriNoError("://invalidurl") + if req.Uri() != "https://example.com" { + t.Errorf("Expected 'https://example.com', got %v", req.Uri()) + } +} + +type postmanReply struct { + Args struct { + } `json:"args"` + Form map[string]string `json:"form"` + Headers map[string]string `json:"headers"` + Url string `json:"url"` +} + +func TestGet(t *testing.T) { + var reply postmanReply + resp, err := NewReq("https://postman-echo.com/get"). + AddHeader("hello", "nononmo"). + SetAutoCalcContentLength(true). + Do() + if err != nil { + t.Error(err) + } + err = resp.Body().Unmarshal(&reply) + if err != nil { + t.Error(err) + } + fmt.Println(resp.Body().String()) + fmt.Println(reply.Headers) + fmt.Println(resp.Cookies()) +} + +type testData struct { + name string + args *Request + want func(*Response) error + wantErr bool +} + +func headerTestData() []testData { + return []testData{ + { + name: "addHeader", + args: NewReq("https://postman-echo.com/get"). + AddHeader("b612", "test-data"). + AddHeader("b612", "test-header"). + AddSimpleCookie("b612", "test-cookie"). + SetHeader("User-Agent", "starnet test"), + want: func(resp *Response) error { + //fmt.Println(resp.Body().String()) + if resp == nil { + return fmt.Errorf("response is nil") + } + if resp.StatusCode != 200 { + return fmt.Errorf("status code is %d", resp.StatusCode) + } + var reply postmanReply + err := resp.Body().Unmarshal(&reply) + if err != nil { + return err + } + if reply.Headers["b612"] != "test-data, test-header" { + return fmt.Errorf("header not found") + } + if reply.Headers["user-agent"] != "starnet test" { + return fmt.Errorf("user-agent not found") + } + if reply.Headers["cookie"] != "b612=test-cookie" { + return fmt.Errorf("cookie not found") + } + return nil + }, + wantErr: false, + }, + { + name: "postForm", + args: NewSimpleRequest("https://postman-echo.com/post", "POST"). + AddHeader("b612", "test-data"). + AddHeader("b612", "test-header"). + AddSimpleCookie("b612", "test-cookie"). + SetHeader("User-Agent", "starnet test"). + //SetHeader("Content-Type", "application/x-www-form-urlencoded"). + AddFormData("hello", "world"). + AddFormData("hello2", "world2"). + SetMethodNoError("POST"), + want: func(resp *Response) error { + //fmt.Println(resp.Body().String()) + if resp == nil { + return fmt.Errorf("response is nil") + } + if resp.StatusCode != 200 { + return fmt.Errorf("status code is %d", resp.StatusCode) + } + var reply postmanReply + err := resp.Body().Unmarshal(&reply) + if err != nil { + return err + } + if reply.Headers["b612"] != "test-data, test-header" { + return fmt.Errorf("header not found") + } + if reply.Headers["user-agent"] != "starnet test" { + return fmt.Errorf("user-agent not found") + } + if reply.Headers["cookie"] != "b612=test-cookie" { + return fmt.Errorf("cookie not found") + } + if reply.Form["hello"] != "world" { + return fmt.Errorf("form data not found") + } + if reply.Form["hello2"] != "world2" { + return fmt.Errorf("form data not found") + } + return nil + }, + wantErr: false, + }, + } +} +func TestCurl(t *testing.T) { + for _, tt := range headerTestData() { + t.Run(tt.name, func(t *testing.T) { + got, err := Curl(tt.args) + if (err != nil) != tt.wantErr { + t.Errorf("Curl() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.want != nil { + if err := tt.want(got); err != nil { + t.Errorf("Curl() = %v", err) + } + } + }) + } +} diff --git a/go.mod b/go.mod index b75fccc..4ac747d 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module b612.me/starnet go 1.16 - -require b612.me/stario v0.0.9 diff --git a/go.sum b/go.sum index 631e800..e69de29 100644 --- a/go.sum +++ b/go.sum @@ -1,47 +0,0 @@ -b612.me/stario v0.0.9 h1:bFDlejUJMwZ12a09snZJspQsOlkqpDAl9qKPEYOGWCk= -b612.me/stario v0.0.9/go.mod h1:x4D/x8zA5SC0pj/uJAi4FyG5p4j5UZoMEZfvuRR6VNw= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= -golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= -golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= -golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= -golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= -golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= -golang.org/x/term v0.18.0 h1:FcHjZXDMxI8mM3nwhX9HlKop4C0YQvCVCdwYl2wOtE8= -golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/httpguts.go b/httpguts.go new file mode 100644 index 0000000..baf8d38 --- /dev/null +++ b/httpguts.go @@ -0,0 +1,120 @@ +package starnet + +import "strings" + +var isTokenTable = [127]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +func IsTokenRune(r rune) bool { + i := int(r) + return i < len(isTokenTable) && isTokenTable[i] +} + +func validMethod(method string) bool { + /* + Method = "OPTIONS" ; Section 9.2 + | "GET" ; Section 9.3 + | "HEAD" ; Section 9.4 + | "POST" ; Section 9.5 + | "PUT" ; Section 9.6 + | "DELETE" ; Section 9.7 + | "TRACE" ; Section 9.8 + | "CONNECT" ; Section 9.9 + | extension-method + extension-method = token + token = 1* + */ + return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1 +} + +func isNotToken(r rune) bool { + return !IsTokenRune(r) +} + +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +}