update tls sniffer
This commit is contained in:
parent
44b807d3d1
commit
7a17672149
24
curl.go
24
curl.go
@ -125,6 +125,7 @@ func (r *Request) Clone() *Request {
|
||||
proxy: r.proxy,
|
||||
timeout: r.timeout,
|
||||
dialTimeout: r.dialTimeout,
|
||||
dialFn: r.dialFn,
|
||||
alreadyApply: r.alreadyApply,
|
||||
disableRedirect: r.disableRedirect,
|
||||
doRawRequest: r.doRawRequest,
|
||||
@ -382,6 +383,7 @@ type RequestOpts struct {
|
||||
proxy string
|
||||
timeout time.Duration
|
||||
dialTimeout time.Duration
|
||||
dialFn func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
headers http.Header
|
||||
cookies []*http.Cookie
|
||||
transport *http.Transport
|
||||
@ -404,6 +406,14 @@ type RequestOpts struct {
|
||||
autoCalcContentLength bool
|
||||
}
|
||||
|
||||
func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return r.dialFn
|
||||
}
|
||||
|
||||
func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) {
|
||||
r.dialFn = dialFn
|
||||
}
|
||||
|
||||
func (r *Request) AutoCalcContentLength() bool {
|
||||
return r.autoCalcContentLength
|
||||
}
|
||||
@ -863,6 +873,14 @@ func WithDialTimeout(timeout time.Duration) RequestOpt {
|
||||
}
|
||||
}
|
||||
|
||||
// if doRawTransport is true, this function will nolonger work
|
||||
func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt {
|
||||
return func(opt *RequestOpts) error {
|
||||
opt.dialFn = fn
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// if doRawTransport is true, this function will nolonger work
|
||||
func WithTimeout(timeout time.Duration) RequestOpt {
|
||||
return func(opt *RequestOpts) error {
|
||||
@ -1450,6 +1468,9 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
if r.dialFn != nil {
|
||||
r.transport.DialContext = r.dialFn
|
||||
}
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
@ -1595,6 +1616,9 @@ func applyOptions(r *Request) error {
|
||||
if r.tlsConfig != nil {
|
||||
r.transport.TLSClientConfig = r.tlsConfig
|
||||
}
|
||||
if r.dialFn != nil {
|
||||
r.transport.DialContext = r.dialFn
|
||||
}
|
||||
r.rawClient.Transport = r.transport
|
||||
if r.disableRedirect {
|
||||
r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
|
374
tlssniffer.go
Normal file
374
tlssniffer.go
Normal file
@ -0,0 +1,374 @@
|
||||
package starnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type myConn struct {
|
||||
reader io.Reader
|
||||
conn net.Conn
|
||||
isReadOnly bool
|
||||
multiReader io.Reader
|
||||
}
|
||||
|
||||
func (c *myConn) Read(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return c.reader.Read(p)
|
||||
}
|
||||
if c.multiReader == nil {
|
||||
c.multiReader = io.MultiReader(c.reader, c.conn)
|
||||
}
|
||||
return c.multiReader.Read(p)
|
||||
}
|
||||
|
||||
func (c *myConn) Write(p []byte) (int, error) {
|
||||
if c.isReadOnly {
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
return c.conn.Write(p)
|
||||
}
|
||||
func (c *myConn) Close() error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.Close()
|
||||
}
|
||||
func (c *myConn) LocalAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.LocalAddr()
|
||||
}
|
||||
func (c *myConn) RemoteAddr() net.Addr {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.RemoteAddr()
|
||||
}
|
||||
func (c *myConn) SetDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetReadDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetReadDeadline(t)
|
||||
}
|
||||
func (c *myConn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isReadOnly {
|
||||
return nil
|
||||
}
|
||||
return c.conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
net.Listener
|
||||
cfg *tls.Config
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config {
|
||||
return l.getConfigForClient
|
||||
}
|
||||
|
||||
func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) {
|
||||
l.getConfigForClient = getConfigForClient
|
||||
}
|
||||
|
||||
func Listen(network, address string) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{Listener: listener}, nil
|
||||
}
|
||||
|
||||
func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) {
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: config,
|
||||
getConfigForClient: getConfigForClient,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
listener, err := net.Listen(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Listener{
|
||||
Listener: listener,
|
||||
cfg: tlsConfig,
|
||||
allowNonTls: allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
tlsCfg: l.cfg,
|
||||
getConfigForClient: l.getConfigForClient,
|
||||
allowNonTls: l.allowNonTls,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
net.Conn
|
||||
once sync.Once
|
||||
initErr error
|
||||
isTLS bool
|
||||
tlsCfg *tls.Config
|
||||
tlsConn *tls.Conn
|
||||
buffer *bytes.Buffer
|
||||
noTlsReader io.Reader
|
||||
isOriginal bool
|
||||
getConfigForClient func(hostname string) *tls.Config
|
||||
hostname string
|
||||
allowNonTls bool
|
||||
}
|
||||
|
||||
func (c *Conn) Hostname() string {
|
||||
if c.hostname != "" {
|
||||
return c.hostname
|
||||
}
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
if c.tlsConn.ConnectionState().ServerName != "" {
|
||||
c.hostname = c.tlsConn.ConnectionState().ServerName
|
||||
return c.hostname
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (c *Conn) IsTLS() bool {
|
||||
return c.isTLS
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConn() *tls.Conn {
|
||||
return c.tlsConn
|
||||
}
|
||||
|
||||
func (c *Conn) isTLSConnection() (bool, error) {
|
||||
if c.getConfigForClient == nil {
|
||||
peek := make([]byte, 5)
|
||||
n, err := io.ReadFull(c.Conn, peek)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
|
||||
c.buffer = bytes.NewBuffer(peek[:n])
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
c.buffer = new(bytes.Buffer)
|
||||
r := io.TeeReader(c.Conn, c.buffer)
|
||||
var hello *tls.ClientHelloInfo
|
||||
tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{
|
||||
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||
hello = new(tls.ClientHelloInfo)
|
||||
*hello = *argHello
|
||||
return nil, nil
|
||||
},
|
||||
}).Handshake()
|
||||
peek := c.buffer.Bytes()
|
||||
n := len(peek)
|
||||
isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
||||
if hello == nil {
|
||||
return isTLS, nil
|
||||
}
|
||||
c.hostname = hello.ServerName
|
||||
if c.hostname == "" {
|
||||
c.hostname, _, _ = net.SplitHostPort(c.Conn.RemoteAddr().String())
|
||||
}
|
||||
return isTLS, nil
|
||||
}
|
||||
|
||||
func (c *Conn) init() {
|
||||
c.once.Do(func() {
|
||||
if c.isOriginal {
|
||||
return
|
||||
}
|
||||
if c.tlsCfg != nil {
|
||||
isTLS, err := c.isTLSConnection()
|
||||
if err != nil {
|
||||
c.initErr = err
|
||||
return
|
||||
}
|
||||
c.isTLS = isTLS
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
var cfg = c.tlsCfg
|
||||
if c.getConfigForClient != nil {
|
||||
cfg = c.getConfigForClient(c.hostname)
|
||||
if cfg == nil {
|
||||
cfg = c.tlsCfg
|
||||
}
|
||||
}
|
||||
c.tlsConn = tls.Server(&myConn{
|
||||
reader: c.buffer,
|
||||
conn: c.Conn,
|
||||
isReadOnly: false,
|
||||
}, cfg)
|
||||
} else {
|
||||
if !c.allowNonTls {
|
||||
c.initErr = net.ErrClosed
|
||||
return
|
||||
}
|
||||
c.noTlsReader = io.MultiReader(c.buffer, c.Conn)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (c *Conn) Read(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Read(b)
|
||||
}
|
||||
return c.noTlsReader.Read(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Write(b []byte) (int, error) {
|
||||
c.init()
|
||||
if c.initErr != nil {
|
||||
return 0, c.initErr
|
||||
}
|
||||
|
||||
if c.isTLS {
|
||||
return c.tlsConn.Write(b)
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c *Conn) Close() error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.Close()
|
||||
}
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
func (c *Conn) SetDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetDeadline(t)
|
||||
}
|
||||
return c.Conn.SetDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetReadDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetReadDeadline(t)
|
||||
}
|
||||
return c.Conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
if c.isTLS && c.tlsConn != nil {
|
||||
return c.tlsConn.SetWriteDeadline(t)
|
||||
}
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *Conn) TlsConnection() (*tls.Conn, error) {
|
||||
if c.initErr != nil {
|
||||
return nil, c.initErr
|
||||
}
|
||||
if !c.isTLS {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
return c.tlsConn, nil
|
||||
}
|
||||
|
||||
func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Client(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
c := &Conn{
|
||||
Conn: conn,
|
||||
isTLS: true,
|
||||
tlsCfg: cfg,
|
||||
tlsConn: tls.Server(conn, cfg),
|
||||
isOriginal: true,
|
||||
}
|
||||
c.init()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func Dial(network, address string) (*Conn, error) {
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
isTLS: false,
|
||||
tlsCfg: nil,
|
||||
tlsConn: nil,
|
||||
noTlsReader: conn,
|
||||
isOriginal: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) {
|
||||
config, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{config},
|
||||
}
|
||||
|
||||
conn, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewClientTlsConn(conn, tlsConfig)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user