notify/benchmark_tcp_proxy_test.go

379 lines
8.8 KiB
Go
Raw Normal View History

package notify
import (
"context"
"errors"
"io"
"math/rand"
"net"
"os"
"strconv"
"sync"
"sync/atomic"
"testing"
"time"
)
const (
benchmarkTCPProxyEnableEnv = "NOTIFY_BENCH_TCP_PROXY_ENABLE"
benchmarkTCPProxyListenAddrEnv = "NOTIFY_BENCH_TCP_PROXY_LISTEN_ADDR"
benchmarkTCPProxyRateMbitEnv = "NOTIFY_BENCH_TCP_PROXY_RATE_MBIT"
benchmarkTCPProxyDelayMsEnv = "NOTIFY_BENCH_TCP_PROXY_DELAY_MS"
benchmarkTCPProxyLossPctEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PCT"
benchmarkTCPProxyLossPenaltyMsEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PENALTY_MS"
benchmarkTCPProxySeedEnv = "NOTIFY_BENCH_TCP_PROXY_SEED"
benchmarkTCPProxyDefaultListenAddr = "127.0.0.1:0"
)
const benchmarkTCPProxyDefaultLossPenalty = 120 * time.Millisecond
type benchmarkTCPProxyConfig struct {
Enabled bool
ListenAddr string
RateBytesPS float64
Delay time.Duration
LossPct float64
LossPenalty time.Duration
Seed int64
}
type benchmarkTCPProxy struct {
listener net.Listener
target string
cfg benchmarkTCPProxyConfig
closed atomic.Bool
wg sync.WaitGroup
}
func benchmarkTCPDialAddr(tb testing.TB, targetAddr string) string {
tb.Helper()
if targetAddr == "" {
return targetAddr
}
cfg := benchmarkTCPProxyConfigFromEnv(tb)
if !cfg.Enabled {
return targetAddr
}
proxy, err := startBenchmarkTCPProxy(targetAddr, cfg)
if err != nil {
tb.Fatalf("start benchmark tcp proxy failed: %v", err)
}
tb.Cleanup(proxy.stop)
if proxy.listener == nil || proxy.listener.Addr() == nil {
tb.Fatalf("benchmark tcp proxy listener is nil")
}
return proxy.listener.Addr().String()
}
func benchmarkTCPProxyConfigFromEnv(tb testing.TB) benchmarkTCPProxyConfig {
tb.Helper()
enabled, ok, err := parseBoolEnv(benchmarkTCPProxyEnableEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyEnableEnv, err)
}
if !ok || !enabled {
return benchmarkTCPProxyConfig{}
}
cfg := benchmarkTCPProxyConfig{
Enabled: true,
ListenAddr: benchmarkTCPProxyDefaultListenAddr,
Seed: 1,
}
if listen := os.Getenv(benchmarkTCPProxyListenAddrEnv); listen != "" {
cfg.ListenAddr = listen
}
rateMbit, ok, err := parseFloatEnv(benchmarkTCPProxyRateMbitEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyRateMbitEnv, err)
}
if ok && rateMbit > 0 {
cfg.RateBytesPS = rateMbit * 1000 * 1000 / 8
}
delayMs, ok, err := parseIntEnv(benchmarkTCPProxyDelayMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyDelayMsEnv, err)
}
if ok && delayMs > 0 {
cfg.Delay = time.Duration(delayMs) * time.Millisecond
}
lossPct, ok, err := parseFloatEnv(benchmarkTCPProxyLossPctEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPctEnv, err)
}
if ok {
if lossPct < 0 || lossPct > 100 {
tb.Fatalf("%s must be in [0, 100], got %.4f", benchmarkTCPProxyLossPctEnv, lossPct)
}
cfg.LossPct = lossPct
}
lossPenaltyMs, ok, err := parseIntEnv(benchmarkTCPProxyLossPenaltyMsEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPenaltyMsEnv, err)
}
if ok && lossPenaltyMs > 0 {
cfg.LossPenalty = time.Duration(lossPenaltyMs) * time.Millisecond
}
if cfg.LossPct > 0 && cfg.LossPenalty <= 0 {
cfg.LossPenalty = benchmarkTCPProxyDefaultLossPenalty
}
seed, ok, err := parseInt64Env(benchmarkTCPProxySeedEnv)
if err != nil {
tb.Fatalf("invalid %s: %v", benchmarkTCPProxySeedEnv, err)
}
if ok {
cfg.Seed = seed
}
return cfg
}
func parseBoolEnv(name string) (value bool, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return false, false, nil
}
value, err = strconv.ParseBool(raw)
return value, true, err
}
func parseFloatEnv(name string) (value float64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseFloat(raw, 64)
return value, true, err
}
func parseIntEnv(name string) (value int, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.Atoi(raw)
return value, true, err
}
func parseInt64Env(name string) (value int64, ok bool, err error) {
raw := os.Getenv(name)
if raw == "" {
return 0, false, nil
}
value, err = strconv.ParseInt(raw, 10, 64)
return value, true, err
}
func startBenchmarkTCPProxy(targetAddr string, cfg benchmarkTCPProxyConfig) (*benchmarkTCPProxy, error) {
if targetAddr == "" {
return nil, errors.New("benchmark tcp proxy target addr is empty")
}
listenAddr := cfg.ListenAddr
if listenAddr == "" {
listenAddr = benchmarkTCPProxyDefaultListenAddr
}
listener, err := net.Listen("tcp", listenAddr)
if err != nil {
return nil, err
}
proxy := &benchmarkTCPProxy{
listener: listener,
target: targetAddr,
cfg: cfg,
}
proxy.wg.Add(1)
go proxy.runAcceptLoop()
return proxy, nil
}
func (p *benchmarkTCPProxy) runAcceptLoop() {
defer p.wg.Done()
for {
conn, err := p.listener.Accept()
if err != nil {
if p.closed.Load() || errors.Is(err, net.ErrClosed) {
return
}
continue
}
p.wg.Add(1)
go func(clientConn net.Conn) {
defer p.wg.Done()
p.handleConn(clientConn)
}(conn)
}
}
func (p *benchmarkTCPProxy) handleConn(clientConn net.Conn) {
if p == nil || clientConn == nil {
return
}
targetConn, err := net.Dial("tcp", p.target)
if err != nil {
_ = clientConn.Close()
return
}
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
_ = tcpConn.SetNoDelay(true)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
errCh := make(chan error, 2)
leftRng := rand.New(rand.NewSource(p.cfg.Seed + 101))
rightRng := rand.New(rand.NewSource(p.cfg.Seed + 202))
go func() {
errCh <- benchmarkRelayStream(ctx, targetConn, clientConn, p.cfg, leftRng)
}()
go func() {
errCh <- benchmarkRelayStream(ctx, clientConn, targetConn, p.cfg, rightRng)
}()
_ = <-errCh
cancel()
_ = clientConn.Close()
_ = targetConn.Close()
<-errCh
}
func (p *benchmarkTCPProxy) stop() {
if p == nil {
return
}
if p.closed.CompareAndSwap(false, true) {
if p.listener != nil {
_ = p.listener.Close()
}
}
p.wg.Wait()
}
func benchmarkRelayStream(ctx context.Context, dst net.Conn, src net.Conn, cfg benchmarkTCPProxyConfig, rng *rand.Rand) error {
if dst == nil || src == nil {
return net.ErrClosed
}
chunkCh := make(chan benchmarkRelayChunk, 128)
readerErrCh := make(chan error, 1)
go benchmarkRelayReader(ctx, src, chunkCh, readerErrCh)
scheduler := benchmarkRelayScheduler{rateBytesPS: cfg.RateBytesPS}
for chunk := range chunkCh {
sendAt := scheduler.schedule(chunk.readAt, len(chunk.payload), cfg.Delay)
if cfg.LossPct > 0 && cfg.LossPenalty > 0 && rng != nil && rng.Float64()*100 < cfg.LossPct {
sendAt = sendAt.Add(cfg.LossPenalty)
}
if waitErr := benchmarkSleepUntilContext(ctx, sendAt); waitErr != nil {
return waitErr
}
if writeErr := writeAllToConn(dst, chunk.payload); writeErr != nil {
return writeErr
}
}
readerErr := <-readerErrCh
if errors.Is(readerErr, io.EOF) {
return nil
}
return readerErr
}
type benchmarkRelayChunk struct {
readAt time.Time
payload []byte
}
func benchmarkRelayReader(ctx context.Context, src net.Conn, chunkCh chan<- benchmarkRelayChunk, errCh chan<- error) {
defer close(chunkCh)
if src == nil {
errCh <- net.ErrClosed
return
}
buf := make([]byte, 64*1024)
for {
n, err := src.Read(buf)
if n > 0 {
chunk := benchmarkRelayChunk{
readAt: time.Now(),
payload: append([]byte(nil), buf[:n]...),
}
select {
case <-ctx.Done():
errCh <- ctx.Err()
return
case chunkCh <- chunk:
}
}
if err != nil {
errCh <- err
return
}
}
}
type benchmarkRelayScheduler struct {
rateBytesPS float64
nextAt time.Time
}
func (s *benchmarkRelayScheduler) schedule(readAt time.Time, bytes int, delay time.Duration) time.Time {
if s == nil {
return readAt
}
sendAt := readAt
if delay > 0 {
sendAt = sendAt.Add(delay)
}
if !s.nextAt.IsZero() && s.nextAt.After(sendAt) {
sendAt = s.nextAt
}
if s.rateBytesPS > 0 && bytes > 0 {
duration := time.Duration(float64(bytes) / s.rateBytesPS * float64(time.Second))
if duration <= 0 {
duration = time.Nanosecond
}
s.nextAt = sendAt.Add(duration)
} else {
s.nextAt = sendAt
}
return sendAt
}
func benchmarkSleepUntilContext(ctx context.Context, at time.Time) error {
wait := time.Until(at)
if wait <= 0 {
return nil
}
return benchmarkSleepContext(ctx, wait)
}
func benchmarkSleepContext(ctx context.Context, d time.Duration) error {
if d <= 0 {
return nil
}
timer := time.NewTimer(d)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func writeAllToConn(conn net.Conn, payload []byte) error {
for len(payload) > 0 {
n, err := conn.Write(payload)
if n > 0 {
payload = payload[n:]
}
if err != nil {
return err
}
if n == 0 {
return io.ErrNoProgress
}
}
return nil
}