package starnet import ( "bytes" "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/binary" "encoding/pem" "errors" "io" "math/big" "net" "os" "sync" "testing" "time" ) type staticSniffer struct { result SniffResult err error } type fixedAddr string func (a fixedAddr) Network() string { return "tcp" } func (a fixedAddr) String() string { return string(a) } type recordingConn struct { buf bytes.Buffer } func (c *recordingConn) Read([]byte) (int, error) { return 0, io.EOF } func (c *recordingConn) Write(p []byte) (int, error) { return c.buf.Write(p) } func (c *recordingConn) Close() error { return nil } func (c *recordingConn) LocalAddr() net.Addr { return fixedAddr("127.0.0.1:443") } func (c *recordingConn) RemoteAddr() net.Addr { return fixedAddr("127.0.0.1:50000") } func (c *recordingConn) SetDeadline(time.Time) error { return nil } func (c *recordingConn) SetReadDeadline(time.Time) error { return nil } func (c *recordingConn) SetWriteDeadline(time.Time) error { return nil } type bytesConn struct { reader io.Reader local net.Addr remote net.Addr } func (c *bytesConn) Read(p []byte) (int, error) { return c.reader.Read(p) } func (c *bytesConn) Write(p []byte) (int, error) { return len(p), nil } func (c *bytesConn) Close() error { return nil } func (c *bytesConn) LocalAddr() net.Addr { return c.local } func (c *bytesConn) RemoteAddr() net.Addr { return c.remote } func (c *bytesConn) SetDeadline(time.Time) error { return nil } func (c *bytesConn) SetReadDeadline(time.Time) error { return nil } func (c *bytesConn) SetWriteDeadline(time.Time) error { return nil } func (s staticSniffer) Sniff(_ net.Conn, _ int) (SniffResult, error) { return s.result, s.err } func captureClientHelloBytes(t *testing.T, serverName string, nextProtos []string) []byte { t.Helper() conn := &recordingConn{} tc := tls.Client(conn, &tls.Config{ InsecureSkipVerify: true, ServerName: serverName, NextProtos: nextProtos, MinVersion: tls.VersionTLS12, }) _ = tc.Handshake() _ = tc.Close() out := append([]byte(nil), conn.buf.Bytes()...) if len(out) < 9 { t.Fatalf("captured hello too short: %d", len(out)) } return out } func prefixBytes(data []byte, n int) []byte { if n < 0 { n = 0 } if len(data) < n { n = len(data) } return data[:n] } func splitClientHelloRecord(t *testing.T, data []byte, splitBodyAt int) []byte { t.Helper() if len(data) < 9 { t.Fatalf("client hello too short: %d", len(data)) } if data[0] != 0x16 || data[1] != 0x03 { t.Fatalf("unexpected tls record header: % x", prefixBytes(data, 5)) } recordLen := int(binary.BigEndian.Uint16(data[3:5])) if len(data) < 5+recordLen { t.Fatalf("short tls record: have=%d want=%d", len(data), 5+recordLen) } recordBody := append([]byte(nil), data[5:5+recordLen]...) if len(recordBody) < 4 || recordBody[0] != 0x01 { t.Fatalf("unexpected handshake record: % x", prefixBytes(recordBody, 4)) } if splitBodyAt <= 4 || splitBodyAt >= len(recordBody) { t.Fatalf("invalid split point %d for body len %d", splitBodyAt, len(recordBody)) } var out bytes.Buffer writeRecord := func(body []byte) { out.WriteByte(0x16) out.WriteByte(data[1]) out.WriteByte(data[2]) header := []byte{0, 0} binary.BigEndian.PutUint16(header, uint16(len(body))) out.Write(header) out.Write(body) } writeRecord(recordBody[:splitBodyAt]) writeRecord(recordBody[splitBodyAt:]) return out.Bytes() } func TestTLSSnifferParsesClientHello(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() done := make(chan error, 1) go func() { tc := tls.Client(client, &tls.Config{ InsecureSkipVerify: true, ServerName: "example.com", NextProtos: []string{"h2", "http/1.1"}, MinVersion: tls.VersionTLS12, }) defer tc.Close() done <- tc.Handshake() }() _ = server.SetReadDeadline(time.Now().Add(2 * time.Second)) res, err := (TLSSniffer{}).Sniff(server, 64*1024) if err != nil { t.Fatalf("sniff: %v", err) } if !res.IsTLS { t.Fatalf("expected TLS") } if res.ClientHello == nil { t.Fatalf("expected client hello metadata") } if res.ClientHello.ServerName != "example.com" { t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName) } if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" { t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos) } if len(res.ClientHello.SupportedVersions) == 0 { t.Fatalf("expected supported versions") } if len(res.ClientHello.CipherSuites) == 0 { t.Fatalf("expected cipher suites") } _ = server.Close() <-done } func TestTLSSnifferParsesFragmentedClientHelloRecords(t *testing.T) { rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) fragmented := splitClientHelloRecord(t, rawHello, 32) conn := &bytesConn{ reader: bytes.NewReader(fragmented), local: fixedAddr("127.0.0.1:443"), remote: fixedAddr("127.0.0.1:50000"), } res, err := (TLSSniffer{}).Sniff(conn, 64*1024) if err != nil { t.Fatalf("sniff: %v", err) } if !res.IsTLS { t.Fatalf("expected TLS") } if res.ClientHello == nil { t.Fatalf("expected client hello metadata") } if res.ClientHello.ServerName != "example.com" { t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName) } if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" { t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos) } if len(res.ClientHello.SupportedVersions) == 0 { t.Fatalf("expected supported versions") } if len(res.ClientHello.CipherSuites) == 0 { t.Fatalf("expected cipher suites") } } func TestTLSSnifferKeepsTruncatedTLSAsTLS(t *testing.T) { rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) truncated := append([]byte(nil), rawHello[:20]...) conn := &bytesConn{ reader: bytes.NewReader(truncated), local: fixedAddr("127.0.0.1:443"), remote: fixedAddr("127.0.0.1:50000"), } res, err := (TLSSniffer{}).Sniff(conn, 64*1024) if err != nil { t.Fatalf("sniff: %v", err) } if !res.IsTLS { t.Fatalf("expected truncated hello to stay classified as TLS") } if res.ClientHello == nil { t.Fatalf("expected client hello metadata") } if res.ClientHello.ServerName != "" { t.Fatalf("expected empty server name for truncated hello, got %q", res.ClientHello.ServerName) } if res.ClientHello.LocalAddr == nil || res.ClientHello.RemoteAddr == nil { t.Fatalf("expected local/remote addr in metadata") } if !bytes.Equal(res.Buffer.Bytes(), truncated) { t.Fatalf("buffer mismatch: got %d bytes want %d", res.Buffer.Len(), len(truncated)) } } func TestTLSSnifferRespectsMaxBytesWhileKeepingTLSClassification(t *testing.T) { rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"}) limit := 5 conn := &bytesConn{ reader: bytes.NewReader(rawHello), local: fixedAddr("127.0.0.1:443"), remote: fixedAddr("127.0.0.1:50000"), } res, err := (TLSSniffer{}).Sniff(conn, limit) if err != nil { t.Fatalf("sniff: %v", err) } if !res.IsTLS { t.Fatalf("expected TLS classification with header-sized limit") } if res.ClientHello == nil { t.Fatalf("expected client hello metadata") } if res.ClientHello.ServerName != "" { t.Fatalf("expected empty server name with header-sized limit, got %q", res.ClientHello.ServerName) } if res.Buffer.Len() != limit { t.Fatalf("buffer length mismatch: got %d want %d", res.Buffer.Len(), limit) } if !bytes.Equal(res.Buffer.Bytes(), rawHello[:limit]) { t.Fatalf("buffer content mismatch") } } func TestTLSSnifferRejectsFakeTLSClientRecord(t *testing.T) { fake := []byte{ 0x16, 0x03, 0x03, 0x00, 0x04, 'G', 'E', 'T', ' ', } conn := &bytesConn{ reader: bytes.NewReader(fake), local: fixedAddr("127.0.0.1:443"), remote: fixedAddr("127.0.0.1:50000"), } res, err := (TLSSniffer{}).Sniff(conn, 64*1024) if err != nil { t.Fatalf("sniff: %v", err) } if res.IsTLS { t.Fatalf("expected fake tls-looking record to be classified as non-tls") } if res.ClientHello != nil { t.Fatalf("expected no client hello metadata for fake record") } if !bytes.Equal(res.Buffer.Bytes(), fake) { t.Fatalf("buffer mismatch") } } // ---------- cert helpers ---------- func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("GenerateKey: %v", err) } serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) if err != nil { t.Fatalf("serial: %v", err) } tpl := &x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: "starnet-test", }, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, DNSNames: dnsNames, } der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv) if err != nil { t.Fatalf("CreateCertificate: %v", err) } certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) return certPEM, keyPEM } func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate { t.Helper() certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { t.Fatalf("X509KeyPair: %v", err) } return cert } func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) { t.Helper() certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) cf, err := os.CreateTemp("", "starnet-cert-*.pem") if err != nil { t.Fatalf("CreateTemp cert: %v", err) } kf, err := os.CreateTemp("", "starnet-key-*.pem") if err != nil { _ = cf.Close() _ = os.Remove(cf.Name()) t.Fatalf("CreateTemp key: %v", err) } if _, err := cf.Write(certPEM); err != nil { t.Fatalf("write cert: %v", err) } if _, err := kf.Write(keyPEM); err != nil { t.Fatalf("write key: %v", err) } _ = cf.Close() _ = kf.Close() return cf.Name(), kf.Name(), func() { _ = os.Remove(cf.Name()) _ = os.Remove(kf.Name()) } } // ---------- server helpers ---------- func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) { t.Helper() ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("ListenWithConfig: %v", err) } var wg sync.WaitGroup stop := make(chan struct{}) wg.Add(1) go func() { defer wg.Done() for { c, err := ln.Accept() if err != nil { select { case <-stop: return default: return } } go func(conn net.Conn) { defer conn.Close() _, _ = io.Copy(conn, conn) }(c) } }() cleanup := func() { close(stop) _ = ln.Close() wg.Wait() } return ln, ln.Addr().String(), cleanup } // ---------- tests ---------- func TestListen(t *testing.T) { ln, err := Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("dial: %v", err) } defer c.Close() msg := []byte("x") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, 1) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } } func TestListenWithListenConfig(t *testing.T) { lc := net.ListenConfig{} cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("ListenWithListenConfig: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("dial: %v", err) } defer c.Close() msg := []byte("ok") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, 2) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } } func TestListenerSetConfig(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() cfg2 := cfg cfg2.SniffTimeout = time.Second ln.SetConfig(cfg2) got := ln.Config() if got.SniffTimeout != time.Second { t.Fatalf("SetConfig not applied") } } func TestPlainAllowed(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true cfg.BaseTLSConfig = nil _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Dial: %v", err) } defer c.Close() msg := []byte("hello-plain") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, len(msg)) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } if string(buf) != string(msg) { t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg)) } } func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Dial: %v", err) } defer c.Close() _, _ = c.Write([]byte("plain")) _ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) b := make([]byte, 1) _, err = c.Read(b) if err == nil { t.Fatalf("expected read error due to non-tls rejection") } } func TestTLSHandshakeAndEcho(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() msg := []byte("hello-tls") if _, err := tc.Write(msg); err != nil { t.Fatalf("tls write: %v", err) } buf := make([]byte, len(msg)) if _, err := io.ReadFull(tc, buf); err != nil { t.Fatalf("tls read: %v", err) } if string(buf) != string(msg) { t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg)) } } func TestDynamicConfigBySNI(t *testing.T) { certA := genSelfSignedCert(t, "a.local") certB := genSelfSignedCert(t, "b.local") base := TLSDefaults() base.Certificates = []tls.Certificate{certA} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base cfg.GetConfigForClient = func(host string) (*tls.Config, error) { if host == "b.local" { b := TLSDefaults() b.Certificates = []tls.Certificate{certB} return b, nil } return nil, nil } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "b.local", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() if !tc.ConnectionState().HandshakeComplete { t.Fatalf("handshake not complete") } } func TestGetConfigForClientError(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base cfg.GetConfigForClient = func(host string) (*tls.Config, error) { return nil, errors.New("boom") } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() _, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }) if err == nil { t.Fatalf("expected tls dial failure due to selector error") } } func TestGetConfigForClientHelloReceivesMetadata(t *testing.T) { certA := genSelfSignedCert(t, "a.local") certB := genSelfSignedCert(t, "b.local") base := TLSDefaults() base.Certificates = []tls.Certificate{certA} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base metaCh := make(chan *ClientHelloMeta, 1) cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) { metaCh <- hello.Clone() if hello != nil && hello.ServerName == "b.local" { b := TLSDefaults() b.Certificates = []tls.Certificate{certB} return b, nil } return nil, nil } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "b.local", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() select { case hello := <-metaCh: if hello == nil { t.Fatalf("expected metadata") } if hello.ServerName != "b.local" { t.Fatalf("unexpected server name: %q", hello.ServerName) } if hello.LocalAddr == nil { t.Fatalf("expected LocalAddr") } if hello.RemoteAddr == nil { t.Fatalf("expected RemoteAddr") } if len(hello.CipherSuites) == 0 { t.Fatalf("expected cipher suites in metadata") } case <-time.After(2 * time.Second): t.Fatalf("timeout waiting client hello metadata") } } func TestGetConfigForClientHelloWithoutSNIStillGetsLocalAddr(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base metaCh := make(chan *ClientHelloMeta, 1) cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) { metaCh <- hello.Clone() return nil, nil } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() raw, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("dial: %v", err) } defer raw.Close() tc := tls.Client(raw, &tls.Config{ InsecureSkipVerify: true, MinVersion: tls.VersionTLS12, }) if err := tc.Handshake(); err != nil { t.Fatalf("tls handshake: %v", err) } defer tc.Close() select { case hello := <-metaCh: if hello == nil { t.Fatalf("expected metadata") } if hello.ServerName != "" { t.Fatalf("expected empty SNI, got %q", hello.ServerName) } if hello.LocalAddr == nil { t.Fatalf("expected LocalAddr") } if hello.RemoteAddr == nil { t.Fatalf("expected RemoteAddr") } case <-time.After(2 * time.Second): t.Fatalf("timeout waiting client hello metadata") } } func TestAcceptContextCancel(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err = ln.AcceptContext(ctx) if err == nil { t.Fatalf("expected context timeout/cancel") } } func TestListenerStats(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("dial: %v", err) } _, _ = c.Write([]byte("x")) _ = c.Close() time.Sleep(100 * time.Millisecond) s := ln.Stats() if s.Accepted == 0 { t.Fatalf("expected accepted > 0") } } func TestComposeServerTLSConfigInheritsBaseDefaults(t *testing.T) { base := TLSDefaults() base.MinVersion = tls.VersionTLS13 base.NextProtos = []string{"h2", "http/1.1"} base.ClientAuth = tls.RequireAndVerifyClientCert base.DynamicRecordSizingDisabled = true selected := TLSDefaults() selected.Certificates = []tls.Certificate{genSelfSignedCert(t, "b.local")} composed := composeServerTLSConfig(base, selected) if composed == nil { t.Fatalf("expected composed config") } if composed.MinVersion != tls.VersionTLS13 { t.Fatalf("expected min version from base, got %v", composed.MinVersion) } if len(composed.NextProtos) != 2 { t.Fatalf("expected next protos from base, got %v", composed.NextProtos) } if composed.ClientAuth != tls.RequireAndVerifyClientCert { t.Fatalf("expected client auth from base, got %v", composed.ClientAuth) } if !composed.DynamicRecordSizingDisabled { t.Fatalf("expected dynamic record sizing disabled to inherit") } if len(composed.Certificates) != 1 { t.Fatalf("expected selected certificates to be preserved") } } func TestComposeServerTLSConfigKeepsSelectedOverrides(t *testing.T) { base := TLSDefaults() base.MinVersion = tls.VersionTLS12 base.NextProtos = []string{"http/1.1"} selected := TLSDefaults() selected.MinVersion = tls.VersionTLS13 selected.NextProtos = []string{"h2"} composed := composeServerTLSConfig(base, selected) if composed.MinVersion != tls.VersionTLS13 { t.Fatalf("expected selected min version, got %v", composed.MinVersion) } if len(composed.NextProtos) != 1 || composed.NextProtos[0] != "h2" { t.Fatalf("expected selected next protos, got %v", composed.NextProtos) } } func TestConnInitFailureStats(t *testing.T) { t.Run("sniff failure", func(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() stats := &Stats{} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = TLSDefaults() conn := newConn(server, cfg, stats) conn.sniffer = staticSniffer{err: io.ErrUnexpectedEOF} buf := make([]byte, 1) _, err := conn.Read(buf) if err == nil || !errors.Is(err, ErrTLSSniffFailed) { t.Fatalf("expected tls sniff failure, got %v", err) } snap := stats.Snapshot() if snap.InitFailures != 1 || snap.SniffFailures != 1 { t.Fatalf("unexpected sniff failure stats: %+v", snap) } }) t.Run("tls config selection failure", func(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() stats := &Stats{} cfg := DefaultListenerConfig() cfg.GetConfigForClientHello = func(*ClientHelloMeta) (*tls.Config, error) { return nil, errors.New("boom") } conn := newConn(server, cfg, stats) conn.sniffer = staticSniffer{ result: SniffResult{ IsTLS: true, ClientHello: &ClientHelloMeta{}, Buffer: bytes.NewBuffer(nil), }, } buf := make([]byte, 1) _, err := conn.Read(buf) if err == nil || !errors.Is(err, ErrTLSConfigSelectionFailed) { t.Fatalf("expected tls config selection failure, got %v", err) } snap := stats.Snapshot() if snap.InitFailures != 1 || snap.TLSConfigFailures != 1 { t.Fatalf("unexpected tls config failure stats: %+v", snap) } }) t.Run("plain rejected", func(t *testing.T) { client, server := net.Pipe() defer client.Close() defer server.Close() stats := &Stats{} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = TLSDefaults() cfg.AllowNonTLS = false conn := newConn(server, cfg, stats) conn.sniffer = staticSniffer{ result: SniffResult{ IsTLS: false, Buffer: bytes.NewBuffer(nil), }, } buf := make([]byte, 1) _, err := conn.Read(buf) if !errors.Is(err, ErrNonTLSNotAllowed) { t.Fatalf("expected plain rejection, got %v", err) } snap := stats.Snapshot() if snap.InitFailures != 1 || snap.PlainRejected != 1 { t.Fatalf("unexpected plain rejection stats: %+v", snap) } }) } func TestDialAndDialWithConfig(t *testing.T) { nl, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer nl.Close() go func() { c, err := nl.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c1, err := Dial("tcp", nl.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } defer c1.Close() msg := []byte("abc") if _, err := c1.Write(msg); err != nil { t.Fatalf("c1 write: %v", err) } got := make([]byte, 3) if _, err := io.ReadFull(c1, got); err != nil { t.Fatalf("c1 read: %v", err) } c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second}) if err != nil { t.Fatalf("DialWithConfig: %v", err) } defer c2.Close() } func TestListenTLS_FileAPI(t *testing.T) { certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost") defer cleanupFiles() ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false) if err != nil { t.Fatalf("ListenTLS: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() msg := []byte("hi") if _, err := tc.Write(msg); err != nil { t.Fatalf("tls write: %v", err) } out := make([]byte, 2) if _, err := io.ReadFull(tc, out); err != nil { t.Fatalf("tls read: %v", err) } } func TestDialTLSWithConfig(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = base cfg.AllowNonTLS = false ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() clientCfg := &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", } c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second) if err != nil { t.Fatalf("DialTLSWithConfig: %v", err) } defer c.Close() if !c.IsTLS() { t.Fatalf("expected IsTLS true") } } func TestDialTLS_FileAPI(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = base cfg.AllowNonTLS = false ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost") defer cleanupFiles() c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile) if err != nil { t.Fatalf("DialTLS: %v", err) } defer c.Close() if !c.IsTLS() { t.Fatalf("expected IsTLS true") } } func TestConnIsTLS_PlainAndTLS(t *testing.T) { // ---- plain case ---- plainCfg := DefaultListenerConfig() plainCfg.AllowNonTLS = true ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg) if err != nil { t.Fatalf("listen1: %v", err) } defer ln1.Close() plainDone := make(chan *Conn, 1) plainErr := make(chan error, 1) go func() { nc, err := ln1.Accept() if err != nil { plainErr <- err return } sc, ok := nc.(*Conn) if !ok { _ = nc.Close() plainErr <- errors.New("accepted conn is not *Conn") return } plainDone <- sc // block until client sends one byte, then close buf := make([]byte, 1) _, _ = sc.Read(buf) _ = sc.Close() }() c1, err := net.Dial("tcp", ln1.Addr().String()) if err != nil { t.Fatalf("dial1: %v", err) } if _, err := c1.Write([]byte("p")); err != nil { _ = c1.Close() t.Fatalf("plain client write: %v", err) } _ = c1.Close() select { case err := <-plainErr: t.Fatalf("plain server error: %v", err) case sc1 := <-plainDone: if sc1.IsTLS() { t.Fatalf("plain conn should not be TLS") } case <-time.After(2 * time.Second): t.Fatalf("timeout waiting plain side") } // ---- tls case ---- cert := genSelfSignedCert(t, "localhost") tlsBase := TLSDefaults() tlsBase.Certificates = []tls.Certificate{cert} tlsCfg := DefaultListenerConfig() tlsCfg.BaseTLSConfig = tlsBase tlsCfg.AllowNonTLS = false ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg) if err != nil { t.Fatalf("listen2: %v", err) } defer ln2.Close() tlsDone := make(chan *Conn, 1) tlsErr := make(chan error, 1) go func() { nc, err := ln2.Accept() if err != nil { tlsErr <- err return } sc, ok := nc.(*Conn) if !ok { _ = nc.Close() tlsErr <- errors.New("accepted conn is not *Conn") return } tlsDone <- sc // key point: wait for real data to ensure TLS handshake/path is executed buf := make([]byte, 1) _, _ = sc.Read(buf) _ = sc.Close() }() d := &net.Dialer{Timeout: 2 * time.Second} tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{ InsecureSkipVerify: true, // test only ServerName: "localhost", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } if _, err := tc.Write([]byte("t")); err != nil { _ = tc.Close() t.Fatalf("tls client write: %v", err) } _ = tc.Close() select { case err := <-tlsErr: t.Fatalf("tls server error: %v", err) case sc2 := <-tlsDone: if !sc2.IsTLS() { t.Fatalf("tls conn should be TLS") } case <-time.After(3 * time.Second): t.Fatalf("timeout waiting tls side") } }