package notify import ( "context" "errors" "io" "math/rand" "net" "os" "strconv" "sync" "sync/atomic" "testing" "time" ) const ( benchmarkTCPProxyEnableEnv = "NOTIFY_BENCH_TCP_PROXY_ENABLE" benchmarkTCPProxyListenAddrEnv = "NOTIFY_BENCH_TCP_PROXY_LISTEN_ADDR" benchmarkTCPProxyRateMbitEnv = "NOTIFY_BENCH_TCP_PROXY_RATE_MBIT" benchmarkTCPProxyDelayMsEnv = "NOTIFY_BENCH_TCP_PROXY_DELAY_MS" benchmarkTCPProxyLossPctEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PCT" benchmarkTCPProxyLossPenaltyMsEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PENALTY_MS" benchmarkTCPProxySeedEnv = "NOTIFY_BENCH_TCP_PROXY_SEED" benchmarkTCPProxyDefaultListenAddr = "127.0.0.1:0" ) const benchmarkTCPProxyDefaultLossPenalty = 120 * time.Millisecond type benchmarkTCPProxyConfig struct { Enabled bool ListenAddr string RateBytesPS float64 Delay time.Duration LossPct float64 LossPenalty time.Duration Seed int64 } type benchmarkTCPProxy struct { listener net.Listener target string cfg benchmarkTCPProxyConfig closed atomic.Bool wg sync.WaitGroup } func benchmarkTCPDialAddr(tb testing.TB, targetAddr string) string { tb.Helper() if targetAddr == "" { return targetAddr } cfg := benchmarkTCPProxyConfigFromEnv(tb) if !cfg.Enabled { return targetAddr } proxy, err := startBenchmarkTCPProxy(targetAddr, cfg) if err != nil { tb.Fatalf("start benchmark tcp proxy failed: %v", err) } tb.Cleanup(proxy.stop) if proxy.listener == nil || proxy.listener.Addr() == nil { tb.Fatalf("benchmark tcp proxy listener is nil") } return proxy.listener.Addr().String() } func benchmarkTCPProxyConfigFromEnv(tb testing.TB) benchmarkTCPProxyConfig { tb.Helper() enabled, ok, err := parseBoolEnv(benchmarkTCPProxyEnableEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxyEnableEnv, err) } if !ok || !enabled { return benchmarkTCPProxyConfig{} } cfg := benchmarkTCPProxyConfig{ Enabled: true, ListenAddr: benchmarkTCPProxyDefaultListenAddr, Seed: 1, } if listen := os.Getenv(benchmarkTCPProxyListenAddrEnv); listen != "" { cfg.ListenAddr = listen } rateMbit, ok, err := parseFloatEnv(benchmarkTCPProxyRateMbitEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxyRateMbitEnv, err) } if ok && rateMbit > 0 { cfg.RateBytesPS = rateMbit * 1000 * 1000 / 8 } delayMs, ok, err := parseIntEnv(benchmarkTCPProxyDelayMsEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxyDelayMsEnv, err) } if ok && delayMs > 0 { cfg.Delay = time.Duration(delayMs) * time.Millisecond } lossPct, ok, err := parseFloatEnv(benchmarkTCPProxyLossPctEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPctEnv, err) } if ok { if lossPct < 0 || lossPct > 100 { tb.Fatalf("%s must be in [0, 100], got %.4f", benchmarkTCPProxyLossPctEnv, lossPct) } cfg.LossPct = lossPct } lossPenaltyMs, ok, err := parseIntEnv(benchmarkTCPProxyLossPenaltyMsEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPenaltyMsEnv, err) } if ok && lossPenaltyMs > 0 { cfg.LossPenalty = time.Duration(lossPenaltyMs) * time.Millisecond } if cfg.LossPct > 0 && cfg.LossPenalty <= 0 { cfg.LossPenalty = benchmarkTCPProxyDefaultLossPenalty } seed, ok, err := parseInt64Env(benchmarkTCPProxySeedEnv) if err != nil { tb.Fatalf("invalid %s: %v", benchmarkTCPProxySeedEnv, err) } if ok { cfg.Seed = seed } return cfg } func parseBoolEnv(name string) (value bool, ok bool, err error) { raw := os.Getenv(name) if raw == "" { return false, false, nil } value, err = strconv.ParseBool(raw) return value, true, err } func parseFloatEnv(name string) (value float64, ok bool, err error) { raw := os.Getenv(name) if raw == "" { return 0, false, nil } value, err = strconv.ParseFloat(raw, 64) return value, true, err } func parseIntEnv(name string) (value int, ok bool, err error) { raw := os.Getenv(name) if raw == "" { return 0, false, nil } value, err = strconv.Atoi(raw) return value, true, err } func parseInt64Env(name string) (value int64, ok bool, err error) { raw := os.Getenv(name) if raw == "" { return 0, false, nil } value, err = strconv.ParseInt(raw, 10, 64) return value, true, err } func startBenchmarkTCPProxy(targetAddr string, cfg benchmarkTCPProxyConfig) (*benchmarkTCPProxy, error) { if targetAddr == "" { return nil, errors.New("benchmark tcp proxy target addr is empty") } listenAddr := cfg.ListenAddr if listenAddr == "" { listenAddr = benchmarkTCPProxyDefaultListenAddr } listener, err := net.Listen("tcp", listenAddr) if err != nil { return nil, err } proxy := &benchmarkTCPProxy{ listener: listener, target: targetAddr, cfg: cfg, } proxy.wg.Add(1) go proxy.runAcceptLoop() return proxy, nil } func (p *benchmarkTCPProxy) runAcceptLoop() { defer p.wg.Done() for { conn, err := p.listener.Accept() if err != nil { if p.closed.Load() || errors.Is(err, net.ErrClosed) { return } continue } p.wg.Add(1) go func(clientConn net.Conn) { defer p.wg.Done() p.handleConn(clientConn) }(conn) } } func (p *benchmarkTCPProxy) handleConn(clientConn net.Conn) { if p == nil || clientConn == nil { return } targetConn, err := net.Dial("tcp", p.target) if err != nil { _ = clientConn.Close() return } if tcpConn, ok := clientConn.(*net.TCPConn); ok { _ = tcpConn.SetNoDelay(true) } if tcpConn, ok := targetConn.(*net.TCPConn); ok { _ = tcpConn.SetNoDelay(true) } ctx, cancel := context.WithCancel(context.Background()) defer cancel() errCh := make(chan error, 2) leftRng := rand.New(rand.NewSource(p.cfg.Seed + 101)) rightRng := rand.New(rand.NewSource(p.cfg.Seed + 202)) go func() { errCh <- benchmarkRelayStream(ctx, targetConn, clientConn, p.cfg, leftRng) }() go func() { errCh <- benchmarkRelayStream(ctx, clientConn, targetConn, p.cfg, rightRng) }() _ = <-errCh cancel() _ = clientConn.Close() _ = targetConn.Close() <-errCh } func (p *benchmarkTCPProxy) stop() { if p == nil { return } if p.closed.CompareAndSwap(false, true) { if p.listener != nil { _ = p.listener.Close() } } p.wg.Wait() } func benchmarkRelayStream(ctx context.Context, dst net.Conn, src net.Conn, cfg benchmarkTCPProxyConfig, rng *rand.Rand) error { if dst == nil || src == nil { return net.ErrClosed } chunkCh := make(chan benchmarkRelayChunk, 128) readerErrCh := make(chan error, 1) go benchmarkRelayReader(ctx, src, chunkCh, readerErrCh) scheduler := benchmarkRelayScheduler{rateBytesPS: cfg.RateBytesPS} for chunk := range chunkCh { sendAt := scheduler.schedule(chunk.readAt, len(chunk.payload), cfg.Delay) if cfg.LossPct > 0 && cfg.LossPenalty > 0 && rng != nil && rng.Float64()*100 < cfg.LossPct { sendAt = sendAt.Add(cfg.LossPenalty) } if waitErr := benchmarkSleepUntilContext(ctx, sendAt); waitErr != nil { return waitErr } if writeErr := writeAllToConn(dst, chunk.payload); writeErr != nil { return writeErr } } readerErr := <-readerErrCh if errors.Is(readerErr, io.EOF) { return nil } return readerErr } type benchmarkRelayChunk struct { readAt time.Time payload []byte } func benchmarkRelayReader(ctx context.Context, src net.Conn, chunkCh chan<- benchmarkRelayChunk, errCh chan<- error) { defer close(chunkCh) if src == nil { errCh <- net.ErrClosed return } buf := make([]byte, 64*1024) for { n, err := src.Read(buf) if n > 0 { chunk := benchmarkRelayChunk{ readAt: time.Now(), payload: append([]byte(nil), buf[:n]...), } select { case <-ctx.Done(): errCh <- ctx.Err() return case chunkCh <- chunk: } } if err != nil { errCh <- err return } } } type benchmarkRelayScheduler struct { rateBytesPS float64 nextAt time.Time } func (s *benchmarkRelayScheduler) schedule(readAt time.Time, bytes int, delay time.Duration) time.Time { if s == nil { return readAt } sendAt := readAt if delay > 0 { sendAt = sendAt.Add(delay) } if !s.nextAt.IsZero() && s.nextAt.After(sendAt) { sendAt = s.nextAt } if s.rateBytesPS > 0 && bytes > 0 { duration := time.Duration(float64(bytes) / s.rateBytesPS * float64(time.Second)) if duration <= 0 { duration = time.Nanosecond } s.nextAt = sendAt.Add(duration) } else { s.nextAt = sendAt } return sendAt } func benchmarkSleepUntilContext(ctx context.Context, at time.Time) error { wait := time.Until(at) if wait <= 0 { return nil } return benchmarkSleepContext(ctx, wait) } func benchmarkSleepContext(ctx context.Context, d time.Duration) error { if d <= 0 { return nil } timer := time.NewTimer(d) defer timer.Stop() select { case <-ctx.Done(): return ctx.Err() case <-timer.C: return nil } } func writeAllToConn(conn net.Conn, payload []byte) error { for len(payload) > 0 { n, err := conn.Write(payload) if n > 0 { payload = payload[n:] } if err != nil { return err } if n == 0 { return io.ErrNoProgress } } return nil }