8 Commits

Author SHA1 Message Date
b612 4e17fee681 bug fix 2025-07-14 18:38:31 +08:00
b612 a8eed30db5 add http client control 2025-07-14 18:23:14 +08:00
b612 c1eaf43058 update 2025-06-17 12:36:57 +08:00
b612 9f5aca124d update 2025-06-17 12:09:12 +08:00
b612 54958724e7 bug fix 2025-06-13 17:16:38 +08:00
b612 7a17672149 update tls sniffer 2025-06-12 16:50:47 +08:00
b612 44b807d3d1 update 2025-06-06 15:43:38 +08:00
b612 0d847462b3 bug fix:nil pointer 2025-04-28 13:19:45 +08:00
3 changed files with 707 additions and 8 deletions
+1
View File
@@ -0,0 +1 @@
.idea
+300 -3
View File
@@ -14,6 +14,7 @@ import (
"os"
"strconv"
"strings"
"sync"
"time"
)
@@ -108,6 +109,186 @@ type Request struct {
RequestOpts
}
func (r *Request) Clone() *Request {
clonedRequest := &Request{
ctx: r.ctx,
uri: r.uri,
method: r.method,
errInfo: r.errInfo,
RequestOpts: RequestOpts{
headers: CloneHeader(r.headers),
cookies: CloneCookies(r.cookies),
bodyFormData: CloneStringMapSlice(r.bodyFormData),
bodyFileData: CloneFiles(r.bodyFileData),
queries: CloneStringMapSlice(r.queries),
bodyDataBytes: CloneByteSlice(r.bodyDataBytes),
proxy: r.proxy,
timeout: r.timeout,
dialTimeout: r.dialTimeout,
dialFn: r.dialFn,
alreadyApply: r.alreadyApply,
disableRedirect: r.disableRedirect,
doRawRequest: r.doRawRequest,
doRawClient: r.doRawClient,
doRawTransport: r.doRawTransport,
skipTLSVerify: r.skipTLSVerify,
autoFetchRespBody: r.autoFetchRespBody,
customIP: CloneStringSlice(r.customIP),
alreadySetLookUpIPfn: r.alreadySetLookUpIPfn,
lookUpIPfn: r.lookUpIPfn,
customDNS: CloneStringSlice(r.customDNS),
basicAuth: r.basicAuth,
autoCalcContentLength: r.autoCalcContentLength,
},
}
// 手动深拷贝嵌套引用类型
if r.bodyDataReader != nil {
clonedRequest.bodyDataReader = r.bodyDataReader
}
if r.FileUploadRecallFn != nil {
clonedRequest.FileUploadRecallFn = r.FileUploadRecallFn
}
// 对于 tlsConfig 类型,需要手动复制
if r.tlsConfig != nil {
clonedRequest.tlsConfig = CloneTLSConfig(r.tlsConfig)
}
// 对于 http.Transport,需要进行手动复制
if r.transport != nil {
clonedRequest.transport = CloneTransport(r.transport)
}
return clonedRequest
}
// CloneHeader 复制 http.Header
func CloneHeader(original http.Header) http.Header {
newHeader := make(http.Header)
for key, values := range original {
copiedValues := make([]string, len(values))
copy(copiedValues, values)
newHeader[key] = copiedValues
}
return newHeader
}
// CloneCookies 复制 []*http.Cookie
func CloneCookies(original []*http.Cookie) []*http.Cookie {
cloned := make([]*http.Cookie, len(original))
for i, cookie := range original {
cloned[i] = &http.Cookie{
Name: cookie.Name,
Value: cookie.Value,
Path: cookie.Path,
Domain: cookie.Domain,
Expires: cookie.Expires,
RawExpires: cookie.RawExpires,
MaxAge: cookie.MaxAge,
Secure: cookie.Secure,
HttpOnly: cookie.HttpOnly,
SameSite: cookie.SameSite,
Raw: cookie.Raw,
Unparsed: append([]string(nil), cookie.Unparsed...),
}
}
return cloned
}
// CloneStringMapSlice 复制 map[string][]string
func CloneStringMapSlice(original map[string][]string) map[string][]string {
newMap := make(map[string][]string)
for key, values := range original {
copiedValues := make([]string, len(values))
copy(copiedValues, values)
newMap[key] = copiedValues
}
return newMap
}
// CloneFiles 复制 []RequestFile
func CloneFiles(original []RequestFile) []RequestFile {
newFiles := make([]RequestFile, len(original))
copy(newFiles, original)
return newFiles
}
// CloneByteSlice 复制 []byte
func CloneByteSlice(original []byte) []byte {
if original == nil {
return nil
}
newSlice := make([]byte, len(original))
copy(newSlice, original)
return newSlice
}
// CloneStringSlice 复制 []string
func CloneStringSlice(original []string) []string {
newSlice := make([]string, len(original))
copy(newSlice, original)
return newSlice
}
// CloneTLSConfig 复制 tls.Config
func CloneTLSConfig(original *tls.Config) *tls.Config {
newConfig := &tls.Config{
Rand: original.Rand,
Time: original.Time,
Certificates: append([]tls.Certificate(nil), original.Certificates...),
NameToCertificate: original.NameToCertificate,
GetCertificate: original.GetCertificate,
GetClientCertificate: original.GetClientCertificate,
GetConfigForClient: original.GetConfigForClient,
VerifyPeerCertificate: original.VerifyPeerCertificate,
VerifyConnection: original.VerifyConnection,
RootCAs: original.RootCAs,
NextProtos: append([]string(nil), original.NextProtos...),
ServerName: original.ServerName,
ClientAuth: original.ClientAuth,
ClientCAs: original.ClientCAs,
InsecureSkipVerify: original.InsecureSkipVerify,
CipherSuites: append([]uint16(nil), original.CipherSuites...),
PreferServerCipherSuites: original.PreferServerCipherSuites,
SessionTicketsDisabled: original.SessionTicketsDisabled,
SessionTicketKey: original.SessionTicketKey,
ClientSessionCache: original.ClientSessionCache,
MinVersion: original.MinVersion,
MaxVersion: original.MaxVersion,
CurvePreferences: append([]tls.CurveID(nil), original.CurvePreferences...),
DynamicRecordSizingDisabled: original.DynamicRecordSizingDisabled,
Renegotiation: original.Renegotiation,
KeyLogWriter: original.KeyLogWriter,
}
return newConfig
}
// CloneTransport 复制 http.Transport
func CloneTransport(original *http.Transport) *http.Transport {
newTransport := &http.Transport{
Proxy: original.Proxy,
DialContext: original.DialContext,
Dial: original.Dial,
DialTLS: original.DialTLS,
TLSClientConfig: original.TLSClientConfig,
TLSHandshakeTimeout: original.TLSHandshakeTimeout,
DisableKeepAlives: original.DisableKeepAlives,
DisableCompression: original.DisableCompression,
MaxIdleConns: original.MaxIdleConns,
MaxIdleConnsPerHost: original.MaxIdleConnsPerHost,
IdleConnTimeout: original.IdleConnTimeout,
ResponseHeaderTimeout: original.ResponseHeaderTimeout,
ExpectContinueTimeout: original.ExpectContinueTimeout,
TLSNextProto: original.TLSNextProto,
ProxyConnectHeader: original.ProxyConnectHeader,
MaxResponseHeaderBytes: original.MaxResponseHeaderBytes,
WriteBufferSize: original.WriteBufferSize,
ReadBufferSize: original.ReadBufferSize,
}
return newTransport
}
func (r *Request) Method() string {
return r.method
}
@@ -202,6 +383,7 @@ type RequestOpts struct {
proxy string
timeout time.Duration
dialTimeout time.Duration
dialFn func(ctx context.Context, network, addr string) (net.Conn, error)
headers http.Header
cookies []*http.Cookie
transport *http.Transport
@@ -224,6 +406,14 @@ type RequestOpts struct {
autoCalcContentLength bool
}
func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) {
return r.dialFn
}
func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) {
r.dialFn = dialFn
}
func (r *Request) AutoCalcContentLength() bool {
return r.autoCalcContentLength
}
@@ -665,6 +855,14 @@ func (r *Request) AddFileWithNameAndTypeNoError(formName, filepath, filename, fi
return r
}
func (r *Request) HttpClient() (*http.Client, error) {
err := applyOptions(r)
if err != nil {
return nil, err
}
return r.rawClient, nil
}
type RequestFile struct {
FormName string
FileName string
@@ -683,6 +881,14 @@ func WithDialTimeout(timeout time.Duration) RequestOpt {
}
}
// if doRawTransport is true, this function will nolonger work
func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt {
return func(opt *RequestOpts) error {
opt.dialFn = fn
return nil
}
}
// if doRawTransport is true, this function will nolonger work
func WithTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) error {
@@ -1065,16 +1271,24 @@ type Response struct {
*http.Response
req Request
data *Body
rawClient *http.Client
}
type Body struct {
full []byte
raw io.ReadCloser
isFull bool
sync.Mutex
}
func (b *Body) readAll() {
b.Lock()
defer b.Unlock()
if !b.isFull {
if b.raw == nil {
b.isFull = true
return
}
b.full, _ = io.ReadAll(b.raw)
b.isFull = true
b.raw.Close()
@@ -1099,6 +1313,8 @@ func (b *Body) Unmarshal(u interface{}) error {
// Reader returns a reader for the body
// if this function is called, other functions like String, Bytes, Unmarshal may not work
func (b *Body) Reader() io.ReadCloser {
b.Lock()
defer b.Unlock()
if b.isFull {
return io.NopCloser(bytes.NewReader(b.full))
}
@@ -1118,6 +1334,24 @@ func (r *Response) Body() *Body {
return r.data
}
func (r *Response) Close() error {
if r != nil && r.data != nil && r.data.raw != nil {
return r.Response.Body.Close()
}
return nil
}
func (r *Response) CloseAll() error {
if r.rawClient != nil {
r.rawClient.CloseIdleConnections()
}
return r.Close()
}
func (r *Response) HttpClient() *http.Client {
return r.rawClient
}
func Curl(r *Request) (*Response, error) {
r.errInfo = nil
err := applyOptions(r)
@@ -1129,6 +1363,7 @@ func Curl(r *Request) (*Response, error) {
Response: resp,
req: *r,
data: new(Body),
rawClient: r.rawClient,
}
if err != nil {
res.Response = &http.Response{}
@@ -1169,6 +1404,17 @@ func NewRequestWithContext(ctx context.Context, uri string, method string, opts
return newRequest(ctx, uri, method, opts...)
}
func NewHttpClient(opts ...RequestOpt) (*http.Client, error) {
req, err := newRequest(context.Background(), "", "", opts...)
if err != nil {
return nil, err
}
defer func() {
req = nil
}()
return req.HttpClient()
}
func newRequest(ctx context.Context, uri string, method string, opts ...RequestOpt) (*Request, error) {
var req *http.Request
var err error
@@ -1186,7 +1432,7 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
method: method,
RequestOpts: RequestOpts{
rawRequest: req,
rawClient: new(http.Client),
rawClient: nil,
timeout: DefaultTimeout,
dialTimeout: DefaultDialTimeout,
autoFetchRespBody: DefaultFetchRespBody,
@@ -1209,10 +1455,13 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
}
}
}
if r.transport == nil {
if r.transport == nil && !r.doRawTransport {
r.transport = &http.Transport{}
}
if r.doRawTransport {
if r.rawClient == nil && !r.doRawClient {
r.rawClient = new(http.Client)
}
if !r.doRawTransport {
if r.skipTLSVerify {
if r.transport.TLSClientConfig == nil {
r.transport.TLSClientConfig = &tls.Config{}
@@ -1261,6 +1510,9 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
}
return nil, lastErr
}
if r.dialFn != nil {
r.transport.DialContext = r.dialFn
}
}
return r, nil
}
@@ -1397,6 +1649,18 @@ func applyOptions(r *Request) error {
r.lookUpIPfn = resolver.LookupIPAddr
}
}
if r.skipTLSVerify {
if r.transport.TLSClientConfig == nil {
r.transport.TLSClientConfig = &tls.Config{}
}
r.transport.TLSClientConfig.InsecureSkipVerify = true
}
if r.tlsConfig != nil {
r.transport.TLSClientConfig = r.tlsConfig
}
if r.dialFn != nil {
r.transport.DialContext = r.dialFn
}
r.rawClient.Transport = r.transport
if r.disableRedirect {
r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
@@ -1448,3 +1712,36 @@ func copyWithContext(ctx context.Context, recall func(string, int64, int64), fil
}
}
}
func NewReqWithClient(client *http.Client, uri string, opts ...RequestOpt) *Request {
return NewSimpleRequestWithClient(client, uri, "GET", opts...)
}
func NewReqWithContextWithClient(ctx context.Context, client *http.Client, uri string, opts ...RequestOpt) *Request {
return NewSimpleRequestWithContextWithClient(ctx, client, uri, "GET", opts...)
}
func NewSimpleRequestWithClient(client *http.Client, uri string, method string, opts ...RequestOpt) *Request {
r, _ := NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...)
return r
}
func NewRequestWithClient(client *http.Client, uri string, method string, opts ...RequestOpt) (*Request, error) {
return NewRequestWithContextWithClient(context.Background(), client, uri, method, opts...)
}
func NewSimpleRequestWithContextWithClient(ctx context.Context, client *http.Client, uri string, method string, opts ...RequestOpt) *Request {
r, _ := NewRequestWithContextWithClient(ctx, client, uri, method, opts...)
return r
}
func NewRequestWithContextWithClient(ctx context.Context, client *http.Client, uri string, method string, opts ...RequestOpt) (*Request, error) {
req, err := newRequest(ctx, uri, method, opts...)
if err != nil {
return nil, err
}
req.rawClient = client
req.SetDoRawClient(true)
req.SetDoRawTransport(true)
return req, err
}
+401
View File
@@ -0,0 +1,401 @@
package starnet
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"sync"
"time"
)
type myConn struct {
reader io.Reader
conn net.Conn
isReadOnly bool
multiReader io.Reader
}
func (c *myConn) Read(p []byte) (int, error) {
if c.isReadOnly {
return c.reader.Read(p)
}
if c.multiReader == nil {
c.multiReader = io.MultiReader(c.reader, c.conn)
}
return c.multiReader.Read(p)
}
func (c *myConn) Write(p []byte) (int, error) {
if c.isReadOnly {
return 0, io.ErrClosedPipe
}
return c.conn.Write(p)
}
func (c *myConn) Close() error {
if c.isReadOnly {
return nil
}
return c.conn.Close()
}
func (c *myConn) LocalAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.LocalAddr()
}
func (c *myConn) RemoteAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.RemoteAddr()
}
func (c *myConn) SetDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetDeadline(t)
}
func (c *myConn) SetReadDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetReadDeadline(t)
}
func (c *myConn) SetWriteDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetWriteDeadline(t)
}
type Listener struct {
net.Listener
cfg *tls.Config
getConfigForClient func(hostname string) *tls.Config
allowNonTls bool
}
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
return l.getConfigForClient
}
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
l.getConfigForClient = getConfigForClient
}
func Listen(network, address string) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: listener}, nil
}
func ListenTLSWithListenConfig(liscfg net.ListenConfig, network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
listener, err := liscfg.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenWithListener(listener net.Listener, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: tlsConfig,
allowNonTls: allowNonTls,
}, nil
}
func (l *Listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
tlsCfg: l.cfg,
getConfigForClient: l.getConfigForClient,
allowNonTls: l.allowNonTls,
}, nil
}
type Conn struct {
net.Conn
once sync.Once
initErr error
isTLS bool
tlsCfg *tls.Config
tlsConn *tls.Conn
buffer *bytes.Buffer
noTlsReader io.Reader
isOriginal bool
getConfigForClient func(hostname string) *tls.Config
hostname string
allowNonTls bool
}
func (c *Conn) Hostname() string {
if c.hostname != "" {
return c.hostname
}
if c.isTLS && c.tlsConn != nil {
if c.tlsConn.ConnectionState().ServerName != "" {
c.hostname = c.tlsConn.ConnectionState().ServerName
return c.hostname
}
}
return ""
}
func (c *Conn) IsTLS() bool {
return c.isTLS
}
func (c *Conn) TlsConn() *tls.Conn {
return c.tlsConn
}
func (c *Conn) isTLSConnection() (bool, error) {
if c.getConfigForClient == nil {
peek := make([]byte, 5)
n, err := io.ReadFull(c.Conn, peek)
if err != nil {
return false, err
}
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
c.buffer = bytes.NewBuffer(peek[:n])
return isTLS, nil
}
c.buffer = new(bytes.Buffer)
r := io.TeeReader(c.Conn, c.buffer)
var hello *tls.ClientHelloInfo
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
peek := c.buffer.Bytes()
n := len(peek)
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
if hello == nil {
return isTLS, nil
}
c.hostname = hello.ServerName
if c.hostname == "" {
c.hostname, _, _ = net.SplitHostPort(c.Conn.LocalAddr().String())
}
return isTLS, nil
}
func (c *Conn) init() {
c.once.Do(func() {
if c.isOriginal {
return
}
if c.tlsCfg != nil {
isTLS, err := c.isTLSConnection()
if err != nil {
c.initErr = err
return
}
c.isTLS = isTLS
}
if c.isTLS {
var cfg = c.tlsCfg
if c.getConfigForClient != nil {
cfg = c.getConfigForClient(c.hostname)
if cfg == nil {
cfg = c.tlsCfg
}
}
c.tlsConn = tls.Server(&myConn{
reader: c.buffer,
conn: c.Conn,
isReadOnly: false,
}, cfg)
} else {
if !c.allowNonTls {
c.initErr = net.ErrClosed
return
}
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
}
})
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.noTlsReader.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.Conn.Write(b)
}
func (c *Conn) Close() error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.Close()
}
return c.Conn.Close()
}
func (c *Conn) SetDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.Conn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.Conn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.Conn.SetWriteDeadline(t)
}
func (c *Conn) TlsConnection() (*tls.Conn, error) {
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS {
return nil, net.ErrClosed
}
return c.tlsConn, nil
}
func (c *Conn) OriginalConn() net.Conn {
return c.Conn
}
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Client(conn, cfg),
isOriginal: true,
}
return c, nil
}
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Server(conn, cfg),
isOriginal: true,
}
c.init()
return c, nil
}
func Dial(network, address string) (*Conn, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
isTLS: false,
tlsCfg: nil,
tlsConn: nil,
noTlsReader: conn,
isOriginal: true,
}, nil
}
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClientTlsConn(conn, tlsConfig)
}