starnet/tlssniffer_test.go

692 lines
14 KiB
Go
Raw Normal View History

2026-03-08 20:19:40 +08:00
package starnet
import (
"context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"io"
"math/big"
"net"
"os"
"sync"
"testing"
"time"
)
// ---------- cert helpers ----------
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("GenerateKey: %v", err)
}
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
if err != nil {
t.Fatalf("serial: %v", err)
}
tpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
CommonName: "starnet-test",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
DNSNames: dnsNames,
}
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("CreateCertificate: %v", err)
}
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
return certPEM, keyPEM
}
func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate {
t.Helper()
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
t.Fatalf("X509KeyPair: %v", err)
}
return cert
}
func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) {
t.Helper()
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
cf, err := os.CreateTemp("", "starnet-cert-*.pem")
if err != nil {
t.Fatalf("CreateTemp cert: %v", err)
}
kf, err := os.CreateTemp("", "starnet-key-*.pem")
if err != nil {
_ = cf.Close()
_ = os.Remove(cf.Name())
t.Fatalf("CreateTemp key: %v", err)
}
if _, err := cf.Write(certPEM); err != nil {
t.Fatalf("write cert: %v", err)
}
if _, err := kf.Write(keyPEM); err != nil {
t.Fatalf("write key: %v", err)
}
_ = cf.Close()
_ = kf.Close()
return cf.Name(), kf.Name(), func() {
_ = os.Remove(cf.Name())
_ = os.Remove(kf.Name())
}
}
// ---------- server helpers ----------
func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) {
t.Helper()
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("ListenWithConfig: %v", err)
}
var wg sync.WaitGroup
stop := make(chan struct{})
wg.Add(1)
go func() {
defer wg.Done()
for {
c, err := ln.Accept()
if err != nil {
select {
case <-stop:
return
default:
return
}
}
go func(conn net.Conn) {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}(c)
}
}()
cleanup := func() {
close(stop)
_ = ln.Close()
wg.Wait()
}
return ln, ln.Addr().String(), cleanup
}
// ---------- tests ----------
func TestListen(t *testing.T) {
ln, err := Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
msg := []byte("x")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 1)
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
}
func TestListenWithListenConfig(t *testing.T) {
lc := net.ListenConfig{}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("ListenWithListenConfig: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer c.Close()
msg := []byte("ok")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, 2)
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
}
func TestListenerSetConfig(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
cfg2 := cfg
cfg2.SniffTimeout = time.Second
ln.SetConfig(cfg2)
got := ln.Config()
if got.SniffTimeout != time.Second {
t.Fatalf("SetConfig not applied")
}
}
func TestPlainAllowed(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
cfg.BaseTLSConfig = nil
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
msg := []byte("hello-plain")
if _, err := c.Write(msg); err != nil {
t.Fatalf("write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(c, buf); err != nil {
t.Fatalf("read: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg))
}
}
func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c.Close()
_, _ = c.Write([]byte("plain"))
_ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
b := make([]byte, 1)
_, err = c.Read(b)
if err == nil {
t.Fatalf("expected read error due to non-tls rejection")
}
}
func TestTLSHandshakeAndEcho(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
msg := []byte("hello-tls")
if _, err := tc.Write(msg); err != nil {
t.Fatalf("tls write: %v", err)
}
buf := make([]byte, len(msg))
if _, err := io.ReadFull(tc, buf); err != nil {
t.Fatalf("tls read: %v", err)
}
if string(buf) != string(msg) {
t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg))
}
}
func TestDynamicConfigBySNI(t *testing.T) {
certA := genSelfSignedCert(t, "a.local")
certB := genSelfSignedCert(t, "b.local")
base := TLSDefaults()
base.Certificates = []tls.Certificate{certA}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
if host == "b.local" {
b := TLSDefaults()
b.Certificates = []tls.Certificate{certB}
return b, nil
}
return nil, nil
}
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
tc, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "b.local",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
if !tc.ConnectionState().HandshakeComplete {
t.Fatalf("handshake not complete")
}
}
func TestGetConfigForClientError(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = false
cfg.BaseTLSConfig = base
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
return nil, errors.New("boom")
}
_, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
_, err := tls.Dial("tcp", addr, &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
})
if err == nil {
t.Fatalf("expected tls dial failure due to selector error")
}
}
func TestAcceptContextCancel(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err = ln.AcceptContext(ctx)
if err == nil {
t.Fatalf("expected context timeout/cancel")
}
}
func TestListenerStats(t *testing.T) {
cfg := DefaultListenerConfig()
cfg.AllowNonTLS = true
ln, addr, cleanup := startEchoServer(t, cfg)
defer cleanup()
c, err := net.Dial("tcp", addr)
if err != nil {
t.Fatalf("dial: %v", err)
}
_, _ = c.Write([]byte("x"))
_ = c.Close()
time.Sleep(100 * time.Millisecond)
s := ln.Stats()
if s.Accepted == 0 {
t.Fatalf("expected accepted > 0")
}
}
func TestDialAndDialWithConfig(t *testing.T) {
nl, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer nl.Close()
go func() {
c, err := nl.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
c1, err := Dial("tcp", nl.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}
defer c1.Close()
msg := []byte("abc")
if _, err := c1.Write(msg); err != nil {
t.Fatalf("c1 write: %v", err)
}
got := make([]byte, 3)
if _, err := io.ReadFull(c1, got); err != nil {
t.Fatalf("c1 read: %v", err)
}
c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second})
if err != nil {
t.Fatalf("DialWithConfig: %v", err)
}
defer c2.Close()
}
func TestListenTLS_FileAPI(t *testing.T) {
certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
defer cleanupFiles()
ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false)
if err != nil {
t.Fatalf("ListenTLS: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
defer tc.Close()
msg := []byte("hi")
if _, err := tc.Write(msg); err != nil {
t.Fatalf("tls write: %v", err)
}
out := make([]byte, 2)
if _, err := io.ReadFull(tc, out); err != nil {
t.Fatalf("tls read: %v", err)
}
}
func TestDialTLSWithConfig(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.BaseTLSConfig = base
cfg.AllowNonTLS = false
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
clientCfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: "localhost",
}
c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second)
if err != nil {
t.Fatalf("DialTLSWithConfig: %v", err)
}
defer c.Close()
if !c.IsTLS() {
t.Fatalf("expected IsTLS true")
}
}
func TestDialTLS_FileAPI(t *testing.T) {
cert := genSelfSignedCert(t, "localhost")
base := TLSDefaults()
base.Certificates = []tls.Certificate{cert}
cfg := DefaultListenerConfig()
cfg.BaseTLSConfig = base
cfg.AllowNonTLS = false
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
c, err := ln.Accept()
if err != nil {
return
}
defer c.Close()
_, _ = io.Copy(c, c)
}()
clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
defer cleanupFiles()
c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile)
if err != nil {
t.Fatalf("DialTLS: %v", err)
}
defer c.Close()
if !c.IsTLS() {
t.Fatalf("expected IsTLS true")
}
}
func TestConnIsTLS_PlainAndTLS(t *testing.T) {
// ---- plain case ----
plainCfg := DefaultListenerConfig()
plainCfg.AllowNonTLS = true
ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg)
if err != nil {
t.Fatalf("listen1: %v", err)
}
defer ln1.Close()
plainDone := make(chan *Conn, 1)
plainErr := make(chan error, 1)
go func() {
nc, err := ln1.Accept()
if err != nil {
plainErr <- err
return
}
sc, ok := nc.(*Conn)
if !ok {
_ = nc.Close()
plainErr <- errors.New("accepted conn is not *Conn")
return
}
plainDone <- sc
// block until client sends one byte, then close
buf := make([]byte, 1)
_, _ = sc.Read(buf)
_ = sc.Close()
}()
c1, err := net.Dial("tcp", ln1.Addr().String())
if err != nil {
t.Fatalf("dial1: %v", err)
}
if _, err := c1.Write([]byte("p")); err != nil {
_ = c1.Close()
t.Fatalf("plain client write: %v", err)
}
_ = c1.Close()
select {
case err := <-plainErr:
t.Fatalf("plain server error: %v", err)
case sc1 := <-plainDone:
if sc1.IsTLS() {
t.Fatalf("plain conn should not be TLS")
}
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting plain side")
}
// ---- tls case ----
cert := genSelfSignedCert(t, "localhost")
tlsBase := TLSDefaults()
tlsBase.Certificates = []tls.Certificate{cert}
tlsCfg := DefaultListenerConfig()
tlsCfg.BaseTLSConfig = tlsBase
tlsCfg.AllowNonTLS = false
ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg)
if err != nil {
t.Fatalf("listen2: %v", err)
}
defer ln2.Close()
tlsDone := make(chan *Conn, 1)
tlsErr := make(chan error, 1)
go func() {
nc, err := ln2.Accept()
if err != nil {
tlsErr <- err
return
}
sc, ok := nc.(*Conn)
if !ok {
_ = nc.Close()
tlsErr <- errors.New("accepted conn is not *Conn")
return
}
tlsDone <- sc
// key point: wait for real data to ensure TLS handshake/path is executed
buf := make([]byte, 1)
_, _ = sc.Read(buf)
_ = sc.Close()
}()
d := &net.Dialer{Timeout: 2 * time.Second}
tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{
InsecureSkipVerify: true, // test only
ServerName: "localhost",
MinVersion: tls.VersionTLS12,
})
if err != nil {
t.Fatalf("tls dial: %v", err)
}
if _, err := tc.Write([]byte("t")); err != nil {
_ = tc.Close()
t.Fatalf("tls client write: %v", err)
}
_ = tc.Close()
select {
case err := <-tlsErr:
t.Fatalf("tls server error: %v", err)
case sc2 := <-tlsDone:
if !sc2.IsTLS() {
t.Fatalf("tls conn should be TLS")
}
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting tls side")
}
}