starnet/tlssniffer.go

834 lines
20 KiB
Go
Raw Permalink Normal View History

2025-06-12 16:50:47 +08:00
package starnet
import (
"bytes"
2025-06-17 12:09:12 +08:00
"context"
2025-06-12 16:50:47 +08:00
"crypto/tls"
"encoding/binary"
"errors"
2025-06-12 16:50:47 +08:00
"io"
"net"
"sync"
"time"
)
2026-03-08 20:19:40 +08:00
// replayConn replays buffered bytes first, then reads from live conn.
type replayConn struct {
reader io.Reader
conn net.Conn
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn {
return &replayConn{
reader: io.MultiReader(buffered, conn),
conn: conn,
2025-06-12 16:50:47 +08:00
}
}
2026-03-08 20:19:40 +08:00
func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) }
func (c *replayConn) Close() error { return c.conn.Close() }
func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() }
func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() }
func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) }
func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) }
func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) }
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
// SniffResult describes protocol sniffing result.
type SniffResult struct {
IsTLS bool
ClientHello *ClientHelloMeta
Buffer *bytes.Buffer
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// Sniffer detects protocol and metadata from initial bytes.
type Sniffer interface {
Sniff(conn net.Conn, maxBytes int) (SniffResult, error)
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// TLSSniffer is the default sniffer implementation.
type TLSSniffer struct{}
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
// Sniff detects TLS and extracts SNI when possible.
func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
if maxBytes <= 0 {
maxBytes = 64 * 1024
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
var buf bytes.Buffer
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
meta, isTLS := sniffClientHello(limited, &buf, conn)
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
out := SniffResult{
IsTLS: isTLS,
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
2025-06-12 16:50:47 +08:00
}
if isTLS {
out.ClientHello = meta
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
return out, nil
2025-06-12 16:50:47 +08:00
}
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
meta := &ClientHelloMeta{
LocalAddr: conn.LocalAddr(),
RemoteAddr: conn.RemoteAddr(),
}
header, complete := readTLSRecordHeader(r, buf)
if len(header) < 3 {
return nil, false
}
isTLS := header[0] == 0x16 && header[1] == 0x03
if !isTLS {
return nil, false
}
if len(header) < 5 || !complete {
return meta, true
}
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
if !bodyOK {
return meta, true
}
if len(recordBody) < 4 || recordBody[0] != 0x01 {
return nil, false
}
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
helloBytes := append([]byte(nil), recordBody[4:]...)
for len(helloBytes) < helloLen {
nextHeader, nextOK := readTLSRecordHeader(r, buf)
if len(nextHeader) < 5 || !nextOK {
return meta, true
}
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
return meta, true
}
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
if !nextBodyOK {
return meta, true
}
helloBytes = append(helloBytes, nextBody...)
}
parseClientHelloBody(meta, helloBytes[:helloLen])
return meta, true
}
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
return readBufferedBytes(r, buf, 5)
2025-06-12 16:50:47 +08:00
}
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
if n <= 0 {
return nil, true
}
tmp := make([]byte, n)
readN, err := io.ReadFull(r, tmp)
if readN > 0 {
buf.Write(tmp[:readN])
}
return append([]byte(nil), tmp[:readN]...), err == nil
}
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
if meta == nil || len(body) < 34 {
return
}
offset := 2 + 32
sessionIDLen := int(body[offset])
offset++
if offset+sessionIDLen > len(body) {
return
}
offset += sessionIDLen
if offset+2 > len(body) {
return
}
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+cipherSuitesLen > len(body) {
return
}
for i := 0; i+1 < cipherSuitesLen; i += 2 {
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
}
offset += cipherSuitesLen
if offset >= len(body) {
return
}
compressionMethodsLen := int(body[offset])
offset++
if offset+compressionMethodsLen > len(body) {
return
}
offset += compressionMethodsLen
if offset+2 > len(body) {
return
}
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
offset += 2
if offset+extensionsLen > len(body) {
return
}
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
}
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
for offset := 0; offset+4 <= len(exts); {
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
offset += 4
if offset+extLen > len(exts) {
return
}
extData := exts[offset : offset+extLen]
offset += extLen
switch extType {
case 0:
parseServerNameExtension(meta, extData)
case 16:
parseALPNExtension(meta, extData)
case 43:
parseSupportedVersionsExtension(meta, extData)
}
}
}
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset+3 <= len(list); {
nameType := list[offset]
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
offset += 3
if offset+nameLen > len(list) {
return
}
if nameType == 0 {
meta.ServerName = string(list[offset : offset+nameLen])
return
}
offset += nameLen
}
}
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 2 {
return
}
listLen := int(binary.BigEndian.Uint16(data[:2]))
if listLen == 0 || 2+listLen > len(data) {
return
}
list := data[2 : 2+listLen]
for offset := 0; offset < len(list); {
nameLen := int(list[offset])
offset++
if offset+nameLen > len(list) {
return
}
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
offset += nameLen
}
}
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
if len(data) < 1 {
return
}
listLen := int(data[0])
if listLen == 0 || 1+listLen > len(data) {
return
}
list := data[1 : 1+listLen]
for offset := 0; offset+1 < len(list); offset += 2 {
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
}
}
2026-03-08 20:19:40 +08:00
// Conn wraps net.Conn with lazy protocol initialization.
2025-06-12 16:50:47 +08:00
type Conn struct {
net.Conn
2026-03-08 20:19:40 +08:00
once sync.Once
initErr error
closeOnce sync.Once
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
isTLS bool
tlsConn *tls.Conn
plainConn net.Conn
2025-06-12 16:50:47 +08:00
clientHello *ClientHelloMeta
2025-06-12 16:50:47 +08:00
baseTLSConfig *tls.Config
getConfigForClient GetConfigForClientFunc
getConfigForClientHello GetConfigForClientHelloFunc
allowNonTLS bool
sniffer Sniffer
sniffTimeout time.Duration
maxClientHello int
logger Logger
stats *Stats
skipSniff bool
2026-03-08 20:19:40 +08:00
}
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
return &Conn{
Conn: raw,
plainConn: raw,
baseTLSConfig: cfg.BaseTLSConfig,
getConfigForClient: cfg.GetConfigForClient,
getConfigForClientHello: cfg.GetConfigForClientHello,
2026-03-08 20:19:40 +08:00
allowNonTLS: cfg.AllowNonTLS,
sniffer: TLSSniffer{},
sniffTimeout: cfg.SniffTimeout,
maxClientHello: cfg.MaxClientHelloBytes,
logger: cfg.Logger,
stats: stats,
2025-06-12 16:50:47 +08:00
}
}
func (c *Conn) init() {
c.once.Do(func() {
2026-03-08 20:19:40 +08:00
if c.skipSniff {
2025-06-12 16:50:47 +08:00
return
}
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
2026-03-08 20:19:40 +08:00
c.isTLS = false
return
}
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout))
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello)
if c.sniffTimeout > 0 {
_ = c.Conn.SetReadDeadline(time.Time{})
}
if err != nil {
c.initErr = errors.Join(ErrTLSSniffFailed, err)
c.failSniff(err)
2026-03-08 20:19:40 +08:00
return
}
c.isTLS = res.IsTLS
c.clientHello = res.ClientHello
2025-06-12 16:50:47 +08:00
if c.isTLS {
2026-03-08 20:19:40 +08:00
if c.stats != nil {
c.stats.incTLSDetected()
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
tlsCfg, errCfg := c.selectTLSConfig()
if errCfg != nil {
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
c.failTLSConfigSelection(errCfg)
2025-06-12 16:50:47 +08:00
return
}
2026-03-08 20:19:40 +08:00
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
c.tlsConn = tls.Server(rc, tlsCfg)
return
}
if c.stats != nil {
c.stats.incPlainDetected()
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
if !c.allowNonTLS {
c.initErr = ErrNonTLSNotAllowed
c.failPlainRejected()
2026-03-08 20:19:40 +08:00
return
}
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
2025-06-12 16:50:47 +08:00
})
}
2026-03-08 20:19:40 +08:00
func (c *Conn) failAndClose(format string, v ...interface{}) {
if c.logger != nil {
c.logger.Printf("starnet: "+format, v...)
}
_ = c.Close()
}
func (c *Conn) failSniff(err error) {
if c.stats != nil {
c.stats.incSniffFailures()
}
c.failAndClose("tls sniff failed: %v", err)
}
func (c *Conn) failTLSConfigSelection(err error) {
if c.stats != nil {
c.stats.incTLSConfigFailures()
}
c.failAndClose("tls config selection failed: %v", err)
}
func (c *Conn) failPlainRejected() {
if c.stats != nil {
c.stats.incPlainRejected()
}
c.failAndClose("plain tcp rejected")
}
2026-03-08 20:19:40 +08:00
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
var selected *tls.Config
if c.getConfigForClientHello != nil {
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
}
}
if selected == nil && c.getConfigForClient != nil {
cfg, err := c.getConfigForClient(c.serverName())
2026-03-08 20:19:40 +08:00
if err != nil {
return nil, err
}
if cfg != nil {
selected = cfg
2026-03-08 20:19:40 +08:00
}
}
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
if composed != nil {
return composed, nil
2026-03-08 20:19:40 +08:00
}
return nil, ErrNoTLSConfig
}
// Hostname returns sniffed SNI hostname (if any).
func (c *Conn) Hostname() string {
c.init()
return c.serverName()
}
// ClientHello returns sniffed TLS metadata (if any).
func (c *Conn) ClientHello() *ClientHelloMeta {
c.init()
return c.clientHello.Clone()
}
func (c *Conn) serverName() string {
if c.clientHello == nil {
return ""
}
return c.clientHello.ServerName
}
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
if base == nil {
return selected
}
if selected == nil {
return base
}
out := base.Clone()
applyServerTLSOverrides(out, selected)
return out
}
func applyServerTLSOverrides(dst, src *tls.Config) {
if dst == nil || src == nil {
return
}
if src.Rand != nil {
dst.Rand = src.Rand
}
if src.Time != nil {
dst.Time = src.Time
}
if len(src.Certificates) > 0 {
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
}
if len(src.NameToCertificate) > 0 {
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
for k, v := range src.NameToCertificate {
m[k] = v
}
dst.NameToCertificate = m
}
if src.GetCertificate != nil {
dst.GetCertificate = src.GetCertificate
}
if src.GetClientCertificate != nil {
dst.GetClientCertificate = src.GetClientCertificate
}
if src.GetConfigForClient != nil {
dst.GetConfigForClient = src.GetConfigForClient
}
if src.VerifyPeerCertificate != nil {
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
}
if src.VerifyConnection != nil {
dst.VerifyConnection = src.VerifyConnection
}
if src.RootCAs != nil {
dst.RootCAs = src.RootCAs
}
if len(src.NextProtos) > 0 {
dst.NextProtos = append([]string(nil), src.NextProtos...)
}
if src.ServerName != "" {
dst.ServerName = src.ServerName
}
if src.ClientAuth > dst.ClientAuth {
dst.ClientAuth = src.ClientAuth
}
if src.ClientCAs != nil {
dst.ClientCAs = src.ClientCAs
}
if src.InsecureSkipVerify {
dst.InsecureSkipVerify = true
}
if len(src.CipherSuites) > 0 {
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
}
if src.PreferServerCipherSuites {
dst.PreferServerCipherSuites = true
}
if src.SessionTicketsDisabled {
dst.SessionTicketsDisabled = true
}
if src.SessionTicketKey != ([32]byte{}) {
dst.SessionTicketKey = src.SessionTicketKey
}
if src.ClientSessionCache != nil {
dst.ClientSessionCache = src.ClientSessionCache
}
if src.UnwrapSession != nil {
dst.UnwrapSession = src.UnwrapSession
}
if src.WrapSession != nil {
dst.WrapSession = src.WrapSession
}
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
dst.MinVersion = src.MinVersion
}
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
dst.MaxVersion = src.MaxVersion
}
if len(src.CurvePreferences) > 0 {
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
}
if src.DynamicRecordSizingDisabled {
dst.DynamicRecordSizingDisabled = true
}
if src.Renegotiation != 0 {
dst.Renegotiation = src.Renegotiation
}
if src.KeyLogWriter != nil {
dst.KeyLogWriter = src.KeyLogWriter
}
if len(src.EncryptedClientHelloConfigList) > 0 {
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
}
if src.EncryptedClientHelloRejectionVerify != nil {
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
}
if src.GetEncryptedClientHelloKeys != nil {
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
}
if len(src.EncryptedClientHelloKeys) > 0 {
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
}
2026-03-08 20:19:40 +08:00
}
func (c *Conn) IsTLS() bool {
c.init()
return c.initErr == nil && c.isTLS
}
func (c *Conn) TLSConn() (*tls.Conn, error) {
c.init()
if c.initErr != nil {
return nil, c.initErr
}
if !c.isTLS || c.tlsConn == nil {
return nil, ErrNotTLS
}
return c.tlsConn, nil
}
2025-06-12 16:50:47 +08:00
func (c *Conn) Read(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Read(b)
}
2026-03-08 20:19:40 +08:00
return c.plainConn.Read(b)
2025-06-12 16:50:47 +08:00
}
func (c *Conn) Write(b []byte) (int, error) {
c.init()
if c.initErr != nil {
return 0, c.initErr
}
if c.isTLS {
return c.tlsConn.Write(b)
}
2026-03-08 20:19:40 +08:00
return c.plainConn.Write(b)
2025-06-12 16:50:47 +08:00
}
func (c *Conn) Close() error {
2026-03-08 20:19:40 +08:00
var err error
c.closeOnce.Do(func() {
if c.tlsConn != nil {
err = c.tlsConn.Close()
} else {
err = c.Conn.Close()
}
if c.stats != nil {
c.stats.incClosed()
}
})
return err
2025-06-12 16:50:47 +08:00
}
func (c *Conn) SetDeadline(t time.Time) error {
2026-03-08 20:19:40 +08:00
c.init()
if c.initErr != nil {
return c.initErr
}
2025-06-12 16:50:47 +08:00
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetDeadline(t)
}
2026-03-08 20:19:40 +08:00
return c.plainConn.SetDeadline(t)
2025-06-12 16:50:47 +08:00
}
func (c *Conn) SetReadDeadline(t time.Time) error {
2026-03-08 20:19:40 +08:00
c.init()
if c.initErr != nil {
return c.initErr
}
2025-06-12 16:50:47 +08:00
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetReadDeadline(t)
}
2026-03-08 20:19:40 +08:00
return c.plainConn.SetReadDeadline(t)
2025-06-12 16:50:47 +08:00
}
func (c *Conn) SetWriteDeadline(t time.Time) error {
2026-03-08 20:19:40 +08:00
c.init()
if c.initErr != nil {
return c.initErr
}
2025-06-12 16:50:47 +08:00
if c.isTLS && c.tlsConn != nil {
return c.tlsConn.SetWriteDeadline(t)
}
2026-03-08 20:19:40 +08:00
return c.plainConn.SetWriteDeadline(t)
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// Listener wraps net.Listener and returns starnet.Conn from Accept.
type Listener struct {
net.Listener
mu sync.RWMutex
cfg ListenerConfig
stats Stats
}
// Listen creates a plain listener config (no TLS detection).
func Listen(network, address string) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
return &Listener{Listener: ln, cfg: cfg}, nil
}
// ListenWithConfig creates a listener with full config.
func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := net.Listen(network, address)
if err != nil {
return nil, err
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// ListenWithListenConfig creates listener using net.ListenConfig.
func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) {
ln, err := lc.Listen(context.Background(), network, address)
if err != nil {
return nil, err
}
return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil
2025-06-17 12:36:57 +08:00
}
2026-03-08 20:19:40 +08:00
// ListenTLS creates TLS listener from cert/key paths.
func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = allowNonTLS
cfg.BaseTLSConfig = TLSDefaults()
cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert}
return ListenWithConfig(network, address, cfg)
}
func normalizeConfig(cfg ListenerConfig) ListenerConfig {
out := DefaultListenerConfig()
out.AllowNonTLS = cfg.AllowNonTLS
out.SniffTimeout = cfg.SniffTimeout
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
out.BaseTLSConfig = cfg.BaseTLSConfig
out.GetConfigForClient = cfg.GetConfigForClient
out.GetConfigForClientHello = cfg.GetConfigForClientHello
2026-03-08 20:19:40 +08:00
out.Logger = cfg.Logger
if out.MaxClientHelloBytes <= 0 {
out.MaxClientHelloBytes = 64 * 1024
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
return out
}
// SetConfig atomically replaces listener config for new accepted connections.
func (l *Listener) SetConfig(cfg ListenerConfig) {
l.mu.Lock()
l.cfg = normalizeConfig(cfg)
l.mu.Unlock()
}
// Config returns a copy of current config.
func (l *Listener) Config() ListenerConfig {
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return cfg
}
// Stats returns current counters snapshot.
func (l *Listener) Stats() StatsSnapshot {
return l.stats.Snapshot()
}
func (l *Listener) Accept() (net.Conn, error) {
raw, err := l.Listener.Accept()
if err != nil {
return nil, err
}
l.stats.incAccepted()
l.mu.RLock()
cfg := l.cfg
l.mu.RUnlock()
return newConn(raw, cfg, &l.stats), nil
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// AcceptContext supports cancellation by closing accepted conn when ctx is done early.
func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) {
type result struct {
c net.Conn
err error
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
ch := make(chan result, 1)
go func() {
c, err := l.Accept()
ch <- result{c: c, err: err}
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-ch:
return r.c, r.err
2025-06-12 16:50:47 +08:00
}
}
2026-03-08 20:19:40 +08:00
// Dial creates a plain TCP starnet.Conn.
2025-06-12 16:50:47 +08:00
func Dial(network, address string) (*Conn, error) {
2026-03-08 20:19:40 +08:00
raw, err := net.Dial(network, address)
2025-06-12 16:50:47 +08:00
if err != nil {
return nil, err
}
2026-03-08 20:19:40 +08:00
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
cfg.GetConfigForClient = nil
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
// DialWithConfig dials with net.Dialer options.
func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) {
d := net.Dialer{
Timeout: dc.Timeout,
LocalAddr: dc.LocalAddr,
}
raw, err := d.Dial(network, address)
2025-06-12 16:50:47 +08:00
if err != nil {
return nil, err
}
2026-03-08 20:19:40 +08:00
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
c := newConn(raw, cfg, nil)
c.isTLS = false
return c, nil
}
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
// DialTLSWithConfig creates a TLS client connection wrapper.
func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) {
d := net.Dialer{Timeout: timeout}
raw, err := d.Dial(network, address)
if err != nil {
return nil, err
2025-06-12 16:50:47 +08:00
}
2026-03-08 20:19:40 +08:00
tc := tls.Client(raw, tlsCfg)
return &Conn{
Conn: raw,
plainConn: raw,
isTLS: true,
tlsConn: tc,
initErr: nil,
allowNonTLS: false,
skipSniff: true,
}, nil
}
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
// DialTLS creates TLS client conn from cert/key paths.
func DialTLS(network, address, certFile, keyFile string) (*Conn, error) {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
2025-06-12 16:50:47 +08:00
if err != nil {
return nil, err
}
2026-03-08 20:19:40 +08:00
cfg := TLSDefaults()
cfg.Certificates = []tls.Certificate{cert}
return DialTLSWithConfig(network, address, cfg, 0)
}
2025-06-12 16:50:47 +08:00
2026-03-08 20:19:40 +08:00
func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) {
if listener == nil {
return nil, ErrNilConn
}
return &Listener{
Listener: listener,
cfg: normalizeConfig(cfg),
}, nil
2025-06-12 16:50:47 +08:00
}