2026-03-08 20:19:40 +08:00
|
|
|
package starnet
|
|
|
|
|
|
|
|
|
|
import (
|
2026-03-27 12:05:23 +08:00
|
|
|
"bytes"
|
2026-03-08 20:19:40 +08:00
|
|
|
"context"
|
|
|
|
|
"crypto/rand"
|
|
|
|
|
"crypto/rsa"
|
|
|
|
|
"crypto/tls"
|
|
|
|
|
"crypto/x509"
|
|
|
|
|
"crypto/x509/pkix"
|
2026-03-27 12:05:23 +08:00
|
|
|
"encoding/binary"
|
2026-03-08 20:19:40 +08:00
|
|
|
"encoding/pem"
|
|
|
|
|
"errors"
|
|
|
|
|
"io"
|
|
|
|
|
"math/big"
|
|
|
|
|
"net"
|
|
|
|
|
"os"
|
|
|
|
|
"sync"
|
|
|
|
|
"testing"
|
|
|
|
|
"time"
|
|
|
|
|
)
|
|
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
type staticSniffer struct {
|
|
|
|
|
result SniffResult
|
|
|
|
|
err error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type fixedAddr string
|
|
|
|
|
|
|
|
|
|
func (a fixedAddr) Network() string { return "tcp" }
|
|
|
|
|
func (a fixedAddr) String() string { return string(a) }
|
|
|
|
|
|
|
|
|
|
type recordingConn struct {
|
|
|
|
|
buf bytes.Buffer
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *recordingConn) Read([]byte) (int, error) { return 0, io.EOF }
|
|
|
|
|
func (c *recordingConn) Write(p []byte) (int, error) { return c.buf.Write(p) }
|
|
|
|
|
func (c *recordingConn) Close() error { return nil }
|
|
|
|
|
func (c *recordingConn) LocalAddr() net.Addr { return fixedAddr("127.0.0.1:443") }
|
|
|
|
|
func (c *recordingConn) RemoteAddr() net.Addr { return fixedAddr("127.0.0.1:50000") }
|
|
|
|
|
func (c *recordingConn) SetDeadline(time.Time) error { return nil }
|
|
|
|
|
func (c *recordingConn) SetReadDeadline(time.Time) error { return nil }
|
|
|
|
|
func (c *recordingConn) SetWriteDeadline(time.Time) error { return nil }
|
|
|
|
|
|
|
|
|
|
type bytesConn struct {
|
|
|
|
|
reader io.Reader
|
|
|
|
|
local net.Addr
|
|
|
|
|
remote net.Addr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (c *bytesConn) Read(p []byte) (int, error) { return c.reader.Read(p) }
|
|
|
|
|
func (c *bytesConn) Write(p []byte) (int, error) { return len(p), nil }
|
|
|
|
|
func (c *bytesConn) Close() error { return nil }
|
|
|
|
|
func (c *bytesConn) LocalAddr() net.Addr { return c.local }
|
|
|
|
|
func (c *bytesConn) RemoteAddr() net.Addr { return c.remote }
|
|
|
|
|
func (c *bytesConn) SetDeadline(time.Time) error { return nil }
|
|
|
|
|
func (c *bytesConn) SetReadDeadline(time.Time) error { return nil }
|
|
|
|
|
func (c *bytesConn) SetWriteDeadline(time.Time) error { return nil }
|
|
|
|
|
|
|
|
|
|
func (s staticSniffer) Sniff(_ net.Conn, _ int) (SniffResult, error) {
|
|
|
|
|
return s.result, s.err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func captureClientHelloBytes(t *testing.T, serverName string, nextProtos []string) []byte {
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
|
|
conn := &recordingConn{}
|
|
|
|
|
tc := tls.Client(conn, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: serverName,
|
|
|
|
|
NextProtos: nextProtos,
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
_ = tc.Handshake()
|
|
|
|
|
_ = tc.Close()
|
|
|
|
|
|
|
|
|
|
out := append([]byte(nil), conn.buf.Bytes()...)
|
|
|
|
|
if len(out) < 9 {
|
|
|
|
|
t.Fatalf("captured hello too short: %d", len(out))
|
|
|
|
|
}
|
|
|
|
|
return out
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func prefixBytes(data []byte, n int) []byte {
|
|
|
|
|
if n < 0 {
|
|
|
|
|
n = 0
|
|
|
|
|
}
|
|
|
|
|
if len(data) < n {
|
|
|
|
|
n = len(data)
|
|
|
|
|
}
|
|
|
|
|
return data[:n]
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func splitClientHelloRecord(t *testing.T, data []byte, splitBodyAt int) []byte {
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
|
|
if len(data) < 9 {
|
|
|
|
|
t.Fatalf("client hello too short: %d", len(data))
|
|
|
|
|
}
|
|
|
|
|
if data[0] != 0x16 || data[1] != 0x03 {
|
|
|
|
|
t.Fatalf("unexpected tls record header: % x", prefixBytes(data, 5))
|
|
|
|
|
}
|
|
|
|
|
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
|
|
|
|
|
if len(data) < 5+recordLen {
|
|
|
|
|
t.Fatalf("short tls record: have=%d want=%d", len(data), 5+recordLen)
|
|
|
|
|
}
|
|
|
|
|
recordBody := append([]byte(nil), data[5:5+recordLen]...)
|
|
|
|
|
if len(recordBody) < 4 || recordBody[0] != 0x01 {
|
|
|
|
|
t.Fatalf("unexpected handshake record: % x", prefixBytes(recordBody, 4))
|
|
|
|
|
}
|
|
|
|
|
if splitBodyAt <= 4 || splitBodyAt >= len(recordBody) {
|
|
|
|
|
t.Fatalf("invalid split point %d for body len %d", splitBodyAt, len(recordBody))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var out bytes.Buffer
|
|
|
|
|
writeRecord := func(body []byte) {
|
|
|
|
|
out.WriteByte(0x16)
|
|
|
|
|
out.WriteByte(data[1])
|
|
|
|
|
out.WriteByte(data[2])
|
|
|
|
|
header := []byte{0, 0}
|
|
|
|
|
binary.BigEndian.PutUint16(header, uint16(len(body)))
|
|
|
|
|
out.Write(header)
|
|
|
|
|
out.Write(body)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
writeRecord(recordBody[:splitBodyAt])
|
|
|
|
|
writeRecord(recordBody[splitBodyAt:])
|
|
|
|
|
return out.Bytes()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSSnifferParsesClientHello(t *testing.T) {
|
|
|
|
|
client, server := net.Pipe()
|
|
|
|
|
defer client.Close()
|
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
|
|
done := make(chan error, 1)
|
|
|
|
|
go func() {
|
|
|
|
|
tc := tls.Client(client, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "example.com",
|
|
|
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
done <- tc.Handshake()
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
_ = server.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
|
|
|
res, err := (TLSSniffer{}).Sniff(server, 64*1024)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("sniff: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if !res.IsTLS {
|
|
|
|
|
t.Fatalf("expected TLS")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello == nil {
|
|
|
|
|
t.Fatalf("expected client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello.ServerName != "example.com" {
|
|
|
|
|
t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" {
|
|
|
|
|
t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos)
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.SupportedVersions) == 0 {
|
|
|
|
|
t.Fatalf("expected supported versions")
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.CipherSuites) == 0 {
|
|
|
|
|
t.Fatalf("expected cipher suites")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_ = server.Close()
|
|
|
|
|
<-done
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSSnifferParsesFragmentedClientHelloRecords(t *testing.T) {
|
|
|
|
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
|
|
|
|
fragmented := splitClientHelloRecord(t, rawHello, 32)
|
|
|
|
|
|
|
|
|
|
conn := &bytesConn{
|
|
|
|
|
reader: bytes.NewReader(fragmented),
|
|
|
|
|
local: fixedAddr("127.0.0.1:443"),
|
|
|
|
|
remote: fixedAddr("127.0.0.1:50000"),
|
|
|
|
|
}
|
|
|
|
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("sniff: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !res.IsTLS {
|
|
|
|
|
t.Fatalf("expected TLS")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello == nil {
|
|
|
|
|
t.Fatalf("expected client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello.ServerName != "example.com" {
|
|
|
|
|
t.Fatalf("unexpected server name: %q", res.ClientHello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.SupportedProtos) == 0 || res.ClientHello.SupportedProtos[0] != "h2" {
|
|
|
|
|
t.Fatalf("unexpected ALPN list: %v", res.ClientHello.SupportedProtos)
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.SupportedVersions) == 0 {
|
|
|
|
|
t.Fatalf("expected supported versions")
|
|
|
|
|
}
|
|
|
|
|
if len(res.ClientHello.CipherSuites) == 0 {
|
|
|
|
|
t.Fatalf("expected cipher suites")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSSnifferKeepsTruncatedTLSAsTLS(t *testing.T) {
|
|
|
|
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
|
|
|
|
truncated := append([]byte(nil), rawHello[:20]...)
|
|
|
|
|
|
|
|
|
|
conn := &bytesConn{
|
|
|
|
|
reader: bytes.NewReader(truncated),
|
|
|
|
|
local: fixedAddr("127.0.0.1:443"),
|
|
|
|
|
remote: fixedAddr("127.0.0.1:50000"),
|
|
|
|
|
}
|
|
|
|
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("sniff: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !res.IsTLS {
|
|
|
|
|
t.Fatalf("expected truncated hello to stay classified as TLS")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello == nil {
|
|
|
|
|
t.Fatalf("expected client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello.ServerName != "" {
|
|
|
|
|
t.Fatalf("expected empty server name for truncated hello, got %q", res.ClientHello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello.LocalAddr == nil || res.ClientHello.RemoteAddr == nil {
|
|
|
|
|
t.Fatalf("expected local/remote addr in metadata")
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(res.Buffer.Bytes(), truncated) {
|
|
|
|
|
t.Fatalf("buffer mismatch: got %d bytes want %d", res.Buffer.Len(), len(truncated))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSSnifferRespectsMaxBytesWhileKeepingTLSClassification(t *testing.T) {
|
|
|
|
|
rawHello := captureClientHelloBytes(t, "example.com", []string{"h2", "http/1.1"})
|
|
|
|
|
limit := 5
|
|
|
|
|
|
|
|
|
|
conn := &bytesConn{
|
|
|
|
|
reader: bytes.NewReader(rawHello),
|
|
|
|
|
local: fixedAddr("127.0.0.1:443"),
|
|
|
|
|
remote: fixedAddr("127.0.0.1:50000"),
|
|
|
|
|
}
|
|
|
|
|
res, err := (TLSSniffer{}).Sniff(conn, limit)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("sniff: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if !res.IsTLS {
|
|
|
|
|
t.Fatalf("expected TLS classification with header-sized limit")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello == nil {
|
|
|
|
|
t.Fatalf("expected client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello.ServerName != "" {
|
|
|
|
|
t.Fatalf("expected empty server name with header-sized limit, got %q", res.ClientHello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if res.Buffer.Len() != limit {
|
|
|
|
|
t.Fatalf("buffer length mismatch: got %d want %d", res.Buffer.Len(), limit)
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(res.Buffer.Bytes(), rawHello[:limit]) {
|
|
|
|
|
t.Fatalf("buffer content mismatch")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSSnifferRejectsFakeTLSClientRecord(t *testing.T) {
|
|
|
|
|
fake := []byte{
|
|
|
|
|
0x16, 0x03, 0x03, 0x00, 0x04,
|
|
|
|
|
'G', 'E', 'T', ' ',
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
conn := &bytesConn{
|
|
|
|
|
reader: bytes.NewReader(fake),
|
|
|
|
|
local: fixedAddr("127.0.0.1:443"),
|
|
|
|
|
remote: fixedAddr("127.0.0.1:50000"),
|
|
|
|
|
}
|
|
|
|
|
res, err := (TLSSniffer{}).Sniff(conn, 64*1024)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("sniff: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if res.IsTLS {
|
|
|
|
|
t.Fatalf("expected fake tls-looking record to be classified as non-tls")
|
|
|
|
|
}
|
|
|
|
|
if res.ClientHello != nil {
|
|
|
|
|
t.Fatalf("expected no client hello metadata for fake record")
|
|
|
|
|
}
|
|
|
|
|
if !bytes.Equal(res.Buffer.Bytes(), fake) {
|
|
|
|
|
t.Fatalf("buffer mismatch")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
// ---------- cert helpers ----------
|
|
|
|
|
|
|
|
|
|
func genSelfSignedCertPEM(t *testing.T, dnsNames ...string) (certPEM, keyPEM []byte) {
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
|
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("GenerateKey: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
serial, err := rand.Int(rand.Reader, big.NewInt(1<<62))
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("serial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tpl := &x509.Certificate{
|
|
|
|
|
SerialNumber: serial,
|
|
|
|
|
Subject: pkix.Name{
|
|
|
|
|
CommonName: "starnet-test",
|
|
|
|
|
},
|
|
|
|
|
NotBefore: time.Now().Add(-time.Hour),
|
|
|
|
|
NotAfter: time.Now().Add(24 * time.Hour),
|
|
|
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
|
|
|
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
|
|
|
|
DNSNames: dnsNames,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &priv.PublicKey, priv)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("CreateCertificate: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
certPEM = pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
|
|
|
|
keyPEM = pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
|
|
|
|
return certPEM, keyPEM
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func genSelfSignedCert(t *testing.T, dnsNames ...string) tls.Certificate {
|
|
|
|
|
t.Helper()
|
|
|
|
|
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
|
|
|
|
|
cert, err := tls.X509KeyPair(certPEM, keyPEM)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("X509KeyPair: %v", err)
|
|
|
|
|
}
|
|
|
|
|
return cert
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func writeTempCertFiles(t *testing.T, dnsNames ...string) (certFile, keyFile string, cleanup func()) {
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
|
|
certPEM, keyPEM := genSelfSignedCertPEM(t, dnsNames...)
|
|
|
|
|
|
|
|
|
|
cf, err := os.CreateTemp("", "starnet-cert-*.pem")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("CreateTemp cert: %v", err)
|
|
|
|
|
}
|
|
|
|
|
kf, err := os.CreateTemp("", "starnet-key-*.pem")
|
|
|
|
|
if err != nil {
|
|
|
|
|
_ = cf.Close()
|
|
|
|
|
_ = os.Remove(cf.Name())
|
|
|
|
|
t.Fatalf("CreateTemp key: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if _, err := cf.Write(certPEM); err != nil {
|
|
|
|
|
t.Fatalf("write cert: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if _, err := kf.Write(keyPEM); err != nil {
|
|
|
|
|
t.Fatalf("write key: %v", err)
|
|
|
|
|
}
|
|
|
|
|
_ = cf.Close()
|
|
|
|
|
_ = kf.Close()
|
|
|
|
|
|
|
|
|
|
return cf.Name(), kf.Name(), func() {
|
|
|
|
|
_ = os.Remove(cf.Name())
|
|
|
|
|
_ = os.Remove(kf.Name())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---------- server helpers ----------
|
|
|
|
|
|
|
|
|
|
func startEchoServer(t *testing.T, cfg ListenerConfig) (*Listener, string, func()) {
|
|
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ListenWithConfig: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var wg sync.WaitGroup
|
|
|
|
|
stop := make(chan struct{})
|
|
|
|
|
|
|
|
|
|
wg.Add(1)
|
|
|
|
|
go func() {
|
|
|
|
|
defer wg.Done()
|
|
|
|
|
for {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
select {
|
|
|
|
|
case <-stop:
|
|
|
|
|
return
|
|
|
|
|
default:
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
go func(conn net.Conn) {
|
|
|
|
|
defer conn.Close()
|
|
|
|
|
_, _ = io.Copy(conn, conn)
|
|
|
|
|
}(c)
|
|
|
|
|
}
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
cleanup := func() {
|
|
|
|
|
close(stop)
|
|
|
|
|
_ = ln.Close()
|
|
|
|
|
wg.Wait()
|
|
|
|
|
}
|
|
|
|
|
return ln, ln.Addr().String(), cleanup
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---------- tests ----------
|
|
|
|
|
|
|
|
|
|
func TestListen(t *testing.T) {
|
|
|
|
|
ln, err := Listen("tcp", "127.0.0.1:0")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
c, err := net.Dial("tcp", ln.Addr().String())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("x")
|
|
|
|
|
if _, err := c.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
if _, err := io.ReadFull(c, buf); err != nil {
|
|
|
|
|
t.Fatalf("read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestListenWithListenConfig(t *testing.T) {
|
|
|
|
|
lc := net.ListenConfig{}
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithListenConfig(lc, "tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ListenWithListenConfig: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
c, err := net.Dial("tcp", ln.Addr().String())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("ok")
|
|
|
|
|
if _, err := c.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
buf := make([]byte, 2)
|
|
|
|
|
if _, err := io.ReadFull(c, buf); err != nil {
|
|
|
|
|
t.Fatalf("read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestListenerSetConfig(t *testing.T) {
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
cfg2 := cfg
|
|
|
|
|
cfg2.SniffTimeout = time.Second
|
|
|
|
|
ln.SetConfig(cfg2)
|
|
|
|
|
|
|
|
|
|
got := ln.Config()
|
|
|
|
|
if got.SniffTimeout != time.Second {
|
|
|
|
|
t.Fatalf("SetConfig not applied")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestPlainAllowed(t *testing.T) {
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
cfg.BaseTLSConfig = nil
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
c, err := net.Dial("tcp", addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("hello-plain")
|
|
|
|
|
if _, err := c.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
buf := make([]byte, len(msg))
|
|
|
|
|
if _, err := io.ReadFull(c, buf); err != nil {
|
|
|
|
|
t.Fatalf("read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if string(buf) != string(msg) {
|
|
|
|
|
t.Fatalf("echo mismatch: got=%q want=%q", string(buf), string(msg))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestPlainRejectedWhenNonTLSDisabled(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
c, err := net.Dial("tcp", addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
_, _ = c.Write([]byte("plain"))
|
|
|
|
|
_ = c.SetReadDeadline(time.Now().Add(500 * time.Millisecond))
|
|
|
|
|
b := make([]byte, 1)
|
|
|
|
|
_, err = c.Read(b)
|
|
|
|
|
if err == nil {
|
|
|
|
|
t.Fatalf("expected read error due to non-tls rejection")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestTLSHandshakeAndEcho(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "localhost",
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("tls dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("hello-tls")
|
|
|
|
|
if _, err := tc.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("tls write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
buf := make([]byte, len(msg))
|
|
|
|
|
if _, err := io.ReadFull(tc, buf); err != nil {
|
|
|
|
|
t.Fatalf("tls read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if string(buf) != string(msg) {
|
|
|
|
|
t.Fatalf("tls echo mismatch: got=%q want=%q", string(buf), string(msg))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestDynamicConfigBySNI(t *testing.T) {
|
|
|
|
|
certA := genSelfSignedCert(t, "a.local")
|
|
|
|
|
certB := genSelfSignedCert(t, "b.local")
|
|
|
|
|
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{certA}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
|
|
|
|
|
if host == "b.local" {
|
|
|
|
|
b := TLSDefaults()
|
|
|
|
|
b.Certificates = []tls.Certificate{certB}
|
|
|
|
|
return b, nil
|
|
|
|
|
}
|
|
|
|
|
return nil, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "b.local",
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("tls dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
|
|
|
|
|
if !tc.ConnectionState().HandshakeComplete {
|
|
|
|
|
t.Fatalf("handshake not complete")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestGetConfigForClientError(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
cfg.GetConfigForClient = func(host string) (*tls.Config, error) {
|
|
|
|
|
return nil, errors.New("boom")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
_, err := tls.Dial("tcp", addr, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "localhost",
|
|
|
|
|
})
|
|
|
|
|
if err == nil {
|
|
|
|
|
t.Fatalf("expected tls dial failure due to selector error")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
func TestGetConfigForClientHelloReceivesMetadata(t *testing.T) {
|
|
|
|
|
certA := genSelfSignedCert(t, "a.local")
|
|
|
|
|
certB := genSelfSignedCert(t, "b.local")
|
|
|
|
|
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{certA}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
|
|
|
|
|
metaCh := make(chan *ClientHelloMeta, 1)
|
|
|
|
|
cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) {
|
|
|
|
|
metaCh <- hello.Clone()
|
|
|
|
|
if hello != nil && hello.ServerName == "b.local" {
|
|
|
|
|
b := TLSDefaults()
|
|
|
|
|
b.Certificates = []tls.Certificate{certB}
|
|
|
|
|
return b, nil
|
|
|
|
|
}
|
|
|
|
|
return nil, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
tc, err := tls.Dial("tcp", addr, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "b.local",
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("tls dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case hello := <-metaCh:
|
|
|
|
|
if hello == nil {
|
|
|
|
|
t.Fatalf("expected metadata")
|
|
|
|
|
}
|
|
|
|
|
if hello.ServerName != "b.local" {
|
|
|
|
|
t.Fatalf("unexpected server name: %q", hello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if hello.LocalAddr == nil {
|
|
|
|
|
t.Fatalf("expected LocalAddr")
|
|
|
|
|
}
|
|
|
|
|
if hello.RemoteAddr == nil {
|
|
|
|
|
t.Fatalf("expected RemoteAddr")
|
|
|
|
|
}
|
|
|
|
|
if len(hello.CipherSuites) == 0 {
|
|
|
|
|
t.Fatalf("expected cipher suites in metadata")
|
|
|
|
|
}
|
|
|
|
|
case <-time.After(2 * time.Second):
|
|
|
|
|
t.Fatalf("timeout waiting client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestGetConfigForClientHelloWithoutSNIStillGetsLocalAddr(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
|
|
|
|
|
metaCh := make(chan *ClientHelloMeta, 1)
|
|
|
|
|
cfg.GetConfigForClientHello = func(hello *ClientHelloMeta) (*tls.Config, error) {
|
|
|
|
|
metaCh <- hello.Clone()
|
|
|
|
|
return nil, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
_, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
raw, err := net.Dial("tcp", addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer raw.Close()
|
|
|
|
|
|
|
|
|
|
tc := tls.Client(raw, &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
if err := tc.Handshake(); err != nil {
|
|
|
|
|
t.Fatalf("tls handshake: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case hello := <-metaCh:
|
|
|
|
|
if hello == nil {
|
|
|
|
|
t.Fatalf("expected metadata")
|
|
|
|
|
}
|
|
|
|
|
if hello.ServerName != "" {
|
|
|
|
|
t.Fatalf("expected empty SNI, got %q", hello.ServerName)
|
|
|
|
|
}
|
|
|
|
|
if hello.LocalAddr == nil {
|
|
|
|
|
t.Fatalf("expected LocalAddr")
|
|
|
|
|
}
|
|
|
|
|
if hello.RemoteAddr == nil {
|
|
|
|
|
t.Fatalf("expected RemoteAddr")
|
|
|
|
|
}
|
|
|
|
|
case <-time.After(2 * time.Second):
|
|
|
|
|
t.Fatalf("timeout waiting client hello metadata")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func TestAcceptContextCancel(t *testing.T) {
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
|
|
|
defer cancel()
|
|
|
|
|
|
|
|
|
|
_, err = ln.AcceptContext(ctx)
|
|
|
|
|
if err == nil {
|
|
|
|
|
t.Fatalf("expected context timeout/cancel")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestListenerStats(t *testing.T) {
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.AllowNonTLS = true
|
|
|
|
|
|
|
|
|
|
ln, addr, cleanup := startEchoServer(t, cfg)
|
|
|
|
|
defer cleanup()
|
|
|
|
|
|
|
|
|
|
c, err := net.Dial("tcp", addr)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
_, _ = c.Write([]byte("x"))
|
|
|
|
|
_ = c.Close()
|
|
|
|
|
|
|
|
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
|
|
|
|
|
|
s := ln.Stats()
|
|
|
|
|
if s.Accepted == 0 {
|
|
|
|
|
t.Fatalf("expected accepted > 0")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-27 12:05:23 +08:00
|
|
|
func TestComposeServerTLSConfigInheritsBaseDefaults(t *testing.T) {
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.MinVersion = tls.VersionTLS13
|
|
|
|
|
base.NextProtos = []string{"h2", "http/1.1"}
|
|
|
|
|
base.ClientAuth = tls.RequireAndVerifyClientCert
|
|
|
|
|
base.DynamicRecordSizingDisabled = true
|
|
|
|
|
|
|
|
|
|
selected := TLSDefaults()
|
|
|
|
|
selected.Certificates = []tls.Certificate{genSelfSignedCert(t, "b.local")}
|
|
|
|
|
|
|
|
|
|
composed := composeServerTLSConfig(base, selected)
|
|
|
|
|
if composed == nil {
|
|
|
|
|
t.Fatalf("expected composed config")
|
|
|
|
|
}
|
|
|
|
|
if composed.MinVersion != tls.VersionTLS13 {
|
|
|
|
|
t.Fatalf("expected min version from base, got %v", composed.MinVersion)
|
|
|
|
|
}
|
|
|
|
|
if len(composed.NextProtos) != 2 {
|
|
|
|
|
t.Fatalf("expected next protos from base, got %v", composed.NextProtos)
|
|
|
|
|
}
|
|
|
|
|
if composed.ClientAuth != tls.RequireAndVerifyClientCert {
|
|
|
|
|
t.Fatalf("expected client auth from base, got %v", composed.ClientAuth)
|
|
|
|
|
}
|
|
|
|
|
if !composed.DynamicRecordSizingDisabled {
|
|
|
|
|
t.Fatalf("expected dynamic record sizing disabled to inherit")
|
|
|
|
|
}
|
|
|
|
|
if len(composed.Certificates) != 1 {
|
|
|
|
|
t.Fatalf("expected selected certificates to be preserved")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestComposeServerTLSConfigKeepsSelectedOverrides(t *testing.T) {
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.MinVersion = tls.VersionTLS12
|
|
|
|
|
base.NextProtos = []string{"http/1.1"}
|
|
|
|
|
|
|
|
|
|
selected := TLSDefaults()
|
|
|
|
|
selected.MinVersion = tls.VersionTLS13
|
|
|
|
|
selected.NextProtos = []string{"h2"}
|
|
|
|
|
|
|
|
|
|
composed := composeServerTLSConfig(base, selected)
|
|
|
|
|
if composed.MinVersion != tls.VersionTLS13 {
|
|
|
|
|
t.Fatalf("expected selected min version, got %v", composed.MinVersion)
|
|
|
|
|
}
|
|
|
|
|
if len(composed.NextProtos) != 1 || composed.NextProtos[0] != "h2" {
|
|
|
|
|
t.Fatalf("expected selected next protos, got %v", composed.NextProtos)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConnInitFailureStats(t *testing.T) {
|
|
|
|
|
t.Run("sniff failure", func(t *testing.T) {
|
|
|
|
|
client, server := net.Pipe()
|
|
|
|
|
defer client.Close()
|
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
|
|
stats := &Stats{}
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.BaseTLSConfig = TLSDefaults()
|
|
|
|
|
conn := newConn(server, cfg, stats)
|
|
|
|
|
conn.sniffer = staticSniffer{err: io.ErrUnexpectedEOF}
|
|
|
|
|
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
_, err := conn.Read(buf)
|
|
|
|
|
if err == nil || !errors.Is(err, ErrTLSSniffFailed) {
|
|
|
|
|
t.Fatalf("expected tls sniff failure, got %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
snap := stats.Snapshot()
|
|
|
|
|
if snap.InitFailures != 1 || snap.SniffFailures != 1 {
|
|
|
|
|
t.Fatalf("unexpected sniff failure stats: %+v", snap)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("tls config selection failure", func(t *testing.T) {
|
|
|
|
|
client, server := net.Pipe()
|
|
|
|
|
defer client.Close()
|
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
|
|
stats := &Stats{}
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.GetConfigForClientHello = func(*ClientHelloMeta) (*tls.Config, error) {
|
|
|
|
|
return nil, errors.New("boom")
|
|
|
|
|
}
|
|
|
|
|
conn := newConn(server, cfg, stats)
|
|
|
|
|
conn.sniffer = staticSniffer{
|
|
|
|
|
result: SniffResult{
|
|
|
|
|
IsTLS: true,
|
|
|
|
|
ClientHello: &ClientHelloMeta{},
|
|
|
|
|
Buffer: bytes.NewBuffer(nil),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
_, err := conn.Read(buf)
|
|
|
|
|
if err == nil || !errors.Is(err, ErrTLSConfigSelectionFailed) {
|
|
|
|
|
t.Fatalf("expected tls config selection failure, got %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
snap := stats.Snapshot()
|
|
|
|
|
if snap.InitFailures != 1 || snap.TLSConfigFailures != 1 {
|
|
|
|
|
t.Fatalf("unexpected tls config failure stats: %+v", snap)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
t.Run("plain rejected", func(t *testing.T) {
|
|
|
|
|
client, server := net.Pipe()
|
|
|
|
|
defer client.Close()
|
|
|
|
|
defer server.Close()
|
|
|
|
|
|
|
|
|
|
stats := &Stats{}
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.BaseTLSConfig = TLSDefaults()
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
conn := newConn(server, cfg, stats)
|
|
|
|
|
conn.sniffer = staticSniffer{
|
|
|
|
|
result: SniffResult{
|
|
|
|
|
IsTLS: false,
|
|
|
|
|
Buffer: bytes.NewBuffer(nil),
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
_, err := conn.Read(buf)
|
|
|
|
|
if !errors.Is(err, ErrNonTLSNotAllowed) {
|
|
|
|
|
t.Fatalf("expected plain rejection, got %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
snap := stats.Snapshot()
|
|
|
|
|
if snap.InitFailures != 1 || snap.PlainRejected != 1 {
|
|
|
|
|
t.Fatalf("unexpected plain rejection stats: %+v", snap)
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-08 20:19:40 +08:00
|
|
|
func TestDialAndDialWithConfig(t *testing.T) {
|
|
|
|
|
nl, err := net.Listen("tcp", "127.0.0.1:0")
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer nl.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := nl.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
c1, err := Dial("tcp", nl.Addr().String())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("Dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c1.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("abc")
|
|
|
|
|
if _, err := c1.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("c1 write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
got := make([]byte, 3)
|
|
|
|
|
if _, err := io.ReadFull(c1, got); err != nil {
|
|
|
|
|
t.Fatalf("c1 read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
c2, err := DialWithConfig("tcp", nl.Addr().String(), DialConfig{Timeout: time.Second})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("DialWithConfig: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c2.Close()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestListenTLS_FileAPI(t *testing.T) {
|
|
|
|
|
certFile, keyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
|
|
|
|
|
defer cleanupFiles()
|
|
|
|
|
|
|
|
|
|
ln, err := ListenTLS("tcp", "127.0.0.1:0", certFile, keyFile, false)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("ListenTLS: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
tc, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "localhost",
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("tls dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer tc.Close()
|
|
|
|
|
|
|
|
|
|
msg := []byte("hi")
|
|
|
|
|
if _, err := tc.Write(msg); err != nil {
|
|
|
|
|
t.Fatalf("tls write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
out := make([]byte, 2)
|
|
|
|
|
if _, err := io.ReadFull(tc, out); err != nil {
|
|
|
|
|
t.Fatalf("tls read: %v", err)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestDialTLSWithConfig(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
clientCfg := &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true,
|
|
|
|
|
ServerName: "localhost",
|
|
|
|
|
}
|
|
|
|
|
c, err := DialTLSWithConfig("tcp", ln.Addr().String(), clientCfg, time.Second)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("DialTLSWithConfig: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
if !c.IsTLS() {
|
|
|
|
|
t.Fatalf("expected IsTLS true")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestDialTLS_FileAPI(t *testing.T) {
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
base := TLSDefaults()
|
|
|
|
|
base.Certificates = []tls.Certificate{cert}
|
|
|
|
|
cfg := DefaultListenerConfig()
|
|
|
|
|
cfg.BaseTLSConfig = base
|
|
|
|
|
cfg.AllowNonTLS = false
|
|
|
|
|
|
|
|
|
|
ln, err := ListenWithConfig("tcp", "127.0.0.1:0", cfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln.Close()
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
c, err := ln.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
_, _ = io.Copy(c, c)
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
clientCertFile, clientKeyFile, cleanupFiles := writeTempCertFiles(t, "localhost")
|
|
|
|
|
defer cleanupFiles()
|
|
|
|
|
|
|
|
|
|
c, err := DialTLS("tcp", ln.Addr().String(), clientCertFile, clientKeyFile)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("DialTLS: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer c.Close()
|
|
|
|
|
|
|
|
|
|
if !c.IsTLS() {
|
|
|
|
|
t.Fatalf("expected IsTLS true")
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func TestConnIsTLS_PlainAndTLS(t *testing.T) {
|
|
|
|
|
// ---- plain case ----
|
|
|
|
|
plainCfg := DefaultListenerConfig()
|
|
|
|
|
plainCfg.AllowNonTLS = true
|
|
|
|
|
|
|
|
|
|
ln1, err := ListenWithConfig("tcp", "127.0.0.1:0", plainCfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen1: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln1.Close()
|
|
|
|
|
|
|
|
|
|
plainDone := make(chan *Conn, 1)
|
|
|
|
|
plainErr := make(chan error, 1)
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
nc, err := ln1.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
plainErr <- err
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
sc, ok := nc.(*Conn)
|
|
|
|
|
if !ok {
|
|
|
|
|
_ = nc.Close()
|
|
|
|
|
plainErr <- errors.New("accepted conn is not *Conn")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
plainDone <- sc
|
|
|
|
|
|
|
|
|
|
// block until client sends one byte, then close
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
_, _ = sc.Read(buf)
|
|
|
|
|
_ = sc.Close()
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
c1, err := net.Dial("tcp", ln1.Addr().String())
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("dial1: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if _, err := c1.Write([]byte("p")); err != nil {
|
|
|
|
|
_ = c1.Close()
|
|
|
|
|
t.Fatalf("plain client write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
_ = c1.Close()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case err := <-plainErr:
|
|
|
|
|
t.Fatalf("plain server error: %v", err)
|
|
|
|
|
case sc1 := <-plainDone:
|
|
|
|
|
if sc1.IsTLS() {
|
|
|
|
|
t.Fatalf("plain conn should not be TLS")
|
|
|
|
|
}
|
|
|
|
|
case <-time.After(2 * time.Second):
|
|
|
|
|
t.Fatalf("timeout waiting plain side")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ---- tls case ----
|
|
|
|
|
cert := genSelfSignedCert(t, "localhost")
|
|
|
|
|
tlsBase := TLSDefaults()
|
|
|
|
|
tlsBase.Certificates = []tls.Certificate{cert}
|
|
|
|
|
|
|
|
|
|
tlsCfg := DefaultListenerConfig()
|
|
|
|
|
tlsCfg.BaseTLSConfig = tlsBase
|
|
|
|
|
tlsCfg.AllowNonTLS = false
|
|
|
|
|
|
|
|
|
|
ln2, err := ListenWithConfig("tcp", "127.0.0.1:0", tlsCfg)
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("listen2: %v", err)
|
|
|
|
|
}
|
|
|
|
|
defer ln2.Close()
|
|
|
|
|
|
|
|
|
|
tlsDone := make(chan *Conn, 1)
|
|
|
|
|
tlsErr := make(chan error, 1)
|
|
|
|
|
|
|
|
|
|
go func() {
|
|
|
|
|
nc, err := ln2.Accept()
|
|
|
|
|
if err != nil {
|
|
|
|
|
tlsErr <- err
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
sc, ok := nc.(*Conn)
|
|
|
|
|
if !ok {
|
|
|
|
|
_ = nc.Close()
|
|
|
|
|
tlsErr <- errors.New("accepted conn is not *Conn")
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
tlsDone <- sc
|
|
|
|
|
|
|
|
|
|
// key point: wait for real data to ensure TLS handshake/path is executed
|
|
|
|
|
buf := make([]byte, 1)
|
|
|
|
|
_, _ = sc.Read(buf)
|
|
|
|
|
_ = sc.Close()
|
|
|
|
|
}()
|
|
|
|
|
|
|
|
|
|
d := &net.Dialer{Timeout: 2 * time.Second}
|
|
|
|
|
tc, err := tls.DialWithDialer(d, "tcp", ln2.Addr().String(), &tls.Config{
|
|
|
|
|
InsecureSkipVerify: true, // test only
|
|
|
|
|
ServerName: "localhost",
|
|
|
|
|
MinVersion: tls.VersionTLS12,
|
|
|
|
|
})
|
|
|
|
|
if err != nil {
|
|
|
|
|
t.Fatalf("tls dial: %v", err)
|
|
|
|
|
}
|
|
|
|
|
if _, err := tc.Write([]byte("t")); err != nil {
|
|
|
|
|
_ = tc.Close()
|
|
|
|
|
t.Fatalf("tls client write: %v", err)
|
|
|
|
|
}
|
|
|
|
|
_ = tc.Close()
|
|
|
|
|
|
|
|
|
|
select {
|
|
|
|
|
case err := <-tlsErr:
|
|
|
|
|
t.Fatalf("tls server error: %v", err)
|
|
|
|
|
case sc2 := <-tlsDone:
|
|
|
|
|
if !sc2.IsTLS() {
|
|
|
|
|
t.Fatalf("tls conn should be TLS")
|
|
|
|
|
}
|
|
|
|
|
case <-time.After(3 * time.Second):
|
|
|
|
|
t.Fatalf("timeout waiting tls side")
|
|
|
|
|
}
|
|
|
|
|
}
|