diff --git a/errors.go b/errors.go index 582d51b..7707e7b 100644 --- a/errors.go +++ b/errors.go @@ -68,6 +68,12 @@ var ( // ErrNilConn indicates a nil net.Conn argument. ErrNilConn = errors.New("starnet: nil connection") + // ErrTLSSniffFailed indicates TLS sniffing/parsing failed before handshake setup. + ErrTLSSniffFailed = errors.New("starnet: tls sniff failed") + + // ErrTLSConfigSelectionFailed indicates dynamic TLS config selection failed. + ErrTLSConfigSelectionFailed = errors.New("starnet: tls config selection failed") + // ErrNonTLSNotAllowed indicates plain TCP was detected while non-TLS is forbidden. ErrNonTLSNotAllowed = errors.New("starnet: non-TLS connection not allowed") @@ -179,7 +185,8 @@ func IsTLS(err error) bool { if err == nil { return false } - if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) { + if errors.Is(err, ErrNotTLS) || errors.Is(err, ErrNoTLSConfig) || errors.Is(err, ErrNonTLSNotAllowed) || + errors.Is(err, ErrTLSSniffFailed) || errors.Is(err, ErrTLSConfigSelectionFailed) { return true } diff --git a/tlsconfig.go b/tlsconfig.go index 5d18e09..05a37d3 100644 --- a/tlsconfig.go +++ b/tlsconfig.go @@ -9,14 +9,49 @@ import ( // GetConfigForClientFunc selects TLS config by hostname/SNI. type GetConfigForClientFunc func(hostname string) (*tls.Config, error) +// ClientHelloMeta carries sniffed TLS metadata and connection context. +type ClientHelloMeta struct { + ServerName string + LocalAddr net.Addr + RemoteAddr net.Addr + SupportedProtos []string + SupportedVersions []uint16 + CipherSuites []uint16 +} + +// Clone returns a detached copy safe for callers to mutate. +func (m *ClientHelloMeta) Clone() *ClientHelloMeta { + if m == nil { + return nil + } + out := *m + if m.SupportedProtos != nil { + out.SupportedProtos = append([]string(nil), m.SupportedProtos...) + } + if m.SupportedVersions != nil { + out.SupportedVersions = append([]uint16(nil), m.SupportedVersions...) + } + if m.CipherSuites != nil { + out.CipherSuites = append([]uint16(nil), m.CipherSuites...) + } + return &out +} + +// GetConfigForClientHelloFunc selects TLS config by sniffed TLS metadata. +type GetConfigForClientHelloFunc func(hello *ClientHelloMeta) (*tls.Config, error) + // ListenerConfig controls listener behavior. type ListenerConfig struct { // BaseTLSConfig is used for TLS when dynamic selection returns nil. BaseTLSConfig *tls.Config - // GetConfigForClient selects TLS config for a hostname. + // GetConfigForClient selects TLS config for a hostname/SNI. + // Deprecated: prefer GetConfigForClientHello for richer context. GetConfigForClient GetConfigForClientFunc + // GetConfigForClientHello selects TLS config for sniffed TLS metadata. + GetConfigForClientHello GetConfigForClientHelloFunc + // AllowNonTLS allows plain TCP fallback. AllowNonTLS bool diff --git a/tlssniffer.go b/tlssniffer.go index 16497e5..acc58c7 100644 --- a/tlssniffer.go +++ b/tlssniffer.go @@ -4,6 +4,8 @@ import ( "bytes" "context" "crypto/tls" + "encoding/binary" + "errors" "io" "net" "sync" @@ -34,9 +36,9 @@ func (c *replayConn) SetWriteDeadline(t time.Time) error { return c.conn.SetWrit // SniffResult describes protocol sniffing result. type SniffResult struct { - IsTLS bool - Hostname string - Buffer *bytes.Buffer + IsTLS bool + ClientHello *ClientHelloMeta + Buffer *bytes.Buffer } // Sniffer detects protocol and metadata from initial bytes. @@ -55,44 +57,210 @@ func (s TLSSniffer) Sniff(conn net.Conn, maxBytes int) (SniffResult, error) { var buf bytes.Buffer limited := &io.LimitedReader{R: conn, N: int64(maxBytes)} - tee := io.TeeReader(limited, &buf) - - var hello *tls.ClientHelloInfo - _ = tls.Server(readOnlyConn{r: tee, raw: conn}, &tls.Config{ - GetConfigForClient: func(ch *tls.ClientHelloInfo) (*tls.Config, error) { - cp := *ch - hello = &cp - return nil, nil - }, - }).Handshake() - - peek := buf.Bytes() - isTLS := len(peek) >= 3 && peek[0] == 0x16 && peek[1] == 0x03 + meta, isTLS := sniffClientHello(limited, &buf, conn) out := SniffResult{ IsTLS: isTLS, - Buffer: bytes.NewBuffer(append([]byte(nil), peek...)), + Buffer: bytes.NewBuffer(append([]byte(nil), buf.Bytes()...)), } - if hello != nil { - out.Hostname = hello.ServerName + if isTLS { + out.ClientHello = meta } return out, nil } -// readOnlyConn rejects writes/close and reads from a reader. -type readOnlyConn struct { - r io.Reader - raw net.Conn +func sniffClientHello(r io.Reader, buf *bytes.Buffer, conn net.Conn) (*ClientHelloMeta, bool) { + meta := &ClientHelloMeta{ + LocalAddr: conn.LocalAddr(), + RemoteAddr: conn.RemoteAddr(), + } + + header, complete := readTLSRecordHeader(r, buf) + if len(header) < 3 { + return nil, false + } + isTLS := header[0] == 0x16 && header[1] == 0x03 + if !isTLS { + return nil, false + } + if len(header) < 5 || !complete { + return meta, true + } + + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + recordBody, bodyOK := readBufferedBytes(r, buf, recordLen) + if !bodyOK { + return meta, true + } + if len(recordBody) < 4 || recordBody[0] != 0x01 { + return nil, false + } + + helloLen := int(recordBody[1])<<16 | int(recordBody[2])<<8 | int(recordBody[3]) + helloBytes := append([]byte(nil), recordBody[4:]...) + for len(helloBytes) < helloLen { + nextHeader, nextOK := readTLSRecordHeader(r, buf) + if len(nextHeader) < 5 || !nextOK { + return meta, true + } + if nextHeader[0] != 0x16 || nextHeader[1] != 0x03 { + return meta, true + } + nextLen := int(binary.BigEndian.Uint16(nextHeader[3:5])) + nextBody, nextBodyOK := readBufferedBytes(r, buf, nextLen) + if !nextBodyOK { + return meta, true + } + helloBytes = append(helloBytes, nextBody...) + } + + parseClientHelloBody(meta, helloBytes[:helloLen]) + return meta, true } -func (c readOnlyConn) Read(p []byte) (int, error) { return c.r.Read(p) } -func (c readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } -func (c readOnlyConn) Close() error { return nil } -func (c readOnlyConn) LocalAddr() net.Addr { return c.raw.LocalAddr() } -func (c readOnlyConn) RemoteAddr() net.Addr { return c.raw.RemoteAddr() } -func (c readOnlyConn) SetDeadline(_ time.Time) error { return nil } -func (c readOnlyConn) SetReadDeadline(_ time.Time) error { return nil } -func (c readOnlyConn) SetWriteDeadline(_ time.Time) error { return nil } +func readTLSRecordHeader(r io.Reader, buf *bytes.Buffer) ([]byte, bool) { + return readBufferedBytes(r, buf, 5) +} + +func readBufferedBytes(r io.Reader, buf *bytes.Buffer, n int) ([]byte, bool) { + if n <= 0 { + return nil, true + } + tmp := make([]byte, n) + readN, err := io.ReadFull(r, tmp) + if readN > 0 { + buf.Write(tmp[:readN]) + } + return append([]byte(nil), tmp[:readN]...), err == nil +} + +func parseClientHelloBody(meta *ClientHelloMeta, body []byte) { + if meta == nil || len(body) < 34 { + return + } + + offset := 2 + 32 + sessionIDLen := int(body[offset]) + offset++ + if offset+sessionIDLen > len(body) { + return + } + offset += sessionIDLen + + if offset+2 > len(body) { + return + } + cipherSuitesLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) + offset += 2 + if offset+cipherSuitesLen > len(body) { + return + } + for i := 0; i+1 < cipherSuitesLen; i += 2 { + meta.CipherSuites = append(meta.CipherSuites, binary.BigEndian.Uint16(body[offset+i:offset+i+2])) + } + offset += cipherSuitesLen + + if offset >= len(body) { + return + } + compressionMethodsLen := int(body[offset]) + offset++ + if offset+compressionMethodsLen > len(body) { + return + } + offset += compressionMethodsLen + + if offset+2 > len(body) { + return + } + extensionsLen := int(binary.BigEndian.Uint16(body[offset : offset+2])) + offset += 2 + if offset+extensionsLen > len(body) { + return + } + + parseClientHelloExtensions(meta, body[offset:offset+extensionsLen]) +} + +func parseClientHelloExtensions(meta *ClientHelloMeta, exts []byte) { + for offset := 0; offset+4 <= len(exts); { + extType := binary.BigEndian.Uint16(exts[offset : offset+2]) + extLen := int(binary.BigEndian.Uint16(exts[offset+2 : offset+4])) + offset += 4 + if offset+extLen > len(exts) { + return + } + extData := exts[offset : offset+extLen] + offset += extLen + + switch extType { + case 0: + parseServerNameExtension(meta, extData) + case 16: + parseALPNExtension(meta, extData) + case 43: + parseSupportedVersionsExtension(meta, extData) + } + } +} + +func parseServerNameExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 2 { + return + } + listLen := int(binary.BigEndian.Uint16(data[:2])) + if listLen == 0 || 2+listLen > len(data) { + return + } + list := data[2 : 2+listLen] + for offset := 0; offset+3 <= len(list); { + nameType := list[offset] + nameLen := int(binary.BigEndian.Uint16(list[offset+1 : offset+3])) + offset += 3 + if offset+nameLen > len(list) { + return + } + if nameType == 0 { + meta.ServerName = string(list[offset : offset+nameLen]) + return + } + offset += nameLen + } +} + +func parseALPNExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 2 { + return + } + listLen := int(binary.BigEndian.Uint16(data[:2])) + if listLen == 0 || 2+listLen > len(data) { + return + } + list := data[2 : 2+listLen] + for offset := 0; offset < len(list); { + nameLen := int(list[offset]) + offset++ + if offset+nameLen > len(list) { + return + } + meta.SupportedProtos = append(meta.SupportedProtos, string(list[offset:offset+nameLen])) + offset += nameLen + } +} + +func parseSupportedVersionsExtension(meta *ClientHelloMeta, data []byte) { + if len(data) < 1 { + return + } + listLen := int(data[0]) + if listLen == 0 || 1+listLen > len(data) { + return + } + list := data[1 : 1+listLen] + for offset := 0; offset+1 < len(list); offset += 2 { + meta.SupportedVersions = append(meta.SupportedVersions, binary.BigEndian.Uint16(list[offset:offset+2])) + } +} // Conn wraps net.Conn with lazy protocol initialization. type Conn struct { @@ -106,17 +274,18 @@ type Conn struct { tlsConn *tls.Conn plainConn net.Conn - hostname string + clientHello *ClientHelloMeta - baseTLSConfig *tls.Config - getConfigForClient GetConfigForClientFunc - allowNonTLS bool - sniffer Sniffer - sniffTimeout time.Duration - maxClientHello int - logger Logger - stats *Stats - skipSniff bool + baseTLSConfig *tls.Config + getConfigForClient GetConfigForClientFunc + getConfigForClientHello GetConfigForClientHelloFunc + allowNonTLS bool + sniffer Sniffer + sniffTimeout time.Duration + maxClientHello int + logger Logger + stats *Stats + skipSniff bool } func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { @@ -125,6 +294,7 @@ func newConn(raw net.Conn, cfg ListenerConfig, stats *Stats) *Conn { plainConn: raw, baseTLSConfig: cfg.BaseTLSConfig, getConfigForClient: cfg.GetConfigForClient, + getConfigForClientHello: cfg.GetConfigForClientHello, allowNonTLS: cfg.AllowNonTLS, sniffer: TLSSniffer{}, sniffTimeout: cfg.SniffTimeout, @@ -139,7 +309,7 @@ func (c *Conn) init() { if c.skipSniff { return } - if c.baseTLSConfig == nil && c.getConfigForClient == nil { + if c.baseTLSConfig == nil && c.getConfigForClient == nil && c.getConfigForClientHello == nil { c.isTLS = false return } @@ -152,13 +322,13 @@ func (c *Conn) init() { _ = c.Conn.SetReadDeadline(time.Time{}) } if err != nil { - c.initErr = err - c.failAndClose("sniff failed: %v", err) + c.initErr = errors.Join(ErrTLSSniffFailed, err) + c.failSniff(err) return } c.isTLS = res.IsTLS - c.hostname = res.Hostname + c.clientHello = res.ClientHello if c.isTLS { if c.stats != nil { @@ -166,8 +336,8 @@ func (c *Conn) init() { } tlsCfg, errCfg := c.selectTLSConfig() if errCfg != nil { - c.initErr = errCfg - c.failAndClose("tls config select failed: %v", errCfg) + c.initErr = errors.Join(ErrTLSConfigSelectionFailed, errCfg) + c.failTLSConfigSelection(errCfg) return } rc := newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) @@ -180,7 +350,7 @@ func (c *Conn) init() { } if !c.allowNonTLS { c.initErr = ErrNonTLSNotAllowed - c.failAndClose("plain tcp rejected") + c.failPlainRejected() return } c.plainConn = newReplayConn(bytes.NewBuffer(res.Buffer.Bytes()), c.Conn) @@ -188,27 +358,57 @@ func (c *Conn) init() { } func (c *Conn) failAndClose(format string, v ...interface{}) { - if c.stats != nil { - c.stats.incInitFailures() - } if c.logger != nil { c.logger.Printf("starnet: "+format, v...) } _ = c.Close() } +func (c *Conn) failSniff(err error) { + if c.stats != nil { + c.stats.incSniffFailures() + } + c.failAndClose("tls sniff failed: %v", err) +} + +func (c *Conn) failTLSConfigSelection(err error) { + if c.stats != nil { + c.stats.incTLSConfigFailures() + } + c.failAndClose("tls config selection failed: %v", err) +} + +func (c *Conn) failPlainRejected() { + if c.stats != nil { + c.stats.incPlainRejected() + } + c.failAndClose("plain tcp rejected") +} + func (c *Conn) selectTLSConfig() (*tls.Config, error) { - if c.getConfigForClient != nil { - cfg, err := c.getConfigForClient(c.hostname) + var selected *tls.Config + if c.getConfigForClientHello != nil { + cfg, err := c.getConfigForClientHello(c.clientHello.Clone()) if err != nil { return nil, err } if cfg != nil { - return cfg, nil + selected = cfg } } - if c.baseTLSConfig != nil { - return c.baseTLSConfig, nil + if selected == nil && c.getConfigForClient != nil { + cfg, err := c.getConfigForClient(c.serverName()) + if err != nil { + return nil, err + } + if cfg != nil { + selected = cfg + } + } + + composed := composeServerTLSConfig(c.baseTLSConfig, selected) + if composed != nil { + return composed, nil } return nil, ErrNoTLSConfig } @@ -216,7 +416,140 @@ func (c *Conn) selectTLSConfig() (*tls.Config, error) { // Hostname returns sniffed SNI hostname (if any). func (c *Conn) Hostname() string { c.init() - return c.hostname + return c.serverName() +} + +// ClientHello returns sniffed TLS metadata (if any). +func (c *Conn) ClientHello() *ClientHelloMeta { + c.init() + return c.clientHello.Clone() +} + +func (c *Conn) serverName() string { + if c.clientHello == nil { + return "" + } + return c.clientHello.ServerName +} + +func composeServerTLSConfig(base, selected *tls.Config) *tls.Config { + if base == nil { + return selected + } + if selected == nil { + return base + } + + out := base.Clone() + applyServerTLSOverrides(out, selected) + return out +} + +func applyServerTLSOverrides(dst, src *tls.Config) { + if dst == nil || src == nil { + return + } + + if src.Rand != nil { + dst.Rand = src.Rand + } + if src.Time != nil { + dst.Time = src.Time + } + if len(src.Certificates) > 0 { + dst.Certificates = append([]tls.Certificate(nil), src.Certificates...) + } + if len(src.NameToCertificate) > 0 { + m := make(map[string]*tls.Certificate, len(src.NameToCertificate)) + for k, v := range src.NameToCertificate { + m[k] = v + } + dst.NameToCertificate = m + } + if src.GetCertificate != nil { + dst.GetCertificate = src.GetCertificate + } + if src.GetClientCertificate != nil { + dst.GetClientCertificate = src.GetClientCertificate + } + if src.GetConfigForClient != nil { + dst.GetConfigForClient = src.GetConfigForClient + } + if src.VerifyPeerCertificate != nil { + dst.VerifyPeerCertificate = src.VerifyPeerCertificate + } + if src.VerifyConnection != nil { + dst.VerifyConnection = src.VerifyConnection + } + if src.RootCAs != nil { + dst.RootCAs = src.RootCAs + } + if len(src.NextProtos) > 0 { + dst.NextProtos = append([]string(nil), src.NextProtos...) + } + if src.ServerName != "" { + dst.ServerName = src.ServerName + } + if src.ClientAuth > dst.ClientAuth { + dst.ClientAuth = src.ClientAuth + } + if src.ClientCAs != nil { + dst.ClientCAs = src.ClientCAs + } + if src.InsecureSkipVerify { + dst.InsecureSkipVerify = true + } + if len(src.CipherSuites) > 0 { + dst.CipherSuites = append([]uint16(nil), src.CipherSuites...) + } + if src.PreferServerCipherSuites { + dst.PreferServerCipherSuites = true + } + if src.SessionTicketsDisabled { + dst.SessionTicketsDisabled = true + } + if src.SessionTicketKey != ([32]byte{}) { + dst.SessionTicketKey = src.SessionTicketKey + } + if src.ClientSessionCache != nil { + dst.ClientSessionCache = src.ClientSessionCache + } + if src.UnwrapSession != nil { + dst.UnwrapSession = src.UnwrapSession + } + if src.WrapSession != nil { + dst.WrapSession = src.WrapSession + } + if src.MinVersion != 0 && (dst.MinVersion == 0 || src.MinVersion > dst.MinVersion) { + dst.MinVersion = src.MinVersion + } + if src.MaxVersion != 0 && (dst.MaxVersion == 0 || src.MaxVersion < dst.MaxVersion) { + dst.MaxVersion = src.MaxVersion + } + if len(src.CurvePreferences) > 0 { + dst.CurvePreferences = append([]tls.CurveID(nil), src.CurvePreferences...) + } + if src.DynamicRecordSizingDisabled { + dst.DynamicRecordSizingDisabled = true + } + if src.Renegotiation != 0 { + dst.Renegotiation = src.Renegotiation + } + if src.KeyLogWriter != nil { + dst.KeyLogWriter = src.KeyLogWriter + } + if len(src.EncryptedClientHelloConfigList) > 0 { + dst.EncryptedClientHelloConfigList = append([]byte(nil), src.EncryptedClientHelloConfigList...) + } + if src.EncryptedClientHelloRejectionVerify != nil { + dst.EncryptedClientHelloRejectionVerify = src.EncryptedClientHelloRejectionVerify + } + if src.GetEncryptedClientHelloKeys != nil { + dst.GetEncryptedClientHelloKeys = src.GetEncryptedClientHelloKeys + } + if len(src.EncryptedClientHelloKeys) > 0 { + dst.EncryptedClientHelloKeys = append([]tls.EncryptedClientHelloKey(nil), src.EncryptedClientHelloKeys...) + } } func (c *Conn) IsTLS() bool { @@ -365,6 +698,7 @@ func normalizeConfig(cfg ListenerConfig) ListenerConfig { out.MaxClientHelloBytes = cfg.MaxClientHelloBytes out.BaseTLSConfig = cfg.BaseTLSConfig out.GetConfigForClient = cfg.GetConfigForClient + out.GetConfigForClientHello = cfg.GetConfigForClientHello out.Logger = cfg.Logger if out.MaxClientHelloBytes <= 0 { out.MaxClientHelloBytes = 64 * 1024 @@ -471,7 +805,6 @@ func DialTLSWithConfig(network, address string, tlsCfg *tls.Config, timeout time plainConn: raw, isTLS: true, tlsConn: tc, - hostname: "", initErr: nil, allowNonTLS: false, skipSniff: true, diff --git a/tlssniffer_test.go b/tlssniffer_test.go index 6b523fe..c70f81a 100644 --- a/tlssniffer_test.go +++ b/tlssniffer_test.go @@ -1,12 +1,14 @@ package starnet import ( + "bytes" "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" + "encoding/binary" "encoding/pem" "errors" "io" @@ -18,6 +20,280 @@ import ( "time" ) +type staticSniffer struct { + result SniffResult + err error +} + +type fixedAddr string + +func (a fixedAddr) Network() string { return "tcp" } +func (a fixedAddr) String() string { return string(a) } + +type recordingConn struct { + buf bytes.Buffer +} + +func (c *recordingConn) Read([]byte) (int, error) { return 0, io.EOF } +func (c *recordingConn) Write(p []byte) (int, error) { return c.buf.Write(p) } +func (c *recordingConn) Close() error { return nil } +func (c *recordingConn) LocalAddr() net.Addr { return fixedAddr("127.0.0.1:443") } +func (c *recordingConn) RemoteAddr() net.Addr { return fixedAddr("127.0.0.1:50000") } +func (c *recordingConn) SetDeadline(time.Time) error { return nil } +func (c *recordingConn) SetReadDeadline(time.Time) error { return nil } +func (c *recordingConn) SetWriteDeadline(time.Time) error { return nil } + +type bytesConn struct { + reader io.Reader + local net.Addr + remote net.Addr +} + +func (c *bytesConn) Read(p []byte) (int, error) { return c.reader.Read(p) } +func (c *bytesConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *bytesConn) Close() error { return nil } +func (c *bytesConn) LocalAddr() net.Addr { return c.local } +func (c *bytesConn) RemoteAddr() net.Addr { return c.remote } +func (c *bytesConn) SetDeadline(time.Time) error { return nil } +func (c *bytesConn) SetReadDeadline(time.Time) error { return nil } +func (c *bytesConn) SetWriteDeadline(time.Time) error { return nil } + +func (s staticSniffer) Sniff(_ net.Conn, _ int) (SniffResult, error) { + return s.result, s.err +} + +func captureClientHelloBytes(t *testing.T, serverName string, nextProtos []string) []byte { + t.Helper() + + conn := &recordingConn{} + tc := tls.Client(conn, &tls.Config{ + InsecureSkipVerify: true, + ServerName: serverName, + NextProtos: nextProtos, + MinVersion: tls.VersionTLS12, + }) + _ = tc.Handshake() + _ = tc.Close() + + out := append([]byte(nil), conn.buf.Bytes()...) + if len(out) < 9 { + t.Fatalf("captured hello too short: %d", len(out)) + } + return out +} + +func prefixBytes(data []byte, n int) []byte { + if n < 0 { + n = 0 + } + if len(data) < n { + n = len(data) + } + return data[:n] +} + +func splitClientHelloRecord(t *testing.T, data []byte, splitBodyAt int) []byte { + t.Helper() + + if len(data) < 9 { + t.Fatalf("client hello too short: %d", len(data)) + } + if data[0] != 0x16 || data[1] != 0x03 { + t.Fatalf("unexpected tls record header: % x", prefixBytes(data, 5)) + } + recordLen := int(binary.BigEndian.Uint16(data[3:5])) + if len(data) < 5+recordLen { + t.Fatalf("short tls record: have=%d want=%d", len(data), 5+recordLen) + } + recordBody := append([]byte(nil), data[5:5+recordLen]...) + if len(recordBody) < 4 || recordBody[0] != 0x01 { + t.Fatalf("unexpected handshake record: % x", prefixBytes(recordBody, 4)) + } + if splitBodyAt <= 4 || splitBodyAt >= len(recordBody) { + t.Fatalf("invalid split point %d for body len %d", splitBodyAt, len(recordBody)) + } + + var out bytes.Buffer + writeRecord := func(body []byte) { + out.WriteByte(0x16) + out.WriteByte(data[1]) + out.WriteByte(data[2]) + header := []byte{0, 0} + binary.BigEndian.PutUint16(header, uint16(len(body))) + out.Write(header) + out.Write(body) + } + + writeRecord(recordBody[:splitBodyAt]) + writeRecord(recordBody[splitBodyAt:]) + return out.Bytes() +} + +func TestTLSSnifferParsesClientHello(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + done := make(chan error, 1) + go func() { + tc := tls.Client(client, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "example.com", + NextProtos: []string{"h2", "http/1.1"}, + MinVersion: tls.VersionTLS12, + }) + defer tc.Close() + done <- tc.Handshake() + }() + + _ = server.SetReadDeadline(time.Now().Add(2 * time.Second)) + res, err := (TLSSniffer{}).Sniff(server, 64*1024) + if err != nil { + t.Fatalf("sniff: %v", err) + } + + if !res.IsTLS { + t.Fatalf("expected TLS") + } + if res.ClientHello == nil { + t.Fatalf("expected client hello metadata") + } + if res.ClientHello.ServerName != "example.com" { + t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName) + } + if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" { + t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos) + } + if len(res.ClientHello.SupportedVersions) == 0 { + t.Fatalf("expected supported versions") + } + if len(res.ClientHello.CipherSuites) == 0 { + t.Fatalf("expected cipher suites") + } + + _ = server.Close() + <-done +} + +func TestTLSSnifferParsesFragmentedClientHelloRecords(t *testing.T) { + rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) + fragmented := splitClientHelloRecord(t, rawHello, 32) + + conn := &bytesConn{ + reader: bytes.NewReader(fragmented), + local: fixedAddr("127.0.0.1:443"), + remote: fixedAddr("127.0.0.1:50000"), + } + res, err := (TLSSniffer{}).Sniff(conn, 64*1024) + if err != nil { + t.Fatalf("sniff: %v", err) + } + if !res.IsTLS { + t.Fatalf("expected TLS") + } + if res.ClientHello == nil { + t.Fatalf("expected client hello metadata") + } + if res.ClientHello.ServerName != "example.com" { + t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName) + } + if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" { + t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos) + } + if len(res.ClientHello.SupportedVersions) == 0 { + t.Fatalf("expected supported versions") + } + if len(res.ClientHello.CipherSuites) == 0 { + t.Fatalf("expected cipher suites") + } +} + +func TestTLSSnifferKeepsTruncatedTLSAsTLS(t *testing.T) { + rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) + truncated := append([]byte(nil), rawHello[:20]...) + + conn := &bytesConn{ + reader: bytes.NewReader(truncated), + local: fixedAddr("127.0.0.1:443"), + remote: fixedAddr("127.0.0.1:50000"), + } + res, err := (TLSSniffer{}).Sniff(conn, 64*1024) + if err != nil { + t.Fatalf("sniff: %v", err) + } + if !res.IsTLS { + t.Fatalf("expected truncated hello to stay classified as TLS") + } + if res.ClientHello == nil { + t.Fatalf("expected client hello metadata") + } + if res.ClientHello.ServerName != "" { + t.Fatalf("expected empty server name for truncated hello, got %q", res.ClientHello.ServerName) + } + if res.ClientHello.LocalAddr == nil || res.ClientHello.RemoteAddr == nil { + t.Fatalf("expected local/remote addr in metadata") + } + if !bytes.Equal(res.Buffer.Bytes(), truncated) { + t.Fatalf("buffer mismatch: got %d bytes want %d", res.Buffer.Len(), len(truncated)) + } +} + +func TestTLSSnifferRespectsMaxBytesWhileKeepingTLSClassification(t *testing.T) { + rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) + limit := 5 + + conn := &bytesConn{ + reader: bytes.NewReader(rawHello), + local: fixedAddr("127.0.0.1:443"), + remote: fixedAddr("127.0.0.1:50000"), + } + res, err := (TLSSniffer{}).Sniff(conn, limit) + if err != nil { + t.Fatalf("sniff: %v", err) + } + if !res.IsTLS { + t.Fatalf("expected TLS classification with header-sized limit") + } + if res.ClientHello == nil { + t.Fatalf("expected client hello metadata") + } + if res.ClientHello.ServerName != "" { + t.Fatalf("expected empty server name with header-sized limit, got %q", res.ClientHello.ServerName) + } + if res.Buffer.Len() != limit { + t.Fatalf("buffer length mismatch: got %d want %d", res.Buffer.Len(), limit) + } + if !bytes.Equal(res.Buffer.Bytes(), rawHello[:limit]) { + t.Fatalf("buffer content mismatch") + } +} + +func TestTLSSnifferRejectsFakeTLSClientRecord(t *testing.T) { + fake := []byte{ + 0x16, 0x03, 0x03, 0x00, 0x04, + 'G', 'E', 'T', ' ', + } + + conn := &bytesConn{ + reader: bytes.NewReader(fake), + local: fixedAddr("127.0.0.1:443"), + remote: fixedAddr("127.0.0.1:50000"), + } + res, err := (TLSSniffer{}).Sniff(conn, 64*1024) + if err != nil { + t.Fatalf("sniff: %v", err) + } + if res.IsTLS { + t.Fatalf("expected fake tls-looking record to be classified as non-tls") + } + if res.ClientHello != nil { + t.Fatalf("expected no client hello metadata for fake record") + } + if !bytes.Equal(res.Buffer.Bytes(), fake) { + t.Fatalf("buffer mismatch") + } +} + // ---------- cert helpers ---------- func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) { @@ -377,6 +653,115 @@ func TestGetConfigForClientError(t *testing.T) { } } +func TestGetConfigForClientHelloReceivesMetadata(t *testing.T) { + certA := genSelfSignedCert(t, "a.local") + certB := genSelfSignedCert(t, "b.local") + + base := TLSDefaults() + base.Certificates = []tls.Certificate{certA} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + + metaCh := make(chan *ClientHelloMeta, 1) + cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) { + metaCh <- hello.Clone() + if hello != nil && hello.ServerName == "b.local" { + b := TLSDefaults() + b.Certificates = []tls.Certificate{certB} + return b, nil + } + return nil, nil + } + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + ServerName: "b.local", + MinVersion: tls.VersionTLS12, + }) + if err != nil { + t.Fatalf("tls dial: %v", err) + } + defer tc.Close() + + select { + case hello := <-metaCh: + if hello == nil { + t.Fatalf("expected metadata") + } + if hello.ServerName != "b.local" { + t.Fatalf("unexpected server name: %q", hello.ServerName) + } + if hello.LocalAddr == nil { + t.Fatalf("expected LocalAddr") + } + if hello.RemoteAddr == nil { + t.Fatalf("expected RemoteAddr") + } + if len(hello.CipherSuites) == 0 { + t.Fatalf("expected cipher suites in metadata") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting client hello metadata") + } +} + +func TestGetConfigForClientHelloWithoutSNIStillGetsLocalAddr(t *testing.T) { + cert := genSelfSignedCert(t, "localhost") + base := TLSDefaults() + base.Certificates = []tls.Certificate{cert} + + cfg := DefaultListenerConfig() + cfg.AllowNonTLS = false + cfg.BaseTLSConfig = base + + metaCh := make(chan *ClientHelloMeta, 1) + cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) { + metaCh <- hello.Clone() + return nil, nil + } + + _, addr, cleanup := startEchoServer(t, cfg) + defer cleanup() + + raw, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer raw.Close() + + tc := tls.Client(raw, &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS12, + }) + if err := tc.Handshake(); err != nil { + t.Fatalf("tls handshake: %v", err) + } + defer tc.Close() + + select { + case hello := <-metaCh: + if hello == nil { + t.Fatalf("expected metadata") + } + if hello.ServerName != "" { + t.Fatalf("expected empty SNI, got %q", hello.ServerName) + } + if hello.LocalAddr == nil { + t.Fatalf("expected LocalAddr") + } + if hello.RemoteAddr == nil { + t.Fatalf("expected RemoteAddr") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting client hello metadata") + } +} + func TestAcceptContextCancel(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true @@ -418,6 +803,140 @@ func TestListenerStats(t *testing.T) { } } +func TestComposeServerTLSConfigInheritsBaseDefaults(t *testing.T) { + base := TLSDefaults() + base.MinVersion = tls.VersionTLS13 + base.NextProtos = []string{"h2", "http/1.1"} + base.ClientAuth = tls.RequireAndVerifyClientCert + base.DynamicRecordSizingDisabled = true + + selected := TLSDefaults() + selected.Certificates = []tls.Certificate{genSelfSignedCert(t, "b.local")} + + composed := composeServerTLSConfig(base, selected) + if composed == nil { + t.Fatalf("expected composed config") + } + if composed.MinVersion != tls.VersionTLS13 { + t.Fatalf("expected min version from base, got %v", composed.MinVersion) + } + if len(composed.NextProtos) != 2 { + t.Fatalf("expected next protos from base, got %v", composed.NextProtos) + } + if composed.ClientAuth != tls.RequireAndVerifyClientCert { + t.Fatalf("expected client auth from base, got %v", composed.ClientAuth) + } + if !composed.DynamicRecordSizingDisabled { + t.Fatalf("expected dynamic record sizing disabled to inherit") + } + if len(composed.Certificates) != 1 { + t.Fatalf("expected selected certificates to be preserved") + } +} + +func TestComposeServerTLSConfigKeepsSelectedOverrides(t *testing.T) { + base := TLSDefaults() + base.MinVersion = tls.VersionTLS12 + base.NextProtos = []string{"http/1.1"} + + selected := TLSDefaults() + selected.MinVersion = tls.VersionTLS13 + selected.NextProtos = []string{"h2"} + + composed := composeServerTLSConfig(base, selected) + if composed.MinVersion != tls.VersionTLS13 { + t.Fatalf("expected selected min version, got %v", composed.MinVersion) + } + if len(composed.NextProtos) != 1 || composed.NextProtos[0] != "h2" { + t.Fatalf("expected selected next protos, got %v", composed.NextProtos) + } +} + +func TestConnInitFailureStats(t *testing.T) { + t.Run("sniff failure", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + stats := &Stats{} + cfg := DefaultListenerConfig() + cfg.BaseTLSConfig = TLSDefaults() + conn := newConn(server, cfg, stats) + conn.sniffer = staticSniffer{err: io.ErrUnexpectedEOF} + + buf := make([]byte, 1) + _, err := conn.Read(buf) + if err == nil || !errors.Is(err, ErrTLSSniffFailed) { + t.Fatalf("expected tls sniff failure, got %v", err) + } + + snap := stats.Snapshot() + if snap.InitFailures != 1 || snap.SniffFailures != 1 { + t.Fatalf("unexpected sniff failure stats: %+v", snap) + } + }) + + t.Run("tls config selection failure", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + stats := &Stats{} + cfg := DefaultListenerConfig() + cfg.GetConfigForClientHello = func(*ClientHelloMeta) (*tls.Config, error) { + return nil, errors.New("boom") + } + conn := newConn(server, cfg, stats) + conn.sniffer = staticSniffer{ + result: SniffResult{ + IsTLS: true, + ClientHello: &ClientHelloMeta{}, + Buffer: bytes.NewBuffer(nil), + }, + } + + buf := make([]byte, 1) + _, err := conn.Read(buf) + if err == nil || !errors.Is(err, ErrTLSConfigSelectionFailed) { + t.Fatalf("expected tls config selection failure, got %v", err) + } + + snap := stats.Snapshot() + if snap.InitFailures != 1 || snap.TLSConfigFailures != 1 { + t.Fatalf("unexpected tls config failure stats: %+v", snap) + } + }) + + t.Run("plain rejected", func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + stats := &Stats{} + cfg := DefaultListenerConfig() + cfg.BaseTLSConfig = TLSDefaults() + cfg.AllowNonTLS = false + conn := newConn(server, cfg, stats) + conn.sniffer = staticSniffer{ + result: SniffResult{ + IsTLS: false, + Buffer: bytes.NewBuffer(nil), + }, + } + + buf := make([]byte, 1) + _, err := conn.Read(buf) + if !errors.Is(err, ErrNonTLSNotAllowed) { + t.Fatalf("expected plain rejection, got %v", err) + } + + snap := stats.Snapshot() + if snap.InitFailures != 1 || snap.PlainRejected != 1 { + t.Fatalf("unexpected plain rejection stats: %+v", snap) + } + }) +} + func TestDialAndDialWithConfig(t *testing.T) { nl, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { diff --git a/tlsstats.go b/tlsstats.go index 01ad51e..e0a65cc 100644 --- a/tlsstats.go +++ b/tlsstats.go @@ -4,36 +4,48 @@ import "sync/atomic" // StatsSnapshot is a read-only copy of runtime counters. type StatsSnapshot struct { - Accepted uint64 - TLSDetected uint64 - PlainDetected uint64 - InitFailures uint64 - Closed uint64 + Accepted uint64 + TLSDetected uint64 + PlainDetected uint64 + InitFailures uint64 + SniffFailures uint64 + TLSConfigFailures uint64 + PlainRejected uint64 + Closed uint64 } // Stats provides lock-free counters. type Stats struct { - accepted uint64 - tlsDetected uint64 - plainDetected uint64 - initFailures uint64 - closed uint64 + accepted uint64 + tlsDetected uint64 + plainDetected uint64 + initFailures uint64 + sniffFailures uint64 + tlsConfigFailures uint64 + plainRejected uint64 + closed uint64 } -func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) } -func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) } -func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) } -func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) } -func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) } +func (s *Stats) incAccepted() { atomic.AddUint64(&s.accepted, 1) } +func (s *Stats) incTLSDetected() { atomic.AddUint64(&s.tlsDetected, 1) } +func (s *Stats) incPlainDetected() { atomic.AddUint64(&s.plainDetected, 1) } +func (s *Stats) incInitFailures() { atomic.AddUint64(&s.initFailures, 1) } +func (s *Stats) incClosed() { atomic.AddUint64(&s.closed, 1) } +func (s *Stats) incSniffFailures() { atomic.AddUint64(&s.sniffFailures, 1); s.incInitFailures() } +func (s *Stats) incTLSConfigFailures() { atomic.AddUint64(&s.tlsConfigFailures, 1); s.incInitFailures() } +func (s *Stats) incPlainRejected() { atomic.AddUint64(&s.plainRejected, 1); s.incInitFailures() } // Snapshot returns a stable view of counters. func (s *Stats) Snapshot() StatsSnapshot { return StatsSnapshot{ - Accepted: atomic.LoadUint64(&s.accepted), - TLSDetected: atomic.LoadUint64(&s.tlsDetected), - PlainDetected: atomic.LoadUint64(&s.plainDetected), - InitFailures: atomic.LoadUint64(&s.initFailures), - Closed: atomic.LoadUint64(&s.closed), + Accepted: atomic.LoadUint64(&s.accepted), + TLSDetected: atomic.LoadUint64(&s.tlsDetected), + PlainDetected: atomic.LoadUint64(&s.plainDetected), + InitFailures: atomic.LoadUint64(&s.initFailures), + SniffFailures: atomic.LoadUint64(&s.sniffFailures), + TLSConfigFailures: atomic.LoadUint64(&s.tlsConfigFailures), + PlainRejected: atomic.LoadUint64(&s.plainRejected), + Closed: atomic.LoadUint64(&s.closed), } }