Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9ac9b65bc5 |
@ -68,6 +68,12 @@ var (
|
|||||||
// ErrNilConn indicates a nil net.Conn argument.
|
// ErrNilConn indicates a nil net.Conn argument.
|
||||||
ErrNilConn = errors.New("starnet: nil connection")
|
ErrNilConn = errors.New("starnet: nil connection")
|
||||||
|
|
||||||
|
// ErrTLSSniffFailed indicates TLS sniffing/parsing failed before handshake setup.
|
||||||
|
ErrTLSSniffFailed = errors.New("starnet: tls sniff failed")
|
||||||
|
|
||||||
|
// ErrTLSConfigSelectionFailed indicates dynamic TLS config selection failed.
|
||||||
|
ErrTLSConfigSelectionFailed = errors.New("starnet: tls config selection failed")
|
||||||
|
|
||||||
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
|
// ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden.
|
||||||
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
|
ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed")
|
||||||
|
|
||||||
@ -179,7 +185,8 @@ func IsTLS(err error) bool {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) {
|
if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) ||
|
||||||
|
errors.Is(err, ErrTLSSniffFailed) || errors.Is(err, ErrTLSConfigSelectionFailed) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
37
tlsconfig.go
37
tlsconfig.go
@ -9,14 +9,49 @@ import (
|
|||||||
// GetConfigForClientFunc selects TLS config by hostname/SNI.
|
// GetConfigForClientFunc selects TLS config by hostname/SNI.
|
||||||
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
|
type GetConfigForClientFunc func(hostname string) (*tls.Config, error)
|
||||||
|
|
||||||
|
// ClientHelloMeta carries sniffed TLS metadata and connection context.
|
||||||
|
type ClientHelloMeta struct {
|
||||||
|
ServerName string
|
||||||
|
LocalAddr net.Addr
|
||||||
|
RemoteAddr net.Addr
|
||||||
|
SupportedProtos []string
|
||||||
|
SupportedVersions []uint16
|
||||||
|
CipherSuites []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone returns a detached copy safe for callers to mutate.
|
||||||
|
func (m *ClientHelloMeta) Clone() *ClientHelloMeta {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := *m
|
||||||
|
if m.SupportedProtos != nil {
|
||||||
|
out.SupportedProtos = append([]string(nil), m.SupportedProtos...)
|
||||||
|
}
|
||||||
|
if m.SupportedVersions != nil {
|
||||||
|
out.SupportedVersions = append([]uint16(nil), m.SupportedVersions...)
|
||||||
|
}
|
||||||
|
if m.CipherSuites != nil {
|
||||||
|
out.CipherSuites = append([]uint16(nil), m.CipherSuites...)
|
||||||
|
}
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigForClientHelloFunc selects TLS config by sniffed TLS metadata.
|
||||||
|
type GetConfigForClientHelloFunc func(hello *ClientHelloMeta) (*tls.Config, error)
|
||||||
|
|
||||||
// ListenerConfig controls listener behavior.
|
// ListenerConfig controls listener behavior.
|
||||||
type ListenerConfig struct {
|
type ListenerConfig struct {
|
||||||
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
|
// BaseTLSConfig is used for TLS when dynamic selection returns nil.
|
||||||
BaseTLSConfig *tls.Config
|
BaseTLSConfig *tls.Config
|
||||||
|
|
||||||
// GetConfigForClient selects TLS config for a hostname.
|
// GetConfigForClient selects TLS config for a hostname/SNI.
|
||||||
|
// Deprecated: prefer GetConfigForClientHello for richer context.
|
||||||
GetConfigForClient GetConfigForClientFunc
|
GetConfigForClient GetConfigForClientFunc
|
||||||
|
|
||||||
|
// GetConfigForClientHello selects TLS config for sniffed TLS metadata.
|
||||||
|
GetConfigForClientHello GetConfigForClientHelloFunc
|
||||||
|
|
||||||
// AllowNonTLS allows plain TCP fallback.
|
// AllowNonTLS allows plain TCP fallback.
|
||||||
AllowNonTLS bool
|
AllowNonTLS bool
|
||||||
|
|
||||||
|
|||||||
427
tlssniffer.go
427
tlssniffer.go
@ -4,6 +4,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@ -35,7 +37,7 @@ func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWrit
|
|||||||
// SniffResult describes protocol sniffing result.
|
// SniffResult describes protocol sniffing result.
|
||||||
type SniffResult struct {
|
type SniffResult struct {
|
||||||
IsTLS bool
|
IsTLS bool
|
||||||
Hostname string
|
ClientHello *ClientHelloMeta
|
||||||
Buffer *bytes.Buffer
|
Buffer *bytes.Buffer
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,44 +57,210 @@ func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) {
|
|||||||
|
|
||||||
var buf bytes.Buffer
|
var buf bytes.Buffer
|
||||||
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
limited := &io.LimitedReader{R: conn, N: int64(maxBytes)}
|
||||||
tee := io.TeeReader(limited, &buf)
|
meta, isTLS := sniffClientHello(limited, &buf, conn)
|
||||||
|
|
||||||
var hello *tls.ClientHelloInfo
|
|
||||||
_ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{
|
|
||||||
GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
||||||
cp := *ch
|
|
||||||
hello = &cp
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
}).Handshake()
|
|
||||||
|
|
||||||
peek := buf.Bytes()
|
|
||||||
isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03
|
|
||||||
|
|
||||||
out := SniffResult{
|
out := SniffResult{
|
||||||
IsTLS: isTLS,
|
IsTLS: isTLS,
|
||||||
Buffer: bytes.NewBuffer(append([]byte(nil), peek...)),
|
Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)),
|
||||||
}
|
}
|
||||||
if hello != nil {
|
if isTLS {
|
||||||
out.Hostname = hello.ServerName
|
out.ClientHello = meta
|
||||||
}
|
}
|
||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// readOnlyConn rejects writes/close and reads from a reader.
|
func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) {
|
||||||
type readOnlyConn struct {
|
meta := &ClientHelloMeta{
|
||||||
r io.Reader
|
LocalAddr: conn.LocalAddr(),
|
||||||
raw net.Conn
|
RemoteAddr: conn.RemoteAddr(),
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) }
|
header, complete := readTLSRecordHeader(r, buf)
|
||||||
func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
if len(header) < 3 {
|
||||||
func (c readOnlyConn) Close() error { return nil }
|
return nil, false
|
||||||
func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() }
|
}
|
||||||
func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() }
|
isTLS := header[0] == 0x16 && header[1] == 0x03
|
||||||
func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil }
|
if !isTLS {
|
||||||
func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil }
|
return nil, false
|
||||||
func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil }
|
}
|
||||||
|
if len(header) < 5 || !complete {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||||
|
recordBody, bodyOK := readBufferedBytes(r, buf, recordLen)
|
||||||
|
if !bodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3])
|
||||||
|
helloBytes := append([]byte(nil), recordBody[4:]...)
|
||||||
|
for len(helloBytes) < helloLen {
|
||||||
|
nextHeader, nextOK := readTLSRecordHeader(r, buf)
|
||||||
|
if len(nextHeader) < 5 || !nextOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5]))
|
||||||
|
nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen)
|
||||||
|
if !nextBodyOK {
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
helloBytes = append(helloBytes, nextBody...)
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloBody(meta, helloBytes[:helloLen])
|
||||||
|
return meta, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) {
|
||||||
|
return readBufferedBytes(r, buf, 5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) {
|
||||||
|
if n <= 0 {
|
||||||
|
return nil, true
|
||||||
|
}
|
||||||
|
tmp := make([]byte, n)
|
||||||
|
readN, err := io.ReadFull(r, tmp)
|
||||||
|
if readN > 0 {
|
||||||
|
buf.Write(tmp[:readN])
|
||||||
|
}
|
||||||
|
return append([]byte(nil), tmp[:readN]...), err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloBody(meta *ClientHelloMeta, body []byte) {
|
||||||
|
if meta == nil || len(body) < 34 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
offset := 2 + 32
|
||||||
|
sessionIDLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+sessionIDLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += sessionIDLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+cipherSuitesLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for i := 0; i+1 < cipherSuitesLen; i += 2 {
|
||||||
|
meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2]))
|
||||||
|
}
|
||||||
|
offset += cipherSuitesLen
|
||||||
|
|
||||||
|
if offset >= len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
compressionMethodsLen := int(body[offset])
|
||||||
|
offset++
|
||||||
|
if offset+compressionMethodsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += compressionMethodsLen
|
||||||
|
|
||||||
|
if offset+2 > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2]))
|
||||||
|
offset += 2
|
||||||
|
if offset+extensionsLen > len(body) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
parseClientHelloExtensions(meta, body[offset:offset+extensionsLen])
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) {
|
||||||
|
for offset := 0; offset+4 <= len(exts); {
|
||||||
|
extType := binary.BigEndian.Uint16(exts[offset : offset+2])
|
||||||
|
extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4]))
|
||||||
|
offset += 4
|
||||||
|
if offset+extLen > len(exts) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
extData := exts[offset : offset+extLen]
|
||||||
|
offset += extLen
|
||||||
|
|
||||||
|
switch extType {
|
||||||
|
case 0:
|
||||||
|
parseServerNameExtension(meta, extData)
|
||||||
|
case 16:
|
||||||
|
parseALPNExtension(meta, extData)
|
||||||
|
case 43:
|
||||||
|
parseSupportedVersionsExtension(meta, extData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseServerNameExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset+3 <= len(list); {
|
||||||
|
nameType := list[offset]
|
||||||
|
nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3]))
|
||||||
|
offset += 3
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if nameType == 0 {
|
||||||
|
meta.ServerName = string(list[offset : offset+nameLen])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseALPNExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 2 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||||
|
if listLen == 0 || 2+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[2 : 2+listLen]
|
||||||
|
for offset := 0; offset < len(list); {
|
||||||
|
nameLen := int(list[offset])
|
||||||
|
offset++
|
||||||
|
if offset+nameLen > len(list) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen]))
|
||||||
|
offset += nameLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) {
|
||||||
|
if len(data) < 1 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
listLen := int(data[0])
|
||||||
|
if listLen == 0 || 1+listLen > len(data) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
list := data[1 : 1+listLen]
|
||||||
|
for offset := 0; offset+1 < len(list); offset += 2 {
|
||||||
|
meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Conn wraps net.Conn with lazy protocol initialization.
|
// Conn wraps net.Conn with lazy protocol initialization.
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
@ -106,10 +274,11 @@ type Conn struct {
|
|||||||
tlsConn *tls.Conn
|
tlsConn *tls.Conn
|
||||||
plainConn net.Conn
|
plainConn net.Conn
|
||||||
|
|
||||||
hostname string
|
clientHello *ClientHelloMeta
|
||||||
|
|
||||||
baseTLSConfig *tls.Config
|
baseTLSConfig *tls.Config
|
||||||
getConfigForClient GetConfigForClientFunc
|
getConfigForClient GetConfigForClientFunc
|
||||||
|
getConfigForClientHello GetConfigForClientHelloFunc
|
||||||
allowNonTLS bool
|
allowNonTLS bool
|
||||||
sniffer Sniffer
|
sniffer Sniffer
|
||||||
sniffTimeout time.Duration
|
sniffTimeout time.Duration
|
||||||
@ -125,6 +294,7 @@ func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn {
|
|||||||
plainConn: raw,
|
plainConn: raw,
|
||||||
baseTLSConfig: cfg.BaseTLSConfig,
|
baseTLSConfig: cfg.BaseTLSConfig,
|
||||||
getConfigForClient: cfg.GetConfigForClient,
|
getConfigForClient: cfg.GetConfigForClient,
|
||||||
|
getConfigForClientHello: cfg.GetConfigForClientHello,
|
||||||
allowNonTLS: cfg.AllowNonTLS,
|
allowNonTLS: cfg.AllowNonTLS,
|
||||||
sniffer: TLSSniffer{},
|
sniffer: TLSSniffer{},
|
||||||
sniffTimeout: cfg.SniffTimeout,
|
sniffTimeout: cfg.SniffTimeout,
|
||||||
@ -139,7 +309,7 @@ func (c *Conn) init() {
|
|||||||
if c.skipSniff {
|
if c.skipSniff {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if c.baseTLSConfig == nil && c.getConfigForClient == nil {
|
if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil {
|
||||||
c.isTLS = false
|
c.isTLS = false
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -152,13 +322,13 @@ func (c *Conn) init() {
|
|||||||
_ = c.Conn.SetReadDeadline(time.Time{})
|
_ = c.Conn.SetReadDeadline(time.Time{})
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.initErr = err
|
c.initErr = errors.Join(ErrTLSSniffFailed, err)
|
||||||
c.failAndClose("sniff failed: %v", err)
|
c.failSniff(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.isTLS = res.IsTLS
|
c.isTLS = res.IsTLS
|
||||||
c.hostname = res.Hostname
|
c.clientHello = res.ClientHello
|
||||||
|
|
||||||
if c.isTLS {
|
if c.isTLS {
|
||||||
if c.stats != nil {
|
if c.stats != nil {
|
||||||
@ -166,8 +336,8 @@ func (c *Conn) init() {
|
|||||||
}
|
}
|
||||||
tlsCfg, errCfg := c.selectTLSConfig()
|
tlsCfg, errCfg := c.selectTLSConfig()
|
||||||
if errCfg != nil {
|
if errCfg != nil {
|
||||||
c.initErr = errCfg
|
c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg)
|
||||||
c.failAndClose("tls config select failed: %v", errCfg)
|
c.failTLSConfigSelection(errCfg)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||||
@ -180,7 +350,7 @@ func (c *Conn) init() {
|
|||||||
}
|
}
|
||||||
if !c.allowNonTLS {
|
if !c.allowNonTLS {
|
||||||
c.initErr = ErrNonTLSNotAllowed
|
c.initErr = ErrNonTLSNotAllowed
|
||||||
c.failAndClose("plain tcp rejected")
|
c.failPlainRejected()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn)
|
||||||
@ -188,27 +358,57 @@ func (c *Conn) init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) failAndClose(format string, v ...interface{}) {
|
func (c *Conn) failAndClose(format string, v ...interface{}) {
|
||||||
if c.stats != nil {
|
|
||||||
c.stats.incInitFailures()
|
|
||||||
}
|
|
||||||
if c.logger != nil {
|
if c.logger != nil {
|
||||||
c.logger.Printf("starnet: "+format, v...)
|
c.logger.Printf("starnet: "+format, v...)
|
||||||
}
|
}
|
||||||
_ = c.Close()
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failSniff(err error) {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incSniffFailures()
|
||||||
|
}
|
||||||
|
c.failAndClose("tls sniff failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failTLSConfigSelection(err error) {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incTLSConfigFailures()
|
||||||
|
}
|
||||||
|
c.failAndClose("tls config selection failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) failPlainRejected() {
|
||||||
|
if c.stats != nil {
|
||||||
|
c.stats.incPlainRejected()
|
||||||
|
}
|
||||||
|
c.failAndClose("plain tcp rejected")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
||||||
if c.getConfigForClient != nil {
|
var selected *tls.Config
|
||||||
cfg, err := c.getConfigForClient(c.hostname)
|
if c.getConfigForClientHello != nil {
|
||||||
|
cfg, err := c.getConfigForClientHello(c.clientHello.Clone())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
return cfg, nil
|
selected = cfg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if c.baseTLSConfig != nil {
|
if selected == nil && c.getConfigForClient != nil {
|
||||||
return c.baseTLSConfig, nil
|
cfg, err := c.getConfigForClient(c.serverName())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if cfg != nil {
|
||||||
|
selected = cfg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
composed := composeServerTLSConfig(c.baseTLSConfig, selected)
|
||||||
|
if composed != nil {
|
||||||
|
return composed, nil
|
||||||
}
|
}
|
||||||
return nil, ErrNoTLSConfig
|
return nil, ErrNoTLSConfig
|
||||||
}
|
}
|
||||||
@ -216,7 +416,140 @@ func (c *Conn) selectTLSConfig() (*tls.Config, error) {
|
|||||||
// Hostname returns sniffed SNI hostname (if any).
|
// Hostname returns sniffed SNI hostname (if any).
|
||||||
func (c *Conn) Hostname() string {
|
func (c *Conn) Hostname() string {
|
||||||
c.init()
|
c.init()
|
||||||
return c.hostname
|
return c.serverName()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClientHello returns sniffed TLS metadata (if any).
|
||||||
|
func (c *Conn) ClientHello() *ClientHelloMeta {
|
||||||
|
c.init()
|
||||||
|
return c.clientHello.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Conn) serverName() string {
|
||||||
|
if c.clientHello == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return c.clientHello.ServerName
|
||||||
|
}
|
||||||
|
|
||||||
|
func composeServerTLSConfig(base, selected *tls.Config) *tls.Config {
|
||||||
|
if base == nil {
|
||||||
|
return selected
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
out := base.Clone()
|
||||||
|
applyServerTLSOverrides(out, selected)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyServerTLSOverrides(dst, src *tls.Config) {
|
||||||
|
if dst == nil || src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if src.Rand != nil {
|
||||||
|
dst.Rand = src.Rand
|
||||||
|
}
|
||||||
|
if src.Time != nil {
|
||||||
|
dst.Time = src.Time
|
||||||
|
}
|
||||||
|
if len(src.Certificates) > 0 {
|
||||||
|
dst.Certificates = append([]tls.Certificate(nil), src.Certificates...)
|
||||||
|
}
|
||||||
|
if len(src.NameToCertificate) > 0 {
|
||||||
|
m := make(map[string]*tls.Certificate, len(src.NameToCertificate))
|
||||||
|
for k, v := range src.NameToCertificate {
|
||||||
|
m[k] = v
|
||||||
|
}
|
||||||
|
dst.NameToCertificate = m
|
||||||
|
}
|
||||||
|
if src.GetCertificate != nil {
|
||||||
|
dst.GetCertificate = src.GetCertificate
|
||||||
|
}
|
||||||
|
if src.GetClientCertificate != nil {
|
||||||
|
dst.GetClientCertificate = src.GetClientCertificate
|
||||||
|
}
|
||||||
|
if src.GetConfigForClient != nil {
|
||||||
|
dst.GetConfigForClient = src.GetConfigForClient
|
||||||
|
}
|
||||||
|
if src.VerifyPeerCertificate != nil {
|
||||||
|
dst.VerifyPeerCertificate = src.VerifyPeerCertificate
|
||||||
|
}
|
||||||
|
if src.VerifyConnection != nil {
|
||||||
|
dst.VerifyConnection = src.VerifyConnection
|
||||||
|
}
|
||||||
|
if src.RootCAs != nil {
|
||||||
|
dst.RootCAs = src.RootCAs
|
||||||
|
}
|
||||||
|
if len(src.NextProtos) > 0 {
|
||||||
|
dst.NextProtos = append([]string(nil), src.NextProtos...)
|
||||||
|
}
|
||||||
|
if src.ServerName != "" {
|
||||||
|
dst.ServerName = src.ServerName
|
||||||
|
}
|
||||||
|
if src.ClientAuth > dst.ClientAuth {
|
||||||
|
dst.ClientAuth = src.ClientAuth
|
||||||
|
}
|
||||||
|
if src.ClientCAs != nil {
|
||||||
|
dst.ClientCAs = src.ClientCAs
|
||||||
|
}
|
||||||
|
if src.InsecureSkipVerify {
|
||||||
|
dst.InsecureSkipVerify = true
|
||||||
|
}
|
||||||
|
if len(src.CipherSuites) > 0 {
|
||||||
|
dst.CipherSuites = append([]uint16(nil), src.CipherSuites...)
|
||||||
|
}
|
||||||
|
if src.PreferServerCipherSuites {
|
||||||
|
dst.PreferServerCipherSuites = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketsDisabled {
|
||||||
|
dst.SessionTicketsDisabled = true
|
||||||
|
}
|
||||||
|
if src.SessionTicketKey != ([32]byte{}) {
|
||||||
|
dst.SessionTicketKey = src.SessionTicketKey
|
||||||
|
}
|
||||||
|
if src.ClientSessionCache != nil {
|
||||||
|
dst.ClientSessionCache = src.ClientSessionCache
|
||||||
|
}
|
||||||
|
if src.UnwrapSession != nil {
|
||||||
|
dst.UnwrapSession = src.UnwrapSession
|
||||||
|
}
|
||||||
|
if src.WrapSession != nil {
|
||||||
|
dst.WrapSession = src.WrapSession
|
||||||
|
}
|
||||||
|
if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) {
|
||||||
|
dst.MinVersion = src.MinVersion
|
||||||
|
}
|
||||||
|
if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) {
|
||||||
|
dst.MaxVersion = src.MaxVersion
|
||||||
|
}
|
||||||
|
if len(src.CurvePreferences) > 0 {
|
||||||
|
dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...)
|
||||||
|
}
|
||||||
|
if src.DynamicRecordSizingDisabled {
|
||||||
|
dst.DynamicRecordSizingDisabled = true
|
||||||
|
}
|
||||||
|
if src.Renegotiation != 0 {
|
||||||
|
dst.Renegotiation = src.Renegotiation
|
||||||
|
}
|
||||||
|
if src.KeyLogWriter != nil {
|
||||||
|
dst.KeyLogWriter = src.KeyLogWriter
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloConfigList) > 0 {
|
||||||
|
dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...)
|
||||||
|
}
|
||||||
|
if src.EncryptedClientHelloRejectionVerify != nil {
|
||||||
|
dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify
|
||||||
|
}
|
||||||
|
if src.GetEncryptedClientHelloKeys != nil {
|
||||||
|
dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys
|
||||||
|
}
|
||||||
|
if len(src.EncryptedClientHelloKeys) > 0 {
|
||||||
|
dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) IsTLS() bool {
|
func (c *Conn) IsTLS() bool {
|
||||||
@ -365,6 +698,7 @@ func normalizeConfig(cfg ListenerConfig) ListenerConfig {
|
|||||||
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
|
out.MaxClientHelloBytes = cfg.MaxClientHelloBytes
|
||||||
out.BaseTLSConfig = cfg.BaseTLSConfig
|
out.BaseTLSConfig = cfg.BaseTLSConfig
|
||||||
out.GetConfigForClient = cfg.GetConfigForClient
|
out.GetConfigForClient = cfg.GetConfigForClient
|
||||||
|
out.GetConfigForClientHello = cfg.GetConfigForClientHello
|
||||||
out.Logger = cfg.Logger
|
out.Logger = cfg.Logger
|
||||||
if out.MaxClientHelloBytes <= 0 {
|
if out.MaxClientHelloBytes <= 0 {
|
||||||
out.MaxClientHelloBytes = 64 * 1024
|
out.MaxClientHelloBytes = 64 * 1024
|
||||||
@ -471,7 +805,6 @@ func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time
|
|||||||
plainConn: raw,
|
plainConn: raw,
|
||||||
isTLS: true,
|
isTLS: true,
|
||||||
tlsConn: tc,
|
tlsConn: tc,
|
||||||
hostname: "",
|
|
||||||
initErr: nil,
|
initErr: nil,
|
||||||
allowNonTLS: false,
|
allowNonTLS: false,
|
||||||
skipSniff: true,
|
skipSniff: true,
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
package starnet
|
package starnet
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
"crypto/rsa"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
|
"encoding/binary"
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
@ -18,6 +20,280 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type staticSniffer struct {
|
||||||
|
result SniffResult
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type fixedAddr string
|
||||||
|
|
||||||
|
func (a fixedAddr) Network() string { return "tcp" }
|
||||||
|
func (a fixedAddr) String() string { return string(a) }
|
||||||
|
|
||||||
|
type recordingConn struct {
|
||||||
|
buf bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordingConn) Read([]byte) (int, error) { return 0, io.EOF }
|
||||||
|
func (c *recordingConn) Write(p []byte) (int, error) { return c.buf.Write(p) }
|
||||||
|
func (c *recordingConn) Close() error { return nil }
|
||||||
|
func (c *recordingConn) LocalAddr() net.Addr { return fixedAddr("127.0.0.1:443") }
|
||||||
|
func (c *recordingConn) RemoteAddr() net.Addr { return fixedAddr("127.0.0.1:50000") }
|
||||||
|
func (c *recordingConn) SetDeadline(time.Time) error { return nil }
|
||||||
|
func (c *recordingConn) SetReadDeadline(time.Time) error { return nil }
|
||||||
|
func (c *recordingConn) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
|
type bytesConn struct {
|
||||||
|
reader io.Reader
|
||||||
|
local net.Addr
|
||||||
|
remote net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *bytesConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
||||||
|
func (c *bytesConn) Write(p []byte) (int, error) { return len(p), nil }
|
||||||
|
func (c *bytesConn) Close() error { return nil }
|
||||||
|
func (c *bytesConn) LocalAddr() net.Addr { return c.local }
|
||||||
|
func (c *bytesConn) RemoteAddr() net.Addr { return c.remote }
|
||||||
|
func (c *bytesConn) SetDeadline(time.Time) error { return nil }
|
||||||
|
func (c *bytesConn) SetReadDeadline(time.Time) error { return nil }
|
||||||
|
func (c *bytesConn) SetWriteDeadline(time.Time) error { return nil }
|
||||||
|
|
||||||
|
func (s staticSniffer) Sniff(_ net.Conn, _ int) (SniffResult, error) {
|
||||||
|
return s.result, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func captureClientHelloBytes(t *testing.T, serverName string, nextProtos []string) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
conn := &recordingConn{}
|
||||||
|
tc := tls.Client(conn, &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: serverName,
|
||||||
|
NextProtos: nextProtos,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
})
|
||||||
|
_ = tc.Handshake()
|
||||||
|
_ = tc.Close()
|
||||||
|
|
||||||
|
out := append([]byte(nil), conn.buf.Bytes()...)
|
||||||
|
if len(out) < 9 {
|
||||||
|
t.Fatalf("captured hello too short: %d", len(out))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func prefixBytes(data []byte, n int) []byte {
|
||||||
|
if n < 0 {
|
||||||
|
n = 0
|
||||||
|
}
|
||||||
|
if len(data) < n {
|
||||||
|
n = len(data)
|
||||||
|
}
|
||||||
|
return data[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
func splitClientHelloRecord(t *testing.T, data []byte, splitBodyAt int) []byte {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if len(data) < 9 {
|
||||||
|
t.Fatalf("client hello too short: %d", len(data))
|
||||||
|
}
|
||||||
|
if data[0] != 0x16 || data[1] != 0x03 {
|
||||||
|
t.Fatalf("unexpected tls record header: % x", prefixBytes(data, 5))
|
||||||
|
}
|
||||||
|
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
||||||
|
if len(data) < 5+recordLen {
|
||||||
|
t.Fatalf("short tls record: have=%d want=%d", len(data), 5+recordLen)
|
||||||
|
}
|
||||||
|
recordBody := append([]byte(nil), data[5:5+recordLen]...)
|
||||||
|
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
||||||
|
t.Fatalf("unexpected handshake record: % x", prefixBytes(recordBody, 4))
|
||||||
|
}
|
||||||
|
if splitBodyAt <= 4 || splitBodyAt >= len(recordBody) {
|
||||||
|
t.Fatalf("invalid split point %d for body len %d", splitBodyAt, len(recordBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
var out bytes.Buffer
|
||||||
|
writeRecord := func(body []byte) {
|
||||||
|
out.WriteByte(0x16)
|
||||||
|
out.WriteByte(data[1])
|
||||||
|
out.WriteByte(data[2])
|
||||||
|
header := []byte{0, 0}
|
||||||
|
binary.BigEndian.PutUint16(header, uint16(len(body)))
|
||||||
|
out.Write(header)
|
||||||
|
out.Write(body)
|
||||||
|
}
|
||||||
|
|
||||||
|
writeRecord(recordBody[:splitBodyAt])
|
||||||
|
writeRecord(recordBody[splitBodyAt:])
|
||||||
|
return out.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSSnifferParsesClientHello(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
tc := tls.Client(client, &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: "example.com",
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
})
|
||||||
|
defer tc.Close()
|
||||||
|
done <- tc.Handshake()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = server.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
res, err := (TLSSniffer{}).Sniff(server, 64*1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sniff: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !res.IsTLS {
|
||||||
|
t.Fatalf("expected TLS")
|
||||||
|
}
|
||||||
|
if res.ClientHello == nil {
|
||||||
|
t.Fatalf("expected client hello metadata")
|
||||||
|
}
|
||||||
|
if res.ClientHello.ServerName != "example.com" {
|
||||||
|
t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName)
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" {
|
||||||
|
t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos)
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.SupportedVersions) == 0 {
|
||||||
|
t.Fatalf("expected supported versions")
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.CipherSuites) == 0 {
|
||||||
|
t.Fatalf("expected cipher suites")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = server.Close()
|
||||||
|
<-done
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSSnifferParsesFragmentedClientHelloRecords(t *testing.T) {
|
||||||
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
||||||
|
fragmented := splitClientHelloRecord(t, rawHello, 32)
|
||||||
|
|
||||||
|
conn := &bytesConn{
|
||||||
|
reader: bytes.NewReader(fragmented),
|
||||||
|
local: fixedAddr("127.0.0.1:443"),
|
||||||
|
remote: fixedAddr("127.0.0.1:50000"),
|
||||||
|
}
|
||||||
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sniff: %v", err)
|
||||||
|
}
|
||||||
|
if !res.IsTLS {
|
||||||
|
t.Fatalf("expected TLS")
|
||||||
|
}
|
||||||
|
if res.ClientHello == nil {
|
||||||
|
t.Fatalf("expected client hello metadata")
|
||||||
|
}
|
||||||
|
if res.ClientHello.ServerName != "example.com" {
|
||||||
|
t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName)
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" {
|
||||||
|
t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos)
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.SupportedVersions) == 0 {
|
||||||
|
t.Fatalf("expected supported versions")
|
||||||
|
}
|
||||||
|
if len(res.ClientHello.CipherSuites) == 0 {
|
||||||
|
t.Fatalf("expected cipher suites")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSSnifferKeepsTruncatedTLSAsTLS(t *testing.T) {
|
||||||
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
||||||
|
truncated := append([]byte(nil), rawHello[:20]...)
|
||||||
|
|
||||||
|
conn := &bytesConn{
|
||||||
|
reader: bytes.NewReader(truncated),
|
||||||
|
local: fixedAddr("127.0.0.1:443"),
|
||||||
|
remote: fixedAddr("127.0.0.1:50000"),
|
||||||
|
}
|
||||||
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sniff: %v", err)
|
||||||
|
}
|
||||||
|
if !res.IsTLS {
|
||||||
|
t.Fatalf("expected truncated hello to stay classified as TLS")
|
||||||
|
}
|
||||||
|
if res.ClientHello == nil {
|
||||||
|
t.Fatalf("expected client hello metadata")
|
||||||
|
}
|
||||||
|
if res.ClientHello.ServerName != "" {
|
||||||
|
t.Fatalf("expected empty server name for truncated hello, got %q", res.ClientHello.ServerName)
|
||||||
|
}
|
||||||
|
if res.ClientHello.LocalAddr == nil || res.ClientHello.RemoteAddr == nil {
|
||||||
|
t.Fatalf("expected local/remote addr in metadata")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(res.Buffer.Bytes(), truncated) {
|
||||||
|
t.Fatalf("buffer mismatch: got %d bytes want %d", res.Buffer.Len(), len(truncated))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSSnifferRespectsMaxBytesWhileKeepingTLSClassification(t *testing.T) {
|
||||||
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
||||||
|
limit := 5
|
||||||
|
|
||||||
|
conn := &bytesConn{
|
||||||
|
reader: bytes.NewReader(rawHello),
|
||||||
|
local: fixedAddr("127.0.0.1:443"),
|
||||||
|
remote: fixedAddr("127.0.0.1:50000"),
|
||||||
|
}
|
||||||
|
res, err := (TLSSniffer{}).Sniff(conn, limit)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sniff: %v", err)
|
||||||
|
}
|
||||||
|
if !res.IsTLS {
|
||||||
|
t.Fatalf("expected TLS classification with header-sized limit")
|
||||||
|
}
|
||||||
|
if res.ClientHello == nil {
|
||||||
|
t.Fatalf("expected client hello metadata")
|
||||||
|
}
|
||||||
|
if res.ClientHello.ServerName != "" {
|
||||||
|
t.Fatalf("expected empty server name with header-sized limit, got %q", res.ClientHello.ServerName)
|
||||||
|
}
|
||||||
|
if res.Buffer.Len() != limit {
|
||||||
|
t.Fatalf("buffer length mismatch: got %d want %d", res.Buffer.Len(), limit)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(res.Buffer.Bytes(), rawHello[:limit]) {
|
||||||
|
t.Fatalf("buffer content mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSSnifferRejectsFakeTLSClientRecord(t *testing.T) {
|
||||||
|
fake := []byte{
|
||||||
|
0x16, 0x03, 0x03, 0x00, 0x04,
|
||||||
|
'G', 'E', 'T', ' ',
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &bytesConn{
|
||||||
|
reader: bytes.NewReader(fake),
|
||||||
|
local: fixedAddr("127.0.0.1:443"),
|
||||||
|
remote: fixedAddr("127.0.0.1:50000"),
|
||||||
|
}
|
||||||
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sniff: %v", err)
|
||||||
|
}
|
||||||
|
if res.IsTLS {
|
||||||
|
t.Fatalf("expected fake tls-looking record to be classified as non-tls")
|
||||||
|
}
|
||||||
|
if res.ClientHello != nil {
|
||||||
|
t.Fatalf("expected no client hello metadata for fake record")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(res.Buffer.Bytes(), fake) {
|
||||||
|
t.Fatalf("buffer mismatch")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---------- cert helpers ----------
|
// ---------- cert helpers ----------
|
||||||
|
|
||||||
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
|
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
|
||||||
@ -377,6 +653,115 @@ func TestGetConfigForClientError(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetConfigForClientHelloReceivesMetadata(t *testing.T) {
|
||||||
|
certA := genSelfSignedCert(t, "a.local")
|
||||||
|
certB := genSelfSignedCert(t, "b.local")
|
||||||
|
|
||||||
|
base := TLSDefaults()
|
||||||
|
base.Certificates = []tls.Certificate{certA}
|
||||||
|
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = false
|
||||||
|
cfg.BaseTLSConfig = base
|
||||||
|
|
||||||
|
metaCh := make(chan *ClientHelloMeta, 1)
|
||||||
|
cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) {
|
||||||
|
metaCh <- hello.Clone()
|
||||||
|
if hello != nil && hello.ServerName == "b.local" {
|
||||||
|
b := TLSDefaults()
|
||||||
|
b.Certificates = []tls.Certificate{certB}
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, addr, cleanup := startEchoServer(t, cfg)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
ServerName: "b.local",
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tls dial: %v", err)
|
||||||
|
}
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case hello := <-metaCh:
|
||||||
|
if hello == nil {
|
||||||
|
t.Fatalf("expected metadata")
|
||||||
|
}
|
||||||
|
if hello.ServerName != "b.local" {
|
||||||
|
t.Fatalf("unexpected server name: %q", hello.ServerName)
|
||||||
|
}
|
||||||
|
if hello.LocalAddr == nil {
|
||||||
|
t.Fatalf("expected LocalAddr")
|
||||||
|
}
|
||||||
|
if hello.RemoteAddr == nil {
|
||||||
|
t.Fatalf("expected RemoteAddr")
|
||||||
|
}
|
||||||
|
if len(hello.CipherSuites) == 0 {
|
||||||
|
t.Fatalf("expected cipher suites in metadata")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("timeout waiting client hello metadata")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetConfigForClientHelloWithoutSNIStillGetsLocalAddr(t *testing.T) {
|
||||||
|
cert := genSelfSignedCert(t, "localhost")
|
||||||
|
base := TLSDefaults()
|
||||||
|
base.Certificates = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.AllowNonTLS = false
|
||||||
|
cfg.BaseTLSConfig = base
|
||||||
|
|
||||||
|
metaCh := make(chan *ClientHelloMeta, 1)
|
||||||
|
cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) {
|
||||||
|
metaCh <- hello.Clone()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, addr, cleanup := startEchoServer(t, cfg)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
raw, err := net.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer raw.Close()
|
||||||
|
|
||||||
|
tc := tls.Client(raw, &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
})
|
||||||
|
if err := tc.Handshake(); err != nil {
|
||||||
|
t.Fatalf("tls handshake: %v", err)
|
||||||
|
}
|
||||||
|
defer tc.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case hello := <-metaCh:
|
||||||
|
if hello == nil {
|
||||||
|
t.Fatalf("expected metadata")
|
||||||
|
}
|
||||||
|
if hello.ServerName != "" {
|
||||||
|
t.Fatalf("expected empty SNI, got %q", hello.ServerName)
|
||||||
|
}
|
||||||
|
if hello.LocalAddr == nil {
|
||||||
|
t.Fatalf("expected LocalAddr")
|
||||||
|
}
|
||||||
|
if hello.RemoteAddr == nil {
|
||||||
|
t.Fatalf("expected RemoteAddr")
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatalf("timeout waiting client hello metadata")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestAcceptContextCancel(t *testing.T) {
|
func TestAcceptContextCancel(t *testing.T) {
|
||||||
cfg := DefaultListenerConfig()
|
cfg := DefaultListenerConfig()
|
||||||
cfg.AllowNonTLS = true
|
cfg.AllowNonTLS = true
|
||||||
@ -418,6 +803,140 @@ func TestListenerStats(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestComposeServerTLSConfigInheritsBaseDefaults(t *testing.T) {
|
||||||
|
base := TLSDefaults()
|
||||||
|
base.MinVersion = tls.VersionTLS13
|
||||||
|
base.NextProtos = []string{"h2", "http/1.1"}
|
||||||
|
base.ClientAuth = tls.RequireAndVerifyClientCert
|
||||||
|
base.DynamicRecordSizingDisabled = true
|
||||||
|
|
||||||
|
selected := TLSDefaults()
|
||||||
|
selected.Certificates = []tls.Certificate{genSelfSignedCert(t, "b.local")}
|
||||||
|
|
||||||
|
composed := composeServerTLSConfig(base, selected)
|
||||||
|
if composed == nil {
|
||||||
|
t.Fatalf("expected composed config")
|
||||||
|
}
|
||||||
|
if composed.MinVersion != tls.VersionTLS13 {
|
||||||
|
t.Fatalf("expected min version from base, got %v", composed.MinVersion)
|
||||||
|
}
|
||||||
|
if len(composed.NextProtos) != 2 {
|
||||||
|
t.Fatalf("expected next protos from base, got %v", composed.NextProtos)
|
||||||
|
}
|
||||||
|
if composed.ClientAuth != tls.RequireAndVerifyClientCert {
|
||||||
|
t.Fatalf("expected client auth from base, got %v", composed.ClientAuth)
|
||||||
|
}
|
||||||
|
if !composed.DynamicRecordSizingDisabled {
|
||||||
|
t.Fatalf("expected dynamic record sizing disabled to inherit")
|
||||||
|
}
|
||||||
|
if len(composed.Certificates) != 1 {
|
||||||
|
t.Fatalf("expected selected certificates to be preserved")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComposeServerTLSConfigKeepsSelectedOverrides(t *testing.T) {
|
||||||
|
base := TLSDefaults()
|
||||||
|
base.MinVersion = tls.VersionTLS12
|
||||||
|
base.NextProtos = []string{"http/1.1"}
|
||||||
|
|
||||||
|
selected := TLSDefaults()
|
||||||
|
selected.MinVersion = tls.VersionTLS13
|
||||||
|
selected.NextProtos = []string{"h2"}
|
||||||
|
|
||||||
|
composed := composeServerTLSConfig(base, selected)
|
||||||
|
if composed.MinVersion != tls.VersionTLS13 {
|
||||||
|
t.Fatalf("expected selected min version, got %v", composed.MinVersion)
|
||||||
|
}
|
||||||
|
if len(composed.NextProtos) != 1 || composed.NextProtos[0] != "h2" {
|
||||||
|
t.Fatalf("expected selected next protos, got %v", composed.NextProtos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnInitFailureStats(t *testing.T) {
|
||||||
|
t.Run("sniff failure", func(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
stats := &Stats{}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.BaseTLSConfig = TLSDefaults()
|
||||||
|
conn := newConn(server, cfg, stats)
|
||||||
|
conn.sniffer = staticSniffer{err: io.ErrUnexpectedEOF}
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, err := conn.Read(buf)
|
||||||
|
if err == nil || !errors.Is(err, ErrTLSSniffFailed) {
|
||||||
|
t.Fatalf("expected tls sniff failure, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := stats.Snapshot()
|
||||||
|
if snap.InitFailures != 1 || snap.SniffFailures != 1 {
|
||||||
|
t.Fatalf("unexpected sniff failure stats: %+v", snap)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("tls config selection failure", func(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
stats := &Stats{}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.GetConfigForClientHello = func(*ClientHelloMeta) (*tls.Config, error) {
|
||||||
|
return nil, errors.New("boom")
|
||||||
|
}
|
||||||
|
conn := newConn(server, cfg, stats)
|
||||||
|
conn.sniffer = staticSniffer{
|
||||||
|
result: SniffResult{
|
||||||
|
IsTLS: true,
|
||||||
|
ClientHello: &ClientHelloMeta{},
|
||||||
|
Buffer: bytes.NewBuffer(nil),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, err := conn.Read(buf)
|
||||||
|
if err == nil || !errors.Is(err, ErrTLSConfigSelectionFailed) {
|
||||||
|
t.Fatalf("expected tls config selection failure, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := stats.Snapshot()
|
||||||
|
if snap.InitFailures != 1 || snap.TLSConfigFailures != 1 {
|
||||||
|
t.Fatalf("unexpected tls config failure stats: %+v", snap)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("plain rejected", func(t *testing.T) {
|
||||||
|
client, server := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
stats := &Stats{}
|
||||||
|
cfg := DefaultListenerConfig()
|
||||||
|
cfg.BaseTLSConfig = TLSDefaults()
|
||||||
|
cfg.AllowNonTLS = false
|
||||||
|
conn := newConn(server, cfg, stats)
|
||||||
|
conn.sniffer = staticSniffer{
|
||||||
|
result: SniffResult{
|
||||||
|
IsTLS: false,
|
||||||
|
Buffer: bytes.NewBuffer(nil),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
_, err := conn.Read(buf)
|
||||||
|
if !errors.Is(err, ErrNonTLSNotAllowed) {
|
||||||
|
t.Fatalf("expected plain rejection, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
snap := stats.Snapshot()
|
||||||
|
if snap.InitFailures != 1 || snap.PlainRejected != 1 {
|
||||||
|
t.Fatalf("unexpected plain rejection stats: %+v", snap)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestDialAndDialWithConfig(t *testing.T) {
|
func TestDialAndDialWithConfig(t *testing.T) {
|
||||||
nl, err := net.Listen("tcp", "127.0.0.1:0")
|
nl, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
12
tlsstats.go
12
tlsstats.go
@ -8,6 +8,9 @@ type StatsSnapshot struct {
|
|||||||
TLSDetected uint64
|
TLSDetected uint64
|
||||||
PlainDetected uint64
|
PlainDetected uint64
|
||||||
InitFailures uint64
|
InitFailures uint64
|
||||||
|
SniffFailures uint64
|
||||||
|
TLSConfigFailures uint64
|
||||||
|
PlainRejected uint64
|
||||||
Closed uint64
|
Closed uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -17,6 +20,9 @@ type Stats struct {
|
|||||||
tlsDetected uint64
|
tlsDetected uint64
|
||||||
plainDetected uint64
|
plainDetected uint64
|
||||||
initFailures uint64
|
initFailures uint64
|
||||||
|
sniffFailures uint64
|
||||||
|
tlsConfigFailures uint64
|
||||||
|
plainRejected uint64
|
||||||
closed uint64
|
closed uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -25,6 +31,9 @@ func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) }
|
|||||||
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
|
func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) }
|
||||||
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
|
func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) }
|
||||||
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
|
func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) }
|
||||||
|
func (s *Stats) incSniffFailures() { atomic.AddUint64(&s.sniffFailures, 1); s.incInitFailures() }
|
||||||
|
func (s *Stats) incTLSConfigFailures() { atomic.AddUint64(&s.tlsConfigFailures, 1); s.incInitFailures() }
|
||||||
|
func (s *Stats) incPlainRejected() { atomic.AddUint64(&s.plainRejected, 1); s.incInitFailures() }
|
||||||
|
|
||||||
// Snapshot returns a stable view of counters.
|
// Snapshot returns a stable view of counters.
|
||||||
func (s *Stats) Snapshot() StatsSnapshot {
|
func (s *Stats) Snapshot() StatsSnapshot {
|
||||||
@ -33,6 +42,9 @@ func (s *Stats) Snapshot() StatsSnapshot {
|
|||||||
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
|
TLSDetected: atomic.LoadUint64(&s.tlsDetected),
|
||||||
PlainDetected: atomic.LoadUint64(&s.plainDetected),
|
PlainDetected: atomic.LoadUint64(&s.plainDetected),
|
||||||
InitFailures: atomic.LoadUint64(&s.initFailures),
|
InitFailures: atomic.LoadUint64(&s.initFailures),
|
||||||
|
SniffFailures: atomic.LoadUint64(&s.sniffFailures),
|
||||||
|
TLSConfigFailures: atomic.LoadUint64(&s.tlsConfigFailures),
|
||||||
|
PlainRejected: atomic.LoadUint64(&s.plainRejected),
|
||||||
Closed: atomic.LoadUint64(&s.closed),
|
Closed: atomic.LoadUint64(&s.closed),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user