package starnet import ( "bytes" "context" "crypto/tls" "encoding/binary" "errors" "io" "net" "sync" "time" ) // replayConn replays buffered bytes first, then reads from live conn. type replayConn struct { reader io.Reader conn net.Conn } func newReplayConn(buffered io.Reader, conn net.Conn) *replayConn { return &replayConn{ reader: io.MultiReader(buffered, conn), conn: conn, } } func (c *replayConn) Read(p []byte) (int, error) { return c.reader.Read(p) } func (c *replayConn) Write(p []byte) (int, error) { return c.conn.Write(p) } func (c *replayConn) Close() error { return c.conn.Close() } func (c *replayConn) LocalAddr() net.Addr { return c.conn.LocalAddr() } func (c *replayConn) RemoteAddr() net.Addr { return c.conn.RemoteAddr() } func (c *replayConn) SetDeadline(t time.Time) error { return c.conn.SetDeadline(t) } func (c *replayConn) SetReadDeadline(t time.Time) error { return c.conn.SetReadDeadline(t) } func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWriteDeadline(t) } // SniffResult describes protocol sniffing result. type SniffResult struct { IsTLS bool ClientHello *ClientHelloMeta Buffer *bytes.Buffer } // Sniffer detects protocol and metadata from initial bytes. type Sniffer interface { Sniff(conn net.Conn, maxBytes int) (SniffResult, error) } // TLSSniffer is the default sniffer implementation. type TLSSniffer struct{} // Sniff detects TLS and extracts SNI when possible. func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { if maxBytes <= 0 { maxBytes = 64 * 1024 } var buf bytes.Buffer limited := &io.LimitedReader{R: conn, N: int64(maxBytes)} meta, isTLS := sniffClientHello(limited, &buf, conn) out := SniffResult{ IsTLS: isTLS, Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)), } if isTLS { out.ClientHello = meta } return out, nil } func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) { meta := &ClientHelloMeta{ LocalAddr: conn.LocalAddr(), RemoteAddr: conn.RemoteAddr(), } header, complete := readTLSRecordHeader(r, buf) if len(header) < 3 { return nil, false } isTLS := header[0] == 0x16 && header[1] == 0x03 if !isTLS { return nil, false } if len(header) < 5 || !complete { return meta, true } recordLen := int(binary.BigEndian.Uint16(header[3:5])) recordBody, bodyOK := readBufferedBytes(r, buf, recordLen) if !bodyOK { return meta, true } if len(recordBody) < 4 || recordBody[0] != 0x01 { return nil, false } helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3]) helloBytes := append([]byte(nil), recordBody[4:]...) for len(helloBytes) < helloLen { nextHeader, nextOK := readTLSRecordHeader(r, buf) if len(nextHeader) < 5 || !nextOK { return meta, true } if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 { return meta, true } nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5])) nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen) if !nextBodyOK { return meta, true } helloBytes = append(helloBytes, nextBody...) } parseClientHelloBody(meta, helloBytes[:helloLen]) return meta, true } func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) { return readBufferedBytes(r, buf, 5) } func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) { if n <= 0 { return nil, true } tmp := make([]byte, n) readN, err := io.ReadFull(r, tmp) if readN > 0 { buf.Write(tmp[:readN]) } return append([]byte(nil), tmp[:readN]...), err == nil } func parseClientHelloBody(meta *ClientHelloMeta, body []byte) { if meta == nil || len(body) < 34 { return } offset := 2 + 32 sessionIDLen := int(body[offset]) offset++ if offset+sessionIDLen > len(body) { return } offset += sessionIDLen if offset+2 > len(body) { return } cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) offset += 2 if offset+cipherSuitesLen > len(body) { return } for i := 0; i+1 < cipherSuitesLen; i += 2 { meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2])) } offset += cipherSuitesLen if offset >= len(body) { return } compressionMethodsLen := int(body[offset]) offset++ if offset+compressionMethodsLen > len(body) { return } offset += compressionMethodsLen if offset+2 > len(body) { return } extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) offset += 2 if offset+extensionsLen > len(body) { return } parseClientHelloExtensions(meta, body[offset:offset+extensionsLen]) } func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) { for offset := 0; offset+4 <= len(exts); { extType := binary.BigEndian.Uint16(exts[offset : offset+2]) extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4])) offset += 4 if offset+extLen > len(exts) { return } extData := exts[offset : offset+extLen] offset += extLen switch extType { case 0: parseServerNameExtension(meta, extData) case 16: parseALPNExtension(meta, extData) case 43: parseSupportedVersionsExtension(meta, extData) } } } func parseServerNameExtension(meta *ClientHelloMeta, data []byte) { if len(data) < 2 { return } listLen := int(binary.BigEndian.Uint16(data[:2])) if listLen == 0 || 2+listLen > len(data) { return } list := data[2 : 2+listLen] for offset := 0; offset+3 <= len(list); { nameType := list[offset] nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3])) offset += 3 if offset+nameLen > len(list) { return } if nameType == 0 { meta.ServerName = string(list[offset : offset+nameLen]) return } offset += nameLen } } func parseALPNExtension(meta *ClientHelloMeta, data []byte) { if len(data) < 2 { return } listLen := int(binary.BigEndian.Uint16(data[:2])) if listLen == 0 || 2+listLen > len(data) { return } list := data[2 : 2+listLen] for offset := 0; offset < len(list); { nameLen := int(list[offset]) offset++ if offset+nameLen > len(list) { return } meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen])) offset += nameLen } } func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) { if len(data) < 1 { return } listLen := int(data[0]) if listLen == 0 || 1+listLen > len(data) { return } list := data[1 : 1+listLen] for offset := 0; offset+1 < len(list); offset += 2 { meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2])) } } // Conn wraps net.Conn with lazy protocol initialization. type Conn struct { net.Conn once sync.Once initErr error closeOnce sync.Once isTLS bool tlsConn *tls.Conn plainConn net.Conn clientHello *ClientHelloMeta baseTLSConfig *tls.Config getConfigForClient GetConfigForClientFunc getConfigForClientHello GetConfigForClientHelloFunc allowNonTLS bool sniffer Sniffer sniffTimeout time.Duration maxClientHello int logger Logger stats *Stats skipSniff bool } func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { return &Conn{ Conn: raw, plainConn: raw, baseTLSConfig: cfg.BaseTLSConfig, getConfigForClient: cfg.GetConfigForClient, getConfigForClientHello: cfg.GetConfigForClientHello, allowNonTLS: cfg.AllowNonTLS, sniffer: TLSSniffer{}, sniffTimeout: cfg.SniffTimeout, maxClientHello: cfg.MaxClientHelloBytes, logger: cfg.Logger, stats: stats, } } func (c *Conn) init() { c.once.Do(func() { if c.skipSniff { return } if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil { c.isTLS = false return } if c.sniffTimeout > 0 { _ = c.Conn.SetReadDeadline(time.Now().Add(c.sniffTimeout)) } res, err := c.sniffer.Sniff(c.Conn, c.maxClientHello) if c.sniffTimeout > 0 { _ = c.Conn.SetReadDeadline(time.Time{}) } if err != nil { c.initErr = errors.Join(ErrTLSSniffFailed, err) c.failSniff(err) return } c.isTLS = res.IsTLS c.clientHello = res.ClientHello if c.isTLS { if c.stats != nil { c.stats.incTLSDetected() } tlsCfg, errCfg := c.selectTLSConfig() if errCfg != nil { c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg) c.failTLSConfigSelection(errCfg) return } rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) c.tlsConn = tls.Server(rc, tlsCfg) return } if c.stats != nil { c.stats.incPlainDetected() } if !c.allowNonTLS { c.initErr = ErrNonTLSNotAllowed c.failPlainRejected() return } c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) }) } func (c *Conn) failAndClose(format string, v ...interface{}) { if c.logger != nil { c.logger.Printf("starnet: "+format, v...) } _ = c.Close() } func (c *Conn) failSniff(err error) { if c.stats != nil { c.stats.incSniffFailures() } c.failAndClose("tls sniff failed: %v", err) } func (c *Conn) failTLSConfigSelection(err error) { if c.stats != nil { c.stats.incTLSConfigFailures() } c.failAndClose("tls config selection failed: %v", err) } func (c *Conn) failPlainRejected() { if c.stats != nil { c.stats.incPlainRejected() } c.failAndClose("plain tcp rejected") } func (c *Conn) selectTLSConfig() (*tls.Config, error) { var selected *tls.Config if c.getConfigForClientHello != nil { cfg, err := c.getConfigForClientHello(c.clientHello.Clone()) if err != nil { return nil, err } if cfg != nil { selected = cfg } } if selected == nil && c.getConfigForClient != nil { cfg, err := c.getConfigForClient(c.serverName()) if err != nil { return nil, err } if cfg != nil { selected = cfg } } composed := composeServerTLSConfig(c.baseTLSConfig, selected) if composed != nil { return composed, nil } return nil, ErrNoTLSConfig } // Hostname returns sniffed SNI hostname (if any). func (c *Conn) Hostname() string { c.init() return c.serverName() } // ClientHello returns sniffed TLS metadata (if any). func (c *Conn) ClientHello() *ClientHelloMeta { c.init() return c.clientHello.Clone() } func (c *Conn) serverName() string { if c.clientHello == nil { return "" } return c.clientHello.ServerName } func composeServerTLSConfig(base, selected *tls.Config) *tls.Config { if base == nil { return selected } if selected == nil { return base } out := base.Clone() applyServerTLSOverrides(out, selected) return out } func applyServerTLSOverrides(dst, src *tls.Config) { if dst == nil || src == nil { return } if src.Rand != nil { dst.Rand = src.Rand } if src.Time != nil { dst.Time = src.Time } if len(src.Certificates) > 0 { dst.Certificates = append([]tls.Certificate(nil), src.Certificates...) } if len(src.NameToCertificate) > 0 { m := make(map[string]*tls.Certificate, len(src.NameToCertificate)) for k, v := range src.NameToCertificate { m[k] = v } dst.NameToCertificate = m } if src.GetCertificate != nil { dst.GetCertificate = src.GetCertificate } if src.GetClientCertificate != nil { dst.GetClientCertificate = src.GetClientCertificate } if src.GetConfigForClient != nil { dst.GetConfigForClient = src.GetConfigForClient } if src.VerifyPeerCertificate != nil { dst.VerifyPeerCertificate = src.VerifyPeerCertificate } if src.VerifyConnection != nil { dst.VerifyConnection = src.VerifyConnection } if src.RootCAs != nil { dst.RootCAs = src.RootCAs } if len(src.NextProtos) > 0 { dst.NextProtos = append([]string(nil), src.NextProtos...) } if src.ServerName != "" { dst.ServerName = src.ServerName } if src.ClientAuth > dst.ClientAuth { dst.ClientAuth = src.ClientAuth } if src.ClientCAs != nil { dst.ClientCAs = src.ClientCAs } if src.InsecureSkipVerify { dst.InsecureSkipVerify = true } if len(src.CipherSuites) > 0 { dst.CipherSuites = append([]uint16(nil), src.CipherSuites...) } if src.PreferServerCipherSuites { dst.PreferServerCipherSuites = true } if src.SessionTicketsDisabled { dst.SessionTicketsDisabled = true } if src.SessionTicketKey != ([32]byte{}) { dst.SessionTicketKey = src.SessionTicketKey } if src.ClientSessionCache != nil { dst.ClientSessionCache = src.ClientSessionCache } if src.UnwrapSession != nil { dst.UnwrapSession = src.UnwrapSession } if src.WrapSession != nil { dst.WrapSession = src.WrapSession } if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) { dst.MinVersion = src.MinVersion } if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) { dst.MaxVersion = src.MaxVersion } if len(src.CurvePreferences) > 0 { dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...) } if src.DynamicRecordSizingDisabled { dst.DynamicRecordSizingDisabled = true } if src.Renegotiation != 0 { dst.Renegotiation = src.Renegotiation } if src.KeyLogWriter != nil { dst.KeyLogWriter = src.KeyLogWriter } if len(src.EncryptedClientHelloConfigList) > 0 { dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...) } if src.EncryptedClientHelloRejectionVerify != nil { dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify } if src.GetEncryptedClientHelloKeys != nil { dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys } if len(src.EncryptedClientHelloKeys) > 0 { dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...) } } func (c *Conn) IsTLS() bool { c.init() return c.initErr == nil && c.isTLS } func (c *Conn) TLSConn() (*tls.Conn, error) { c.init() if c.initErr != nil { return nil, c.initErr } if !c.isTLS || c.tlsConn == nil { return nil, ErrNotTLS } return c.tlsConn, nil } func (c *Conn) Read(b []byte) (int, error) { c.init() if c.initErr != nil { return 0, c.initErr } if c.isTLS { return c.tlsConn.Read(b) } return c.plainConn.Read(b) } func (c *Conn) Write(b []byte) (int, error) { c.init() if c.initErr != nil { return 0, c.initErr } if c.isTLS { return c.tlsConn.Write(b) } return c.plainConn.Write(b) } func (c *Conn) Close() error { var err error c.closeOnce.Do(func() { if c.tlsConn != nil { err = c.tlsConn.Close() } else { err = c.Conn.Close() } if c.stats != nil { c.stats.incClosed() } }) return err } func (c *Conn) SetDeadline(t time.Time) error { c.init() if c.initErr != nil { return c.initErr } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetDeadline(t) } return c.plainConn.SetDeadline(t) } func (c *Conn) SetReadDeadline(t time.Time) error { c.init() if c.initErr != nil { return c.initErr } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetReadDeadline(t) } return c.plainConn.SetReadDeadline(t) } func (c *Conn) SetWriteDeadline(t time.Time) error { c.init() if c.initErr != nil { return c.initErr } if c.isTLS && c.tlsConn != nil { return c.tlsConn.SetWriteDeadline(t) } return c.plainConn.SetWriteDeadline(t) } // Listener wraps net.Listener and returns starnet.Conn from Accept. type Listener struct { net.Listener mu sync.RWMutex cfg ListenerConfig stats Stats } // Listen creates a plain listener config (no TLS detection). func Listen(network, address string) (*Listener, error) { ln, err := net.Listen(network, address) if err != nil { return nil, err } cfg := DefaultListenerConfig() cfg.AllowNonTLS = true cfg.BaseTLSConfig = nil cfg.GetConfigForClient = nil return &Listener{Listener: ln, cfg: cfg}, nil } // ListenWithConfig creates a listener with full config. func ListenWithConfig(network, address string, cfg ListenerConfig) (*Listener, error) { ln, err := net.Listen(network, address) if err != nil { return nil, err } return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil } // ListenWithListenConfig creates listener using net.ListenConfig. func ListenWithListenConfig(lc net.ListenConfig, network, address string, cfg ListenerConfig) (*Listener, error) { ln, err := lc.Listen(context.Background(), network, address) if err != nil { return nil, err } return &Listener{Listener: ln, cfg: normalizeConfig(cfg)}, nil } // ListenTLS creates TLS listener from cert/key paths. func ListenTLS(network, address, certFile, keyFile string, allowNonTLS bool) (*Listener, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } cfg := DefaultListenerConfig() cfg.AllowNonTLS = allowNonTLS cfg.BaseTLSConfig = TLSDefaults() cfg.BaseTLSConfig.Certificates = []tls.Certificate{cert} return ListenWithConfig(network, address, cfg) } func normalizeConfig(cfg ListenerConfig) ListenerConfig { out := DefaultListenerConfig() out.AllowNonTLS = cfg.AllowNonTLS out.SniffTimeout = cfg.SniffTimeout out.MaxClientHelloBytes = cfg.MaxClientHelloBytes out.BaseTLSConfig = cfg.BaseTLSConfig out.GetConfigForClient = cfg.GetConfigForClient out.GetConfigForClientHello = cfg.GetConfigForClientHello out.Logger = cfg.Logger if out.MaxClientHelloBytes <= 0 { out.MaxClientHelloBytes = 64 * 1024 } return out } // SetConfig atomically replaces listener config for new accepted connections. func (l *Listener) SetConfig(cfg ListenerConfig) { l.mu.Lock() l.cfg = normalizeConfig(cfg) l.mu.Unlock() } // Config returns a copy of current config. func (l *Listener) Config() ListenerConfig { l.mu.RLock() cfg := l.cfg l.mu.RUnlock() return cfg } // Stats returns current counters snapshot. func (l *Listener) Stats() StatsSnapshot { return l.stats.Snapshot() } func (l *Listener) Accept() (net.Conn, error) { raw, err := l.Listener.Accept() if err != nil { return nil, err } l.stats.incAccepted() l.mu.RLock() cfg := l.cfg l.mu.RUnlock() return newConn(raw, cfg, &l.stats), nil } // AcceptContext supports cancellation by closing accepted conn when ctx is done early. func (l *Listener) AcceptContext(ctx context.Context) (net.Conn, error) { type result struct { c net.Conn err error } ch := make(chan result, 1) go func() { c, err := l.Accept() ch <- result{c: c, err: err} }() select { case <-ctx.Done(): return nil, ctx.Err() case r := <-ch: return r.c, r.err } } // Dial creates a plain TCP starnet.Conn. func Dial(network, address string) (*Conn, error) { raw, err := net.Dial(network, address) if err != nil { return nil, err } cfg := DefaultListenerConfig() cfg.AllowNonTLS = true cfg.BaseTLSConfig = nil cfg.GetConfigForClient = nil c := newConn(raw, cfg, nil) c.isTLS = false return c, nil } // DialWithConfig dials with net.Dialer options. func DialWithConfig(network, address string, dc DialConfig) (*Conn, error) { d := net.Dialer{ Timeout: dc.Timeout, LocalAddr: dc.LocalAddr, } raw, err := d.Dial(network, address) if err != nil { return nil, err } cfg := DefaultListenerConfig() cfg.AllowNonTLS = true c := newConn(raw, cfg, nil) c.isTLS = false return c, nil } // DialTLSWithConfig creates a TLS client connection wrapper. func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time.Duration) (*Conn, error) { d := net.Dialer{Timeout: timeout} raw, err := d.Dial(network, address) if err != nil { return nil, err } tc := tls.Client(raw, tlsCfg) return &Conn{ Conn: raw, plainConn: raw, isTLS: true, tlsConn: tc, initErr: nil, allowNonTLS: false, skipSniff: true, }, nil } // DialTLS creates TLS client conn from cert/key paths. func DialTLS(network, address, certFile, keyFile string) (*Conn, error) { cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } cfg := TLSDefaults() cfg.Certificates = []tls.Certificate{cert} return DialTLSWithConfig(network, address, cfg, 0) } func WrapListener(listener net.Listener, cfg ListenerConfig) (*Listener, error) { if listener == nil { return nil, ErrNilConn } return &Listener{ Listener: listener, cfg: normalizeConfig(cfg), }, nil }