update tls sniffer

This commit is contained in:
兔子 2025-06-12 16:50:47 +08:00
parent 44b807d3d1
commit 7a17672149
2 changed files with 398 additions and 0 deletions

24
curl.go
View File

@ -125,6 +125,7 @@ func (r *Request) Clone() *Request {
proxy: r.proxy, proxy: r.proxy,
timeout: r.timeout, timeout: r.timeout,
dialTimeout: r.dialTimeout, dialTimeout: r.dialTimeout,
dialFn: r.dialFn,
alreadyApply: r.alreadyApply, alreadyApply: r.alreadyApply,
disableRedirect: r.disableRedirect, disableRedirect: r.disableRedirect,
doRawRequest: r.doRawRequest, doRawRequest: r.doRawRequest,
@ -382,6 +383,7 @@ type RequestOpts struct {
proxy string proxy string
timeout time.Duration timeout time.Duration
dialTimeout time.Duration dialTimeout time.Duration
dialFn func(ctx context.Context, network, addr string) (net.Conn, error)
headers http.Header headers http.Header
cookies []*http.Cookie cookies []*http.Cookie
transport *http.Transport transport *http.Transport
@ -404,6 +406,14 @@ type RequestOpts struct {
autoCalcContentLength bool autoCalcContentLength bool
} }
func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) {
return r.dialFn
}
func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) {
r.dialFn = dialFn
}
func (r *Request) AutoCalcContentLength() bool { func (r *Request) AutoCalcContentLength() bool {
return r.autoCalcContentLength return r.autoCalcContentLength
} }
@ -863,6 +873,14 @@ func WithDialTimeout(timeout time.Duration) RequestOpt {
} }
} }
// if doRawTransport is true, this function will nolonger work
func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt {
return func(opt *RequestOpts) error {
opt.dialFn = fn
return nil
}
}
// if doRawTransport is true, this function will nolonger work // if doRawTransport is true, this function will nolonger work
func WithTimeout(timeout time.Duration) RequestOpt { func WithTimeout(timeout time.Duration) RequestOpt {
return func(opt *RequestOpts) error { return func(opt *RequestOpts) error {
@ -1450,6 +1468,9 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
} }
return nil, lastErr return nil, lastErr
} }
if r.dialFn != nil {
r.transport.DialContext = r.dialFn
}
} }
return r, nil return r, nil
} }
@ -1595,6 +1616,9 @@ func applyOptions(r *Request) error {
if r.tlsConfig != nil { if r.tlsConfig != nil {
r.transport.TLSClientConfig = r.tlsConfig r.transport.TLSClientConfig = r.tlsConfig
} }
if r.dialFn != nil {
r.transport.DialContext = r.dialFn
}
r.rawClient.Transport = r.transport r.rawClient.Transport = r.transport
if r.disableRedirect { if r.disableRedirect {
r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {

374
tlssniffer.go Normal file
View File

@ -0,0 +1,374 @@
package starnet
import (
"bytes"
"crypto/tls"
"io"
"net"
"sync"
"time"
)
type myConn struct {
reader io.Reader
conn net.Conn
isReadOnly bool
multiReader io.Reader
}
func (c *myConn) Read(p []byte) (int, error) {
if c.isReadOnly {
return c.reader.Read(p)
}
if c.multiReader == nil {
c.multiReader = io.MultiReader(c.reader, c.conn)
}
return c.multiReader.Read(p)
}
func (c *myConn) Write(p []byte) (int, error) {
if c.isReadOnly {
return 0, io.ErrClosedPipe
}
return c.conn.Write(p)
}
func (c *myConn) Close() error {
if c.isReadOnly {
return nil
}
return c.conn.Close()
}
func (c *myConn) LocalAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.LocalAddr()
}
func (c *myConn) RemoteAddr() net.Addr {
if c.isReadOnly {
return nil
}
return c.conn.RemoteAddr()
}
func (c *myConn) SetDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetDeadline(t)
}
func (c *myConn) SetReadDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetReadDeadline(t)
}
func (c *myConn) SetWriteDeadline(t time.Time) error {
if c.isReadOnly {
return nil
}
return c.conn.SetWriteDeadline(t)
}
type Listener struct {
net.Listener
cfg *tls.Config
getConfigForClient func(hostname string) *tls.Config
allowNonTls bool
}
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
return l.getConfigForClient
}
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
l.getConfigForClient = getConfigForClient
}
func Listen(network, address string) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: listener}, nil
}
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: config,
getConfigForClient: getConfigForClient,
allowNonTls: allowNonTls,
}, nil
}
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
listener, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{
Listener: listener,
cfg: tlsConfig,
allowNonTls: allowNonTls,
}, nil
}
func (l *Listener) Accept() (net.Conn, error) {
conn, err := l.Listener.Accept()
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
tlsCfg: l.cfg,
getConfigForClient: l.getConfigForClient,
allowNonTls: l.allowNonTls,
}, nil
}
type Conn struct {
net.Conn
once sync.Once
initErr error
isTLS bool
tlsCfg *tls.Config
tlsConn *tls.Conn
buffer *bytes.Buffer
noTlsReader io.Reader
isOriginal bool
getConfigForClient func(hostname string) *tls.Config
hostname string
allowNonTls bool
}
func (c *Conn) Hostname() string {
if c.hostname != "" {
return c.hostname
}
if c.isTLS && c.tlsConn != nil {
if c.tlsConn.ConnectionState().ServerName != "" {
c.hostname = c.tlsConn.ConnectionState().ServerName
return c.hostname
}
}
return ""
}
func (c *Conn) IsTLS() bool {
return c.isTLS
}
func (c *Conn) TlsConn() *tls.Conn {
return c.tlsConn
}
func (c *Conn) isTLSConnection() (bool, error) {
if c.getConfigForClient == nil {
peek := make([]byte, 5)
n, err := io.ReadFull(c.Conn, peek)
if err != nil {
return false, err
}
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
c.buffer = bytes.NewBuffer(peek[:n])
return isTLS, nil
}
c.buffer = new(bytes.Buffer)
r := io.TeeReader(c.Conn, c.buffer)
var hello *tls.ClientHelloInfo
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
peek := c.buffer.Bytes()
n := len(peek)
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
if hello == nil {
return isTLS, nil
}
c.hostname = hello.ServerName
if c.hostname == "" {
c.hostname, _, _ = net.SplitHostPort(c.Conn.RemoteAddr().String())
}
return isTLS, nil
}
func (c *Conn) init() {
c.once.Do(func() {
if c.isOriginal {
return
}
if c.tlsCfg != nil {
isTLS, err := c.isTLSConnection()
if err != nil {
c.initErr = err
return
}
c.isTLS = isTLS
}
if c.isTLS {
var cfg = c.tlsCfg
if c.getConfigForClient != nil {
cfg = c.getConfigForClient(c.hostname)
if cfg == nil {
cfg = c.tlsCfg
}
}
c.tlsConn = tls.Server(&myConn{
reader: c.buffer,
conn: c.Conn,
isReadOnly: false,
}, cfg)
} else {
if !c.allowNonTls {
c.initErr = net.ErrClosed
return
}
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
}
})
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.noTlsReader.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.Conn.Write(b)
}
func (c *Conn) Close() error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.Close()
}
return c.Conn.Close()
}
func (c *Conn) SetDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.Conn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.Conn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.Conn.SetWriteDeadline(t)
}
func (c *Conn) TlsConnection() (*tls.Conn, error) {
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS {
return nil, net.ErrClosed
}
return c.tlsConn, nil
}
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Client(conn, cfg),
isOriginal: true,
}
return c, nil
}
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
if conn == nil {
return nil, net.ErrClosed
}
c := &Conn{
Conn: conn,
isTLS: true,
tlsCfg: cfg,
tlsConn: tls.Server(conn, cfg),
isOriginal: true,
}
c.init()
return c, nil
}
func Dial(network, address string) (*Conn, error) {
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return &Conn{
Conn: conn,
isTLS: false,
tlsCfg: nil,
tlsConn: nil,
noTlsReader: conn,
isOriginal: true,
}, nil
}
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
config, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{config},
}
conn, err := net.Dial(network, address)
if err != nil {
return nil, err
}
return NewClientTlsConn(conn, tlsConfig)
}