starnet/tlssniffer.go
2026-03-08 20:19:40 +08:00

501 lines
12 KiB
Go

package starnet
import (
"bytes"
"context"
"crypto/tls"
"io"
"net"
"sync"
"time"
)
// replayConn replays buffered bytes first, then reads from live conn.
type replayConn struct {
reader io.Reader
conn net.Conn
}
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
return &replayConn{
reader: io.MultiReader(buffered, conn),
conn: conn,
}
}
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c *replayConn) Close() error { return c.conn.Close() }
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
// SniffResult describes protocol sniffing result.
type SniffResult struct {
IsTLS bool
Hostname string
Buffer *bytes.Buffer
}
// Sniffer detects protocol and metadata from initial bytes.
type Sniffer interface {
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
}
// TLSSniffer is the default sniffer implementation.
type TLSSniffer struct{}
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
}
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
tee := io.TeeReader(limited, &buf)
var hello *tls.ClientHelloInfo
_ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
cp := *ch
hello = &cp
return nil, nil
},
}).Handshake()
peek := buf.Bytes()
isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), peek...)),
}
if hello != nil {
out.Hostname = hello.ServerName
}
return out, nil
}
// readOnlyConn rejects writes/close and reads from a reader.
type readOnlyConn struct {
r io.Reader
raw net.Conn
}
func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) }
func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (c readOnlyConn) Close() error { return nil }
func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() }
func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() }
func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil }
func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil }
func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil }
// Conn wraps net.Conn with lazy protocol initialization.
type Conn struct {
net.Conn
once sync.Once
initErr error
closeOnce sync.Once
isTLS bool
tlsConn *tls.Conn
plainConn net.Conn
hostname string
baseTLSConfig *tls.Config
getConfigForClient GetConfigForClientFunc
allowNonTLS bool
sniffer Sniffer
sniffTimeout time.Duration
maxClientHello int
logger Logger
stats *Stats
skipSniff bool
}
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{
Conn: raw,
plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient,
allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger,
stats: stats,
}
}
func (c *Conn) init() {
c.once.Do(func() {
if c.skipSniff {
return
}
if c.baseTLSConfig == nil && c.getConfigForClient == nil {
c.isTLS = false
return
}
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
}
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Time{})
}
if err != nil {
c.initErr = err
c.failAndClose("sniff failed: %v", err)
return
}
c.isTLS = res.IsTLS
c.hostname = res.Hostname
if c.isTLS {
if c.stats != nil {
c.stats.incTLSDetected()
}
tlsCfg, errCfg := c.selectTLSConfig()
if errCfg != nil {
c.initErr = errCfg
c.failAndClose("tls config select failed: %v", errCfg)
return
}
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
c.tlsConn = tls.Server(rc, tlsCfg)
return
}
if c.stats != nil {
c.stats.incPlainDetected()
}
if !c.allowNonTLS {
c.initErr = ErrNonTLSNotAllowed
c.failAndClose("plain tcp rejected")
return
}
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
})
}
func (c *Conn) failAndClose(format string, v ...interface{}) {
if c.stats != nil {
c.stats.incInitFailures()
}
if c.logger != nil {
c.logger.Printf("starnet: "+format, v...)
}
_ = c.Close()
}
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
if c.getConfigForClient != nil {
cfg, err := c.getConfigForClient(c.hostname)
if err != nil {
return nil, err
}
if cfg != nil {
return cfg, nil
}
}
if c.baseTLSConfig != nil {
return c.baseTLSConfig, nil
}
return nil, ErrNoTLSConfig
}
// Hostname returns sniffed SNI hostname (if any).
func (c *Conn) Hostname() string {
c.init()
return c.hostname
}
func (c *Conn) IsTLS() bool {
c.init()
return c.initErr == nil && c.isTLS
}
func (c *Conn) TLSConn() (*tls.Conn, error) {
c.init()
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS || c.tlsConn == nil {
return nil, ErrNotTLS
}
return c.tlsConn, nil
}
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
return c.plainConn.Read(b)
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
return c.plainConn.Write(b)
}
func (c *Conn) Close() error {
var err error
c.closeOnce.Do(func() {
if c.tlsConn != nil {
err = c.tlsConn.Close()
} else {
err = c.Conn.Close()
}
if c.stats != nil {
c.stats.incClosed()
}
})
return err
}
func (c *Conn) SetDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
return c.plainConn.SetDeadline(t)
}
func (c *Conn) SetReadDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
return c.plainConn.SetReadDeadline(t)
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
c.init()
if c.initErr != nil {
return c.initErr
}
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
return c.plainConn.SetWriteDeadline(t)
}
// Listener wraps net.Listener and returns starnet.Conn from Accept.
type Listener struct {
net.Listener
mu sync.RWMutex
cfg ListenerConfig
stats Stats
}
// Listen creates a plain listener config (no TLS detection).
func Listen(network, address string) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
return &Listener{Listener: ln, cfg: cfg}, nil
}
// ListenWithConfig creates a listener with full config.
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenWithListenConfig creates listener using net.ListenConfig.
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
}
// ListenTLS creates TLS listener from cert/key paths.
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = allowNonTLS
cfg.BaseTLSConfig = TLSDefaults()
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
return ListenWithConfig(network, address, cfg)
}
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
out := DefaultListenerConfig()
out.AllowNonTLS = cfg.AllowNonTLS
out.SniffTimeout = cfg.SniffTimeout
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
out.BaseTLSConfig = cfg.BaseTLSConfig
out.GetConfigForClient = cfg.GetConfigForClient
out.Logger = cfg.Logger
if out.MaxClientHelloBytes <= 0 {
out.MaxClientHelloBytes = 64 * 1024
}
return out
}
// SetConfig atomically replaces listener config for new accepted connections.
func (l *Listener) SetConfig(cfg ListenerConfig) {
l.mu.Lock()
l.cfg = normalizeConfig(cfg)
l.mu.Unlock()
}
// Config returns a copy of current config.
func (l *Listener) Config() ListenerConfig {
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return cfg
}
// Stats returns current counters snapshot.
func (l *Listener) Stats() StatsSnapshot {
return l.stats.Snapshot()
}
func (l *Listener) Accept() (net.Conn, error) {
raw, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.stats.incAccepted()
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return newConn(raw, cfg, &l.stats), nil
}
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
type result struct {
c net.Conn
err error
}
ch := make(chan result, 1)
go func() {
c, err := l.Accept()
ch <- result{c: c, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
return r.c, r.err
}
}
// Dial creates a plain TCP starnet.Conn.
func Dial(network, address string) (*Conn, error) {
raw, err := net.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialWithConfig dials with net.Dialer options.
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
d := net.Dialer{
Timeout: dc.Timeout,
LocalAddr: dc.LocalAddr,
}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
// DialTLSWithConfig creates a TLS client connection wrapper.
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
}
tc := tls.Client(raw, tlsCfg)
return &Conn{
Conn: raw,
plainConn: raw,
isTLS: true,
tlsConn: tc,
hostname: "",
initErr: nil,
allowNonTLS: false,
skipSniff: true,
}, nil
}
// DialTLS creates TLS client conn from cert/key paths.
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
cfg := TLSDefaults()
cfg.Certificates = []tls.Certificate{cert}
return DialTLSWithConfig(network, address, cfg, 0)
}
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
if listener == nil {
return nil, ErrNilConn
}
return &Listener{
Listener: listener,
cfg: normalizeConfig(cfg),
}, nil
}