diff --git a/curl.go b/curl.go index a994925..b346a89 100644 --- a/curl.go +++ b/curl.go @@ -125,6 +125,7 @@ func (r *Request) Clone() *Request { proxy: r.proxy, timeout: r.timeout, dialTimeout: r.dialTimeout, + dialFn: r.dialFn, alreadyApply: r.alreadyApply, disableRedirect: r.disableRedirect, doRawRequest: r.doRawRequest, @@ -382,6 +383,7 @@ type RequestOpts struct { proxy string timeout time.Duration dialTimeout time.Duration + dialFn func(ctx context.Context, network, addr string) (net.Conn, error) headers http.Header cookies []*http.Cookie transport *http.Transport @@ -404,6 +406,14 @@ type RequestOpts struct { autoCalcContentLength bool } +func (r *Request) DialFn() func(ctx context.Context, network, addr string) (net.Conn, error) { + return r.dialFn +} + +func (r *Request) SetDialFn(dialFn func(ctx context.Context, network, addr string) (net.Conn, error)) { + r.dialFn = dialFn +} + func (r *Request) AutoCalcContentLength() bool { return r.autoCalcContentLength } @@ -863,6 +873,14 @@ func WithDialTimeout(timeout time.Duration) RequestOpt { } } +// if doRawTransport is true, this function will nolonger work +func WithDial(fn func(ctx context.Context, network string, addr string) (net.Conn, error)) RequestOpt { + return func(opt *RequestOpts) error { + opt.dialFn = fn + return nil + } +} + // if doRawTransport is true, this function will nolonger work func WithTimeout(timeout time.Duration) RequestOpt { return func(opt *RequestOpts) error { @@ -1450,6 +1468,9 @@ func newRequest(ctx context.Context, uri string, method string, opts ...RequestO } return nil, lastErr } + if r.dialFn != nil { + r.transport.DialContext = r.dialFn + } } return r, nil } @@ -1595,6 +1616,9 @@ func applyOptions(r *Request) error { if r.tlsConfig != nil { r.transport.TLSClientConfig = r.tlsConfig } + if r.dialFn != nil { + r.transport.DialContext = r.dialFn + } r.rawClient.Transport = r.transport if r.disableRedirect { r.rawClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { diff --git a/tlssniffer.go b/tlssniffer.go new file mode 100644 index 0000000..89a1798 --- /dev/null +++ b/tlssniffer.go @@ -0,0 +1,374 @@ +package starnet + +import ( + "bytes" + "crypto/tls" + "io" + "net" + "sync" + "time" +) + +type myConn struct { + reader io.Reader + conn net.Conn + isReadOnly bool + multiReader io.Reader +} + +func (c *myConn) Read(p []byte) (int, error) { + if c.isReadOnly { + return c.reader.Read(p) + } + if c.multiReader == nil { + c.multiReader = io.MultiReader(c.reader, c.conn) + } + return c.multiReader.Read(p) +} + +func (c *myConn) Write(p []byte) (int, error) { + if c.isReadOnly { + return 0, io.ErrClosedPipe + } + return c.conn.Write(p) +} +func (c *myConn) Close() error { + if c.isReadOnly { + return nil + } + return c.conn.Close() +} +func (c *myConn) LocalAddr() net.Addr { + if c.isReadOnly { + return nil + } + return c.conn.LocalAddr() +} +func (c *myConn) RemoteAddr() net.Addr { + if c.isReadOnly { + return nil + } + return c.conn.RemoteAddr() +} +func (c *myConn) SetDeadline(t time.Time) error { + if c.isReadOnly { + return nil + } + return c.conn.SetDeadline(t) +} +func (c *myConn) SetReadDeadline(t time.Time) error { + if c.isReadOnly { + return nil + } + return c.conn.SetReadDeadline(t) +} +func (c *myConn) SetWriteDeadline(t time.Time) error { + if c.isReadOnly { + return nil + } + return c.conn.SetWriteDeadline(t) +} + +type Listener struct { + net.Listener + cfg *tls.Config + getConfigForClient func(hostname string) *tls.Config + allowNonTls bool +} + +func (l *Listener) GetConfigForClient() func(hostname string) *tls.Config { + return l.getConfigForClient +} + +func (l *Listener) SetConfigForClient(getConfigForClient func(hostname string) *tls.Config) { + l.getConfigForClient = getConfigForClient +} + +func Listen(network, address string) (*Listener, error) { + listener, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return &Listener{Listener: listener}, nil +} + +func ListenTLSWithConfig(network, address string, config *tls.Config, getConfigForClient func(hostname string) *tls.Config, allowNonTls bool) (*Listener, error) { + listener, err := net.Listen(network, address) + if err != nil { + return nil, err + } + return &Listener{ + Listener: listener, + cfg: config, + getConfigForClient: getConfigForClient, + allowNonTls: allowNonTls, + }, nil +} + +func ListenTLS(network, address string, certFile, keyFile string, allowNonTls bool) (*Listener, error) { + config, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{config}, + } + + listener, err := net.Listen(network, address) + if err != nil { + return nil, err + } + + return &Listener{ + Listener: listener, + cfg: tlsConfig, + allowNonTls: allowNonTls, + }, nil +} + +func (l *Listener) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() + if err != nil { + return nil, err + } + return &Conn{ + Conn: conn, + tlsCfg: l.cfg, + getConfigForClient: l.getConfigForClient, + allowNonTls: l.allowNonTls, + }, nil +} + +type Conn struct { + net.Conn + once sync.Once + initErr error + isTLS bool + tlsCfg *tls.Config + tlsConn *tls.Conn + buffer *bytes.Buffer + noTlsReader io.Reader + isOriginal bool + getConfigForClient func(hostname string) *tls.Config + hostname string + allowNonTls bool +} + +func (c *Conn) Hostname() string { + if c.hostname != "" { + return c.hostname + } + if c.isTLS && c.tlsConn != nil { + if c.tlsConn.ConnectionState().ServerName != "" { + c.hostname = c.tlsConn.ConnectionState().ServerName + return c.hostname + } + } + return "" +} + +func (c *Conn) IsTLS() bool { + return c.isTLS +} + +func (c *Conn) TlsConn() *tls.Conn { + return c.tlsConn +} + +func (c *Conn) isTLSConnection() (bool, error) { + if c.getConfigForClient == nil { + peek := make([]byte, 5) + n, err := io.ReadFull(c.Conn, peek) + if err != nil { + return false, err + } + + isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 + + c.buffer = bytes.NewBuffer(peek[:n]) + return isTLS, nil + } + + c.buffer = new(bytes.Buffer) + r := io.TeeReader(c.Conn, c.buffer) + var hello *tls.ClientHelloInfo + tls.Server(&myConn{reader: r, isReadOnly: true}, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = new(tls.ClientHelloInfo) + *hello = *argHello + return nil, nil + }, + }).Handshake() + peek := c.buffer.Bytes() + n := len(peek) + isTLS := n >= 3 && peek[0] == 0x16 && peek[1] == 0x03 + if hello == nil { + return isTLS, nil + } + c.hostname = hello.ServerName + if c.hostname == "" { + c.hostname, _, _ = net.SplitHostPort(c.Conn.RemoteAddr().String()) + } + return isTLS, nil +} + +func (c *Conn) init() { + c.once.Do(func() { + if c.isOriginal { + return + } + if c.tlsCfg != nil { + isTLS, err := c.isTLSConnection() + if err != nil { + c.initErr = err + return + } + c.isTLS = isTLS + } + + if c.isTLS { + var cfg = c.tlsCfg + if c.getConfigForClient != nil { + cfg = c.getConfigForClient(c.hostname) + if cfg == nil { + cfg = c.tlsCfg + } + } + c.tlsConn = tls.Server(&myConn{ + reader: c.buffer, + conn: c.Conn, + isReadOnly: false, + }, cfg) + } else { + if !c.allowNonTls { + c.initErr = net.ErrClosed + return + } + c.noTlsReader = io.MultiReader(c.buffer, c.Conn) + } + }) +} + +func (c *Conn) Read(b []byte) (int, error) { + c.init() + if c.initErr != nil { + return 0, c.initErr + } + if c.isTLS { + return c.tlsConn.Read(b) + } + return c.noTlsReader.Read(b) +} + +func (c *Conn) Write(b []byte) (int, error) { + c.init() + if c.initErr != nil { + return 0, c.initErr + } + + if c.isTLS { + return c.tlsConn.Write(b) + } + return c.Conn.Write(b) +} + +func (c *Conn) Close() error { + if c.isTLS && c.tlsConn != nil { + return c.tlsConn.Close() + } + return c.Conn.Close() +} + +func (c *Conn) SetDeadline(t time.Time) error { + if c.isTLS && c.tlsConn != nil { + return c.tlsConn.SetDeadline(t) + } + return c.Conn.SetDeadline(t) +} + +func (c *Conn) SetReadDeadline(t time.Time) error { + if c.isTLS && c.tlsConn != nil { + return c.tlsConn.SetReadDeadline(t) + } + return c.Conn.SetReadDeadline(t) +} + +func (c *Conn) SetWriteDeadline(t time.Time) error { + if c.isTLS && c.tlsConn != nil { + return c.tlsConn.SetWriteDeadline(t) + } + return c.Conn.SetWriteDeadline(t) +} + +func (c *Conn) TlsConnection() (*tls.Conn, error) { + if c.initErr != nil { + return nil, c.initErr + } + if !c.isTLS { + return nil, net.ErrClosed + } + return c.tlsConn, nil +} + +func NewClientTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { + if conn == nil { + return nil, net.ErrClosed + } + c := &Conn{ + Conn: conn, + isTLS: true, + tlsCfg: cfg, + tlsConn: tls.Client(conn, cfg), + isOriginal: true, + } + return c, nil +} + +func NewServerTlsConn(conn net.Conn, cfg *tls.Config) (*Conn, error) { + if conn == nil { + return nil, net.ErrClosed + } + c := &Conn{ + Conn: conn, + isTLS: true, + tlsCfg: cfg, + tlsConn: tls.Server(conn, cfg), + isOriginal: true, + } + c.init() + return c, nil +} + +func Dial(network, address string) (*Conn, error) { + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + return &Conn{ + Conn: conn, + isTLS: false, + tlsCfg: nil, + tlsConn: nil, + noTlsReader: conn, + isOriginal: true, + }, nil +} + +func DialTLS(network, address string, certFile, keyFile string) (*Conn, error) { + config, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{config}, + } + + conn, err := net.Dial(network, address) + if err != nil { + return nil, err + } + + return NewClientTlsConn(conn, tlsConfig) +}