starnet/utils.go

213 lines
5.0 KiB
Go
Raw Permalink Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/tls"
"io"
"net/http"
"net/url"
"strings"
)
// validMethod 验证 HTTP 方法是否有效
func validMethod(method string) bool {
return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
}
// isNotToken 检查字符是否不是 token 字符
func isNotToken(r rune) bool {
return !isTokenRune(r)
}
// isTokenRune 检查字符是否是 token 字符
func isTokenRune(r rune) bool {
i := int(r)
return i < 127 && isTokenTable[i]
}
// isTokenTable token 字符表
var isTokenTable = [127]bool{
'!': true, '#': true, '$': true, '%': true, '&': true, '\'': true, '*': true,
'+': true, '-': true, '.': true, '0': true, '1': true, '2': true, '3': true,
'4': true, '5': true, '6': true, '7': true, '8': true, '9': true, 'A': true,
'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true,
'P': true, 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true,
'W': true, 'X': true, 'Y': true, 'Z': true, '^': true, '_': true, '`': true,
'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true,
'h': true, 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true,
'o': true, 'p': true, 'q': true, 'r': true, 's': true, 't': true, 'u': true,
'v': true, 'w': true, 'x': true, 'y': true, 'z': true, '|': true, '~': true,
}
// hasPort 检查地址是否包含端口
func hasPort(s string) bool {
return strings.LastIndex(s, ":") > strings.LastIndex(s, "]")
}
// removeEmptyPort 移除空端口
func removeEmptyPort(host string) string {
if hasPort(host) {
return strings.TrimSuffix(host, ":")
}
return host
}
// UrlEncode URL 编码
func UrlEncode(str string) string {
return url.QueryEscape(str)
}
// UrlEncodeRaw URL 编码(空格编码为 %20
func UrlEncodeRaw(str string) string {
return strings.Replace(url.QueryEscape(str), "+", "%20", -1)
}
// UrlDecode URL 解码
func UrlDecode(str string) (string, error) {
return url.QueryUnescape(str)
}
// BuildQuery 构建查询字符串
func BuildQuery(data map[string]string) string {
query := url.Values{}
for k, v := range data {
query.Add(k, v)
}
return query.Encode()
}
// BuildPostForm 构建 POST 表单数据
func BuildPostForm(data map[string]string) []byte {
return []byte(BuildQuery(data))
}
// cloneHeader 克隆 Header
func cloneHeader(h http.Header) http.Header {
if h == nil {
return make(http.Header)
}
newHeader := make(http.Header, len(h))
for k, v := range h {
newHeader[k] = append([]string(nil), v...)
}
return newHeader
}
// cloneCookies 克隆 Cookies
func cloneCookies(cookies []*http.Cookie) []*http.Cookie {
if cookies == nil {
return nil
}
newCookies := make([]*http.Cookie, len(cookies))
for i, c := range cookies {
newCookies[i] = &http.Cookie{
Name: c.Name,
Value: c.Value,
Path: c.Path,
Domain: c.Domain,
Expires: c.Expires,
RawExpires: c.RawExpires,
MaxAge: c.MaxAge,
Secure: c.Secure,
HttpOnly: c.HttpOnly,
SameSite: c.SameSite,
Raw: c.Raw,
Unparsed: append([]string(nil), c.Unparsed...),
}
}
return newCookies
}
// cloneStringMapSlice 克隆 map[string][]string
func cloneStringMapSlice(m map[string][]string) map[string][]string {
if m == nil {
return make(map[string][]string)
}
newMap := make(map[string][]string, len(m))
for k, v := range m {
newMap[k] = append([]string(nil), v...)
}
return newMap
}
// cloneBytes 克隆字节切片
func cloneBytes(b []byte) []byte {
if b == nil {
return nil
}
newBytes := make([]byte, len(b))
copy(newBytes, b)
return newBytes
}
// cloneStringSlice 克隆字符串切片
func cloneStringSlice(s []string) []string {
if s == nil {
return nil
}
newSlice := make([]string, len(s))
copy(newSlice, s)
return newSlice
}
// cloneFiles 克隆文件列表
func cloneFiles(files []RequestFile) []RequestFile {
if files == nil {
return nil
}
newFiles := make([]RequestFile, len(files))
copy(newFiles, files)
return newFiles
}
// cloneTLSConfig 克隆 TLS 配置
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return nil
}
return cfg.Clone()
}
// copyWithProgress 带进度的复制
func copyWithProgress(ctx context.Context, dst io.Writer, src io.Reader, filename string, total int64, progress UploadProgressFunc) (int64, error) {
if progress == nil {
return io.Copy(dst, src)
}
var written int64
buf := make([]byte, 32*1024) // 32KB buffer
for {
select {
case <-ctx.Done():
return written, ctx.Err()
default:
}
nr, err := src.Read(buf)
if nr > 0 {
nw, ew := dst.Write(buf[:nr])
if nw > 0 {
written += int64(nw)
// 同步调用进度回调(不使用 goroutine
progress(filename, written, total)
}
if ew != nil {
return written, ew
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
// 最后一次进度回调
progress(filename, written, total)
return written, nil
}
return written, err
}
}
}