package starnet import ( "context" "crypto/rand" "crypto/rsa" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "errors" "io" "math/big" "net" "os" "sync" "testing" "time" ) // ---------- cert helpers ---------- func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) { t.Helper() priv, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { t.Fatalf("GenerateKey: %v", err) } serial, err := rand.Int(rand.Reader, big.NewInt(1<<62)) if err != nil { t.Fatalf("serial: %v", err) } tpl := &x509.Certificate{ SerialNumber: serial, Subject: pkix.Name{ CommonName: "starnet-test", }, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(24 * time.Hour), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, DNSNames: dnsNames, } der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv) if err != nil { t.Fatalf("CreateCertificate: %v", err) } certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) return certPEM, keyPEM } func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate { t.Helper() certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) cert, err := tls.X509KeyPair(certPEM, keyPEM) if err != nil { t.Fatalf("X509KeyPair: %v", err) } return cert } func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) { t.Helper() certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...) cf, err := os.CreateTemp("", "starnet-cert-*.pem") if err != nil { t.Fatalf("CreateTemp cert: %v", err) } kf, err := os.CreateTemp("", "starnet-key-*.pem") if err != nil { _ = cf.Close() _ = os.Remove(cf.Name()) t.Fatalf("CreateTemp key: %v", err) } if _, err := cf.Write(certPEM); err != nil { t.Fatalf("write cert: %v", err) } if _, err := kf.Write(keyPEM); err != nil { t.Fatalf("write key: %v", err) } _ = cf.Close() _ = kf.Close() return cf.Name(), kf.Name(), func() { _ = os.Remove(cf.Name()) _ = os.Remove(kf.Name()) } } // ---------- server helpers ---------- func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) { t.Helper() ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("ListenWithConfig: %v", err) } var wg sync.WaitGroup stop := make(chan struct{}) wg.Add(1) go func() { defer wg.Done() for { c, err := ln.Accept() if err != nil { select { case <-stop: return default: return } } go func(conn net.Conn) { defer conn.Close() _, _ = io.Copy(conn, conn) }(c) } }() cleanup := func() { close(stop) _ = ln.Close() wg.Wait() } return ln, ln.Addr().String(), cleanup } // ---------- tests ---------- func TestListen(t *testing.T) { ln, err := Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("dial: %v", err) } defer c.Close() msg := []byte("x") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, 1) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } } func TestListenWithListenConfig(t *testing.T) { lc := net.ListenConfig{} cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("ListenWithListenConfig: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("dial: %v", err) } defer c.Close() msg := []byte("ok") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, 2) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } } func TestListenerSetConfig(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() cfg2 := cfg cfg2.SniffTimeout = time.Second ln.SetConfig(cfg2) got := ln.Config() if got.SniffTimeout != time.Second { t.Fatalf("SetConfig not applied") } } func TestPlainAllowed(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true cfg.BaseTLSConfig = nil _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Dial: %v", err) } defer c.Close() msg := []byte("hello-plain") if _, err := c.Write(msg); err != nil { t.Fatalf("write: %v", err) } buf := make([]byte, len(msg)) if _, err := io.ReadFull(c, buf); err != nil { t.Fatalf("read: %v", err) } if string(buf) != string(msg) { t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg)) } } func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("Dial: %v", err) } defer c.Close() _, _ = c.Write([]byte("plain")) _ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) b := make([]byte, 1) _, err = c.Read(b) if err == nil { t.Fatalf("expected read error due to non-tls rejection") } } func TestTLSHandshakeAndEcho(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() msg := []byte("hello-tls") if _, err := tc.Write(msg); err != nil { t.Fatalf("tls write: %v", err) } buf := make([]byte, len(msg)) if _, err := io.ReadFull(tc, buf); err != nil { t.Fatalf("tls read: %v", err) } if string(buf) != string(msg) { t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg)) } } func TestDynamicConfigBySNI(t *testing.T) { certA := genSelfSignedCert(t, "a.local") certB := genSelfSignedCert(t, "b.local") base := TLSDefaults() base.Certificates = []tls.Certificate{certA} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base cfg.GetConfigForClient = func(host string) (*tls.Config, error) { if host == "b.local" { b := TLSDefaults() b.Certificates = []tls.Certificate{certB} return b, nil } return nil, nil } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() tc, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "b.local", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() if !tc.ConnectionState().HandshakeComplete { t.Fatalf("handshake not complete") } } func TestGetConfigForClientError(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.AllowNonTLS = false cfg.BaseTLSConfig = base cfg.GetConfigForClient = func(host string) (*tls.Config, error) { return nil, errors.New("boom") } _, addr, cleanup := startEchoServer(t, cfg) defer cleanup() _, err := tls.Dial("tcp", addr, &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }) if err == nil { t.Fatalf("expected tls dial failure due to selector error") } } func TestAcceptContextCancel(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() _, err = ln.AcceptContext(ctx) if err == nil { t.Fatalf("expected context timeout/cancel") } } func TestListenerStats(t *testing.T) { cfg := DefaultListenerConfig() cfg.AllowNonTLS = true ln, addr, cleanup := startEchoServer(t, cfg) defer cleanup() c, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("dial: %v", err) } _, _ = c.Write([]byte("x")) _ = c.Close() time.Sleep(100 * time.Millisecond) s := ln.Stats() if s.Accepted == 0 { t.Fatalf("expected accepted > 0") } } func TestDialAndDialWithConfig(t *testing.T) { nl, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer nl.Close() go func() { c, err := nl.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() c1, err := Dial("tcp", nl.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } defer c1.Close() msg := []byte("abc") if _, err := c1.Write(msg); err != nil { t.Fatalf("c1 write: %v", err) } got := make([]byte, 3) if _, err := io.ReadFull(c1, got); err != nil { t.Fatalf("c1 read: %v", err) } c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second}) if err != nil { t.Fatalf("DialWithConfig: %v", err) } defer c2.Close() } func TestListenTLS_FileAPI(t *testing.T) { certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost") defer cleanupFiles() ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false) if err != nil { t.Fatalf("ListenTLS: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", }) if err != nil { t.Fatalf("tls dial: %v", err) } defer tc.Close() msg := []byte("hi") if _, err := tc.Write(msg); err != nil { t.Fatalf("tls write: %v", err) } out := make([]byte, 2) if _, err := io.ReadFull(tc, out); err != nil { t.Fatalf("tls read: %v", err) } } func TestDialTLSWithConfig(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = base cfg.AllowNonTLS = false ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() clientCfg := &tls.Config{ InsecureSkipVerify: true, ServerName: "localhost", } c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second) if err != nil { t.Fatalf("DialTLSWithConfig: %v", err) } defer c.Close() if !c.IsTLS() { t.Fatalf("expected IsTLS true") } } func TestDialTLS_FileAPI(t *testing.T) { cert := genSelfSignedCert(t, "localhost") base := TLSDefaults() base.Certificates = []tls.Certificate{cert} cfg := DefaultListenerConfig() cfg.BaseTLSConfig = base cfg.AllowNonTLS = false ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg) if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() go func() { c, err := ln.Accept() if err != nil { return } defer c.Close() _, _ = io.Copy(c, c) }() clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost") defer cleanupFiles() c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile) if err != nil { t.Fatalf("DialTLS: %v", err) } defer c.Close() if !c.IsTLS() { t.Fatalf("expected IsTLS true") } } func TestConnIsTLS_PlainAndTLS(t *testing.T) { // ---- plain case ---- plainCfg := DefaultListenerConfig() plainCfg.AllowNonTLS = true ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg) if err != nil { t.Fatalf("listen1: %v", err) } defer ln1.Close() plainDone := make(chan *Conn, 1) plainErr := make(chan error, 1) go func() { nc, err := ln1.Accept() if err != nil { plainErr <- err return } sc, ok := nc.(*Conn) if !ok { _ = nc.Close() plainErr <- errors.New("accepted conn is not *Conn") return } plainDone <- sc // block until client sends one byte, then close buf := make([]byte, 1) _, _ = sc.Read(buf) _ = sc.Close() }() c1, err := net.Dial("tcp", ln1.Addr().String()) if err != nil { t.Fatalf("dial1: %v", err) } if _, err := c1.Write([]byte("p")); err != nil { _ = c1.Close() t.Fatalf("plain client write: %v", err) } _ = c1.Close() select { case err := <-plainErr: t.Fatalf("plain server error: %v", err) case sc1 := <-plainDone: if sc1.IsTLS() { t.Fatalf("plain conn should not be TLS") } case <-time.After(2 * time.Second): t.Fatalf("timeout waiting plain side") } // ---- tls case ---- cert := genSelfSignedCert(t, "localhost") tlsBase := TLSDefaults() tlsBase.Certificates = []tls.Certificate{cert} tlsCfg := DefaultListenerConfig() tlsCfg.BaseTLSConfig = tlsBase tlsCfg.AllowNonTLS = false ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg) if err != nil { t.Fatalf("listen2: %v", err) } defer ln2.Close() tlsDone := make(chan *Conn, 1) tlsErr := make(chan error, 1) go func() { nc, err := ln2.Accept() if err != nil { tlsErr <- err return } sc, ok := nc.(*Conn) if !ok { _ = nc.Close() tlsErr <- errors.New("accepted conn is not *Conn") return } tlsDone <- sc // key point: wait for real data to ensure TLS handshake/path is executed buf := make([]byte, 1) _, _ = sc.Read(buf) _ = sc.Close() }() d := &net.Dialer{Timeout: 2 * time.Second} tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{ InsecureSkipVerify: true, // test only ServerName: "localhost", MinVersion: tls.VersionTLS12, }) if err != nil { t.Fatalf("tls dial: %v", err) } if _, err := tc.Write([]byte("t")); err != nil { _ = tc.Close() t.Fatalf("tls client write: %v", err) } _ = tc.Close() select { case err := <-tlsErr: t.Fatalf("tls server error: %v", err) case sc2 := <-tlsDone: if !sc2.IsTLS() { t.Fatalf("tls conn should be TLS") } case <-time.After(3 * time.Second): t.Fatalf("timeout waiting tls side") } }