From f038a897716b440264948a2f2c7ce1af48f2e6c9 Mon Sep 17 00:00:00 2001 From: starainrt Date: Sat, 18 Apr 2026 16:05:57 +0800 Subject: [PATCH] fix: close stream adaptive gaps and switch notify to stario v0.1.1 - make stream fast path honor adaptive soft payload limits end-to-end - split oversized fast-stream payloads into sequential frames before batching - use adaptive soft cap when encoding stream batch payloads - move timeout-like error detection into production code for adaptive tx - tune notify FrameReader read size explicitly to avoid throughput regression - drop local stario replace and depend on released b612.me/stario v0.1.1 --- benchmark_listen_test.go | 16 + benchmark_tcp_proxy_test.go | 378 ++++++ bulk.go | 1233 ++++++++++++++++++-- bulk_batch_sender.go | 548 ++++++++- bulk_benchmark_test.go | 8 +- bulk_buffer_release_test.go | 119 ++ bulk_control.go | 413 ++++++- bulk_dedicated.go | 1102 +++++++++++++++--- bulk_dedicated_attach_test.go | 250 ++++ bulk_dedicated_batch.go | 426 ++++++- bulk_dedicated_lane_sender.go | 639 ++++++++++ bulk_dedicated_lane_sender_test.go | 145 +++ bulk_dedicated_sidecar.go | 451 +++++++ bulk_dispatcher.go | 109 +- bulk_e2e_benchmark_test.go | 9 +- bulk_fastpath.go | 329 +++++- bulk_runtime.go | 171 ++- bulk_shared_batch.go | 228 ++++ bulk_shared_batch_test.go | 458 ++++++++ bulk_snapshot.go | 93 +- bulk_stack_benchmark_test.go | 4 +- bulk_test.go | 323 +++++- bulk_transport_guard_test.go | 1292 +++++++++++++++++++++ client.go | 75 +- client_bulk.go | 175 ++- client_bulk_config.go | 364 ++++++ client_config.go | 9 + client_conn.go | 17 + client_conn_attachment.go | 62 + client_conn_transport.go | 169 +++ client_runtime.go | 15 + client_session_runtime_test.go | 4 +- client_stream.go | 18 +- client_transport.go | 89 +- clienttype.go | 10 + default.go | 2 + diagnostics_snapshot.go | 11 +- go.mod | 2 +- go.sum | 4 +- logical_conn.go | 8 + raw_tcp_benchmark_test.go | 4 +- release_p0_test.go | 134 +++ security_psk.go | 187 ++- security_psk_test.go | 29 + server.go | 6 + server_bulk.go | 199 +++- server_config.go | 9 + server_inbound_source.go | 75 +- server_listen.go | 1 + server_session.go | 3 + server_stream.go | 23 +- servertype.go | 2 + session_runtime_snapshot.go | 129 +- session_runtime_snapshot_test.go | 172 +++ session_state.go | 2 + signal_benchmark_test.go | 4 +- snapshot_binding.go | 44 +- stream.go | 100 +- stream_batch_codec.go | 6 + stream_batch_sender.go | 582 ++++++++++ stream_benchmark_test.go | 8 +- stream_control.go | 18 +- stream_fastpath.go | 199 +++- stream_fastpath_test.go | 414 +++++++ stream_flow.go | 220 +++- stream_flow_test.go | 48 + stream_runtime.go | 50 +- stream_shared_batch.go | 145 +++ stream_snapshot.go | 88 +- timeout_error_test.go => timeout_error.go | 0 transport_binding.go | 86 +- transport_binding_adaptive.go | 383 ++++++ transport_binding_adaptive_test.go | 160 +++ transport_codec.go | 124 +- transport_conn.go | 14 +- transport_write_test.go | 116 +- 76 files changed, 12656 insertions(+), 906 deletions(-) create mode 100644 benchmark_listen_test.go create mode 100644 benchmark_tcp_proxy_test.go create mode 100644 bulk_buffer_release_test.go create mode 100644 bulk_dedicated_lane_sender.go create mode 100644 bulk_dedicated_lane_sender_test.go create mode 100644 bulk_dedicated_sidecar.go create mode 100644 bulk_shared_batch.go create mode 100644 bulk_shared_batch_test.go create mode 100644 client_bulk_config.go create mode 100644 stream_batch_codec.go create mode 100644 stream_batch_sender.go create mode 100644 stream_shared_batch.go rename timeout_error_test.go => timeout_error.go (100%) create mode 100644 transport_binding_adaptive.go create mode 100644 transport_binding_adaptive_test.go diff --git a/benchmark_listen_test.go b/benchmark_listen_test.go new file mode 100644 index 0000000..7021e88 --- /dev/null +++ b/benchmark_listen_test.go @@ -0,0 +1,16 @@ +package notify + +import ( + "os" + "testing" +) + +const benchmarkTCPListenAddrEnv = "NOTIFY_BENCH_TCP_LISTEN_ADDR" + +func benchmarkTCPListenAddr(tb testing.TB) string { + tb.Helper() + if addr := os.Getenv(benchmarkTCPListenAddrEnv); addr != "" { + return addr + } + return "127.0.0.1:0" +} diff --git a/benchmark_tcp_proxy_test.go b/benchmark_tcp_proxy_test.go new file mode 100644 index 0000000..68ca97e --- /dev/null +++ b/benchmark_tcp_proxy_test.go @@ -0,0 +1,378 @@ +package notify + +import ( + "context" + "errors" + "io" + "math/rand" + "net" + "os" + "strconv" + "sync" + "sync/atomic" + "testing" + "time" +) + +const ( + benchmarkTCPProxyEnableEnv = "NOTIFY_BENCH_TCP_PROXY_ENABLE" + benchmarkTCPProxyListenAddrEnv = "NOTIFY_BENCH_TCP_PROXY_LISTEN_ADDR" + benchmarkTCPProxyRateMbitEnv = "NOTIFY_BENCH_TCP_PROXY_RATE_MBIT" + benchmarkTCPProxyDelayMsEnv = "NOTIFY_BENCH_TCP_PROXY_DELAY_MS" + benchmarkTCPProxyLossPctEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PCT" + benchmarkTCPProxyLossPenaltyMsEnv = "NOTIFY_BENCH_TCP_PROXY_LOSS_PENALTY_MS" + benchmarkTCPProxySeedEnv = "NOTIFY_BENCH_TCP_PROXY_SEED" + benchmarkTCPProxyDefaultListenAddr = "127.0.0.1:0" +) + +const benchmarkTCPProxyDefaultLossPenalty = 120 * time.Millisecond + +type benchmarkTCPProxyConfig struct { + Enabled bool + ListenAddr string + RateBytesPS float64 + Delay time.Duration + LossPct float64 + LossPenalty time.Duration + Seed int64 +} + +type benchmarkTCPProxy struct { + listener net.Listener + target string + cfg benchmarkTCPProxyConfig + + closed atomic.Bool + wg sync.WaitGroup +} + +func benchmarkTCPDialAddr(tb testing.TB, targetAddr string) string { + tb.Helper() + if targetAddr == "" { + return targetAddr + } + cfg := benchmarkTCPProxyConfigFromEnv(tb) + if !cfg.Enabled { + return targetAddr + } + proxy, err := startBenchmarkTCPProxy(targetAddr, cfg) + if err != nil { + tb.Fatalf("start benchmark tcp proxy failed: %v", err) + } + tb.Cleanup(proxy.stop) + if proxy.listener == nil || proxy.listener.Addr() == nil { + tb.Fatalf("benchmark tcp proxy listener is nil") + } + return proxy.listener.Addr().String() +} + +func benchmarkTCPProxyConfigFromEnv(tb testing.TB) benchmarkTCPProxyConfig { + tb.Helper() + enabled, ok, err := parseBoolEnv(benchmarkTCPProxyEnableEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxyEnableEnv, err) + } + if !ok || !enabled { + return benchmarkTCPProxyConfig{} + } + cfg := benchmarkTCPProxyConfig{ + Enabled: true, + ListenAddr: benchmarkTCPProxyDefaultListenAddr, + Seed: 1, + } + if listen := os.Getenv(benchmarkTCPProxyListenAddrEnv); listen != "" { + cfg.ListenAddr = listen + } + rateMbit, ok, err := parseFloatEnv(benchmarkTCPProxyRateMbitEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxyRateMbitEnv, err) + } + if ok && rateMbit > 0 { + cfg.RateBytesPS = rateMbit * 1000 * 1000 / 8 + } + delayMs, ok, err := parseIntEnv(benchmarkTCPProxyDelayMsEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxyDelayMsEnv, err) + } + if ok && delayMs > 0 { + cfg.Delay = time.Duration(delayMs) * time.Millisecond + } + lossPct, ok, err := parseFloatEnv(benchmarkTCPProxyLossPctEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPctEnv, err) + } + if ok { + if lossPct < 0 || lossPct > 100 { + tb.Fatalf("%s must be in [0, 100], got %.4f", benchmarkTCPProxyLossPctEnv, lossPct) + } + cfg.LossPct = lossPct + } + lossPenaltyMs, ok, err := parseIntEnv(benchmarkTCPProxyLossPenaltyMsEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxyLossPenaltyMsEnv, err) + } + if ok && lossPenaltyMs > 0 { + cfg.LossPenalty = time.Duration(lossPenaltyMs) * time.Millisecond + } + if cfg.LossPct > 0 && cfg.LossPenalty <= 0 { + cfg.LossPenalty = benchmarkTCPProxyDefaultLossPenalty + } + seed, ok, err := parseInt64Env(benchmarkTCPProxySeedEnv) + if err != nil { + tb.Fatalf("invalid %s: %v", benchmarkTCPProxySeedEnv, err) + } + if ok { + cfg.Seed = seed + } + return cfg +} + +func parseBoolEnv(name string) (value bool, ok bool, err error) { + raw := os.Getenv(name) + if raw == "" { + return false, false, nil + } + value, err = strconv.ParseBool(raw) + return value, true, err +} + +func parseFloatEnv(name string) (value float64, ok bool, err error) { + raw := os.Getenv(name) + if raw == "" { + return 0, false, nil + } + value, err = strconv.ParseFloat(raw, 64) + return value, true, err +} + +func parseIntEnv(name string) (value int, ok bool, err error) { + raw := os.Getenv(name) + if raw == "" { + return 0, false, nil + } + value, err = strconv.Atoi(raw) + return value, true, err +} + +func parseInt64Env(name string) (value int64, ok bool, err error) { + raw := os.Getenv(name) + if raw == "" { + return 0, false, nil + } + value, err = strconv.ParseInt(raw, 10, 64) + return value, true, err +} + +func startBenchmarkTCPProxy(targetAddr string, cfg benchmarkTCPProxyConfig) (*benchmarkTCPProxy, error) { + if targetAddr == "" { + return nil, errors.New("benchmark tcp proxy target addr is empty") + } + listenAddr := cfg.ListenAddr + if listenAddr == "" { + listenAddr = benchmarkTCPProxyDefaultListenAddr + } + listener, err := net.Listen("tcp", listenAddr) + if err != nil { + return nil, err + } + proxy := &benchmarkTCPProxy{ + listener: listener, + target: targetAddr, + cfg: cfg, + } + proxy.wg.Add(1) + go proxy.runAcceptLoop() + return proxy, nil +} + +func (p *benchmarkTCPProxy) runAcceptLoop() { + defer p.wg.Done() + for { + conn, err := p.listener.Accept() + if err != nil { + if p.closed.Load() || errors.Is(err, net.ErrClosed) { + return + } + continue + } + p.wg.Add(1) + go func(clientConn net.Conn) { + defer p.wg.Done() + p.handleConn(clientConn) + }(conn) + } +} + +func (p *benchmarkTCPProxy) handleConn(clientConn net.Conn) { + if p == nil || clientConn == nil { + return + } + targetConn, err := net.Dial("tcp", p.target) + if err != nil { + _ = clientConn.Close() + return + } + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + _ = tcpConn.SetNoDelay(true) + } + if tcpConn, ok := targetConn.(*net.TCPConn); ok { + _ = tcpConn.SetNoDelay(true) + } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + errCh := make(chan error, 2) + leftRng := rand.New(rand.NewSource(p.cfg.Seed + 101)) + rightRng := rand.New(rand.NewSource(p.cfg.Seed + 202)) + go func() { + errCh <- benchmarkRelayStream(ctx, targetConn, clientConn, p.cfg, leftRng) + }() + go func() { + errCh <- benchmarkRelayStream(ctx, clientConn, targetConn, p.cfg, rightRng) + }() + + _ = <-errCh + cancel() + _ = clientConn.Close() + _ = targetConn.Close() + <-errCh +} + +func (p *benchmarkTCPProxy) stop() { + if p == nil { + return + } + if p.closed.CompareAndSwap(false, true) { + if p.listener != nil { + _ = p.listener.Close() + } + } + p.wg.Wait() +} + +func benchmarkRelayStream(ctx context.Context, dst net.Conn, src net.Conn, cfg benchmarkTCPProxyConfig, rng *rand.Rand) error { + if dst == nil || src == nil { + return net.ErrClosed + } + chunkCh := make(chan benchmarkRelayChunk, 128) + readerErrCh := make(chan error, 1) + go benchmarkRelayReader(ctx, src, chunkCh, readerErrCh) + + scheduler := benchmarkRelayScheduler{rateBytesPS: cfg.RateBytesPS} + for chunk := range chunkCh { + sendAt := scheduler.schedule(chunk.readAt, len(chunk.payload), cfg.Delay) + if cfg.LossPct > 0 && cfg.LossPenalty > 0 && rng != nil && rng.Float64()*100 < cfg.LossPct { + sendAt = sendAt.Add(cfg.LossPenalty) + } + if waitErr := benchmarkSleepUntilContext(ctx, sendAt); waitErr != nil { + return waitErr + } + if writeErr := writeAllToConn(dst, chunk.payload); writeErr != nil { + return writeErr + } + } + readerErr := <-readerErrCh + if errors.Is(readerErr, io.EOF) { + return nil + } + return readerErr +} + +type benchmarkRelayChunk struct { + readAt time.Time + payload []byte +} + +func benchmarkRelayReader(ctx context.Context, src net.Conn, chunkCh chan<- benchmarkRelayChunk, errCh chan<- error) { + defer close(chunkCh) + if src == nil { + errCh <- net.ErrClosed + return + } + buf := make([]byte, 64*1024) + for { + n, err := src.Read(buf) + if n > 0 { + chunk := benchmarkRelayChunk{ + readAt: time.Now(), + payload: append([]byte(nil), buf[:n]...), + } + select { + case <-ctx.Done(): + errCh <- ctx.Err() + return + case chunkCh <- chunk: + } + } + if err != nil { + errCh <- err + return + } + } +} + +type benchmarkRelayScheduler struct { + rateBytesPS float64 + nextAt time.Time +} + +func (s *benchmarkRelayScheduler) schedule(readAt time.Time, bytes int, delay time.Duration) time.Time { + if s == nil { + return readAt + } + sendAt := readAt + if delay > 0 { + sendAt = sendAt.Add(delay) + } + if !s.nextAt.IsZero() && s.nextAt.After(sendAt) { + sendAt = s.nextAt + } + if s.rateBytesPS > 0 && bytes > 0 { + duration := time.Duration(float64(bytes) / s.rateBytesPS * float64(time.Second)) + if duration <= 0 { + duration = time.Nanosecond + } + s.nextAt = sendAt.Add(duration) + } else { + s.nextAt = sendAt + } + return sendAt +} + +func benchmarkSleepUntilContext(ctx context.Context, at time.Time) error { + wait := time.Until(at) + if wait <= 0 { + return nil + } + return benchmarkSleepContext(ctx, wait) +} + +func benchmarkSleepContext(ctx context.Context, d time.Duration) error { + if d <= 0 { + return nil + } + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func writeAllToConn(conn net.Conn, payload []byte) error { + for len(payload) > 0 { + n, err := conn.Write(payload) + if n > 0 { + payload = payload[n:] + } + if err != nil { + return err + } + if n == 0 { + return io.ErrNoProgress + } + } + return nil +} diff --git a/bulk.go b/bulk.go index 2def21d..89e8168 100644 --- a/bulk.go +++ b/bulk.go @@ -7,6 +7,7 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" ) @@ -14,6 +15,7 @@ const ( BulkOpenSignalKey = "notify.bulk.open" BulkCloseSignalKey = "notify.bulk.close" BulkResetSignalKey = "notify.bulk.reset" + BulkReadySignalKey = "notify.bulk.ready" BulkReleaseSignalKey = "notify.bulk.release" defaultBulkChunkSize = 1024 * 1024 @@ -23,6 +25,7 @@ const ( defaultBulkOpenMaxInFlight = 32 defaultBulkControlReadTimeout = 0 defaultBulkControlWriteTimeout = 0 + defaultBulkAcceptReadyTimeout = 10 * time.Second ) type BulkMetadata map[string]string @@ -32,13 +35,79 @@ type BulkRange struct { Length int64 } +type BulkOpenMode uint8 + +const ( + // BulkOpenModeDefault keeps legacy behavior: + // Dedicated=true -> dedicated, otherwise shared. + BulkOpenModeDefault BulkOpenMode = iota + // BulkOpenModeAuto prefers dedicated and falls back to shared. + BulkOpenModeAuto + // BulkOpenModeShared forces shared transport path. + BulkOpenModeShared + // BulkOpenModeDedicated forces dedicated transport path. + BulkOpenModeDedicated +) + +type BulkNetworkProfile uint8 + +const ( + // BulkNetworkProfileDefault keeps legacy defaults. + BulkNetworkProfileDefault BulkNetworkProfile = iota + // BulkNetworkProfileLAN is optimized for low-latency/local links. + BulkNetworkProfileLAN + // BulkNetworkProfileWAN is tuned for moderate RTT and occasional loss. + BulkNetworkProfileWAN + // BulkNetworkProfileConstrained is tuned for low bandwidth and unstable links. + BulkNetworkProfileConstrained +) + +type BulkDedicatedAttachConfig struct { + // AttachLimit limits concurrent dedicated attach handshakes per client session. + // 0 means unlimited. + AttachLimit int + // ActiveLimit limits active logical dedicated bulks per client session. + // Physical sidecars remain bounded by LaneLimit. + // 0 means unlimited. + ActiveLimit int + // LaneLimit limits dedicated physical sidecar lanes per client session. + // 0 means unlimited. + LaneLimit int + // Retry controls extra retries after the first attach attempt. + Retry int + // Backoff is the base retry backoff. + Backoff time.Duration + // DialTimeout is used for dedicated sidecar dialing. + DialTimeout time.Duration + // HelloTimeout is used for dedicated attach request/response handshake. + HelloTimeout time.Duration +} + +type BulkOpenTuning struct { + ChunkSize int + WindowBytes int + MaxInFlight int +} + +type bulkDedicatedAttachState uint8 + +const ( + bulkDedicatedAttachStatePending bulkDedicatedAttachState = iota + bulkDedicatedAttachStateAttached + bulkDedicatedAttachStateDegraded + bulkDedicatedAttachStateClosed +) + type BulkOpenOptions struct { ID string Range BulkRange Metadata BulkMetadata ReadTimeout time.Duration WriteTimeout time.Duration - Dedicated bool + Mode BulkOpenMode + // Deprecated: Dedicated is kept for backward compatibility. + // Prefer Mode. + Dedicated bool ChunkSize int WindowBytes int @@ -92,6 +161,7 @@ var ( errBulkRangeInvalid = errors.New("bulk range is invalid") errBulkBackpressureExceeded = errors.New("bulk inbound backpressure exceeded") errBulkDedicatedStreamOnly = errors.New("dedicated bulk requires stream transport") + errBulkDedicatedActiveLimit = errors.New("dedicated bulk active limit reached") ) func clientDedicatedBulkSupportError(c *ClientCommon) error { @@ -136,14 +206,77 @@ func transportDedicatedBulkSupportError(transport *TransportConn) error { type bulkCloseSender func(context.Context, *bulkHandle, bool) error type bulkResetSender func(context.Context, *bulkHandle, string) error type bulkDataSender func(context.Context, *bulkHandle, []byte) error -type bulkWriteSender func(context.Context, *bulkHandle, []byte) (int, error) +type bulkWriteSender func(context.Context, *bulkHandle, uint64, []byte, bool) (int, error) type bulkReleaseSender func(*bulkHandle, int64, int) error +type bulkAsyncWriteRequest struct { + startSeq uint64 + payload []byte + chunks int +} + +var bulkAsyncWritePayloadPool sync.Pool + +type bulkReadChunk struct { + data []byte + release func() +} + +func (c *bulkReadChunk) clear() { + if c == nil { + return + } + if c.release != nil { + c.release() + } + c.data = nil + c.release = nil +} + +type bulkReadPayloadOwner struct { + refs atomic.Int32 + release func() +} + +func newBulkReadPayloadOwner(release func()) *bulkReadPayloadOwner { + if release == nil { + return nil + } + owner := &bulkReadPayloadOwner{release: release} + owner.refs.Store(1) + return owner +} + +func (o *bulkReadPayloadOwner) retainChunk() func() { + if o == nil { + return nil + } + o.refs.Add(1) + return o.releaseChunk +} + +func (o *bulkReadPayloadOwner) releaseChunk() { + if o == nil { + return + } + if o.refs.Add(-1) == 0 && o.release != nil { + o.release() + } +} + +func (o *bulkReadPayloadOwner) done() { + if o == nil { + return + } + o.releaseChunk() +} + type bulkHandle struct { runtime *bulkRuntime runtimeScope string id string dataID uint64 + fastPathVersion uint8 outboundSeq uint64 rangeSpec BulkRange metadata BulkMetadata @@ -155,6 +288,7 @@ type bulkHandle struct { readTimeout time.Duration writeTimeout time.Duration dedicated bool + dedicatedLaneID uint32 dedicatedAttachToken string chunkSize int windowBytes int @@ -168,21 +302,31 @@ type bulkHandle struct { releaseFn bulkReleaseSender ctx context.Context cancel context.CancelFunc + writeCtx context.Context + writeCtxCancel context.CancelFunc createdAt time.Time writeMu sync.Mutex mu sync.Mutex + writeQueue chan bulkAsyncWriteRequest + writeWorkerDone chan struct{} + writeDrain chan struct{} + pendingAsyncWrites int localClosed bool localReadClosed bool remoteClosed bool peerReadClosed bool resetErr error - readQueue [][]byte - readBuf []byte + readQueue []bulkReadChunk + readBuf bulkReadChunk bufferedBytes int readNotify chan struct{} flowNotify chan struct{} + writeStateDone chan struct{} + writeStateClosed bool + releaseNotify chan struct{} + releaseWorkerDone chan struct{} pendingReleaseBytes int64 pendingReleaseChunks int outboundAvailBytes int64 @@ -196,12 +340,23 @@ type bulkHandle struct { dedicatedMu sync.Mutex dedicatedConn net.Conn + dedicatedConnOwned bool dedicatedSender *bulkDedicatedSender dedicatedReady chan struct{} dedicatedWriteClosed bool + dedicatedActiveLease bool + dedicatedState bulkDedicatedAttachState + dedicatedAttempts uint32 + dedicatedLastCode string + dedicatedDataStarted bool acceptMu sync.Mutex acceptDispatched bool + acceptReady chan struct{} + acceptReadyDone bool + acceptReadyErr error + acceptNotifyFn func(error) + acceptNotifySent bool } func newBulkHandle(parent context.Context, runtime *bulkRuntime, runtimeScope string, req BulkOpenRequest, sessionEpoch uint64, logical *LogicalConn, transport *TransportConn, transportGeneration uint64, closeFn bulkCloseSender, resetFn bulkResetSender, sendDataFn bulkDataSender, sendWriteFn bulkWriteSender, releaseFn bulkReleaseSender) *bulkHandle { @@ -216,11 +371,12 @@ func newBulkHandle(parent context.Context, runtime *bulkRuntime, runtimeScope st transportGeneration = logical.transportGenerationSnapshot() } req = normalizeBulkOpenRequest(req) - return &bulkHandle{ + handle := &bulkHandle{ runtime: runtime, runtimeScope: runtimeScope, id: req.BulkID, dataID: req.DataID, + fastPathVersion: normalizeBulkFastPathVersion(req.FastPathVersion), rangeSpec: req.Range, metadata: cloneBulkMetadata(req.Metadata), sessionEpoch: sessionEpoch, @@ -230,6 +386,7 @@ func newBulkHandle(parent context.Context, runtime *bulkRuntime, runtimeScope st readTimeout: req.ReadTimeout, writeTimeout: req.WriteTimeout, dedicated: req.Dedicated, + dedicatedLaneID: req.DedicatedLaneID, dedicatedAttachToken: req.AttachToken, chunkSize: req.ChunkSize, windowBytes: req.WindowBytes, @@ -246,9 +403,34 @@ func newBulkHandle(parent context.Context, runtime *bulkRuntime, runtimeScope st createdAt: time.Now(), readNotify: make(chan struct{}, 1), flowNotify: make(chan struct{}, 1), + writeStateDone: make(chan struct{}), dedicatedReady: make(chan struct{}), + dedicatedState: initialBulkDedicatedAttachState(req.Dedicated), + acceptReady: make(chan struct{}), outboundAvailBytes: int64(req.WindowBytes), } + drain := make(chan struct{}) + close(drain) + handle.writeDrain = drain + if sendWriteFn != nil { + handle.writeCtx, handle.writeCtxCancel = context.WithCancel(ctx) + handle.writeQueue = make(chan bulkAsyncWriteRequest, bulkAsyncWriteQueueSize(req.MaxInFlight)) + handle.writeWorkerDone = make(chan struct{}) + go func(parentDone <-chan struct{}, writeDone <-chan struct{}, cancel context.CancelFunc) { + select { + case <-parentDone: + case <-writeDone: + cancel() + } + }(ctx.Done(), handle.writeStateDone, handle.writeCtxCancel) + go handle.runAsyncWriteLoop() + } + if handle.flowControlEnabled() { + handle.releaseNotify = make(chan struct{}, 1) + handle.releaseWorkerDone = make(chan struct{}) + go handle.runWindowReleaseLoop() + } + return handle } func (b *bulkHandle) ID() string { @@ -258,6 +440,19 @@ func (b *bulkHandle) ID() string { return b.id } +func (b *bulkHandle) fastPathVersionSnapshot() uint8 { + if b == nil { + return bulkFastPathVersionV1 + } + b.mu.Lock() + defer b.mu.Unlock() + return normalizeBulkFastPathVersion(b.fastPathVersion) +} + +func (b *bulkHandle) FastPathVersion() uint8 { + return b.fastPathVersionSnapshot() +} + func (b *bulkHandle) Range() BulkRange { if b == nil { return BulkRange{} @@ -307,6 +502,15 @@ func (b *bulkHandle) Dedicated() bool { return b.dedicated } +func (b *bulkHandle) dedicatedLaneIDSnapshot() uint32 { + if b == nil { + return 0 + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedLaneID +} + func (b *bulkHandle) dedicatedAttachTokenSnapshot() string { if b == nil { return "" @@ -325,6 +529,61 @@ func (b *bulkHandle) setDedicatedAttachToken(token string) { b.mu.Unlock() } +func (b *bulkHandle) markDedicatedAttachAttempt() { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedAttempts++ + b.dedicatedMu.Unlock() +} + +func (b *bulkHandle) setDedicatedAttachLastCode(code string) { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedLastCode = code + b.dedicatedMu.Unlock() +} + +func (b *bulkHandle) markDedicatedAttachDegraded(code string) { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedState = bulkDedicatedAttachStateDegraded + b.dedicatedLastCode = code + b.dedicatedMu.Unlock() +} + +func (b *bulkHandle) markDedicatedAttachClosed() { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedState = bulkDedicatedAttachStateClosed + b.dedicatedMu.Unlock() +} + +func (b *bulkHandle) dedicatedAttachStateSnapshot() bulkDedicatedAttachState { + if b == nil { + return bulkDedicatedAttachStateClosed + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedState +} + +func (b *bulkHandle) dedicatedAttachDiagnosticsSnapshot() (state bulkDedicatedAttachState, attempts uint32, lastCode string, dataStarted bool) { + if b == nil { + return bulkDedicatedAttachStateClosed, 0, "", false + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedState, b.dedicatedAttempts, b.dedicatedLastCode, b.dedicatedDataStarted +} + func (b *bulkHandle) dedicatedConnSnapshot() net.Conn { if b == nil { return nil @@ -371,6 +630,24 @@ func (b *bulkHandle) dedicatedAttachedSnapshot() bool { return b.dedicatedConnSnapshot() != nil } +func (b *bulkHandle) dedicatedDataStartedSnapshot() bool { + if b == nil { + return false + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + return b.dedicatedDataStarted +} + +func (b *bulkHandle) markDedicatedDataStarted() { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedDataStarted = true + b.dedicatedMu.Unlock() +} + func (b *bulkHandle) waitDedicatedReady(ctx context.Context) error { if b == nil || !b.Dedicated() || b.dedicatedAttachedSnapshot() { return nil @@ -378,6 +655,9 @@ func (b *bulkHandle) waitDedicatedReady(ctx context.Context) error { if ctx == nil { ctx = context.Background() } + if err := b.writeStateErrorSnapshot(); err != nil { + return err + } select { case <-b.dedicatedReady: return nil @@ -404,7 +684,10 @@ func (b *bulkHandle) attachDedicatedConn(conn net.Conn) error { return errors.New("bulk dedicated conn already attached") } b.dedicatedConn = conn + b.dedicatedConnOwned = true b.dedicatedWriteClosed = false + b.dedicatedState = bulkDedicatedAttachStateAttached + b.dedicatedLastCode = "" ready := b.dedicatedReady b.dedicatedMu.Unlock() if ready != nil { @@ -417,15 +700,116 @@ func (b *bulkHandle) attachDedicatedConn(conn net.Conn) error { return nil } +func (b *bulkHandle) attachDedicatedConnShared(conn net.Conn) error { + if b == nil { + return io.ErrClosedPipe + } + if conn == nil { + return net.ErrClosed + } + b.dedicatedMu.Lock() + if b.dedicatedConn != nil { + if b.dedicatedConn == conn { + b.dedicatedConnOwned = false + b.dedicatedState = bulkDedicatedAttachStateAttached + b.dedicatedLastCode = "" + b.dedicatedMu.Unlock() + return nil + } + b.dedicatedMu.Unlock() + return errors.New("bulk dedicated conn already attached") + } + b.dedicatedConn = conn + b.dedicatedConnOwned = false + b.dedicatedWriteClosed = false + b.dedicatedState = bulkDedicatedAttachStateAttached + b.dedicatedLastCode = "" + ready := b.dedicatedReady + b.dedicatedMu.Unlock() + if ready != nil { + select { + case <-ready: + default: + close(ready) + } + } + return nil +} + +func (b *bulkHandle) replaceDedicatedConn(conn net.Conn) (net.Conn, *bulkDedicatedSender, error) { + if b == nil { + return nil, nil, io.ErrClosedPipe + } + if conn == nil { + return nil, nil, net.ErrClosed + } + b.dedicatedMu.Lock() + oldConn := b.dedicatedConn + oldOwned := b.dedicatedConnOwned + oldSender := b.dedicatedSender + b.dedicatedConn = conn + b.dedicatedConnOwned = true + b.dedicatedSender = nil + b.dedicatedWriteClosed = false + b.dedicatedState = bulkDedicatedAttachStateAttached + b.dedicatedLastCode = "" + ready := b.dedicatedReady + b.dedicatedMu.Unlock() + if ready != nil { + select { + case <-ready: + default: + close(ready) + } + } + if !oldOwned { + oldConn = nil + } + return oldConn, oldSender, nil +} + +func (b *bulkHandle) replaceDedicatedConnShared(conn net.Conn) (net.Conn, *bulkDedicatedSender, error) { + if b == nil { + return nil, nil, io.ErrClosedPipe + } + if conn == nil { + return nil, nil, net.ErrClosed + } + b.dedicatedMu.Lock() + oldConn := b.dedicatedConn + oldOwned := b.dedicatedConnOwned + oldSender := b.dedicatedSender + b.dedicatedConn = conn + b.dedicatedConnOwned = false + b.dedicatedSender = nil + b.dedicatedWriteClosed = false + b.dedicatedState = bulkDedicatedAttachStateAttached + b.dedicatedLastCode = "" + ready := b.dedicatedReady + b.dedicatedMu.Unlock() + if ready != nil { + select { + case <-ready: + default: + close(ready) + } + } + if !oldOwned { + oldConn = nil + } + return oldConn, oldSender, nil +} + func (b *bulkHandle) bestEffortCloseDedicatedWriteHalf() { if b == nil || !b.dedicated { return } b.dedicatedMu.Lock() conn := b.dedicatedConn + owned := b.dedicatedConnOwned alreadyClosed := b.dedicatedWriteClosed b.dedicatedMu.Unlock() - if conn == nil || alreadyClosed { + if conn == nil || alreadyClosed || !owned { return } type closeWriter interface { @@ -458,16 +842,40 @@ func (b *bulkHandle) setClientSnapshotOwner(client *ClientCommon) { b.client = client } -func (b *bulkHandle) clearDedicatedConn() net.Conn { +func (b *bulkHandle) clearDedicatedConn() (net.Conn, bool) { if b == nil { - return nil + return nil, false } b.dedicatedMu.Lock() conn := b.dedicatedConn + owned := b.dedicatedConnOwned b.dedicatedConn = nil + b.dedicatedConnOwned = false b.dedicatedWriteClosed = false b.dedicatedMu.Unlock() - return conn + return conn, owned +} + +func (b *bulkHandle) markDedicatedActiveReserved() { + if b == nil { + return + } + b.dedicatedMu.Lock() + b.dedicatedActiveLease = true + b.dedicatedMu.Unlock() +} + +func (b *bulkHandle) releaseDedicatedActiveReserved() bool { + if b == nil { + return false + } + b.dedicatedMu.Lock() + defer b.dedicatedMu.Unlock() + if !b.dedicatedActiveLease { + return false + } + b.dedicatedActiveLease = false + return true } func (b *bulkHandle) markAcceptDispatched() bool { @@ -483,6 +891,135 @@ func (b *bulkHandle) markAcceptDispatched() bool { return true } +func (b *bulkHandle) markAcceptHandled() { + if b == nil { + return + } + b.acceptMu.Lock() + b.acceptDispatched = true + b.acceptMu.Unlock() +} + +func (b *bulkHandle) setAcceptNotify(fn func(error)) { + if b == nil { + return + } + b.acceptMu.Lock() + b.acceptNotifyFn = fn + b.acceptMu.Unlock() +} + +func (b *bulkHandle) clearAcceptNotify() { + if b == nil { + return + } + b.acceptMu.Lock() + b.acceptNotifyFn = nil + b.acceptMu.Unlock() +} + +func (b *bulkHandle) markAcceptReady(err error) { + if b == nil { + return + } + b.acceptMu.Lock() + if b.acceptReadyDone { + b.acceptMu.Unlock() + return + } + b.acceptReadyDone = true + b.acceptReadyErr = err + ch := b.acceptReady + b.acceptMu.Unlock() + if ch != nil { + select { + case <-ch: + default: + close(ch) + } + } +} + +func (b *bulkHandle) waitAcceptReady(ctx context.Context) error { + if b == nil { + return io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + b.acceptMu.Lock() + if b.acceptReadyDone { + err := b.acceptReadyErr + b.acceptMu.Unlock() + return err + } + ready := b.acceptReady + b.acceptMu.Unlock() + select { + case <-ctx.Done(): + return ctx.Err() + case <-b.Context().Done(): + if err := b.acceptReadyErrorSnapshot(); err != nil { + return err + } + if err := b.resetErrSnapshot(); err != nil { + return err + } + return context.Canceled + case <-ready: + return b.acceptReadyErrorSnapshot() + } +} + +func (b *bulkHandle) acceptReadyErrorSnapshot() error { + if b == nil { + return io.ErrClosedPipe + } + b.acceptMu.Lock() + defer b.acceptMu.Unlock() + return b.acceptReadyErr +} + +func (b *bulkHandle) notifyAcceptStarted() { + if b == nil { + return + } + var notify func(error) + b.acceptMu.Lock() + if b.acceptNotifySent { + b.acceptMu.Unlock() + return + } + notify = b.acceptNotifyFn + b.acceptNotifyFn = nil + if notify == nil { + b.acceptMu.Unlock() + return + } + b.acceptNotifySent = true + b.acceptMu.Unlock() + go notify(nil) +} + +func (b *bulkHandle) finishAcceptDispatch(err error) { + if b == nil { + return + } + var notify func(error) + b.acceptMu.Lock() + if b.acceptNotifySent { + b.acceptMu.Unlock() + return + } + notify = b.acceptNotifyFn + b.acceptNotifyFn = nil + b.acceptNotifySent = true + b.acceptMu.Unlock() + if notify != nil { + go notify(err) + } +} + func (b *bulkHandle) SessionEpoch() uint64 { if b == nil { return 0 @@ -518,13 +1055,18 @@ func (b *bulkHandle) dataIDSnapshot() uint64 { } func (b *bulkHandle) nextOutboundDataSeq() uint64 { - if b == nil { + return b.reserveOutboundDataSeqs(1) +} + +func (b *bulkHandle) reserveOutboundDataSeqs(count int) uint64 { + if b == nil || count <= 0 { return 0 } b.mu.Lock() defer b.mu.Unlock() - b.outboundSeq++ - return b.outboundSeq + start := b.outboundSeq + 1 + b.outboundSeq += uint64(count) + return start } func (b *bulkHandle) Read(p []byte) (int, error) { @@ -534,16 +1076,20 @@ func (b *bulkHandle) Read(p []byte) (int, error) { if b == nil { return 0, io.ErrClosedPipe } + b.notifyAcceptStarted() for { b.mu.Lock() localReadClosed := b.localReadClosed - if len(b.readBuf) > 0 { - n := copy(p, b.readBuf) - b.readBuf = b.readBuf[n:] + if len(b.readBuf.data) > 0 { + n := copy(p, b.readBuf.data) + b.readBuf.data = b.readBuf.data[n:] b.bufferedBytes -= n if b.bufferedBytes < 0 { b.bufferedBytes = 0 } + if len(b.readBuf.data) == 0 { + b.readBuf.clear() + } b.recordReadLocked(n, time.Now()) b.mu.Unlock() b.maybeSendWindowRelease(n, false) @@ -551,7 +1097,7 @@ func (b *bulkHandle) Read(p []byte) (int, error) { } if len(b.readQueue) > 0 { b.readBuf = b.readQueue[0] - b.readQueue[0] = nil + b.readQueue[0] = bulkReadChunk{} b.readQueue = b.readQueue[1:] b.mu.Unlock() continue @@ -588,6 +1134,7 @@ func (b *bulkHandle) Write(p []byte) (int, error) { if b == nil { return 0, io.ErrClosedPipe } + b.notifyAcceptStarted() b.writeMu.Lock() defer b.writeMu.Unlock() b.mu.Lock() @@ -609,55 +1156,53 @@ func (b *bulkHandle) Write(p []byte) (int, error) { if sendDataFn == nil { return 0, errBulkDataPathNotReady } - if b.dedicated && sendWriteFn != nil { + if sendWriteFn != nil { written := 0 for written < len(p) { end := len(p) if b.windowBytes > 0 && end-written > b.windowBytes { end = written + b.windowBytes } - part := p[written:end] - sendCtx, cancel, err := bulkWriteContext(bulkCtx, writeTimeout) - if err != nil { - if written > 0 { - b.recordWrite(written, time.Now()) + if chunkSize > 0 && b.maxInFlight > 0 { + maxPartBytes := chunkSize * b.maxInFlight + if maxPartBytes > 0 && end-written > maxPartBytes { + end = written + maxPartBytes } + } + part := p[written:end] + partChunks := bulkPayloadChunkCount(len(part), chunkSize) + sendCtx, cancel, err := b.newWriteContext(bulkCtx, writeTimeout) + if err != nil { return written, err } - if err := b.acquireOutboundWindow(sendCtx, len(part)); err != nil { + if err := b.acquireOutboundWindow(sendCtx, len(part), partChunks); err != nil { cancel() - if written > 0 { - b.recordWrite(written, time.Now()) - } return written, b.normalizeWriteError(err) } - partWritten, err := sendWriteFn(sendCtx, b, part) + startSeq := b.reserveOutboundDataSeqs(partChunks) + if b.dedicated { + partWritten, err := b.executeSendWrite(sendCtx, startSeq, part, partChunks, false) + cancel() + written += partWritten + if err != nil { + return written, err + } + continue + } + owned := getBulkAsyncWritePayload(len(part)) + copy(owned, part) + err = b.enqueueAsyncWrite(sendCtx, bulkAsyncWriteRequest{ + startSeq: startSeq, + payload: owned, + chunks: partChunks, + }) cancel() - if partWritten < 0 { - partWritten = 0 - } - if partWritten > len(part) { - partWritten = len(part) - } - if partWritten < len(part) { - b.rollbackOutboundWindow(len(part) - partWritten) - } - written += partWritten if err != nil { - if written > 0 { - b.recordWrite(written, time.Now()) - } + putBulkAsyncWritePayload(owned) + b.rollbackOutboundWindow(len(part), partChunks) return written, b.normalizeWriteError(err) } - if partWritten != len(part) { - if written > 0 { - b.recordWrite(written, time.Now()) - } - return written, io.ErrShortWrite - } - } - if written > 0 { - b.recordWrite(written, time.Now()) + written += len(part) } return written, nil } @@ -671,14 +1216,14 @@ func (b *bulkHandle) Write(p []byte) (int, error) { end = len(p) } chunk := p[written:end] - sendCtx, cancel, err := bulkWriteContext(bulkCtx, writeTimeout) + sendCtx, cancel, err := b.newWriteContext(bulkCtx, writeTimeout) if err != nil { if written > 0 { b.recordWrite(written, time.Now()) } return written, err } - if err := b.acquireOutboundWindow(sendCtx, len(chunk)); err != nil { + if err := b.acquireOutboundWindow(sendCtx, len(chunk), 1); err != nil { cancel() if written > 0 { b.recordWrite(written, time.Now()) @@ -688,7 +1233,7 @@ func (b *bulkHandle) Write(p []byte) (int, error) { err = sendDataFn(sendCtx, b, chunk) cancel() if err != nil { - b.rollbackOutboundWindow(len(chunk)) + b.rollbackOutboundWindow(len(chunk), 1) if written > 0 { b.recordWrite(written, time.Now()) } @@ -714,6 +1259,7 @@ func (b *bulkHandle) close(full bool) error { if b == nil { return nil } + b.notifyAcceptStarted() b.writeMu.Lock() defer b.writeMu.Unlock() b.mu.Lock() @@ -742,6 +1288,7 @@ func (b *bulkHandle) close(full bool) error { } b.localReadClosed = true b.clearBufferedDataLocked() + b.closeWriteStateLocked() shouldFinalize := b.shouldFinalizeLocked() b.mu.Unlock() b.notifyReadable() @@ -752,6 +1299,9 @@ func (b *bulkHandle) close(full bool) error { } closeFn := b.closeFn b.mu.Unlock() + if err := b.waitPendingAsyncWrites(context.Background()); err != nil { + return err + } if closeFn != nil { if err := closeFn(context.Background(), b, full); err != nil && !errors.Is(err, errBulkNotFound) && !b.canIgnoreDedicatedCloseSendError(err) { return err @@ -764,6 +1314,7 @@ func (b *bulkHandle) close(full bool) error { return nil } b.localClosed = true + b.closeWriteStateLocked() if full { b.localReadClosed = true b.clearBufferedDataLocked() @@ -783,6 +1334,7 @@ func (b *bulkHandle) Reset(err error) error { if b == nil { return nil } + b.notifyAcceptStarted() resetErr := bulkResetError(err) b.mu.Lock() if b.resetErr != nil { @@ -826,6 +1378,7 @@ func (b *bulkHandle) markPeerClosed() { b.mu.Lock() b.remoteClosed = true b.peerReadClosed = true + b.closeWriteStateLocked() shouldFinalize := b.shouldFinalizeLocked() b.notifyFlowLocked() b.mu.Unlock() @@ -839,13 +1392,16 @@ func (b *bulkHandle) markReset(err error) { if b == nil { return } + resetErr := bulkResetError(err) b.mu.Lock() if b.resetErr == nil { - b.resetErr = bulkResetError(err) + b.resetErr = resetErr b.clearBufferedDataLocked() + b.closeWriteStateLocked() } b.notifyFlowLocked() b.mu.Unlock() + b.markAcceptReady(resetErr) b.notifyReadable() b.finalize() } @@ -862,51 +1418,71 @@ func (b *bulkHandle) pushOwnedChunkNoReset(chunk []byte) error { return b.pushChunkWithOwnershipOptions(chunk, true, false) } +func (b *bulkHandle) pushOwnedChunkWithReleaseNoReset(chunk []byte, release func()) error { + return b.pushChunkWithOwnershipOptionsAndRelease(chunk, true, false, release) +} + func (b *bulkHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { return b.pushChunkWithOwnershipOptions(chunk, owned, true) } func (b *bulkHandle) pushChunkWithOwnershipOptions(chunk []byte, owned bool, resetOnOverflow bool) error { + return b.pushChunkWithOwnershipOptionsAndRelease(chunk, owned, resetOnOverflow, nil) +} + +func (b *bulkHandle) pushChunkWithOwnershipOptionsAndRelease(chunk []byte, owned bool, resetOnOverflow bool, release func()) error { if b == nil { return io.ErrClosedPipe } if len(chunk) == 0 { + if release != nil { + release() + } return nil } - stored := chunk + stored := bulkReadChunk{data: chunk, release: release} if !owned { - stored = append([]byte(nil), chunk...) + stored.data = append([]byte(nil), chunk...) + if stored.release != nil { + stored.release() + stored.release = nil + } } b.mu.Lock() if b.resetErr != nil { err := b.resetErr b.mu.Unlock() + stored.clear() return err } if b.inboundQueueLimit > 0 && b.bufferedChunkCountLocked() >= b.inboundQueueLimit { if !resetOnOverflow { b.mu.Unlock() + stored.clear() return errBulkBackpressureExceeded } err := b.markResetLocked(errBulkBackpressureExceeded) b.mu.Unlock() + stored.clear() b.notifyReadable() b.finalize() return err } - if b.inboundBytesLimit > 0 && b.bufferedBytes+len(stored) > b.inboundBytesLimit { + if b.inboundBytesLimit > 0 && b.bufferedBytes+len(stored.data) > b.inboundBytesLimit { if !resetOnOverflow { b.mu.Unlock() + stored.clear() return errBulkBackpressureExceeded } err := b.markResetLocked(errBulkBackpressureExceeded) b.mu.Unlock() + stored.clear() b.notifyReadable() b.finalize() return err } b.readQueue = append(b.readQueue, stored) - b.bufferedBytes += len(stored) + b.bufferedBytes += len(stored.data) b.notifyReadableLocked() b.mu.Unlock() return nil @@ -919,6 +1495,7 @@ func (b *bulkHandle) markResetLocked(err error) error { if b.resetErr == nil { b.resetErr = bulkResetError(err) b.clearBufferedDataLocked() + b.closeWriteStateLocked() } return b.resetErr } @@ -927,11 +1504,12 @@ func (b *bulkHandle) clearBufferedDataLocked() { if b == nil { return } + b.readBuf.clear() for i := range b.readQueue { - b.readQueue[i] = nil + b.readQueue[i].clear() } b.readQueue = nil - b.readBuf = nil + b.readBuf = bulkReadChunk{} b.bufferedBytes = 0 } @@ -963,11 +1541,6 @@ func (b *bulkHandle) maybeSendWindowRelease(consumed int, force bool) { if b == nil || !b.flowControlEnabled() { return } - var ( - bytes int64 - chunks int - release bulkReleaseSender - ) b.mu.Lock() if consumed > 0 { b.pendingReleaseBytes += int64(consumed) @@ -977,18 +1550,56 @@ func (b *bulkHandle) maybeSendWindowRelease(consumed int, force bool) { b.mu.Unlock() return } - bytes = b.pendingReleaseBytes - chunks = b.pendingReleaseChunks - release = b.releaseFn - b.pendingReleaseBytes = 0 - b.pendingReleaseChunks = 0 b.mu.Unlock() - if release != nil && (bytes > 0 || chunks > 0) { - _ = release(b, bytes, chunks) + b.scheduleWindowRelease() +} + +func (b *bulkHandle) scheduleWindowRelease() { + if b == nil || b.releaseNotify == nil { + return + } + select { + case b.releaseNotify <- struct{}{}: + default: } } -func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int) error { +func (b *bulkHandle) takePendingWindowRelease() (int64, int, bulkReleaseSender) { + if b == nil { + return 0, 0, nil + } + b.mu.Lock() + defer b.mu.Unlock() + bytes := b.pendingReleaseBytes + chunks := b.pendingReleaseChunks + release := b.releaseFn + b.pendingReleaseBytes = 0 + b.pendingReleaseChunks = 0 + return bytes, chunks, release +} + +func (b *bulkHandle) runWindowReleaseLoop() { + if b == nil { + return + } + defer close(b.releaseWorkerDone) + for { + select { + case <-b.Context().Done(): + return + case <-b.releaseNotify: + } + for { + bytes, chunks, release := b.takePendingWindowRelease() + if release == nil || (bytes <= 0 && chunks <= 0) { + break + } + _ = release(b, bytes, chunks) + } + } +} + +func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int, chunks int) error { if b == nil || size <= 0 || !b.flowControlEnabled() { return nil } @@ -996,6 +1607,9 @@ func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int) error ctx = context.Background() } need := int64(size) + if chunks <= 0 { + chunks = 1 + } for { b.mu.Lock() if b.resetErr != nil { @@ -1016,14 +1630,14 @@ func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int) error } chunksOK := true if b.maxInFlight > 0 { - chunksOK = b.outboundInFlight < b.maxInFlight + chunksOK = b.outboundInFlight+chunks <= b.maxInFlight } if bytesOK && chunksOK { if b.windowBytes > 0 { b.outboundAvailBytes -= need } if b.maxInFlight > 0 { - b.outboundInFlight++ + b.outboundInFlight += chunks } b.mu.Unlock() return nil @@ -1041,10 +1655,13 @@ func (b *bulkHandle) acquireOutboundWindow(ctx context.Context, size int) error } } -func (b *bulkHandle) rollbackOutboundWindow(size int) { +func (b *bulkHandle) rollbackOutboundWindow(size int, chunks int) { if b == nil || size <= 0 || !b.flowControlEnabled() { return } + if chunks <= 0 { + chunks = 1 + } b.mu.Lock() if b.windowBytes > 0 { b.outboundAvailBytes += int64(size) @@ -1054,7 +1671,10 @@ func (b *bulkHandle) rollbackOutboundWindow(size int) { } } if b.maxInFlight > 0 && b.outboundInFlight > 0 { - b.outboundInFlight-- + b.outboundInFlight -= chunks + if b.outboundInFlight < 0 { + b.outboundInFlight = 0 + } } b.notifyFlowLocked() b.mu.Unlock() @@ -1087,7 +1707,7 @@ func (b *bulkHandle) bufferedChunkCountLocked() int { return 0 } count := len(b.readQueue) - if len(b.readBuf) > 0 { + if len(b.readBuf.data) > 0 { count++ } return count @@ -1101,7 +1721,7 @@ func (b *bulkHandle) shouldFinalizeLocked() bool { return true } if b.dedicated { - return (b.peerReadClosed && b.remoteClosed) || (b.localClosed && b.remoteClosed) + return b.localClosed && b.remoteClosed } return b.localReadClosed || (b.peerReadClosed && b.remoteClosed) || (b.localClosed && b.remoteClosed) } @@ -1111,36 +1731,45 @@ func (b *bulkHandle) snapshot() BulkSnapshot { return BulkSnapshot{} } dedicatedAttached := b.dedicatedAttachedSnapshot() + dedicatedState, dedicatedAttempts, dedicatedLastCode, dedicatedDataStarted := b.dedicatedAttachDiagnosticsSnapshot() b.mu.Lock() defer b.mu.Unlock() snapshot := BulkSnapshot{ - ID: b.id, - DataID: b.dataID, - Scope: normalizeFileScope(b.runtimeScope), - Range: b.rangeSpec, - Metadata: cloneBulkMetadata(b.metadata), - Dedicated: b.dedicated, - DedicatedAttached: dedicatedAttached, - SessionEpoch: b.sessionEpoch, - TransportGeneration: b.transportGeneration, - LocalClosed: b.localClosed, - LocalReadClosed: b.localReadClosed, - RemoteClosed: b.remoteClosed, - PeerReadClosed: b.peerReadClosed, - BufferedChunks: b.bufferedChunkCountLocked(), - BufferedBytes: b.bufferedBytes, - ReadTimeout: b.readTimeout, - WriteTimeout: b.writeTimeout, - ChunkSize: b.chunkSize, - WindowBytes: b.windowBytes, - MaxInFlight: b.maxInFlight, - BytesRead: b.bytesRead, - BytesWritten: b.bytesWritten, - ReadCalls: b.readCalls, - WriteCalls: b.writeCalls, - OpenedAt: b.createdAt, - LastReadAt: b.lastReadAt, - LastWriteAt: b.lastWriteAt, + ID: b.id, + DataID: b.dataID, + FastPathVersion: normalizeBulkFastPathVersion(b.fastPathVersion), + Scope: normalizeFileScope(b.runtimeScope), + Range: b.rangeSpec, + Metadata: cloneBulkMetadata(b.metadata), + Dedicated: b.dedicated, + DedicatedLaneID: b.dedicatedLaneID, + DedicatedAttached: dedicatedAttached, + DedicatedAttachState: bulkDedicatedAttachStateName( + dedicatedState, + ), + DedicatedAttachAttempts: dedicatedAttempts, + DedicatedAttachLastCode: dedicatedLastCode, + DedicatedDataStarted: dedicatedDataStarted, + SessionEpoch: b.sessionEpoch, + TransportGeneration: b.transportGeneration, + LocalClosed: b.localClosed, + LocalReadClosed: b.localReadClosed, + RemoteClosed: b.remoteClosed, + PeerReadClosed: b.peerReadClosed, + BufferedChunks: b.bufferedChunkCountLocked(), + BufferedBytes: b.bufferedBytes, + ReadTimeout: b.readTimeout, + WriteTimeout: b.writeTimeout, + ChunkSize: b.chunkSize, + WindowBytes: b.windowBytes, + MaxInFlight: b.maxInFlight, + BytesRead: b.bytesRead, + BytesWritten: b.bytesWritten, + ReadCalls: b.readCalls, + WriteCalls: b.writeCalls, + OpenedAt: b.createdAt, + LastReadAt: b.lastReadAt, + LastWriteAt: b.lastWriteAt, } if b.logical != nil { snapshot.LogicalClientID = b.logical.ID() @@ -1160,6 +1789,7 @@ func (b *bulkHandle) snapshot() BulkSnapshot { snapshot.BindingCurrent = diag.BindingCurrent snapshot.BindingReason = diag.BindingReason snapshot.BindingError = diag.BindingError + snapshot.BindingBulkAdaptiveSoftPayloadBytes = diag.BindingBulkAdaptiveSoftPayloadBytes snapshot.TransportAttached = diag.TransportAttached snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn snapshot.TransportCurrent = diag.TransportCurrent @@ -1176,14 +1806,24 @@ func (b *bulkHandle) finalize() { if b == nil { return } + b.markDedicatedAttachClosed() b.maybeSendWindowRelease(0, true) if b.cancel != nil { b.cancel() } + if b.writeCtxCancel != nil { + b.writeCtxCancel() + } if sender := b.clearDedicatedSender(); sender != nil { sender.stop() } - if conn := b.clearDedicatedConn(); conn != nil { + if b.client != nil && b.releaseDedicatedActiveReserved() { + b.client.releaseBulkDedicatedActiveSlot() + } + if b.client != nil { + b.client.releaseBulkDedicatedLane(b.dedicatedLaneIDSnapshot()) + } + if conn, owned := b.clearDedicatedConn(); conn != nil && owned { _ = conn.Close() } if b.runtime != nil { @@ -1337,6 +1977,252 @@ func (b *bulkHandle) notifyFlowLocked() { } } +func (b *bulkHandle) closeWriteStateLocked() { + if b == nil || b.writeStateClosed { + return + } + b.writeStateClosed = true + if b.writeStateDone != nil { + close(b.writeStateDone) + } +} + +func bulkAsyncWriteQueueSize(maxInFlight int) int { + if maxInFlight <= 0 { + maxInFlight = defaultBulkOpenMaxInFlight + } + if maxInFlight < 8 { + return 8 + } + if maxInFlight > 128 { + return 128 + } + return maxInFlight +} + +func getBulkAsyncWritePayload(size int) []byte { + if size <= 0 { + return nil + } + if pooled, ok := bulkAsyncWritePayloadPool.Get().([]byte); ok && cap(pooled) >= size { + return pooled[:size] + } + return make([]byte, size) +} + +func putBulkAsyncWritePayload(buf []byte) { + if cap(buf) == 0 || cap(buf) > 8*1024*1024 { + return + } + bulkAsyncWritePayloadPool.Put(buf[:0]) +} + +func (b *bulkHandle) beginPendingAsyncWriteLocked() { + if b == nil { + return + } + if b.pendingAsyncWrites == 0 { + b.writeDrain = make(chan struct{}) + } + b.pendingAsyncWrites++ +} + +func (b *bulkHandle) finishPendingAsyncWrite() { + if b == nil { + return + } + b.mu.Lock() + if b.pendingAsyncWrites > 0 { + b.pendingAsyncWrites-- + if b.pendingAsyncWrites == 0 && b.writeDrain != nil { + close(b.writeDrain) + } + } + b.mu.Unlock() +} + +func (b *bulkHandle) waitPendingAsyncWrites(ctx context.Context) error { + if b == nil { + return io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + b.mu.Lock() + if b.pendingAsyncWrites == 0 { + b.mu.Unlock() + return nil + } + drain := b.writeDrain + b.mu.Unlock() + select { + case <-ctx.Done(): + if err := b.writeStateErrorSnapshot(); err != nil { + return err + } + return normalizeStreamDeadlineError(ctx.Err()) + case <-b.Context().Done(): + if err := b.writeStateErrorSnapshot(); err != nil { + return err + } + return context.Canceled + case <-drain: + return b.writeStateErrorSnapshot() + } +} + +func (b *bulkHandle) enqueueAsyncWrite(ctx context.Context, req bulkAsyncWriteRequest) error { + if b == nil { + return io.ErrClosedPipe + } + if ctx == nil { + ctx = context.Background() + } + b.mu.Lock() + if b.resetErr != nil { + err := b.resetErr + b.mu.Unlock() + return err + } + if b.localClosed || b.peerReadClosed { + b.mu.Unlock() + return io.ErrClosedPipe + } + queue := b.writeQueue + if queue == nil { + b.mu.Unlock() + return errBulkDataPathNotReady + } + b.beginPendingAsyncWriteLocked() + b.mu.Unlock() + select { + case queue <- req: + return nil + case <-ctx.Done(): + b.finishPendingAsyncWrite() + return normalizeStreamDeadlineError(ctx.Err()) + case <-b.Context().Done(): + b.finishPendingAsyncWrite() + if err := b.writeStateErrorSnapshot(); err != nil { + return err + } + return context.Canceled + } +} + +func (b *bulkHandle) runAsyncWriteLoop() { + if b == nil { + return + } + defer close(b.writeWorkerDone) + for { + select { + case <-b.Context().Done(): + b.drainPendingAsyncWrites() + return + case req := <-b.writeQueue: + b.processAsyncWrite(req) + } + } +} + +func (b *bulkHandle) drainPendingAsyncWrites() { + if b == nil || b.writeQueue == nil { + return + } + for { + select { + case req := <-b.writeQueue: + b.rollbackOutboundWindow(len(req.payload), req.chunks) + putBulkAsyncWritePayload(req.payload) + b.finishPendingAsyncWrite() + default: + return + } + } +} + +func (b *bulkHandle) processAsyncWrite(req bulkAsyncWriteRequest) { + if b == nil { + return + } + defer putBulkAsyncWritePayload(req.payload) + defer b.finishPendingAsyncWrite() + if len(req.payload) == 0 { + return + } + if err := b.writeStateErrorSnapshot(); err != nil { + b.rollbackOutboundWindow(len(req.payload), req.chunks) + return + } + b.notifyAcceptStarted() + b.mu.Lock() + writeTimeout := b.writeTimeout + bulkCtx := b.ctx + b.mu.Unlock() + sendCtx, cancel, err := b.newWriteContext(bulkCtx, writeTimeout) + if err != nil { + b.rollbackOutboundWindow(len(req.payload), req.chunks) + if stateErr := b.writeStateErrorSnapshot(); stateErr == nil { + b.markReset(err) + } + return + } + _, writeErr := b.executeSendWrite(sendCtx, req.startSeq, req.payload, req.chunks, true) + cancel() + if writeErr != nil { + b.markReset(writeErr) + } +} + +func (b *bulkHandle) executeSendWrite(ctx context.Context, startSeq uint64, payload []byte, chunks int, payloadOwned bool) (int, error) { + if b == nil { + return 0, io.ErrClosedPipe + } + if len(payload) == 0 { + return 0, nil + } + if chunks <= 0 { + chunks = 1 + } + if err := b.writeStateErrorSnapshot(); err != nil { + b.rollbackOutboundWindow(len(payload), chunks) + return 0, err + } + b.mu.Lock() + sendWriteFn := b.sendWriteFn + chunkSize := b.chunkSize + b.mu.Unlock() + if sendWriteFn == nil { + b.rollbackOutboundWindow(len(payload), chunks) + return 0, errBulkDataPathNotReady + } + written, err := sendWriteFn(ctx, b, startSeq, payload, payloadOwned) + if written < 0 { + written = 0 + } + if written > len(payload) { + written = len(payload) + } + if written > 0 { + b.recordWrite(written, time.Now()) + } + if written < len(payload) { + remaining := len(payload) - written + b.rollbackOutboundWindow(remaining, bulkPayloadChunkCount(remaining, chunkSize)) + } + if err != nil { + if b.canIgnoreDedicatedCloseSendError(err) { + return written, nil + } + return written, b.normalizeWriteError(err) + } + if written != len(payload) { + return written, io.ErrShortWrite + } + return written, nil +} + func (b *bulkHandle) normalizeWriteError(err error) error { if err == nil { return nil @@ -1347,6 +2233,34 @@ func (b *bulkHandle) normalizeWriteError(err error) error { return normalizeStreamDeadlineError(err) } +func (b *bulkHandle) writeStateDoneSnapshot() <-chan struct{} { + if b == nil { + return nil + } + b.mu.Lock() + defer b.mu.Unlock() + return b.writeStateDone +} + +func (b *bulkHandle) newWriteContext(parent context.Context, timeout time.Duration) (context.Context, func(), error) { + baseParent := parent + if b != nil && parent == b.ctx && b.writeCtx != nil { + baseParent = b.writeCtx + } + ctx, cancel, err := bulkWriteContext(baseParent, timeout) + if err != nil { + return nil, func() {}, err + } + if b == nil { + return ctx, cancel, nil + } + if stateErr := b.writeStateErrorSnapshot(); stateErr != nil { + cancel() + return nil, func() {}, stateErr + } + return ctx, cancel, nil +} + func (b *bulkHandle) canIgnoreDedicatedCloseSendError(err error) bool { if b == nil || !b.dedicated || err == nil { return false @@ -1372,8 +2286,7 @@ func bulkWriteContext(parent context.Context, timeout time.Duration) (context.Co return nil, func() {}, normalizeStreamDeadlineError(context.DeadlineExceeded) } if deadline.IsZero() { - ctx, cancel := context.WithCancel(parent) - return ctx, cancel, nil + return parent, func() {}, nil } ctx, cancel := context.WithDeadline(parent, deadline) return ctx, cancel, nil @@ -1382,6 +2295,10 @@ func bulkWriteContext(parent context.Context, timeout time.Duration) (context.Co func normalizeBulkOpenRequest(req BulkOpenRequest) BulkOpenRequest { req.Range = normalizeBulkRange(req.Range) req.Metadata = cloneBulkMetadata(req.Metadata) + req.FastPathVersion = normalizeBulkFastPathVersion(req.FastPathVersion) + if req.Dedicated && req.DedicatedLaneID == 0 { + req.DedicatedLaneID = 1 + } if req.ChunkSize <= 0 { req.ChunkSize = defaultBulkChunkSize } @@ -1401,27 +2318,91 @@ func normalizeBulkOpenRequest(req BulkOpenRequest) BulkOpenRequest { } func normalizeBulkOpenOptions(opt BulkOpenOptions) BulkOpenOptions { - req := normalizeBulkOpenRequest(BulkOpenRequest{ - BulkID: opt.ID, - Range: opt.Range, - Metadata: opt.Metadata, - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, - Dedicated: opt.Dedicated, + mode := normalizeBulkOpenMode(opt.Mode) + switch mode { + case BulkOpenModeDefault: + // Preserve legacy behavior when Mode is not explicitly set. + if opt.Dedicated { + mode = BulkOpenModeDedicated + } else { + mode = BulkOpenModeShared + } + } + readTimeout := opt.ReadTimeout + if readTimeout < 0 { + readTimeout = defaultBulkControlReadTimeout + } + writeTimeout := opt.WriteTimeout + if writeTimeout < 0 { + writeTimeout = defaultBulkControlWriteTimeout + } + return BulkOpenOptions{ + ID: opt.ID, + Range: normalizeBulkRange(opt.Range), + Metadata: cloneBulkMetadata(opt.Metadata), + ReadTimeout: readTimeout, + WriteTimeout: writeTimeout, + Mode: mode, + Dedicated: mode == BulkOpenModeDedicated, ChunkSize: opt.ChunkSize, WindowBytes: opt.WindowBytes, MaxInFlight: opt.MaxInFlight, - }) - return BulkOpenOptions{ - ID: req.BulkID, - Range: req.Range, - Metadata: req.Metadata, - ReadTimeout: req.ReadTimeout, - WriteTimeout: req.WriteTimeout, - Dedicated: req.Dedicated, - ChunkSize: req.ChunkSize, - WindowBytes: req.WindowBytes, - MaxInFlight: req.MaxInFlight, + } +} + +func normalizeBulkOpenMode(mode BulkOpenMode) BulkOpenMode { + switch mode { + case BulkOpenModeDefault, BulkOpenModeAuto, BulkOpenModeShared, BulkOpenModeDedicated: + return mode + default: + return BulkOpenModeDefault + } +} + +func initialBulkDedicatedAttachState(dedicated bool) bulkDedicatedAttachState { + if dedicated { + return bulkDedicatedAttachStatePending + } + return bulkDedicatedAttachStateAttached +} + +func bulkPayloadChunkCount(payloadLen int, chunkSize int) int { + if payloadLen <= 0 { + return 0 + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + return (payloadLen + chunkSize - 1) / chunkSize +} + +func bulkDedicatedAttachStateName(state bulkDedicatedAttachState) string { + switch state { + case bulkDedicatedAttachStatePending: + return "pending" + case bulkDedicatedAttachStateAttached: + return "attached" + case bulkDedicatedAttachStateDegraded: + return "degraded" + case bulkDedicatedAttachStateClosed: + return "closed" + default: + return "unknown" + } +} + +func bulkOpenModeName(mode BulkOpenMode) string { + switch normalizeBulkOpenMode(mode) { + case BulkOpenModeAuto: + return "auto" + case BulkOpenModeShared: + return "shared" + case BulkOpenModeDedicated: + return "dedicated" + case BulkOpenModeDefault: + fallthrough + default: + return "default" } } diff --git a/bulk_batch_sender.go b/bulk_batch_sender.go index 5b00fef..02e697d 100644 --- a/bulk_batch_sender.go +++ b/bulk_batch_sender.go @@ -9,7 +9,9 @@ import ( ) const ( - bulkBatchMaxPayloads = 16 + bulkBatchMaxPayloads = 64 + bulkBatchMaxPayloadBytes = bulkFastBatchMaxPlainBytes + bulkBatchMaxFlushDelay = 50 * time.Microsecond ) const ( @@ -22,48 +24,161 @@ type bulkBatchRequestState struct { value atomic.Int32 } +type bulkBatchCodec struct { + encodeSingle func(bulkFastFrame) ([]byte, func(), error) + encodeBatch func([]bulkFastFrame) ([]byte, func(), error) +} + type bulkBatchRequest struct { - ctx context.Context - payload []byte - deadline time.Time - done chan error - state *bulkBatchRequestState + ctx context.Context + frames []bulkFastFrame + fastPathVersion uint8 + payloadOwned bool + deadline time.Time + done chan error + state *bulkBatchRequestState + release func() +} + +type bulkBatchEncodedPayload struct { + payload []byte + release func() +} + +func (p *bulkBatchEncodedPayload) done() { + if p == nil || p.release == nil { + return + } + p.release() + p.release = nil } type bulkBatchSender struct { - binding *transportBinding - reqCh chan bulkBatchRequest - stopCh chan struct{} - doneCh chan struct{} + binding *transportBinding + codec bulkBatchCodec + writeTimeoutProvider func() time.Duration + reqCh chan bulkBatchRequest + stopCh chan struct{} + doneCh chan struct{} stopOnce sync.Once + flushMu sync.Mutex + queued atomic.Int64 errMu sync.Mutex err error } -func newBulkBatchSender(binding *transportBinding) *bulkBatchSender { +func newBulkBatchSender(binding *transportBinding, codec bulkBatchCodec, writeTimeoutProvider func() time.Duration) *bulkBatchSender { sender := &bulkBatchSender{ - binding: binding, - reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), - stopCh: make(chan struct{}), - doneCh: make(chan struct{}), + binding: binding, + codec: codec, + writeTimeoutProvider: writeTimeoutProvider, + reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), } go sender.run() return sender } -func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { +func (s *bulkBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { + return s.submitFramesOwned(ctx, []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: payload, + }}, fastPathVersion, false) +} + +func (s *bulkBatchSender) submitControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { + return s.submitFramesOwned(ctx, []bulkFastFrame{{ + Type: frameType, + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload, + }}, fastPathVersion, false) +} + +func (s *bulkBatchSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, fastPathVersion uint8, payload []byte, chunkSize int, payloadOwned bool) (int, error) { + if s == nil { + return 0, errTransportDetached + } + if len(payload) == 0 { + return 0, nil + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + seq := startSeq + for written < len(payload) { + var batch [bulkFastBatchMaxItems]bulkFastFrame + frames := batch[:0] + batchBytes := bulkFastBatchHeaderLen + start := written + for written < len(payload) && len(frames) < bulkFastBatchMaxItems { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + frame := bulkFastFrame{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: payload[written:end], + } + frameLen := bulkFastBatchFrameLen(frame) + if len(frames) > 0 && batchBytes+frameLen > bulkFastBatchMaxPlainBytes { + break + } + frames = append(frames, frame) + batchBytes += frameLen + seq++ + written = end + } + if len(frames) == 0 { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + frames = append(frames, bulkFastFrame{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: payload[written:end], + }) + seq++ + written = end + } + if err := s.submitFramesOwned(ctx, frames, fastPathVersion, payloadOwned); err != nil { + return start, err + } + } + return written, nil +} + +func (s *bulkBatchSender) submitFrames(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8) error { + return s.submitFramesOwned(ctx, frames, fastPathVersion, false) +} + +func (s *bulkBatchSender) submitFramesOwned(ctx context.Context, frames []bulkFastFrame, fastPathVersion uint8, payloadOwned bool) error { if s == nil { return errTransportDetached } if ctx == nil { ctx = context.Background() } + if len(frames) == 0 { + return nil + } req := bulkBatchRequest{ - ctx: ctx, - payload: payload, - done: make(chan error, 1), - state: &bulkBatchRequestState{}, + ctx: ctx, + frames: frames, + fastPathVersion: normalizeBulkFastPathVersion(fastPathVersion), + payloadOwned: payloadOwned, + done: make(chan error, 1), + state: &bulkBatchRequestState{}, } if deadline, ok := ctx.Deadline(); ok { req.deadline = deadline @@ -71,10 +186,25 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { if err := s.errSnapshot(); err != nil { return err } + if s.shouldDirectSubmit(req) { + if submitted, err := s.tryDirectSubmit(req); submitted { + return err + } + } + req = cloneQueuedBulkBatchRequest(req) + s.queued.Add(1) select { case <-ctx.Done(): + s.queued.Add(-1) + if req.release != nil { + req.release() + } return normalizeStreamDeadlineError(ctx.Err()) case <-s.stopCh: + s.queued.Add(-1) + if req.release != nil { + req.release() + } return s.stoppedErr() case s.reqCh <- req: } @@ -89,6 +219,55 @@ func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error { } } +func (s *bulkBatchSender) shouldDirectSubmit(req bulkBatchRequest) bool { + if len(req.frames) == 0 { + return false + } + return !bulkBatchRequestSupportsSharedSuperBatch(req) +} + +func (s *bulkBatchSender) tryDirectSubmit(req bulkBatchRequest) (bool, error) { + if s == nil { + return true, errTransportDetached + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-req.ctx.Done(): + return true, normalizeStreamDeadlineError(req.ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + if s.queued.Load() != 0 { + return false, nil + } + if !s.flushMu.TryLock() { + return false, nil + } + defer s.flushMu.Unlock() + if s.queued.Load() != 0 { + return false, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + if !req.tryStart() { + return true, req.canceledErr() + } + if err := req.contextErr(); err != nil { + return true, err + } + err := s.flush([]bulkBatchRequest{req}) + if err != nil { + s.setErr(err) + s.failPending(err) + return true, err + } + return true, nil +} + func (s *bulkBatchSender) run() { defer close(s.doneCh) for { @@ -97,34 +276,83 @@ func (s *bulkBatchSender) run() { return } batch := []bulkBatchRequest{req} + batchBytes := bulkBatchRequestApproxBytes(req) + timer := (*time.Timer)(nil) + timerCh := (<-chan time.Time)(nil) + if bulkBatchShouldWaitForMore(batch, batchBytes) { + timer = time.NewTimer(bulkBatchMaxFlushDelay) + timerCh = timer.C + } drain: - for len(batch) < bulkBatchMaxPayloads { + for len(batch) < bulkBatchMaxPayloads && batchBytes < bulkBatchMaxPayloadBytes { + if timerCh == nil { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return + case next := <-s.reqCh: + batch = append(batch, next) + batchBytes += bulkBatchRequestApproxBytes(next) + default: + break drain + } + continue + } select { case <-s.stopCh: + if timer != nil { + timer.Stop() + } s.failPending(s.stoppedErr()) return case next := <-s.reqCh: batch = append(batch, next) - default: + batchBytes += bulkBatchRequestApproxBytes(next) + case <-timerCh: + timerCh = nil break drain } } - active, payloads := activeBulkBatchRequests(batch) + if timer != nil { + if !timer.Stop() && timerCh != nil { + select { + case <-timer.C: + default: + } + } + } + s.flushMu.Lock() + err := s.errSnapshot() + active := make([]bulkBatchRequest, 0, len(batch)) + for _, item := range batch { + if !item.tryStart() { + s.finishRequest(item, item.canceledErr()) + continue + } + if itemErr := item.contextErr(); itemErr != nil { + s.finishRequest(item, itemErr) + continue + } + active = append(active, item) + } if len(active) == 0 { + s.flushMu.Unlock() continue } - deadline := bulkBatchRequestsEarliestDeadline(active) - err := s.flush(payloads, deadline) + if err == nil { + err = s.flush(active) + } + s.flushMu.Unlock() if err != nil { s.setErr(err) for _, item := range active { - item.done <- err + s.finishRequest(item, err) } s.failPending(err) return } for _, item := range active { - item.done <- err + s.finishRequest(item, nil) } } } @@ -139,37 +367,6 @@ func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) { } } -func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) { - active := make([]bulkBatchRequest, 0, len(batch)) - payloads := make([][]byte, 0, len(batch)) - for _, item := range batch { - if !item.tryStart() { - item.done <- item.canceledErr() - continue - } - if err := item.contextErr(); err != nil { - item.done <- err - continue - } - active = append(active, item) - payloads = append(payloads, item.payload) - } - return active, payloads -} - -func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time { - var deadline time.Time - for _, item := range batch { - if item.deadline.IsZero() { - continue - } - if deadline.IsZero() || item.deadline.Before(deadline) { - deadline = item.deadline - } - } - return deadline -} - func (r bulkBatchRequest) contextErr() error { if r.ctx == nil { return nil @@ -203,7 +400,7 @@ func (r bulkBatchRequest) canceledErr() error { return context.Canceled } -func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error { +func (s *bulkBatchSender) flush(requests []bulkBatchRequest) error { if s == nil || s.binding == nil { return errTransportDetached } @@ -211,9 +408,196 @@ func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error { if queue == nil { return errTransportFrameQueueUnavailable } - return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error { - return writeFramedPayloadBatchUnlocked(conn, queue, payloads) - }) + payloads, err := s.encodeRequests(requests) + if err != nil { + return err + } + defer func() { + for index := range payloads { + payloads[index].done() + } + }() + writeTimeout := s.transportWriteTimeout() + for _, payload := range payloads { + frame := payload.payload + started := time.Now() + err := s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { + return writeFramedPayloadUnlocked(conn, queue, frame) + }) + s.binding.observeBulkAdaptivePayloadWrite(len(frame), time.Since(started), writeTimeout, err) + if err != nil { + return err + } + } + return nil +} + +func (s *bulkBatchSender) encodeRequests(requests []bulkBatchRequest) ([]bulkBatchEncodedPayload, error) { + if len(requests) == 0 { + return nil, nil + } + payloads := make([]bulkBatchEncodedPayload, 0, len(requests)) + batch := make([]bulkFastFrame, 0, minInt(len(requests), bulkFastBatchMaxItems)) + mixedBatchLimit := s.sharedMixedPayloadLimit() + batchRequestIndex := -1 + batchDataID := uint64(0) + batchMixed := false + flushBatch := func() error { + if len(batch) == 0 { + return nil + } + payload, release, err := s.encodeBatch(batch) + if err != nil { + return err + } + payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) + batch = batch[:0] + batchRequestIndex = -1 + batchDataID = 0 + batchMixed = false + return nil + } + batchBytes := bulkFastBatchHeaderLen + for reqIndex, req := range requests { + for _, frame := range req.frames { + if !bulkFastPathSupportsSharedBatch(req.fastPathVersion) { + if err := flushBatch(); err != nil { + return nil, err + } + batchBytes = bulkFastBatchHeaderLen + payload, release, err := s.encodeSingle(frame) + if err != nil { + return nil, err + } + payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) + continue + } + frameLen := bulkFastBatchFrameLen(frame) + if frameLen+bulkFastBatchHeaderLen > bulkFastBatchMaxPlainBytes { + if err := flushBatch(); err != nil { + return nil, err + } + batchBytes = bulkFastBatchHeaderLen + payload, release, err := s.encodeSingle(frame) + if err != nil { + return nil, err + } + payloads = append(payloads, bulkBatchEncodedPayload{payload: payload, release: release}) + continue + } + nextMixed := batchMixed + if len(batch) > 0 && (batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID)) { + nextMixed = true + } + batchLimit := bulkFastBatchMaxPlainBytes + if nextMixed && mixedBatchLimit > 0 && mixedBatchLimit < batchLimit { + batchLimit = mixedBatchLimit + } + if len(batch) > 0 && (len(batch) >= bulkFastBatchMaxItems || batchBytes+frameLen > batchLimit) { + if err := flushBatch(); err != nil { + return nil, err + } + batchBytes = bulkFastBatchHeaderLen + nextMixed = false + } + if len(batch) == 0 { + batchRequestIndex = reqIndex + batchDataID = frame.DataID + batchMixed = false + } else if batchRequestIndex != reqIndex || (batchDataID != 0 && batchDataID != frame.DataID) { + batchMixed = true + } + batch = append(batch, frame) + batchBytes += frameLen + } + } + if err := flushBatch(); err != nil { + return nil, err + } + return payloads, nil +} + +func bulkBatchRequestApproxBytes(req bulkBatchRequest) int { + total := 0 + for _, frame := range req.frames { + total += bulkFastBatchFrameLen(frame) + } + return total +} + +func bulkBatchRequestSupportsSharedSuperBatch(req bulkBatchRequest) bool { + if len(req.frames) == 0 || !bulkFastPathSupportsSharedBatch(req.fastPathVersion) { + return false + } + for _, frame := range req.frames { + switch frame.Type { + case bulkFastPayloadTypeData: + default: + return false + } + } + return true +} + +func bulkBatchShouldWaitForMore(batch []bulkBatchRequest, batchBytes int) bool { + if bulkBatchMaxFlushDelay <= 0 || len(batch) == 0 { + return false + } + if len(batch) >= bulkBatchMaxPayloads || batchBytes >= bulkBatchMaxPayloadBytes { + return false + } + for _, req := range batch { + if !bulkBatchRequestSupportsSharedSuperBatch(req) { + return false + } + } + return true +} + +func cloneQueuedBulkBatchRequest(req bulkBatchRequest) bulkBatchRequest { + if len(req.frames) == 0 || req.payloadOwned { + return req + } + clonedFrames := make([]bulkFastFrame, len(req.frames)) + totalPayload := 0 + for _, frame := range req.frames { + totalPayload += len(frame.Payload) + } + var payloadBuf []byte + if totalPayload > 0 { + payloadBuf = getBulkAsyncWritePayload(totalPayload) + req.release = func() { + putBulkAsyncWritePayload(payloadBuf) + } + } + offset := 0 + for index, frame := range req.frames { + clonedFrames[index] = frame + if len(frame.Payload) == 0 { + clonedFrames[index].Payload = nil + continue + } + next := offset + len(frame.Payload) + clonedFrames[index].Payload = payloadBuf[offset:next] + copy(clonedFrames[index].Payload, frame.Payload) + offset = next + } + req.frames = clonedFrames + return req +} + +func (s *bulkBatchSender) encodeSingle(frame bulkFastFrame) ([]byte, func(), error) { + if s == nil || s.codec.encodeSingle == nil { + return nil, nil, errTransportDetached + } + return s.codec.encodeSingle(frame) +} + +func (s *bulkBatchSender) encodeBatch(frames []bulkFastFrame) ([]byte, func(), error) { + if len(frames) == 1 || s.codec.encodeBatch == nil { + return s.encodeSingle(frames[0]) + } + return s.codec.encodeBatch(frames) } func (s *bulkBatchSender) stop() { @@ -231,13 +615,23 @@ func (s *bulkBatchSender) failPending(err error) { for { select { case item := <-s.reqCh: - item.done <- err + s.finishRequest(item, err) default: return } } } +func (s *bulkBatchSender) finishRequest(req bulkBatchRequest, err error) { + if s != nil { + s.queued.Add(-1) + } + if req.release != nil { + req.release() + } + req.done <- err +} + func (s *bulkBatchSender) setErr(err error) { if s == nil || err == nil { return @@ -264,3 +658,31 @@ func (s *bulkBatchSender) stoppedErr() error { } return errTransportDetached } + +func (s *bulkBatchSender) transportWriteDeadline() time.Time { + if s == nil || s.writeTimeoutProvider == nil { + return time.Time{} + } + return writeDeadlineFromTimeout(s.writeTimeoutProvider()) +} + +func (s *bulkBatchSender) transportWriteTimeout() time.Duration { + if s == nil || s.writeTimeoutProvider == nil { + return 0 + } + return s.writeTimeoutProvider() +} + +func (s *bulkBatchSender) sharedMixedPayloadLimit() int { + if s == nil || s.binding == nil { + return bulkAdaptiveSoftPayloadFallbackBytes + } + return s.binding.bulkAdaptiveSoftPayloadBytesSnapshot() +} + +func minInt(a int, b int) int { + if a < b { + return a + } + return b +} diff --git a/bulk_benchmark_test.go b/bulk_benchmark_test.go index 2539500..ad10758 100644 --- a/bulk_benchmark_test.go +++ b/bulk_benchmark_test.go @@ -161,7 +161,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { return nil }) - if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil { b.Fatalf("server Listen failed: %v", err) } b.Cleanup(func() { @@ -172,7 +172,7 @@ func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) { if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { b.Fatalf("UseModernPSKClient failed: %v", err) } - if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } b.Cleanup(func() { @@ -258,7 +258,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr return nil }) - if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil { b.Fatalf("server Listen failed: %v", err) } b.Cleanup(func() { @@ -269,7 +269,7 @@ func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurr if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { b.Fatalf("UseModernPSKClient failed: %v", err) } - if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } b.Cleanup(func() { diff --git a/bulk_buffer_release_test.go b/bulk_buffer_release_test.go new file mode 100644 index 0000000..67c098a --- /dev/null +++ b/bulk_buffer_release_test.go @@ -0,0 +1,119 @@ +package notify + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestBulkOwnedChunkReleaseAfterRead(t *testing.T) { + bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-read"), clientFileScope(), BulkOpenRequest{ + BulkID: "buffer-release-read", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + released := 0 + if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() { + released++ + }); err != nil { + t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err) + } + + buf := make([]byte, 5) + n, err := bulk.Read(buf) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if n != 5 || string(buf[:n]) != "hello" { + t.Fatalf("Read = %d %q, want 5 hello", n, string(buf[:n])) + } + if released != 1 { + t.Fatalf("release count = %d, want 1", released) + } +} + +func TestBulkOwnedChunkReleaseOnReset(t *testing.T) { + bulk := newBulkHandle(context.Background(), newBulkRuntime("buffer-release-reset"), clientFileScope(), BulkOpenRequest{ + BulkID: "buffer-release-reset", + DataID: 1, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + + released := 0 + if err := bulk.pushOwnedChunkWithReleaseNoReset([]byte("hello"), func() { + released++ + }); err != nil { + t.Fatalf("pushOwnedChunkWithReleaseNoReset failed: %v", err) + } + + bulk.markReset(errors.New("boom")) + if released != 1 { + t.Fatalf("release count = %d, want 1", released) + } +} + +func TestBulkReadDoesNotBlockOnAsyncWindowRelease(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + releaseStarted := make(chan struct{}) + releaseUnblock := make(chan struct{}) + bulk := newBulkHandle(ctx, newBulkRuntime("buffer-release-async"), clientFileScope(), BulkOpenRequest{ + BulkID: "buffer-release-async", + DataID: 1, + ChunkSize: 4, + WindowBytes: 4, + MaxInFlight: 1, + }, 0, nil, nil, 0, nil, nil, nil, nil, func(_ *bulkHandle, bytes int64, chunks int) error { + if bytes != 4 || chunks != 1 { + t.Fatalf("release = (%d,%d), want (4,1)", bytes, chunks) + } + close(releaseStarted) + <-releaseUnblock + return nil + }) + + if err := bulk.pushOwnedChunk([]byte("ping")); err != nil { + t.Fatalf("pushOwnedChunk failed: %v", err) + } + + buf := make([]byte, 4) + doneCh := make(chan error, 1) + go func() { + n, err := bulk.Read(buf) + if err != nil { + doneCh <- err + return + } + if got, want := n, 4; got != want { + doneCh <- errors.New("unexpected read size") + return + } + doneCh <- nil + }() + + select { + case err := <-doneCh: + if err != nil { + t.Fatalf("Read failed: %v", err) + } + case <-time.After(200 * time.Millisecond): + t.Fatal("Read should not block on async release sender") + } + + select { + case <-releaseStarted: + case <-time.After(time.Second): + t.Fatal("window release sender did not start") + } + + close(releaseUnblock) + cancel() + if bulk.releaseWorkerDone != nil { + select { + case <-bulk.releaseWorkerDone: + case <-time.After(time.Second): + t.Fatal("release worker did not exit") + } + } +} diff --git a/bulk_control.go b/bulk_control.go index 9001e01..61ed606 100644 --- a/bulk_control.go +++ b/bulk_control.go @@ -7,22 +7,25 @@ import ( ) type BulkOpenRequest struct { - BulkID string - DataID uint64 - Range BulkRange - Metadata BulkMetadata - ReadTimeout time.Duration - WriteTimeout time.Duration - Dedicated bool - AttachToken string - ChunkSize int - WindowBytes int - MaxInFlight int + BulkID string + DataID uint64 + FastPathVersion uint8 + Range BulkRange + Metadata BulkMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration + Dedicated bool + DedicatedLaneID uint32 + AttachToken string + ChunkSize int + WindowBytes int + MaxInFlight int } type BulkOpenResponse struct { BulkID string DataID uint64 + FastPathVersion uint8 Accepted bool Dedicated bool AttachToken string @@ -53,6 +56,18 @@ type BulkResetResponse struct { Error string } +type BulkReadyRequest struct { + BulkID string + DataID uint64 + Error string +} + +type BulkReadyResponse struct { + BulkID string + Accepted bool + Error string +} + type BulkReleaseRequest struct { BulkID string DataID uint64 @@ -73,6 +88,9 @@ func bindClientBulkControl(c *ClientCommon) { c.SetLink(BulkResetSignalKey, func(msg *Message) { c.handleInboundBulkReset(msg) }) + c.SetLink(BulkReadySignalKey, func(msg *Message) { + c.handleInboundBulkReady(msg) + }) c.SetLink(BulkReleaseSignalKey, func(msg *Message) { c.handleInboundBulkRelease(msg) }) @@ -91,14 +109,188 @@ func bindServerBulkControl(s *ServerCommon) { s.SetLink(BulkResetSignalKey, func(msg *Message) { s.handleInboundBulkReset(msg) }) + s.SetLink(BulkReadySignalKey, func(msg *Message) { + s.handleInboundBulkReady(msg) + }) s.SetLink(BulkReleaseSignalKey, func(msg *Message) { s.handleInboundBulkRelease(msg) }) } +func bulkAcceptInfoFromClientBulk(bulk *bulkHandle) BulkAcceptInfo { + if bulk == nil { + return BulkAcceptInfo{} + } + return BulkAcceptInfo{ + ID: bulk.ID(), + Range: bulk.Range(), + Metadata: bulk.Metadata(), + Dedicated: bulk.Dedicated(), + TransportGeneration: bulk.TransportGeneration(), + Bulk: bulk, + } +} + +func bulkAcceptInfoFromServerBulk(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) BulkAcceptInfo { + if bulk == nil { + return BulkAcceptInfo{} + } + if logical == nil { + logical = bulk.LogicalConn() + } + if transport == nil { + transport = bulk.TransportConn() + } + return BulkAcceptInfo{ + ID: bulk.ID(), + Range: bulk.Range(), + Metadata: bulk.Metadata(), + Dedicated: bulk.Dedicated(), + LogicalConn: logical, + TransportConn: transport, + TransportGeneration: bulk.TransportGeneration(), + Bulk: bulk, + } +} + +func dispatchBulkAccept(handler func(BulkAcceptInfo) error, bulk *bulkHandle, info BulkAcceptInfo) error { + if bulk == nil { + return errBulkNotFound + } + if !bulk.markAcceptDispatched() { + return nil + } + var dispatchErr error + defer func() { + bulk.finishAcceptDispatch(dispatchErr) + }() + if handler == nil { + dispatchErr = errBulkHandlerNotConfigured + bulk.markReset(dispatchErr) + return dispatchErr + } + if err := handler(info); err != nil { + dispatchErr = err + bulk.markReset(dispatchErr) + return dispatchErr + } + return nil +} + +func (c *ClientCommon) clientBulkAcceptReadyNotifier(bulk *bulkHandle) func(error) { + return func(readyErr error) { + if c == nil || bulk == nil { + return + } + req := BulkReadyRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + } + if readyErr != nil { + req.Error = readyErr.Error() + } + ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout) + defer cancel() + if _, err := sendBulkReadyClient(ctx, c, req); err != nil && bulk.Context().Err() == nil { + bulk.markReset(err) + } + } +} + +func sendBulkReadyServer(ctx context.Context, s *ServerCommon, logical *LogicalConn, transport *TransportConn, req BulkReadyRequest) error { + if s == nil { + return errBulkServerNil + } + if transport != nil { + if _, err := sendBulkReadyServerTransport(ctx, s, transport, req); err == nil { + return nil + } else if !errors.Is(err, errTransportDetached) && !errors.Is(err, errBulkTransportNil) { + return err + } + } + if logical == nil { + return errBulkLogicalConnNil + } + _, err := sendBulkReadyServerLogical(ctx, s, logical, req) + return err +} + +func (s *ServerCommon) serverBulkAcceptReadyNotifier(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) func(error) { + return func(readyErr error) { + if s == nil || bulk == nil { + return + } + req := BulkReadyRequest{ + BulkID: bulk.ID(), + DataID: bulk.dataIDSnapshot(), + } + if readyErr != nil { + req.Error = readyErr.Error() + } + ctx, cancel := context.WithTimeout(context.Background(), defaultBulkAcceptReadyTimeout) + defer cancel() + if err := sendBulkReadyServer(ctx, s, logical, transport, req); err != nil && bulk.Context().Err() == nil { + bulk.markReset(err) + } + } +} + +func (c *ClientCommon) startClientBulkAcceptDispatch(bulk *bulkHandle) { + if c == nil || bulk == nil { + return + } + bulk.setAcceptNotify(c.clientBulkAcceptReadyNotifier(bulk)) + go func() { + _ = c.dispatchClientBulkAccept(bulk) + }() +} + +func (s *ServerCommon) startServerBulkAcceptDispatch(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) { + if s == nil || bulk == nil { + return + } + if logical == nil { + logical = bulk.LogicalConn() + } + if transport == nil { + transport = bulk.TransportConn() + } + bulk.setAcceptNotify(s.serverBulkAcceptReadyNotifier(bulk, logical, transport)) + go func() { + _ = s.dispatchServerBulkAccept(bulk, logical, transport) + }() +} + +func (c *ClientCommon) dispatchClientBulkAccept(bulk *bulkHandle) error { + if c == nil { + return errBulkClientNil + } + runtime := c.getBulkRuntime() + if runtime == nil { + return errBulkRuntimeNil + } + return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromClientBulk(bulk)) +} + +func (s *ServerCommon) dispatchServerBulkAccept(bulk *bulkHandle, logical *LogicalConn, transport *TransportConn) error { + if s == nil { + return errBulkServerNil + } + runtime := s.getBulkRuntime() + if runtime == nil { + return errBulkRuntimeNil + } + return dispatchBulkAccept(runtime.handlerSnapshot(), bulk, bulkAcceptInfoFromServerBulk(bulk, logical, transport)) +} + func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { req, err := decodeBulkOpenRequest(msg) - resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} + resp := BulkOpenResponse{ + BulkID: req.BulkID, + DataID: req.DataID, + FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion), + Dedicated: req.Dedicated, + } if err != nil { resp.Error = err.Error() replyBulkControlIfNeeded(msg, resp) @@ -133,13 +325,6 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { replyBulkControlIfNeeded(msg, resp) return } - handler := runtime.handlerSnapshot() - if handler == nil { - bulk.markReset(errBulkHandlerNotConfigured) - resp.Error = errBulkHandlerNotConfigured.Error() - replyBulkControlIfNeeded(msg, resp) - return - } if req.Dedicated { if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil { bulk.markReset(err) @@ -147,17 +332,14 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { replyBulkControlIfNeeded(msg, resp) return } + resp.Accepted = true + resp.DataID = bulk.dataIDSnapshot() + resp.TransportGeneration = bulk.TransportGeneration() + replyBulkControlIfNeeded(msg, resp) + c.startClientBulkAcceptDispatch(bulk) + return } - info := BulkAcceptInfo{ - ID: bulk.ID(), - Range: bulk.Range(), - Metadata: bulk.Metadata(), - Dedicated: bulk.Dedicated(), - TransportGeneration: bulk.TransportGeneration(), - Bulk: bulk, - } - if err := handler(info); err != nil { - bulk.markReset(err) + if err := c.dispatchClientBulkAccept(bulk); err != nil { resp.Error = err.Error() replyBulkControlIfNeeded(msg, resp) return @@ -170,7 +352,12 @@ func (c *ClientCommon) handleInboundBulkOpen(msg *Message) { func (s *ServerCommon) handleInboundBulkOpen(msg *Message) { req, err := decodeBulkOpenRequest(msg) - resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated} + resp := BulkOpenResponse{ + BulkID: req.BulkID, + DataID: req.DataID, + FastPathVersion: negotiateBulkFastPathVersion(req.FastPathVersion), + Dedicated: req.Dedicated, + } if err != nil { resp.Error = err.Error() replyBulkControlIfNeeded(msg, resp) @@ -218,25 +405,24 @@ func (s *ServerCommon) handleInboundBulkOpen(msg *Message) { replyBulkControlIfNeeded(msg, resp) return } - handler := runtime.handlerSnapshot() - if handler == nil { + s.attachServerDedicatedSidecarIfExists(logical, bulk) + if runtime.handlerSnapshot() == nil { bulk.markReset(errBulkHandlerNotConfigured) resp.Error = errBulkHandlerNotConfigured.Error() replyBulkControlIfNeeded(msg, resp) return } - info := BulkAcceptInfo{ - ID: bulk.ID(), - Range: bulk.Range(), - Metadata: bulk.Metadata(), - Dedicated: bulk.Dedicated(), - LogicalConn: logical, - TransportConn: transport, - TransportGeneration: bulk.TransportGeneration(), - Bulk: bulk, + if req.Dedicated { + resp.Accepted = true + resp.DataID = bulk.dataIDSnapshot() + resp.TransportGeneration = bulk.TransportGeneration() + replyBulkControlIfNeeded(msg, resp) + if bulk.dedicatedAttachedSnapshot() { + s.startServerBulkAcceptDispatch(bulk, logical, messageTransportConnSnapshot(msg)) + } + return } - if err := handler(info); err != nil { - bulk.markReset(err) + if err := s.dispatchServerBulkAccept(bulk, logical, transport); err != nil { resp.Error = err.Error() replyBulkControlIfNeeded(msg, resp) return @@ -357,6 +543,41 @@ func (c *ClientCommon) handleInboundBulkRelease(msg *Message) { bulk.releaseOutboundWindow(req.Bytes, req.Chunks) } +func (c *ClientCommon) handleInboundBulkReady(msg *Message) { + req, err := decodeBulkReadyRequest(msg) + resp := BulkReadyResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := c.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + bulk, ok := runtime.lookup(clientFileScope(), req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID) + } + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if resp.BulkID == "" { + resp.BulkID = bulk.ID() + } + readyErr := bulkReadyRemoteError(req.Error) + bulk.markAcceptReady(readyErr) + if readyErr != nil { + bulk.markReset(readyErr) + } + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + func (s *ServerCommon) handleInboundBulkReset(msg *Message) { req, err := decodeBulkResetRequest(msg) resp := BulkResetResponse{BulkID: req.BulkID} @@ -390,6 +611,43 @@ func (s *ServerCommon) handleInboundBulkReset(msg *Message) { replyBulkControlIfNeeded(msg, resp) } +func (s *ServerCommon) handleInboundBulkReady(msg *Message) { + req, err := decodeBulkReadyRequest(msg) + resp := BulkReadyResponse{BulkID: req.BulkID} + if err != nil { + resp.Error = err.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + resp.Error = errBulkRuntimeNil.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + logical := messageLogicalConnSnapshot(msg) + scope := serverFileScope(logical) + bulk, ok := runtime.lookup(scope, req.BulkID) + if !ok && req.DataID != 0 { + bulk, ok = runtime.lookupByDataID(scope, req.DataID) + } + if !ok { + resp.Error = errBulkNotFound.Error() + replyBulkControlIfNeeded(msg, resp) + return + } + if resp.BulkID == "" { + resp.BulkID = bulk.ID() + } + readyErr := bulkReadyRemoteError(req.Error) + bulk.markAcceptReady(readyErr) + if readyErr != nil { + bulk.markReset(readyErr) + } + resp.Accepted = true + replyBulkControlIfNeeded(msg, resp) +} + func (s *ServerCommon) handleInboundBulkRelease(msg *Message) { req, err := decodeBulkReleaseRequest(msg) if err != nil { @@ -625,11 +883,26 @@ func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) { return req, nil } +func decodeBulkReadyRequest(msg *Message) (BulkReadyRequest, error) { + var req BulkReadyRequest + if msg == nil { + return BulkReadyRequest{}, errBulkIDEmpty + } + if err := msg.Value.Orm(&req); err != nil { + return BulkReadyRequest{}, err + } + if req.BulkID == "" && req.DataID == 0 { + return BulkReadyRequest{}, errBulkIDEmpty + } + return req, nil +} + func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) { var resp BulkOpenResponse if err := msg.Value.Orm(&resp); err != nil { return BulkOpenResponse{}, err } + resp.FastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion) return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil) } @@ -649,6 +922,14 @@ func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) { return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil) } +func decodeBulkReadyResponse(msg Message) (BulkReadyResponse, error) { + var resp BulkReadyResponse + if err := msg.Value.Orm(&resp); err != nil { + return BulkReadyResponse{}, err + } + return resp, bulkControlResultError("ready", resp.Accepted, resp.Error, nil) +} + func bulkControlResultError(op string, accepted bool, message string, callErr error) error { if callErr != nil { return callErr @@ -697,6 +978,52 @@ func bulkRemoteResetError(message string) error { return errors.New(message) } +func bulkReadyRemoteError(message string) error { + if message == "" { + return nil + } + return bulkControlMessageError(message) +} + +func sendBulkReadyClient(ctx context.Context, c Client, req BulkReadyRequest) (BulkReadyResponse, error) { + if c == nil { + return BulkReadyResponse{}, errBulkClientNil + } + msg, err := c.SendObjCtx(ctx, BulkReadySignalKey, req) + if err != nil { + return BulkReadyResponse{}, err + } + return decodeBulkReadyResponse(msg) +} + +func sendBulkReadyServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkReadyRequest) (BulkReadyResponse, error) { + if s == nil { + return BulkReadyResponse{}, errBulkServerNil + } + if logical == nil { + return BulkReadyResponse{}, errBulkLogicalConnNil + } + msg, err := s.SendObjCtxLogical(ctx, logical, BulkReadySignalKey, req) + if err != nil { + return BulkReadyResponse{}, err + } + return decodeBulkReadyResponse(msg) +} + +func sendBulkReadyServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkReadyRequest) (BulkReadyResponse, error) { + if s == nil { + return BulkReadyResponse{}, errBulkServerNil + } + if transport == nil { + return BulkReadyResponse{}, errBulkTransportNil + } + msg, err := s.SendObjCtxTransport(ctx, transport, BulkReadySignalKey, req) + if err != nil { + return BulkReadyResponse{}, err + } + return decodeBulkReadyResponse(msg) +} + func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 { return streamTransportGeneration(logical, transport) } diff --git a/bulk_dedicated.go b/bulk_dedicated.go index d3bf44d..8f63ca2 100644 --- a/bulk_dedicated.go +++ b/bulk_dedicated.go @@ -3,7 +3,6 @@ package notify import ( "b612.me/notify/internal/transport" "b612.me/stario" - "bytes" "context" cryptorand "crypto/rand" "encoding/binary" @@ -12,6 +11,7 @@ import ( "fmt" "io" "net" + "strings" "sync/atomic" "time" ) @@ -20,18 +20,16 @@ const ( systemBulkAttachKey = "_notify_bulk_attach" bulkDedicatedRecordMagic = "NBR1" bulkDedicatedRecordHeaderLen = 8 - bulkDedicatedAttachTimeout = 5 * time.Second - bulkDedicatedAttachFrameMagicSize = 8 - bulkDedicatedAttachFrameHeaderLen = 14 - bulkDedicatedAttachFrameVersionOffset = 12 - bulkDedicatedAttachFrameFlagsOffset = 13 - bulkDedicatedAttachFrameVersionV1 = 1 - bulkDedicatedAttachFrameFlagsNone = 0 + defaultBulkDedicatedAttachLimit = 16 + defaultBulkDedicatedActiveLimit = 4096 + defaultBulkDedicatedLaneLimit = 4 + defaultBulkDedicatedAttachRetry = 2 + defaultBulkDedicatedAttachBackoff = 150 * time.Millisecond + defaultBulkDedicatedDialTimeout = 5 * time.Second + defaultBulkDedicatedHelloTimeout = 10 * time.Second ) -var bulkDedicatedAttachFrameMagic = [bulkDedicatedAttachFrameMagicSize]byte{11, 27, 19, 96, 12, 25, 2, 20} - type bulkAttachRequest struct { PeerID string BulkID string @@ -39,8 +37,84 @@ type bulkAttachRequest struct { } type bulkAttachResponse struct { - Accepted bool - Error string + Accepted bool + Error string + Code string + Retryable bool + FailedSeq uint64 + FailedBulk string +} + +type bulkAttachErrorCode string + +const ( + bulkAttachErrorCodeInvalidRequest bulkAttachErrorCode = "invalid_request" + bulkAttachErrorCodeServerUnavailable bulkAttachErrorCode = "server_unavailable" + bulkAttachErrorCodePeerNotFound bulkAttachErrorCode = "peer_not_found" + bulkAttachErrorCodeBulkNotFound bulkAttachErrorCode = "bulk_not_found" + bulkAttachErrorCodeBulkNotDedicated bulkAttachErrorCode = "bulk_not_dedicated" + bulkAttachErrorCodeTokenMismatch bulkAttachErrorCode = "token_mismatch" + bulkAttachErrorCodeAlreadyAttached bulkAttachErrorCode = "already_attached" + bulkAttachErrorCodeAttachFailed bulkAttachErrorCode = "attach_failed" + bulkAttachErrorCodeInternal bulkAttachErrorCode = "internal_error" +) + +type bulkAttachError struct { + Code bulkAttachErrorCode + Retryable bool + Message string + FailedSeq uint64 + FailedBulk string +} + +func (e *bulkAttachError) Error() string { + if e == nil { + return "" + } + if e.Message != "" { + return e.Message + } + if e.Code != "" { + return string(e.Code) + } + return "bulk attach failed" +} + +func newBulkAttachError(code bulkAttachErrorCode, retryable bool, message string) *bulkAttachError { + return &bulkAttachError{ + Code: code, + Retryable: retryable, + Message: message, + } +} + +func toBulkAttachResponseError(err error, bulkID string) bulkAttachResponse { + resp := bulkAttachResponse{ + Accepted: false, + FailedBulk: bulkID, + } + if err == nil { + resp.Error = "bulk attach failed" + resp.Code = string(bulkAttachErrorCodeInternal) + return resp + } + var attachErr *bulkAttachError + if errors.As(err, &attachErr) && attachErr != nil { + resp.Error = attachErr.Error() + resp.Code = string(attachErr.Code) + resp.Retryable = attachErr.Retryable + resp.FailedSeq = attachErr.FailedSeq + if attachErr.FailedBulk != "" { + resp.FailedBulk = attachErr.FailedBulk + } + if resp.Code == "" { + resp.Code = string(bulkAttachErrorCodeInternal) + } + return resp + } + resp.Error = err.Error() + resp.Code = string(bulkAttachErrorCodeInternal) + return resp } func newBulkAttachToken() string { @@ -136,29 +210,7 @@ func readDirectSignalFramePayload(conn net.Conn) ([]byte, error) { if conn == nil { return nil, net.ErrClosed } - var header [bulkDedicatedAttachFrameHeaderLen]byte - if _, err := io.ReadFull(conn, header[:]); err != nil { - return nil, err - } - if !bytes.Equal(header[:bulkDedicatedAttachFrameMagicSize], bulkDedicatedAttachFrameMagic[:]) { - return nil, stario.ErrQueueDataFormat - } - if got := header[bulkDedicatedAttachFrameVersionOffset]; got != bulkDedicatedAttachFrameVersionV1 { - return nil, stario.ErrQueueUnsupportedVersion - } - if got := header[bulkDedicatedAttachFrameFlagsOffset]; got != bulkDedicatedAttachFrameFlagsNone { - return nil, stario.ErrQueueUnsupportedFlags - } - length := binary.BigEndian.Uint32(header[bulkDedicatedAttachFrameMagicSize : bulkDedicatedAttachFrameMagicSize+4]) - maxInt := int(^uint(0) >> 1) - if uint64(length) > uint64(maxInt) { - return nil, stario.ErrQueueMessageTooLarge - } - payload := make([]byte, int(length)) - if _, err := io.ReadFull(conn, payload); err != nil { - return nil, err - } - return payload, nil + return newTransportFrameReader(conn, stario.NewQueue()).Next() } func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error { @@ -178,36 +230,61 @@ func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadlin } func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) { + payload, release, err := readBulkDedicatedRecordPooled(conn) + if err != nil { + return nil, err + } + if release != nil { + defer release() + } + return append([]byte(nil), payload...), nil +} + +func readBulkDedicatedRecordPooled(conn net.Conn) ([]byte, func(), error) { if conn == nil { - return nil, net.ErrClosed + return nil, nil, net.ErrClosed } var header [bulkDedicatedRecordHeaderLen]byte if _, err := io.ReadFull(conn, header[:]); err != nil { - return nil, err + return nil, nil, err } if string(header[:4]) != bulkDedicatedRecordMagic { - return nil, fmt.Errorf("%w: record magic=%x", errBulkFastPayloadInvalid, header[:4]) + return nil, nil, fmt.Errorf("%w: record magic=%x", errBulkFastPayloadInvalid, header[:4]) } size := int(binary.BigEndian.Uint32(header[4:8])) if size < 0 { - return nil, errBulkFastPayloadInvalid + return nil, nil, errBulkFastPayloadInvalid } - payload := make([]byte, size) + payload := getModernPSKPayloadBuffer(size) if _, err := io.ReadFull(conn, payload); err != nil { - return nil, err + putModernPSKPayloadBuffer(payload) + return nil, nil, err } - return payload, nil + return payload, func() { + putModernPSKPayloadBuffer(payload) + }, nil } -func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) { +func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context, timeout time.Duration) (net.Conn, error) { source := c.clientConnectSourceSnapshot() - if source != nil && source.canReconnect() { - return source.dial(ctx) + if source != nil { + if source.network != "" && source.addr != "" { + if timeout > 0 { + return transport.DialTimeout(source.network, source.addr, timeout) + } + return transport.Dial(source.network, source.addr) + } + if source.canReconnect() { + return source.dial(ctx) + } } conn := c.clientTransportConnSnapshot() if conn == nil || conn.RemoteAddr() == nil { return nil, errClientReconnectSourceUnavailable } + if timeout > 0 { + return transport.DialTimeout(conn.RemoteAddr().Network(), conn.RemoteAddr().String(), timeout) + } return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String()) } @@ -218,30 +295,346 @@ func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bul if ctx == nil { ctx = context.Background() } - ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout) - defer cancel() - conn, err := c.dialDedicatedBulkConn(ctx) + laneID := bulk.dedicatedLaneIDSnapshot() + releaseActiveSlot, err := c.acquireBulkDedicatedActiveSlot(ctx) if err != nil { return err } - resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk) - if err != nil { - _ = conn.Close() - return err - } - if !resp.Accepted { - _ = conn.Close() - if resp.Error != "" { - return errors.New(resp.Error) + needReleaseActive := true + defer func() { + if needReleaseActive { + releaseActiveSlot() + } + }() + if sidecar := c.clientDedicatedSidecarSnapshotForLane(laneID); sidecar != nil && sidecar.conn != nil { + if err := bulk.attachDedicatedConnShared(sidecar.conn); err == nil { + bulk.markDedicatedActiveReserved() + needReleaseActive = false + return nil } - return errors.New("bulk attach rejected") } - if err := bulk.attachDedicatedConn(conn); err != nil { - _ = conn.Close() + _, flight, leader := c.beginClientDedicatedSidecarAttach(laneID) + if !leader { + if flight == nil { + return errTransportDetached + } + if err := flight.wait(ctx); err != nil { + return err + } + sidecar := c.clientDedicatedSidecarSnapshotForLane(laneID) + if sidecar == nil || sidecar.conn == nil { + return errTransportDetached + } + if err := bulk.attachDedicatedConnShared(sidecar.conn); err != nil { + return err + } + bulk.markDedicatedActiveReserved() + needReleaseActive = false + return nil + } + if flight == nil { + return errTransportDetached + } + var flightErr error + defer func() { + c.finishClientDedicatedSidecarAttach(laneID, flight, flightErr) + }() + releaseAttachSlot, err := c.acquireBulkDedicatedAttachSlot(ctx) + if err != nil { + flightErr = err return err } - go c.readDedicatedBulkLoop(bulk, conn) - return nil + defer releaseAttachSlot() + if sidecar := c.clientDedicatedSidecarSnapshotForLane(laneID); sidecar != nil && sidecar.conn != nil { + if err := bulk.attachDedicatedConnShared(sidecar.conn); err == nil { + bulk.markDedicatedActiveReserved() + needReleaseActive = false + flightErr = nil + return nil + } + } + retry, backoff, dialTimeout, helloTimeout := c.bulkDedicatedAttachConfigSnapshot() + attempts := retry + 1 + if attempts <= 0 { + attempts = 1 + } + var lastErr error + for attempt := 1; attempt <= attempts; attempt++ { + c.bulkAttachAttemptCount.Add(1) + if attempt > 1 { + delay := backoff * time.Duration(1<<(attempt-2)) + if delay > 3*time.Second { + delay = 3 * time.Second + } + if err := waitDedicatedAttachBackoff(ctx, delay); err != nil { + flightErr = err + return err + } + } + dialCtx := ctx + dialCancel := func() {} + if dialTimeout > 0 { + dialCtx, dialCancel = context.WithTimeout(ctx, dialTimeout) + } + conn, err := c.dialDedicatedBulkConn(dialCtx, dialTimeout) + dialCancel() + if err != nil { + lastErr = err + if attempt < attempts && isRetryableDedicatedAttachError(err) { + flightErr = err + c.bulkAttachRetryCount.Add(1) + continue + } + bulk.markDedicatedAttachDegraded(string(bulkAttachErrorCodeAttachFailed)) + flightErr = err + return err + } + helloCtx := ctx + helloCancel := func() {} + if helloTimeout > 0 { + helloCtx, helloCancel = context.WithTimeout(ctx, helloTimeout) + } + resp, err := c.sendDedicatedBulkAttachRequest(helloCtx, conn, bulk) + helloCancel() + if err != nil { + _ = conn.Close() + lastErr = err + if attempt < attempts && isRetryableDedicatedAttachError(err) { + flightErr = err + c.bulkAttachRetryCount.Add(1) + continue + } + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + bulk.markDedicatedAttachDegraded(string(bulkAttachErrorCodeAttachFailed)) + flightErr = err + return err + } + if !resp.Accepted { + _ = conn.Close() + rejectedErr := &bulkAttachError{ + Code: bulkAttachErrorCode(resp.Code), + Retryable: resp.Retryable, + Message: resp.Error, + FailedSeq: resp.FailedSeq, + FailedBulk: resp.FailedBulk, + } + if rejectedErr.Code == "" { + rejectedErr.Code = bulkAttachErrorCodeAttachFailed + } + lastErr = rejectedErr + bulk.setDedicatedAttachLastCode(string(rejectedErr.Code)) + if attempt < attempts && rejectedErr.Retryable { + flightErr = rejectedErr + c.bulkAttachRetryCount.Add(1) + continue + } + bulk.markDedicatedAttachDegraded(string(rejectedErr.Code)) + flightErr = rejectedErr + return rejectedErr + } + sidecar := newBulkDedicatedSidecar(conn, laneID) + activeSidecar, installed := c.installClientDedicatedSidecar(laneID, sidecar) + if !installed { + sidecar.close() + sidecar = activeSidecar + } + if sidecar == nil || sidecar.conn == nil { + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + flightErr = errTransportDetached + return errTransportDetached + } + if err := bulk.attachDedicatedConnShared(sidecar.conn); err != nil { + if installed && c.clearClientDedicatedSidecar(laneID, sidecar) { + sidecar.close() + } + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + if installed { + flightErr = err + } else { + flightErr = nil + } + return err + } + c.bulkAttachSuccessCount.Add(1) + if installed { + go c.readDedicatedSidecarLoop(sidecar) + } + bulk.markDedicatedActiveReserved() + needReleaseActive = false + flightErr = nil + return nil + } + if lastErr != nil { + flightErr = lastErr + return lastErr + } + flightErr = errors.New("bulk attach failed") + return flightErr +} + +func (c *ClientCommon) bulkDedicatedAttachConfigSnapshot() (int, time.Duration, time.Duration, time.Duration) { + if c == nil { + return defaultBulkDedicatedAttachRetry, defaultBulkDedicatedAttachBackoff, defaultBulkDedicatedDialTimeout, defaultBulkDedicatedHelloTimeout + } + c.mu.Lock() + defer c.mu.Unlock() + retry := c.bulkDedicatedAttachRetry + if retry < 0 { + retry = 0 + } + backoff := c.bulkDedicatedAttachBackoff + if backoff <= 0 { + backoff = defaultBulkDedicatedAttachBackoff + } + dialTimeout := c.bulkDedicatedDialTimeout + if dialTimeout <= 0 { + dialTimeout = defaultBulkDedicatedDialTimeout + } + helloTimeout := c.bulkDedicatedHelloTimeout + if helloTimeout <= 0 { + helloTimeout = defaultBulkDedicatedHelloTimeout + } + return retry, backoff, dialTimeout, helloTimeout +} + +func (c *ClientCommon) acquireBulkDedicatedAttachSlot(ctx context.Context) (func(), error) { + sem := c.bulkDedicatedAttachSemaphoreSnapshot() + if c == nil || sem == nil { + return func() {}, nil + } + if ctx == nil { + ctx = context.Background() + } + select { + case sem <- struct{}{}: + return func() { + select { + case <-sem: + default: + } + }, nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +func (c *ClientCommon) reserveBulkDedicatedActiveSlot() bool { + if c == nil { + return false + } + limit := c.bulkDedicatedActiveLimitSnapshot() + if limit <= 0 { + return true + } + for { + current := c.bulkDedicatedActive.Load() + if int(current) >= limit { + return false + } + if c.bulkDedicatedActive.CompareAndSwap(current, current+1) { + return true + } + } +} + +func (c *ClientCommon) acquireBulkDedicatedActiveSlot(ctx context.Context) (func(), error) { + if c == nil { + return func() {}, errBulkClientNil + } + if ctx == nil { + ctx = context.Background() + } + for { + if c.reserveBulkDedicatedActiveSlot() { + return c.releaseBulkDedicatedActiveSlot, nil + } + waitCh := c.bulkDedicatedActiveWaitSnapshot() + if c.reserveBulkDedicatedActiveSlot() { + return c.releaseBulkDedicatedActiveSlot, nil + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-waitCh: + } + } +} + +func (c *ClientCommon) releaseBulkDedicatedActiveSlot() { + if c == nil { + return + } + for { + current := c.bulkDedicatedActive.Load() + if current <= 0 { + return + } + if c.bulkDedicatedActive.CompareAndSwap(current, current-1) { + c.notifyBulkDedicatedActiveWaiters() + return + } + } +} + +func (c *ClientCommon) notifyBulkDedicatedActiveWaiters() { + if c == nil { + return + } + c.mu.Lock() + c.notifyBulkDedicatedActiveWaitersLocked() + c.mu.Unlock() +} + +func waitDedicatedAttachBackoff(ctx context.Context, delay time.Duration) error { + if delay <= 0 { + return nil + } + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} + +func isRetryableDedicatedAttachError(err error) bool { + if err == nil { + return false + } + var attachErr *bulkAttachError + if errors.As(err, &attachErr) && attachErr != nil { + return attachErr.Retryable + } + if errors.Is(err, context.Canceled) { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + if errors.As(err, &netErr) && (netErr.Timeout() || netErr.Temporary()) { + return true + } + message := strings.ToLower(err.Error()) + for _, pattern := range []string{ + "timeout", + "timed out", + "deadline", + "connection reset", + "connection refused", + "connectex", + "broken pipe", + "no route", + "host unreachable", + "transport detached", + } { + if strings.Contains(message, pattern) { + return true + } + } + return false } func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) { @@ -292,37 +685,84 @@ func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn return decodeBulkAttachResponse(c.sequenceDe, transfer.Value) } -func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) { +func (c *ClientCommon) readDedicatedSidecarLoop(sidecar *bulkDedicatedSidecar) { + if c == nil || sidecar == nil || sidecar.conn == nil { + return + } for { - payload, err := readBulkDedicatedRecord(conn) + payload, payloadRelease, err := readBulkDedicatedRecordPooled(sidecar.conn) if err != nil { - handleDedicatedBulkReadError(bulk, err) + c.handleClientDedicatedSidecarFailure(sidecar, err) return } - plain, err := c.decryptTransportPayload(payload) + plain, plainRelease, err := decryptTransportPayloadCodecPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload, payloadRelease) if err != nil { - _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) - bulk.markReset(err) + c.handleClientDedicatedSidecarFailure(sidecar, err) return } - items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) - if err != nil { - _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) - bulk.markReset(err) - return + owner := newBulkReadPayloadOwner(plainRelease) + runtime := c.getBulkRuntime() + if runtime == nil { + owner.done() + continue } - for _, item := range items { - if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { - if !errors.Is(err, io.EOF) { - _ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error()) - bulk.markReset(err) + var ( + currentDataID uint64 + currentBulk *bulkHandle + skipDataID bool + ) + err = walkDedicatedBulkInboundPayload(plain, func(dataID uint64, item bulkDedicatedBatchItem) error { + if dataID != currentDataID { + currentDataID = dataID + currentBulk = nil + skipDataID = false + bulk, ok := runtime.lookupByDataID(clientFileScope(), dataID) + if !ok { + c.bestEffortRejectInboundBulkData("", dataID, errBulkNotFound.Error()) + skipDataID = true + return nil } - return + if !bulk.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) { + detachErr := transportDetachedSessionEpochError() + bulk.markReset(detachErr) + c.bestEffortRejectInboundBulkData(bulk.ID(), dataID, detachErr.Error()) + skipDataID = true + return nil + } + bulk.markDedicatedDataStarted() + currentBulk = bulk } - if bulk.Context().Err() != nil { - return + if skipDataID || currentBulk == nil { + return nil } + var release func() + if item.Type == bulkFastPayloadTypeData { + release = owner.retainChunk() + } + dispatchErr := dispatchDedicatedBulkInboundItemWithRelease(currentBulk, item, release) + if dispatchErr != nil { + if !errors.Is(dispatchErr, io.EOF) { + _ = c.sendDedicatedBulkReset(context.Background(), currentBulk, dispatchErr.Error()) + currentBulk.markReset(dispatchErr) + } + currentBulk = nil + skipDataID = true + return nil + } + if currentBulk.Context().Err() != nil { + currentBulk = nil + skipDataID = true + } + return nil + }) + if err != nil { + if plainRelease != nil { + plainRelease() + } + c.handleClientDedicatedSidecarFailure(sidecar, err) + return } + owner.done() } } @@ -331,7 +771,6 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { return false } current := messageLogicalConnSnapshot(&message) - resp := bulkAttachResponse{} var ( req bulkAttachRequest logical *LogicalConn @@ -343,9 +782,8 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { logical, bulk, err = s.resolveInboundDedicatedBulk(current, req) } if err != nil { - resp.Error = err.Error() if current != nil { - _ = s.replyDedicatedBulkAttach(current, message, resp) + _ = s.replyDedicatedBulkAttach(current, message, toBulkAttachResponseError(err, req.BulkID)) } return true } @@ -359,63 +797,187 @@ func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool { func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) { if s == nil { - return nil, nil, errBulkServerNil + return nil, nil, newBulkAttachError(bulkAttachErrorCodeServerUnavailable, true, errBulkServerNil.Error()) } if current == nil { - return nil, nil, errBulkLogicalConnNil + return nil, nil, newBulkAttachError(bulkAttachErrorCodeInvalidRequest, false, errBulkLogicalConnNil.Error()) } if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" { - return nil, nil, errBulkIDEmpty + return nil, nil, newBulkAttachError(bulkAttachErrorCodeInvalidRequest, false, errBulkIDEmpty.Error()) } logical := s.GetLogicalConn(req.PeerID) if logical == nil { - return nil, nil, errBulkLogicalConnNil + return nil, nil, newBulkAttachError(bulkAttachErrorCodePeerNotFound, true, errBulkLogicalConnNil.Error()) } runtime := s.getBulkRuntime() if runtime == nil { - return nil, nil, errBulkRuntimeNil + return nil, nil, newBulkAttachError(bulkAttachErrorCodeServerUnavailable, true, errBulkRuntimeNil.Error()) } bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID) if !ok { - return nil, nil, errBulkNotFound + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeBulkNotFound, + Retryable: false, + Message: errBulkNotFound.Error(), + FailedBulk: req.BulkID, + } } + bulk.markDedicatedAttachAttempt() if !bulk.Dedicated() { - return nil, nil, errors.New("bulk is not dedicated") + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeBulkNotDedicated)) + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeBulkNotDedicated, + Retryable: false, + Message: "bulk is not dedicated", + FailedBulk: req.BulkID, + } } if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken { - return nil, nil, errors.New("bulk attach token mismatch") + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeTokenMismatch)) + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeTokenMismatch, + Retryable: false, + Message: "bulk attach token mismatch", + FailedBulk: req.BulkID, + } + } + switch bulk.dedicatedAttachStateSnapshot() { + case bulkDedicatedAttachStateClosed: + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeAttachFailed, + Retryable: false, + Message: "bulk dedicated attach closed", + FailedBulk: req.BulkID, + } + case bulkDedicatedAttachStateDegraded: + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeAttachFailed, + Retryable: true, + Message: "bulk dedicated attach degraded", + FailedBulk: req.BulkID, + } + case bulkDedicatedAttachStateAttached: + if bulk.dedicatedDataStartedSnapshot() { + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAlreadyAttached)) + return nil, nil, &bulkAttachError{ + Code: bulkAttachErrorCodeAlreadyAttached, + Retryable: false, + Message: "bulk dedicated already attached", + FailedBulk: req.BulkID, + } + } } return logical, bulk, nil } func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle, message Message) error { if current == nil || logical == nil || bulk == nil { - return errBulkLogicalConnNil + return newBulkAttachError(bulkAttachErrorCodeInvalidRequest, false, errBulkLogicalConnNil.Error()) } + scope := serverFileScope(logical) + laneID := bulk.dedicatedLaneIDSnapshot() conn, err := current.detachTransportForTransfer() if err != nil { - return err + return newBulkAttachError(bulkAttachErrorCodeAttachFailed, true, err.Error()) + } + stopCurrent := func(reason string, err error) { + current.markSessionStopped(reason, err) + s.removeLogical(current) } fail := func(reason string, err error) error { if conn != nil { _ = conn.Close() } - current.markSessionStopped(reason, err) - s.removeLogical(current) - return err + bulk.markDedicatedAttachDegraded(string(bulkAttachErrorCodeAttachFailed)) + stopCurrent(reason, err) + return newBulkAttachError(bulkAttachErrorCodeAttachFailed, true, err.Error()) } - if err := s.replyDedicatedBulkAttachDetached(current, conn, message, bulkAttachResponse{Accepted: true}); err != nil { - return fail("bulk dedicated attach reply failed", err) + if bulk.dedicatedAttachedSnapshot() { + if bulk.dedicatedDataStartedSnapshot() { + rejected := &bulkAttachError{ + Code: bulkAttachErrorCodeAlreadyAttached, + Retryable: false, + Message: "bulk dedicated already attached", + FailedBulk: bulk.ID(), + } + _ = s.replyDedicatedBulkAttachDetached(current, conn, message, toBulkAttachResponseError(rejected, bulk.ID())) + _ = conn.Close() + stopCurrent("bulk dedicated attach duplicate rejected", nil) + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAlreadyAttached)) + return nil + } } - if err := bulk.attachDedicatedConn(conn); err != nil { + sidecar := newBulkDedicatedSidecar(conn, laneID) + if sidecar == nil { + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + return fail("bulk dedicated attach failed", errTransportDetached) + } + var ( + oldConn net.Conn + oldSender *bulkDedicatedSender + ) + if bulk.dedicatedAttachedSnapshot() { + var replaceErr error + oldConn, oldSender, replaceErr = bulk.replaceDedicatedConnShared(conn) + if replaceErr != nil { + sidecar.close() + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + return fail("bulk dedicated reattach failed", replaceErr) + } + } else if err := bulk.attachDedicatedConnShared(conn); err != nil { + sidecar.close() + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) return fail("bulk dedicated attach failed", err) } - go s.readDedicatedBulkLoop(logical, bulk, conn) - current.markSessionStopped("bulk dedicated attach", nil) - s.removeLogical(current) + if err := s.replyDedicatedBulkAttachDetached(current, conn, message, bulkAttachResponse{Accepted: true}); err != nil { + bulk.setDedicatedAttachLastCode(string(bulkAttachErrorCodeAttachFailed)) + sidecar.close() + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.resetDedicatedByConn(scope, conn, transportDetachedError("bulk dedicated attach reply failed", err)) + } + stopCurrent("bulk dedicated attach reply failed", err) + return nil + } + oldSidecar := s.installServerDedicatedSidecar(logical, laneID, sidecar) + if runtime := s.getBulkRuntime(); runtime != nil { + runtime.attachSharedDedicatedConn(scope, laneID, conn) + } + if oldSender != nil { + oldSender.stop() + } + if oldConn != nil { + _ = oldConn.Close() + } + if oldSidecar != nil && oldSidecar != sidecar { + oldSidecar.close() + } + go s.readDedicatedSidecarLoop(logical, sidecar) + s.startServerBulkAcceptDispatch(bulk, logical, messageTransportConnSnapshot(&message)) + if runtime := s.getBulkRuntime(); runtime != nil { + s.dispatchPendingServerBulkAccepts(scope, conn, bulk, logical) + } + stopCurrent("bulk dedicated attach", nil) return nil } +func (s *ServerCommon) dispatchPendingServerBulkAccepts(scope string, conn net.Conn, current *bulkHandle, logical *LogicalConn) { + if s == nil || conn == nil { + return + } + runtime := s.getBulkRuntime() + if runtime == nil { + return + } + for _, pending := range runtime.dedicatedBulksForConn(scope, conn) { + if pending == nil || pending == current { + continue + } + s.startServerBulkAcceptDispatch(pending, logical, pending.TransportConn()) + } +} + func (s *ServerCommon) replyDedicatedBulkAttachDetached(client *LogicalConn, conn net.Conn, message Message, resp bulkAttachResponse) error { if s == nil || client == nil { return errBulkServerNil @@ -467,38 +1029,93 @@ func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Mes return err } -func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) { - for { - payload, err := readBulkDedicatedRecord(conn) - if err != nil { - handleDedicatedBulkReadError(bulk, err) - return - } - plain, err := s.decryptTransportPayloadLogical(logical, payload) - if err != nil { - _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) - bulk.markReset(err) - return - } - items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain) - if err != nil { - _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) - bulk.markReset(err) - return - } - for _, item := range items { - if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil { - if !errors.Is(err, io.EOF) { - _ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error()) - bulk.markReset(err) - } - return - } - if bulk.Context().Err() != nil { - return - } - } +func (s *ServerCommon) readDedicatedSidecarLoop(logical *LogicalConn, sidecar *bulkDedicatedSidecar) { + if s == nil || logical == nil || sidecar == nil || sidecar.conn == nil { + return } + runtime := s.getBulkRuntime() + scope := serverFileScope(logical) + for { + payload, payloadRelease, err := readBulkDedicatedRecordPooled(sidecar.conn) + if err != nil { + s.handleServerDedicatedSidecarFailure(logical, sidecar, err) + return + } + plain, plainRelease, err := decryptTransportPayloadCodecPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload, payloadRelease) + if err != nil { + s.handleServerDedicatedSidecarFailure(logical, sidecar, err) + return + } + owner := newBulkReadPayloadOwner(plainRelease) + if runtime == nil { + owner.done() + continue + } + var ( + currentDataID uint64 + currentBulk *bulkHandle + skipDataID bool + ) + err = walkDedicatedBulkInboundPayload(plain, func(dataID uint64, item bulkDedicatedBatchItem) error { + if dataID != currentDataID { + currentDataID = dataID + currentBulk = nil + skipDataID = false + bulk, ok := runtime.lookupByDataID(scope, dataID) + if !ok { + s.bestEffortRejectInboundDedicatedData(logical, sidecar.conn, dataID, errBulkNotFound.Error()) + skipDataID = true + return nil + } + bulk.markDedicatedDataStarted() + currentBulk = bulk + } + if skipDataID || currentBulk == nil { + return nil + } + var release func() + if item.Type == bulkFastPayloadTypeData { + release = owner.retainChunk() + } + dispatchErr := dispatchDedicatedBulkInboundItemWithRelease(currentBulk, item, release) + if dispatchErr != nil { + if !errors.Is(dispatchErr, io.EOF) { + _ = s.sendDedicatedBulkReset(context.Background(), logical, currentBulk, dispatchErr.Error()) + currentBulk.markReset(dispatchErr) + } + currentBulk = nil + skipDataID = true + return nil + } + if currentBulk.Context().Err() != nil { + currentBulk = nil + skipDataID = true + } + return nil + }) + if err != nil { + if plainRelease != nil { + plainRelease() + } + s.handleServerDedicatedSidecarFailure(logical, sidecar, err) + return + } + owner.done() + } +} + +func (s *ServerCommon) bestEffortRejectInboundDedicatedData(logical *LogicalConn, conn net.Conn, dataID uint64, message string) { + if s == nil || logical == nil || conn == nil || dataID == 0 { + return + } + frame, err := s.encodeDedicatedBulkBatchPayload(logical, dataID, []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeReset, + Payload: []byte(message), + }}) + if err != nil { + return + } + _ = writeBulkDedicatedRecordWithDeadline(conn, frame, writeDeadlineFromTimeout(logical.maxWriteTimeoutSnapshot())) } func handleDedicatedBulkReadError(bulk *bulkHandle, err error) { @@ -508,7 +1125,8 @@ func handleDedicatedBulkReadError(bulk *bulkHandle, err error) { if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() { return } - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) { + message := strings.ToLower(err.Error()) + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || strings.Contains(message, "use of closed network connection") { if bulk.Dedicated() || bulk.localClosedSnapshot() { bulk.markRemoteClosed() return @@ -531,6 +1149,9 @@ func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSend sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) { return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items) }, func(err error) { + if bulk.canIgnoreDedicatedCloseSendError(err) { + return + } bulk.markReset(err) }) actual := bulk.installDedicatedSender(sender) @@ -540,26 +1161,52 @@ func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSend return actual, nil } +func (c *ClientCommon) dedicatedBulkLaneSender(bulk *bulkHandle) (*bulkDedicatedLaneSender, error) { + if c == nil || bulk == nil { + return nil, errBulkClientNil + } + sidecar := c.clientDedicatedSidecarSnapshotForLane(bulk.dedicatedLaneIDSnapshot()) + conn := bulk.dedicatedConnSnapshot() + if sidecar == nil || sidecar.conn == nil || conn == nil || sidecar.conn != conn { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender { + laneRuntime := c.modernPSKRuntime + if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil { + laneRuntime = forked + } + return newBulkDedicatedLaneSender(conn, func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(laneRuntime, batches) + }, func(err error) { + c.handleClientDedicatedSidecarFailure(sidecar, err) + }) + }) + if sender == nil { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + return sender, nil +} + func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error { if c == nil || bulk == nil { return errBulkClientNil } - sender, err := c.dedicatedBulkSender(bulk) + sender, err := c.dedicatedBulkLaneSender(bulk) if err != nil { return err } - return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) + return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) } -func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { +func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { if c == nil || bulk == nil { return 0, errBulkClientNil } - sender, err := c.dedicatedBulkSender(bulk) + sender, err := c.dedicatedBulkLaneSender(bulk) if err != nil { return 0, err } - return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) + return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) } func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error { @@ -575,11 +1222,11 @@ func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHan if full { flags = bulkFastPayloadFlagFullClose } - sender, err := c.dedicatedBulkSender(bulk) + sender, err := c.dedicatedBulkLaneSender(bulk) if err != nil { return err } - return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeClose, flags, 0, nil) } func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error { @@ -591,11 +1238,11 @@ func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHan return err } defer cancel() - sender, err := c.dedicatedBulkSender(bulk) + sender, err := c.dedicatedBulkLaneSender(bulk) if err != nil { return err } - return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeReset, 0, 0, []byte(message)) } func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error { @@ -614,19 +1261,11 @@ func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkH return err } defer cancel() - frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ - Type: bulkFastPayloadTypeRelease, - Payload: payload, - }}) + sender, err := c.dedicatedBulkLaneSender(bulk) if err != nil { return err } - conn := bulk.dedicatedConnSnapshot() - if conn == nil { - return transportDetachedError("dedicated bulk sidecar not attached", nil) - } - deadline, _ := sendCtx.Deadline() - return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeRelease, 0, 0, payload) } func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) { @@ -651,6 +1290,9 @@ func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandl }, func(items []bulkDedicatedSendRequest) ([]byte, error) { return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items) }, func(err error) { + if bulk.canIgnoreDedicatedCloseSendError(err) { + return + } bulk.markReset(err) }) actual := bulk.installDedicatedSender(sender) @@ -660,26 +1302,58 @@ func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandl return actual, nil } +func (s *ServerCommon) dedicatedBulkLaneSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedLaneSender, error) { + if s == nil || bulk == nil { + return nil, errBulkServerNil + } + if logical == nil { + logical = bulk.LogicalConn() + } + if logical == nil { + return nil, errBulkLogicalConnNil + } + sidecar := s.serverDedicatedSidecarSnapshotForLane(logical, bulk.dedicatedLaneIDSnapshot()) + conn := bulk.dedicatedConnSnapshot() + if sidecar == nil || sidecar.conn == nil || conn == nil || sidecar.conn != conn { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + sender := sidecar.laneSenderWithFactory(func(conn net.Conn) *bulkDedicatedLaneSender { + laneRuntime := logical.modernPSKRuntimeSnapshot() + if forked, err := forkDedicatedLaneModernPSKRuntime(laneRuntime); err == nil && forked != nil { + laneRuntime = forked + } + return newBulkDedicatedLaneSender(conn, func(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + return s.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(logical, laneRuntime, batches) + }, func(err error) { + s.handleServerDedicatedSidecarFailure(logical, sidecar, err) + }) + }) + if sender == nil { + return nil, transportDetachedError("dedicated bulk sidecar not attached", nil) + } + return sender, nil +} + func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error { if s == nil || bulk == nil { return errBulkServerNil } - sender, err := s.dedicatedBulkSender(logical, bulk) + sender, err := s.dedicatedBulkLaneSender(logical, bulk) if err != nil { return err } - return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk) + return sender.submitData(ctx, bulk.dataIDSnapshot(), bulk.nextOutboundDataSeq(), chunk) } -func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) { +func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, startSeq uint64, payload []byte) (int, error) { if s == nil || bulk == nil { return 0, errBulkServerNil } - sender, err := s.dedicatedBulkSender(logical, bulk) + sender, err := s.dedicatedBulkLaneSender(logical, bulk) if err != nil { return 0, err } - return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize) + return sender.submitWrite(ctx, bulk.dataIDSnapshot(), startSeq, payload, bulk.chunkSize) } func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error { @@ -695,11 +1369,11 @@ func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *Logi if full { flags = bulkFastPayloadFlagFullClose } - sender, err := s.dedicatedBulkSender(logical, bulk) + sender, err := s.dedicatedBulkLaneSender(logical, bulk) if err != nil { return err } - return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeClose, flags, 0, nil) } func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error { @@ -711,11 +1385,11 @@ func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *Logi return err } defer cancel() - sender, err := s.dedicatedBulkSender(logical, bulk) + sender, err := s.dedicatedBulkLaneSender(logical, bulk) if err != nil { return err } - return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message)) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeReset, 0, 0, []byte(message)) } func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error { @@ -734,25 +1408,22 @@ func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *Lo return err } defer cancel() - frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{ - Type: bulkFastPayloadTypeRelease, - Payload: payload, - }}) + sender, err := s.dedicatedBulkLaneSender(logical, bulk) if err != nil { return err } - conn := bulk.dedicatedConnSnapshot() - if conn == nil { - return transportDetachedError("dedicated bulk sidecar not attached", nil) - } - deadline, _ := sendCtx.Deadline() - return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline) + return sender.submitControl(sendCtx, bulk.dataIDSnapshot(), bulkFastPayloadTypeRelease, 0, 0, payload) } func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if c == nil { return nil, errBulkClientNil } + if runtime := c.modernPSKRuntime; runtime != nil { + return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error { + return writeBulkDedicatedBatchPlain(dst, dataID, items) + }) + } if c.fastPlainEncode != nil { return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items) } @@ -763,6 +1434,48 @@ func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bu return c.encryptTransportPayload(plain) } +func forkDedicatedLaneModernPSKRuntime(base *modernPSKCodecRuntime) (*modernPSKCodecRuntime, error) { + if base == nil { + return nil, nil + } + return base.fork() +} + +func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(runtime *modernPSKCodecRuntime, batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + if c == nil { + return nil, nil, errBulkClientNil + } + if len(batches) == 0 { + return nil, nil, errBulkFastPayloadInvalid + } + if runtime != nil { + return runtime.sealFilledPayloadPooled(bulkDedicatedBatchesPlainLen(batches), func(dst []byte) error { + return writeBulkDedicatedBatchesPlain(dst, batches) + }) + } + if c.fastPlainEncode != nil { + payload, err := encodeBulkDedicatedBatchesPayloadFast(c.fastPlainEncode, c.SecretKey, batches) + return payload, nil, err + } + plain, err := encodeBulkDedicatedBatchesPlain(batches) + if err != nil { + return nil, nil, err + } + payload, err := c.encryptTransportPayload(plain) + return payload, nil, err +} + +func (c *ClientCommon) encodeDedicatedBulkBatchesPayloadPooled(batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + return c.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(c.modernPSKRuntime, batches) +} + +func (c *ClientCommon) encodeDedicatedBulkBatchPayloadPooled(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) { + return c.encodeDedicatedBulkBatchesPayloadPooled([]bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}) +} + func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { if s == nil { return nil, errBulkServerNil @@ -770,6 +1483,11 @@ func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dat if logical == nil { return nil, errBulkLogicalConnNil } + if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { + return runtime.sealFilledPayload(bulkDedicatedBatchPlainLen(items), func(dst []byte) error { + return writeBulkDedicatedBatchPlain(dst, dataID, items) + }) + } if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items) } @@ -779,3 +1497,41 @@ func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dat } return s.encryptTransportPayloadLogical(logical, plain) } + +func (s *ServerCommon) encodeDedicatedBulkBatchesPayloadPooledWithRuntime(logical *LogicalConn, runtime *modernPSKCodecRuntime, batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + if s == nil { + return nil, nil, errBulkServerNil + } + if logical == nil { + return nil, nil, errBulkLogicalConnNil + } + if len(batches) == 0 { + return nil, nil, errBulkFastPayloadInvalid + } + if runtime != nil { + return runtime.sealFilledPayloadPooled(bulkDedicatedBatchesPlainLen(batches), func(dst []byte) error { + return writeBulkDedicatedBatchesPlain(dst, batches) + }) + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + payload, err := encodeBulkDedicatedBatchesPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), batches) + return payload, nil, err + } + plain, err := encodeBulkDedicatedBatchesPlain(batches) + if err != nil { + return nil, nil, err + } + payload, err := s.encryptTransportPayloadLogical(logical, plain) + return payload, nil, err +} + +func (s *ServerCommon) encodeDedicatedBulkBatchesPayloadPooled(logical *LogicalConn, batches []bulkDedicatedOutboundBatch) ([]byte, func(), error) { + return s.encodeDedicatedBulkBatchesPayloadPooledWithRuntime(logical, logical.modernPSKRuntimeSnapshot(), batches) +} + +func (s *ServerCommon) encodeDedicatedBulkBatchPayloadPooled(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, func(), error) { + return s.encodeDedicatedBulkBatchesPayloadPooled(logical, []bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}) +} diff --git a/bulk_dedicated_attach_test.go b/bulk_dedicated_attach_test.go index 80f01ad..643029a 100644 --- a/bulk_dedicated_attach_test.go +++ b/bulk_dedicated_attach_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/binary" + "errors" "net" "testing" "time" @@ -136,6 +137,7 @@ func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHa Dedicated: true, AttachToken: "attach-token", }, 0, target, nil, 0, nil, nil, nil, nil, nil) + bulk.markAcceptHandled() if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil { t.Fatalf("register bulk runtime failed: %v", err) } @@ -215,3 +217,251 @@ func TestHandleBulkAttachSystemMessageAcceptedWritesDirectReplyBeforeDedicatedHa t.Fatalf("attach sidecar logical should be removed after handoff, got %+v", got) } } + +func TestHandleBulkAttachSystemMessageDoesNotExposeSharedSidecarBeforeReplyCompletes(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + sidecarLeft, sidecarRight := net.Pipe() + defer sidecarRight.Close() + + current := server.bootstrapAcceptedLogical("dedicated-attach-current-blocked", nil, sidecarLeft) + if current == nil { + t.Fatal("bootstrapAcceptedLogical(current) should return logical") + } + target := server.bootstrapAcceptedLogical("dedicated-attach-target-blocked", nil, nil) + if target == nil { + t.Fatal("bootstrapAcceptedLogical(target) should return logical") + } + + currentBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{ + BulkID: "server-dedicated-current", + DataID: 17, + Dedicated: true, + AttachToken: "attach-token", + }, 0, target, nil, 0, nil, nil, nil, nil, nil) + currentBulk.markAcceptHandled() + if err := server.getBulkRuntime().register(serverFileScope(target), currentBulk); err != nil { + t.Fatalf("register current bulk runtime failed: %v", err) + } + + pendingBulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{ + BulkID: "server-dedicated-pending", + DataID: 18, + Dedicated: true, + AttachToken: "attach-token-2", + }, 0, target, nil, 0, nil, nil, nil, nil, nil) + pendingBulk.markAcceptHandled() + if err := server.getBulkRuntime().register(serverFileScope(target), pendingBulk); err != nil { + t.Fatalf("register pending bulk runtime failed: %v", err) + } + + reqPayload, err := server.sequenceEn(bulkAttachRequest{ + PeerID: target.ID(), + BulkID: currentBulk.ID(), + AttachToken: "attach-token", + }) + if err != nil { + t.Fatalf("encode bulkAttachRequest failed: %v", err) + } + msg := Message{ + NetType: NET_SERVER, + LogicalConn: current, + ClientConn: current.compatClientConn(), + TransferMsg: TransferMsg{ + ID: 77, + Key: systemBulkAttachKey, + Value: reqPayload, + Type: MSG_SYS_WAIT, + }, + inboundConn: sidecarLeft, + Time: time.Now(), + } + + done := make(chan struct{}) + go func() { + defer close(done) + _ = server.handleBulkAttachSystemMessage(msg) + }() + + time.Sleep(50 * time.Millisecond) + + if got := pendingBulk.dedicatedConnSnapshot(); got != nil { + t.Fatal("pending dedicated bulk should not observe shared sidecar before attach reply is fully sent") + } + if got := server.serverDedicatedSidecarSnapshot(target); got != nil { + t.Fatal("server dedicated sidecar should not be published before attach reply is fully sent") + } + + type attachReplyResult struct { + transfer TransferMsg + resp bulkAttachResponse + err error + } + replyCh := make(chan attachReplyResult, 1) + go func() { + _ = sidecarRight.SetReadDeadline(time.Now().Add(time.Second)) + replyPayload, err := readDirectSignalFramePayload(sidecarRight) + if err != nil { + replyCh <- attachReplyResult{err: err} + return + } + transfer, err := decodeDirectSignalPayload(server.sequenceDe, current.msgDeSnapshot(), current.secretKeySnapshot(), replyPayload) + if err != nil { + replyCh <- attachReplyResult{err: err} + return + } + resp, err := decodeBulkAttachResponse(server.sequenceDe, transfer.Value) + replyCh <- attachReplyResult{transfer: transfer, resp: resp, err: err} + }() + + select { + case result := <-replyCh: + if result.err != nil { + t.Fatalf("read direct attach reply failed: %v", result.err) + } + if !result.resp.Accepted { + t.Fatalf("bulk attach response = %+v, want accepted", result.resp) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for direct attach reply") + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for handleBulkAttachSystemMessage to finish") + } + + if got := pendingBulk.dedicatedConnSnapshot(); got != sidecarLeft { + t.Fatalf("pending dedicated bulk conn mismatch after reply: got %v want %v", got, sidecarLeft) + } + if got := server.serverDedicatedSidecarSnapshot(target); got == nil || got.conn != sidecarLeft { + t.Fatal("server dedicated sidecar should be published after attach reply completes") + } +} + +func TestBulkAttachResponseErrorCarriesStructuredFields(t *testing.T) { + resp := toBulkAttachResponseError(&bulkAttachError{ + Code: bulkAttachErrorCodeTokenMismatch, + Retryable: false, + Message: "bulk attach token mismatch", + FailedSeq: 7, + FailedBulk: "bulk-1", + }, "fallback-bulk") + + if resp.Accepted { + t.Fatalf("Accepted = %v, want false", resp.Accepted) + } + if got, want := resp.Code, string(bulkAttachErrorCodeTokenMismatch); got != want { + t.Fatalf("Code = %q, want %q", got, want) + } + if got, want := resp.Retryable, false; got != want { + t.Fatalf("Retryable = %v, want %v", got, want) + } + if got, want := resp.FailedSeq, uint64(7); got != want { + t.Fatalf("FailedSeq = %d, want %d", got, want) + } + if got, want := resp.FailedBulk, "bulk-1"; got != want { + t.Fatalf("FailedBulk = %q, want %q", got, want) + } +} + +func TestResolveInboundDedicatedBulkRejectsAlreadyAttachedAfterDataStarted(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + currentLeft, currentRight := net.Pipe() + defer currentRight.Close() + current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft) + if current == nil { + t.Fatal("bootstrapAcceptedLogical(current) should return logical") + } + target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil) + if target == nil { + t.Fatal("bootstrapAcceptedLogical(target) should return logical") + } + + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{ + BulkID: "server-dedicated-attach-test", + DataID: 7, + Dedicated: true, + AttachToken: "attach-token", + }, 0, target, nil, 0, nil, nil, nil, nil, nil) + bulk.markAcceptHandled() + if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil { + t.Fatalf("register bulk runtime failed: %v", err) + } + attachedLeft, attachedRight := net.Pipe() + defer attachedRight.Close() + if err := bulk.attachDedicatedConn(attachedLeft); err != nil { + t.Fatalf("attachDedicatedConn failed: %v", err) + } + bulk.markDedicatedDataStarted() + + _, _, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{ + PeerID: target.ID(), + BulkID: bulk.ID(), + AttachToken: "attach-token", + }) + if err == nil { + t.Fatal("resolveInboundDedicatedBulk should reject duplicate attach after data started") + } + var attachErr *bulkAttachError + if !errors.As(err, &attachErr) || attachErr == nil { + t.Fatalf("resolveInboundDedicatedBulk error type = %T, want *bulkAttachError", err) + } + if got, want := attachErr.Code, bulkAttachErrorCodeAlreadyAttached; got != want { + t.Fatalf("attach error code = %q, want %q", got, want) + } + if attachErr.Retryable { + t.Fatalf("attach error retryable = true, want false") + } +} + +func TestResolveInboundDedicatedBulkAllowsReattachBeforeDataStarts(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + currentLeft, currentRight := net.Pipe() + defer currentRight.Close() + current := server.bootstrapAcceptedLogical("dedicated-attach-current", nil, currentLeft) + if current == nil { + t.Fatal("bootstrapAcceptedLogical(current) should return logical") + } + target := server.bootstrapAcceptedLogical("dedicated-attach-target", nil, nil) + if target == nil { + t.Fatal("bootstrapAcceptedLogical(target) should return logical") + } + + bulk := newBulkHandle(context.Background(), server.getBulkRuntime(), serverFileScope(target), BulkOpenRequest{ + BulkID: "server-dedicated-attach-test", + DataID: 7, + Dedicated: true, + AttachToken: "attach-token", + }, 0, target, nil, 0, nil, nil, nil, nil, nil) + bulk.markAcceptHandled() + if err := server.getBulkRuntime().register(serverFileScope(target), bulk); err != nil { + t.Fatalf("register bulk runtime failed: %v", err) + } + attachedLeft, attachedRight := net.Pipe() + defer attachedRight.Close() + if err := bulk.attachDedicatedConn(attachedLeft); err != nil { + t.Fatalf("attachDedicatedConn failed: %v", err) + } + + resolvedLogical, resolvedBulk, err := server.resolveInboundDedicatedBulk(current, bulkAttachRequest{ + PeerID: target.ID(), + BulkID: bulk.ID(), + AttachToken: "attach-token", + }) + if err != nil { + t.Fatalf("resolveInboundDedicatedBulk failed: %v", err) + } + if resolvedLogical != target { + t.Fatalf("resolved logical mismatch: got %v want %v", resolvedLogical, target) + } + if resolvedBulk != bulk { + t.Fatalf("resolved bulk mismatch: got %v want %v", resolvedBulk, bulk) + } +} diff --git a/bulk_dedicated_batch.go b/bulk_dedicated_batch.go index 50490c0..5738c88 100644 --- a/bulk_dedicated_batch.go +++ b/bulk_dedicated_batch.go @@ -12,14 +12,18 @@ import ( ) const ( - bulkDedicatedBatchMagic = "NBD2" - bulkDedicatedBatchVersion = 1 - bulkDedicatedBatchHeaderLen = 20 - bulkDedicatedBatchItemHeaderLen = 16 - bulkDedicatedBatchMaxItems = 32 - bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024 - bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems - bulkDedicatedReleasePayloadLen = 12 + bulkDedicatedBatchMagic = "NBD2" + bulkDedicatedBatchVersion = 1 + bulkDedicatedBatchHeaderLen = 20 + bulkDedicatedBatchItemHeaderLen = 16 + bulkDedicatedSuperBatchMagic = "NBD3" + bulkDedicatedSuperBatchVersion = 1 + bulkDedicatedSuperBatchHeaderLen = 12 + bulkDedicatedSuperBatchGroupHeaderLen = 12 + bulkDedicatedBatchMaxItems = 64 + bulkDedicatedBatchMaxPlainBytes = 16 * 1024 * 1024 + bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems + bulkDedicatedReleasePayloadLen = 12 ) const ( @@ -46,6 +50,16 @@ type bulkDedicatedSendRequest struct { Payload []byte } +type bulkDedicatedOutboundBatch struct { + DataID uint64 + Items []bulkDedicatedSendRequest +} + +type bulkDedicatedInboundBatch struct { + DataID uint64 + Items []bulkDedicatedBatchItem +} + type bulkDedicatedBatchRequest struct { Ctx context.Context Items []bulkDedicatedSendRequest @@ -165,7 +179,7 @@ func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulk if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted { return err } - queuedItems := cloneBulkDedicatedSendRequests(items) + queuedItems := copyBulkDedicatedSendRequests(items) return s.submitBatch(ctx, queuedItems, true) } @@ -471,6 +485,23 @@ func cloneBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedi return cloned } +func copyBulkDedicatedSendRequests(items []bulkDedicatedSendRequest) []bulkDedicatedSendRequest { + if len(items) == 0 { + return nil + } + copied := make([]bulkDedicatedSendRequest, len(items)) + copy(copied, items) + return copied +} + +func bulkDedicatedSendRequestsLen(items []bulkDedicatedSendRequest) int { + total := 0 + for _, item := range items { + total += bulkDedicatedSendRequestLen(item) + } + return total +} + func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int { return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload)) } @@ -479,6 +510,83 @@ func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int { return bulkDedicatedBatchItemHeaderLen + payloadLen } +func bulkDedicatedBatchesPlainLen(batches []bulkDedicatedOutboundBatch) int { + switch len(batches) { + case 0: + return 0 + case 1: + return bulkDedicatedBatchPlainLen(batches[0].Items) + default: + total := bulkDedicatedSuperBatchHeaderLen + for _, batch := range batches { + total += bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batch.Items) + } + return total + } +} + +func encodeBulkDedicatedBatchesPlain(batches []bulkDedicatedOutboundBatch) ([]byte, error) { + if len(batches) == 0 { + return nil, errBulkFastPayloadInvalid + } + total := bulkDedicatedBatchesPlainLen(batches) + buf := make([]byte, total) + if err := writeBulkDedicatedBatchesPlain(buf, batches); err != nil { + return nil, err + } + return buf, nil +} + +func writeBulkDedicatedBatchesPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error { + switch len(batches) { + case 0: + return errBulkFastPayloadInvalid + case 1: + return writeBulkDedicatedBatchPlain(buf, batches[0].DataID, batches[0].Items) + default: + return writeBulkDedicatedSuperBatchPlain(buf, batches) + } +} + +func writeBulkDedicatedSuperBatchPlain(buf []byte, batches []bulkDedicatedOutboundBatch) error { + if len(batches) <= 1 { + return errBulkFastPayloadInvalid + } + if len(buf) != bulkDedicatedBatchesPlainLen(batches) || len(buf) > bulkDedicatedBatchMaxPlainBytes { + return errBulkFastPayloadInvalid + } + copy(buf[:4], bulkDedicatedSuperBatchMagic) + buf[4] = bulkDedicatedSuperBatchVersion + binary.BigEndian.PutUint32(buf[8:12], uint32(len(batches))) + offset := bulkDedicatedSuperBatchHeaderLen + totalItems := 0 + for _, batch := range batches { + if batch.DataID == 0 || len(batch.Items) == 0 { + return errBulkFastPayloadInvalid + } + totalItems += len(batch.Items) + if totalItems > bulkDedicatedBatchMaxItems { + return errBulkFastPayloadInvalid + } + binary.BigEndian.PutUint64(buf[offset:offset+8], batch.DataID) + binary.BigEndian.PutUint32(buf[offset+8:offset+12], uint32(len(batch.Items))) + offset += bulkDedicatedSuperBatchGroupHeaderLen + for _, item := range batch.Items { + buf[offset] = item.Type + buf[offset+1] = item.Flags + binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq) + binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload))) + offset += bulkDedicatedBatchItemHeaderLen + copy(buf[offset:offset+len(item.Payload)], item.Payload) + offset += len(item.Payload) + } + } + if offset != len(buf) { + return errBulkFastPayloadInvalid + } + return nil +} + func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) { if bytes <= 0 && chunks <= 0 { return nil, errBulkFastPayloadInvalid @@ -505,24 +613,26 @@ func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) { } func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { - if dataID == 0 || len(items) == 0 { - return nil, errBulkFastPayloadInvalid - } - total := bulkDedicatedBatchPlainLen(items) - buf := make([]byte, total) - if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil { - return nil, err - } - return buf, nil + return encodeBulkDedicatedBatchesPlain([]bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}) } func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) { + return encodeBulkDedicatedBatchesPayloadFast(encode, secretKey, []bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}) +} + +func encodeBulkDedicatedBatchesPayloadFast(encode transportFastPlainEncoder, secretKey []byte, batches []bulkDedicatedOutboundBatch) ([]byte, error) { if encode == nil { return nil, errTransportPayloadEncryptFailed } - plainLen := bulkDedicatedBatchPlainLen(items) + plainLen := bulkDedicatedBatchesPlainLen(batches) return encode(secretKey, plainLen, func(dst []byte) error { - return writeBulkDedicatedBatchPlain(dst, dataID, items) + return writeBulkDedicatedBatchesPlain(dst, batches) }) } @@ -606,29 +716,268 @@ func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatch return dataID, items, true, nil } -func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) { +func decodeBulkDedicatedSuperBatchPlain(payload []byte) ([]bulkDedicatedInboundBatch, bool, error) { + if len(payload) < 4 || string(payload[:4]) != bulkDedicatedSuperBatchMagic { + return nil, false, nil + } + if len(payload) < bulkDedicatedSuperBatchHeaderLen { + return nil, true, errBulkFastPayloadInvalid + } + if payload[4] != bulkDedicatedSuperBatchVersion { + return nil, true, errBulkFastPayloadInvalid + } + groupCount := int(binary.BigEndian.Uint32(payload[8:12])) + if groupCount <= 0 { + return nil, true, errBulkFastPayloadInvalid + } + batches := make([]bulkDedicatedInboundBatch, 0, groupCount) + offset := bulkDedicatedSuperBatchHeaderLen + totalItems := 0 + for i := 0; i < groupCount; i++ { + if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen { + return nil, true, errBulkFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[offset : offset+8]) + count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12])) + offset += bulkDedicatedSuperBatchGroupHeaderLen + if dataID == 0 || count <= 0 { + return nil, true, errBulkFastPayloadInvalid + } + totalItems += count + if totalItems > bulkDedicatedBatchMaxItems { + return nil, true, errBulkFastPayloadInvalid + } + items := make([]bulkDedicatedBatchItem, 0, count) + for j := 0; j < count; j++ { + if len(payload)-offset < bulkDedicatedBatchItemHeaderLen { + return nil, true, errBulkFastPayloadInvalid + } + itemType := payload[offset] + switch itemType { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return nil, true, errBulkFastPayloadInvalid + } + flags := payload[offset+1] + seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16])) + offset += bulkDedicatedBatchItemHeaderLen + if dataLen < 0 || len(payload)-offset < dataLen { + return nil, true, errBulkFastPayloadInvalid + } + items = append(items, bulkDedicatedBatchItem{ + Type: itemType, + Flags: flags, + Seq: seq, + Payload: payload[offset : offset+dataLen], + }) + offset += dataLen + } + batches = append(batches, bulkDedicatedInboundBatch{ + DataID: dataID, + Items: items, + }) + } + if offset != len(payload) { + return nil, true, errBulkFastPayloadInvalid + } + return batches, true, nil +} + +func decodeDedicatedBulkInboundBatches(plain []byte) ([]bulkDedicatedInboundBatch, error) { + if batches, matched, err := decodeBulkDedicatedSuperBatchPlain(plain); matched { + if err != nil { + return nil, err + } + return batches, nil + } if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched { if err != nil { return nil, err } - if expectedDataID == 0 || dataID != expectedDataID { - return nil, errBulkFastPayloadInvalid - } - return items, nil + return []bulkDedicatedInboundBatch{{ + DataID: dataID, + Items: items, + }}, nil } frame, matched, err := decodeBulkFastFrame(plain) if err != nil { return nil, err } - if !matched || expectedDataID == 0 || frame.DataID != expectedDataID { + if !matched || frame.DataID == 0 { return nil, errBulkFastPayloadInvalid } - return []bulkDedicatedBatchItem{{ + return []bulkDedicatedInboundBatch{{ + DataID: frame.DataID, + Items: []bulkDedicatedBatchItem{{ + Type: frame.Type, + Flags: frame.Flags, + Seq: frame.Seq, + Payload: frame.Payload, + }}, + }}, nil +} + +func decodeDedicatedBulkInboundPayload(plain []byte) (uint64, []bulkDedicatedBatchItem, error) { + batches, err := decodeDedicatedBulkInboundBatches(plain) + if err != nil { + return 0, nil, err + } + if len(batches) != 1 { + return 0, nil, errBulkFastPayloadInvalid + } + return batches[0].DataID, batches[0].Items, nil +} + +func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) { + dataID, items, err := decodeDedicatedBulkInboundPayload(plain) + if err != nil { + return nil, err + } + if expectedDataID == 0 || dataID != expectedDataID { + return nil, errBulkFastPayloadInvalid + } + return items, nil +} + +func hasBulkDedicatedMagic(payload []byte, magic string) bool { + return len(payload) >= 4 && + payload[0] == magic[0] && + payload[1] == magic[1] && + payload[2] == magic[2] && + payload[3] == magic[3] +} + +func walkDedicatedBulkInboundPayload(plain []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error { + if visit == nil { + return errBulkFastPayloadInvalid + } + if hasBulkDedicatedMagic(plain, bulkDedicatedSuperBatchMagic) { + return walkDedicatedBulkInboundSuperBatchPlain(plain, visit) + } + if hasBulkDedicatedMagic(plain, bulkDedicatedBatchMagic) { + return walkDedicatedBulkInboundBatchPlain(plain, visit) + } + frame, matched, err := decodeBulkFastFrame(plain) + if err != nil { + return err + } + if !matched || frame.DataID == 0 { + return errBulkFastPayloadInvalid + } + return visit(frame.DataID, bulkDedicatedBatchItem{ Type: frame.Type, Flags: frame.Flags, Seq: frame.Seq, Payload: frame.Payload, - }}, nil + }) +} + +func walkDedicatedBulkInboundBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error { + if len(payload) < bulkDedicatedBatchHeaderLen { + return errBulkFastPayloadInvalid + } + if payload[4] != bulkDedicatedBatchVersion { + return errBulkFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[8:16]) + count := int(binary.BigEndian.Uint32(payload[16:20])) + if dataID == 0 || count <= 0 { + return errBulkFastPayloadInvalid + } + offset := bulkDedicatedBatchHeaderLen + for i := 0; i < count; i++ { + if len(payload)-offset < bulkDedicatedBatchItemHeaderLen { + return errBulkFastPayloadInvalid + } + itemType := payload[offset] + switch itemType { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return errBulkFastPayloadInvalid + } + flags := payload[offset+1] + seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16])) + offset += bulkDedicatedBatchItemHeaderLen + if dataLen < 0 || len(payload)-offset < dataLen { + return errBulkFastPayloadInvalid + } + if err := visit(dataID, bulkDedicatedBatchItem{ + Type: itemType, + Flags: flags, + Seq: seq, + Payload: payload[offset : offset+dataLen], + }); err != nil { + return err + } + offset += dataLen + } + if offset != len(payload) { + return errBulkFastPayloadInvalid + } + return nil +} + +func walkDedicatedBulkInboundSuperBatchPlain(payload []byte, visit func(dataID uint64, item bulkDedicatedBatchItem) error) error { + if len(payload) < bulkDedicatedSuperBatchHeaderLen { + return errBulkFastPayloadInvalid + } + if payload[4] != bulkDedicatedSuperBatchVersion { + return errBulkFastPayloadInvalid + } + groupCount := int(binary.BigEndian.Uint32(payload[8:12])) + if groupCount <= 0 { + return errBulkFastPayloadInvalid + } + offset := bulkDedicatedSuperBatchHeaderLen + totalItems := 0 + for i := 0; i < groupCount; i++ { + if len(payload)-offset < bulkDedicatedSuperBatchGroupHeaderLen { + return errBulkFastPayloadInvalid + } + dataID := binary.BigEndian.Uint64(payload[offset : offset+8]) + count := int(binary.BigEndian.Uint32(payload[offset+8 : offset+12])) + offset += bulkDedicatedSuperBatchGroupHeaderLen + if dataID == 0 || count <= 0 { + return errBulkFastPayloadInvalid + } + totalItems += count + if totalItems > bulkDedicatedBatchMaxItems { + return errBulkFastPayloadInvalid + } + for j := 0; j < count; j++ { + if len(payload)-offset < bulkDedicatedBatchItemHeaderLen { + return errBulkFastPayloadInvalid + } + itemType := payload[offset] + switch itemType { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return errBulkFastPayloadInvalid + } + flags := payload[offset+1] + seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16])) + offset += bulkDedicatedBatchItemHeaderLen + if dataLen < 0 || len(payload)-offset < dataLen { + return errBulkFastPayloadInvalid + } + if err := visit(dataID, bulkDedicatedBatchItem{ + Type: itemType, + Flags: flags, + Seq: seq, + Payload: payload[offset : offset+dataLen], + }); err != nil { + return err + } + offset += dataLen + } + } + if offset != len(payload) { + return errBulkFastPayloadInvalid + } + return nil } func normalizeDedicatedBulkSendError(err error) error { @@ -643,13 +992,23 @@ func normalizeDedicatedBulkSendError(err error) error { } func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error { + return dispatchDedicatedBulkInboundItemWithRelease(bulk, item, nil) +} + +func dispatchDedicatedBulkInboundItemWithRelease(bulk *bulkHandle, item bulkDedicatedBatchItem, release func()) error { if bulk == nil { + if release != nil { + release() + } return io.ErrClosedPipe } switch item.Type { case bulkFastPayloadTypeData: - return bulk.pushOwnedChunkNoReset(item.Payload) + return bulk.pushOwnedChunkWithReleaseNoReset(item.Payload, release) case bulkFastPayloadTypeClose: + if release != nil { + release() + } if item.Flags&bulkFastPayloadFlagFullClose != 0 { bulk.markPeerClosed() return nil @@ -657,6 +1016,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI bulk.markRemoteClosed() return nil case bulkFastPayloadTypeReset: + if release != nil { + release() + } resetErr := errBulkReset if len(item.Payload) > 0 { resetErr = bulkRemoteResetError(string(item.Payload)) @@ -664,6 +1026,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI bulk.markReset(bulkResetError(resetErr)) return nil case bulkFastPayloadTypeRelease: + if release != nil { + release() + } bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload) if err != nil { return err @@ -671,6 +1036,9 @@ func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchI bulk.releaseOutboundWindow(bytes, chunks) return nil default: + if release != nil { + release() + } return errBulkFastPayloadInvalid } } diff --git a/bulk_dedicated_lane_sender.go b/bulk_dedicated_lane_sender.go new file mode 100644 index 0000000..b563bc6 --- /dev/null +++ b/bulk_dedicated_lane_sender.go @@ -0,0 +1,639 @@ +package notify + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" +) + +const bulkDedicatedLaneMicroBatchWait = 50 * time.Microsecond + +type bulkDedicatedLaneBatchRequest struct { + Ctx context.Context + DataID uint64 + Items []bulkDedicatedSendRequest + Deadline time.Time + wait bool + resultCh chan error + state bulkDedicatedRequestState + aborted atomic.Bool +} + +var bulkDedicatedLaneBatchRequestPool sync.Pool + +type bulkDedicatedLaneSender struct { + conn net.Conn + encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error) + fail func(error) + + reqCh chan *bulkDedicatedLaneBatchRequest + stopCh chan struct{} + doneCh chan struct{} + stopOnce sync.Once + flushMu sync.Mutex + queued atomic.Int64 + + errMu sync.Mutex + err error +} + +func newBulkDedicatedLaneSender(conn net.Conn, encode func([]bulkDedicatedOutboundBatch) ([]byte, func(), error), fail func(error)) *bulkDedicatedLaneSender { + sender := &bulkDedicatedLaneSender{ + conn: conn, + encode: encode, + fail: fail, + reqCh: make(chan *bulkDedicatedLaneBatchRequest, bulkDedicatedSendQueueSize), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go sender.run() + return sender +} + +func getBulkDedicatedLaneBatchRequest() *bulkDedicatedLaneBatchRequest { + if pooled, ok := bulkDedicatedLaneBatchRequestPool.Get().(*bulkDedicatedLaneBatchRequest); ok && pooled != nil { + pooled.reset() + return pooled + } + req := &bulkDedicatedLaneBatchRequest{ + resultCh: make(chan error, 1), + } + req.reset() + return req +} + +func (r *bulkDedicatedLaneBatchRequest) reset() { + if r == nil { + return + } + r.Ctx = nil + r.DataID = 0 + r.Deadline = time.Time{} + r.wait = false + r.Items = r.Items[:0] + r.state.value.Store(bulkDedicatedRequestQueued) + r.aborted.Store(false) + if r.resultCh != nil { + select { + case <-r.resultCh: + default: + } + } +} + +func (r *bulkDedicatedLaneBatchRequest) prepare(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) { + if r == nil { + return + } + r.reset() + r.Ctx = ctx + r.DataID = dataID + r.wait = wait + if deadline, ok := ctx.Deadline(); ok { + r.Deadline = deadline + } + if cap(r.Items) < len(items) { + r.Items = make([]bulkDedicatedSendRequest, len(items)) + } else { + r.Items = r.Items[:len(items)] + } + copy(r.Items, items) +} + +func (r *bulkDedicatedLaneBatchRequest) recycle() { + if r == nil { + return + } + r.reset() + bulkDedicatedLaneBatchRequestPool.Put(r) +} + +func (s *bulkDedicatedLaneSender) submitData(ctx context.Context, dataID uint64, seq uint64, payload []byte) error { + if s == nil { + return errTransportDetached + } + items := []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: append([]byte(nil), payload...), + }} + return s.submitBatch(ctx, dataID, items, false) +} + +func (s *bulkDedicatedLaneSender) submitWrite(ctx context.Context, dataID uint64, startSeq uint64, payload []byte, chunkSize int) (int, error) { + if s == nil { + return 0, errTransportDetached + } + if len(payload) == 0 { + return 0, nil + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + seq := startSeq + for written < len(payload) { + var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest + items := itemBuf[:0] + batchBytes := bulkDedicatedBatchHeaderLen + start := written + for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written) + if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes { + break + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + batchBytes += itemLen + seq++ + written = end + } + if len(items) == 0 { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + items = append(items, bulkDedicatedSendRequest{ + Type: bulkFastPayloadTypeData, + Seq: seq, + Payload: payload[written:end], + }) + seq++ + written = end + } + if err := s.submitWriteBatch(ctx, dataID, items); err != nil { + return start, err + } + start = written + } + return written, nil +} + +func (s *bulkDedicatedLaneSender) submitWriteBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) error { + if s == nil { + return errTransportDetached + } + if len(items) == 0 { + return nil + } + if submitted, err := s.tryDirectSubmitBatch(ctx, dataID, items); submitted { + return err + } + queuedItems := copyBulkDedicatedSendRequests(items) + return s.submitBatch(ctx, dataID, queuedItems, true) +} + +func (s *bulkDedicatedLaneSender) submitControl(ctx context.Context, dataID uint64, frameType uint8, flags uint8, seq uint64, payload []byte) error { + if s == nil { + return errTransportDetached + } + items := []bulkDedicatedSendRequest{{ + Type: frameType, + Flags: flags, + Seq: seq, + }} + if len(payload) > 0 { + items[0].Payload = append([]byte(nil), payload...) + } + return s.submitBatch(ctx, dataID, items, true) +} + +func (s *bulkDedicatedLaneSender) submitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest, wait bool) error { + if s == nil { + return errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + if err := s.errSnapshot(); err != nil { + return err + } + req := getBulkDedicatedLaneBatchRequest() + req.prepare(ctx, dataID, items, wait) + s.queued.Add(1) + select { + case <-ctx.Done(): + s.queued.Add(-1) + req.recycle() + return normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + s.queued.Add(-1) + req.recycle() + return s.stoppedErr() + case s.reqCh <- req: + if !wait { + return nil + } + return s.waitAck(req) + } +} + +func (s *bulkDedicatedLaneSender) tryDirectSubmitBatch(ctx context.Context, dataID uint64, items []bulkDedicatedSendRequest) (bool, error) { + if s == nil { + return true, errTransportDetached + } + if ctx == nil { + ctx = context.Background() + } + if len(items) == 0 { + return true, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-ctx.Done(): + return true, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + if s.queued.Load() != 0 { + return false, nil + } + if !s.flushMu.TryLock() { + return false, nil + } + defer s.flushMu.Unlock() + if s.queued.Load() != 0 { + return false, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-ctx.Done(): + return true, normalizeStreamDeadlineError(ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + deadline, _ := ctx.Deadline() + if err := s.flush([]bulkDedicatedOutboundBatch{{ + DataID: dataID, + Items: items, + }}, deadline); err != nil { + err = normalizeDedicatedBulkSendError(err) + s.setErr(err) + s.failPending(err) + if s.fail != nil { + go s.fail(err) + } + return true, err + } + return true, nil +} + +func (s *bulkDedicatedLaneSender) waitAck(req *bulkDedicatedLaneBatchRequest) error { + if s == nil { + return errTransportDetached + } + if req == nil { + return errTransportDetached + } + ctx := req.Ctx + if ctx == nil { + ctx = context.Background() + } + select { + case err := <-req.resultCh: + req.recycle() + return normalizeDedicatedBulkSendError(err) + case <-ctx.Done(): + if req.tryCancel() { + req.aborted.Store(true) + return normalizeStreamDeadlineError(ctx.Err()) + } + err := <-req.resultCh + req.recycle() + return normalizeDedicatedBulkSendError(err) + } +} + +func (s *bulkDedicatedLaneSender) stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + s.setErr(errTransportDetached) + close(s.stopCh) + }) + <-s.doneCh +} + +func (s *bulkDedicatedLaneSender) run() { + defer close(s.doneCh) + var carry *bulkDedicatedLaneBatchRequest + for { + req, ok := s.nextRequest(carry) + carry = nil + if !ok { + return + } + if !req.tryStart() { + s.finishRequest(req, req.canceledErr()) + continue + } + if err := req.contextErr(); err != nil { + s.finishRequest(req, err) + continue + } + batchReqs := []*bulkDedicatedLaneBatchRequest{req} + batches := []bulkDedicatedOutboundBatch{{ + DataID: req.DataID, + Items: req.Items, + }} + batchBytes := bulkDedicatedBatchesPlainLen(batches) + deadline := req.Deadline + s.flushMu.Lock() + err := s.errSnapshot() + if err == nil { + carry, err = s.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline) + if err == nil { + err = s.flush(batches, deadline) + } + } + s.flushMu.Unlock() + if err != nil { + err = normalizeDedicatedBulkSendError(err) + s.setErr(err) + s.finishBatchRequests(batchReqs, err) + if carry != nil { + s.finishRequest(carry, err) + carry = nil + } + s.failPending(err) + if s.fail != nil { + go s.fail(err) + } + return + } + s.finishBatchRequests(batchReqs, nil) + } +} + +func appendBulkDedicatedLaneBatch(batches *[]bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest) { + if req == nil { + return + } + if len(*batches) > 0 && (*batches)[len(*batches)-1].DataID == req.DataID { + last := &(*batches)[len(*batches)-1] + last.Items = append(last.Items, req.Items...) + return + } + *batches = append(*batches, bulkDedicatedOutboundBatch{ + DataID: req.DataID, + Items: req.Items, + }) +} + +func bulkDedicatedLaneBatchItemCount(batches []bulkDedicatedOutboundBatch) int { + total := 0 + for _, batch := range batches { + total += len(batch.Items) + } + return total +} + +func bulkDedicatedLaneNextBatchBytes(batches []bulkDedicatedOutboundBatch, req *bulkDedicatedLaneBatchRequest, currentBytes int) int { + reqBytes := bulkDedicatedSendRequestsLen(req.Items) + if len(batches) == 0 { + return bulkDedicatedBatchHeaderLen + reqBytes + } + last := batches[len(batches)-1] + if last.DataID == req.DataID { + return currentBytes + reqBytes + } + if len(batches) == 1 { + return bulkDedicatedSuperBatchHeaderLen + + bulkDedicatedSuperBatchGroupHeaderLen + bulkDedicatedSendRequestsLen(batches[0].Items) + + bulkDedicatedSuperBatchGroupHeaderLen + reqBytes + } + return currentBytes + bulkDedicatedSuperBatchGroupHeaderLen + reqBytes +} + +func (r *bulkDedicatedLaneBatchRequest) contextErr() error { + if r.Ctx == nil { + return nil + } + select { + case <-r.Ctx.Done(): + return normalizeStreamDeadlineError(r.Ctx.Err()) + default: + return nil + } +} + +func (r *bulkDedicatedLaneBatchRequest) tryStart() bool { + if r == nil { + return false + } + return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted) +} + +func (r *bulkDedicatedLaneBatchRequest) tryCancel() bool { + if r == nil { + return false + } + return r.state.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled) +} + +func (r *bulkDedicatedLaneBatchRequest) canceledErr() error { + if r == nil { + return context.Canceled + } + if err := r.contextErr(); err != nil { + return err + } + return context.Canceled +} + +func (s *bulkDedicatedLaneSender) nextRequest(carry *bulkDedicatedLaneBatchRequest) (*bulkDedicatedLaneBatchRequest, bool) { + if carry != nil { + select { + case <-s.stopCh: + err := s.stoppedErr() + s.finishRequest(carry, err) + s.failPending(err) + return nil, false + default: + return carry, true + } + } + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return nil, false + case req := <-s.reqCh: + return req, true + } +} + +func (s *bulkDedicatedLaneSender) collectBatchRequests(batchReqs *[]*bulkDedicatedLaneBatchRequest, batches *[]bulkDedicatedOutboundBatch, batchBytes *int, deadline *time.Time) (*bulkDedicatedLaneBatchRequest, error) { + if s == nil || bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes { + return nil, nil + } + wait := bulkDedicatedLaneMicroBatchWait + var ( + timer *time.Timer + timerCh <-chan time.Time + waited bool + ) + if wait > 0 { + timer = time.NewTimer(wait) + timerCh = timer.C + defer timer.Stop() + } + for { + if bulkDedicatedLaneBatchItemCount(*batches) >= bulkDedicatedBatchMaxItems || *batchBytes >= bulkDedicatedBatchMaxPlainBytes { + return nil, nil + } + var ( + req *bulkDedicatedLaneBatchRequest + ok bool + ) + select { + case <-s.stopCh: + return nil, s.stoppedErr() + case req = <-s.reqCh: + ok = true + default: + } + if !ok { + if waited || timerCh == nil { + return nil, nil + } + waited = true + select { + case <-s.stopCh: + return nil, s.stoppedErr() + case req = <-s.reqCh: + case <-timerCh: + return nil, nil + } + } + if err := req.contextErr(); err != nil { + if req.tryCancel() { + s.finishRequest(req, err) + continue + } + s.finishRequest(req, req.canceledErr()) + continue + } + nextItems := bulkDedicatedLaneBatchItemCount(*batches) + len(req.Items) + if nextItems > bulkDedicatedBatchMaxItems { + return req, nil + } + nextBytes := bulkDedicatedLaneNextBatchBytes(*batches, req, *batchBytes) + if nextBytes > bulkDedicatedBatchMaxPlainBytes { + return req, nil + } + if !req.tryStart() { + s.finishRequest(req, req.canceledErr()) + continue + } + if err := req.contextErr(); err != nil { + s.finishRequest(req, err) + continue + } + *batchReqs = append(*batchReqs, req) + appendBulkDedicatedLaneBatch(batches, req) + *batchBytes = nextBytes + *deadline = earlierDeadline(*deadline, req.Deadline) + } +} + +func earlierDeadline(current time.Time, next time.Time) time.Time { + if current.IsZero() { + return next + } + if next.IsZero() || current.Before(next) { + return current + } + return next +} + +func (s *bulkDedicatedLaneSender) flush(batches []bulkDedicatedOutboundBatch, deadline time.Time) error { + if s == nil || s.conn == nil { + return errTransportDetached + } + payload, release, err := s.encode(batches) + if err != nil { + return err + } + if release != nil { + defer release() + } + return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline) +} + +func (s *bulkDedicatedLaneSender) finishRequest(req *bulkDedicatedLaneBatchRequest, err error) { + if s != nil { + s.queued.Add(-1) + } + if req == nil { + return + } + if req.wait && req.resultCh != nil { + if req.aborted.Load() { + req.recycle() + return + } + req.resultCh <- err + return + } + req.recycle() +} + +func (s *bulkDedicatedLaneSender) finishBatchRequests(reqs []*bulkDedicatedLaneBatchRequest, err error) { + for _, req := range reqs { + s.finishRequest(req, err) + } +} + +func (s *bulkDedicatedLaneSender) failPending(err error) { + for { + select { + case item := <-s.reqCh: + s.finishRequest(item, err) + default: + return + } + } +} + +func (s *bulkDedicatedLaneSender) setErr(err error) { + if s == nil || err == nil { + return + } + s.errMu.Lock() + if s.err == nil { + s.err = err + } + s.errMu.Unlock() +} + +func (s *bulkDedicatedLaneSender) errSnapshot() error { + if s == nil { + return errTransportDetached + } + s.errMu.Lock() + defer s.errMu.Unlock() + return s.err +} + +func (s *bulkDedicatedLaneSender) stoppedErr() error { + if err := s.errSnapshot(); err != nil { + return err + } + return errTransportDetached +} diff --git a/bulk_dedicated_lane_sender_test.go b/bulk_dedicated_lane_sender_test.go new file mode 100644 index 0000000..b29d3a6 --- /dev/null +++ b/bulk_dedicated_lane_sender_test.go @@ -0,0 +1,145 @@ +package notify + +import ( + "bytes" + "context" + "testing" + "time" +) + +func TestBulkDedicatedLaneSenderCollectBatchRequestsBatchesAcrossDataIDs(t *testing.T) { + sender := &bulkDedicatedLaneSender{ + reqCh: make(chan *bulkDedicatedLaneBatchRequest, 3), + stopCh: make(chan struct{}), + } + + first := &bulkDedicatedLaneBatchRequest{ + Ctx: context.Background(), + DataID: 7, + Items: []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: 1, + Payload: []byte("a"), + }}, + Deadline: time.Now().Add(2 * time.Second), + } + if !first.tryStart() { + t.Fatal("first request should start") + } + + second := &bulkDedicatedLaneBatchRequest{ + Ctx: context.Background(), + DataID: 7, + Items: []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: 2, + Payload: []byte("b"), + }}, + Deadline: time.Now().Add(time.Second), + } + third := &bulkDedicatedLaneBatchRequest{ + Ctx: context.Background(), + DataID: 8, + Items: []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: 3, + Payload: []byte("c"), + }}, + Deadline: time.Now().Add(3 * time.Second), + } + + sender.reqCh <- second + sender.reqCh <- third + + batchReqs := []*bulkDedicatedLaneBatchRequest{first} + batches := []bulkDedicatedOutboundBatch{{ + DataID: first.DataID, + Items: first.Items, + }} + batchBytes := bulkDedicatedBatchesPlainLen(batches) + deadline := first.Deadline + + carry, err := sender.collectBatchRequests(&batchReqs, &batches, &batchBytes, &deadline) + if err != nil { + t.Fatalf("collectBatchRequests error = %v", err) + } + if carry != nil { + t.Fatalf("carry request = %+v, want nil", carry) + } + if len(batchReqs) != 3 { + t.Fatalf("batched request count = %d, want 3", len(batchReqs)) + } + if len(batches) != 2 { + t.Fatalf("batched group count = %d, want 2", len(batches)) + } + if got, want := batches[0].DataID, uint64(7); got != want { + t.Fatalf("first batch dataID = %d, want %d", got, want) + } + if got, want := len(batches[0].Items), 2; got != want { + t.Fatalf("first batch item count = %d, want %d", got, want) + } + if got, want := batches[1].DataID, uint64(8); got != want { + t.Fatalf("second batch dataID = %d, want %d", got, want) + } + if got, want := len(batches[1].Items), 1; got != want { + t.Fatalf("second batch item count = %d, want %d", got, want) + } + if got, want := batches[0].Items[0].Seq, uint64(1); got != want { + t.Fatalf("first batch seq = %d, want %d", got, want) + } + if got, want := batches[0].Items[1].Seq, uint64(2); got != want { + t.Fatalf("second batch seq = %d, want %d", got, want) + } + if got, want := batches[1].Items[0].Seq, uint64(3); got != want { + t.Fatalf("third batch seq = %d, want %d", got, want) + } + if !deadline.Equal(second.Deadline) { + t.Fatalf("merged deadline = %v, want %v", deadline, second.Deadline) + } +} + +func TestDedicatedBulkSuperBatchRoundTrip(t *testing.T) { + batches := []bulkDedicatedOutboundBatch{ + { + DataID: 7, + Items: []bulkDedicatedSendRequest{ + {Type: bulkFastPayloadTypeData, Seq: 1, Payload: []byte("alpha")}, + {Type: bulkFastPayloadTypeData, Seq: 2, Payload: []byte("beta")}, + }, + }, + { + DataID: 8, + Items: []bulkDedicatedSendRequest{ + {Type: bulkFastPayloadTypeClose, Flags: bulkFastPayloadFlagFullClose, Seq: 3}, + {Type: bulkFastPayloadTypeRelease, Seq: 4, Payload: []byte{1, 2, 3, 4}}, + }, + }, + } + + plain, err := encodeBulkDedicatedBatchesPlain(batches) + if err != nil { + t.Fatalf("encodeBulkDedicatedBatchesPlain error = %v", err) + } + got, err := decodeDedicatedBulkInboundBatches(plain) + if err != nil { + t.Fatalf("decodeDedicatedBulkInboundBatches error = %v", err) + } + if len(got) != 2 { + t.Fatalf("decoded batch count = %d, want 2", len(got)) + } + if got[0].DataID != 7 || got[1].DataID != 8 { + t.Fatalf("decoded dataIDs = %d,%d, want 7,8", got[0].DataID, got[1].DataID) + } + if len(got[0].Items) != 2 || len(got[1].Items) != 2 { + t.Fatalf("decoded item counts = %d,%d, want 2,2", len(got[0].Items), len(got[1].Items)) + } + if !bytes.Equal(got[0].Items[0].Payload, []byte("alpha")) || !bytes.Equal(got[0].Items[1].Payload, []byte("beta")) { + t.Fatalf("decoded first batch payloads mismatch: %+v", got[0].Items) + } + if got[1].Items[0].Type != bulkFastPayloadTypeClose || got[1].Items[0].Flags != bulkFastPayloadFlagFullClose { + t.Fatalf("decoded close item mismatch: %+v", got[1].Items[0]) + } + if got[1].Items[1].Type != bulkFastPayloadTypeRelease || !bytes.Equal(got[1].Items[1].Payload, []byte{1, 2, 3, 4}) { + t.Fatalf("decoded release item mismatch: %+v", got[1].Items[1]) + } +} diff --git a/bulk_dedicated_sidecar.go b/bulk_dedicated_sidecar.go new file mode 100644 index 0000000..28d0ad5 --- /dev/null +++ b/bulk_dedicated_sidecar.go @@ -0,0 +1,451 @@ +package notify + +import ( + "context" + "net" + "sync" +) + +type bulkDedicatedSidecar struct { + laneID uint32 + conn net.Conn + closeOnce sync.Once + + senderMu sync.Mutex + sender *bulkDedicatedLaneSender +} + +type bulkDedicatedLane struct { + id uint32 + activeBulks int + sidecar *bulkDedicatedSidecar + attachFlight *bulkDedicatedAttachFlight +} + +type bulkDedicatedAttachFlight struct { + done chan struct{} + once sync.Once + err error +} + +func normalizeBulkDedicatedLaneID(laneID uint32) uint32 { + if laneID == 0 { + return 1 + } + return laneID +} + +func newBulkDedicatedSidecar(conn net.Conn, laneID uint32) *bulkDedicatedSidecar { + if conn == nil { + return nil + } + return &bulkDedicatedSidecar{ + laneID: normalizeBulkDedicatedLaneID(laneID), + conn: conn, + } +} + +func newBulkDedicatedAttachFlight() *bulkDedicatedAttachFlight { + return &bulkDedicatedAttachFlight{ + done: make(chan struct{}), + } +} + +func (f *bulkDedicatedAttachFlight) finish(err error) { + if f == nil { + return + } + f.once.Do(func() { + f.err = err + close(f.done) + }) +} + +func (f *bulkDedicatedAttachFlight) wait(ctx context.Context) error { + if f == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-f.done: + return f.err + } +} + +func (s *bulkDedicatedSidecar) close() { + if s == nil { + return + } + s.closeOnce.Do(func() { + if sender := s.laneSenderSnapshot(); sender != nil { + sender.stop() + } + if s.conn != nil { + _ = s.conn.Close() + } + }) +} + +func (s *bulkDedicatedSidecar) laneSenderSnapshot() *bulkDedicatedLaneSender { + if s == nil { + return nil + } + s.senderMu.Lock() + defer s.senderMu.Unlock() + return s.sender +} + +func (s *bulkDedicatedSidecar) laneSenderWithFactory(factory func(net.Conn) *bulkDedicatedLaneSender) *bulkDedicatedLaneSender { + if s == nil || factory == nil { + return nil + } + s.senderMu.Lock() + defer s.senderMu.Unlock() + if s.sender != nil { + return s.sender + } + if s.conn == nil { + return nil + } + s.sender = factory(s.conn) + return s.sender +} + +func (c *ClientCommon) clientDedicatedSidecarSnapshot() *bulkDedicatedSidecar { + if c == nil { + return nil + } + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + return firstClientDedicatedSidecarLocked(c.bulkDedicatedLanes) +} + +func firstClientDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedLane) *bulkDedicatedSidecar { + var ( + selected *bulkDedicatedSidecar + bestID uint32 + ) + for laneID, lane := range lanes { + if lane == nil || lane.sidecar == nil { + continue + } + if selected == nil || laneID < bestID { + selected = lane.sidecar + bestID = laneID + } + } + return selected +} + +func (c *ClientCommon) reserveBulkDedicatedLane() uint32 { + if c == nil { + return normalizeBulkDedicatedLaneID(0) + } + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + if c.bulkDedicatedLanes == nil { + c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane) + } + limit := c.bulkDedicatedLaneLimitSnapshot() + var best *bulkDedicatedLane + for _, lane := range c.bulkDedicatedLanes { + if lane == nil { + continue + } + if best == nil || lane.activeBulks < best.activeBulks || (lane.activeBulks == best.activeBulks && lane.id < best.id) { + best = lane + } + } + if best == nil || ((limit <= 0 || len(c.bulkDedicatedLanes) < limit) && best.activeBulks > 0) { + c.bulkDedicatedNextLaneID++ + laneID := normalizeBulkDedicatedLaneID(c.bulkDedicatedNextLaneID) + best = &bulkDedicatedLane{id: laneID} + c.bulkDedicatedLanes[laneID] = best + } + best.activeBulks++ + return best.id +} + +func (c *ClientCommon) releaseBulkDedicatedLane(laneID uint32) { + if c == nil { + return + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + lane := c.bulkDedicatedLanes[laneID] + if lane == nil { + return + } + if lane.activeBulks > 0 { + lane.activeBulks-- + } + if lane.activeBulks == 0 && lane.sidecar == nil && lane.attachFlight == nil { + delete(c.bulkDedicatedLanes, laneID) + } +} + +func (c *ClientCommon) clientDedicatedSidecarSnapshotForLane(laneID uint32) *bulkDedicatedSidecar { + if c == nil { + return nil + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + if lane := c.bulkDedicatedLanes[laneID]; lane != nil { + return lane.sidecar + } + return nil +} + +func (c *ClientCommon) beginClientDedicatedSidecarAttach(laneID uint32) (*bulkDedicatedSidecar, *bulkDedicatedAttachFlight, bool) { + if c == nil { + return nil, nil, false + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + if c.bulkDedicatedLanes == nil { + c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane) + } + lane := c.bulkDedicatedLanes[laneID] + if lane == nil { + lane = &bulkDedicatedLane{id: laneID} + c.bulkDedicatedLanes[laneID] = lane + } + if lane.sidecar != nil { + return lane.sidecar, nil, false + } + if lane.attachFlight != nil { + return nil, lane.attachFlight, false + } + flight := newBulkDedicatedAttachFlight() + lane.attachFlight = flight + return nil, flight, true +} + +func (c *ClientCommon) finishClientDedicatedSidecarAttach(laneID uint32, flight *bulkDedicatedAttachFlight, err error) { + if c == nil || flight == nil { + return + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + if lane := c.bulkDedicatedLanes[laneID]; lane != nil && lane.attachFlight == flight { + lane.attachFlight = nil + if lane.activeBulks == 0 && lane.sidecar == nil { + delete(c.bulkDedicatedLanes, laneID) + } + } + c.bulkDedicatedSidecarMu.Unlock() + flight.finish(err) +} + +func (c *ClientCommon) installClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) (*bulkDedicatedSidecar, bool) { + if c == nil || sidecar == nil { + return nil, false + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + if c.bulkDedicatedLanes == nil { + c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane) + } + lane := c.bulkDedicatedLanes[laneID] + if lane == nil { + lane = &bulkDedicatedLane{id: laneID} + c.bulkDedicatedLanes[laneID] = lane + } + if lane.sidecar != nil { + return lane.sidecar, false + } + lane.sidecar = sidecar + return sidecar, true +} + +func (c *ClientCommon) clearClientDedicatedSidecar(laneID uint32, sidecar *bulkDedicatedSidecar) bool { + if c == nil || sidecar == nil { + return false + } + laneID = normalizeBulkDedicatedLaneID(laneID) + c.bulkDedicatedSidecarMu.Lock() + defer c.bulkDedicatedSidecarMu.Unlock() + lane := c.bulkDedicatedLanes[laneID] + if lane == nil || lane.sidecar != sidecar { + return false + } + lane.sidecar = nil + if lane.activeBulks == 0 && lane.attachFlight == nil { + delete(c.bulkDedicatedLanes, laneID) + } + return true +} + +func (c *ClientCommon) closeClientDedicatedSidecar() { + if c == nil { + return + } + c.bulkDedicatedSidecarMu.Lock() + lanes := c.bulkDedicatedLanes + c.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane) + c.bulkDedicatedSidecarMu.Unlock() + for _, lane := range lanes { + if lane == nil { + continue + } + if lane.sidecar != nil { + lane.sidecar.close() + } + if lane.attachFlight != nil { + lane.attachFlight.finish(errServiceShutdown) + } + } +} + +func (c *ClientCommon) handleClientDedicatedSidecarFailure(sidecar *bulkDedicatedSidecar, err error) { + if c == nil || sidecar == nil { + return + } + if !c.clearClientDedicatedSidecar(sidecar.laneID, sidecar) { + return + } + runtime := c.getBulkRuntime() + if runtime != nil && sidecar.conn != nil { + runtime.handleDedicatedReadErrorByConn(clientFileScope(), sidecar.conn, err) + } + sidecar.close() +} + +func (s *ServerCommon) serverDedicatedSidecarSnapshot(logical *LogicalConn) *bulkDedicatedSidecar { + if s == nil || logical == nil { + return nil + } + s.bulkDedicatedSidecarMu.Lock() + defer s.bulkDedicatedSidecarMu.Unlock() + return firstServerDedicatedSidecarLocked(s.bulkDedicatedSidecars[logical]) +} + +func firstServerDedicatedSidecarLocked(lanes map[uint32]*bulkDedicatedSidecar) *bulkDedicatedSidecar { + var ( + selected *bulkDedicatedSidecar + bestID uint32 + ) + for laneID, sidecar := range lanes { + if sidecar == nil { + continue + } + if selected == nil || laneID < bestID { + selected = sidecar + bestID = laneID + } + } + return selected +} + +func (s *ServerCommon) serverDedicatedSidecarSnapshotForLane(logical *LogicalConn, laneID uint32) *bulkDedicatedSidecar { + if s == nil || logical == nil { + return nil + } + laneID = normalizeBulkDedicatedLaneID(laneID) + s.bulkDedicatedSidecarMu.Lock() + defer s.bulkDedicatedSidecarMu.Unlock() + if lanes := s.bulkDedicatedSidecars[logical]; lanes != nil { + return lanes[laneID] + } + return nil +} + +func (s *ServerCommon) installServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) *bulkDedicatedSidecar { + if s == nil || logical == nil || sidecar == nil { + return nil + } + laneID = normalizeBulkDedicatedLaneID(laneID) + s.bulkDedicatedSidecarMu.Lock() + defer s.bulkDedicatedSidecarMu.Unlock() + lanes := s.bulkDedicatedSidecars[logical] + if lanes == nil { + lanes = make(map[uint32]*bulkDedicatedSidecar) + s.bulkDedicatedSidecars[logical] = lanes + } + prev := lanes[laneID] + lanes[laneID] = sidecar + return prev +} + +func (s *ServerCommon) clearServerDedicatedSidecar(logical *LogicalConn, laneID uint32, sidecar *bulkDedicatedSidecar) bool { + if s == nil || logical == nil || sidecar == nil { + return false + } + laneID = normalizeBulkDedicatedLaneID(laneID) + s.bulkDedicatedSidecarMu.Lock() + defer s.bulkDedicatedSidecarMu.Unlock() + lanes := s.bulkDedicatedSidecars[logical] + if lanes == nil || lanes[laneID] != sidecar { + return false + } + delete(lanes, laneID) + if len(lanes) == 0 { + delete(s.bulkDedicatedSidecars, logical) + } + return true +} + +func (s *ServerCommon) closeServerDedicatedSidecar(logical *LogicalConn) { + if s == nil || logical == nil { + return + } + s.bulkDedicatedSidecarMu.Lock() + lanes := s.bulkDedicatedSidecars[logical] + delete(s.bulkDedicatedSidecars, logical) + s.bulkDedicatedSidecarMu.Unlock() + for _, sidecar := range lanes { + if sidecar != nil { + sidecar.close() + } + } +} + +func (s *ServerCommon) closeAllServerDedicatedSidecars() { + if s == nil { + return + } + s.bulkDedicatedSidecarMu.Lock() + lanesByLogical := s.bulkDedicatedSidecars + s.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar) + s.bulkDedicatedSidecarMu.Unlock() + for _, lanes := range lanesByLogical { + for _, sidecar := range lanes { + if sidecar != nil { + sidecar.close() + } + } + } +} + +func (s *ServerCommon) handleServerDedicatedSidecarFailure(logical *LogicalConn, sidecar *bulkDedicatedSidecar, err error) { + if s == nil || logical == nil || sidecar == nil { + return + } + if !s.clearServerDedicatedSidecar(logical, sidecar.laneID, sidecar) { + return + } + runtime := s.getBulkRuntime() + if runtime != nil && sidecar.conn != nil { + runtime.handleDedicatedReadErrorByConn(serverFileScope(logical), sidecar.conn, err) + } + sidecar.close() +} + +func (s *ServerCommon) attachServerDedicatedSidecarIfExists(logical *LogicalConn, bulk *bulkHandle) { + if s == nil || logical == nil || bulk == nil || !bulk.Dedicated() { + return + } + sidecar := s.serverDedicatedSidecarSnapshotForLane(logical, bulk.dedicatedLaneIDSnapshot()) + if sidecar == nil || sidecar.conn == nil { + return + } + _ = bulk.attachDedicatedConnShared(sidecar.conn) +} diff --git a/bulk_dispatcher.go b/bulk_dispatcher.go index 0bfff5b..f3e2aa3 100644 --- a/bulk_dispatcher.go +++ b/bulk_dispatcher.go @@ -12,6 +12,10 @@ import ( const bulkDispatchRejectTimeout = 300 * time.Millisecond func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) { + c.dispatchFastBulkFrameWithOwner(frame, nil) +} + +func (c *ClientCommon) dispatchFastBulkFrameWithOwner(frame bulkFastFrame, owner *bulkReadPayloadOwner) { if frame.DataID == 0 { return } @@ -38,7 +42,13 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) { } switch frame.Type { case bulkFastPayloadTypeData: - if err := bulk.pushOwnedChunk(frame.Payload); err != nil { + var err error + if owner != nil { + err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk()) + } else { + err = bulk.pushOwnedChunk(frame.Payload) + } + if err != nil { if c.showError || c.debugMode { fmt.Println("client bulk push chunk error", err) } @@ -58,14 +68,28 @@ func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) { resetErr = bulkRemoteResetError(string(frame.Payload)) } bulk.markReset(bulkResetError(resetErr)) + case bulkFastPayloadTypeRelease: + bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload) + if err != nil { + if c.showError || c.debugMode { + fmt.Println("client bulk release decode error", err) + } + c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error()) + return + } + bulk.releaseOutboundWindow(bytes, chunks) } } func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) { - c.dispatchFastBulkFrame(frame) + c.dispatchFastBulkFrameWithOwner(frame, nil) } func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) { + s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil) +} + +func (s *ServerCommon) dispatchFastBulkFrameWithOwner(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame, owner *bulkReadPayloadOwner) { if logical == nil || frame.DataID == 0 { return } @@ -91,7 +115,13 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr } switch frame.Type { case bulkFastPayloadTypeData: - if err := bulk.pushOwnedChunk(frame.Payload); err != nil { + var err error + if owner != nil { + err = bulk.pushChunkWithOwnershipOptionsAndRelease(frame.Payload, true, true, owner.retainChunk()) + } else { + err = bulk.pushOwnedChunk(frame.Payload) + } + if err != nil { if s.showError || s.debugMode { fmt.Println("server bulk push chunk error", err) } @@ -111,11 +141,82 @@ func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *Tr resetErr = bulkRemoteResetError(string(frame.Payload)) } bulk.markReset(bulkResetError(resetErr)) + case bulkFastPayloadTypeRelease: + bytes, chunks, err := decodeBulkDedicatedReleasePayload(frame.Payload) + if err != nil { + if s.showError || s.debugMode { + fmt.Println("server bulk release decode error", err) + } + s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error()) + return + } + bulk.releaseOutboundWindow(bytes, chunks) } } func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) { - s.dispatchFastBulkFrame(logical, transport, conn, frame) + s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, nil) +} + +func (c *ClientCommon) tryDispatchBorrowedBulkTransportPayload(payload []byte) bool { + if c == nil || len(payload) == 0 { + return false + } + plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(c.modernPSKRuntime, c.msgDe, c.SecretKey, payload) + if err != nil { + if c.showError || c.debugMode { + fmt.Println("client decode transport payload error", err) + } + return true + } + owner := newBulkReadPayloadOwner(plainRelease) + matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + c.dispatchFastBulkFrameWithOwner(frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + return false + } + if walkErr != nil && (c.showError || c.debugMode) { + fmt.Println("client decode bulk fast payload error", walkErr) + } + return true +} + +func (s *ServerCommon) tryDispatchBorrowedBulkTransportPayload(source interface{}, payload []byte) bool { + if s == nil || len(payload) == 0 { + return false + } + logical, transport := s.resolveInboundSource(source) + if logical == nil { + return false + } + plain, plainRelease, err := decryptTransportPayloadCodecOwnedPooled(logical.modernPSKRuntimeSnapshot(), logical.msgDeSnapshot(), logical.secretKeySnapshot(), payload) + if err != nil { + if s.showError || s.debugMode { + fmt.Println("server decode transport payload error", err) + } + return true + } + conn := serverInboundConn(source) + owner := newBulkReadPayloadOwner(plainRelease) + matched, walkErr := walkBulkFastFrames(plain, func(frame bulkFastFrame) error { + s.dispatchFastBulkFrameWithOwner(logical, transport, conn, frame, owner) + return nil + }) + if owner != nil { + owner.done() + } + if !matched { + return false + } + if walkErr != nil && (s.showError || s.debugMode) { + fmt.Println("server decode bulk fast payload error", walkErr) + } + return true } func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) { diff --git a/bulk_e2e_benchmark_test.go b/bulk_e2e_benchmark_test.go index 7ee92b3..c7ea4d3 100644 --- a/bulk_e2e_benchmark_test.go +++ b/bulk_e2e_benchmark_test.go @@ -308,7 +308,11 @@ func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkS tb.Fatalf("UseSignalReliabilityClient failed: %v", err) } } - if err := client.Connect(network, server.addr); err != nil { + dialAddr := server.addr + if network == "tcp" { + dialAddr = benchmarkTCPDialAddr(tb, dialAddr) + } + if err := client.Connect(network, dialAddr); err != nil { tb.Fatalf("client Connect failed: %v", err) } tb.Cleanup(func() { @@ -333,6 +337,9 @@ func bulkBenchmarkListenAddr(tb testing.TB, network string) string { case "unix": return filepath.Join(tb.TempDir(), "notify-bulk.sock") case "udp", "tcp": + if network == "tcp" { + return benchmarkTCPListenAddr(tb) + } return "127.0.0.1:0" default: tb.Fatalf("unsupported benchmark network %q", network) diff --git a/bulk_fastpath.go b/bulk_fastpath.go index 482ce0d..f64175c 100644 --- a/bulk_fastpath.go +++ b/bulk_fastpath.go @@ -120,32 +120,145 @@ func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) { } func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { - if c != nil && c.fastBulkEncode != nil { - return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk) - } - scratch := getBulkFastFrameScratch(len(chunk)) - defer putBulkFastFrameScratch(scratch) - frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] - if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { - return nil, err - } - copy(frame[bulkFastPayloadHeaderLen:], chunk) - return c.encryptTransportPayload(frame) + return c.encodeBulkFastPayload(bulkFastFrame{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: chunk, + }) } -func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error { +func (c *ClientCommon) encodeBulkFastPayload(frame bulkFastFrame) ([]byte, error) { + if c == nil { + return nil, errBulkClientNil + } + if c.fastPlainEncode != nil { + return encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + } + plain, err := encodeBulkFastFramePayload(frame) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (c *ClientCommon) encodeBulkFastBatchPayload(frames []bulkFastFrame) ([]byte, error) { + if c == nil { + return nil, errBulkClientNil + } + if c.fastPlainEncode != nil { + return encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + } + plain, err := encodeBulkFastBatchPlain(frames) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (c *ClientCommon) encodeBulkFastPayloadPooled(frame bulkFastFrame) ([]byte, func(), error) { + if c == nil { + return nil, nil, errBulkClientNil + } + if runtime := c.modernPSKRuntime; runtime != nil { + return encodeBulkFastFramePayloadPooled(runtime, frame) + } + if c.fastPlainEncode != nil { + payload, err := encodeBulkFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + return payload, nil, err + } + plain, err := encodeBulkFastFramePayload(frame) + if err != nil { + return nil, nil, err + } + payload, err := c.encryptTransportPayload(plain) + return payload, nil, err +} + +func (c *ClientCommon) encodeBulkFastBatchPayloadPooled(frames []bulkFastFrame) ([]byte, func(), error) { + if c == nil { + return nil, nil, errBulkClientNil + } + if runtime := c.modernPSKRuntime; runtime != nil { + return encodeBulkFastBatchPayloadPooled(runtime, frames) + } + if c.fastPlainEncode != nil { + payload, err := encodeBulkFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + return payload, nil, err + } + plain, err := encodeBulkFastBatchPlain(frames) + if err != nil { + return nil, nil, err + } + payload, err := c.encryptTransportPayload(plain) + return payload, nil, err +} + +func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return net.ErrClosed + } + if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { + return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) + } payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk) if err != nil { return err } + return c.writePayloadToTransport(payload) +} + +func (c *ClientCommon) sendFastBulkWrite(ctx context.Context, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { + if len(payload) == 0 { + return 0, nil + } + binding := c.clientTransportBindingSnapshot() + if binding == nil { + return 0, net.ErrClosed + } + if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { + return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) + } + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + seq := startSeq + for written < len(payload) { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + if err := c.sendFastBulkData(ctx, dataID, seq, payload[written:end], fastPathVersion); err != nil { + return written, err + } + seq++ + written = end + } + return written, nil +} + +func (c *ClientCommon) sendFastBulkControl(ctx context.Context, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { + frame := bulkFastFrame{ + Type: frameType, + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload, + } binding := c.clientTransportBindingSnapshot() if binding == nil { return net.ErrClosed } - if sender := binding.bulkBatchSenderSnapshot(); sender != nil { - return sender.submit(ctx, payload) + if sender := binding.clientBulkBatchSenderSnapshot(c); sender != nil { + return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) } - return c.writePayloadToTransport(payload) + encoded, err := c.encodeBulkFastPayload(frame) + if err != nil { + return err + } + return c.writePayloadToTransport(encoded) } func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { @@ -157,22 +270,81 @@ func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8 } func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { - if logical != nil { - if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil { - return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk) - } - } - scratch := getBulkFastFrameScratch(len(chunk)) - defer putBulkFastFrameScratch(scratch) - frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)] - if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil { - return nil, err - } - copy(frame[bulkFastPayloadHeaderLen:], chunk) - return s.encryptTransportPayloadLogical(logical, frame) + return s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ + Type: bulkFastPayloadTypeData, + DataID: dataID, + Seq: seq, + Payload: chunk, + }) } -func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { +func (s *ServerCommon) encodeBulkFastPayloadLogical(logical *LogicalConn, frame bulkFastFrame) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + return encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) + } + plain, err := encodeBulkFastFramePayload(frame) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} + +func (s *ServerCommon) encodeBulkFastBatchPayloadLogical(logical *LogicalConn, frames []bulkFastFrame) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + return encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) + } + plain, err := encodeBulkFastBatchPlain(frames) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} + +func (s *ServerCommon) encodeBulkFastPayloadLogicalPooled(logical *LogicalConn, frame bulkFastFrame) ([]byte, func(), error) { + if logical == nil { + return nil, nil, errTransportDetached + } + if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { + return encodeBulkFastFramePayloadPooled(runtime, frame) + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + payload, err := encodeBulkFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) + return payload, nil, err + } + plain, err := encodeBulkFastFramePayload(frame) + if err != nil { + return nil, nil, err + } + payload, err := s.encryptTransportPayloadLogical(logical, plain) + return payload, nil, err +} + +func (s *ServerCommon) encodeBulkFastBatchPayloadLogicalPooled(logical *LogicalConn, frames []bulkFastFrame) ([]byte, func(), error) { + if logical == nil { + return nil, nil, errTransportDetached + } + if runtime := logical.modernPSKRuntimeSnapshot(); runtime != nil { + return encodeBulkFastBatchPayloadPooled(runtime, frames) + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + payload, err := encodeBulkFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) + return payload, nil, err + } + plain, err := encodeBulkFastBatchPlain(frames) + if err != nil { + return nil, nil, err + } + payload, err := s.encryptTransportPayloadLogical(logical, plain) + return payload, nil, err +} + +func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte, fastPathVersion uint8) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } @@ -182,18 +354,87 @@ func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *L if logical == nil { return errTransportDetached } + if binding := logical.transportBindingSnapshot(); binding != nil { + if binding.queueSnapshot() != nil { + if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { + return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) + } + } + } payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err } + return s.writeEnvelopePayload(logical, transport, nil, payload) +} + +func (s *ServerCommon) sendFastBulkWriteTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, startSeq uint64, chunkSize int, fastPathVersion uint8, payload []byte, payloadOwned bool) (int, error) { + if len(payload) == 0 { + return 0, nil + } + if err := s.ensureServerTransportSendReady(transport); err != nil { + return 0, err + } + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return 0, errTransportDetached + } if binding := logical.transportBindingSnapshot(); binding != nil { if binding.queueSnapshot() != nil { - if sender := binding.bulkBatchSenderSnapshot(); sender != nil { - return sender.submit(ctx, payload) + if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { + return sender.submitWrite(ctx, dataID, startSeq, fastPathVersion, payload, chunkSize, payloadOwned) } } } - return s.writeEnvelopePayload(logical, transport, nil, payload) + if chunkSize <= 0 { + chunkSize = defaultBulkChunkSize + } + written := 0 + seq := startSeq + for written < len(payload) { + end := written + chunkSize + if end > len(payload) { + end = len(payload) + } + if err := s.sendFastBulkDataTransport(ctx, logical, transport, dataID, seq, payload[written:end], fastPathVersion); err != nil { + return written, err + } + seq++ + written = end + } + return written, nil +} + +func (s *ServerCommon) sendFastBulkControlTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, frameType uint8, flags uint8, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { + if err := s.ensureServerTransportSendReady(transport); err != nil { + return err + } + if logical == nil && transport != nil { + logical = transport.logicalConnSnapshot() + } + if logical == nil { + return errTransportDetached + } + if binding := logical.transportBindingSnapshot(); binding != nil { + if binding.queueSnapshot() != nil { + if sender := binding.serverBulkBatchSenderSnapshot(logical); sender != nil { + return sender.submitControl(ctx, frameType, flags, dataID, seq, fastPathVersion, payload) + } + } + } + encoded, err := s.encodeBulkFastPayloadLogical(logical, bulkFastFrame{ + Type: frameType, + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload, + }) + if err != nil { + return err + } + return s.writeEnvelopePayload(logical, transport, nil, encoded) } func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) { @@ -224,18 +465,22 @@ func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time. if err != nil { return err } - if frame, matched, err := decodeBulkFastFrame(plain); matched { + if frames, matched, err := decodeBulkFastFrames(plain); matched { if err != nil { return err } - c.dispatchFastBulkFrame(frame) + for _, frame := range frames { + c.dispatchFastBulkFrame(frame) + } return nil } - if frame, matched, err := decodeStreamFastDataFrame(plain); matched { + if frames, matched, err := decodeStreamFastDataFrames(plain); matched { if err != nil { return err } - c.dispatchFastStreamData(frame) + for _, frame := range frames { + c.dispatchFastStreamData(frame) + } return nil } env, err := c.decodeEnvelopePlain(plain) @@ -257,18 +502,22 @@ func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, tra if err != nil { return err } - if frame, matched, err := decodeBulkFastFrame(plain); matched { + if frames, matched, err := decodeBulkFastFrames(plain); matched { if err != nil { return err } - s.dispatchFastBulkFrame(logical, transport, conn, frame) + for _, frame := range frames { + s.dispatchFastBulkFrame(logical, transport, conn, frame) + } return nil } - if frame, matched, err := decodeStreamFastDataFrame(plain); matched { + if frames, matched, err := decodeStreamFastDataFrames(plain); matched { if err != nil { return err } - s.dispatchFastStreamData(logical, transport, conn, frame) + for _, frame := range frames { + s.dispatchFastStreamData(logical, transport, conn, frame) + } return nil } env, err := s.decodeEnvelopePlain(plain) diff --git a/bulk_runtime.go b/bulk_runtime.go index 6ad1c6d..9ed8bed 100644 --- a/bulk_runtime.go +++ b/bulk_runtime.go @@ -2,7 +2,7 @@ package notify import ( "fmt" - "strconv" + "net" "strings" "sync" "sync/atomic" @@ -16,14 +16,14 @@ type bulkRuntime struct { mu sync.RWMutex handler func(BulkAcceptInfo) error bulks map[string]*bulkHandle - data map[string]*bulkHandle + data map[string]map[uint64]*bulkHandle } func newBulkRuntime(rolePrefix string) *bulkRuntime { return &bulkRuntime{ rolePrefix: rolePrefix, bulks: make(map[string]*bulkHandle), - data: make(map[string]*bulkHandle), + data: make(map[string]map[uint64]*bulkHandle), } } @@ -66,8 +66,8 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error { if bulk == nil || bulk.id == "" { return errBulkIDEmpty } + scope = normalizeFileScope(scope) key := bulkRuntimeKey(scope, bulk.id) - dataKey := bulkRuntimeDataKey(scope, bulk.dataID) r.mu.Lock() defer r.mu.Unlock() if _, ok := r.bulks[key]; ok { @@ -76,11 +76,16 @@ func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error { if bulk.dataID == 0 { return errBulkDataIDEmpty } - if _, ok := r.data[dataKey]; ok { + dataScope := r.data[scope] + if dataScope == nil { + dataScope = make(map[uint64]*bulkHandle) + r.data[scope] = dataScope + } + if _, ok := dataScope[bulk.dataID]; ok { return errBulkAlreadyExists } r.bulks[key] = bulk - r.data[dataKey] = bulk + dataScope[bulk.dataID] = bulk return nil } @@ -99,10 +104,14 @@ func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle, if r == nil || dataID == 0 { return nil, false } - key := bulkRuntimeDataKey(scope, dataID) + scope = normalizeFileScope(scope) r.mu.RLock() defer r.mu.RUnlock() - bulk, ok := r.data[key] + dataScope := r.data[scope] + if dataScope == nil { + return nil, false + } + bulk, ok := dataScope[dataID] return bulk, ok } @@ -110,11 +119,17 @@ func (r *bulkRuntime) remove(scope string, bulkID string) { if r == nil || bulkID == "" { return } + scope = normalizeFileScope(scope) key := bulkRuntimeKey(scope, bulkID) r.mu.Lock() defer r.mu.Unlock() if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 { - delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID)) + if dataScope := r.data[scope]; dataScope != nil { + delete(dataScope, bulk.dataID) + if len(dataScope) == 0 { + delete(r.data, scope) + } + } } delete(r.bulks, key) } @@ -149,6 +164,140 @@ func (r *bulkRuntime) closeMatching(match func(string) bool, err error) { } } +func (r *bulkRuntime) resetDedicatedByConn(scope string, conn net.Conn, err error) { + if r == nil || conn == nil { + return + } + scope = normalizeFileScope(scope) + resetErr := bulkRuntimeCloseError(err) + prefix := scope + "\x00" + r.mu.RLock() + bulks := make([]*bulkHandle, 0, len(r.bulks)) + for key, bulk := range r.bulks { + if bulk == nil { + continue + } + if !strings.HasPrefix(key, prefix) { + continue + } + if !bulk.Dedicated() { + continue + } + if bulk.dedicatedConnSnapshot() != conn { + continue + } + bulks = append(bulks, bulk) + } + r.mu.RUnlock() + for _, bulk := range bulks { + bulk.markReset(resetErr) + } +} + +func (r *bulkRuntime) handleDedicatedReadErrorByConn(scope string, conn net.Conn, err error) { + if r == nil || conn == nil { + return + } + scope = normalizeFileScope(scope) + prefix := scope + "\x00" + r.mu.RLock() + bulks := make([]*bulkHandle, 0, len(r.bulks)) + for key, bulk := range r.bulks { + if bulk == nil { + continue + } + if !strings.HasPrefix(key, prefix) { + continue + } + if !bulk.Dedicated() { + continue + } + if bulk.dedicatedConnSnapshot() != conn { + continue + } + bulks = append(bulks, bulk) + } + r.mu.RUnlock() + for _, bulk := range bulks { + handleDedicatedBulkReadError(bulk, err) + } +} + +func (r *bulkRuntime) attachSharedDedicatedConn(scope string, laneID uint32, conn net.Conn) { + if r == nil || conn == nil { + return + } + scope = normalizeFileScope(scope) + laneID = normalizeBulkDedicatedLaneID(laneID) + prefix := scope + "\x00" + r.mu.RLock() + bulks := make([]*bulkHandle, 0, len(r.bulks)) + for key, bulk := range r.bulks { + if bulk == nil { + continue + } + if !strings.HasPrefix(key, prefix) { + continue + } + if !bulk.Dedicated() { + continue + } + if bulk.dedicatedLaneIDSnapshot() != laneID { + continue + } + bulks = append(bulks, bulk) + } + r.mu.RUnlock() + for _, bulk := range bulks { + current := bulk.dedicatedConnSnapshot() + switch { + case current == conn: + _ = bulk.attachDedicatedConnShared(conn) + case current == nil: + _ = bulk.attachDedicatedConnShared(conn) + default: + oldConn, oldSender, err := bulk.replaceDedicatedConnShared(conn) + if err != nil { + bulk.markReset(err) + continue + } + if oldSender != nil { + oldSender.stop() + } + if oldConn != nil { + _ = oldConn.Close() + } + } + } +} + +func (r *bulkRuntime) dedicatedBulksForConn(scope string, conn net.Conn) []*bulkHandle { + if r == nil || conn == nil { + return nil + } + scope = normalizeFileScope(scope) + prefix := scope + "\x00" + r.mu.RLock() + bulks := make([]*bulkHandle, 0, len(r.bulks)) + for key, bulk := range r.bulks { + if bulk == nil { + continue + } + if !strings.HasPrefix(key, prefix) { + continue + } + if !bulk.Dedicated() { + continue + } + if bulk.dedicatedConnSnapshot() != conn { + continue + } + bulks = append(bulks, bulk) + } + r.mu.RUnlock() + return bulks +} + func (r *bulkRuntime) snapshots() []BulkSnapshot { if r == nil { return nil @@ -170,10 +319,6 @@ func bulkRuntimeKey(scope string, bulkID string) string { return normalizeFileScope(scope) + "\x00" + bulkID } -func bulkRuntimeDataKey(scope string, dataID uint64) string { - return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10) -} - func bulkRuntimeCloseError(err error) error { if err != nil { return err diff --git a/bulk_shared_batch.go b/bulk_shared_batch.go new file mode 100644 index 0000000..3f8e484 --- /dev/null +++ b/bulk_shared_batch.go @@ -0,0 +1,228 @@ +package notify + +import ( + "encoding/binary" +) + +const ( + bulkFastPathVersionV1 = 1 + bulkFastPathVersionV2 = 2 + bulkFastPathVersionCurrent = bulkFastPathVersionV2 +) + +const ( + bulkFastBatchMagic = "NBF2" + bulkFastBatchVersion = 1 + bulkFastBatchHeaderLen = 12 + bulkFastBatchItemHeaderLen = 24 + bulkFastBatchMaxItems = 64 + bulkFastBatchMaxPlainBytes = 8 * 1024 * 1024 +) + +func normalizeBulkFastPathVersion(version uint8) uint8 { + if version < bulkFastPathVersionV1 { + return bulkFastPathVersionV1 + } + if version > bulkFastPathVersionCurrent { + return bulkFastPathVersionCurrent + } + return version +} + +func negotiateBulkFastPathVersion(version uint8) uint8 { + return normalizeBulkFastPathVersion(version) +} + +func bulkFastPathSupportsSharedBatch(version uint8) bool { + return normalizeBulkFastPathVersion(version) >= bulkFastPathVersionV2 +} + +func bulkFastBatchFrameLen(frame bulkFastFrame) int { + return bulkFastBatchItemHeaderLen + len(frame.Payload) +} + +func bulkFastBatchPlainLen(frames []bulkFastFrame) int { + total := bulkFastBatchHeaderLen + for _, frame := range frames { + total += bulkFastBatchFrameLen(frame) + } + return total +} + +func encodeBulkFastFramePayload(frame bulkFastFrame) ([]byte, error) { + return encodeBulkFastControlFrame(frame.Type, frame.Flags, frame.DataID, frame.Seq, frame.Payload) +} + +func encodeBulkFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame bulkFastFrame) ([]byte, error) { + if encode == nil { + return nil, errTransportPayloadEncryptFailed + } + plainLen := bulkFastPayloadHeaderLen + len(frame.Payload) + return encode(secretKey, plainLen, func(dst []byte) error { + if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { + return err + } + copy(dst[bulkFastPayloadHeaderLen:], frame.Payload) + return nil + }) +} + +func encodeBulkFastFramePayloadPooled(runtime *modernPSKCodecRuntime, frame bulkFastFrame) ([]byte, func(), error) { + if runtime == nil { + return nil, nil, errTransportPayloadEncryptFailed + } + plainLen := bulkFastPayloadHeaderLen + len(frame.Payload) + return runtime.sealFilledPayloadPooled(plainLen, func(dst []byte) error { + if err := encodeBulkFastFrameHeader(dst, frame.Type, frame.Flags, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { + return err + } + copy(dst[bulkFastPayloadHeaderLen:], frame.Payload) + return nil + }) +} + +func encodeBulkFastBatchPlain(frames []bulkFastFrame) ([]byte, error) { + if len(frames) == 0 { + return nil, errBulkFastPayloadInvalid + } + buf := make([]byte, bulkFastBatchPlainLen(frames)) + if err := writeBulkFastBatchPlain(buf, frames); err != nil { + return nil, err + } + return buf, nil +} + +func encodeBulkFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []bulkFastFrame) ([]byte, error) { + if encode == nil { + return nil, errTransportPayloadEncryptFailed + } + plainLen := bulkFastBatchPlainLen(frames) + return encode(secretKey, plainLen, func(dst []byte) error { + return writeBulkFastBatchPlain(dst, frames) + }) +} + +func encodeBulkFastBatchPayloadPooled(runtime *modernPSKCodecRuntime, frames []bulkFastFrame) ([]byte, func(), error) { + if runtime == nil { + return nil, nil, errTransportPayloadEncryptFailed + } + return runtime.sealFilledPayloadPooled(bulkFastBatchPlainLen(frames), func(dst []byte) error { + return writeBulkFastBatchPlain(dst, frames) + }) +} + +func writeBulkFastBatchPlain(dst []byte, frames []bulkFastFrame) error { + if len(frames) == 0 || len(dst) != bulkFastBatchPlainLen(frames) { + return errBulkFastPayloadInvalid + } + copy(dst[:4], bulkFastBatchMagic) + dst[4] = bulkFastBatchVersion + binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames))) + offset := bulkFastBatchHeaderLen + for _, frame := range frames { + if frame.DataID == 0 { + return errBulkFastPayloadInvalid + } + dst[offset] = frame.Type + dst[offset+1] = frame.Flags + binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID) + binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq) + binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload))) + offset += bulkFastBatchItemHeaderLen + copy(dst[offset:offset+len(frame.Payload)], frame.Payload) + offset += len(frame.Payload) + } + return nil +} + +func walkBulkFastBatchPlain(payload []byte, fn func(bulkFastFrame) error) (bool, error) { + if len(payload) < 4 || string(payload[:4]) != bulkFastBatchMagic { + return false, nil + } + if len(payload) < bulkFastBatchHeaderLen { + return true, errBulkFastPayloadInvalid + } + if payload[4] != bulkFastBatchVersion { + return true, errBulkFastPayloadInvalid + } + count := int(binary.BigEndian.Uint32(payload[8:12])) + if count <= 0 { + return true, errBulkFastPayloadInvalid + } + offset := bulkFastBatchHeaderLen + for index := 0; index < count; index++ { + if len(payload)-offset < bulkFastBatchItemHeaderLen { + return true, errBulkFastPayloadInvalid + } + frameType := payload[offset] + switch frameType { + case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease: + default: + return true, errBulkFastPayloadInvalid + } + flags := payload[offset+1] + dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20]) + payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24])) + offset += bulkFastBatchItemHeaderLen + if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen { + return true, errBulkFastPayloadInvalid + } + if fn != nil { + if err := fn(bulkFastFrame{ + Type: frameType, + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload[offset : offset+payloadLen], + }); err != nil { + return true, err + } + } + offset += payloadLen + } + if offset != len(payload) { + return true, errBulkFastPayloadInvalid + } + return true, nil +} + +func decodeBulkFastBatchPlain(payload []byte) ([]bulkFastFrame, bool, error) { + frames := make([]bulkFastFrame, 0, 1) + matched, err := walkBulkFastBatchPlain(payload, func(frame bulkFastFrame) error { + frames = append(frames, frame) + return nil + }) + if !matched || err != nil { + return nil, matched, err + } + return frames, true, nil +} + +func walkBulkFastFrames(payload []byte, fn func(bulkFastFrame) error) (bool, error) { + if matched, err := walkBulkFastBatchPlain(payload, fn); matched { + return true, err + } + frame, matched, err := decodeBulkFastFrame(payload) + if !matched || err != nil { + return matched, err + } + if fn != nil { + if err := fn(frame); err != nil { + return true, err + } + } + return true, nil +} + +func decodeBulkFastFrames(payload []byte) ([]bulkFastFrame, bool, error) { + frames := make([]bulkFastFrame, 0, 1) + matched, err := walkBulkFastFrames(payload, func(frame bulkFastFrame) error { + frames = append(frames, frame) + return nil + }) + if !matched || err != nil { + return nil, matched, err + } + return frames, true, nil +} diff --git a/bulk_shared_batch_test.go b/bulk_shared_batch_test.go new file mode 100644 index 0000000..f4dcf6e --- /dev/null +++ b/bulk_shared_batch_test.go @@ -0,0 +1,458 @@ +package notify + +import ( + "context" + "testing" + "time" +) + +func TestBulkFastBatchPlainRoundTrip(t *testing.T) { + releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2) + if err != nil { + t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err) + } + frames := []bulkFastFrame{ + { + Type: bulkFastPayloadTypeData, + DataID: 11, + Seq: 7, + Payload: []byte("alpha"), + }, + { + Type: bulkFastPayloadTypeRelease, + DataID: 12, + Seq: 0, + Payload: releasePayload, + }, + } + wire, err := encodeBulkFastBatchPlain(frames) + if err != nil { + t.Fatalf("encodeBulkFastBatchPlain failed: %v", err) + } + decoded, matched, err := decodeBulkFastBatchPlain(wire) + if err != nil { + t.Fatalf("decodeBulkFastBatchPlain failed: %v", err) + } + if !matched { + t.Fatal("decodeBulkFastBatchPlain should match encoded batch") + } + if got, want := len(decoded), len(frames); got != want { + t.Fatalf("decoded frame count = %d, want %d", got, want) + } + for index := range frames { + if got, want := decoded[index].Type, frames[index].Type; got != want { + t.Fatalf("frame %d type = %d, want %d", index, got, want) + } + if got, want := decoded[index].DataID, frames[index].DataID; got != want { + t.Fatalf("frame %d dataID = %d, want %d", index, got, want) + } + if got, want := decoded[index].Seq, frames[index].Seq; got != want { + t.Fatalf("frame %d seq = %d, want %d", index, got, want) + } + if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want { + t.Fatalf("frame %d payload = %q, want %q", index, got, want) + } + } +} + +func TestBulkBatchSenderEncodeRequestsCoalescesSharedFastV2Frames(t *testing.T) { + var ( + singleCalls int + batchCalls [][]bulkFastFrame + ) + sender := &bulkBatchSender{ + codec: bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + singleCalls++ + return []byte("single"), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + cloned := make([]bulkFastFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil, nil + }, + }, + } + payloads, err := sender.encodeRequests([]bulkBatchRequest{ + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 1, + Payload: []byte("a"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeRelease, + DataID: 101, + Seq: 0, + Payload: []byte("rel"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 1; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if singleCalls != 0 { + t.Fatalf("single encode calls = %d, want 0", singleCalls) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("batched item count = %d, want %d", got, want) + } +} + +func TestBulkBatchSenderEncodeRequestsCoalescesAcrossRequestsAndBulks(t *testing.T) { + var ( + singleCalls int + batchCalls [][]bulkFastFrame + ) + sender := &bulkBatchSender{ + codec: bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + singleCalls++ + return []byte("single"), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + cloned := make([]bulkFastFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil, nil + }, + }, + } + payloads, err := sender.encodeRequests([]bulkBatchRequest{ + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 1, + Payload: []byte("bulk-a"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 202, + Seq: 1, + Payload: []byte("bulk-b"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 1; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if singleCalls != 0 { + t.Fatalf("single encode calls = %d, want 0", singleCalls) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("batched item count = %d, want %d", got, want) + } + if got, want := batchCalls[0][0].DataID, uint64(101); got != want { + t.Fatalf("first batched dataID = %d, want %d", got, want) + } + if got, want := batchCalls[0][1].DataID, uint64(202); got != want { + t.Fatalf("second batched dataID = %d, want %d", got, want) + } +} + +func TestBulkBatchSenderEncodeRequestsSplitsLargeCrossBulkSuperBatch(t *testing.T) { + var batchCalls [][]bulkFastFrame + sender := &bulkBatchSender{ + codec: bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + return []byte("single"), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + cloned := make([]bulkFastFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil, nil + }, + }, + } + payload := make([]byte, 768*1024) + payloads, err := sender.encodeRequests([]bulkBatchRequest{ + { + frames: []bulkFastFrame{ + { + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 1, + Payload: payload, + }, + { + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 2, + Payload: payload, + }, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{ + { + Type: bulkFastPayloadTypeData, + DataID: 202, + Seq: 1, + Payload: payload, + }, + { + Type: bulkFastPayloadTypeData, + DataID: 202, + Seq: 2, + Payload: payload, + }, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 2; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := len(batchCalls), 2; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("first payload frame count = %d, want %d", got, want) + } + if got, want := len(batchCalls[1]), 2; got != want { + t.Fatalf("second payload frame count = %d, want %d", got, want) + } + if got, want := batchCalls[0][0].DataID, uint64(101); got != want { + t.Fatalf("first payload dataID = %d, want %d", got, want) + } + if got, want := batchCalls[1][0].DataID, uint64(202); got != want { + t.Fatalf("second payload dataID = %d, want %d", got, want) + } +} + +func TestBulkReleaseFastRoundTripTransport(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + const chunkSize = 64 * 1024 + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{Offset: 0, Length: chunkSize}, + ChunkSize: chunkSize, + WindowBytes: chunkSize, + MaxInFlight: 1, + }) + if err != nil { + t.Fatalf("client OpenBulk failed: %v", err) + } + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + clientHandle := bulk.(*bulkHandle) + serverHandle := accepted.Bulk.(*bulkHandle) + if got, want := clientHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want { + t.Fatalf("client fast path version = %d, want %d", got, want) + } + if got, want := serverHandle.FastPathVersion(), uint8(bulkFastPathVersionV2); got != want { + t.Fatalf("server fast path version = %d, want %d", got, want) + } + + clientHandle.mu.Lock() + clientHandle.outboundAvailBytes = 0 + clientHandle.outboundInFlight = 1 + clientHandle.mu.Unlock() + + releaseFn := serverBulkReleaseSender(server, accepted.LogicalConn, accepted.TransportConn) + if err := releaseFn(serverHandle, chunkSize, 1); err != nil { + t.Fatalf("server bulk release sender failed: %v", err) + } + + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + clientHandle.mu.Lock() + avail := clientHandle.outboundAvailBytes + inFlight := clientHandle.outboundInFlight + clientHandle.mu.Unlock() + if avail == chunkSize && inFlight == 0 { + return + } + time.Sleep(10 * time.Millisecond) + } + + clientHandle.mu.Lock() + avail := clientHandle.outboundAvailBytes + inFlight := clientHandle.outboundInFlight + clientHandle.mu.Unlock() + t.Fatalf("client outbound window not released by fast path: avail=%d inFlight=%d", avail, inFlight) +} + +func TestBulkBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) { + var ( + singleCalls int + batchCalls [][]bulkFastFrame + ) + sender := &bulkBatchSender{ + codec: bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + singleCalls++ + return []byte("single"), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + cloned := make([]bulkFastFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil, nil + }, + }, + } + largePayload := make([]byte, bulkFastBatchMaxPlainBytes-bulkFastBatchHeaderLen-bulkFastBatchItemHeaderLen-128) + payloads, err := sender.encodeRequests([]bulkBatchRequest{ + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 1, + Payload: largePayload, + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 101, + Seq: 2, + Payload: []byte("sep"), + }}, + fastPathVersion: bulkFastPathVersionV1, + }, + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 202, + Seq: 1, + Payload: []byte("a"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 202, + Seq: 2, + Payload: []byte("b"), + }}, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 3; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := singleCalls, 2; got != want { + t.Fatalf("single encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("post-flush batched frame count = %d, want %d", got, want) + } +} + +func TestBulkBatchSenderEncodeRequestsUsesBindingAdaptiveSoftLimit(t *testing.T) { + binding := &transportBinding{} + binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil) + var batchCalls [][]bulkFastFrame + sender := &bulkBatchSender{ + binding: binding, + codec: bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + return []byte("single"), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + cloned := make([]bulkFastFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil, nil + }, + }, + } + payload := make([]byte, 96*1024) + payloads, err := sender.encodeRequests([]bulkBatchRequest{ + { + frames: []bulkFastFrame{ + {Type: bulkFastPayloadTypeData, DataID: 101, Seq: 1, Payload: payload}, + {Type: bulkFastPayloadTypeData, DataID: 101, Seq: 2, Payload: payload}, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{ + {Type: bulkFastPayloadTypeData, DataID: 202, Seq: 1, Payload: payload}, + {Type: bulkFastPayloadTypeData, DataID: 202, Seq: 2, Payload: payload}, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 2; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := len(batchCalls), 2; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + for index, want := range []uint64{101, 202} { + if got := batchCalls[index][0].DataID; got != want { + t.Fatalf("payload %d first dataID = %d, want %d", index, got, want) + } + if got, wantItems := len(batchCalls[index]), 2; got != wantItems { + t.Fatalf("payload %d item count = %d, want %d", index, got, wantItems) + } + } +} diff --git a/bulk_snapshot.go b/bulk_snapshot.go index c0f0813..29b168f 100644 --- a/bulk_snapshot.go +++ b/bulk_snapshot.go @@ -7,49 +7,56 @@ import ( ) type BulkSnapshot struct { - ID string - DataID uint64 - Scope string - Range BulkRange - Metadata BulkMetadata - BindingOwner string - BindingAlive bool - BindingCurrent bool - BindingReason string - BindingError string - Dedicated bool - DedicatedAttached bool - SessionEpoch uint64 - LogicalClientID string - TransportGeneration uint64 - TransportAttached bool - TransportHasRuntimeConn bool - TransportCurrent bool - TransportDetachReason string - TransportDetachKind string - TransportDetachGeneration uint64 - TransportDetachError string - TransportDetachedAt time.Time - ReattachEligible bool - LocalClosed bool - LocalReadClosed bool - RemoteClosed bool - PeerReadClosed bool - BufferedChunks int - BufferedBytes int - ReadTimeout time.Duration - WriteTimeout time.Duration - ChunkSize int - WindowBytes int - MaxInFlight int - BytesRead int64 - BytesWritten int64 - ReadCalls int64 - WriteCalls int64 - OpenedAt time.Time - LastReadAt time.Time - LastWriteAt time.Time - ResetError string + ID string + DataID uint64 + FastPathVersion uint8 + Scope string + Range BulkRange + Metadata BulkMetadata + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + BindingBulkAdaptiveSoftPayloadBytes int + Dedicated bool + DedicatedLaneID uint32 + DedicatedAttached bool + DedicatedAttachState string + DedicatedAttachAttempts uint32 + DedicatedAttachLastCode string + DedicatedDataStarted bool + SessionEpoch uint64 + LogicalClientID string + TransportGeneration uint64 + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + ReattachEligible bool + LocalClosed bool + LocalReadClosed bool + RemoteClosed bool + PeerReadClosed bool + BufferedChunks int + BufferedBytes int + ReadTimeout time.Duration + WriteTimeout time.Duration + ChunkSize int + WindowBytes int + MaxInFlight int + BytesRead int64 + BytesWritten int64 + ReadCalls int64 + WriteCalls int64 + OpenedAt time.Time + LastReadAt time.Time + LastWriteAt time.Time + ResetError string } type clientBulkSnapshotReader interface { diff --git a/bulk_stack_benchmark_test.go b/bulk_stack_benchmark_test.go index 24d2348..321f04a 100644 --- a/bulk_stack_benchmark_test.go +++ b/bulk_stack_benchmark_test.go @@ -88,7 +88,7 @@ func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) { func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) { b.Helper() - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b)) if err != nil { b.Fatalf("net.Listen failed: %v", err) } @@ -107,7 +107,7 @@ func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode acceptCh <- conn }() - clientConn, err := net.Dial("tcp", listener.Addr().String()) + clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String())) if err != nil { b.Fatalf("net.Dial failed: %v", err) } diff --git a/bulk_test.go b/bulk_test.go index 1a5674c..61308e7 100644 --- a/bulk_test.go +++ b/bulk_test.go @@ -109,6 +109,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) { if !clientSnapshots[0].BindingAlive || !clientSnapshots[0].BindingCurrent || !clientSnapshots[0].TransportAttached || !clientSnapshots[0].TransportCurrent { t.Fatalf("client bulk binding snapshot mismatch: %+v", clientSnapshots[0]) } + if got, want := clientSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("client bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want) + } serverSnapshots, err := GetServerBulkSnapshots(server) if err != nil { t.Fatalf("GetServerBulkSnapshots failed: %v", err) @@ -119,6 +122,9 @@ func TestBulkOpenRoundTripTCP(t *testing.T) { if got, want := serverSnapshots[0].BindingOwner, "server-transport"; got != want { t.Fatalf("server bulk BindingOwner = %q, want %q", got, want) } + if got, want := serverSnapshots[0].BindingBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("server bulk BindingBulkAdaptiveSoftPayloadBytes = %d, want %d", got, want) + } if !serverSnapshots[0].BindingAlive || !serverSnapshots[0].BindingCurrent || !serverSnapshots[0].TransportAttached || !serverSnapshots[0].TransportCurrent { t.Fatalf("server bulk binding snapshot mismatch: %+v", serverSnapshots[0]) } @@ -135,6 +141,314 @@ func TestBulkOpenRoundTripTCP(t *testing.T) { waitForBulkContextDone(t, bulk.Context(), 2*time.Second) } +func TestDedicatedBulkOpenUnblocksSynchronousReadHandler(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + readCh := make(chan string, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + defer func() { + _ = info.Bulk.Close() + }() + buf := make([]byte, 5) + if _, err := io.ReadFull(info.Bulk, buf); err != nil { + return err + } + readCh <- string(buf) + return nil + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + type openResult struct { + bulk Bulk + err error + } + openCh := make(chan openResult, 1) + go func() { + bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "sync-read-handler", + Range: BulkRange{ + Offset: 0, + Length: 5, + }, + }) + openCh <- openResult{bulk: bulk, err: err} + }() + + var bulk Bulk + select { + case result := <-openCh: + if result.err != nil { + t.Fatalf("client OpenDedicatedBulk failed: %v", result.err) + } + bulk = result.bulk + case <-time.After(2 * time.Second): + t.Fatal("client OpenDedicatedBulk timed out while remote handler was synchronously reading") + } + defer func() { + if bulk != nil { + _ = bulk.Close() + } + }() + + if _, err := bulk.Write([]byte("hello")); err != nil { + t.Fatalf("client dedicated bulk Write failed: %v", err) + } + select { + case got := <-readCh: + if got != "hello" { + t.Fatalf("server handler read %q, want %q", got, "hello") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for synchronous handler read") + } +} + +func TestDedicatedBulkOpenUnblocksOnBlockingFirstWrite(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + payload := strings.Repeat("w", 16) + writeDone := make(chan error, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + _, err := io.WriteString(info.Bulk, payload) + if err == nil { + err = info.Bulk.CloseWrite() + } + writeDone <- err + return err + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + bulk, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{ + ID: "blocking-write-ready", + Range: BulkRange{ + Offset: 0, + Length: int64(len(payload)), + }, + ChunkSize: 4, + WindowBytes: 4, + MaxInFlight: 1, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk failed: %v", err) + } + + readBulkExactly(t, bulk, payload, 2*time.Second) + waitForBulkReadEOF(t, bulk, 2*time.Second) + + select { + case err := <-writeDone: + if err != nil { + t.Fatalf("server handler write failed: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for blocking first write to finish") + } + + if err := bulk.Close(); err != nil { + t.Fatalf("client dedicated bulk Close failed: %v", err) + } + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestServerOpenBulkLogicalDedicatedUnblocksOnBlockingFirstRead(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + readCh := make(chan string, 1) + client.SetBulkHandler(func(info BulkAcceptInfo) error { + defer func() { + _ = info.Bulk.Close() + }() + buf := make([]byte, 5) + if _, err := io.ReadFull(info.Bulk, buf); err != nil { + return err + } + readCh <- string(buf) + return nil + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + bulk, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{ + ID: "server-blocking-read-ready", + Range: BulkRange{ + Offset: 0, + Length: 5, + }, + Dedicated: true, + }) + if err != nil { + t.Fatalf("server OpenBulkLogical dedicated failed: %v", err) + } + + if _, err := bulk.Write([]byte("hello")); err != nil { + t.Fatalf("server dedicated bulk Write failed: %v", err) + } + + select { + case got := <-readCh: + if got != "hello" { + t.Fatalf("client handler read %q, want %q", got, "hello") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for blocking first read to finish") + } + + if err := bulk.Close(); err != nil { + t.Fatalf("server dedicated bulk Close failed: %v", err) + } + waitForBulkContextDone(t, bulk.Context(), 2*time.Second) +} + +func TestDedicatedBulkOpenReturnsHandlerFailureAfterAccepted(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + server.SetBulkHandler(func(info BulkAcceptInfo) error { + time.Sleep(80 * time.Millisecond) + return errors.New("dedicated handler failed after accept") + }) + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{ + ID: "accepted-then-fail", + Range: BulkRange{ + Offset: 0, + Length: 1, + }, + }) + if err == nil || !strings.Contains(err.Error(), "dedicated handler failed after accept") { + t.Fatalf("client OpenDedicatedBulk error = %v, want dedicated handler failure after accept", err) + } +} + +func TestServerOpenBulkLogicalDedicatedReturnsHandlerFailureAfterAccepted(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + client.SetBulkHandler(func(info BulkAcceptInfo) error { + time.Sleep(80 * time.Millisecond) + return errors.New("client dedicated handler failed after accept") + }) + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := server.OpenBulkLogical(ctx, logical, BulkOpenOptions{ + ID: "server-accepted-then-fail", + Range: BulkRange{ + Offset: 0, + Length: 1, + }, + Dedicated: true, + }) + if err == nil || !strings.Contains(err.Error(), "client dedicated handler failed after accept") { + t.Fatalf("server OpenBulkLogical dedicated error = %v, want client dedicated handler failure after accept", err) + } +} + func TestBulkOpenRoundTripServerLogicalTCP(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { @@ -532,15 +846,18 @@ func TestDedicatedBulkWritePrefersClosedPipeOverContextCanceled(t *testing.T) { MaxInFlight: 4, }, 0, nil, nil, 0, nil, nil, func(context.Context, *bulkHandle, []byte) error { return nil - }, func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + }, func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) { bulk.markPeerClosed() <-ctx.Done() return 0, ctx.Err() }, nil) _, err := bulk.Write([]byte("abcdefgh")) - if !errors.Is(err, io.ErrClosedPipe) { - t.Fatalf("bulk Write error = %v, want %v", err, io.ErrClosedPipe) + if err != nil && !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("bulk Write error = %v, want nil or %v", err, io.ErrClosedPipe) + } + if err := bulk.waitPendingAsyncWrites(context.Background()); err != nil && !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("bulk waitPendingAsyncWrites error = %v, want nil or %v", err, io.ErrClosedPipe) } } diff --git a/bulk_transport_guard_test.go b/bulk_transport_guard_test.go index 5bf5e1d..bd099ea 100644 --- a/bulk_transport_guard_test.go +++ b/bulk_transport_guard_test.go @@ -3,6 +3,9 @@ package notify import ( "context" "errors" + "net" + "strings" + "sync" "testing" "time" ) @@ -45,6 +48,1276 @@ func TestBulkOpenDedicatedUDPRejected(t *testing.T) { } } +func TestBulkOpenAutoUDPFallsBackToShared(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("udp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Mode: BulkOpenModeAuto, + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk auto over udp failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("client OpenBulk auto over udp should fall back to shared") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if accepted.Dedicated { + t.Fatal("server accepted bulk should be shared for auto over udp") + } + _ = accepted.Bulk.Close() +} + +func TestOpenDedicatedBulkWaitsForActiveSlotUntilContextDeadline(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + client.bulkDedicatedActiveLimit = 1 + client.bulkDedicatedActive.Store(1) + + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + _, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("client OpenDedicatedBulk error = %v, want %v", err, context.DeadlineExceeded) + } +} + +func TestOpenBulkAutoWaitsForActiveSlotAndKeepsDedicated(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 4) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + client.bulkDedicatedActiveLimit = 1 + client.bulkDedicatedActive.Store(1) + time.AfterFunc(40*time.Millisecond, func() { + client.releaseBulkDedicatedActiveSlot() + }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + bulk, err := client.OpenBulk(ctx, BulkOpenOptions{ + Mode: BulkOpenModeAuto, + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk auto failed: %v", err) + } + if !bulk.Snapshot().Dedicated { + t.Fatal("client OpenBulk auto should wait for active slot and stay dedicated") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second) + if !accepted.Dedicated { + t.Fatal("server accepted bulk should stay dedicated after active slot wait") + } + + if _, err := bulk.Write([]byte("auto-fallback")); err != nil { + t.Fatalf("client bulk Write failed: %v", err) + } + readBulkExactly(t, accepted.Bulk, "auto-fallback", 2*time.Second) + _ = accepted.Bulk.Close() +} + +func TestOpenSharedBulkForcesShared(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + bulk, err := client.OpenSharedBulk(context.Background(), BulkOpenOptions{ + Mode: BulkOpenModeDedicated, + Dedicated: true, + Range: BulkRange{ + Offset: 0, + Length: 256, + }, + }) + if err != nil { + t.Fatalf("client OpenSharedBulk failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("OpenSharedBulk should force shared mode") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second) + if accepted.Dedicated { + t.Fatal("server accepted bulk should be shared for OpenSharedBulk") + } + _ = accepted.Bulk.Close() +} + +func TestOpenBulkDefaultModeUsesClientSetting(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + server.SetBulkHandler(func(info BulkAcceptInfo) error { + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + client.SetBulkDefaultOpenMode(BulkOpenModeDedicated) + cfg := client.BulkDedicatedAttachConfig() + cfg.ActiveLimit = 1 + client.SetBulkDedicatedAttachConfig(cfg) + client.bulkDedicatedActive.Store(1) + + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + _, err := client.OpenBulk(ctx, BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("client OpenBulk default-mode dedicated error = %v, want %v", err, context.DeadlineExceeded) + } +} + +func TestBulkNetworkProfileWANUsesAutoFallback(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + client.SetBulkNetworkProfile(BulkNetworkProfileWAN) + if got, want := client.BulkNetworkProfile(), BulkNetworkProfileWAN; got != want { + t.Fatalf("BulkNetworkProfile = %v, want %v", got, want) + } + if got, want := client.BulkDefaultOpenMode(), BulkOpenModeAuto; got != want { + t.Fatalf("BulkDefaultOpenMode = %v, want %v", got, want) + } + cfg := client.BulkDedicatedAttachConfig() + if got, want := cfg.ActiveLimit, 4096; got != want { + t.Fatalf("WAN ActiveLimit = %d, want %d", got, want) + } + client.setClientConnectSource(newClientFactoryConnectSource(func(context.Context) (net.Conn, error) { + return nil, errors.New("forced attach dial failure") + })) + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk with WAN profile failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("client OpenBulk with WAN profile should fallback to shared when dedicated attach fails") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second) + if accepted.Dedicated { + t.Fatal("server accepted bulk should be shared after WAN auto fallback") + } + _ = accepted.Bulk.Close() +} + +func TestOpenBulkAutoAttachFailureTriggersSingleServerAccept(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + client.setClientConnectSource(newClientFactoryConnectSource(func(context.Context) (net.Conn, error) { + return nil, errors.New("forced attach dial failure") + })) + + bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{ + Mode: BulkOpenModeAuto, + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("client OpenBulk auto failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("client OpenBulk auto should fallback to shared after attach dial failure") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + if accepted.Dedicated { + t.Fatal("server should only dispatch shared accept after dedicated attach failure") + } + select { + case extra := <-acceptCh: + t.Fatalf("unexpected extra server bulk accept: %+v", extra) + case <-time.After(300 * time.Millisecond): + } + _ = accepted.Bulk.Close() +} + +func TestDedicatedBulkReusesSessionSidecarAcrossBulks(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 4) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + cfg := client.BulkDedicatedAttachConfig() + cfg.LaneLimit = 1 + client.SetBulkDedicatedAttachConfig(cfg) + + bulk1, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "reuse-1", + Range: BulkRange{ + Offset: 0, + Length: 16, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk #1 failed: %v", err) + } + accepted1 := waitAcceptedBulkByID(t, acceptCh, bulk1.ID(), 2*time.Second) + + bulk2, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "reuse-2", + Range: BulkRange{ + Offset: 16, + Length: 16, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk #2 failed: %v", err) + } + accepted2 := waitAcceptedBulkByID(t, acceptCh, bulk2.ID(), 2*time.Second) + + clientHandle1 := bulk1.(*bulkHandle) + clientHandle2 := bulk2.(*bulkHandle) + serverHandle1 := accepted1.Bulk.(*bulkHandle) + serverHandle2 := accepted2.Bulk.(*bulkHandle) + + clientConn1 := clientHandle1.dedicatedConnSnapshot() + clientConn2 := clientHandle2.dedicatedConnSnapshot() + if clientConn1 == nil || clientConn2 == nil || clientConn1 != clientConn2 { + t.Fatal("client dedicated bulks should reuse the same sidecar conn") + } + serverConn1 := serverHandle1.dedicatedConnSnapshot() + serverConn2 := serverHandle2.dedicatedConnSnapshot() + if serverConn1 == nil || serverConn2 == nil || serverConn1 != serverConn2 { + t.Fatal("server dedicated bulks should reuse the same sidecar conn") + } + if got := client.bulkAttachAttemptCount.Load(); got != 1 { + t.Fatalf("client bulkAttachAttemptCount = %d, want 1", got) + } + if got := client.bulkAttachSuccessCount.Load(); got != 1 { + t.Fatalf("client bulkAttachSuccessCount = %d, want 1", got) + } + + if err := bulk1.Close(); err != nil { + t.Fatalf("client bulk1 Close failed: %v", err) + } + if err := accepted1.Bulk.Close(); err != nil { + t.Fatalf("server bulk1 Close failed: %v", err) + } + waitForBulkContextDone(t, bulk1.Context(), 2*time.Second) + waitForBulkContextDone(t, accepted1.Bulk.Context(), 2*time.Second) + + if client.clientDedicatedSidecarSnapshot() == nil { + t.Fatal("shared dedicated sidecar should stay alive after closing only one bulk") + } + if _, err := bulk2.Write([]byte("reuse-ok")); err != nil { + t.Fatalf("client bulk2 Write failed after bulk1 close: %v", err) + } + readBulkExactly(t, accepted2.Bulk, "reuse-ok", 2*time.Second) + + if err := bulk2.Close(); err != nil { + t.Fatalf("client bulk2 Close failed: %v", err) + } + if err := accepted2.Bulk.Close(); err != nil { + t.Fatalf("server bulk2 Close failed: %v", err) + } +} + +func TestSharedDedicatedSidecarStaleDataIDRejectDoesNotBreakOtherBulks(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 4) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + cfg := client.BulkDedicatedAttachConfig() + cfg.LaneLimit = 1 + client.SetBulkDedicatedAttachConfig(cfg) + + bulk1, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "stale-sidecar-1", + Range: BulkRange{ + Offset: 0, + Length: 16, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk #1 failed: %v", err) + } + accepted1 := waitAcceptedBulkByID(t, acceptCh, bulk1.ID(), 2*time.Second) + + bulk2, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: "stale-sidecar-2", + Range: BulkRange{ + Offset: 16, + Length: 16, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk #2 failed: %v", err) + } + accepted2 := waitAcceptedBulkByID(t, acceptCh, bulk2.ID(), 2*time.Second) + + clientHandle1 := bulk1.(*bulkHandle) + clientHandle2 := bulk2.(*bulkHandle) + clientConn1 := clientHandle1.dedicatedConnSnapshot() + clientConn2 := clientHandle2.dedicatedConnSnapshot() + if clientConn1 == nil || clientConn2 == nil || clientConn1 != clientConn2 { + t.Fatal("client dedicated bulks should share the same sidecar conn") + } + + staleDataID := clientHandle1.dataIDSnapshot() + if staleDataID == 0 { + t.Fatal("stale data id should not be zero") + } + + if err := bulk1.Close(); err != nil { + t.Fatalf("client bulk1 Close failed: %v", err) + } + if err := accepted1.Bulk.Close(); err != nil { + t.Fatalf("server bulk1 Close failed: %v", err) + } + waitForBulkContextDone(t, bulk1.Context(), 2*time.Second) + waitForBulkContextDone(t, accepted1.Bulk.Context(), 2*time.Second) + + stalePayload, err := client.encodeDedicatedBulkBatchPayload(staleDataID, []bulkDedicatedSendRequest{{ + Type: bulkFastPayloadTypeData, + Seq: 1, + Payload: []byte("stale-data"), + }}) + if err != nil { + t.Fatalf("encodeDedicatedBulkBatchPayload failed: %v", err) + } + if err := writeBulkDedicatedRecord(clientConn2, stalePayload); err != nil { + t.Fatalf("writeBulkDedicatedRecord stale payload failed: %v", err) + } + + deadline := time.Now().Add(200 * time.Millisecond) + for time.Now().Before(deadline) { + sidecar := client.clientDedicatedSidecarSnapshot() + if sidecar == nil || sidecar.conn != clientConn2 { + t.Fatalf("shared dedicated sidecar should remain alive after stale data reject, got %+v", sidecar) + } + time.Sleep(10 * time.Millisecond) + } + + if _, err := bulk2.Write([]byte("still-alive")); err != nil { + t.Fatalf("client bulk2 Write failed after stale data reject: %v", err) + } + readBulkExactly(t, accepted2.Bulk, "still-alive", 2*time.Second) + + if err := bulk2.Close(); err != nil { + t.Fatalf("client bulk2 Close failed: %v", err) + } + if err := accepted2.Bulk.Close(); err != nil { + t.Fatalf("server bulk2 Close failed: %v", err) + } +} + +func TestDedicatedBulkConcurrentOpenSingleflightAttach(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 4) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + server.SetLink(systemBulkAttachKey, func(msg *Message) { + time.Sleep(120 * time.Millisecond) + _ = server.handleBulkAttachSystemMessage(*msg) + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + cfg := client.BulkDedicatedAttachConfig() + cfg.AttachLimit = 4 + cfg.LaneLimit = 1 + client.SetBulkDedicatedAttachConfig(cfg) + + results := make([]Bulk, 2) + errs := make([]error, 2) + ids := []string{"concurrent-1", "concurrent-2"} + start := make(chan struct{}) + var wg sync.WaitGroup + for i := range ids { + wg.Add(1) + go func(index int) { + defer wg.Done() + <-start + results[index], errs[index] = client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: ids[index], + Range: BulkRange{ + Offset: int64(index * 16), + Length: 16, + }, + }) + }(i) + } + close(start) + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Fatalf("client OpenDedicatedBulk #%d failed: %v", i+1, err) + } + } + + acceptedByID := make(map[string]BulkAcceptInfo, 2) + for len(acceptedByID) < 2 { + info := waitAcceptedBulk(t, acceptCh, 2*time.Second) + acceptedByID[info.ID] = info + } + accepted1 := acceptedByID[ids[0]] + accepted2 := acceptedByID[ids[1]] + + clientHandle1 := results[0].(*bulkHandle) + clientHandle2 := results[1].(*bulkHandle) + serverHandle1 := accepted1.Bulk.(*bulkHandle) + serverHandle2 := accepted2.Bulk.(*bulkHandle) + + if got := client.bulkAttachAttemptCount.Load(); got != 1 { + t.Fatalf("client bulkAttachAttemptCount = %d, want 1", got) + } + if got := client.bulkAttachSuccessCount.Load(); got != 1 { + t.Fatalf("client bulkAttachSuccessCount = %d, want 1", got) + } + + clientConn1 := clientHandle1.dedicatedConnSnapshot() + clientConn2 := clientHandle2.dedicatedConnSnapshot() + if clientConn1 == nil || clientConn2 == nil || clientConn1 != clientConn2 { + t.Fatal("concurrent dedicated bulks should share one client sidecar conn") + } + serverConn1 := serverHandle1.dedicatedConnSnapshot() + serverConn2 := serverHandle2.dedicatedConnSnapshot() + if serverConn1 == nil || serverConn2 == nil || serverConn1 != serverConn2 { + t.Fatal("concurrent dedicated bulks should share one server sidecar conn") + } + + if _, err := results[0].Write([]byte("c1")); err != nil { + t.Fatalf("client bulk1 Write failed: %v", err) + } + if _, err := results[1].Write([]byte("c2")); err != nil { + t.Fatalf("client bulk2 Write failed: %v", err) + } + readBulkExactly(t, accepted1.Bulk, "c1", 2*time.Second) + readBulkExactly(t, accepted2.Bulk, "c2", 2*time.Second) + + for _, bulk := range results { + if bulk != nil { + _ = bulk.Close() + } + } + _ = accepted1.Bulk.Close() + _ = accepted2.Bulk.Close() +} + +func TestDedicatedBulkConcurrentPendingRejectPropagatesOpenError(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan string, 2) + badDispatchCh := make(chan struct{}, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + if info.ID == "pending-reject-bad" { + select { + case badDispatchCh <- struct{}{}: + default: + } + return errors.New("server rejected pending dedicated bulk") + } + acceptCh <- info.ID + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + transportConn := client.clientTransportConnSnapshot() + if transportConn == nil || transportConn.RemoteAddr() == nil { + t.Fatal("client transport snapshot should not be nil") + } + network := transportConn.RemoteAddr().Network() + addr := transportConn.RemoteAddr().String() + client.setClientConnectSource(newClientFactoryConnectSource(func(ctx context.Context) (net.Conn, error) { + timer := time.NewTimer(120 * time.Millisecond) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) + })) + + cfg := client.BulkDedicatedAttachConfig() + cfg.AttachLimit = 4 + cfg.LaneLimit = 1 + client.SetBulkDedicatedAttachConfig(cfg) + + type openResult struct { + id string + bulk Bulk + err error + } + resultCh := make(chan openResult, 2) + start := make(chan struct{}) + for _, id := range []string{"pending-reject-good", "pending-reject-bad"} { + go func(id string) { + <-start + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + bulk, err := client.OpenDedicatedBulk(ctx, BulkOpenOptions{ + ID: id, + Range: BulkRange{ + Offset: 0, + Length: 16, + }, + }) + resultCh <- openResult{id: id, bulk: bulk, err: err} + }(id) + } + close(start) + + results := make(map[string]openResult, 2) + for len(results) < 2 { + result := <-resultCh + results[result.id] = result + } + + good := results["pending-reject-good"] + if good.err != nil { + t.Fatalf("good dedicated open failed: %v", good.err) + } + defer func() { + if good.bulk != nil { + _ = good.bulk.Close() + } + }() + + bad := results["pending-reject-bad"] + if bad.err == nil { + dispatched := false + select { + case <-badDispatchCh: + dispatched = true + default: + } + if bad.bulk == nil { + t.Fatalf("bad dedicated open should fail after remote pending accept rejection: dispatched=%v", dispatched) + } + if handle, ok := bad.bulk.(*bulkHandle); ok { + t.Fatalf("bad dedicated open should fail after remote pending accept rejection: dispatched=%v snapshot=%+v resetErr=%v", dispatched, bad.bulk.Snapshot(), handle.resetErrSnapshot()) + } + t.Fatalf("bad dedicated open should fail after remote pending accept rejection: dispatched=%v snapshot=%+v", dispatched, bad.bulk.Snapshot()) + } + if !strings.Contains(bad.err.Error(), "server rejected pending dedicated bulk") { + t.Fatalf("bad dedicated open error = %v, want remote reject detail", bad.err) + } + + if got := client.bulkAttachAttemptCount.Load(); got != 1 { + t.Fatalf("client bulkAttachAttemptCount = %d, want 1", got) + } + if got := client.bulkAttachSuccessCount.Load(); got != 1 { + t.Fatalf("client bulkAttachSuccessCount = %d, want 1", got) + } + + select { + case acceptedID := <-acceptCh: + if acceptedID != "pending-reject-good" { + t.Fatalf("accepted bulk id = %q, want %q", acceptedID, "pending-reject-good") + } + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for accepted pending dedicated bulk") + } + select { + case extra := <-acceptCh: + t.Fatalf("unexpected extra accepted bulk id: %q", extra) + case <-time.After(300 * time.Millisecond): + } +} + +func TestDedicatedBulkLanePoolSpreadsAcrossConfiguredLanes(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 6) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + cfg := client.BulkDedicatedAttachConfig() + cfg.AttachLimit = 4 + cfg.LaneLimit = 2 + client.SetBulkDedicatedAttachConfig(cfg) + + ids := []string{"lane-1", "lane-2", "lane-3"} + results := make([]Bulk, 0, len(ids)) + acceptedByID := make(map[string]BulkAcceptInfo, len(ids)) + for index, id := range ids { + bulk, err := client.OpenDedicatedBulk(context.Background(), BulkOpenOptions{ + ID: id, + Range: BulkRange{ + Offset: int64(index * 16), + Length: 16, + }, + }) + if err != nil { + t.Fatalf("client OpenDedicatedBulk %q failed: %v", id, err) + } + results = append(results, bulk) + acceptedByID[id] = waitAcceptedBulkByID(t, acceptCh, id, 2*time.Second) + } + defer func() { + for _, bulk := range results { + if bulk != nil { + _ = bulk.Close() + } + } + for _, accepted := range acceptedByID { + _ = accepted.Bulk.Close() + } + }() + + clientConns := make(map[net.Conn]struct{}) + serverConns := make(map[net.Conn]struct{}) + laneIDs := make([]uint32, 0, len(results)) + for _, bulk := range results { + handle := bulk.(*bulkHandle) + conn := handle.dedicatedConnSnapshot() + if conn == nil { + t.Fatal("client dedicated conn should not be nil") + } + clientConns[conn] = struct{}{} + laneIDs = append(laneIDs, handle.dedicatedLaneIDSnapshot()) + } + for _, accepted := range acceptedByID { + handle := accepted.Bulk.(*bulkHandle) + conn := handle.dedicatedConnSnapshot() + if conn == nil { + t.Fatal("server dedicated conn should not be nil") + } + serverConns[conn] = struct{}{} + } + + if got, want := len(clientConns), 2; got != want { + t.Fatalf("client dedicated lane conn count = %d, want %d", got, want) + } + if got, want := len(serverConns), 2; got != want { + t.Fatalf("server dedicated lane conn count = %d, want %d", got, want) + } + if got, want := client.bulkAttachAttemptCount.Load(), int64(2); got != want { + t.Fatalf("client bulkAttachAttemptCount = %d, want %d", got, want) + } + if got, want := client.bulkAttachSuccessCount.Load(), int64(2); got != want { + t.Fatalf("client bulkAttachSuccessCount = %d, want %d", got, want) + } + if laneIDs[0] == 0 || laneIDs[1] == 0 || laneIDs[2] == 0 { + t.Fatalf("dedicated lane ids should be assigned, got %v", laneIDs) + } + if laneIDs[0] == laneIDs[1] { + t.Fatalf("first two bulks should spread across lanes, got %v", laneIDs[:2]) + } + if laneIDs[2] != laneIDs[0] && laneIDs[2] != laneIDs[1] { + t.Fatalf("third bulk should reuse an existing lane, got %v", laneIDs) + } +} + +func TestServerOpenBulkLogicalAutoUDPFallsBackToShared(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + if err := server.Listen("udp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 2) + client.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + logical := waitForTransferControlLogicalConn(t, server, 2*time.Second) + bulk, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{ + Mode: BulkOpenModeAuto, + Range: BulkRange{ + Offset: 0, + Length: 128, + }, + }) + if err != nil { + t.Fatalf("server OpenBulkLogical auto over udp failed: %v", err) + } + if bulk.Snapshot().Dedicated { + t.Fatal("server OpenBulkLogical auto over udp should use shared") + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulkByID(t, acceptCh, bulk.ID(), 2*time.Second) + if accepted.Dedicated { + t.Fatal("client accepted bulk should be shared for server auto over udp") + } + _ = accepted.Bulk.Close() +} + +func TestBulkDedicatedAttachConfigNormalize(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{ + AttachLimit: -1, + ActiveLimit: -2, + LaneLimit: -3, + Retry: -3, + Backoff: 0, + DialTimeout: 0, + HelloTimeout: 0, + }) + cfg := client.BulkDedicatedAttachConfig() + if cfg.AttachLimit != 0 { + t.Fatalf("AttachLimit = %d, want 0", cfg.AttachLimit) + } + if cfg.ActiveLimit != 0 { + t.Fatalf("ActiveLimit = %d, want 0", cfg.ActiveLimit) + } + if cfg.LaneLimit != 0 { + t.Fatalf("LaneLimit = %d, want 0", cfg.LaneLimit) + } + if cfg.Retry != 0 { + t.Fatalf("Retry = %d, want 0", cfg.Retry) + } + if cfg.Backoff <= 0 || cfg.DialTimeout <= 0 || cfg.HelloTimeout <= 0 { + t.Fatalf("normalized timeouts should be >0: %+v", cfg) + } + if client.bulkDedicatedAttachSem != nil { + t.Fatal("bulk dedicated attach semaphore should be nil when AttachLimit=0") + } +} + +func TestBulkDedicatedAttachDefaults(t *testing.T) { + client := NewClient().(*ClientCommon) + cfg := client.BulkDedicatedAttachConfig() + if got, want := cfg.AttachLimit, 16; got != want { + t.Fatalf("default AttachLimit = %d, want %d", got, want) + } + if got, want := cfg.ActiveLimit, 4096; got != want { + t.Fatalf("default ActiveLimit = %d, want %d", got, want) + } + if got, want := cfg.LaneLimit, 4; got != want { + t.Fatalf("default LaneLimit = %d, want %d", got, want) + } + + client.SetBulkNetworkProfile(BulkNetworkProfileWAN) + cfg = client.BulkDedicatedAttachConfig() + if got, want := cfg.AttachLimit, 2; got != want { + t.Fatalf("WAN AttachLimit = %d, want %d", got, want) + } + if got, want := cfg.ActiveLimit, 4096; got != want { + t.Fatalf("WAN ActiveLimit = %d, want %d", got, want) + } + if got, want := cfg.LaneLimit, 2; got != want { + t.Fatalf("WAN LaneLimit = %d, want %d", got, want) + } +} + +func TestBulkDedicatedDefaultsSupportHighLogicalConcurrency(t *testing.T) { + client := NewClient().(*ClientCommon) + cfg := client.BulkDedicatedAttachConfig() + if cfg.ActiveLimit < 2048 { + t.Fatalf("default ActiveLimit = %d, want >= 2048", cfg.ActiveLimit) + } + if cfg.LaneLimit != 4 { + t.Fatalf("default LaneLimit = %d, want 4", cfg.LaneLimit) + } + + const logicalBulks = 2048 + for i := 0; i < logicalBulks; i++ { + if !client.reserveBulkDedicatedActiveSlot() { + t.Fatalf("reserveBulkDedicatedActiveSlot failed at %d/%d", i+1, logicalBulks) + } + } + for i := logicalBulks; i < cfg.ActiveLimit; i++ { + if !client.reserveBulkDedicatedActiveSlot() { + t.Fatalf("reserveBulkDedicatedActiveSlot failed before reaching configured limit at %d/%d", i+1, cfg.ActiveLimit) + } + } + if client.reserveBulkDedicatedActiveSlot() { + t.Fatal("reserveBulkDedicatedActiveSlot should stop at configured logical limit") + } + for i := 0; i < cfg.ActiveLimit; i++ { + client.releaseBulkDedicatedActiveSlot() + } + + laneCounts := make(map[uint32]int, cfg.LaneLimit) + for i := 0; i < logicalBulks; i++ { + laneID := client.reserveBulkDedicatedLane() + laneCounts[laneID]++ + } + if got, want := len(laneCounts), cfg.LaneLimit; got != want { + t.Fatalf("logical lane spread count = %d, want %d", got, want) + } + minCount, maxCount := logicalBulks, 0 + for laneID, count := range laneCounts { + if laneID == 0 { + t.Fatal("lane id should be normalized") + } + if count < minCount { + minCount = count + } + if count > maxCount { + maxCount = count + } + } + if maxCount-minCount > 1 { + t.Fatalf("logical lane spread too uneven: min=%d max=%d counts=%v", minCount, maxCount, laneCounts) + } + for laneID, count := range laneCounts { + for i := 0; i < count; i++ { + client.releaseBulkDedicatedLane(laneID) + } + } + client.bulkDedicatedSidecarMu.Lock() + remaining := len(client.bulkDedicatedLanes) + client.bulkDedicatedSidecarMu.Unlock() + if remaining != 0 { + t.Fatalf("bulk dedicated lanes should be released, remaining=%d", remaining) + } +} + +func TestBulkOpenTuningNormalize(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetBulkOpenTuning(BulkOpenTuning{ + ChunkSize: -1, + WindowBytes: 1, + MaxInFlight: 0, + }) + tuning := client.BulkOpenTuning() + if got, want := tuning.ChunkSize, defaultBulkChunkSize; got != want { + t.Fatalf("ChunkSize = %d, want %d", got, want) + } + if got, want := tuning.WindowBytes, defaultBulkChunkSize; got != want { + t.Fatalf("WindowBytes = %d, want %d", got, want) + } + if got, want := tuning.MaxInFlight, defaultBulkOpenMaxInFlight; got != want { + t.Fatalf("MaxInFlight = %d, want %d", got, want) + } +} + +func TestBulkNetworkProfileAppliesOpenTuning(t *testing.T) { + client := NewClient().(*ClientCommon) + tuning := client.BulkOpenTuning() + if got, want := tuning.ChunkSize, defaultBulkChunkSize; got != want { + t.Fatalf("default ChunkSize = %d, want %d", got, want) + } + if got, want := tuning.WindowBytes, defaultBulkOpenWindowBytes; got != want { + t.Fatalf("default WindowBytes = %d, want %d", got, want) + } + if got, want := tuning.MaxInFlight, defaultBulkOpenMaxInFlight; got != want { + t.Fatalf("default MaxInFlight = %d, want %d", got, want) + } + + client.SetBulkNetworkProfile(BulkNetworkProfileWAN) + tuning = client.BulkOpenTuning() + if got, want := tuning.ChunkSize, 512*1024; got != want { + t.Fatalf("WAN ChunkSize = %d, want %d", got, want) + } + if got, want := tuning.WindowBytes, 8*1024*1024; got != want { + t.Fatalf("WAN WindowBytes = %d, want %d", got, want) + } + if got, want := tuning.MaxInFlight, 16; got != want { + t.Fatalf("WAN MaxInFlight = %d, want %d", got, want) + } + + client.SetBulkNetworkProfile(BulkNetworkProfileConstrained) + tuning = client.BulkOpenTuning() + if got, want := tuning.ChunkSize, 128*1024; got != want { + t.Fatalf("Constrained ChunkSize = %d, want %d", got, want) + } + if got, want := tuning.WindowBytes, 1024*1024; got != want { + t.Fatalf("Constrained WindowBytes = %d, want %d", got, want) + } + if got, want := tuning.MaxInFlight, 8; got != want { + t.Fatalf("Constrained MaxInFlight = %d, want %d", got, want) + } +} + +func TestOpenBulkAppliesClientOpenTuningDefaults(t *testing.T) { + server := NewServer().(*ServerCommon) + if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKServer failed: %v", err) + } + acceptCh := make(chan BulkAcceptInfo, 1) + server.SetBulkHandler(func(info BulkAcceptInfo) error { + acceptCh <- info + return nil + }) + if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + t.Fatalf("server Listen failed: %v", err) + } + defer func() { + _ = server.Stop() + }() + + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + t.Fatalf("client Connect failed: %v", err) + } + defer func() { + _ = client.Stop() + }() + + client.SetBulkOpenTuning(BulkOpenTuning{ + ChunkSize: 256 * 1024, + WindowBytes: 2 * 1024 * 1024, + MaxInFlight: 7, + }) + tuning := client.BulkOpenTuning() + if got, want := tuning.ChunkSize, 256*1024; got != want { + t.Fatalf("client tuning ChunkSize = %d, want %d", got, want) + } + if got, want := tuning.WindowBytes, 2*1024*1024; got != want { + t.Fatalf("client tuning WindowBytes = %d, want %d", got, want) + } + if got, want := tuning.MaxInFlight, 7; got != want { + t.Fatalf("client tuning MaxInFlight = %d, want %d", got, want) + } + + bulk, err := client.OpenSharedBulk(context.Background(), BulkOpenOptions{ + Range: BulkRange{ + Offset: 0, + Length: 512 * 1024, + }, + }) + if err != nil { + t.Fatalf("client OpenSharedBulk failed: %v", err) + } + defer func() { + _ = bulk.Close() + }() + + accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second) + defer func() { + _ = accepted.Bulk.Close() + }() + + clientSnapshot := bulk.Snapshot() + if got, want := clientSnapshot.ChunkSize, 256*1024; got != want { + t.Fatalf("client ChunkSize = %d, want %d", got, want) + } + if got, want := clientSnapshot.WindowBytes, 2*1024*1024; got != want { + t.Fatalf("client WindowBytes = %d, want %d", got, want) + } + if got, want := clientSnapshot.MaxInFlight, 7; got != want { + t.Fatalf("client MaxInFlight = %d, want %d", got, want) + } + + serverSnapshot := accepted.Bulk.Snapshot() + if got, want := serverSnapshot.ChunkSize, 256*1024; got != want { + t.Fatalf("server ChunkSize = %d, want %d", got, want) + } + if got, want := serverSnapshot.WindowBytes, 2*1024*1024; got != want { + t.Fatalf("server WindowBytes = %d, want %d", got, want) + } + if got, want := serverSnapshot.MaxInFlight, 7; got != want { + t.Fatalf("server MaxInFlight = %d, want %d", got, want) + } +} + func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) { server := NewServer().(*ServerCommon) if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil { @@ -83,3 +1356,22 @@ func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) { t.Fatalf("server OpenBulkLogical dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly) } } + +func waitAcceptedBulkByID(t *testing.T, ch <-chan BulkAcceptInfo, id string, timeout time.Duration) BulkAcceptInfo { + t.Helper() + deadline := time.Now().Add(timeout) + for { + remain := time.Until(deadline) + if remain <= 0 { + t.Fatalf("timed out waiting for accepted bulk id=%q", id) + } + select { + case info := <-ch: + if info.ID == id { + return info + } + case <-time.After(remain): + t.Fatalf("timed out waiting for accepted bulk id=%q", id) + } + } +} diff --git a/client.go b/client.go index f0e67e2..5d4b8fc 100644 --- a/client.go +++ b/client.go @@ -35,6 +35,7 @@ type ClientCommon struct { fastStreamEncode transportFastStreamEncoder fastBulkEncode transportFastBulkEncoder fastPlainEncode transportFastPlainEncoder + modernPSKRuntime *modernPSKCodecRuntime handshakeRsaPubKey []byte SecretKey []byte noFinSyncMsgMaxKeepSeconds int @@ -55,6 +56,26 @@ type ClientCommon struct { streamRuntime *streamRuntime recordRuntime *recordRuntime bulkRuntime *bulkRuntime + bulkDefaultOpenMode BulkOpenMode + bulkNetworkProfile BulkNetworkProfile + bulkOpenTuning BulkOpenTuning + bulkDedicatedAttachLimit int + bulkDedicatedAttachSem chan struct{} + bulkDedicatedAttachRetry int + bulkDedicatedAttachBackoff time.Duration + bulkDedicatedDialTimeout time.Duration + bulkDedicatedHelloTimeout time.Duration + bulkDedicatedActiveLimit int + bulkDedicatedActive atomic.Int32 + bulkDedicatedActiveWait chan struct{} + bulkDedicatedLaneLimit int + bulkDedicatedSidecarMu sync.Mutex + bulkDedicatedLanes map[uint32]*bulkDedicatedLane + bulkDedicatedNextLaneID uint32 + bulkAttachAttemptCount atomic.Int64 + bulkAttachRetryCount atomic.Int64 + bulkAttachSuccessCount atomic.Int64 + bulkAttachFallbackCount atomic.Int64 connectionRetryState *connectionRetryState securityReadyCheck bool debugMode bool @@ -63,21 +84,32 @@ type ClientCommon struct { func NewClient() Client { transport := defaultModernPSKTransportBundle() var client = ClientCommon{ - maxReadTimeout: 0, - maxWriteTimeout: 0, - peerIdentity: newClientPeerIdentity(), - sequenceEn: encode, - sequenceDe: Decode, - keyExchangeFn: aesRsaHello, - SecretKey: nil, - handshakeRsaPubKey: defaultRsaPubKey, - msgEn: transport.msgEn, - msgDe: transport.msgDe, - fastStreamEncode: transport.fastStreamEncode, - fastBulkEncode: transport.fastBulkEncode, - fastPlainEncode: transport.fastPlainEncode, - skipKeyExchange: true, - securityReadyCheck: true, + maxReadTimeout: 0, + maxWriteTimeout: 0, + peerIdentity: newClientPeerIdentity(), + sequenceEn: encode, + sequenceDe: Decode, + keyExchangeFn: aesRsaHello, + SecretKey: nil, + handshakeRsaPubKey: defaultRsaPubKey, + msgEn: transport.msgEn, + msgDe: transport.msgDe, + fastStreamEncode: transport.fastStreamEncode, + fastBulkEncode: transport.fastBulkEncode, + fastPlainEncode: transport.fastPlainEncode, + skipKeyExchange: true, + securityReadyCheck: true, + bulkDefaultOpenMode: BulkOpenModeShared, + bulkNetworkProfile: BulkNetworkProfileDefault, + bulkOpenTuning: defaultBulkOpenTuning(), + bulkDedicatedAttachLimit: defaultBulkDedicatedAttachLimit, + bulkDedicatedAttachRetry: defaultBulkDedicatedAttachRetry, + bulkDedicatedAttachBackoff: defaultBulkDedicatedAttachBackoff, + bulkDedicatedDialTimeout: defaultBulkDedicatedDialTimeout, + bulkDedicatedHelloTimeout: defaultBulkDedicatedHelloTimeout, + bulkDedicatedActiveLimit: defaultBulkDedicatedActiveLimit, + bulkDedicatedActiveWait: make(chan struct{}), + bulkDedicatedLaneLimit: defaultBulkDedicatedLaneLimit, } client.alive.Store(false) client.useHeartBeat = true @@ -93,6 +125,10 @@ func NewClient() Client { client.streamRuntime = newStreamRuntime("cstrm") client.recordRuntime = newRecordRuntime() client.bulkRuntime = newBulkRuntime("cblk") + client.bulkDedicatedLanes = make(map[uint32]*bulkDedicatedLane) + if client.bulkDedicatedAttachLimit > 0 { + client.bulkDedicatedAttachSem = make(chan struct{}, client.bulkDedicatedAttachLimit) + } client.connectionRetryState = newConnectionRetryState() client.onFileEvent = normalizeFileEventCallback(nil) client.fileEventObserver = normalizeFileEventCallback(nil) @@ -103,3 +139,12 @@ func NewClient() Client { client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler) return &client } + +func (c *ClientCommon) maxWriteTimeoutSnapshot() time.Duration { + if c == nil { + return 0 + } + c.mu.Lock() + defer c.mu.Unlock() + return c.maxWriteTimeout +} diff --git a/client_bulk.go b/client_bulk.go index 4190c4a..81a3467 100644 --- a/client_bulk.go +++ b/client_bulk.go @@ -1,6 +1,9 @@ package notify -import "context" +import ( + "context" + "errors" +) func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { runtime := c.getBulkRuntime() @@ -10,10 +13,67 @@ func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { runtime.setHandler(fn) } +func (c *ClientCommon) OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return c.OpenBulk(ctx, opt) +} + +func (c *ClientCommon) OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { + opt.Mode = BulkOpenModeDedicated + opt.Dedicated = true + return c.OpenBulk(ctx, opt) +} + func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { + if normalizeBulkOpenMode(opt.Mode) == BulkOpenModeDefault && !opt.Dedicated { + opt.Mode = c.bulkDefaultOpenModeSnapshot() + } + opt = normalizeBulkOpenOptions(opt) + switch opt.Mode { + case BulkOpenModeDedicated: + opt.Dedicated = true + return c.openBulkWithDedicatedMode(ctx, opt) + case BulkOpenModeAuto: + // Auto mode prefers dedicated path and falls back to shared if dedicated fails. + if err := clientDedicatedBulkSupportError(c); err == nil { + dedicatedOpt := opt + dedicatedOpt.Mode = BulkOpenModeDedicated + dedicatedOpt.Dedicated = true + bulk, dedicatedErr := c.openBulkWithDedicatedMode(ctx, dedicatedOpt) + if dedicatedErr == nil { + return bulk, nil + } + sharedOpt := opt + sharedOpt.Mode = BulkOpenModeShared + sharedOpt.Dedicated = false + sharedBulk, sharedErr := c.openBulkWithDedicatedMode(ctx, sharedOpt) + if sharedErr == nil { + c.bulkAttachFallbackCount.Add(1) + return sharedBulk, nil + } + return nil, errors.Join(dedicatedErr, sharedErr) + } + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + c.bulkAttachFallbackCount.Add(1) + return c.openBulkWithDedicatedMode(ctx, opt) + case BulkOpenModeShared, BulkOpenModeDefault: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return c.openBulkWithDedicatedMode(ctx, opt) + default: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return c.openBulkWithDedicatedMode(ctx, opt) + } +} + +func (c *ClientCommon) openBulkWithDedicatedMode(ctx context.Context, opt BulkOpenOptions) (Bulk, error) { if c == nil { return nil, errBulkClientNil } + opt = applyBulkOpenTuningDefaults(opt, c.bulkOpenTuningSnapshot()) runtime := c.getBulkRuntime() if runtime == nil { return nil, errBulkRuntimeNil @@ -33,6 +93,66 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists { return nil, errBulkAlreadyExists } + if req.Dedicated { + req.DedicatedLaneID = c.reserveBulkDedicatedLane() + if req.DataID == 0 { + req.DataID = runtime.nextDataID() + } + if req.AttachToken == "" { + req.AttachToken = newBulkAttachToken() + } + bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c)) + bulk.setClientSnapshotOwner(c) + bulk.markAcceptHandled() + if err := runtime.register(clientFileScope(), bulk); err != nil { + c.releaseBulkDedicatedLane(req.DedicatedLaneID) + return nil, err + } + resp, err := sendBulkOpenClient(ctx, c, req) + if err != nil { + bulk.markReset(err) + return nil, err + } + if resp.DataID != 0 && resp.DataID != req.DataID { + err = errBulkAlreadyExists + _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: "bulk dedicated data id mismatch", + }) + bulk.markReset(err) + return nil, err + } + if resp.TransportGeneration != 0 { + bulk.transportGeneration = resp.TransportGeneration + } + if resp.FastPathVersion != 0 { + bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion) + } + if resp.AttachToken != "" { + req.AttachToken = resp.AttachToken + bulk.setDedicatedAttachToken(resp.AttachToken) + } + if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil { + _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + bulk.markReset(err) + return nil, err + } + if err := bulk.waitAcceptReady(ctx); err != nil { + _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + bulk.markReset(err) + return nil, err + } + return bulk, nil + } resp, err := sendBulkOpenClient(ctx, c, req) if err != nil { return nil, err @@ -40,6 +160,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } req.Dedicated = resp.Dedicated if resp.AttachToken != "" { req.AttachToken = resp.AttachToken @@ -49,7 +172,9 @@ func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, } bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c)) bulk.setClientSnapshotOwner(c) + bulk.markAcceptHandled() if err := runtime.register(clientFileScope(), bulk); err != nil { + c.releaseBulkDedicatedLane(req.DedicatedLaneID) _, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{ BulkID: req.BulkID, DataID: req.DataID, @@ -78,15 +203,16 @@ func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques id = runtime.nextID() } return normalizeBulkOpenRequest(BulkOpenRequest{ - BulkID: id, - Range: opt.Range, - Metadata: cloneBulkMetadata(opt.Metadata), - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, - Dedicated: opt.Dedicated, - ChunkSize: opt.ChunkSize, - WindowBytes: opt.WindowBytes, - MaxInFlight: opt.MaxInFlight, + BulkID: id, + FastPathVersion: bulkFastPathVersionCurrent, + Range: opt.Range, + Metadata: cloneBulkMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + Dedicated: opt.Dedicated, + ChunkSize: opt.ChunkSize, + WindowBytes: opt.WindowBytes, + MaxInFlight: opt.MaxInFlight, }) } @@ -148,12 +274,12 @@ func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender { if dataID == 0 { return errBulkDataPathNotReady } - return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk) + return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot()) } } func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender { - return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) { if c == nil { return 0, errBulkClientNil } @@ -168,12 +294,19 @@ func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender { if err := bulk.waitDedicatedReady(ctx); err != nil { return 0, err } - return c.sendDedicatedBulkWrite(ctx, bulk, payload) + return c.sendDedicatedBulkWrite(ctx, bulk, startSeq, payload) } if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) { return 0, errTransportDetached } - return 0, nil + if bulk == nil { + return 0, errBulkRuntimeNil + } + dataID := bulk.dataIDSnapshot() + if dataID == 0 { + return 0, errBulkDataPathNotReady + } + return c.sendFastBulkWrite(ctx, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned) } } @@ -185,8 +318,20 @@ func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender { if bytes <= 0 && chunks <= 0 { return nil } + ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() if bulk.Dedicated() { - return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks) + return c.sendDedicatedBulkRelease(ctx, bulk, bytes, chunks) + } + if bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 { + payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) + if err != nil { + return err + } + return c.sendFastBulkControl(ctx, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload) } return sendBulkReleaseClient(c, BulkReleaseRequest{ BulkID: bulk.ID(), diff --git a/client_bulk_config.go b/client_bulk_config.go new file mode 100644 index 0000000..69251ee --- /dev/null +++ b/client_bulk_config.go @@ -0,0 +1,364 @@ +package notify + +import "time" + +func defaultBulkOpenTuning() BulkOpenTuning { + return normalizeBulkOpenTuning(BulkOpenTuning{ + ChunkSize: defaultBulkChunkSize, + WindowBytes: defaultBulkOpenWindowBytes, + MaxInFlight: defaultBulkOpenMaxInFlight, + }) +} + +func normalizeBulkOpenTuning(tuning BulkOpenTuning) BulkOpenTuning { + if tuning.ChunkSize <= 0 { + tuning.ChunkSize = defaultBulkChunkSize + } + if tuning.WindowBytes <= 0 { + tuning.WindowBytes = defaultBulkOpenWindowBytes + } + if tuning.WindowBytes < tuning.ChunkSize { + tuning.WindowBytes = tuning.ChunkSize + } + if tuning.MaxInFlight <= 0 { + tuning.MaxInFlight = defaultBulkOpenMaxInFlight + } + return tuning +} + +func applyBulkOpenTuningDefaults(opt BulkOpenOptions, tuning BulkOpenTuning) BulkOpenOptions { + tuning = normalizeBulkOpenTuning(tuning) + if opt.ChunkSize <= 0 { + opt.ChunkSize = tuning.ChunkSize + } + if opt.WindowBytes <= 0 { + opt.WindowBytes = tuning.WindowBytes + } + if opt.MaxInFlight <= 0 { + opt.MaxInFlight = tuning.MaxInFlight + } + return opt +} + +func normalizeBulkNetworkProfile(profile BulkNetworkProfile) BulkNetworkProfile { + switch profile { + case BulkNetworkProfileDefault, BulkNetworkProfileLAN, BulkNetworkProfileWAN, BulkNetworkProfileConstrained: + return profile + default: + return BulkNetworkProfileDefault + } +} + +func bulkNetworkProfileName(profile BulkNetworkProfile) string { + switch normalizeBulkNetworkProfile(profile) { + case BulkNetworkProfileLAN: + return "lan" + case BulkNetworkProfileWAN: + return "wan" + case BulkNetworkProfileConstrained: + return "constrained" + case BulkNetworkProfileDefault: + fallthrough + default: + return "default" + } +} + +func bulkNetworkProfilePreset(profile BulkNetworkProfile) (BulkOpenMode, BulkDedicatedAttachConfig, BulkOpenTuning) { + switch normalizeBulkNetworkProfile(profile) { + case BulkNetworkProfileLAN: + return BulkOpenModeAuto, BulkDedicatedAttachConfig{ + AttachLimit: defaultBulkDedicatedAttachLimit, + ActiveLimit: defaultBulkDedicatedActiveLimit, + LaneLimit: 4, + Retry: defaultBulkDedicatedAttachRetry, + Backoff: defaultBulkDedicatedAttachBackoff, + DialTimeout: defaultBulkDedicatedDialTimeout, + HelloTimeout: defaultBulkDedicatedHelloTimeout, + }, BulkOpenTuning{ + ChunkSize: 1024 * 1024, + WindowBytes: 32 * 1024 * 1024, + MaxInFlight: 64, + } + case BulkNetworkProfileWAN: + return BulkOpenModeAuto, BulkDedicatedAttachConfig{ + AttachLimit: 2, + ActiveLimit: defaultBulkDedicatedActiveLimit, + LaneLimit: 2, + Retry: 4, + Backoff: 250 * time.Millisecond, + DialTimeout: 15 * time.Second, + HelloTimeout: 20 * time.Second, + }, BulkOpenTuning{ + ChunkSize: 512 * 1024, + WindowBytes: 8 * 1024 * 1024, + MaxInFlight: 16, + } + case BulkNetworkProfileConstrained: + return BulkOpenModeShared, BulkDedicatedAttachConfig{ + AttachLimit: 1, + ActiveLimit: 2, + LaneLimit: 1, + Retry: 5, + Backoff: 400 * time.Millisecond, + DialTimeout: 20 * time.Second, + HelloTimeout: 30 * time.Second, + }, BulkOpenTuning{ + ChunkSize: 128 * 1024, + WindowBytes: 1 * 1024 * 1024, + MaxInFlight: 8, + } + default: + return BulkOpenModeShared, BulkDedicatedAttachConfig{ + AttachLimit: defaultBulkDedicatedAttachLimit, + ActiveLimit: defaultBulkDedicatedActiveLimit, + LaneLimit: defaultBulkDedicatedLaneLimit, + Retry: defaultBulkDedicatedAttachRetry, + Backoff: defaultBulkDedicatedAttachBackoff, + DialTimeout: defaultBulkDedicatedDialTimeout, + HelloTimeout: defaultBulkDedicatedHelloTimeout, + }, defaultBulkOpenTuning() + } +} + +func normalizeBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) BulkDedicatedAttachConfig { + if cfg.AttachLimit < 0 { + cfg.AttachLimit = 0 + } + if cfg.ActiveLimit < 0 { + cfg.ActiveLimit = 0 + } + if cfg.LaneLimit < 0 { + cfg.LaneLimit = 0 + } + if cfg.Retry < 0 { + cfg.Retry = 0 + } + if cfg.Backoff <= 0 { + cfg.Backoff = defaultBulkDedicatedAttachBackoff + } + if cfg.DialTimeout <= 0 { + cfg.DialTimeout = defaultBulkDedicatedDialTimeout + } + if cfg.HelloTimeout <= 0 { + cfg.HelloTimeout = defaultBulkDedicatedHelloTimeout + } + return cfg +} + +func (c *ClientCommon) setBulkDedicatedAttachSemaphoreLocked(limit int) { + if limit <= 0 { + c.bulkDedicatedAttachSem = nil + return + } + c.bulkDedicatedAttachSem = make(chan struct{}, limit) +} + +func (c *ClientCommon) applyBulkDedicatedAttachConfigLocked(cfg BulkDedicatedAttachConfig) { + cfg = normalizeBulkDedicatedAttachConfig(cfg) + c.bulkDedicatedAttachLimit = cfg.AttachLimit + c.setBulkDedicatedAttachSemaphoreLocked(cfg.AttachLimit) + c.bulkDedicatedActiveLimit = cfg.ActiveLimit + c.notifyBulkDedicatedActiveWaitersLocked() + c.bulkDedicatedLaneLimit = cfg.LaneLimit + c.bulkDedicatedAttachRetry = cfg.Retry + c.bulkDedicatedAttachBackoff = cfg.Backoff + c.bulkDedicatedDialTimeout = cfg.DialTimeout + c.bulkDedicatedHelloTimeout = cfg.HelloTimeout +} + +func (c *ClientCommon) ensureBulkDedicatedActiveWaitLocked() chan struct{} { + if c == nil { + return nil + } + if c.bulkDedicatedActiveWait == nil { + c.bulkDedicatedActiveWait = make(chan struct{}) + } + return c.bulkDedicatedActiveWait +} + +func (c *ClientCommon) notifyBulkDedicatedActiveWaitersLocked() { + if c == nil { + return + } + waitCh := c.ensureBulkDedicatedActiveWaitLocked() + close(waitCh) + c.bulkDedicatedActiveWait = make(chan struct{}) +} + +func (c *ClientCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) { + c.bulkOpenTuning = normalizeBulkOpenTuning(tuning) +} + +func (c *ClientCommon) bulkDefaultOpenModeSnapshot() BulkOpenMode { + if c == nil { + return BulkOpenModeShared + } + c.mu.Lock() + defer c.mu.Unlock() + mode := normalizeBulkOpenMode(c.bulkDefaultOpenMode) + if mode == BulkOpenModeDefault { + mode = BulkOpenModeShared + } + return mode +} + +func (c *ClientCommon) bulkDedicatedAttachSemaphoreSnapshot() chan struct{} { + if c == nil { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + return c.bulkDedicatedAttachSem +} + +func (c *ClientCommon) bulkDedicatedActiveLimitSnapshot() int { + if c == nil { + return 0 + } + c.mu.Lock() + defer c.mu.Unlock() + if c.bulkDedicatedActiveLimit < 0 { + return 0 + } + return c.bulkDedicatedActiveLimit +} + +func (c *ClientCommon) bulkDedicatedActiveWaitSnapshot() chan struct{} { + if c == nil { + return nil + } + c.mu.Lock() + defer c.mu.Unlock() + return c.ensureBulkDedicatedActiveWaitLocked() +} + +func (c *ClientCommon) bulkDedicatedLaneLimitSnapshot() int { + if c == nil { + return defaultBulkDedicatedLaneLimit + } + c.mu.Lock() + defer c.mu.Unlock() + limit := c.bulkDedicatedLaneLimit + if limit < 0 { + return 0 + } + if limit == 0 { + return 0 + } + return limit +} + +func (c *ClientCommon) SetBulkDefaultOpenMode(mode BulkOpenMode) { + if c == nil { + return + } + mode = normalizeBulkOpenMode(mode) + if mode == BulkOpenModeDefault { + mode = BulkOpenModeShared + } + c.mu.Lock() + c.bulkDefaultOpenMode = mode + c.mu.Unlock() +} + +func (c *ClientCommon) BulkDefaultOpenMode() BulkOpenMode { + return c.bulkDefaultOpenModeSnapshot() +} + +func (c *ClientCommon) bulkOpenTuningSnapshot() BulkOpenTuning { + if c == nil { + return defaultBulkOpenTuning() + } + c.mu.Lock() + defer c.mu.Unlock() + return normalizeBulkOpenTuning(c.bulkOpenTuning) +} + +func (c *ClientCommon) SetBulkOpenTuning(tuning BulkOpenTuning) { + if c == nil { + return + } + c.mu.Lock() + c.applyBulkOpenTuningLocked(tuning) + c.mu.Unlock() +} + +func (c *ClientCommon) BulkOpenTuning() BulkOpenTuning { + return c.bulkOpenTuningSnapshot() +} + +func (c *ClientCommon) SetBulkDedicatedAttachConfig(cfg BulkDedicatedAttachConfig) { + if c == nil { + return + } + c.mu.Lock() + c.applyBulkDedicatedAttachConfigLocked(cfg) + c.mu.Unlock() +} + +func (c *ClientCommon) BulkDedicatedAttachConfig() BulkDedicatedAttachConfig { + if c == nil { + return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{}) + } + c.mu.Lock() + defer c.mu.Unlock() + return normalizeBulkDedicatedAttachConfig(BulkDedicatedAttachConfig{ + AttachLimit: c.bulkDedicatedAttachLimit, + ActiveLimit: c.bulkDedicatedActiveLimit, + LaneLimit: c.bulkDedicatedLaneLimit, + Retry: c.bulkDedicatedAttachRetry, + Backoff: c.bulkDedicatedAttachBackoff, + DialTimeout: c.bulkDedicatedDialTimeout, + HelloTimeout: c.bulkDedicatedHelloTimeout, + }) +} + +func (c *ClientCommon) SetBulkNetworkProfile(profile BulkNetworkProfile) { + if c == nil { + return + } + profile = normalizeBulkNetworkProfile(profile) + mode, cfg, tuning := bulkNetworkProfilePreset(profile) + c.mu.Lock() + c.bulkNetworkProfile = profile + c.bulkDefaultOpenMode = mode + c.applyBulkDedicatedAttachConfigLocked(cfg) + c.applyBulkOpenTuningLocked(tuning) + c.mu.Unlock() +} + +func (c *ClientCommon) BulkNetworkProfile() BulkNetworkProfile { + if c == nil { + return BulkNetworkProfileDefault + } + c.mu.Lock() + defer c.mu.Unlock() + return normalizeBulkNetworkProfile(c.bulkNetworkProfile) +} + +func (s *ServerCommon) applyBulkOpenTuningLocked(tuning BulkOpenTuning) { + s.bulkOpenTuning = normalizeBulkOpenTuning(tuning) +} + +func (s *ServerCommon) bulkOpenTuningSnapshot() BulkOpenTuning { + if s == nil { + return defaultBulkOpenTuning() + } + s.mu.RLock() + defer s.mu.RUnlock() + return normalizeBulkOpenTuning(s.bulkOpenTuning) +} + +func (s *ServerCommon) SetBulkOpenTuning(tuning BulkOpenTuning) { + if s == nil { + return + } + s.mu.Lock() + s.applyBulkOpenTuningLocked(tuning) + s.mu.Unlock() +} + +func (s *ServerCommon) BulkOpenTuning() BulkOpenTuning { + return s.bulkOpenTuningSnapshot() +} diff --git a/client_config.go b/client_config.go index 1fe30ae..a0aef26 100644 --- a/client_config.go +++ b/client_config.go @@ -75,6 +75,7 @@ func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) { c.fastStreamEncode = nil c.fastBulkEncode = nil c.fastPlainEncode = nil + c.modernPSKRuntime = nil c.securityReadyCheck = false } @@ -89,6 +90,7 @@ func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) { c.fastStreamEncode = nil c.fastBulkEncode = nil c.fastPlainEncode = nil + c.modernPSKRuntime = nil c.securityReadyCheck = false } @@ -108,6 +110,13 @@ func (c *ClientCommon) GetSecretKey() []byte { // Prefer UseModernPSKClient or UseLegacySecurityClient. func (c *ClientCommon) SetSecretKey(key []byte) { c.SecretKey = key + if len(key) == 0 { + c.modernPSKRuntime = nil + } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { + c.modernPSKRuntime = runtime + } else { + c.modernPSKRuntime = nil + } c.securityReadyCheck = len(key) == 0 c.skipKeyExchange = true } diff --git a/client_conn.go b/client_conn.go index 9a8e753..1939f4b 100644 --- a/client_conn.go +++ b/client_conn.go @@ -72,6 +72,23 @@ func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) { conn := rt.tuConn generation := rt.transportGeneration defer closeClientConnSessionRuntimeTransportDone(rt) + if conn != nil && !isPacketTransportConn(conn) { + reader := newTransportFrameReader(conn, nil) + for { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return + default: + } + payload, release, err := c.readTUTransportPayloadPooled(conn, reader) + if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) { + return + } + } + } buf := streamReadBuffer() for { select { diff --git a/client_conn_attachment.go b/client_conn_attachment.go index 01de0d3..103a28c 100644 --- a/client_conn_attachment.go +++ b/client_conn_attachment.go @@ -13,6 +13,7 @@ type clientConnAttachmentState struct { fastStreamEncode transportFastStreamEncoder fastBulkEncode transportFastBulkEncoder fastPlainEncode transportFastPlainEncoder + modernPSKRuntime *modernPSKCodecRuntime handshakeRsaKey []byte secretKey []byte lastHeartBeat int64 @@ -154,6 +155,7 @@ func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durati state.maxWriteTimeout = maxWriteTimeout state.msgEn = msgEn state.msgDe = msgDe + state.modernPSKRuntime = nil state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey) }) @@ -208,6 +210,7 @@ func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) { state.fastStreamEncode = nil state.fastBulkEncode = nil state.fastPlainEncode = nil + state.modernPSKRuntime = nil }) } @@ -224,6 +227,7 @@ func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) { state.fastStreamEncode = nil state.fastBulkEncode = nil state.fastPlainEncode = nil + state.modernPSKRuntime = nil }) } @@ -286,6 +290,64 @@ func (c *ClientConn) setClientConnSecretKey(key []byte) { }) } +func (c *LogicalConn) attachmentStateRaw() *clientConnAttachmentState { + if c == nil { + return nil + } + if state := c.attachment.Load(); state != nil { + if client := c.compatClientConn(); client != nil { + client.attachment.Store(state) + } + return state + } + if client := c.compatClientConn(); client != nil { + if state := client.attachment.Load(); state != nil { + if c.attachment.CompareAndSwap(nil, state) { + client.attachment.Store(state) + return state + } + return c.attachmentStateRaw() + } + } + return nil +} + +func (c *LogicalConn) modernPSKRuntimeSnapshot() *modernPSKCodecRuntime { + if state := c.attachmentStateRaw(); state != nil { + return state.modernPSKRuntime + } + return nil +} + +func (c *LogicalConn) setModernPSKRuntime(runtime *modernPSKCodecRuntime) { + c.updateAttachmentState(func(state *clientConnAttachmentState) { + state.modernPSKRuntime = runtime + }) +} + +func (c *ClientConn) clientConnAttachmentStateRaw() *clientConnAttachmentState { + if c == nil { + return nil + } + if logical := c.logicalView.Load(); logical != nil { + return logical.attachmentStateRaw() + } + return c.attachment.Load() +} + +func (c *ClientConn) clientConnModernPSKRuntimeSnapshot() *modernPSKCodecRuntime { + if state := c.clientConnAttachmentStateRaw(); state != nil { + return state.modernPSKRuntime + } + return nil +} + +func (c *ClientConn) setClientConnModernPSKRuntime(runtime *modernPSKCodecRuntime) { + c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) { + state.modernPSKRuntime = runtime + }) +} + func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 { if c == nil { return 0 diff --git a/client_conn_transport.go b/client_conn_transport.go index fd967a0..c94f724 100644 --- a/client_conn_transport.go +++ b/client_conn_transport.go @@ -1,6 +1,7 @@ package notify import ( + "b612.me/stario" "context" "net" "os" @@ -15,6 +16,10 @@ type serverInboundSourcePusher interface { pushMessageSource([]byte, interface{}) } +type serverInboundSourceFastPusher interface { + pushTransportPayloadSourceFast([]byte, func(), interface{}) bool +} + func (c *LogicalConn) readTUMessage() { rt := c.clientConnSessionRuntimeSnapshot() if rt == nil { @@ -37,6 +42,23 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) { conn := rt.tuConn generation := rt.transportGeneration defer closeClientConnSessionRuntimeTransportDone(rt) + if conn != nil && !isPacketTransportConn(conn) { + reader := newTransportFrameReader(conn, nil) + for { + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return + default: + } + payload, release, err := c.readTUTransportPayloadPooled(conn, reader) + if !c.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) { + return + } + } + } buf := streamReadBuffer() for { select { @@ -54,6 +76,55 @@ func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) { } } +func (c *LogicalConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) { + if reader == nil { + return nil, nil, net.ErrClosed + } + if conn == nil { + return nil, nil, net.ErrClosed + } + if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + return reader.NextPooled() +} + +func (c *LogicalConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool { + if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) { + if release != nil { + release() + } + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return false + } + if err == os.ErrDeadlineExceeded { + return true + } + if err != nil { + if release != nil { + release() + } + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseTransportOnStop(conn) { + _ = conn.Close() + } + return false + default: + } + if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { + detacher.detachLogicalSessionTransport(c, "read error", err) + return false + } + c.stopServerOwnedSession("read error", err) + return false + } + c.pushServerOwnedTransportPayload(payload, release, conn, generation) + return true +} + func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) { if len(data) == 0 { data = streamReadBuffer() @@ -140,6 +211,29 @@ func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn server.pushMessage(data, c.clientConnIDSnapshot()) } +func (c *LogicalConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) { + if c == nil || len(payload) == 0 { + if release != nil { + release() + } + return + } + server := c.Server() + if server == nil { + if release != nil { + release() + } + return + } + if pusher, ok := server.(serverInboundSourceFastPusher); ok { + pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(c, conn, nil, generation)) + return + } + if release != nil { + release() + } +} + func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool { if c == nil || conn == nil { return false @@ -185,10 +279,65 @@ func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byt return num, data, err } +func (c *ClientConn) readTUTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) { + if logical := c.LogicalConn(); logical != nil { + return logical.readTUTransportPayloadPooled(conn, reader) + } + if reader == nil { + return nil, nil, net.ErrClosed + } + if conn == nil { + return nil, nil, net.ErrClosed + } + if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + } + return reader.NextPooled() +} + func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool { return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err) } +func (c *ClientConn) handleTUTransportPayloadReadResultWithSessionPooled(stopCtx context.Context, conn net.Conn, generation uint64, payload []byte, release func(), err error) bool { + if logical := c.LogicalConn(); logical != nil { + return logical.handleTUTransportPayloadReadResultWithSessionPooled(stopCtx, conn, generation, payload, release, err) + } + if transportReadShouldStop(stopCtx) || !c.ownsTransportRead(conn, generation) { + if release != nil { + release() + } + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return false + } + if err == os.ErrDeadlineExceeded { + return true + } + if err != nil { + if release != nil { + release() + } + select { + case <-sessionStopChan(stopCtx): + if c.shouldCloseClientConnTransportOnStop(conn) { + _ = conn.Close() + } + return false + default: + } + if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() { + detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err) + return false + } + c.stopServerOwnedSession("read error", err) + return false + } + c.pushServerOwnedTransportPayload(payload, release, conn, generation) + return true +} + func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool { if logical := c.LogicalConn(); logical != nil { return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) @@ -255,6 +404,26 @@ func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, c.server.pushMessage(data, c.clientConnIDSnapshot()) } +func (c *ClientConn) pushServerOwnedTransportPayload(payload []byte, release func(), conn net.Conn, generation uint64) { + if logical := c.LogicalConn(); logical != nil { + logical.pushServerOwnedTransportPayload(payload, release, conn, generation) + return + } + if c == nil || c.server == nil || len(payload) == 0 { + if release != nil { + release() + } + return + } + if pusher, ok := c.server.(serverInboundSourceFastPusher); ok { + pusher.pushTransportPayloadSourceFast(payload, release, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation)) + return + } + if release != nil { + release() + } +} + func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool { if logical := c.LogicalConn(); logical != nil { return logical.shouldCloseTransportOnStop(conn) diff --git a/client_runtime.go b/client_runtime.go index 355a3d0..cdd3a13 100644 --- a/client_runtime.go +++ b/client_runtime.go @@ -433,6 +433,21 @@ func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, q } binding := newTransportBinding(conn, queue) dispatcher := c.clientInboundDispatcherSnapshot() + if conn != nil && queue != nil && !isPacketTransportConn(conn) { + reader := newTransportFrameReader(conn, queue) + for { + select { + case <-stopCtx.Done(): + c.closeClientTransportBinding(binding) + return + default: + } + payload, release, err := c.readTransportPayloadPooled(conn, reader) + if !c.handleTransportPayloadReadResultWithSession(stopCtx, binding, payload, release, err, epoch, dispatcher) { + return + } + } + } buf := streamReadBuffer() for { select { diff --git a/client_session_runtime_test.go b/client_session_runtime_test.go index ea2440d..9933a0b 100644 --- a/client_session_runtime_test.go +++ b/client_session_runtime_test.go @@ -320,7 +320,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) { defer oldLeft.Close() defer oldRight.Close() oldBinding := newTransportBinding(oldLeft, queue) - oldSender := oldBinding.bulkBatchSenderSnapshot() + oldSender := oldBinding.clientBulkBatchSenderSnapshot(client) client.setClientSessionRuntime(&clientSessionRuntime{ transport: oldBinding, @@ -345,7 +345,7 @@ func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) { epoch: 2, }) - err := oldSender.submit(context.Background(), []byte("payload")) + err := oldSender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload")) if err != errTransportDetached { t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached) } diff --git a/client_stream.go b/client_stream.go index 567ea34..906b942 100644 --- a/client_stream.go +++ b/client_stream.go @@ -35,6 +35,11 @@ func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (S if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } else { + req.FastPathVersion = streamFastPathVersionV1 + } req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata) stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot()) stream.setClientSnapshotOwner(c) @@ -66,11 +71,12 @@ func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp id = runtime.nextID() } return normalizeStreamOpenRequest(StreamOpenRequest{ - StreamID: id, - Channel: opt.Channel, - Metadata: cloneStreamMetadata(opt.Metadata), - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + StreamID: id, + FastPathVersion: streamFastPathVersionCurrent, + Channel: opt.Channel, + Metadata: cloneStreamMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, }) } @@ -110,7 +116,7 @@ func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender { } } if dataID := stream.dataIDSnapshot(); dataID != 0 { - return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk) + return c.sendFastStreamData(ctx, stream, chunk) } return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk)) } diff --git a/client_transport.go b/client_transport.go index 729f4ee..35ab280 100644 --- a/client_transport.go +++ b/client_transport.go @@ -130,24 +130,85 @@ func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx co return true } +func (c *ClientCommon) readTransportPayloadPooled(conn net.Conn, reader *stario.FrameReader) ([]byte, func(), error) { + if reader == nil { + return nil, nil, net.ErrClosed + } + if conn == nil { + return nil, nil, net.ErrClosed + } + if c.maxReadTimeout.Seconds() != 0 { + _ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)) + } + return reader.NextPooled() +} + +func (c *ClientCommon) handleTransportPayloadReadResultWithSession(stopCtx context.Context, binding *transportBinding, payload []byte, release func(), err error, epoch uint64, dispatcher *inboundDispatcher) bool { + if err == os.ErrDeadlineExceeded { + return true + } + if err != nil { + if release != nil { + release() + } + if c.showError || c.debugMode { + fmt.Println("client read error", err) + } + select { + case <-sessionStopChan(stopCtx): + c.closeClientTransportBinding(binding) + return false + default: + } + c.stopClientSessionIfCurrent(epoch, "client read error", err) + return false + } + c.dispatchTransportPayloadFast(payload, release, dispatcher) + return true +} + +func (c *ClientCommon) dispatchTransportPayloadFast(payload []byte, release func(), dispatcher *inboundDispatcher) { + if len(payload) == 0 { + if release != nil { + release() + } + return + } + if c.tryDispatchBorrowedBulkTransportPayload(payload) { + if release != nil { + release() + } + return + } + owned := append([]byte(nil), payload...) + if release != nil { + release() + } + if dispatcher == nil { + now := time.Now() + if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { + fmt.Println("client decode envelope error", err) + } + return + } + c.wg.Add(1) + if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { + defer c.wg.Done() + now := time.Now() + if err := c.dispatchInboundTransportPayload(owned, now); err != nil && (c.showError || c.debugMode) { + fmt.Println("client decode envelope error", err) + } + }) { + c.wg.Done() + } +} + func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool { if queue == nil || dispatcher == nil || len(data) == 0 { return false } - if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error { - payload := msg.Msg - c.wg.Add(1) - if !dispatcher.Dispatch(clientInboundDispatchSource(), func() { - defer c.wg.Done() - now := time.Now() - if err := c.dispatchInboundTransportPayload(payload, now); err != nil { - if c.showError || c.debugMode { - fmt.Println("client decode envelope error", err) - } - } - }) { - c.wg.Done() - } + if err := queue.ParseMessageView(data, "b612", func(frame stario.FrameView) error { + c.dispatchTransportPayloadFast(frame.Payload, nil, dispatcher) return nil }); err != nil && (c.showError || c.debugMode) { fmt.Println("client parse inbound frame error", err) diff --git a/clienttype.go b/clienttype.go index bd84f93..bace9a6 100644 --- a/clienttype.go +++ b/clienttype.go @@ -18,6 +18,14 @@ type Client interface { SetStreamConfig(StreamConfig) SetTransferResumeStore(TransferResumeStore) RecoverTransferSnapshots(context.Context) error + SetBulkNetworkProfile(BulkNetworkProfile) + BulkNetworkProfile() BulkNetworkProfile + SetBulkDefaultOpenMode(BulkOpenMode) + BulkDefaultOpenMode() BulkOpenMode + SetBulkOpenTuning(BulkOpenTuning) + BulkOpenTuning() BulkOpenTuning + SetBulkDedicatedAttachConfig(BulkDedicatedAttachConfig) + BulkDedicatedAttachConfig() BulkDedicatedAttachConfig SetFileReceiveDir(dir string) error send(msg TransferMsg) (WaitMsg, error) sendEnvelope(env Envelope) error @@ -77,6 +85,8 @@ type Client interface { OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) + OpenSharedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) + OpenDedicatedBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error) SendFile(ctx context.Context, filePath string) error } diff --git a/default.go b/default.go index ffa557d..3a32212 100644 --- a/default.go +++ b/default.go @@ -126,6 +126,8 @@ func init() { RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{}) RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{}) RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{}) + RegisterName("b612.me/notify.BulkReadyRequest", BulkReadyRequest{}) + RegisterName("b612.me/notify.BulkReadyResponse", BulkReadyResponse{}) RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{}) RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{}) RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{}) diff --git a/diagnostics_snapshot.go b/diagnostics_snapshot.go index bc177ea..a2789cd 100644 --- a/diagnostics_snapshot.go +++ b/diagnostics_snapshot.go @@ -59,6 +59,11 @@ type DiagnosticsSummary struct { LogicalCount int CurrentTransportCount int + BulkAttachAttempts int64 + BulkAttachRetries int64 + BulkAttachSuccess int64 + BulkAutoFallbacks int64 + StreamCount int ActiveStreamCount int StaleStreamCount int @@ -236,7 +241,11 @@ func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSna func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary { summary := DiagnosticsSummary{ - LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime), + LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime), + BulkAttachAttempts: snapshot.Runtime.BulkAttachAttempts, + BulkAttachRetries: snapshot.Runtime.BulkAttachRetries, + BulkAttachSuccess: snapshot.Runtime.BulkAttachSuccess, + BulkAutoFallbacks: snapshot.Runtime.BulkAutoFallbacks, } if snapshot.Runtime.TransportAttached { summary.CurrentTransportCount = 1 diff --git a/go.mod b/go.mod index 8d5cbc0..6f1b171 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.24.0 require ( b612.me/starcrypto v1.0.2 - b612.me/stario v0.1.0 + b612.me/stario v0.1.1 github.com/Microsoft/go-winio v0.6.2 ) diff --git a/go.sum b/go.sum index 3f3df03..4afe4fd 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE= b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE= -b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4= -b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc= +b612.me/stario v0.1.1 h1:WIQy5DdK2Tkk+PIRORaVb76f4KY+64UvWChXNI7hSVY= +b612.me/stario v0.1.1/go.mod h1:qMMqjaMhKmdfFn5T0oleL5L4FpFWEHsuIrT878hyo7I= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/logical_conn.go b/logical_conn.go index 0c4ca00..00cc98e 100644 --- a/logical_conn.go +++ b/logical_conn.go @@ -409,6 +409,7 @@ func (c *LogicalConn) applyAttachmentProfile(maxReadTimeout time.Duration, maxWr state.fastStreamEncode = fastStreamEncode state.fastBulkEncode = fastBulkEncode state.fastPlainEncode = fastPlainEncode + state.modernPSKRuntime = nil state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey) }) @@ -553,6 +554,12 @@ func (c *LogicalConn) runtimeSnapshot() ClientConnRuntimeSnapshot { snapshot.HasRuntimeConn = c.transportSnapshot() != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil } + if binding := c.transportBindingSnapshot(); binding != nil { + snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot() + snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot() + } if detach := c.transportDetachSnapshot(); detach != nil { snapshot.TransportDetachReason = detach.Reason snapshot.TransportDetachKind = classifyClientConnTransportDetachReason(detach.Reason) @@ -816,6 +823,7 @@ func (c *LogicalConn) applyClientConnAttachmentProfile(maxReadTimeout time.Durat state.maxWriteTimeout = maxWriteTimeout state.msgEn = msgEn state.msgDe = msgDe + state.modernPSKRuntime = nil state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey) state.secretKey = cloneClientConnAttachmentBytes(secretKey) }) diff --git a/raw_tcp_benchmark_test.go b/raw_tcp_benchmark_test.go index 96a5b84..74e745d 100644 --- a/raw_tcp_benchmark_test.go +++ b/raw_tcp_benchmark_test.go @@ -41,7 +41,7 @@ func BenchmarkRawTCPLocalhostThroughput(b *testing.B) { func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) { b.Helper() - listener, err := net.Listen("tcp", "127.0.0.1:0") + listener, err := net.Listen("tcp", benchmarkTCPListenAddr(b)) if err != nil { b.Fatalf("net.Listen failed: %v", err) } @@ -60,7 +60,7 @@ func benchmarkRawTCPLocalhostThroughput(b *testing.B, payloadSize int) { acceptCh <- conn }() - clientConn, err := net.Dial("tcp", listener.Addr().String()) + clientConn, err := net.Dial("tcp", benchmarkTCPDialAddr(b, listener.Addr().String())) if err != nil { b.Fatalf("net.Dial failed: %v", err) } diff --git a/release_p0_test.go b/release_p0_test.go index 0c7b0d4..8b5eb0f 100644 --- a/release_p0_test.go +++ b/release_p0_test.go @@ -5,6 +5,7 @@ import ( "errors" "net" "strings" + "sync/atomic" "testing" "time" ) @@ -14,6 +15,29 @@ type releaseP0TestAddr string func (a releaseP0TestAddr) Network() string { return "tcp" } func (a releaseP0TestAddr) String() string { return string(a) } +type closeInspectConn struct { + closeFn func() + closed atomic.Bool +} + +func (c *closeInspectConn) Read([]byte) (int, error) { return 0, net.ErrClosed } +func (c *closeInspectConn) Write(p []byte) (int, error) { return len(p), nil } +func (c *closeInspectConn) LocalAddr() net.Addr { return releaseP0TestAddr("local") } +func (c *closeInspectConn) RemoteAddr() net.Addr { return releaseP0TestAddr("remote") } +func (c *closeInspectConn) SetDeadline(time.Time) error { return nil } +func (c *closeInspectConn) SetReadDeadline(time.Time) error { return nil } +func (c *closeInspectConn) SetWriteDeadline(time.Time) error { return nil } + +func (c *closeInspectConn) Close() error { + if c == nil { + return nil + } + if c.closed.CompareAndSwap(false, true) && c.closeFn != nil { + c.closeFn() + } + return nil +} + func TestGetLogicalConnRuntimeSnapshotWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) logical := &LogicalConn{server: server} @@ -101,6 +125,116 @@ func TestHandleDedicatedBulkReadErrorPreservesUnderlyingCause(t *testing.T) { } } +func TestHandleClientDedicatedSidecarFailureMarksBulkBeforeClosingConn(t *testing.T) { + client := NewClient().(*ClientCommon) + runtime := client.getBulkRuntime() + if runtime == nil { + t.Fatal("client bulk runtime should not be nil") + } + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "sidecar-order", + DataID: 7, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + var closeObservedErr error + conn := &closeInspectConn{ + closeFn: func() { + closeObservedErr = bulk.resetErrSnapshot() + }, + } + if err := bulk.attachDedicatedConnShared(conn); err != nil { + t.Fatalf("attachDedicatedConnShared failed: %v", err) + } + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + sidecar := newBulkDedicatedSidecar(conn, 1) + client.installClientDedicatedSidecar(1, sidecar) + + client.handleClientDedicatedSidecarFailure(sidecar, errors.New("boom sidecar")) + + if !errors.Is(closeObservedErr, errTransportDetached) { + t.Fatalf("closeObservedErr = %v, want transport detached", closeObservedErr) + } + if !strings.Contains(closeObservedErr.Error(), "dedicated bulk read error") || !strings.Contains(closeObservedErr.Error(), "boom sidecar") { + t.Fatalf("closeObservedErr detail = %q, want dedicated read error and cause", closeObservedErr.Error()) + } +} + +func TestCleanupClientSessionResourcesMarksBulkBeforeClosingSidecar(t *testing.T) { + client := NewClient().(*ClientCommon) + runtime := client.getBulkRuntime() + if runtime == nil { + t.Fatal("client bulk runtime should not be nil") + } + bulk := newBulkHandle(context.Background(), runtime, clientFileScope(), BulkOpenRequest{ + BulkID: "cleanup-order", + DataID: 9, + Dedicated: true, + Range: BulkRange{ + Length: 1, + }, + }, 0, nil, nil, 0, nil, nil, nil, nil, nil) + var closeObservedErr error + conn := &closeInspectConn{ + closeFn: func() { + closeObservedErr = bulk.resetErrSnapshot() + }, + } + if err := bulk.attachDedicatedConnShared(conn); err != nil { + t.Fatalf("attachDedicatedConnShared failed: %v", err) + } + if err := runtime.register(clientFileScope(), bulk); err != nil { + t.Fatalf("register bulk failed: %v", err) + } + client.installClientDedicatedSidecar(1, newBulkDedicatedSidecar(conn, 1)) + + client.cleanupClientSessionResources() + + if !errors.Is(closeObservedErr, errServiceShutdown) { + t.Fatalf("closeObservedErr = %v, want %v", closeObservedErr, errServiceShutdown) + } +} + +func TestBestEffortRejectInboundDedicatedDataUsesDedicatedResetRecord(t *testing.T) { + server := NewServer().(*ServerCommon) + UseLegacySecurityServer(server) + + logical := server.bootstrapAcceptedLogical("dedicated-reject", nil, nil) + if logical == nil { + t.Fatal("bootstrapAcceptedLogical should return logical") + } + conn := newBulkAttachScriptConn(nil) + + server.bestEffortRejectInboundDedicatedData(logical, conn, 42, "unknown data id") + + recordConn := newBulkAttachScriptConn(conn.writtenBytes()) + payload, err := readBulkDedicatedRecord(recordConn) + if err != nil { + t.Fatalf("readBulkDedicatedRecord failed: %v", err) + } + plain, err := server.decryptTransportPayloadLogical(logical, payload) + if err != nil { + t.Fatalf("decryptTransportPayloadLogical failed: %v", err) + } + items, err := decodeDedicatedBulkInboundItems(42, plain) + if err != nil { + t.Fatalf("decodeDedicatedBulkInboundItems failed: %v", err) + } + if len(items) != 1 { + t.Fatalf("decoded items = %d, want 1", len(items)) + } + if items[0].Type != bulkFastPayloadTypeReset { + t.Fatalf("reset item type = %d, want %d", items[0].Type, bulkFastPayloadTypeReset) + } + if got, want := string(items[0].Payload), "unknown data id"; got != want { + t.Fatalf("reset payload = %q, want %q", got, want) + } +} + func TestRegisterAcceptedLogicalWithoutCompatClient(t *testing.T) { server := NewServer().(*ServerCommon) UseLegacySecurityServer(server) diff --git a/security_psk.go b/security_psk.go index a90d189..0c0c1a3 100644 --- a/security_psk.go +++ b/security_psk.go @@ -40,6 +40,8 @@ type modernPSKTransportBundle struct { fastPlainEncode transportFastPlainEncoder } +var modernPSKPayloadPool sync.Pool + // ModernPSKOptions configures the modern PSK transport profile. // // The current profile derives a 32-byte transport key with Argon2id and uses @@ -81,6 +83,10 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e return err } transport := buildModernPSKTransportBundle(aad) + runtime, err := newModernPSKCodecRuntime(key, aad) + if err != nil { + return err + } c.SetSecretKey(key) c.SetMsgEn(transport.msgEn) c.SetMsgDe(transport.msgDe) @@ -88,6 +94,7 @@ func UseModernPSKClient(c Client, sharedSecret []byte, opts *ModernPSKOptions) e client.fastStreamEncode = transport.fastStreamEncode client.fastBulkEncode = transport.fastBulkEncode client.fastPlainEncode = transport.fastPlainEncode + client.modernPSKRuntime = runtime } c.SetSkipExchangeKey(true) return nil @@ -104,6 +111,10 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e return err } transport := buildModernPSKTransportBundle(aad) + runtime, err := newModernPSKCodecRuntime(key, aad) + if err != nil { + return err + } s.SetSecretKey(key) s.SetDefaultCommEncode(transport.msgEn) s.SetDefaultCommDecode(transport.msgDe) @@ -111,6 +122,7 @@ func UseModernPSKServer(s Server, sharedSecret []byte, opts *ModernPSKOptions) e server.defaultFastStreamEncode = transport.fastStreamEncode server.defaultFastBulkEncode = transport.fastBulkEncode server.defaultFastPlainEncode = transport.fastPlainEncode + server.defaultModernPSKRuntime = runtime } return nil } @@ -127,6 +139,7 @@ func UseLegacySecurityClient(c Client) { client.fastStreamEncode = nil client.fastBulkEncode = nil client.fastPlainEncode = nil + client.modernPSKRuntime = nil } c.SetSkipExchangeKey(false) c.SetRsaPubKey(bytes.Clone(defaultRsaPubKey)) @@ -144,6 +157,7 @@ func UseLegacySecurityServer(s Server) { server.defaultFastStreamEncode = nil server.defaultFastBulkEncode = nil server.defaultFastPlainEncode = nil + server.defaultModernPSKRuntime = nil } s.SetRsaPrivKey(bytes.Clone(defaultRsaKey)) } @@ -185,14 +199,14 @@ func buildModernPSKCodecs(aad []byte) (func([]byte, []byte) []byte, func([]byte, func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { aadCopy := bytes.Clone(aad) - cache := &modernPSKCodecCache{} + cache := newModernPSKCodecCache(aadCopy) msgEn := func(key []byte, plain []byte) []byte { runtime, err := cache.runtimeForKey(key) if err != nil { log.Print(err) return nil } - out, err := runtime.sealPlainPayload(aadCopy, plain) + out, err := runtime.sealPlainPayload(plain) if err != nil { log.Print(err) return nil @@ -214,9 +228,7 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { log.Print(err) return nil } - nonce := encrypted[len(modernPSKMagic):headerLen] - ciphertext := encrypted[headerLen:] - plain, err := runtime.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, aadCopy) + plain, err := runtime.openPayload(encrypted) if err != nil { log.Print(err) return nil @@ -228,21 +240,21 @@ func buildModernPSKTransportBundle(aad []byte) modernPSKTransportBundle { if err != nil { return nil, err } - return runtime.sealStreamFastPayload(aadCopy, dataID, seq, payload) + return runtime.sealStreamFastPayload(dataID, seq, payload) } fastBulkEncode := func(key []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } - return runtime.sealBulkFastPayload(aadCopy, dataID, seq, payload) + return runtime.sealBulkFastPayload(dataID, seq, payload) } fastPlainEncode := func(key []byte, plainLen int, fill func([]byte) error) ([]byte, error) { runtime, err := cache.runtimeForKey(key) if err != nil { return nil, err } - return runtime.sealFilledPayload(aadCopy, plainLen, fill) + return runtime.sealFilledPayload(plainLen, fill) } return modernPSKTransportBundle{ msgEn: msgEn, @@ -269,16 +281,23 @@ func (s *ServerCommon) validateSecurityConfiguration() error { type modernPSKCodecCache struct { mu sync.Mutex + aad []byte key []byte runtime *modernPSKCodecRuntime } type modernPSKCodecRuntime struct { aead cipher.AEAD + key []byte + aad []byte prefix [modernPSKNonceSize - 8]byte seq atomic.Uint64 } +func newModernPSKCodecCache(aad []byte) *modernPSKCodecCache { + return &modernPSKCodecCache{aad: bytes.Clone(aad)} +} + func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, error) { if c == nil { return nil, errModernPSKSecretEmpty @@ -288,7 +307,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, if c.runtime != nil && bytes.Equal(c.key, key) { return c.runtime, nil } - runtime, err := newModernPSKCodecRuntime(key) + runtime, err := newModernPSKCodecRuntime(key, c.aad) if err != nil { return nil, err } @@ -297,7 +316,7 @@ func (c *modernPSKCodecCache) runtimeForKey(key []byte) (*modernPSKCodecRuntime, return runtime, nil } -func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { +func newModernPSKCodecRuntime(key []byte, aad []byte) (*modernPSKCodecRuntime, error) { if len(key) == 0 { return nil, errModernPSKSecretEmpty } @@ -311,6 +330,8 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { } runtime := &modernPSKCodecRuntime{ aead: aead, + key: bytes.Clone(key), + aad: bytes.Clone(aad), } if _, err := cryptorand.Read(runtime.prefix[:]); err != nil { return nil, err @@ -318,6 +339,13 @@ func newModernPSKCodecRuntime(key []byte) (*modernPSKCodecRuntime, error) { return runtime, nil } +func (r *modernPSKCodecRuntime) fork() (*modernPSKCodecRuntime, error) { + if r == nil { + return nil, errModernPSKSecretEmpty + } + return newModernPSKCodecRuntime(r.key, r.aad) +} + func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { var nonce [modernPSKNonceSize]byte if r == nil { @@ -328,8 +356,8 @@ func (r *modernPSKCodecRuntime) nextNonce() [modernPSKNonceSize]byte { return nonce } -func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { - return r.sealFilledPayload(aad, streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { +func (r *modernPSKCodecRuntime) sealStreamFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) { + return r.sealFilledPayload(streamFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeStreamFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } @@ -338,11 +366,11 @@ func (r *modernPSKCodecRuntime) sealStreamFastPayload(aad []byte, dataID uint64, }) } -func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, seq uint64, payload []byte) ([]byte, error) { +func (r *modernPSKCodecRuntime) sealBulkFastPayload(dataID uint64, seq uint64, payload []byte) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } - return r.sealFilledPayload(aad, bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { + return r.sealFilledPayload(bulkFastPayloadHeaderLen+len(payload), func(frame []byte) error { if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil { return err } @@ -351,14 +379,14 @@ func (r *modernPSKCodecRuntime) sealBulkFastPayload(aad []byte, dataID uint64, s }) } -func (r *modernPSKCodecRuntime) sealPlainPayload(aad []byte, plain []byte) ([]byte, error) { - return r.sealFilledPayload(aad, len(plain), func(dst []byte) error { +func (r *modernPSKCodecRuntime) sealPlainPayload(plain []byte) ([]byte, error) { + return r.sealFilledPayload(len(plain), func(dst []byte) error { copy(dst, plain) return nil }) } -func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill func([]byte) error) ([]byte, error) { +func (r *modernPSKCodecRuntime) sealFilledPayload(plainLen int, fill func([]byte) error) ([]byte, error) { if r == nil { return nil, errTransportPayloadEncryptFailed } @@ -368,6 +396,35 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill nonce := r.nextNonce() headerLen := len(modernPSKMagic) + modernPSKNonceSize out := make([]byte, headerLen+plainLen+r.aead.Overhead()) + sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill) + if err != nil { + return nil, err + } + return out[:headerLen+len(sealed)], nil +} + +func (r *modernPSKCodecRuntime) sealFilledPayloadPooled(plainLen int, fill func([]byte) error) ([]byte, func(), error) { + if r == nil { + return nil, nil, errTransportPayloadEncryptFailed + } + if plainLen < 0 { + return nil, nil, errTransportPayloadEncryptFailed + } + nonce := r.nextNonce() + headerLen := len(modernPSKMagic) + modernPSKNonceSize + totalLen := headerLen + plainLen + r.aead.Overhead() + out := getModernPSKPayloadBuffer(totalLen) + sealed, err := r.sealInto(out, headerLen, nonce, plainLen, fill) + if err != nil { + putModernPSKPayloadBuffer(out) + return nil, nil, err + } + return out[:headerLen+len(sealed)], func() { + putModernPSKPayloadBuffer(out) + }, nil +} + +func (r *modernPSKCodecRuntime) sealInto(out []byte, headerLen int, nonce [modernPSKNonceSize]byte, plainLen int, fill func([]byte) error) ([]byte, error) { copy(out[:len(modernPSKMagic)], modernPSKMagic) copy(out[len(modernPSKMagic):headerLen], nonce[:]) frame := out[headerLen : headerLen+plainLen] @@ -376,6 +433,98 @@ func (r *modernPSKCodecRuntime) sealFilledPayload(aad []byte, plainLen int, fill return nil, err } } - sealed := r.aead.Seal(frame[:0], nonce[:], frame, aad) - return out[:headerLen+len(sealed)], nil + return r.aead.Seal(frame[:0], nonce[:], frame, r.aad), nil +} + +func (r *modernPSKCodecRuntime) openPayload(encrypted []byte) ([]byte, error) { + if r == nil { + return nil, errTransportPayloadDecryptFailed + } + headerLen := len(modernPSKMagic) + modernPSKNonceSize + if len(encrypted) < headerLen { + return nil, errModernPSKPayload + } + if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { + return nil, errModernPSKPayload + } + nonce := encrypted[len(modernPSKMagic):headerLen] + ciphertext := encrypted[headerLen:] + return r.aead.Open(make([]byte, 0, len(ciphertext)), nonce, ciphertext, r.aad) +} + +func (r *modernPSKCodecRuntime) openPayloadPooled(encrypted []byte, release func()) ([]byte, func(), error) { + if r == nil { + if release != nil { + release() + } + return nil, nil, errTransportPayloadDecryptFailed + } + headerLen := len(modernPSKMagic) + modernPSKNonceSize + if len(encrypted) < headerLen { + if release != nil { + release() + } + return nil, nil, errModernPSKPayload + } + if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { + if release != nil { + release() + } + return nil, nil, errModernPSKPayload + } + nonce := encrypted[len(modernPSKMagic):headerLen] + ciphertext := encrypted[headerLen:] + plain, err := r.aead.Open(ciphertext[:0], nonce, ciphertext, r.aad) + if err != nil { + if release != nil { + release() + } + return nil, nil, err + } + return plain, release, nil +} + +func (r *modernPSKCodecRuntime) openPayloadOwnedPooled(encrypted []byte) ([]byte, func(), error) { + if r == nil { + return nil, nil, errTransportPayloadDecryptFailed + } + headerLen := len(modernPSKMagic) + modernPSKNonceSize + if len(encrypted) < headerLen { + return nil, nil, errModernPSKPayload + } + if !bytes.Equal(encrypted[:len(modernPSKMagic)], modernPSKMagic) { + return nil, nil, errModernPSKPayload + } + nonce := encrypted[len(modernPSKMagic):headerLen] + ciphertext := encrypted[headerLen:] + plainLen := len(ciphertext) - r.aead.Overhead() + if plainLen < 0 { + return nil, nil, errModernPSKPayload + } + out := getModernPSKPayloadBuffer(plainLen) + plain, err := r.aead.Open(out[:0], nonce, ciphertext, r.aad) + if err != nil { + putModernPSKPayloadBuffer(out) + return nil, nil, err + } + return plain, func() { + putModernPSKPayloadBuffer(out) + }, nil +} + +func getModernPSKPayloadBuffer(size int) []byte { + if size <= 0 { + return nil + } + if pooled, ok := modernPSKPayloadPool.Get().([]byte); ok && cap(pooled) >= size { + return pooled[:size] + } + return make([]byte, size) +} + +func putModernPSKPayloadBuffer(buf []byte) { + if cap(buf) == 0 || cap(buf) > 32*1024*1024 { + return + } + modernPSKPayloadPool.Put(buf[:0]) } diff --git a/security_psk_test.go b/security_psk_test.go index 4fa7eee..ca6f67b 100644 --- a/security_psk_test.go +++ b/security_psk_test.go @@ -83,6 +83,12 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) { sharedKey := []byte("0123456789abcdef0123456789abcdef") client.SetSecretKey(sharedKey) server.SetSecretKey(sharedKey) + if client.modernPSKRuntime == nil { + t.Fatal("client modernPSKRuntime should be installed after SetSecretKey") + } + if server.defaultModernPSKRuntime == nil { + t.Fatal("server defaultModernPSKRuntime should be installed after SetSecretKey") + } plain := []byte("notify default modern transport") wire := client.msgEn(client.SecretKey, plain) @@ -92,6 +98,29 @@ func TestDefaultConstructorsUseModernTransportAfterSetSecretKey(t *testing.T) { } } +func TestCustomCodecOverridesClearModernRuntime(t *testing.T) { + client := NewClient().(*ClientCommon) + server := NewServer().(*ServerCommon) + sharedKey := []byte("0123456789abcdef0123456789abcdef") + client.SetSecretKey(sharedKey) + server.SetSecretKey(sharedKey) + if client.modernPSKRuntime == nil || server.defaultModernPSKRuntime == nil { + t.Fatal("modern runtimes should be installed before override") + } + + client.SetMsgEn(defaultMsgEn) + client.SetMsgDe(defaultMsgDe) + server.SetDefaultCommEncode(defaultMsgEn) + server.SetDefaultCommDecode(defaultMsgDe) + + if client.modernPSKRuntime != nil { + t.Fatal("client modernPSKRuntime should be cleared after custom codec override") + } + if server.defaultModernPSKRuntime != nil { + t.Fatal("server defaultModernPSKRuntime should be cleared after custom codec override") + } +} + func TestDefaultConstructorsDecodeSignalEnvelopeWithModernTransport(t *testing.T) { client := NewClient().(*ClientCommon) server := NewServer().(*ServerCommon) diff --git a/server.go b/server.go index 49b1f95..94c4f4f 100644 --- a/server.go +++ b/server.go @@ -33,6 +33,7 @@ type ServerCommon struct { defaultFastStreamEncode transportFastStreamEncoder defaultFastBulkEncode transportFastBulkEncoder defaultFastPlainEncode transportFastPlainEncoder + defaultModernPSKRuntime *modernPSKCodecRuntime linkFns map[string]func(message *Message) defaultFns func(message *Message) noFinSyncMsgMaxKeepSeconds int64 @@ -47,6 +48,9 @@ type ServerCommon struct { streamRuntime *streamRuntime recordRuntime *recordRuntime bulkRuntime *bulkRuntime + bulkOpenTuning BulkOpenTuning + bulkDedicatedSidecarMu sync.Mutex + bulkDedicatedSidecars map[*LogicalConn]map[uint32]*bulkDedicatedSidecar connectionRetryState *connectionRetryState detachedClientKeepSeconds int64 securityReadyCheck bool @@ -81,6 +85,8 @@ func NewServer() Server { server.streamRuntime = newStreamRuntime("sstrm") server.recordRuntime = newRecordRuntime() server.bulkRuntime = newBulkRuntime("sblk") + server.bulkOpenTuning = defaultBulkOpenTuning() + server.bulkDedicatedSidecars = make(map[*LogicalConn]map[uint32]*bulkDedicatedSidecar) server.connectionRetryState = newConnectionRetryState() server.onFileEvent = normalizeFileEventCallback(nil) server.fileEventObserver = normalizeFileEventCallback(nil) diff --git a/server_bulk.go b/server_bulk.go index 2634698..75854ec 100644 --- a/server_bulk.go +++ b/server_bulk.go @@ -1,6 +1,9 @@ package notify -import "context" +import ( + "context" + "errors" +) func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { runtime := s.getBulkRuntime() @@ -11,9 +14,48 @@ func (s *ServerCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) { } func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) { + opt = normalizeBulkOpenOptions(opt) + switch opt.Mode { + case BulkOpenModeDedicated: + opt.Dedicated = true + return s.openBulkLogicalWithMode(ctx, logical, opt) + case BulkOpenModeAuto: + if err := logicalDedicatedBulkSupportError(logical); err == nil { + dedicatedOpt := opt + dedicatedOpt.Mode = BulkOpenModeDedicated + dedicatedOpt.Dedicated = true + bulk, dedicatedErr := s.openBulkLogicalWithMode(ctx, logical, dedicatedOpt) + if dedicatedErr == nil { + return bulk, nil + } + sharedOpt := opt + sharedOpt.Mode = BulkOpenModeShared + sharedOpt.Dedicated = false + sharedBulk, sharedErr := s.openBulkLogicalWithMode(ctx, logical, sharedOpt) + if sharedErr == nil { + return sharedBulk, nil + } + return nil, errors.Join(dedicatedErr, sharedErr) + } + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkLogicalWithMode(ctx, logical, opt) + case BulkOpenModeShared, BulkOpenModeDefault: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkLogicalWithMode(ctx, logical, opt) + default: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkLogicalWithMode(ctx, logical, opt) + } +} + +func (s *ServerCommon) openBulkLogicalWithMode(ctx context.Context, logical *LogicalConn, opt BulkOpenOptions) (Bulk, error) { if s == nil { return nil, errBulkServerNil } + opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot()) if logical == nil { return nil, errBulkLogicalConnNil } @@ -42,17 +84,44 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn req.AttachToken = newBulkAttachToken() } bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, logical.CurrentTransportConn(), logical.transportGenerationSnapshot(), serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, logical.CurrentTransportConn()), serverBulkWriteSender(s, logical, logical.CurrentTransportConn()), serverBulkReleaseSender(s, logical, logical.CurrentTransportConn())) + bulk.markAcceptHandled() if err := runtime.register(scope, bulk); err != nil { return nil, err } + s.attachServerDedicatedSidecarIfExists(logical, bulk) resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) if err != nil { - runtime.remove(scope, bulk.ID()) + bulk.markReset(err) + return nil, err + } + if resp.DataID != 0 && resp.DataID != req.DataID { + err = errBulkAlreadyExists + _, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: "bulk dedicated data id mismatch", + }) + bulk.markReset(err) return nil, err } if resp.TransportGeneration != 0 { bulk.transportGeneration = resp.TransportGeneration } + if resp.FastPathVersion != 0 { + bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion) + } + if resp.AttachToken != "" { + bulk.setDedicatedAttachToken(resp.AttachToken) + } + if err := bulk.waitAcceptReady(ctx); err != nil { + _, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + bulk.markReset(err) + return nil, err + } return bulk, nil } resp, err := sendBulkOpenServerLogical(ctx, s, logical, req) @@ -62,6 +131,9 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } req.Dedicated = resp.Dedicated if resp.AttachToken != "" { req.AttachToken = resp.AttachToken @@ -71,6 +143,7 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn } transport := logical.CurrentTransportConn() bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, nil), serverBulkResetSender(s, logical, nil), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + bulk.markAcceptHandled() if err := runtime.register(scope, bulk); err != nil { _, _ = sendBulkResetServerLogical(context.Background(), s, logical, BulkResetRequest{ BulkID: req.BulkID, @@ -79,13 +152,53 @@ func (s *ServerCommon) OpenBulkLogical(ctx context.Context, logical *LogicalConn }) return nil, err } + s.attachServerDedicatedSidecarIfExists(logical, bulk) return bulk, nil } func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) { + opt = normalizeBulkOpenOptions(opt) + switch opt.Mode { + case BulkOpenModeDedicated: + opt.Dedicated = true + return s.openBulkTransportWithMode(ctx, transport, opt) + case BulkOpenModeAuto: + if err := transportDedicatedBulkSupportError(transport); err == nil { + dedicatedOpt := opt + dedicatedOpt.Mode = BulkOpenModeDedicated + dedicatedOpt.Dedicated = true + bulk, dedicatedErr := s.openBulkTransportWithMode(ctx, transport, dedicatedOpt) + if dedicatedErr == nil { + return bulk, nil + } + sharedOpt := opt + sharedOpt.Mode = BulkOpenModeShared + sharedOpt.Dedicated = false + sharedBulk, sharedErr := s.openBulkTransportWithMode(ctx, transport, sharedOpt) + if sharedErr == nil { + return sharedBulk, nil + } + return nil, errors.Join(dedicatedErr, sharedErr) + } + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkTransportWithMode(ctx, transport, opt) + case BulkOpenModeShared, BulkOpenModeDefault: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkTransportWithMode(ctx, transport, opt) + default: + opt.Mode = BulkOpenModeShared + opt.Dedicated = false + return s.openBulkTransportWithMode(ctx, transport, opt) + } +} + +func (s *ServerCommon) openBulkTransportWithMode(ctx context.Context, transport *TransportConn, opt BulkOpenOptions) (Bulk, error) { if s == nil { return nil, errBulkServerNil } + opt = applyBulkOpenTuningDefaults(opt, s.bulkOpenTuningSnapshot()) if transport == nil { return nil, errBulkTransportNil } @@ -118,17 +231,44 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo req.AttachToken = newBulkAttachToken() } bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, transport.TransportGeneration(), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + bulk.markAcceptHandled() if err := runtime.register(scope, bulk); err != nil { return nil, err } + s.attachServerDedicatedSidecarIfExists(logical, bulk) resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) if err != nil { - runtime.remove(scope, bulk.ID()) + bulk.markReset(err) + return nil, err + } + if resp.DataID != 0 && resp.DataID != req.DataID { + err = errBulkAlreadyExists + _, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: "bulk dedicated data id mismatch", + }) + bulk.markReset(err) return nil, err } if resp.TransportGeneration != 0 { bulk.transportGeneration = resp.TransportGeneration } + if resp.FastPathVersion != 0 { + bulk.fastPathVersion = normalizeBulkFastPathVersion(resp.FastPathVersion) + } + if resp.AttachToken != "" { + bulk.setDedicatedAttachToken(resp.AttachToken) + } + if err := bulk.waitAcceptReady(ctx); err != nil { + _, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{ + BulkID: req.BulkID, + DataID: req.DataID, + Error: err.Error(), + }) + bulk.markReset(err) + return nil, err + } return bulk, nil } resp, err := sendBulkOpenServerTransport(ctx, s, transport, req) @@ -138,6 +278,9 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } req.Dedicated = resp.Dedicated if resp.AttachToken != "" { req.AttachToken = resp.AttachToken @@ -146,6 +289,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo return nil, errBulkDataIDEmpty } bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport)) + bulk.markAcceptHandled() if err := runtime.register(scope, bulk); err != nil { _, _ = sendBulkResetServerTransport(context.Background(), s, transport, BulkResetRequest{ BulkID: req.BulkID, @@ -154,6 +298,7 @@ func (s *ServerCommon) OpenBulkTransport(ctx context.Context, transport *Transpo }) return nil, err } + s.attachServerDedicatedSidecarIfExists(logical, bulk) return bulk, nil } @@ -164,15 +309,16 @@ func serverBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenReques id = runtime.nextID() } return normalizeBulkOpenRequest(BulkOpenRequest{ - BulkID: id, - Range: opt.Range, - Metadata: cloneBulkMetadata(opt.Metadata), - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, - Dedicated: opt.Dedicated, - ChunkSize: opt.ChunkSize, - WindowBytes: opt.WindowBytes, - MaxInFlight: opt.MaxInFlight, + BulkID: id, + FastPathVersion: bulkFastPathVersionCurrent, + Range: opt.Range, + Metadata: cloneBulkMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, + Dedicated: opt.Dedicated, + ChunkSize: opt.ChunkSize, + WindowBytes: opt.WindowBytes, + MaxInFlight: opt.MaxInFlight, }) } @@ -247,12 +393,12 @@ func serverBulkDataSender(s *ServerCommon, transport *TransportConn) bulkDataSen if dataID == 0 { return errBulkDataPathNotReady } - return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk) + return s.sendFastBulkDataTransport(ctx, bulk.LogicalConn(), transport, dataID, bulk.nextOutboundDataSeq(), chunk, bulk.fastPathVersionSnapshot()) } } func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *TransportConn) bulkWriteSender { - return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) { + return func(ctx context.Context, bulk *bulkHandle, startSeq uint64, payload []byte, payloadOwned bool) (int, error) { if s == nil { return 0, errBulkServerNil } @@ -267,7 +413,7 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra if err := bulk.waitDedicatedReady(ctx); err != nil { return 0, err } - return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, payload) + return s.sendDedicatedBulkWrite(ctx, bulk.LogicalConn(), bulk, startSeq, payload) } if transport == nil { return 0, errBulkTransportNil @@ -275,7 +421,14 @@ func serverBulkWriteSender(s *ServerCommon, logical *LogicalConn, transport *Tra if !transport.IsCurrent() { return 0, errTransportDetached } - return 0, nil + if bulk == nil { + return 0, errBulkRuntimeNil + } + dataID := bulk.dataIDSnapshot() + if dataID == 0 { + return 0, errBulkDataPathNotReady + } + return s.sendFastBulkWriteTransport(ctx, bulk.LogicalConn(), transport, dataID, startSeq, bulk.chunkSize, bulk.fastPathVersionSnapshot(), payload, payloadOwned) } } @@ -287,8 +440,20 @@ func serverBulkReleaseSender(s *ServerCommon, logical *LogicalConn, transport *T if bytes <= 0 && chunks <= 0 { return nil } + ctx, cancel, err := bulk.newWriteContext(bulk.Context(), bulk.writeTimeout) + if err != nil { + return err + } + defer cancel() if bulk.Dedicated() { - return s.sendDedicatedBulkRelease(context.Background(), logical, bulk, bytes, chunks) + return s.sendDedicatedBulkRelease(ctx, logical, bulk, bytes, chunks) + } + if transport != nil && transport.IsCurrent() && bulk.fastPathVersionSnapshot() >= bulkFastPathVersionV2 { + payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks) + if err != nil { + return err + } + return s.sendFastBulkControlTransport(ctx, logical, transport, bulkFastPayloadTypeRelease, 0, bulk.dataIDSnapshot(), 0, bulk.fastPathVersionSnapshot(), payload) } req := BulkReleaseRequest{ BulkID: bulk.ID(), diff --git a/server_config.go b/server_config.go index c5c202f..f3f6866 100644 --- a/server_config.go +++ b/server_config.go @@ -35,6 +35,7 @@ func (s *ServerCommon) SetDefaultCommEncode(fn func([]byte, []byte) []byte) { s.defaultFastStreamEncode = nil s.defaultFastBulkEncode = nil s.defaultFastPlainEncode = nil + s.defaultModernPSKRuntime = nil s.securityReadyCheck = false } @@ -45,6 +46,7 @@ func (s *ServerCommon) SetDefaultCommDecode(fn func([]byte, []byte) []byte) { s.defaultFastStreamEncode = nil s.defaultFastBulkEncode = nil s.defaultFastPlainEncode = nil + s.defaultModernPSKRuntime = nil s.securityReadyCheck = false } @@ -97,6 +99,13 @@ func (s *ServerCommon) GetSecretKey() []byte { // Prefer UseModernPSKServer or UseLegacySecurityServer. func (s *ServerCommon) SetSecretKey(key []byte) { s.SecretKey = key + if len(key) == 0 { + s.defaultModernPSKRuntime = nil + } else if runtime, err := newModernPSKCodecRuntime(key, defaultModernPSKAAD); err == nil { + s.defaultModernPSKRuntime = runtime + } else { + s.defaultModernPSKRuntime = nil + } s.securityReadyCheck = len(key) == 0 } diff --git a/server_inbound_source.go b/server_inbound_source.go index 3f455ca..a205710 100644 --- a/server_inbound_source.go +++ b/server_inbound_source.go @@ -59,26 +59,8 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt if queue == nil || dispatcher == nil || len(data) == 0 { return false } - if err := queue.ParseMessageOwned(data, source, func(msg stario.MsgQueue) error { - payload := msg.Msg - source := msg.Conn - s.wg.Add(1) - if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { - defer s.wg.Done() - logical, transport := s.resolveInboundSource(source) - if logical == nil { - return - } - now := time.Now() - inboundConn := serverInboundConn(source) - if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, payload, now); err != nil { - if s.showError || s.debugMode { - fmt.Println("server decode envelope error", err) - } - } - }) { - s.wg.Done() - } + if err := queue.ParseMessageView(data, source, func(frame stario.FrameView) error { + s.pushTransportPayloadSourceFast(frame.Payload, nil, frame.Conn) return nil }); err != nil && (s.showError || s.debugMode) { fmt.Println("server parse inbound frame error", err) @@ -86,6 +68,59 @@ func (s *ServerCommon) pushMessageSourceFast(queue *stario.StarQueue, data []byt return true } +func (s *ServerCommon) pushTransportPayloadSourceFast(payload []byte, release func(), source interface{}) bool { + dispatcher := s.serverInboundDispatcherSnapshot() + if len(payload) == 0 { + if release != nil { + release() + } + return false + } + if dispatcher == nil { + queue := s.serverQueueSnapshot() + if queue == nil { + if release != nil { + release() + } + return false + } + frame := queue.BuildMessage(payload) + if release != nil { + release() + } + if err := queue.ParseMessage(frame, source); err != nil && (s.showError || s.debugMode) { + fmt.Println("server enqueue inbound frame error", err) + } + return true + } + if s.tryDispatchBorrowedBulkTransportPayload(source, payload) { + if release != nil { + release() + } + return true + } + owned := append([]byte(nil), payload...) + if release != nil { + release() + } + s.wg.Add(1) + if !dispatcher.Dispatch(serverInboundDispatchSource(source), func() { + defer s.wg.Done() + logical, transport := s.resolveInboundSource(source) + if logical == nil { + return + } + now := time.Now() + inboundConn := serverInboundConn(source) + if err := s.dispatchInboundTransportPayload(logical, transport, inboundConn, owned, now); err != nil && (s.showError || s.debugMode) { + fmt.Println("server decode envelope error", err) + } + }) { + s.wg.Done() + } + return true +} + func serverInboundConn(source interface{}) net.Conn { switch data := source.(type) { case net.Conn: diff --git a/server_listen.go b/server_listen.go index f754095..db92a8c 100644 --- a/server_listen.go +++ b/server_listen.go @@ -130,6 +130,7 @@ func (s *ServerCommon) removeLogical(logical *LogicalConn) { s.getFileAckPool().closeScopeFamily(scope) s.getSignalAckPool().closeScopeFamily(scope) s.getReceivedSignalCache().closeScope(scope) + s.closeServerDedicatedSidecar(logical) s.getPeerRegistry().removeLogical(logical) } diff --git a/server_session.go b/server_session.go index 02703fe..177e565 100644 --- a/server_session.go +++ b/server_session.go @@ -57,6 +57,7 @@ func (s *ServerCommon) detachClientSessionTransport(client *ClientConn, reason s if runtime := s.getBulkRuntime(); runtime != nil { runtime.closeScope(serverFileScope(client), errTransportDetached) } + s.closeServerDedicatedSidecar(logicalConnFromClient(client)) client.detachServerOwnedTransport() } @@ -80,6 +81,7 @@ func (s *ServerCommon) detachLogicalSessionTransport(logical *LogicalConn, reaso if runtime := s.getBulkRuntime(); runtime != nil { runtime.closeScope(serverFileScope(logical), errTransportDetached) } + s.closeServerDedicatedSidecar(logical) logical.detachServerOwnedTransport() } @@ -108,6 +110,7 @@ func (s *ServerCommon) registerAcceptedLogical(logical *LogicalConn) *LogicalCon } logical.setServer(s) logical.applyAttachmentProfile(s.maxReadTimeout, s.maxWriteTimeout, s.defaultMsgEn, s.defaultMsgDe, s.defaultFastStreamEncode, s.defaultFastBulkEncode, s.defaultFastPlainEncode, s.handshakeRsaKey, s.SecretKey) + logical.setModernPSKRuntime(s.defaultModernPSKRuntime) logical.markHeartbeatNow() return s.getPeerRegistry().registerLogical(logical) } diff --git a/server_stream.go b/server_stream.go index 68ac9c9..c053e2c 100644 --- a/server_stream.go +++ b/server_stream.go @@ -33,6 +33,11 @@ func (s *ServerCommon) OpenStreamLogical(ctx context.Context, logical *LogicalCo if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } else { + req.FastPathVersion = streamFastPathVersionV1 + } req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata) transport := logical.CurrentTransportConn() stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, nil), serverStreamResetSender(s, logical, nil), serverStreamDataSender(s, transport), runtime.configSnapshot()) @@ -73,6 +78,11 @@ func (s *ServerCommon) OpenStreamTransport(ctx context.Context, transport *Trans if resp.DataID != 0 { req.DataID = resp.DataID } + if resp.FastPathVersion != 0 { + req.FastPathVersion = resp.FastPathVersion + } else { + req.FastPathVersion = streamFastPathVersionV1 + } req.Metadata = mergeStreamMetadata(req.Metadata, resp.Metadata) stream := newStreamHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, resp.TransportGeneration, serverStreamCloseSender(s, logical, transport), serverStreamResetSender(s, logical, transport), serverStreamDataSender(s, transport), runtime.configSnapshot()) if err := runtime.register(scope, stream); err != nil { @@ -91,11 +101,12 @@ func serverStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOp id = runtime.nextID() } return normalizeStreamOpenRequest(StreamOpenRequest{ - StreamID: id, - Channel: opt.Channel, - Metadata: cloneStreamMetadata(opt.Metadata), - ReadTimeout: opt.ReadTimeout, - WriteTimeout: opt.WriteTimeout, + StreamID: id, + FastPathVersion: streamFastPathVersionCurrent, + Channel: opt.Channel, + Metadata: cloneStreamMetadata(opt.Metadata), + ReadTimeout: opt.ReadTimeout, + WriteTimeout: opt.WriteTimeout, }) } @@ -148,7 +159,7 @@ func serverStreamDataSender(s *ServerCommon, transport *TransportConn) streamDat } } if dataID := stream.dataIDSnapshot(); dataID != 0 { - return s.sendFastStreamDataTransport(stream.LogicalConn(), transport, dataID, stream.nextOutboundDataSeq(), chunk) + return s.sendFastStreamDataTransport(ctx, stream.LogicalConn(), transport, stream, chunk) } return s.sendEnvelopeTransport(transport, newStreamDataEnvelope(stream.ID(), chunk)) } diff --git a/servertype.go b/servertype.go index 22d4124..f1c6ac3 100644 --- a/servertype.go +++ b/servertype.go @@ -24,6 +24,8 @@ type Server interface { SetStreamConfig(StreamConfig) SetTransferResumeStore(TransferResumeStore) RecoverTransferSnapshots(context.Context) error + SetBulkOpenTuning(BulkOpenTuning) + BulkOpenTuning() BulkOpenTuning SetFileReceiveDir(dir string) error send(c *ClientConn, msg TransferMsg) (WaitMsg, error) sendEnvelope(c *ClientConn, env Envelope) error diff --git a/session_runtime_snapshot.go b/session_runtime_snapshot.go index 0c69c9f..03c6ee7 100644 --- a/session_runtime_snapshot.go +++ b/session_runtime_snapshot.go @@ -6,18 +6,34 @@ import ( ) type ClientRuntimeSnapshot struct { - OwnerState string - Alive bool - SessionEpoch uint64 - TransportAttached bool - HasRuntimeConn bool - HasRuntimeQueue bool - HasRuntimeStopCtx bool - ConnectSource string - ConnectNetwork string - ConnectAddress string - CanReconnect bool - Retry ConnectionRetrySnapshot + OwnerState string + Alive bool + SessionEpoch uint64 + TransportAttached bool + HasRuntimeConn bool + HasRuntimeQueue bool + HasRuntimeStopCtx bool + ConnectSource string + ConnectNetwork string + ConnectAddress string + CanReconnect bool + BulkNetworkProfile string + BulkDefaultMode string + BulkChunkSize int + BulkWindowBytes int + BulkMaxInFlight int + BulkAttachLimit int + BulkActiveLimit int + BulkLaneLimit int + BulkAttachAttempts int64 + BulkAttachRetries int64 + BulkAttachSuccess int64 + BulkAutoFallbacks int64 + TransportBulkAdaptiveSoftPayloadBytes int + TransportStreamAdaptiveSoftPayloadBytes int + TransportStreamAdaptiveWaitThresholdBytes int + TransportStreamAdaptiveFlushDelay time.Duration + Retry ConnectionRetrySnapshot } type ServerRuntimeSnapshot struct { @@ -33,36 +49,43 @@ type ServerRuntimeSnapshot struct { HasRuntimeUDPListener bool HasRuntimeQueue bool HasRuntimeStopCtx bool + BulkChunkSize int + BulkWindowBytes int + BulkMaxInFlight int Retry ConnectionRetrySnapshot } type ClientConnRuntimeSnapshot struct { - ClientID string - RemoteAddress string - Alive bool - Reason string - Error string - IdentityBound bool - UsesStreamTransport bool - TransportGeneration uint64 - TransportAttached bool - HasRuntimeConn bool - HasRuntimeStopCtx bool - TransportAttachCount uint64 - TransportDetachCount uint64 - LastTransportAttachAt time.Time - DetachedClientKeepSec int64 - LastHeartbeatAt time.Time - TransportDetachReason string - TransportDetachKind string - TransportDetachGeneration uint64 - TransportDetachError string - TransportDetachedAt time.Time - TransportDetachHasExpiry bool - TransportDetachExpiry time.Time - TransportDetachRemaining time.Duration - TransportDetachExpired bool - ReattachEligible bool + ClientID string + RemoteAddress string + Alive bool + Reason string + Error string + IdentityBound bool + UsesStreamTransport bool + TransportGeneration uint64 + TransportAttached bool + HasRuntimeConn bool + HasRuntimeStopCtx bool + TransportAttachCount uint64 + TransportDetachCount uint64 + LastTransportAttachAt time.Time + DetachedClientKeepSec int64 + LastHeartbeatAt time.Time + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + TransportDetachHasExpiry bool + TransportDetachExpiry time.Time + TransportDetachRemaining time.Duration + TransportDetachExpired bool + ReattachEligible bool + TransportBulkAdaptiveSoftPayloadBytes int + TransportStreamAdaptiveSoftPayloadBytes int + TransportStreamAdaptiveWaitThresholdBytes int + TransportStreamAdaptiveFlushDelay time.Duration } func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot { @@ -85,6 +108,26 @@ func (c *ClientCommon) clientRuntimeSnapshot() ClientRuntimeSnapshot { snapshot.ConnectAddress = source.addr snapshot.CanReconnect = source.canReconnect() } + snapshot.BulkNetworkProfile = bulkNetworkProfileName(c.BulkNetworkProfile()) + snapshot.BulkDefaultMode = bulkOpenModeName(c.BulkDefaultOpenMode()) + tuning := c.BulkOpenTuning() + cfg := c.BulkDedicatedAttachConfig() + snapshot.BulkChunkSize = tuning.ChunkSize + snapshot.BulkWindowBytes = tuning.WindowBytes + snapshot.BulkMaxInFlight = tuning.MaxInFlight + snapshot.BulkAttachLimit = cfg.AttachLimit + snapshot.BulkActiveLimit = cfg.ActiveLimit + snapshot.BulkLaneLimit = cfg.LaneLimit + snapshot.BulkAttachAttempts = c.bulkAttachAttemptCount.Load() + snapshot.BulkAttachRetries = c.bulkAttachRetryCount.Load() + snapshot.BulkAttachSuccess = c.bulkAttachSuccessCount.Load() + snapshot.BulkAutoFallbacks = c.bulkAttachFallbackCount.Load() + if binding := c.clientTransportBindingSnapshot(); binding != nil { + snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot() + snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot() + } snapshot.Retry = c.connectionRetrySnapshot() return snapshot } @@ -118,6 +161,10 @@ func (s *ServerCommon) serverRuntimeSnapshot() ServerRuntimeSnapshot { snapshot.HasRuntimeQueue = rt.queue != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil } + tuning := s.BulkOpenTuning() + snapshot.BulkChunkSize = tuning.ChunkSize + snapshot.BulkWindowBytes = tuning.WindowBytes + snapshot.BulkMaxInFlight = tuning.MaxInFlight snapshot.Retry = s.connectionRetrySnapshot() return snapshot } @@ -153,6 +200,12 @@ func (c *ClientConn) clientConnRuntimeSnapshot() ClientConnRuntimeSnapshot { snapshot.HasRuntimeConn = c.clientConnTransportSnapshot() != nil snapshot.HasRuntimeStopCtx = rt.stopCtx != nil } + if binding := c.clientConnTransportBindingSnapshot(); binding != nil { + snapshot.TransportBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot() + snapshot.TransportStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot() + snapshot.TransportStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot() + } if detach := c.clientConnTransportDetachSnapshot(); detach != nil { snapshot.TransportDetachReason = detach.Reason snapshot.TransportDetachKind = c.clientConnTransportDetachKindSnapshot() diff --git a/session_runtime_snapshot_test.go b/session_runtime_snapshot_test.go index d7e532e..0a07d3a 100644 --- a/session_runtime_snapshot_test.go +++ b/session_runtime_snapshot_test.go @@ -37,6 +37,45 @@ func TestGetClientRuntimeSnapshotDefaults(t *testing.T) { if snapshot.ConnectSource != "" || snapshot.ConnectNetwork != "" || snapshot.ConnectAddress != "" || snapshot.CanReconnect { t.Fatalf("unexpected default connect source snapshot: %+v", snapshot) } + if got, want := snapshot.BulkNetworkProfile, "default"; got != want { + t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want) + } + if got, want := snapshot.BulkDefaultMode, "shared"; got != want { + t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want) + } + if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want { + t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want { + t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want { + t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAttachLimit, defaultBulkDedicatedAttachLimit; got != want { + t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkActiveLimit, defaultBulkDedicatedActiveLimit; got != want { + t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkLaneLimit, defaultBulkDedicatedLaneLimit; got != want { + t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want) + } + if snapshot.BulkAttachAttempts != 0 || snapshot.BulkAttachRetries != 0 || snapshot.BulkAttachSuccess != 0 || snapshot.BulkAutoFallbacks != 0 { + t.Fatalf("unexpected default bulk attach counters: %+v", snapshot) + } + if snapshot.TransportBulkAdaptiveSoftPayloadBytes != 0 { + t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportBulkAdaptiveSoftPayloadBytes) + } + if snapshot.TransportStreamAdaptiveSoftPayloadBytes != 0 { + t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveSoftPayloadBytes) + } + if snapshot.TransportStreamAdaptiveWaitThresholdBytes != 0 { + t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want 0", snapshot.TransportStreamAdaptiveWaitThresholdBytes) + } + if snapshot.TransportStreamAdaptiveFlushDelay != 0 { + t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", snapshot.TransportStreamAdaptiveFlushDelay) + } if snapshot.Retry != (ConnectionRetrySnapshot{}) { t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) } @@ -78,11 +117,144 @@ func TestGetServerRuntimeSnapshotDefaults(t *testing.T) { if !snapshot.HasRuntimeStopCtx { t.Fatalf("HasRuntimeStopCtx mismatch: got %v want true", snapshot.HasRuntimeStopCtx) } + if got, want := snapshot.BulkChunkSize, defaultBulkChunkSize; got != want { + t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkWindowBytes, defaultBulkOpenWindowBytes; got != want { + t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkMaxInFlight, defaultBulkOpenMaxInFlight; got != want { + t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want) + } if snapshot.Retry != (ConnectionRetrySnapshot{}) { t.Fatalf("Retry snapshot mismatch: %+v", snapshot.Retry) } } +func TestGetClientRuntimeSnapshotIncludesBulkAttachStats(t *testing.T) { + client := NewClient().(*ClientCommon) + client.SetBulkNetworkProfile(BulkNetworkProfileWAN) + client.SetBulkDefaultOpenMode(BulkOpenModeAuto) + client.bulkAttachAttemptCount.Store(11) + client.bulkAttachRetryCount.Store(5) + client.bulkAttachSuccessCount.Store(6) + client.bulkAttachFallbackCount.Store(3) + + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.BulkNetworkProfile, "wan"; got != want { + t.Fatalf("BulkNetworkProfile mismatch: got %q want %q", got, want) + } + if got, want := snapshot.BulkDefaultMode, "auto"; got != want { + t.Fatalf("BulkDefaultMode mismatch: got %q want %q", got, want) + } + if got, want := snapshot.BulkChunkSize, 512*1024; got != want { + t.Fatalf("BulkChunkSize mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkWindowBytes, 8*1024*1024; got != want { + t.Fatalf("BulkWindowBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkMaxInFlight, 16; got != want { + t.Fatalf("BulkMaxInFlight mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAttachLimit, 2; got != want { + t.Fatalf("BulkAttachLimit mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkActiveLimit, 4096; got != want { + t.Fatalf("BulkActiveLimit mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkLaneLimit, 2; got != want { + t.Fatalf("BulkLaneLimit mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAttachAttempts, int64(11); got != want { + t.Fatalf("BulkAttachAttempts mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAttachRetries, int64(5); got != want { + t.Fatalf("BulkAttachRetries mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAttachSuccess, int64(6); got != want { + t.Fatalf("BulkAttachSuccess mismatch: got %d want %d", got, want) + } + if got, want := snapshot.BulkAutoFallbacks, int64(3); got != want { + t.Fatalf("BulkAutoFallbacks mismatch: got %d want %d", got, want) + } +} + +func TestGetClientRuntimeSnapshotIncludesAdaptiveBulkSoftPayload(t *testing.T) { + client := NewClient().(*ClientCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + binding := newTransportBinding(left, queue) + binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil) + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: binding, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 1, + }) + client.markSessionStarted() + + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.TransportBulkAdaptiveSoftPayloadBytes, bulkAdaptiveSoftPayloadMinBytes; got != want { + t.Fatalf("TransportBulkAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamBatchWaitThreshold; got != want { + t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportStreamAdaptiveFlushDelay, streamBatchMaxFlushDelay; got != want { + t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want %s", got, want) + } +} + +func TestGetClientRuntimeSnapshotIncludesAdaptiveStreamTuning(t *testing.T) { + client := NewClient().(*ClientCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32) + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + binding := newTransportBinding(left, queue) + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: binding, + stopCtx: stopCtx, + stopFn: stopFn, + queue: queue, + epoch: 1, + }) + client.markSessionStarted() + + snapshot, err := GetClientRuntimeSnapshot(client) + if err != nil { + t.Fatalf("GetClientRuntimeSnapshot failed: %v", err) + } + if got, want := snapshot.TransportStreamAdaptiveSoftPayloadBytes, streamAdaptiveSoftPayloadMinBytes; got != want { + t.Fatalf("TransportStreamAdaptiveSoftPayloadBytes mismatch: got %d want %d", got, want) + } + if got, want := snapshot.TransportStreamAdaptiveWaitThresholdBytes, streamAdaptiveWaitThresholdMinBytes; got != want { + t.Fatalf("TransportStreamAdaptiveWaitThresholdBytes mismatch: got %d want %d", got, want) + } + if got := snapshot.TransportStreamAdaptiveFlushDelay; got != 0 { + t.Fatalf("TransportStreamAdaptiveFlushDelay mismatch: got %s want 0", got) + } +} + func TestGetRuntimeSnapshotRejectsNil(t *testing.T) { if _, err := GetClientRuntimeSnapshot(nil); !errors.Is(err, errClientRuntimeSnapshotNil) { t.Fatalf("GetClientRuntimeSnapshot nil error = %v, want %v", err, errClientRuntimeSnapshotNil) diff --git a/session_state.go b/session_state.go index f428de2..3a1c892 100644 --- a/session_state.go +++ b/session_state.go @@ -140,6 +140,7 @@ func (c *ClientCommon) cleanupClientSessionResources() { if runtime := c.getBulkRuntime(); runtime != nil { runtime.closeAll(errServiceShutdown) } + c.closeClientDedicatedSidecar() } func (s *ServerCommon) cleanupServerSessionResources() { @@ -158,4 +159,5 @@ func (s *ServerCommon) cleanupServerSessionResources() { if runtime := s.getBulkRuntime(); runtime != nil { runtime.closeAll(errServiceShutdown) } + s.closeAllServerDedicatedSidecars() } diff --git a/signal_benchmark_test.go b/signal_benchmark_test.go index eebd0db..a28d5f2 100644 --- a/signal_benchmark_test.go +++ b/signal_benchmark_test.go @@ -76,7 +76,7 @@ func startSignalRoundTripServerForBenchmark(b *testing.B) (*ServerCommon, string server.SetLink("signal-roundtrip", func(msg *Message) { _ = msg.Reply([]byte("ack:" + string(msg.Value))) }) - if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil { if benchmarkListenPermissionDenied(err) { b.Skipf("tcp benchmark requires local listen permission: %v", err) } @@ -102,7 +102,7 @@ func newSignalRoundTripBenchmarkClient(b *testing.B, addr string) *ClientCommon if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { b.Fatalf("UseModernPSKClient failed: %v", err) } - if err := client.Connect("tcp", addr); err != nil { + if err := client.Connect("tcp", benchmarkTCPDialAddr(b, addr)); err != nil { b.Fatalf("client Connect failed: %v", err) } return client diff --git a/snapshot_binding.go b/snapshot_binding.go index 09fdf72..f7b4065 100644 --- a/snapshot_binding.go +++ b/snapshot_binding.go @@ -3,20 +3,24 @@ package notify import "time" type snapshotBindingDiagnostics struct { - BindingOwner string - BindingAlive bool - BindingCurrent bool - BindingReason string - BindingError string - TransportAttached bool - TransportHasRuntimeConn bool - TransportCurrent bool - TransportDetachReason string - TransportDetachKind string - TransportDetachError string - TransportDetachGeneration uint64 - TransportDetachedAt time.Time - ReattachEligible bool + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + BindingBulkAdaptiveSoftPayloadBytes int + BindingStreamAdaptiveSoftPayloadBytes int + BindingStreamAdaptiveWaitThresholdBytes int + BindingStreamAdaptiveFlushDelay time.Duration + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachError string + TransportDetachGeneration uint64 + TransportDetachedAt time.Time + ReattachEligible bool } func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) snapshotBindingDiagnostics { @@ -36,6 +40,12 @@ func snapshotBindingDiagnosticsFromClient(c *ClientCommon, sessionEpoch uint64) diag.TransportAttached = c.clientTransportAttachedSnapshot() diag.TransportHasRuntimeConn = c.clientTransportConnSnapshot() != nil diag.TransportCurrent = diag.BindingCurrent && diag.TransportAttached + if binding := c.clientTransportBindingSnapshot(); binding != nil { + diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot() + diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot() + diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot() + diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot() + } return diag } @@ -72,5 +82,11 @@ func snapshotBindingDiagnosticsFromLogical(logical *LogicalConn, transport *Tran diag.TransportCurrent = runtime.TransportAttached } diag.BindingCurrent = diag.BindingAlive && diag.TransportCurrent + if binding := logical.transportBindingSnapshot(); binding != nil { + diag.BindingBulkAdaptiveSoftPayloadBytes = binding.bulkAdaptiveSoftPayloadBytesSnapshot() + diag.BindingStreamAdaptiveSoftPayloadBytes = binding.streamAdaptiveSoftPayloadBytesSnapshot() + diag.BindingStreamAdaptiveWaitThresholdBytes = binding.streamAdaptiveWaitThresholdBytesSnapshot() + diag.BindingStreamAdaptiveFlushDelay = binding.streamAdaptiveFlushDelaySnapshot() + } return diag } diff --git a/stream.go b/stream.go index 592c785..a607e03 100644 --- a/stream.go +++ b/stream.go @@ -7,6 +7,7 @@ import ( "net" "os" "sync" + "sync/atomic" "time" ) @@ -93,7 +94,8 @@ type streamHandle struct { runtimeScope string id string dataID uint64 - outboundSeq uint64 + fastPathVersion uint8 + outboundSeq atomic.Uint64 channel StreamChannel metadata StreamMetadata sessionEpoch uint64 @@ -132,6 +134,9 @@ type streamHandle struct { writeDeadlineOverride bool readDeadlineNotify chan struct{} writeDeadlineNotify chan struct{} + writeWaitSeq uint64 + writeWaitCancel context.CancelFunc + writeWaitChanged chan struct{} bytesRead int64 bytesWritten int64 readCalls int64 @@ -157,6 +162,7 @@ func newStreamHandle(parent context.Context, runtime *streamRuntime, runtimeScop runtimeScope: runtimeScope, id: req.StreamID, dataID: req.DataID, + fastPathVersion: normalizeStreamFastPathVersion(req.FastPathVersion), channel: normalizeStreamChannel(req.Channel), metadata: cloneStreamMetadata(req.Metadata), sessionEpoch: sessionEpoch, @@ -224,13 +230,25 @@ func (s *streamHandle) dataIDSnapshot() uint64 { } func (s *streamHandle) nextOutboundDataSeq() uint64 { + return s.reserveOutboundDataSeqs(1) +} + +func (s *streamHandle) reserveOutboundDataSeqs(count int) uint64 { if s == nil { return 0 } - s.mu.Lock() - defer s.mu.Unlock() - s.outboundSeq++ - return s.outboundSeq + if count <= 0 { + count = 1 + } + end := s.outboundSeq.Add(uint64(count)) + return end - uint64(count) + 1 +} + +func (s *streamHandle) fastPathVersionSnapshot() uint8 { + if s == nil { + return streamFastPathVersionV1 + } + return normalizeStreamFastPathVersion(s.fastPathVersion) } func (s *streamHandle) Channel() StreamChannel { @@ -377,6 +395,7 @@ func (s *streamHandle) Write(p []byte) (int, error) { sendDataFn := s.sendDataFn chunkSize := s.chunkSize writeTimeout := s.writeTimeout + writeDeadlineOverride := s.writeDeadlineOverride streamCtx := s.ctx runtime := s.runtime s.mu.Unlock() @@ -399,6 +418,20 @@ func (s *streamHandle) Write(p []byte) (int, error) { end = len(p) } chunk := p[written:end] + if !writeDeadlineOverride && writeTimeout <= 0 { + if tryAcquireStreamOutboundBudget(runtime, len(chunk)) { + err := sendDataFn(streamCtx, s, chunk) + releaseStreamOutboundBudget(runtime, len(chunk)) + if err != nil { + if written > 0 { + s.recordWrite(written, time.Now()) + } + return written, s.normalizeWriteError(err) + } + written = end + continue + } + } sendCtx, cancel, deadlineChanged, err := s.newWriteContext(streamCtx, writeTimeout) if err != nil { if written > 0 { @@ -464,7 +497,15 @@ func (s *streamHandle) SetWriteDeadline(deadline time.Time) error { s.writeDeadline = deadline s.writeDeadlineOverride = true signalStreamDeadlineChangeLocked(&s.writeDeadlineNotify) + waitCancel := s.writeWaitCancel + if s.writeWaitChanged != nil { + close(s.writeWaitChanged) + s.writeWaitChanged = nil + } s.mu.Unlock() + if waitCancel != nil { + waitCancel() + } return nil } @@ -535,7 +576,6 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time } s.mu.Lock() deadline := s.effectiveWriteDeadlineLocked(time.Now(), writeTimeout) - deadlineNotify := s.writeDeadlineNotify s.mu.Unlock() if !deadline.IsZero() && !deadline.After(time.Now()) { return nil, func() {}, nil, os.ErrDeadlineExceeded @@ -548,19 +588,20 @@ func (s *streamHandle) newWriteContext(parent context.Context, writeTimeout time baseCtx, baseCancel = context.WithCancel(parent) } changed := make(chan struct{}) - done := make(chan struct{}) - go func() { - defer close(done) - select { - case <-baseCtx.Done(): - case <-deadlineNotify: - close(changed) - baseCancel() - } - }() + s.mu.Lock() + s.writeWaitSeq++ + waitSeq := s.writeWaitSeq + s.writeWaitCancel = baseCancel + s.writeWaitChanged = changed + s.mu.Unlock() cancel := func() { baseCancel() - <-done + s.mu.Lock() + if s.writeWaitSeq == waitSeq { + s.writeWaitCancel = nil + s.writeWaitChanged = nil + } + s.mu.Unlock() } return baseCtx, cancel, changed, nil } @@ -814,7 +855,11 @@ func (s *streamHandle) pushChunkWithOwnership(chunk []byte, owned bool) error { s.finalize() return err } - s.readQueue = append(s.readQueue, stored) + if len(s.readBuf) == 0 && len(s.readQueue) == 0 { + s.readBuf = stored + } else { + s.readQueue = append(s.readQueue, stored) + } s.bufferedBytes += len(stored) s.notifyReadableLocked() s.mu.Unlock() @@ -917,6 +962,10 @@ func (s *streamHandle) snapshot() StreamSnapshot { snapshot.BindingCurrent = diag.BindingCurrent snapshot.BindingReason = diag.BindingReason snapshot.BindingError = diag.BindingError + snapshot.BindingBulkAdaptiveSoftPayloadBytes = diag.BindingBulkAdaptiveSoftPayloadBytes + snapshot.BindingStreamAdaptiveSoftPayloadBytes = diag.BindingStreamAdaptiveSoftPayloadBytes + snapshot.BindingStreamAdaptiveWaitThresholdBytes = diag.BindingStreamAdaptiveWaitThresholdBytes + snapshot.BindingStreamAdaptiveFlushDelay = diag.BindingStreamAdaptiveFlushDelay snapshot.TransportAttached = diag.TransportAttached snapshot.TransportHasRuntimeConn = diag.TransportHasRuntimeConn snapshot.TransportCurrent = diag.TransportCurrent @@ -1057,8 +1106,23 @@ func acquireStreamOutboundBudget(runtime *streamRuntime, ctx context.Context, si return runtime.acquireOutbound(ctx, size) } +func tryAcquireStreamOutboundBudget(runtime *streamRuntime, size int) bool { + if runtime == nil { + return true + } + return runtime.tryAcquireOutbound(size) +} + +func releaseStreamOutboundBudget(runtime *streamRuntime, size int) { + if runtime == nil { + return + } + runtime.releaseOutbound(size) +} + func normalizeStreamOpenRequest(req StreamOpenRequest) StreamOpenRequest { req.Channel = normalizeStreamChannel(req.Channel) + req.FastPathVersion = normalizeStreamFastPathVersion(req.FastPathVersion) req.Metadata = cloneStreamMetadata(req.Metadata) return req } diff --git a/stream_batch_codec.go b/stream_batch_codec.go new file mode 100644 index 0000000..3858a90 --- /dev/null +++ b/stream_batch_codec.go @@ -0,0 +1,6 @@ +package notify + +type streamBatchCodec struct { + encodeSingle func(streamFastDataFrame) ([]byte, error) + encodeBatch func([]streamFastDataFrame) ([]byte, error) +} diff --git a/stream_batch_sender.go b/stream_batch_sender.go new file mode 100644 index 0000000..7a8130d --- /dev/null +++ b/stream_batch_sender.go @@ -0,0 +1,582 @@ +package notify + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" +) + +const ( + streamBatchMaxPayloads = 64 + streamBatchMaxPayloadBytes = 2 * 1024 * 1024 + streamBatchMaxFlushDelay = 50 * time.Microsecond + streamBatchWaitThreshold = 128 * 1024 +) + +const ( + streamBatchRequestQueued int32 = iota + streamBatchRequestStarted + streamBatchRequestCanceled +) + +type streamBatchRequestState struct { + value atomic.Int32 +} + +type streamBatchRequest struct { + ctx context.Context + frame streamFastDataFrame + hasFrame bool + encodedPayload []byte + hasEncoded bool + frames []streamFastDataFrame + fastPathVersion uint8 + deadline time.Time + done chan error + state *streamBatchRequestState +} + +type streamBatchSender struct { + binding *transportBinding + codec streamBatchCodec + writeTimeoutProvider func() time.Duration + reqCh chan streamBatchRequest + stopCh chan struct{} + doneCh chan struct{} + + stopOnce sync.Once + flushMu sync.Mutex + queued atomic.Int64 + errMu sync.Mutex + err error +} + +func newStreamBatchSender(binding *transportBinding, codec streamBatchCodec, writeTimeoutProvider func() time.Duration) *streamBatchSender { + sender := &streamBatchSender{ + binding: binding, + codec: codec, + writeTimeoutProvider: writeTimeoutProvider, + reqCh: make(chan streamBatchRequest, streamBatchMaxPayloads*4), + stopCh: make(chan struct{}), + doneCh: make(chan struct{}), + } + go sender.run() + return sender +} + +func (s *streamBatchSender) submitData(ctx context.Context, dataID uint64, seq uint64, fastPathVersion uint8, payload []byte) error { + if s == nil { + return errTransportDetached + } + if len(payload) == 0 { + return nil + } + return s.submitRequest(streamBatchRequest{ + ctx: ctx, + frame: streamFastDataFrame{ + DataID: dataID, + Seq: seq, + Payload: payload, + }, + hasFrame: true, + fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), + }) +} + +func (s *streamBatchSender) submitEncoded(ctx context.Context, fastPathVersion uint8, payload []byte) error { + if s == nil { + return errTransportDetached + } + if len(payload) == 0 { + return nil + } + return s.submitRequest(streamBatchRequest{ + ctx: ctx, + encodedPayload: payload, + hasEncoded: true, + fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), + }) +} + +func (s *streamBatchSender) submitFrames(ctx context.Context, fastPathVersion uint8, frames []streamFastDataFrame) error { + if s == nil { + return errTransportDetached + } + if len(frames) == 0 { + return nil + } + queuedFrames := append([]streamFastDataFrame(nil), frames...) + req := streamBatchRequest{ + ctx: ctx, + frames: queuedFrames, + fastPathVersion: normalizeStreamFastPathVersion(fastPathVersion), + } + if len(queuedFrames) == 1 { + req.frame = queuedFrames[0] + req.frames = nil + req.hasFrame = true + } + return s.submitRequest(req) +} + +func (s *streamBatchSender) submitRequest(req streamBatchRequest) error { + if s == nil { + return errTransportDetached + } + if req.ctx == nil { + req.ctx = context.Background() + } + if !req.hasFrame && !req.hasEncoded && len(req.frames) == 0 { + return nil + } + req.fastPathVersion = normalizeStreamFastPathVersion(req.fastPathVersion) + req.done = make(chan error, 1) + req.state = &streamBatchRequestState{} + if deadline, ok := req.ctx.Deadline(); ok { + req.deadline = deadline + } + if err := s.errSnapshot(); err != nil { + return err + } + if s.shouldDirectSubmit(req) { + if submitted, err := s.tryDirectSubmit(req); submitted { + return err + } + } + s.queued.Add(1) + select { + case <-req.ctx.Done(): + s.queued.Add(-1) + return normalizeStreamDeadlineError(req.ctx.Err()) + case <-s.stopCh: + s.queued.Add(-1) + return s.stoppedErr() + case s.reqCh <- req: + } + select { + case err := <-req.done: + return err + case <-req.ctx.Done(): + if req.tryCancel() { + return normalizeStreamDeadlineError(req.ctx.Err()) + } + return <-req.done + } +} + +func (s *streamBatchSender) shouldDirectSubmit(req streamBatchRequest) bool { + if req.hasEncoded { + return false + } + if !req.hasFrame && len(req.frames) == 0 { + return false + } + return !streamFastPathSupportsBatch(req.fastPathVersion) +} + +func (s *streamBatchSender) tryDirectSubmit(req streamBatchRequest) (bool, error) { + if s == nil { + return true, errTransportDetached + } + if err := s.errSnapshot(); err != nil { + return true, err + } + select { + case <-req.ctx.Done(): + return true, normalizeStreamDeadlineError(req.ctx.Err()) + case <-s.stopCh: + return true, s.stoppedErr() + default: + } + if s.queued.Load() != 0 { + return false, nil + } + if !s.flushMu.TryLock() { + return false, nil + } + defer s.flushMu.Unlock() + if s.queued.Load() != 0 { + return false, nil + } + if err := s.errSnapshot(); err != nil { + return true, err + } + if !req.tryStart() { + return true, req.canceledErr() + } + if err := req.contextErr(); err != nil { + return true, err + } + if err := s.flush([]streamBatchRequest{req}); err != nil { + s.setErr(err) + s.failPending(err) + return true, err + } + return true, nil +} + +func (s *streamBatchSender) run() { + defer close(s.doneCh) + for { + req, ok := s.nextRequest() + if !ok { + return + } + batch := []streamBatchRequest{req} + batchBytes := streamBatchRequestApproxBytes(req) + softPayloadLimit := s.batchSoftPayloadLimit() + waitThreshold := s.batchWaitThreshold() + flushDelay := s.batchFlushDelay() + timer := (*time.Timer)(nil) + timerCh := (<-chan time.Time)(nil) + if flushDelay > 0 && batchBytes < waitThreshold && batchBytes < softPayloadLimit && len(batch) < streamBatchMaxPayloads { + timer = time.NewTimer(flushDelay) + timerCh = timer.C + } + drain: + for len(batch) < streamBatchMaxPayloads && batchBytes < softPayloadLimit { + if timerCh == nil { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return + case next := <-s.reqCh: + batch = append(batch, next) + batchBytes += streamBatchRequestApproxBytes(next) + default: + break drain + } + continue + } + select { + case <-s.stopCh: + if timer != nil { + timer.Stop() + } + s.failPending(s.stoppedErr()) + return + case next := <-s.reqCh: + batch = append(batch, next) + batchBytes += streamBatchRequestApproxBytes(next) + case <-timerCh: + timerCh = nil + break drain + } + } + if timer != nil { + if !timer.Stop() && timerCh != nil { + select { + case <-timer.C: + default: + } + } + } + s.flushMu.Lock() + err := s.errSnapshot() + active := make([]streamBatchRequest, 0, len(batch)) + for _, item := range batch { + if !item.tryStart() { + s.finishRequest(item, item.canceledErr()) + continue + } + if itemErr := item.contextErr(); itemErr != nil { + s.finishRequest(item, itemErr) + continue + } + active = append(active, item) + } + if len(active) == 0 { + s.flushMu.Unlock() + continue + } + if err == nil { + err = s.flush(active) + } + s.flushMu.Unlock() + if err != nil { + s.setErr(err) + for _, item := range active { + s.finishRequest(item, err) + } + s.failPending(err) + return + } + for _, item := range active { + s.finishRequest(item, nil) + } + } +} + +func (s *streamBatchSender) nextRequest() (streamBatchRequest, bool) { + select { + case <-s.stopCh: + s.failPending(s.stoppedErr()) + return streamBatchRequest{}, false + case req := <-s.reqCh: + return req, true + } +} + +func (s *streamBatchSender) flush(requests []streamBatchRequest) error { + if s == nil || s.binding == nil { + return errTransportDetached + } + queue := s.binding.queueSnapshot() + if queue == nil { + return errTransportFrameQueueUnavailable + } + payloads, err := s.encodeRequests(requests) + if err != nil { + return err + } + writeTimeout := s.transportWriteTimeout() + payloadBytes := 0 + for _, payload := range payloads { + payloadBytes += len(payload) + } + started := time.Now() + err = s.binding.withConnWriteLockDeadline(writeDeadlineFromTimeout(writeTimeout), func(conn net.Conn) error { + return writeFramedPayloadBatchUnlocked(conn, queue, payloads) + }) + s.binding.observeStreamAdaptivePayloadWrite(payloadBytes, time.Since(started), writeTimeout, err) + return err +} + +func (s *streamBatchSender) transportWriteTimeout() time.Duration { + if s == nil || s.writeTimeoutProvider == nil { + return 0 + } + return s.writeTimeoutProvider() +} + +func (s *streamBatchSender) batchSoftPayloadLimit() int { + if s == nil || s.binding == nil { + return streamAdaptiveSoftPayloadFallbackBytes + } + return s.binding.streamAdaptiveSoftPayloadBytesSnapshot() +} + +func (s *streamBatchSender) batchWaitThreshold() int { + if s == nil || s.binding == nil { + return streamBatchWaitThreshold + } + return s.binding.streamAdaptiveWaitThresholdBytesSnapshot() +} + +func (s *streamBatchSender) batchFlushDelay() time.Duration { + if s == nil || s.binding == nil { + return streamBatchMaxFlushDelay + } + return s.binding.streamAdaptiveFlushDelaySnapshot() +} + +func (s *streamBatchSender) batchPlainPayloadLimit() int { + limit := s.batchSoftPayloadLimit() + if limit <= streamFastBatchHeaderLen { + return streamFastBatchHeaderLen + 1 + } + return minInt(limit, streamFastBatchMaxPlainBytes) +} + +func (s *streamBatchSender) encodeRequests(requests []streamBatchRequest) ([][]byte, error) { + if len(requests) == 0 { + return nil, nil + } + payloads := make([][]byte, 0, len(requests)) + batchPlainLimit := s.batchPlainPayloadLimit() + var batch []streamFastDataFrame + flushBatch := func() error { + if len(batch) == 0 { + return nil + } + payload, err := s.encodeBatch(batch) + if err != nil { + return err + } + payloads = append(payloads, payload) + batch = batch[:0] + return nil + } + batchBytes := streamFastBatchHeaderLen + appendFrame := func(frame streamFastDataFrame, fastPathVersion uint8) error { + if !streamFastPathSupportsBatch(fastPathVersion) { + if err := flushBatch(); err != nil { + return err + } + batchBytes = streamFastBatchHeaderLen + payload, err := s.encodeSingle(frame) + if err != nil { + return err + } + payloads = append(payloads, payload) + return nil + } + frameLen := streamFastBatchFrameLen(frame) + if frameLen+streamFastBatchHeaderLen > batchPlainLimit { + if err := flushBatch(); err != nil { + return err + } + batchBytes = streamFastBatchHeaderLen + payload, err := s.encodeSingle(frame) + if err != nil { + return err + } + payloads = append(payloads, payload) + return nil + } + if len(batch) > 0 && (len(batch) >= streamFastBatchMaxItems || batchBytes+frameLen > batchPlainLimit) { + if err := flushBatch(); err != nil { + return err + } + batchBytes = streamFastBatchHeaderLen + } + if batch == nil { + batch = make([]streamFastDataFrame, 0, minInt(len(requests), streamFastBatchMaxItems)) + } + batch = append(batch, frame) + batchBytes += frameLen + return nil + } + for _, req := range requests { + if req.hasFrame { + if err := appendFrame(req.frame, req.fastPathVersion); err != nil { + return nil, err + } + } + for _, frame := range req.frames { + if err := appendFrame(frame, req.fastPathVersion); err != nil { + return nil, err + } + } + if req.hasEncoded { + if err := flushBatch(); err != nil { + return nil, err + } + batchBytes = streamFastBatchHeaderLen + payloads = append(payloads, req.encodedPayload) + } + } + if err := flushBatch(); err != nil { + return nil, err + } + return payloads, nil +} + +func streamBatchRequestApproxBytes(req streamBatchRequest) int { + total := 0 + if req.hasFrame { + total += streamFastBatchFrameLen(req.frame) + } + for _, frame := range req.frames { + total += streamFastBatchFrameLen(frame) + } + if req.hasEncoded { + total += len(req.encodedPayload) + } + return total +} + +func (s *streamBatchSender) encodeSingle(frame streamFastDataFrame) ([]byte, error) { + if s == nil || s.codec.encodeSingle == nil { + return nil, errTransportDetached + } + return s.codec.encodeSingle(frame) +} + +func (s *streamBatchSender) encodeBatch(frames []streamFastDataFrame) ([]byte, error) { + if len(frames) == 1 || s.codec.encodeBatch == nil { + return s.encodeSingle(frames[0]) + } + return s.codec.encodeBatch(frames) +} + +func (s *streamBatchSender) finishRequest(req streamBatchRequest, err error) { + if s != nil { + s.queued.Add(-1) + } + req.done <- err +} + +func (s *streamBatchSender) stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + s.setErr(errTransportDetached) + close(s.stopCh) + }) + <-s.doneCh +} + +func (s *streamBatchSender) failPending(err error) { + for { + select { + case item := <-s.reqCh: + s.finishRequest(item, err) + default: + return + } + } +} + +func (s *streamBatchSender) setErr(err error) { + if s == nil || err == nil { + return + } + s.errMu.Lock() + if s.err == nil { + s.err = err + } + s.errMu.Unlock() +} + +func (s *streamBatchSender) errSnapshot() error { + if s == nil { + return errTransportDetached + } + s.errMu.Lock() + defer s.errMu.Unlock() + return s.err +} + +func (s *streamBatchSender) stoppedErr() error { + if err := s.errSnapshot(); err != nil { + return err + } + return errTransportDetached +} + +func (r streamBatchRequest) contextErr() error { + if r.ctx == nil { + return nil + } + select { + case <-r.ctx.Done(): + return normalizeStreamDeadlineError(r.ctx.Err()) + default: + return nil + } +} + +func (r streamBatchRequest) tryStart() bool { + if r.state == nil { + return true + } + return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestStarted) +} + +func (r streamBatchRequest) tryCancel() bool { + if r.state == nil { + return false + } + return r.state.value.CompareAndSwap(streamBatchRequestQueued, streamBatchRequestCanceled) +} + +func (r streamBatchRequest) canceledErr() error { + if err := r.contextErr(); err != nil { + return err + } + return context.Canceled +} diff --git a/stream_benchmark_test.go b/stream_benchmark_test.go index 2d5c278..fd7118a 100644 --- a/stream_benchmark_test.go +++ b/stream_benchmark_test.go @@ -128,7 +128,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi return nil }) - if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil { b.Fatalf("server Listen failed: %v", err) } b.Cleanup(func() { @@ -140,7 +140,7 @@ func benchmarkStreamTCPThroughput(b *testing.B, payloadSize int, cfg StreamConfi if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { b.Fatalf("UseModernPSKClient failed: %v", err) } - if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } b.Cleanup(func() { @@ -216,7 +216,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu return nil }) - if err := server.Listen("tcp", "127.0.0.1:0"); err != nil { + if err := server.Listen("tcp", benchmarkTCPListenAddr(b)); err != nil { b.Fatalf("server Listen failed: %v", err) } b.Cleanup(func() { @@ -228,7 +228,7 @@ func benchmarkStreamTCPThroughputConcurrent(b *testing.B, payloadSize int, concu if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { b.Fatalf("UseModernPSKClient failed: %v", err) } - if err := client.Connect("tcp", server.listener.Addr().String()); err != nil { + if err := client.Connect("tcp", benchmarkTCPDialAddr(b, server.listener.Addr().String())); err != nil { b.Fatalf("client Connect failed: %v", err) } b.Cleanup(func() { diff --git a/stream_control.go b/stream_control.go index 2002bfa..677536d 100644 --- a/stream_control.go +++ b/stream_control.go @@ -7,17 +7,19 @@ import ( ) type StreamOpenRequest struct { - StreamID string - DataID uint64 - Channel StreamChannel - Metadata StreamMetadata - ReadTimeout time.Duration - WriteTimeout time.Duration + StreamID string + DataID uint64 + FastPathVersion uint8 + Channel StreamChannel + Metadata StreamMetadata + ReadTimeout time.Duration + WriteTimeout time.Duration } type StreamOpenResponse struct { StreamID string DataID uint64 + FastPathVersion uint8 Accepted bool TransportGeneration uint64 Metadata StreamMetadata @@ -92,6 +94,8 @@ func (c *ClientCommon) handleInboundStreamOpen(msg *Message) { return } scope := clientFileScope() + req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion) + resp.FastPathVersion = req.FastPathVersion if req.DataID == 0 { req.DataID = runtime.nextDataID() resp.DataID = req.DataID @@ -178,6 +182,8 @@ func (s *ServerCommon) handleInboundStreamOpen(msg *Message) { } transport := messageTransportConnSnapshot(msg) scope := serverFileScope(logical) + req.FastPathVersion = negotiateStreamFastPathVersion(req.FastPathVersion) + resp.FastPathVersion = req.FastPathVersion if req.DataID == 0 { req.DataID = runtime.nextDataID() resp.DataID = req.DataID diff --git a/stream_fastpath.go b/stream_fastpath.go index b24c7d2..c234d1c 100644 --- a/stream_fastpath.go +++ b/stream_fastpath.go @@ -1,8 +1,10 @@ package notify import ( + "context" "encoding/binary" "errors" + "io" ) var ( @@ -15,6 +17,7 @@ const ( streamFastPayloadVersion = 1 streamFastPayloadTypeData = 1 streamFastPayloadHeaderLen = 28 + streamFastBatchDirectLimit = 512 * 1024 ) type streamFastDataFrame struct { @@ -24,6 +27,56 @@ type streamFastDataFrame struct { Payload []byte } +func streamAdaptiveFramePayloadLimit(binding *transportBinding) int { + if binding == nil { + return 0 + } + limit := binding.streamAdaptiveSoftPayloadBytesSnapshot() - streamFastPayloadHeaderLen + if limit <= 0 { + return 1 + } + maxPayload := streamFastBatchMaxPlainBytes - streamFastPayloadHeaderLen + if limit > maxPayload { + return maxPayload + } + return limit +} + +func streamFastSplitFrameCount(size int, maxPayload int) int { + if size <= 0 || maxPayload <= 0 { + return 1 + } + return (size + maxPayload - 1) / maxPayload +} + +func buildStreamFastSplitFrames(dataID uint64, startSeq uint64, chunk []byte, maxPayload int) []streamFastDataFrame { + if len(chunk) == 0 { + return nil + } + if maxPayload <= 0 || len(chunk) <= maxPayload { + return []streamFastDataFrame{{ + DataID: dataID, + Seq: startSeq, + Payload: chunk, + }} + } + frames := make([]streamFastDataFrame, 0, streamFastSplitFrameCount(len(chunk), maxPayload)) + seq := startSeq + for offset := 0; offset < len(chunk); offset += maxPayload { + end := offset + maxPayload + if end > len(chunk) { + end = len(chunk) + } + frames = append(frames, streamFastDataFrame{ + DataID: dataID, + Seq: seq, + Payload: chunk[offset:end], + }) + seq++ + } + return frames +} + func encodeStreamFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error { if dataID == 0 { return errStreamFastDataIDEmpty @@ -51,6 +104,31 @@ func encodeStreamFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byt return frame, nil } +func encodeStreamFastFramePayload(frame streamFastDataFrame) ([]byte, error) { + framePayload := make([]byte, streamFastPayloadHeaderLen+len(frame.Payload)) + if err := encodeStreamFastDataFrameHeader(framePayload, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { + return nil, err + } + framePayload[6] = frame.Flags + copy(framePayload[streamFastPayloadHeaderLen:], frame.Payload) + return framePayload, nil +} + +func encodeStreamFastFramePayloadFast(encode transportFastPlainEncoder, secretKey []byte, frame streamFastDataFrame) ([]byte, error) { + if encode == nil { + return nil, errTransportPayloadEncryptFailed + } + plainLen := streamFastPayloadHeaderLen + len(frame.Payload) + return encode(secretKey, plainLen, func(dst []byte) error { + if err := encodeStreamFastDataFrameHeader(dst, frame.DataID, frame.Seq, len(frame.Payload)); err != nil { + return err + } + dst[6] = frame.Flags + copy(dst[streamFastPayloadHeaderLen:], frame.Payload) + return nil + }) +} + func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error) { if len(payload) < 4 || string(payload[:4]) != streamFastPayloadMagic { return streamFastDataFrame{}, false, nil @@ -77,18 +155,66 @@ func decodeStreamFastDataFrame(payload []byte) (streamFastDataFrame, bool, error }, true, nil } -func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { - if c != nil && c.fastStreamEncode != nil { - return c.fastStreamEncode(c.SecretKey, dataID, seq, chunk) +func (c *ClientCommon) encodeFastStreamPayload(frame streamFastDataFrame) ([]byte, error) { + if c != nil && c.fastStreamEncode != nil && frame.Flags == 0 { + return c.fastStreamEncode(c.SecretKey, frame.DataID, frame.Seq, frame.Payload) } - plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) + if c != nil && c.fastPlainEncode != nil { + return encodeStreamFastFramePayloadFast(c.fastPlainEncode, c.SecretKey, frame) + } + plain, err := encodeStreamFastFramePayload(frame) if err != nil { return nil, err } return c.encryptTransportPayload(plain) } -func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byte) error { +func (c *ClientCommon) encodeFastStreamDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + return c.encodeFastStreamPayload(streamFastDataFrame{ + DataID: dataID, + Seq: seq, + Payload: chunk, + }) +} + +func (c *ClientCommon) encodeFastStreamBatchPayload(frames []streamFastDataFrame) ([]byte, error) { + if c == nil { + return nil, errStreamClientNil + } + if c.fastPlainEncode != nil { + return encodeStreamFastBatchPayloadFast(c.fastPlainEncode, c.SecretKey, frames) + } + plain, err := encodeStreamFastBatchPlain(frames) + if err != nil { + return nil, err + } + return c.encryptTransportPayload(plain) +} + +func (c *ClientCommon) sendFastStreamData(ctx context.Context, stream *streamHandle, chunk []byte) error { + if stream == nil { + return io.ErrClosedPipe + } + dataID := stream.dataIDSnapshot() + fastPathVersion := stream.fastPathVersionSnapshot() + if binding := c.clientTransportBindingSnapshot(); binding != nil && streamFastPathSupportsBatch(fastPathVersion) { + if sender := binding.clientStreamBatchSenderSnapshot(c); sender != nil { + if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload { + startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload)) + return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload)) + } + seq := stream.reserveOutboundDataSeqs(1) + if len(chunk) < streamFastBatchDirectLimit { + return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) + } + payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) + if err != nil { + return err + } + return sender.submitEncoded(ctx, fastPathVersion, payload) + } + } + seq := stream.reserveOutboundDataSeqs(1) payload, err := c.encodeFastStreamDataPayload(dataID, seq, chunk) if err != nil { return err @@ -96,29 +222,78 @@ func (c *ClientCommon) sendFastStreamData(dataID uint64, seq uint64, chunk []byt return c.writePayloadToTransport(payload) } -func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { - if logical != nil { - if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil { - return fastStreamEncode(logical.secretKeySnapshot(), dataID, seq, chunk) - } +func (s *ServerCommon) encodeFastStreamPayloadLogical(logical *LogicalConn, frame streamFastDataFrame) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached } - plain, err := encodeStreamFastDataFrame(dataID, seq, chunk) + if fastStreamEncode := logical.fastStreamEncodeSnapshot(); fastStreamEncode != nil && frame.Flags == 0 { + return fastStreamEncode(logical.secretKeySnapshot(), frame.DataID, frame.Seq, frame.Payload) + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + return encodeStreamFastFramePayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frame) + } + plain, err := encodeStreamFastFramePayload(frame) if err != nil { return nil, err } return s.encryptTransportPayloadLogical(logical, plain) } -func (s *ServerCommon) sendFastStreamDataTransport(logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error { +func (s *ServerCommon) encodeFastStreamDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) { + return s.encodeFastStreamPayloadLogical(logical, streamFastDataFrame{ + DataID: dataID, + Seq: seq, + Payload: chunk, + }) +} + +func (s *ServerCommon) encodeFastStreamBatchPayloadLogical(logical *LogicalConn, frames []streamFastDataFrame) ([]byte, error) { + if logical == nil { + return nil, errTransportDetached + } + if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil { + return encodeStreamFastBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), frames) + } + plain, err := encodeStreamFastBatchPlain(frames) + if err != nil { + return nil, err + } + return s.encryptTransportPayloadLogical(logical, plain) +} + +func (s *ServerCommon) sendFastStreamDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, stream *streamHandle, chunk []byte) error { if err := s.ensureServerTransportSendReady(transport); err != nil { return err } + if stream == nil { + return io.ErrClosedPipe + } if logical == nil && transport != nil { logical = transport.logicalConnSnapshot() } if logical == nil { return errTransportDetached } + dataID := stream.dataIDSnapshot() + fastPathVersion := stream.fastPathVersionSnapshot() + if binding := logical.transportBindingSnapshot(); binding != nil && binding.queueSnapshot() != nil && streamFastPathSupportsBatch(fastPathVersion) { + if sender := binding.serverStreamBatchSenderSnapshot(logical); sender != nil { + if maxPayload := streamAdaptiveFramePayloadLimit(binding); maxPayload > 0 && len(chunk) > maxPayload { + startSeq := stream.reserveOutboundDataSeqs(streamFastSplitFrameCount(len(chunk), maxPayload)) + return sender.submitFrames(ctx, fastPathVersion, buildStreamFastSplitFrames(dataID, startSeq, chunk, maxPayload)) + } + seq := stream.reserveOutboundDataSeqs(1) + if len(chunk) < streamFastBatchDirectLimit { + return sender.submitData(ctx, dataID, seq, fastPathVersion, chunk) + } + payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) + if err != nil { + return err + } + return sender.submitEncoded(ctx, fastPathVersion, payload) + } + } + seq := stream.reserveOutboundDataSeqs(1) payload, err := s.encodeFastStreamDataPayloadLogical(logical, dataID, seq, chunk) if err != nil { return err diff --git a/stream_fastpath_test.go b/stream_fastpath_test.go index 18dfab6..ea535ac 100644 --- a/stream_fastpath_test.go +++ b/stream_fastpath_test.go @@ -4,6 +4,8 @@ import ( "b612.me/stario" "context" "math" + "sync" + "sync/atomic" "testing" "time" ) @@ -31,6 +33,46 @@ func TestStreamFastDataFrameRoundTrip(t *testing.T) { } } +func TestStreamFastBatchPlainRoundTrip(t *testing.T) { + frames := []streamFastDataFrame{ + { + DataID: 11, + Seq: 7, + Payload: []byte("alpha"), + }, + { + DataID: 12, + Seq: 8, + Payload: []byte("beta"), + }, + } + wire, err := encodeStreamFastBatchPlain(frames) + if err != nil { + t.Fatalf("encodeStreamFastBatchPlain failed: %v", err) + } + decoded, matched, err := decodeStreamFastBatchPlain(wire) + if err != nil { + t.Fatalf("decodeStreamFastBatchPlain failed: %v", err) + } + if !matched { + t.Fatal("decodeStreamFastBatchPlain should match encoded batch") + } + if got, want := len(decoded), len(frames); got != want { + t.Fatalf("decoded frame count = %d, want %d", got, want) + } + for index := range frames { + if got, want := decoded[index].DataID, frames[index].DataID; got != want { + t.Fatalf("frame %d data id = %d, want %d", index, got, want) + } + if got, want := decoded[index].Seq, frames[index].Seq; got != want { + t.Fatalf("frame %d seq = %d, want %d", index, got, want) + } + if got, want := string(decoded[index].Payload), string(frames[index].Payload); got != want { + t.Fatalf("frame %d payload = %q, want %q", index, got, want) + } + } +} + func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) { client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { @@ -60,6 +102,165 @@ func TestClientDispatchInboundTransportPayloadFastStream(t *testing.T) { readStreamExactly(t, stream, "fast-payload", 2*time.Second) } +func TestClientDispatchInboundTransportPayloadFastStreamBatch(t *testing.T) { + client := NewClient().(*ClientCommon) + if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { + t.Fatalf("UseModernPSKClient failed: %v", err) + } + runtime := client.getStreamRuntime() + if runtime == nil { + t.Fatal("client stream runtime should not be nil") + } + streamA := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "fast-client-a", + DataID: 23, + FastPathVersion: streamFastPathVersionCurrent, + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + streamB := newStreamHandle(context.Background(), runtime, clientFileScope(), StreamOpenRequest{ + StreamID: "fast-client-b", + DataID: 24, + FastPathVersion: streamFastPathVersionCurrent, + Channel: StreamDataChannel, + }, 0, nil, nil, 0, nil, nil, nil, runtime.configSnapshot()) + if err := runtime.register(clientFileScope(), streamA); err != nil { + t.Fatalf("register streamA failed: %v", err) + } + if err := runtime.register(clientFileScope(), streamB); err != nil { + t.Fatalf("register streamB failed: %v", err) + } + + payload, err := client.encodeFastStreamBatchPayload([]streamFastDataFrame{ + {DataID: 23, Seq: 1, Payload: []byte("fast-a")}, + {DataID: 24, Seq: 2, Payload: []byte("fast-b")}, + }) + if err != nil { + t.Fatalf("encodeFastStreamBatchPayload failed: %v", err) + } + if err := client.dispatchInboundTransportPayload(payload, time.Now()); err != nil { + t.Fatalf("dispatchInboundTransportPayload failed: %v", err) + } + + readStreamExactly(t, streamA, "fast-a", 2*time.Second) + readStreamExactly(t, streamB, "fast-b", 2*time.Second) +} + +func TestStreamBatchSenderEncodeRequestsCoalescesFastV2Frames(t *testing.T) { + var ( + singleCalls int + batchCalls [][]streamFastDataFrame + ) + sender := &streamBatchSender{ + codec: streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + singleCalls++ + return []byte("single"), nil + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + cloned := make([]streamFastDataFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil + }, + }, + } + payloads, err := sender.encodeRequests([]streamBatchRequest{ + { + frames: []streamFastDataFrame{{ + DataID: 101, + Seq: 1, + Payload: []byte("a"), + }}, + fastPathVersion: streamFastPathVersionV2, + }, + { + frames: []streamFastDataFrame{{ + DataID: 102, + Seq: 2, + Payload: []byte("b"), + }}, + fastPathVersion: streamFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 1; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if singleCalls != 0 { + t.Fatalf("single encode calls = %d, want 0", singleCalls) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("batched frame count = %d, want %d", got, want) + } +} + +func TestStreamBatchSenderEncodeRequestsFlushesBeforePreEncodedPayload(t *testing.T) { + var ( + singleCalls int + batchCalls [][]streamFastDataFrame + ) + sender := &streamBatchSender{ + codec: streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + singleCalls++ + return append([]byte("single-"), frame.Payload...), nil + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + cloned := make([]streamFastDataFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch-2"), nil + }, + }, + } + payloads, err := sender.encodeRequests([]streamBatchRequest{ + { + frames: []streamFastDataFrame{ + {DataID: 101, Seq: 1, Payload: []byte("a")}, + {DataID: 102, Seq: 2, Payload: []byte("b")}, + }, + fastPathVersion: streamFastPathVersionV2, + }, + { + encodedPayload: []byte("raw"), + hasEncoded: true, + fastPathVersion: streamFastPathVersionV2, + }, + { + frames: []streamFastDataFrame{ + {DataID: 103, Seq: 3, Payload: []byte("c")}, + }, + fastPathVersion: streamFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 3; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := string(payloads[0]), "batch-2"; got != want { + t.Fatalf("first payload = %q, want %q", got, want) + } + if got, want := string(payloads[1]), "raw"; got != want { + t.Fatalf("second payload = %q, want %q", got, want) + } + if got, want := string(payloads[2]), "single-c"; got != want { + t.Fatalf("third payload = %q, want %q", got, want) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if singleCalls != 1 { + t.Fatalf("single encode calls = %d, want %d", singleCalls, 1) + } +} + func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T) { client := NewClient().(*ClientCommon) if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil { @@ -110,3 +311,216 @@ func TestClientPushMessageFastDispatchesDirectWithRuntimeDispatcher(t *testing.T default: } } + +func TestStreamBatchSenderEncodeRequestsResetsBatchBytesAfterFlushBoundary(t *testing.T) { + var ( + singleCalls int + batchCalls [][]streamFastDataFrame + ) + sender := &streamBatchSender{ + codec: streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + singleCalls++ + return []byte("single"), nil + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + cloned := make([]streamFastDataFrame, len(frames)) + copy(cloned, frames) + batchCalls = append(batchCalls, cloned) + return []byte("batch"), nil + }, + }, + } + largePayload := make([]byte, streamFastBatchMaxPlainBytes-streamFastBatchHeaderLen-streamFastBatchItemHeaderLen-128) + payloads, err := sender.encodeRequests([]streamBatchRequest{ + { + frames: []streamFastDataFrame{{ + DataID: 101, + Seq: 1, + Payload: largePayload, + }}, + fastPathVersion: streamFastPathVersionV2, + }, + { + encodedPayload: []byte("raw"), + hasEncoded: true, + fastPathVersion: streamFastPathVersionV2, + }, + { + frames: []streamFastDataFrame{{ + DataID: 202, + Seq: 1, + Payload: []byte("a"), + }}, + fastPathVersion: streamFastPathVersionV2, + }, + { + frames: []streamFastDataFrame{{ + DataID: 202, + Seq: 2, + Payload: []byte("b"), + }}, + fastPathVersion: streamFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 3; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := singleCalls, 1; got != want { + t.Fatalf("single encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls), 1; got != want { + t.Fatalf("batch encode calls = %d, want %d", got, want) + } + if got, want := len(batchCalls[0]), 2; got != want { + t.Fatalf("post-flush batched frame count = %d, want %d", got, want) + } +} + +func TestStreamBatchSenderEncodeRequestsUsesAdaptiveSoftLimit(t *testing.T) { + binding := &transportBinding{} + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) + + var ( + singleCalls int + batchCalls int + ) + sender := &streamBatchSender{ + binding: binding, + codec: streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + singleCalls++ + return []byte("single"), nil + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + batchCalls++ + return []byte("batch"), nil + }, + }, + } + payload := make([]byte, 160*1024) + payloads, err := sender.encodeRequests([]streamBatchRequest{ + { + frames: []streamFastDataFrame{{ + DataID: 101, + Seq: 1, + Payload: payload, + }}, + fastPathVersion: streamFastPathVersionV2, + }, + { + frames: []streamFastDataFrame{{ + DataID: 102, + Seq: 2, + Payload: payload, + }}, + fastPathVersion: streamFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("encodeRequests failed: %v", err) + } + if got, want := len(payloads), 2; got != want { + t.Fatalf("payload count = %d, want %d", got, want) + } + if got, want := singleCalls, 2; got != want { + t.Fatalf("single encode calls = %d, want %d", got, want) + } + if batchCalls != 0 { + t.Fatalf("batch encode calls = %d, want 0", batchCalls) + } +} + +func TestClientSendFastStreamDataSplitsLargeChunkWhenAdaptiveSoftLimitShrinks(t *testing.T) { + binding := newTransportBinding(&delayedWriteConn{}, stario.NewQueue()) + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) + + var ( + mu sync.Mutex + singleFrames []streamFastDataFrame + batchFrames [][]streamFastDataFrame + ) + binding.streamSender = newStreamBatchSender(binding, streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + mu.Lock() + singleFrames = append(singleFrames, streamFastDataFrame{ + DataID: frame.DataID, + Seq: frame.Seq, + Payload: append([]byte(nil), frame.Payload...), + }) + mu.Unlock() + return []byte{1}, nil + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + cloned := make([]streamFastDataFrame, len(frames)) + for index := range frames { + cloned[index] = streamFastDataFrame{ + DataID: frames[index].DataID, + Seq: frames[index].Seq, + Payload: append([]byte(nil), frames[index].Payload...), + } + } + mu.Lock() + batchFrames = append(batchFrames, cloned) + mu.Unlock() + return []byte{2}, nil + }, + }, nil) + defer binding.stopBackgroundWorkers() + + client := NewClient().(*ClientCommon) + stopCtx, stopFn := context.WithCancel(context.Background()) + defer stopFn() + client.setClientSessionRuntime(&clientSessionRuntime{ + transport: binding, + transportAttached: true, + stopCtx: stopCtx, + stopFn: stopFn, + queue: binding.queueSnapshot(), + inboundDispatcher: newInboundDispatcher(), + suppressGoodByeOnStop: &atomic.Bool{}, + }) + + stream := &streamHandle{ + dataID: 41, + fastPathVersion: streamFastPathVersionV2, + } + chunk := make([]byte, streamFastBatchDirectLimit+128*1024) + for index := range chunk { + chunk[index] = byte(index) + } + + if err := client.sendFastStreamData(context.Background(), stream, chunk); err != nil { + t.Fatalf("sendFastStreamData failed: %v", err) + } + + expectedFrames := streamFastSplitFrameCount(len(chunk), streamAdaptiveFramePayloadLimit(binding)) + if got, want := int(stream.outboundSeq.Load()), expectedFrames; got != want { + t.Fatalf("reserved seq count = %d, want %d", got, want) + } + + mu.Lock() + defer mu.Unlock() + if len(batchFrames) != 0 { + t.Fatalf("batch encode calls = %d, want 0", len(batchFrames)) + } + if got, want := len(singleFrames), expectedFrames; got != want { + t.Fatalf("single frame count = %d, want %d", got, want) + } + rebuilt := make([]byte, 0, len(chunk)) + for index, frame := range singleFrames { + if got, want := frame.DataID, uint64(41); got != want { + t.Fatalf("frame %d data id = %d, want %d", index, got, want) + } + if got, want := frame.Seq, uint64(index+1); got != want { + t.Fatalf("frame %d seq = %d, want %d", index, got, want) + } + rebuilt = append(rebuilt, frame.Payload...) + } + if string(rebuilt) != string(chunk) { + t.Fatal("rebuilt payload does not match original chunk") + } +} diff --git a/stream_flow.go b/stream_flow.go index 4342f19..80ed75c 100644 --- a/stream_flow.go +++ b/stream_flow.go @@ -3,15 +3,19 @@ package notify import ( "context" "sync" + "sync/atomic" ) type streamFlowController struct { - mu sync.Mutex - queue []*streamFlowRequest - inFlightBytes int - inFlightChunks int - windowBytes int - maxChunks int + mu sync.Mutex + + queue []*streamFlowRequest + + inFlightBytes atomic.Int64 + inFlightChunks atomic.Int64 + windowBytes atomic.Int64 + maxChunks atomic.Int64 + waiters atomic.Int32 } type streamFlowRequest struct { @@ -22,10 +26,10 @@ type streamFlowRequest struct { func newStreamFlowController(cfg streamConfig) *streamFlowController { cfg = normalizeStreamConfig(cfg) - return &streamFlowController{ - windowBytes: cfg.OutboundWindowBytes, - maxChunks: cfg.OutboundMaxInFlightChunks, - } + controller := &streamFlowController{} + controller.windowBytes.Store(int64(cfg.OutboundWindowBytes)) + controller.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks)) + return controller } func (c *streamFlowController) applyConfig(cfg streamConfig) { @@ -33,9 +37,12 @@ func (c *streamFlowController) applyConfig(cfg streamConfig) { return } cfg = normalizeStreamConfig(cfg) + c.windowBytes.Store(int64(cfg.OutboundWindowBytes)) + c.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks)) + if c.waiters.Load() == 0 { + return + } c.mu.Lock() - c.windowBytes = cfg.OutboundWindowBytes - c.maxChunks = cfg.OutboundMaxInFlightChunks c.drainLocked() c.mu.Unlock() } @@ -47,58 +54,32 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e if ctx == nil { ctx = context.Background() } + if c.tryAcquire(size) { + return c.releaseFunc(size), nil + } req := &streamFlowRequest{ size: size, ready: make(chan struct{}), } c.mu.Lock() + if c.tryAcquireLocked(size) { + c.mu.Unlock() + return c.releaseFunc(size), nil + } c.queue = append(c.queue, req) + c.waiters.Add(1) c.drainLocked() c.mu.Unlock() select { case <-req.ready: - released := false - return func() { - c.mu.Lock() - if released { - c.mu.Unlock() - return - } - released = true - c.inFlightBytes -= size - if c.inFlightBytes < 0 { - c.inFlightBytes = 0 - } - if c.inFlightChunks > 0 { - c.inFlightChunks-- - } - c.drainLocked() - c.mu.Unlock() - }, nil + return c.releaseFunc(size), nil case <-ctx.Done(): c.mu.Lock() if req.admitted { c.mu.Unlock() - released := false - return func() { - c.mu.Lock() - if released { - c.mu.Unlock() - return - } - released = true - c.inFlightBytes -= size - if c.inFlightBytes < 0 { - c.inFlightBytes = 0 - } - if c.inFlightChunks > 0 { - c.inFlightChunks-- - } - c.drainLocked() - c.mu.Unlock() - }, nil + return c.releaseFunc(size), nil } c.removeLocked(req) c.drainLocked() @@ -107,6 +88,128 @@ func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), e } } +func (c *streamFlowController) tryAcquire(size int) bool { + if c == nil || size <= 0 { + return true + } + if c.waiters.Load() != 0 { + return false + } + return c.tryAcquireCAS(size) +} + +func (c *streamFlowController) tryAcquireLocked(size int) bool { + if c == nil || size <= 0 { + return true + } + if len(c.queue) != 0 { + return false + } + return c.tryAcquireCAS(size) +} + +func (c *streamFlowController) tryAcquireCAS(size int) bool { + if c == nil || size <= 0 { + return true + } + size64 := int64(size) + for { + window := c.windowBytes.Load() + maxChunks := c.maxChunks.Load() + inFlightBytes := c.inFlightBytes.Load() + inFlightChunks := c.inFlightChunks.Load() + + if maxChunks > 0 && inFlightChunks >= maxChunks { + return false + } + if window > 0 && inFlightBytes+size64 > window { + if !(inFlightBytes == 0 && inFlightChunks == 0) { + return false + } + } + if !c.inFlightBytes.CompareAndSwap(inFlightBytes, inFlightBytes+size64) { + continue + } + if c.addChunksCAS(1, maxChunks) { + return true + } + c.subBytesCAS(size64) + return false + } +} + +func (c *streamFlowController) addChunksCAS(delta int64, maxChunks int64) bool { + if c == nil || delta <= 0 { + return true + } + for { + current := c.inFlightChunks.Load() + if maxChunks > 0 && current+delta > maxChunks { + return false + } + if c.inFlightChunks.CompareAndSwap(current, current+delta) { + return true + } + } +} + +func (c *streamFlowController) subBytesCAS(delta int64) { + if c == nil || delta <= 0 { + return + } + for { + current := c.inFlightBytes.Load() + next := current - delta + if next < 0 { + next = 0 + } + if c.inFlightBytes.CompareAndSwap(current, next) { + return + } + } +} + +func (c *streamFlowController) subChunksCAS(delta int64) { + if c == nil || delta <= 0 { + return + } + for { + current := c.inFlightChunks.Load() + next := current - delta + if next < 0 { + next = 0 + } + if c.inFlightChunks.CompareAndSwap(current, next) { + return + } + } +} + +func (c *streamFlowController) releaseFunc(size int) func() { + released := false + return func() { + if released { + return + } + released = true + c.release(size) + } +} + +func (c *streamFlowController) release(size int) { + if c == nil || size <= 0 { + return + } + c.subBytesCAS(int64(size)) + c.subChunksCAS(1) + if c.waiters.Load() == 0 { + return + } + c.mu.Lock() + c.drainLocked() + c.mu.Unlock() +} + func (c *streamFlowController) removeLocked(req *streamFlowRequest) { if c == nil || req == nil { return @@ -118,6 +221,7 @@ func (c *streamFlowController) removeLocked(req *streamFlowRequest) { copy(c.queue[i:], c.queue[i+1:]) c.queue[len(c.queue)-1] = nil c.queue = c.queue[:len(c.queue)-1] + c.waiters.Add(-1) return } } @@ -132,18 +236,17 @@ func (c *streamFlowController) drainLocked() { c.queue = c.queue[1:] continue } - if c.maxChunks > 0 && c.inFlightChunks >= c.maxChunks { + if !c.canAdmitLocked(req.size) { return } - if !c.canAdmitLocked(req.size) { + if !c.tryAcquireCAS(req.size) { return } copy(c.queue[0:], c.queue[1:]) c.queue[len(c.queue)-1] = nil c.queue = c.queue[:len(c.queue)-1] + c.waiters.Add(-1) req.admitted = true - c.inFlightBytes += req.size - c.inFlightChunks++ close(req.ready) } } @@ -155,11 +258,18 @@ func (c *streamFlowController) canAdmitLocked(size int) bool { if size <= 0 { return true } - if c.windowBytes <= 0 { + window := c.windowBytes.Load() + chunks := c.inFlightChunks.Load() + bytes := c.inFlightBytes.Load() + maxChunks := c.maxChunks.Load() + if maxChunks > 0 && chunks >= maxChunks { + return false + } + if window <= 0 { return true } - if c.inFlightBytes+size <= c.windowBytes { + if bytes+int64(size) <= window { return true } - return c.inFlightBytes == 0 && c.inFlightChunks == 0 + return bytes == 0 && chunks == 0 } diff --git a/stream_flow_test.go b/stream_flow_test.go index da532ac..d9ca1c7 100644 --- a/stream_flow_test.go +++ b/stream_flow_test.go @@ -105,3 +105,51 @@ func TestStreamFlowControllerAdmitsRequestsFIFO(t *testing.T) { t.Fatalf("second admitted request = %d, want 3", second) } } + +func TestStreamFlowControllerTryAcquireDoesNotBypassQueuedWaiter(t *testing.T) { + controller := newStreamFlowController(streamConfig{ + ChunkSize: 4, + InboundQueueLimit: 1, + InboundBufferedBytesLimit: 4, + OutboundWindowBytes: 4, + OutboundMaxInFlightChunks: 1, + }) + + releaseFirst, err := controller.acquire(context.Background(), 4) + if err != nil { + t.Fatalf("first acquire failed: %v", err) + } + defer releaseFirst() + + waiterReady := make(chan struct{}) + waiterAcquired := make(chan func(), 1) + go func() { + close(waiterReady) + releaseSecond, err := controller.acquire(context.Background(), 4) + if err != nil { + t.Errorf("waiter acquire failed: %v", err) + return + } + waiterAcquired <- releaseSecond + }() + + <-waiterReady + time.Sleep(20 * time.Millisecond) + + if controller.tryAcquire(4) { + t.Fatal("tryAcquire should not bypass queued waiter") + } + + releaseFirst() + + var releaseSecond func() + select { + case releaseSecond = <-waiterAcquired: + case <-time.After(time.Second): + t.Fatal("timed out waiting for queued waiter to acquire") + } + if releaseSecond == nil { + t.Fatal("queued waiter returned nil release func") + } + releaseSecond() +} diff --git a/stream_runtime.go b/stream_runtime.go index 39bea38..a530a29 100644 --- a/stream_runtime.go +++ b/stream_runtime.go @@ -3,7 +3,6 @@ package notify import ( "context" "fmt" - "strconv" "strings" "sync" "sync/atomic" @@ -17,7 +16,7 @@ type streamRuntime struct { mu sync.RWMutex handler func(StreamAcceptInfo) error streams map[string]*streamHandle - data map[string]*streamHandle + data map[string]map[uint64]*streamHandle cfg streamConfig flow *streamFlowController } @@ -27,7 +26,7 @@ func newStreamRuntime(rolePrefix string) *streamRuntime { return &streamRuntime{ rolePrefix: rolePrefix, streams: make(map[string]*streamHandle), - data: make(map[string]*streamHandle), + data: make(map[string]map[uint64]*streamHandle), cfg: cfg, flow: newStreamFlowController(cfg), } @@ -72,18 +71,23 @@ func (r *streamRuntime) register(scope string, stream *streamHandle) error { if stream == nil || stream.id == "" { return errStreamIDEmpty } + scope = normalizeFileScope(scope) key := streamRuntimeKey(scope, stream.id) - dataKey := streamRuntimeDataKey(scope, stream.dataID) r.mu.Lock() defer r.mu.Unlock() if _, ok := r.streams[key]; ok { return errStreamAlreadyExists } if stream.dataID != 0 { - if _, ok := r.data[dataKey]; ok { + dataScope := r.data[scope] + if dataScope == nil { + dataScope = make(map[uint64]*streamHandle) + r.data[scope] = dataScope + } + if _, ok := dataScope[stream.dataID]; ok { return errStreamAlreadyExists } - r.data[dataKey] = stream + dataScope[stream.dataID] = stream } r.streams[key] = stream return nil @@ -104,10 +108,14 @@ func (r *streamRuntime) lookupByDataID(scope string, dataID uint64) (*streamHand if r == nil || dataID == 0 { return nil, false } - key := streamRuntimeDataKey(scope, dataID) + scope = normalizeFileScope(scope) r.mu.RLock() defer r.mu.RUnlock() - stream, ok := r.data[key] + dataScope := r.data[scope] + if dataScope == nil { + return nil, false + } + stream, ok := dataScope[dataID] return stream, ok } @@ -115,11 +123,17 @@ func (r *streamRuntime) remove(scope string, streamID string) { if r == nil || streamID == "" { return } + scope = normalizeFileScope(scope) key := streamRuntimeKey(scope, streamID) r.mu.Lock() defer r.mu.Unlock() if stream := r.streams[key]; stream != nil && stream.dataID != 0 { - delete(r.data, streamRuntimeDataKey(scope, stream.dataID)) + if dataScope := r.data[scope]; dataScope != nil { + delete(dataScope, stream.dataID) + if len(dataScope) == 0 { + delete(r.data, scope) + } + } } delete(r.streams, key) } @@ -131,6 +145,20 @@ func (r *streamRuntime) acquireOutbound(ctx context.Context, size int) (func(), return r.flow.acquire(ctx, size) } +func (r *streamRuntime) tryAcquireOutbound(size int) bool { + if r == nil || r.flow == nil { + return true + } + return r.flow.tryAcquire(size) +} + +func (r *streamRuntime) releaseOutbound(size int) { + if r == nil || r.flow == nil { + return + } + r.flow.release(size) +} + func (r *streamRuntime) snapshots() []StreamSnapshot { if r == nil { return nil @@ -182,10 +210,6 @@ func streamRuntimeKey(scope string, streamID string) string { return normalizeFileScope(scope) + "\x00" + streamID } -func streamRuntimeDataKey(scope string, dataID uint64) string { - return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10) -} - func (c *ClientCommon) getStreamRuntime() *streamRuntime { if c == nil { return nil diff --git a/stream_shared_batch.go b/stream_shared_batch.go new file mode 100644 index 0000000..5b0fcac --- /dev/null +++ b/stream_shared_batch.go @@ -0,0 +1,145 @@ +package notify + +import "encoding/binary" + +const ( + streamFastPathVersionV1 = 1 + streamFastPathVersionV2 = 2 + streamFastPathVersionCurrent = streamFastPathVersionV2 +) + +const ( + streamFastBatchMagic = "NSB1" + streamFastBatchVersion = 1 + streamFastBatchHeaderLen = 12 + streamFastBatchItemHeaderLen = 24 + streamFastBatchMaxItems = 64 + streamFastBatchMaxPlainBytes = 8 * 1024 * 1024 +) + +func normalizeStreamFastPathVersion(version uint8) uint8 { + if version < streamFastPathVersionV1 { + return streamFastPathVersionV1 + } + if version > streamFastPathVersionCurrent { + return streamFastPathVersionCurrent + } + return version +} + +func negotiateStreamFastPathVersion(version uint8) uint8 { + return normalizeStreamFastPathVersion(version) +} + +func streamFastPathSupportsBatch(version uint8) bool { + return normalizeStreamFastPathVersion(version) >= streamFastPathVersionV2 +} + +func streamFastBatchFrameLen(frame streamFastDataFrame) int { + return streamFastBatchItemHeaderLen + len(frame.Payload) +} + +func streamFastBatchPlainLen(frames []streamFastDataFrame) int { + total := streamFastBatchHeaderLen + for _, frame := range frames { + total += streamFastBatchFrameLen(frame) + } + return total +} + +func encodeStreamFastBatchPlain(frames []streamFastDataFrame) ([]byte, error) { + if len(frames) == 0 { + return nil, errStreamFastPayloadInvalid + } + buf := make([]byte, streamFastBatchPlainLen(frames)) + if err := writeStreamFastBatchPlain(buf, frames); err != nil { + return nil, err + } + return buf, nil +} + +func encodeStreamFastBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, frames []streamFastDataFrame) ([]byte, error) { + if encode == nil { + return nil, errTransportPayloadEncryptFailed + } + plainLen := streamFastBatchPlainLen(frames) + return encode(secretKey, plainLen, func(dst []byte) error { + return writeStreamFastBatchPlain(dst, frames) + }) +} + +func writeStreamFastBatchPlain(dst []byte, frames []streamFastDataFrame) error { + if len(frames) == 0 || len(dst) != streamFastBatchPlainLen(frames) { + return errStreamFastPayloadInvalid + } + copy(dst[:4], streamFastBatchMagic) + dst[4] = streamFastBatchVersion + binary.BigEndian.PutUint32(dst[8:12], uint32(len(frames))) + offset := streamFastBatchHeaderLen + for _, frame := range frames { + if frame.DataID == 0 { + return errStreamFastPayloadInvalid + } + dst[offset] = frame.Flags + binary.BigEndian.PutUint64(dst[offset+4:offset+12], frame.DataID) + binary.BigEndian.PutUint64(dst[offset+12:offset+20], frame.Seq) + binary.BigEndian.PutUint32(dst[offset+20:offset+24], uint32(len(frame.Payload))) + offset += streamFastBatchItemHeaderLen + copy(dst[offset:offset+len(frame.Payload)], frame.Payload) + offset += len(frame.Payload) + } + return nil +} + +func decodeStreamFastBatchPlain(payload []byte) ([]streamFastDataFrame, bool, error) { + if len(payload) < 4 || string(payload[:4]) != streamFastBatchMagic { + return nil, false, nil + } + if len(payload) < streamFastBatchHeaderLen { + return nil, true, errStreamFastPayloadInvalid + } + if payload[4] != streamFastBatchVersion { + return nil, true, errStreamFastPayloadInvalid + } + count := int(binary.BigEndian.Uint32(payload[8:12])) + if count <= 0 { + return nil, true, errStreamFastPayloadInvalid + } + frames := make([]streamFastDataFrame, 0, count) + offset := streamFastBatchHeaderLen + for index := 0; index < count; index++ { + if len(payload)-offset < streamFastBatchItemHeaderLen { + return nil, true, errStreamFastPayloadInvalid + } + flags := payload[offset] + dataID := binary.BigEndian.Uint64(payload[offset+4 : offset+12]) + seq := binary.BigEndian.Uint64(payload[offset+12 : offset+20]) + payloadLen := int(binary.BigEndian.Uint32(payload[offset+20 : offset+24])) + offset += streamFastBatchItemHeaderLen + if dataID == 0 || payloadLen < 0 || len(payload)-offset < payloadLen { + return nil, true, errStreamFastPayloadInvalid + } + frames = append(frames, streamFastDataFrame{ + Flags: flags, + DataID: dataID, + Seq: seq, + Payload: payload[offset : offset+payloadLen], + }) + offset += payloadLen + } + if offset != len(payload) { + return nil, true, errStreamFastPayloadInvalid + } + return frames, true, nil +} + +func decodeStreamFastDataFrames(payload []byte) ([]streamFastDataFrame, bool, error) { + if frames, matched, err := decodeStreamFastBatchPlain(payload); matched { + return frames, true, err + } + frame, matched, err := decodeStreamFastDataFrame(payload) + if !matched || err != nil { + return nil, matched, err + } + return []streamFastDataFrame{frame}, true, nil +} diff --git a/stream_snapshot.go b/stream_snapshot.go index abb5594..2cb9c6b 100644 --- a/stream_snapshot.go +++ b/stream_snapshot.go @@ -7,48 +7,52 @@ import ( ) type StreamSnapshot struct { - ID string - DataID uint64 - Scope string - Channel StreamChannel - Metadata StreamMetadata - BindingOwner string - BindingAlive bool - BindingCurrent bool - BindingReason string - BindingError string - SessionEpoch uint64 - LogicalClientID string - LocalAddress string - RemoteAddress string - TransportGeneration uint64 - TransportAttached bool - TransportHasRuntimeConn bool - TransportCurrent bool - TransportDetachReason string - TransportDetachKind string - TransportDetachGeneration uint64 - TransportDetachError string - TransportDetachedAt time.Time - ReattachEligible bool - LocalClosed bool - LocalReadClosed bool - RemoteClosed bool - PeerReadClosed bool - BufferedChunks int - BufferedBytes int - ReadTimeout time.Duration - WriteTimeout time.Duration - BytesRead int64 - BytesWritten int64 - ReadCalls int64 - WriteCalls int64 - OpenedAt time.Time - LastReadAt time.Time - LastWriteAt time.Time - ReadDeadline time.Time - WriteDeadline time.Time - ResetError string + ID string + DataID uint64 + Scope string + Channel StreamChannel + Metadata StreamMetadata + BindingOwner string + BindingAlive bool + BindingCurrent bool + BindingReason string + BindingError string + BindingBulkAdaptiveSoftPayloadBytes int + BindingStreamAdaptiveSoftPayloadBytes int + BindingStreamAdaptiveWaitThresholdBytes int + BindingStreamAdaptiveFlushDelay time.Duration + SessionEpoch uint64 + LogicalClientID string + LocalAddress string + RemoteAddress string + TransportGeneration uint64 + TransportAttached bool + TransportHasRuntimeConn bool + TransportCurrent bool + TransportDetachReason string + TransportDetachKind string + TransportDetachGeneration uint64 + TransportDetachError string + TransportDetachedAt time.Time + ReattachEligible bool + LocalClosed bool + LocalReadClosed bool + RemoteClosed bool + PeerReadClosed bool + BufferedChunks int + BufferedBytes int + ReadTimeout time.Duration + WriteTimeout time.Duration + BytesRead int64 + BytesWritten int64 + ReadCalls int64 + WriteCalls int64 + OpenedAt time.Time + LastReadAt time.Time + LastWriteAt time.Time + ReadDeadline time.Time + WriteDeadline time.Time + ResetError string } type clientStreamSnapshotReader interface { diff --git a/timeout_error_test.go b/timeout_error.go similarity index 100% rename from timeout_error_test.go rename to timeout_error.go diff --git a/transport_binding.go b/transport_binding.go index 4995381..6c80a45 100644 --- a/transport_binding.go +++ b/transport_binding.go @@ -15,9 +15,14 @@ type transportBinding struct { queue *stario.StarQueue writeMu sync.Mutex + adaptiveTx adaptiveTxState + controlMu sync.Mutex controlSender *controlBatchSender + streamMu sync.Mutex + streamSender *streamBatchSender + bulkMu sync.Mutex bulkSender *bulkBatchSender } @@ -71,7 +76,7 @@ func (b *transportBinding) withConnWriteLockDeadline(deadline time.Time, fn func return fn(conn) } -func (b *transportBinding) bulkBatchSenderSnapshot() *bulkBatchSender { +func (b *transportBinding) bulkBatchSenderSnapshotWithCodec(codec bulkBatchCodec, writeTimeout func() time.Duration) *bulkBatchSender { if b == nil { return nil } @@ -80,10 +85,39 @@ func (b *transportBinding) bulkBatchSenderSnapshot() *bulkBatchSender { if b.bulkSender != nil { return b.bulkSender } - b.bulkSender = newBulkBatchSender(b) + b.bulkSender = newBulkBatchSender(b, codec, writeTimeout) return b.bulkSender } +func (b *transportBinding) clientBulkBatchSenderSnapshot(c *ClientCommon) *bulkBatchSender { + if b == nil || c == nil { + return nil + } + return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ + encodeSingle: c.encodeBulkFastPayloadPooled, + encodeBatch: c.encodeBulkFastBatchPayloadPooled, + }, c.maxWriteTimeoutSnapshot) +} + +func (b *transportBinding) serverBulkBatchSenderSnapshot(logical *LogicalConn) *bulkBatchSender { + if b == nil || logical == nil { + return nil + } + server := logical.Server() + common, ok := server.(*ServerCommon) + if !ok || common == nil { + return nil + } + return b.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + return common.encodeBulkFastPayloadLogicalPooled(logical, frame) + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + return common.encodeBulkFastBatchPayloadLogicalPooled(logical, frames) + }, + }, logical.maxWriteTimeoutSnapshot) +} + func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender { if b == nil { return nil @@ -97,6 +131,48 @@ func (b *transportBinding) controlBatchSenderSnapshot() *controlBatchSender { return b.controlSender } +func (b *transportBinding) streamBatchSenderSnapshotWithCodec(codec streamBatchCodec, writeTimeout func() time.Duration) *streamBatchSender { + if b == nil { + return nil + } + b.streamMu.Lock() + defer b.streamMu.Unlock() + if b.streamSender != nil { + return b.streamSender + } + b.streamSender = newStreamBatchSender(b, codec, writeTimeout) + return b.streamSender +} + +func (b *transportBinding) clientStreamBatchSenderSnapshot(c *ClientCommon) *streamBatchSender { + if b == nil || c == nil { + return nil + } + return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{ + encodeSingle: c.encodeFastStreamPayload, + encodeBatch: c.encodeFastStreamBatchPayload, + }, c.maxWriteTimeoutSnapshot) +} + +func (b *transportBinding) serverStreamBatchSenderSnapshot(logical *LogicalConn) *streamBatchSender { + if b == nil || logical == nil { + return nil + } + server := logical.Server() + common, ok := server.(*ServerCommon) + if !ok || common == nil { + return nil + } + return b.streamBatchSenderSnapshotWithCodec(streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + return common.encodeFastStreamPayloadLogical(logical, frame) + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + return common.encodeFastStreamBatchPayloadLogical(logical, frames) + }, + }, logical.maxWriteTimeoutSnapshot) +} + func (b *transportBinding) stopBackgroundWorkers() { if b == nil { return @@ -104,12 +180,18 @@ func (b *transportBinding) stopBackgroundWorkers() { b.controlMu.Lock() controlSender := b.controlSender b.controlMu.Unlock() + b.streamMu.Lock() + streamSender := b.streamSender + b.streamMu.Unlock() b.bulkMu.Lock() bulkSender := b.bulkSender b.bulkMu.Unlock() if controlSender != nil { controlSender.stop() } + if streamSender != nil { + streamSender.stop() + } if bulkSender != nil { bulkSender.stop() } diff --git a/transport_binding_adaptive.go b/transport_binding_adaptive.go new file mode 100644 index 0000000..7562407 --- /dev/null +++ b/transport_binding_adaptive.go @@ -0,0 +1,383 @@ +package notify + +import ( + "sync" + "time" +) + +const ( + bulkAdaptiveSoftPayloadMinBytes = 256 * 1024 + bulkAdaptiveSoftPayloadFallbackBytes = 2 * 1024 * 1024 + bulkAdaptiveSoftPayloadStartBytes = bulkFastBatchMaxPlainBytes + bulkAdaptiveSoftPayloadTargetFlush = 4 * time.Millisecond + bulkAdaptiveSoftPayloadSlowFlush = 16 * time.Millisecond + bulkAdaptiveSoftPayloadMinSampleBytes = 256 * 1024 + bulkAdaptiveSoftPayloadGrowSuccesses = 8 + + streamAdaptiveSoftPayloadMinBytes = 256 * 1024 + streamAdaptiveSoftPayloadFallbackBytes = streamBatchMaxPayloadBytes + streamAdaptiveSoftPayloadStartBytes = streamBatchMaxPayloadBytes + streamAdaptiveSoftPayloadTargetFlush = 4 * time.Millisecond + streamAdaptiveSoftPayloadSlowFlush = 16 * time.Millisecond + streamAdaptiveSoftPayloadMinSampleBytes = 64 * 1024 + streamAdaptiveSoftPayloadGrowSuccesses = 8 + + streamAdaptiveWaitThresholdMinBytes = 32 * 1024 + streamAdaptiveFlushDelayMid = 25 * time.Microsecond +) + +var bulkAdaptiveSoftPayloadSteps = [...]int{ + 256 * 1024, + 512 * 1024, + 1024 * 1024, + 2 * 1024 * 1024, + 4 * 1024 * 1024, + bulkFastBatchMaxPlainBytes, +} + +var streamAdaptiveSoftPayloadSteps = [...]int{ + 256 * 1024, + 512 * 1024, + 1024 * 1024, + streamBatchMaxPayloadBytes, +} + +type adaptiveTxState struct { + mu sync.Mutex + + bulkSoftPayloadBytes int + bulkGoodputBytesPerS float64 + bulkGrowStreak int + + streamSoftPayloadBytes int + streamGoodputBytesPerS float64 + streamGrowStreak int +} + +func (b *transportBinding) bulkAdaptiveSoftPayloadBytesSnapshot() int { + if b == nil { + return bulkAdaptiveSoftPayloadFallbackBytes + } + return b.adaptiveTx.bulkSoftPayloadBytesSnapshot() +} + +func (b *transportBinding) observeBulkAdaptivePayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) { + if b == nil { + return + } + b.adaptiveTx.observeBulkPayloadWrite(payloadBytes, elapsed, timeout, err) +} + +func (b *transportBinding) streamAdaptiveSoftPayloadBytesSnapshot() int { + if b == nil { + return streamAdaptiveSoftPayloadFallbackBytes + } + return b.adaptiveTx.streamSoftPayloadBytesSnapshot() +} + +func (b *transportBinding) streamAdaptiveWaitThresholdBytesSnapshot() int { + if b == nil { + return streamBatchWaitThreshold + } + return b.adaptiveTx.streamWaitThresholdBytesSnapshot() +} + +func (b *transportBinding) streamAdaptiveFlushDelaySnapshot() time.Duration { + if b == nil { + return streamBatchMaxFlushDelay + } + return b.adaptiveTx.streamFlushDelaySnapshot() +} + +func (b *transportBinding) observeStreamAdaptivePayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) { + if b == nil { + return + } + b.adaptiveTx.observeStreamPayloadWrite(payloadBytes, elapsed, timeout, err) +} + +func (s *adaptiveTxState) bulkSoftPayloadBytesSnapshot() int { + if s == nil { + return bulkAdaptiveSoftPayloadStartBytes + } + s.mu.Lock() + defer s.mu.Unlock() + return s.bulkSoftPayloadBytesLocked() +} + +func (s *adaptiveTxState) observeBulkPayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) { + if s == nil || payloadBytes <= 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + + current := s.bulkSoftPayloadBytesLocked() + target, hasSample := s.observeBulkGoodputLocked(payloadBytes, elapsed) + nearTimeout := timeout > 0 && elapsed >= (timeout*3)/4 + if isTimeoutLikeError(err) || nearTimeout { + s.bulkGrowStreak = 0 + if hasSample && target < current { + s.bulkSoftPayloadBytes = target + return + } + s.bulkSoftPayloadBytes = previousBulkAdaptiveSoftPayloadStep(current) + return + } + if err != nil { + s.bulkGrowStreak = 0 + return + } + if !hasSample { + return + } + if elapsed >= bulkAdaptiveSoftPayloadSlowFlush { + s.bulkGrowStreak = 0 + if target < current { + s.bulkSoftPayloadBytes = target + return + } + s.bulkSoftPayloadBytes = previousBulkAdaptiveSoftPayloadStep(current) + return + } + if target > current { + s.bulkGrowStreak++ + if s.bulkGrowStreak >= bulkAdaptiveSoftPayloadGrowSuccesses { + s.bulkSoftPayloadBytes = nextBulkAdaptiveSoftPayloadStep(current, target) + s.bulkGrowStreak = 0 + } + return + } + if target < current && elapsed >= bulkAdaptiveSoftPayloadTargetFlush*2 { + s.bulkSoftPayloadBytes = target + s.bulkGrowStreak = 0 + return + } + s.bulkGrowStreak = 0 +} + +func (s *adaptiveTxState) bulkSoftPayloadBytesLocked() int { + if s.bulkSoftPayloadBytes == 0 { + s.bulkSoftPayloadBytes = bulkAdaptiveSoftPayloadStartBytes + } + return normalizeBulkAdaptiveSoftPayloadBytes(s.bulkSoftPayloadBytes) +} + +func (s *adaptiveTxState) streamSoftPayloadBytesSnapshot() int { + if s == nil { + return streamAdaptiveSoftPayloadStartBytes + } + s.mu.Lock() + defer s.mu.Unlock() + return s.streamSoftPayloadBytesLocked() +} + +func (s *adaptiveTxState) streamWaitThresholdBytesSnapshot() int { + if s == nil { + return streamBatchWaitThreshold + } + s.mu.Lock() + defer s.mu.Unlock() + return streamAdaptiveWaitThresholdBytesForSoftPayload(s.streamSoftPayloadBytesLocked()) +} + +func (s *adaptiveTxState) streamFlushDelaySnapshot() time.Duration { + if s == nil { + return streamBatchMaxFlushDelay + } + s.mu.Lock() + defer s.mu.Unlock() + return streamAdaptiveFlushDelayForSoftPayload(s.streamSoftPayloadBytesLocked()) +} + +func (s *adaptiveTxState) observeStreamPayloadWrite(payloadBytes int, elapsed time.Duration, timeout time.Duration, err error) { + if s == nil || payloadBytes <= 0 { + return + } + s.mu.Lock() + defer s.mu.Unlock() + + current := s.streamSoftPayloadBytesLocked() + target, hasSample := s.observeStreamGoodputLocked(payloadBytes, elapsed) + nearTimeout := timeout > 0 && elapsed >= (timeout*3)/4 + if isTimeoutLikeError(err) || nearTimeout { + s.streamGrowStreak = 0 + if hasSample && target < current { + s.streamSoftPayloadBytes = target + return + } + s.streamSoftPayloadBytes = previousStreamAdaptiveSoftPayloadStep(current) + return + } + if err != nil { + s.streamGrowStreak = 0 + return + } + if !hasSample { + return + } + if elapsed >= streamAdaptiveSoftPayloadSlowFlush { + s.streamGrowStreak = 0 + if target < current { + s.streamSoftPayloadBytes = target + return + } + s.streamSoftPayloadBytes = previousStreamAdaptiveSoftPayloadStep(current) + return + } + if target > current { + s.streamGrowStreak++ + if s.streamGrowStreak >= streamAdaptiveSoftPayloadGrowSuccesses { + s.streamSoftPayloadBytes = nextStreamAdaptiveSoftPayloadStep(current, target) + s.streamGrowStreak = 0 + } + return + } + if target < current && elapsed >= streamAdaptiveSoftPayloadTargetFlush*2 { + s.streamSoftPayloadBytes = target + s.streamGrowStreak = 0 + return + } + s.streamGrowStreak = 0 +} + +func (s *adaptiveTxState) streamSoftPayloadBytesLocked() int { + if s.streamSoftPayloadBytes == 0 { + s.streamSoftPayloadBytes = streamAdaptiveSoftPayloadStartBytes + } + return normalizeStreamAdaptiveSoftPayloadBytes(s.streamSoftPayloadBytes) +} + +func (s *adaptiveTxState) observeBulkGoodputLocked(payloadBytes int, elapsed time.Duration) (int, bool) { + if payloadBytes < bulkAdaptiveSoftPayloadMinSampleBytes || elapsed <= 0 { + return 0, false + } + sample := float64(payloadBytes) / elapsed.Seconds() + if sample <= 0 { + return 0, false + } + if s.bulkGoodputBytesPerS <= 0 { + s.bulkGoodputBytesPerS = sample + } else { + const alpha = 0.25 + s.bulkGoodputBytesPerS = s.bulkGoodputBytesPerS*(1-alpha) + sample*alpha + } + target := int(s.bulkGoodputBytesPerS * bulkAdaptiveSoftPayloadTargetFlush.Seconds()) + return normalizeBulkAdaptiveSoftPayloadBytes(target), true +} + +func (s *adaptiveTxState) observeStreamGoodputLocked(payloadBytes int, elapsed time.Duration) (int, bool) { + if payloadBytes < streamAdaptiveSoftPayloadMinSampleBytes || elapsed <= 0 { + return 0, false + } + sample := float64(payloadBytes) / elapsed.Seconds() + if sample <= 0 { + return 0, false + } + if s.streamGoodputBytesPerS <= 0 { + s.streamGoodputBytesPerS = sample + } else { + const alpha = 0.25 + s.streamGoodputBytesPerS = s.streamGoodputBytesPerS*(1-alpha) + sample*alpha + } + target := int(s.streamGoodputBytesPerS * streamAdaptiveSoftPayloadTargetFlush.Seconds()) + return normalizeStreamAdaptiveSoftPayloadBytes(target), true +} + +func normalizeBulkAdaptiveSoftPayloadBytes(size int) int { + if size <= bulkAdaptiveSoftPayloadMinBytes { + return bulkAdaptiveSoftPayloadMinBytes + } + for _, step := range bulkAdaptiveSoftPayloadSteps { + if size <= step { + return step + } + } + return bulkAdaptiveSoftPayloadStartBytes +} + +func previousBulkAdaptiveSoftPayloadStep(current int) int { + current = normalizeBulkAdaptiveSoftPayloadBytes(current) + for index := len(bulkAdaptiveSoftPayloadSteps) - 1; index >= 0; index-- { + step := bulkAdaptiveSoftPayloadSteps[index] + if current > step { + return step + } + } + return bulkAdaptiveSoftPayloadMinBytes +} + +func nextBulkAdaptiveSoftPayloadStep(current int, target int) int { + current = normalizeBulkAdaptiveSoftPayloadBytes(current) + target = normalizeBulkAdaptiveSoftPayloadBytes(target) + for _, step := range bulkAdaptiveSoftPayloadSteps { + if step > current { + if step > target { + return target + } + return step + } + } + return bulkAdaptiveSoftPayloadStartBytes +} + +func normalizeStreamAdaptiveSoftPayloadBytes(size int) int { + if size <= streamAdaptiveSoftPayloadMinBytes { + return streamAdaptiveSoftPayloadMinBytes + } + for _, step := range streamAdaptiveSoftPayloadSteps { + if size <= step { + return step + } + } + return streamAdaptiveSoftPayloadStartBytes +} + +func previousStreamAdaptiveSoftPayloadStep(current int) int { + current = normalizeStreamAdaptiveSoftPayloadBytes(current) + for index := len(streamAdaptiveSoftPayloadSteps) - 1; index >= 0; index-- { + step := streamAdaptiveSoftPayloadSteps[index] + if current > step { + return step + } + } + return streamAdaptiveSoftPayloadMinBytes +} + +func nextStreamAdaptiveSoftPayloadStep(current int, target int) int { + current = normalizeStreamAdaptiveSoftPayloadBytes(current) + target = normalizeStreamAdaptiveSoftPayloadBytes(target) + for _, step := range streamAdaptiveSoftPayloadSteps { + if step > current { + if step > target { + return target + } + return step + } + } + return streamAdaptiveSoftPayloadStartBytes +} + +func streamAdaptiveWaitThresholdBytesForSoftPayload(size int) int { + size = normalizeStreamAdaptiveSoftPayloadBytes(size) + threshold := size / 16 + if threshold < streamAdaptiveWaitThresholdMinBytes { + return streamAdaptiveWaitThresholdMinBytes + } + if threshold > streamBatchWaitThreshold { + return streamBatchWaitThreshold + } + return threshold +} + +func streamAdaptiveFlushDelayForSoftPayload(size int) time.Duration { + size = normalizeStreamAdaptiveSoftPayloadBytes(size) + switch { + case size >= streamAdaptiveSoftPayloadStartBytes: + return streamBatchMaxFlushDelay + case size >= 1024*1024: + return streamAdaptiveFlushDelayMid + default: + return 0 + } +} diff --git a/transport_binding_adaptive_test.go b/transport_binding_adaptive_test.go new file mode 100644 index 0000000..f5ac4e4 --- /dev/null +++ b/transport_binding_adaptive_test.go @@ -0,0 +1,160 @@ +package notify + +import ( + "b612.me/stario" + "net" + "testing" + "time" +) + +func TestTransportBindingAdaptiveBulkSoftPayloadStartsAggressive(t *testing.T) { + binding := &transportBinding{} + if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want) + } +} + +func TestTransportBindingAdaptiveBulkSoftPayloadShrinksAfterSlowWrite(t *testing.T) { + binding := &transportBinding{} + binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil) + if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadMinBytes; got != want { + t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want) + } +} + +func TestTransportBindingAdaptiveBulkSoftPayloadRecoversAfterGoodWrites(t *testing.T) { + binding := &transportBinding{} + binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 640*time.Millisecond, 0, nil) + samples := bulkAdaptiveSoftPayloadGrowSuccesses * (len(bulkAdaptiveSoftPayloadSteps) - 1) + for i := 0; i < samples; i++ { + binding.observeBulkAdaptivePayloadWrite(8*1024*1024, 6*time.Millisecond, 0, nil) + } + if got, want := binding.bulkAdaptiveSoftPayloadBytesSnapshot(), bulkAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("adaptive bulk soft payload = %d, want %d", got, want) + } +} + +type delayedWriteConn struct { + delay time.Duration +} + +func (c *delayedWriteConn) Read([]byte) (int, error) { return 0, net.ErrClosed } +func (c *delayedWriteConn) Write(p []byte) (int, error) { time.Sleep(c.delay); return len(p), nil } +func (c *delayedWriteConn) Close() error { return nil } +func (c *delayedWriteConn) LocalAddr() net.Addr { return nil } +func (c *delayedWriteConn) RemoteAddr() net.Addr { return nil } +func (c *delayedWriteConn) SetDeadline(time.Time) error { return nil } +func (c *delayedWriteConn) SetReadDeadline(time.Time) error { return nil } +func (c *delayedWriteConn) SetWriteDeadline(time.Time) error { return nil } + +func TestBulkBatchSenderFlushAdaptsToSlowWrites(t *testing.T) { + binding := newTransportBinding(&delayedWriteConn{delay: 20 * time.Millisecond}, stario.NewQueue()) + sender := newTestBulkBatchSender(binding) + defer sender.stop() + + payload := make([]byte, 512*1024) + err := sender.flush([]bulkBatchRequest{ + { + frames: []bulkFastFrame{ + {Type: bulkFastPayloadTypeData, DataID: 101, Seq: 1, Payload: payload}, + {Type: bulkFastPayloadTypeData, DataID: 101, Seq: 2, Payload: payload}, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + { + frames: []bulkFastFrame{ + {Type: bulkFastPayloadTypeData, DataID: 202, Seq: 1, Payload: payload}, + {Type: bulkFastPayloadTypeData, DataID: 202, Seq: 2, Payload: payload}, + }, + fastPathVersion: bulkFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("flush failed: %v", err) + } + if got := binding.bulkAdaptiveSoftPayloadBytesSnapshot(); got >= bulkAdaptiveSoftPayloadStartBytes { + t.Fatalf("adaptive bulk soft payload = %d, want smaller than %d after slow write", got, bulkAdaptiveSoftPayloadStartBytes) + } +} + +func TestTransportBindingAdaptiveStreamStartsAggressive(t *testing.T) { + binding := &transportBinding{} + if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("adaptive stream soft payload = %d, want %d", got, want) + } + if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamBatchWaitThreshold; got != want { + t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want) + } + if got, want := binding.streamAdaptiveFlushDelaySnapshot(), streamBatchMaxFlushDelay; got != want { + t.Fatalf("adaptive stream flush delay = %s, want %s", got, want) + } +} + +func TestTransportBindingAdaptiveStreamShrinksAfterSlowWrite(t *testing.T) { + binding := &transportBinding{} + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) + if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadMinBytes; got != want { + t.Fatalf("adaptive stream soft payload = %d, want %d", got, want) + } + if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamAdaptiveWaitThresholdMinBytes; got != want { + t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want) + } + if got := binding.streamAdaptiveFlushDelaySnapshot(); got != 0 { + t.Fatalf("adaptive stream flush delay = %s, want 0", got) + } +} + +func TestTransportBindingAdaptiveStreamRecoversAfterGoodWrites(t *testing.T) { + binding := &transportBinding{} + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 640*time.Millisecond, 0, nil) + samples := streamAdaptiveSoftPayloadGrowSuccesses * (len(streamAdaptiveSoftPayloadSteps) - 1) + for i := 0; i < samples; i++ { + binding.observeStreamAdaptivePayloadWrite(2*1024*1024, 6*time.Millisecond, 0, nil) + } + if got, want := binding.streamAdaptiveSoftPayloadBytesSnapshot(), streamAdaptiveSoftPayloadStartBytes; got != want { + t.Fatalf("adaptive stream soft payload = %d, want %d", got, want) + } + if got, want := binding.streamAdaptiveWaitThresholdBytesSnapshot(), streamBatchWaitThreshold; got != want { + t.Fatalf("adaptive stream wait threshold = %d, want %d", got, want) + } + if got, want := binding.streamAdaptiveFlushDelaySnapshot(), streamBatchMaxFlushDelay; got != want { + t.Fatalf("adaptive stream flush delay = %s, want %s", got, want) + } +} + +func TestStreamBatchSenderFlushAdaptsToSlowWrites(t *testing.T) { + binding := newTransportBinding(&delayedWriteConn{delay: 20 * time.Millisecond}, stario.NewQueue()) + sender := newTestStreamBatchSender(binding, nil) + defer sender.stop() + + payload := make([]byte, 512*1024) + err := sender.flush([]streamBatchRequest{ + { + frame: streamFastDataFrame{ + DataID: 101, + Seq: 1, + Payload: payload, + }, + hasFrame: true, + fastPathVersion: streamFastPathVersionV2, + }, + { + frame: streamFastDataFrame{ + DataID: 202, + Seq: 1, + Payload: payload, + }, + hasFrame: true, + fastPathVersion: streamFastPathVersionV2, + }, + }) + if err != nil { + t.Fatalf("flush failed: %v", err) + } + if got := binding.streamAdaptiveSoftPayloadBytesSnapshot(); got >= streamAdaptiveSoftPayloadStartBytes { + t.Fatalf("adaptive stream soft payload = %d, want smaller than %d after slow write", got, streamAdaptiveSoftPayloadStartBytes) + } + if got := binding.streamAdaptiveWaitThresholdBytesSnapshot(); got >= streamBatchWaitThreshold { + t.Fatalf("adaptive stream wait threshold = %d, want smaller than %d after slow write", got, streamBatchWaitThreshold) + } +} diff --git a/transport_codec.go b/transport_codec.go index 2895b1d..e810a08 100644 --- a/transport_codec.go +++ b/transport_codec.go @@ -9,12 +9,93 @@ var ( errTransportPayloadDecryptFailed = errors.New("transport payload decrypt failed") ) +func encryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgEn func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { + if runtime != nil { + encoded, err := runtime.sealPlainPayload(data) + if err != nil { + return nil, errTransportPayloadEncryptFailed + } + return encoded, nil + } + if msgEn == nil { + return nil, errTransportPayloadEncryptFailed + } + encoded := msgEn(secretKey, data) + if encoded == nil && len(data) != 0 { + return nil, errTransportPayloadEncryptFailed + } + return encoded, nil +} + +func decryptTransportPayloadCodec(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, error) { + if runtime != nil { + plain, err := runtime.openPayload(data) + if err != nil { + return nil, errTransportPayloadDecryptFailed + } + return plain, nil + } + if msgDe == nil { + return nil, errTransportPayloadDecryptFailed + } + plain := msgDe(secretKey, data) + if plain == nil && len(data) != 0 { + return nil, errTransportPayloadDecryptFailed + } + return plain, nil +} + +func decryptTransportPayloadCodecPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte, release func()) ([]byte, func(), error) { + if runtime != nil { + plain, plainRelease, err := runtime.openPayloadPooled(data, release) + if err != nil { + return nil, nil, errTransportPayloadDecryptFailed + } + return plain, plainRelease, nil + } + if msgDe == nil { + if release != nil { + release() + } + return nil, nil, errTransportPayloadDecryptFailed + } + plain := msgDe(secretKey, data) + if release != nil { + release() + } + if plain == nil && len(data) != 0 { + return nil, nil, errTransportPayloadDecryptFailed + } + return plain, nil, nil +} + +func decryptTransportPayloadCodecOwnedPooled(runtime *modernPSKCodecRuntime, msgDe func([]byte, []byte) []byte, secretKey []byte, data []byte) ([]byte, func(), error) { + if runtime != nil { + plain, plainRelease, err := runtime.openPayloadOwnedPooled(data) + if err != nil { + return nil, nil, errTransportPayloadDecryptFailed + } + return plain, plainRelease, nil + } + if msgDe == nil { + return nil, nil, errTransportPayloadDecryptFailed + } + plain := msgDe(secretKey, data) + if plain == nil && len(data) != 0 { + return nil, nil, errTransportPayloadDecryptFailed + } + return plain, nil, nil +} + func (c *ClientCommon) encodeTransferMsg(msg TransferMsg) ([]byte, error) { data, err := c.sequenceEn(msg) if err != nil { return nil, err } - data = c.msgEn(c.SecretKey, data) + data, err = c.encryptTransportPayload(data) + if err != nil { + return nil, err + } queue := c.clientQueueSnapshot() if queue == nil { return nil, errClientSessionQueueUnavailable @@ -23,7 +104,11 @@ func (c *ClientCommon) encodeTransferMsg(msg TransferMsg) ([]byte, error) { } func (c *ClientCommon) decodeTransferMsg(data []byte) (TransferMsg, error) { - msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data)) + plain, err := c.decryptTransportPayload(data) + if err != nil { + return TransferMsg{}, err + } + msg, err := c.sequenceDe(plain) if err != nil { return TransferMsg{}, err } @@ -41,7 +126,10 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte } msgEn := c.clientConnMsgEnSnapshot() secretKey := c.clientConnSecretKeySnapshot() - data = msgEn(secretKey, data) + data, err = encryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgEn, secretKey, data) + if err != nil { + return nil, err + } queue := s.serverQueueSnapshot() if queue == nil { return nil, errServerSessionQueueUnavailable @@ -52,7 +140,11 @@ func (s *ServerCommon) encodeTransferMsg(c *ClientConn, msg TransferMsg) ([]byte func (s *ServerCommon) decodeTransferMsg(c *ClientConn, data []byte) (TransferMsg, error) { msgDe := c.clientConnMsgDeSnapshot() secretKey := c.clientConnSecretKeySnapshot() - msg, err := s.sequenceDe(msgDe(secretKey, data)) + plain, err := decryptTransportPayloadCodec(c.clientConnModernPSKRuntimeSnapshot(), msgDe, secretKey, data) + if err != nil { + return TransferMsg{}, err + } + msg, err := s.sequenceDe(plain) if err != nil { return TransferMsg{}, err } @@ -80,11 +172,7 @@ func (c *ClientCommon) encodeEnvelopePlain(env Envelope) ([]byte, error) { } func (c *ClientCommon) encryptTransportPayload(data []byte) ([]byte, error) { - encoded := c.msgEn(c.SecretKey, data) - if encoded == nil && len(data) != 0 { - return nil, errTransportPayloadEncryptFailed - } - return encoded, nil + return encryptTransportPayloadCodec(c.modernPSKRuntime, c.msgEn, c.SecretKey, data) } func (c *ClientCommon) encodeEnvelope(env Envelope) ([]byte, error) { @@ -108,11 +196,7 @@ func (c *ClientCommon) decodeEnvelope(data []byte) (Envelope, error) { } func (c *ClientCommon) decryptTransportPayload(data []byte) ([]byte, error) { - plain := c.msgDe(c.SecretKey, data) - if plain == nil && len(data) != 0 { - return nil, errTransportPayloadDecryptFailed - } - return plain, nil + return decryptTransportPayloadCodec(c.modernPSKRuntime, c.msgDe, c.SecretKey, data) } func (c *ClientCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { @@ -167,11 +251,7 @@ func (s *ServerCommon) encryptTransportPayloadLogical(logical *LogicalConn, data if msgEn == nil { return nil, errTransportDetached } - encoded := msgEn(secretKey, data) - if encoded == nil && len(data) != 0 { - return nil, errTransportPayloadEncryptFailed - } - return encoded, nil + return encryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgEn, secretKey, data) } func (s *ServerCommon) encodeEnvelopeLogical(logical *LogicalConn, env Envelope) ([]byte, error) { @@ -210,11 +290,7 @@ func (s *ServerCommon) decryptTransportPayloadLogical(logical *LogicalConn, data if msgDe == nil { return nil, errTransportDetached } - plain := msgDe(secretKey, data) - if plain == nil && len(data) != 0 { - return nil, errTransportPayloadDecryptFailed - } - return plain, nil + return decryptTransportPayloadCodec(logical.modernPSKRuntimeSnapshot(), msgDe, secretKey, data) } func (s *ServerCommon) decodeEnvelopePlain(data []byte) (Envelope, error) { diff --git a/transport_conn.go b/transport_conn.go index 60c6edc..439083e 100644 --- a/transport_conn.go +++ b/transport_conn.go @@ -1,6 +1,7 @@ package notify import ( + "b612.me/stario" "context" "errors" "net" @@ -16,7 +17,7 @@ type TransportConn struct { } const ( - transportStreamReadBufferSize = 256 * 1024 + transportStreamReadBufferSize = 1024 * 1024 transportPacketReadBufferSize = 64 * 1024 ) @@ -28,6 +29,17 @@ func packetReadBuffer() []byte { return make([]byte, transportPacketReadBufferSize) } +func newTransportFrameReader(conn net.Conn, queue *stario.StarQueue) *stario.FrameReader { + reader := stario.NewFrameReader(conn, queue) + if reader == nil { + return nil + } + if transportStreamReadBufferSize > stario.DefaultFrameReaderBufferSize { + reader.SetReadBufferSize(transportStreamReadBufferSize) + } + return reader +} + type TransportConnRuntimeSnapshot struct { ClientID string RemoteAddress string diff --git a/transport_write_test.go b/transport_write_test.go index 0764185..d541988 100644 --- a/transport_write_test.go +++ b/transport_write_test.go @@ -62,18 +62,72 @@ func TestWriteFullToConnSerializesConcurrentWriters(t *testing.T) { } } +func newTestBulkBatchSender(binding *transportBinding) *bulkBatchSender { + return newTestBulkBatchSenderWithWriteTimeout(binding, nil) +} + +func newTestBulkBatchSenderWithWriteTimeout(binding *transportBinding, writeTimeout func() time.Duration) *bulkBatchSender { + return newBulkBatchSender(binding, bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + return append([]byte(nil), frame.Payload...), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + payload, err := encodeBulkFastBatchPlain(frames) + return payload, nil, err + }, + }, writeTimeout) +} + +func newTestStreamBatchSender(binding *transportBinding, writeTimeout func() time.Duration) *streamBatchSender { + return newStreamBatchSender(binding, streamBatchCodec{ + encodeSingle: func(frame streamFastDataFrame) ([]byte, error) { + return encodeStreamFastFramePayload(frame) + }, + encodeBatch: func(frames []streamFastDataFrame) ([]byte, error) { + return encodeStreamFastBatchPlain(frames) + }, + }, writeTimeout) +} + func TestBulkBatchSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) { left, right := net.Pipe() defer left.Close() defer right.Close() - sender := newBulkBatchSender(newTransportBinding(left, stario.NewQueue())) - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() + sender := newTestBulkBatchSenderWithWriteTimeout(newTransportBinding(left, stario.NewQueue()), func() time.Duration { + return 50 * time.Millisecond + }) errCh := make(chan error, 1) go func() { - errCh <- sender.submit(ctx, []byte("payload")) + errCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload")) + }() + + select { + case err := <-errCh: + if err == nil { + t.Fatal("sender.submit should fail when receiver stalls") + } + if !isTimeoutLikeError(err) { + t.Fatalf("sender.submit error = %v, want timeout-like error", err) + } + case <-time.After(time.Second): + t.Fatal("sender.submit should not hang when receiver stalls") + } +} + +func TestStreamBatchSenderRespectsBindingWriteDeadlineWhenReceiverStalls(t *testing.T) { + left, right := net.Pipe() + defer left.Close() + defer right.Close() + + sender := newTestStreamBatchSender(newTransportBinding(left, stario.NewQueue()), func() time.Duration { + return 50 * time.Millisecond + }) + + errCh := make(chan error, 1) + go func() { + errCh <- sender.submitData(context.Background(), 1, 1, streamFastPathVersionV2, []byte("payload")) }() select { @@ -258,12 +312,12 @@ func TestWriteNetBuffersFullUnlockedUsesUnwrappedVectoredConn(t *testing.T) { func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) - sender := newBulkBatchSender(binding) + sender := newTestBulkBatchSender(binding) defer sender.stop() firstErrCh := make(chan error, 1) go func() { - firstErrCh <- sender.submit(context.Background(), []byte("first")) + firstErrCh <- sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("first")) }() select { @@ -275,7 +329,7 @@ func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) secondErrCh := make(chan error, 1) go func() { - secondErrCh <- sender.submit(ctx, []byte("second")) + secondErrCh <- sender.submitData(ctx, 1, 2, bulkFastPathVersionV1, []byte("second")) }() time.Sleep(20 * time.Millisecond) cancel() @@ -306,16 +360,48 @@ func TestBulkBatchSenderSkipsQueuedCanceledRequest(t *testing.T) { } } +func TestBulkBatchSenderDoesNotDirectSubmitShareableV2Data(t *testing.T) { + sender := &bulkBatchSender{} + req := bulkBatchRequest{ + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeData, + DataID: 1, + Seq: 1, + Payload: make([]byte, 256*1024), + }}, + fastPathVersion: bulkFastPathVersionV2, + } + if sender.shouldDirectSubmit(req) { + t.Fatal("shareable v2 shared bulk data should queue for super-batch instead of direct submit") + } +} + +func TestBulkBatchSenderDirectSubmitsUnbatchableRequest(t *testing.T) { + sender := &bulkBatchSender{} + req := bulkBatchRequest{ + frames: []bulkFastFrame{{ + Type: bulkFastPayloadTypeRelease, + DataID: 1, + Seq: 0, + Payload: []byte("rel"), + }}, + fastPathVersion: bulkFastPathVersionV2, + } + if !sender.shouldDirectSubmit(req) { + t.Fatal("unbatchable shared bulk control request should still direct submit") + } +} + func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) { conn := newBlockingPacketWriteConn() binding := newTransportBinding(conn, stario.NewQueue()) - sender := newBulkBatchSender(binding) + sender := newTestBulkBatchSender(binding) defer sender.stop() ctx, cancel := context.WithCancel(context.Background()) errCh := make(chan error, 1) go func() { - errCh <- sender.submit(ctx, []byte("payload")) + errCh <- sender.submitData(ctx, 1, 1, bulkFastPathVersionV1, []byte("payload")) }() select { @@ -346,10 +432,18 @@ func TestBulkBatchSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T func TestTransportBindingStopBackgroundWorkersStopsSharedSender(t *testing.T) { binding := newTransportBinding(newBlockingPacketWriteConn(), stario.NewQueue()) - sender := binding.bulkBatchSenderSnapshot() + sender := binding.bulkBatchSenderSnapshotWithCodec(bulkBatchCodec{ + encodeSingle: func(frame bulkFastFrame) ([]byte, func(), error) { + return append([]byte(nil), frame.Payload...), nil, nil + }, + encodeBatch: func(frames []bulkFastFrame) ([]byte, func(), error) { + payload, err := encodeBulkFastBatchPlain(frames) + return payload, nil, err + }, + }, nil) binding.stopBackgroundWorkers() - err := sender.submit(context.Background(), []byte("payload")) + err := sender.submitData(context.Background(), 1, 1, bulkFastPathVersionV1, []byte("payload")) if !errors.Is(err, errTransportDetached) { t.Fatalf("sender.submit after stop = %v, want %v", err, errTransportDetached) }