2025-06-12 16:50:47 +08:00
|
|
|
package starnet
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"bytes"
|
2025-06-17 12:09:12 +08:00
|
|
|
"context"
|
2025-06-12 16:50:47 +08:00
|
|
|
"crypto/tls"
|
2026-03-27 12:05:23 +08:00
|
|
|
"errors"
|
2025-06-12 16:50:47 +08:00
|
|
|
"io"
|
|
|
|
|
"net"
|
|
|
|
|
"sync"
|
|
|
|
|
"time"
|
2026-04-19 15:39:51 +08:00
|
|
|
|
|
|
|
|
"b612.me/starnet/internal/tlssniffercore"
|
2025-06-12 16:50:47 +08:00
|
|
|
)
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// replayConn replays buffered bytes first, then reads from live conn.
|
|
|
|
|
type replayConn struct {
|
|
|
|
|
reader io.Reader
|
|
|
|
|
conn net.Conn
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
|
|
|
|
|
return &replayConn{
|
|
|
|
|
reader: io.MultiReader(buffered, conn),
|
|
|
|
|
conn: conn,
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
|
|
|
|
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
|
|
|
|
|
func (c *replayConn) Close() error { return c.conn.Close() }
|
|
|
|
|
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
|
|
|
|
|
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
|
|
|
|
|
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
|
|
|
|
|
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
|
|
|
|
|
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// SniffResult describes protocol sniffing result.
|
|
|
|
|
type SniffResult struct {
|
2026-03-27 12:05:23 +08:00
|
|
|
IsTLS bool
|
|
|
|
|
ClientHello *ClientHelloMeta
|
|
|
|
|
Buffer *bytes.Buffer
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// Sniffer detects protocol and metadata from initial bytes.
|
|
|
|
|
type Sniffer interface {
|
|
|
|
|
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// TLSSniffer is the default sniffer implementation.
|
|
|
|
|
type TLSSniffer struct{}
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// Sniff detects TLS and extracts SNI when possible.
|
|
|
|
|
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
2026-04-19 15:39:51 +08:00
|
|
|
res, err := (tlssniffercore.Sniffer{}).Sniff(conn, maxBytes)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return SniffResult{}, err
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
return convertCoreSniffResult(res), nil
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
|
|
|
|
|
2026-04-19 15:39:51 +08:00
|
|
|
func convertCoreSniffResult(res tlssniffercore.SniffResult) SniffResult {
|
|
|
|
|
out := SniffResult{
|
|
|
|
|
IsTLS: res.IsTLS,
|
|
|
|
|
Buffer: res.Buffer,
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
if res.ClientHello != nil {
|
|
|
|
|
out.ClientHello = convertCoreClientHelloMeta(res.ClientHello)
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
return out
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
|
|
|
|
|
2026-04-19 15:39:51 +08:00
|
|
|
func convertCoreClientHelloMeta(meta *tlssniffercore.ClientHelloMeta) *ClientHelloMeta {
|
|
|
|
|
if meta == nil {
|
|
|
|
|
return nil
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
2026-04-19 15:39:51 +08:00
|
|
|
return &ClientHelloMeta{
|
|
|
|
|
ServerName: meta.ServerName,
|
|
|
|
|
LocalAddr: meta.LocalAddr,
|
|
|
|
|
RemoteAddr: meta.RemoteAddr,
|
|
|
|
|
SupportedProtos: append([]string(nil), meta.SupportedProtos...),
|
|
|
|
|
SupportedVersions: append([]uint16(nil), meta.SupportedVersions...),
|
|
|
|
|
CipherSuites: append([]uint16(nil), meta.CipherSuites...),
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
|
|
|
|
|
// Conn wraps net.Conn with lazy protocol initialization.
|
2025-06-12 16:50:47 +08:00
|
|
|
type Conn struct {
|
|
|
|
|
net.Conn
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
once sync.Once
|
|
|
|
|
initErr error
|
|
|
|
|
closeOnce sync.Once
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
isTLS bool
|
|
|
|
|
tlsConn *tls.Conn
|
|
|
|
|
plainConn net.Conn
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
clientHello *ClientHelloMeta
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
baseTLSConfig *tls.Config
|
|
|
|
|
getConfigForClient GetConfigForClientFunc
|
|
|
|
|
getConfigForClientHello GetConfigForClientHelloFunc
|
|
|
|
|
allowNonTLS bool
|
|
|
|
|
sniffer Sniffer
|
|
|
|
|
sniffTimeout time.Duration
|
|
|
|
|
maxClientHello int
|
|
|
|
|
logger Logger
|
|
|
|
|
stats *Stats
|
|
|
|
|
skipSniff bool
|
2026-03-08 20:19:40 +08:00
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
|
|
|
|
return &Conn{
|
2026-04-19 15:39:51 +08:00
|
|
|
Conn: raw,
|
|
|
|
|
plainConn: raw,
|
|
|
|
|
baseTLSConfig: cfg.BaseTLSConfig,
|
|
|
|
|
getConfigForClient: cfg.GetConfigForClient,
|
2026-03-27 12:05:23 +08:00
|
|
|
getConfigForClientHello: cfg.GetConfigForClientHello,
|
2026-04-19 15:39:51 +08:00
|
|
|
allowNonTLS: cfg.AllowNonTLS,
|
|
|
|
|
sniffer: TLSSniffer{},
|
|
|
|
|
sniffTimeout: cfg.SniffTimeout,
|
|
|
|
|
maxClientHello: cfg.MaxClientHelloBytes,
|
|
|
|
|
logger: cfg.Logger,
|
|
|
|
|
stats: stats,
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) init() {
|
|
|
|
|
c.once.Do(func() {
|
2026-03-08 20:19:40 +08:00
|
|
|
if c.skipSniff {
|
2025-06-12 16:50:47 +08:00
|
|
|
return
|
|
|
|
|
}
|
2026-03-27 12:05:23 +08:00
|
|
|
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
|
2026-03-08 20:19:40 +08:00
|
|
|
c.isTLS = false
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if c.sniffTimeout > 0 {
|
|
|
|
|
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
|
|
|
|
|
if c.sniffTimeout > 0 {
|
|
|
|
|
_ = c.Conn.SetReadDeadline(time.Time{})
|
|
|
|
|
}
|
|
|
|
|
if err != nil {
|
2026-03-27 12:05:23 +08:00
|
|
|
c.initErr = errors.Join(ErrTLSSniffFailed, err)
|
|
|
|
|
c.failSniff(err)
|
2026-03-08 20:19:40 +08:00
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c.isTLS = res.IsTLS
|
2026-03-27 12:05:23 +08:00
|
|
|
c.clientHello = res.ClientHello
|
2025-06-12 16:50:47 +08:00
|
|
|
|
|
|
|
|
if c.isTLS {
|
2026-03-08 20:19:40 +08:00
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incTLSDetected()
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
tlsCfg, errCfg := c.selectTLSConfig()
|
|
|
|
|
if errCfg != nil {
|
2026-03-27 12:05:23 +08:00
|
|
|
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
|
|
|
|
|
c.failTLSConfigSelection(errCfg)
|
2025-06-12 16:50:47 +08:00
|
|
|
return
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
|
|
|
|
c.tlsConn = tls.Server(rc, tlsCfg)
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incPlainDetected()
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
if !c.allowNonTLS {
|
|
|
|
|
c.initErr = ErrNonTLSNotAllowed
|
2026-03-27 12:05:23 +08:00
|
|
|
c.failPlainRejected()
|
2026-03-08 20:19:40 +08:00
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
2025-06-12 16:50:47 +08:00
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func (c *Conn) failAndClose(format string, v ...interface{}) {
|
|
|
|
|
if c.logger != nil {
|
|
|
|
|
c.logger.Printf("starnet: "+format, v...)
|
|
|
|
|
}
|
|
|
|
|
_ = c.Close()
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
func (c *Conn) failSniff(err error) {
|
|
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incSniffFailures()
|
|
|
|
|
}
|
|
|
|
|
c.failAndClose("tls sniff failed: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) failTLSConfigSelection(err error) {
|
|
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incTLSConfigFailures()
|
|
|
|
|
}
|
|
|
|
|
c.failAndClose("tls config selection failed: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) failPlainRejected() {
|
|
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incPlainRejected()
|
|
|
|
|
}
|
|
|
|
|
c.failAndClose("plain tcp rejected")
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
2026-03-27 12:05:23 +08:00
|
|
|
var selected *tls.Config
|
|
|
|
|
if c.getConfigForClientHello != nil {
|
|
|
|
|
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if cfg != nil {
|
|
|
|
|
selected = cfg
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if selected == nil && c.getConfigForClient != nil {
|
|
|
|
|
cfg, err := c.getConfigForClient(c.serverName())
|
2026-03-08 20:19:40 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if cfg != nil {
|
2026-03-27 12:05:23 +08:00
|
|
|
selected = cfg
|
2026-03-08 20:19:40 +08:00
|
|
|
}
|
|
|
|
|
}
|
2026-03-27 12:05:23 +08:00
|
|
|
|
|
|
|
|
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
|
|
|
|
|
if composed != nil {
|
|
|
|
|
return composed, nil
|
2026-03-08 20:19:40 +08:00
|
|
|
}
|
|
|
|
|
return nil, ErrNoTLSConfig
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Hostname returns sniffed SNI hostname (if any).
|
|
|
|
|
func (c *Conn) Hostname() string {
|
|
|
|
|
c.init()
|
2026-03-27 12:05:23 +08:00
|
|
|
return c.serverName()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ClientHello returns sniffed TLS metadata (if any).
|
|
|
|
|
func (c *Conn) ClientHello() *ClientHelloMeta {
|
|
|
|
|
c.init()
|
|
|
|
|
return c.clientHello.Clone()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) serverName() string {
|
|
|
|
|
if c.clientHello == nil {
|
|
|
|
|
return ""
|
|
|
|
|
}
|
|
|
|
|
return c.clientHello.ServerName
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
2026-04-19 15:39:51 +08:00
|
|
|
return tlssniffercore.ComposeServerTLSConfig(base, selected)
|
2026-03-27 12:05:23 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func applyServerTLSOverrides(dst, src *tls.Config) {
|
2026-04-19 15:39:51 +08:00
|
|
|
tlssniffercore.ApplyServerTLSOverrides(dst, src)
|
2026-03-08 20:19:40 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) IsTLS() bool {
|
|
|
|
|
c.init()
|
|
|
|
|
return c.initErr == nil && c.isTLS
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) TLSConn() (*tls.Conn, error) {
|
|
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return nil, c.initErr
|
|
|
|
|
}
|
|
|
|
|
if !c.isTLS || c.tlsConn == nil {
|
|
|
|
|
return nil, ErrNotTLS
|
|
|
|
|
}
|
|
|
|
|
return c.tlsConn, nil
|
|
|
|
|
}
|
|
|
|
|
|
2025-06-12 16:50:47 +08:00
|
|
|
func (c *Conn) Read(b []byte) (int, error) {
|
|
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return 0, c.initErr
|
|
|
|
|
}
|
|
|
|
|
if c.isTLS {
|
|
|
|
|
return c.tlsConn.Read(b)
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return c.plainConn.Read(b)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) Write(b []byte) (int, error) {
|
|
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return 0, c.initErr
|
|
|
|
|
}
|
|
|
|
|
if c.isTLS {
|
|
|
|
|
return c.tlsConn.Write(b)
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return c.plainConn.Write(b)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) Close() error {
|
2026-03-08 20:19:40 +08:00
|
|
|
var err error
|
|
|
|
|
c.closeOnce.Do(func() {
|
|
|
|
|
if c.tlsConn != nil {
|
|
|
|
|
err = c.tlsConn.Close()
|
|
|
|
|
} else {
|
|
|
|
|
err = c.Conn.Close()
|
|
|
|
|
}
|
|
|
|
|
if c.stats != nil {
|
|
|
|
|
c.stats.incClosed()
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
return err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) SetDeadline(t time.Time) error {
|
2026-03-08 20:19:40 +08:00
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return c.initErr
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
if c.isTLS && c.tlsConn != nil {
|
|
|
|
|
return c.tlsConn.SetDeadline(t)
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return c.plainConn.SetDeadline(t)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) SetReadDeadline(t time.Time) error {
|
2026-03-08 20:19:40 +08:00
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return c.initErr
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
if c.isTLS && c.tlsConn != nil {
|
|
|
|
|
return c.tlsConn.SetReadDeadline(t)
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return c.plainConn.SetReadDeadline(t)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *Conn) SetWriteDeadline(t time.Time) error {
|
2026-03-08 20:19:40 +08:00
|
|
|
c.init()
|
|
|
|
|
if c.initErr != nil {
|
|
|
|
|
return c.initErr
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
if c.isTLS && c.tlsConn != nil {
|
|
|
|
|
return c.tlsConn.SetWriteDeadline(t)
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return c.plainConn.SetWriteDeadline(t)
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// Listener wraps net.Listener and returns starnet.Conn from Accept.
|
|
|
|
|
type Listener struct {
|
|
|
|
|
net.Listener
|
|
|
|
|
|
|
|
|
|
mu sync.RWMutex
|
|
|
|
|
cfg ListenerConfig
|
|
|
|
|
stats Stats
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Listen creates a plain listener config (no TLS detection).
|
|
|
|
|
func Listen(network, address string) (*Listener, error) {
|
|
|
|
|
ln, err := net.Listen(network, address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
cfg.BaseTLSConfig = nil
|
|
|
|
|
cfg.GetConfigForClient = nil
|
|
|
|
|
return &Listener{Listener: ln, cfg: cfg}, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ListenWithConfig creates a listener with full config.
|
|
|
|
|
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
|
|
|
|
|
ln, err := net.Listen(network, address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// ListenWithListenConfig creates listener using net.ListenConfig.
|
|
|
|
|
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
|
|
|
|
|
ln, err := lc.Listen(context.Background(), network, address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
|
2025-06-17 12:36:57 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// ListenTLS creates TLS listener from cert/key paths.
|
|
|
|
|
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
|
|
|
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = allowNonTLS
|
|
|
|
|
cfg.BaseTLSConfig = TLSDefaults()
|
|
|
|
|
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
|
|
|
|
|
return ListenWithConfig(network, address, cfg)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
|
|
|
|
|
out := DefaultListenerConfig()
|
|
|
|
|
out.AllowNonTLS = cfg.AllowNonTLS
|
|
|
|
|
out.SniffTimeout = cfg.SniffTimeout
|
|
|
|
|
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
|
|
|
|
|
out.BaseTLSConfig = cfg.BaseTLSConfig
|
|
|
|
|
out.GetConfigForClient = cfg.GetConfigForClient
|
2026-03-27 12:05:23 +08:00
|
|
|
out.GetConfigForClientHello = cfg.GetConfigForClientHello
|
2026-03-08 20:19:40 +08:00
|
|
|
out.Logger = cfg.Logger
|
|
|
|
|
if out.MaxClientHelloBytes <= 0 {
|
|
|
|
|
out.MaxClientHelloBytes = 64 * 1024
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SetConfig atomically replaces listener config for new accepted connections.
|
|
|
|
|
func (l *Listener) SetConfig(cfg ListenerConfig) {
|
|
|
|
|
l.mu.Lock()
|
|
|
|
|
l.cfg = normalizeConfig(cfg)
|
|
|
|
|
l.mu.Unlock()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Config returns a copy of current config.
|
|
|
|
|
func (l *Listener) Config() ListenerConfig {
|
|
|
|
|
l.mu.RLock()
|
|
|
|
|
cfg := l.cfg
|
|
|
|
|
l.mu.RUnlock()
|
|
|
|
|
return cfg
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Stats returns current counters snapshot.
|
|
|
|
|
func (l *Listener) Stats() StatsSnapshot {
|
|
|
|
|
return l.stats.Snapshot()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (l *Listener) Accept() (net.Conn, error) {
|
|
|
|
|
raw, err := l.Listener.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
l.stats.incAccepted()
|
|
|
|
|
|
|
|
|
|
l.mu.RLock()
|
|
|
|
|
cfg := l.cfg
|
|
|
|
|
l.mu.RUnlock()
|
|
|
|
|
|
|
|
|
|
return newConn(raw, cfg, &l.stats), nil
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
|
|
|
|
|
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
|
|
|
|
|
type result struct {
|
|
|
|
|
c net.Conn
|
|
|
|
|
err error
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
ch := make(chan result, 1)
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := l.Accept()
|
|
|
|
|
ch <- result{c: c, err: err}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case <-ctx.Done():
|
|
|
|
|
return nil, ctx.Err()
|
|
|
|
|
case r := <-ch:
|
|
|
|
|
return r.c, r.err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// Dial creates a plain TCP starnet.Conn.
|
2025-06-12 16:50:47 +08:00
|
|
|
func Dial(network, address string) (*Conn, error) {
|
2026-03-08 20:19:40 +08:00
|
|
|
raw, err := net.Dial(network, address)
|
2025-06-12 16:50:47 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
cfg.BaseTLSConfig = nil
|
|
|
|
|
cfg.GetConfigForClient = nil
|
|
|
|
|
c := newConn(raw, cfg, nil)
|
|
|
|
|
c.isTLS = false
|
|
|
|
|
return c, nil
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// DialWithConfig dials with net.Dialer options.
|
|
|
|
|
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
|
|
|
|
|
d := net.Dialer{
|
|
|
|
|
Timeout: dc.Timeout,
|
|
|
|
|
LocalAddr: dc.LocalAddr,
|
|
|
|
|
}
|
|
|
|
|
raw, err := d.Dial(network, address)
|
2025-06-12 16:50:47 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
c := newConn(raw, cfg, nil)
|
|
|
|
|
c.isTLS = false
|
|
|
|
|
return c, nil
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// DialTLSWithConfig creates a TLS client connection wrapper.
|
|
|
|
|
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
|
|
|
|
|
d := net.Dialer{Timeout: timeout}
|
|
|
|
|
raw, err := d.Dial(network, address)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
tc := tls.Client(raw, tlsCfg)
|
|
|
|
|
return &Conn{
|
|
|
|
|
Conn: raw,
|
|
|
|
|
plainConn: raw,
|
|
|
|
|
isTLS: true,
|
|
|
|
|
tlsConn: tc,
|
|
|
|
|
initErr: nil,
|
|
|
|
|
allowNonTLS: false,
|
|
|
|
|
skipSniff: true,
|
|
|
|
|
}, nil
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// DialTLS creates TLS client conn from cert/key paths.
|
|
|
|
|
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
|
|
|
|
|
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
2025-06-12 16:50:47 +08:00
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
2026-03-08 20:19:40 +08:00
|
|
|
cfg := TLSDefaults()
|
|
|
|
|
cfg.Certificates = []tls.Certificate{cert}
|
|
|
|
|
return DialTLSWithConfig(network, address, cfg, 0)
|
|
|
|
|
}
|
2025-06-12 16:50:47 +08:00
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
|
|
|
|
|
if listener == nil {
|
|
|
|
|
return nil, ErrNilConn
|
|
|
|
|
}
|
|
|
|
|
return &Listener{
|
|
|
|
|
Listener: listener,
|
|
|
|
|
cfg: normalizeConfig(cfg),
|
|
|
|
|
}, nil
|
2025-06-12 16:50:47 +08:00
|
|
|
}
|