package notify import ( "context" "sync" "sync/atomic" ) type streamFlowController struct { mu sync.Mutex queue []*streamFlowRequest inFlightBytes atomic.Int64 inFlightChunks atomic.Int64 windowBytes atomic.Int64 maxChunks atomic.Int64 waiters atomic.Int32 } type streamFlowRequest struct { size int ready chan struct{} admitted bool } func newStreamFlowController(cfg streamConfig) *streamFlowController { cfg = normalizeStreamConfig(cfg) controller := &streamFlowController{} controller.windowBytes.Store(int64(cfg.OutboundWindowBytes)) controller.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks)) return controller } func (c *streamFlowController) applyConfig(cfg streamConfig) { if c == nil { return } cfg = normalizeStreamConfig(cfg) c.windowBytes.Store(int64(cfg.OutboundWindowBytes)) c.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks)) if c.waiters.Load() == 0 { return } c.mu.Lock() c.drainLocked() c.mu.Unlock() } func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), error) { if c == nil || size <= 0 { return func() {}, nil } if ctx == nil { ctx = context.Background() } if c.tryAcquire(size) { return c.releaseFunc(size), nil } req := &streamFlowRequest{ size: size, ready: make(chan struct{}), } c.mu.Lock() if c.tryAcquireLocked(size) { c.mu.Unlock() return c.releaseFunc(size), nil } c.queue = append(c.queue, req) c.waiters.Add(1) c.drainLocked() c.mu.Unlock() select { case <-req.ready: return c.releaseFunc(size), nil case <-ctx.Done(): c.mu.Lock() if req.admitted { c.mu.Unlock() return c.releaseFunc(size), nil } c.removeLocked(req) c.drainLocked() c.mu.Unlock() return nil, ctx.Err() } } func (c *streamFlowController) tryAcquire(size int) bool { if c == nil || size <= 0 { return true } if c.waiters.Load() != 0 { return false } return c.tryAcquireCAS(size) } func (c *streamFlowController) tryAcquireLocked(size int) bool { if c == nil || size <= 0 { return true } if len(c.queue) != 0 { return false } return c.tryAcquireCAS(size) } func (c *streamFlowController) tryAcquireCAS(size int) bool { if c == nil || size <= 0 { return true } size64 := int64(size) for { window := c.windowBytes.Load() maxChunks := c.maxChunks.Load() inFlightBytes := c.inFlightBytes.Load() inFlightChunks := c.inFlightChunks.Load() if maxChunks > 0 && inFlightChunks >= maxChunks { return false } if window > 0 && inFlightBytes+size64 > window { if !(inFlightBytes == 0 && inFlightChunks == 0) { return false } } if !c.inFlightBytes.CompareAndSwap(inFlightBytes, inFlightBytes+size64) { continue } if c.addChunksCAS(1, maxChunks) { return true } c.subBytesCAS(size64) return false } } func (c *streamFlowController) addChunksCAS(delta int64, maxChunks int64) bool { if c == nil || delta <= 0 { return true } for { current := c.inFlightChunks.Load() if maxChunks > 0 && current+delta > maxChunks { return false } if c.inFlightChunks.CompareAndSwap(current, current+delta) { return true } } } func (c *streamFlowController) subBytesCAS(delta int64) { if c == nil || delta <= 0 { return } for { current := c.inFlightBytes.Load() next := current - delta if next < 0 { next = 0 } if c.inFlightBytes.CompareAndSwap(current, next) { return } } } func (c *streamFlowController) subChunksCAS(delta int64) { if c == nil || delta <= 0 { return } for { current := c.inFlightChunks.Load() next := current - delta if next < 0 { next = 0 } if c.inFlightChunks.CompareAndSwap(current, next) { return } } } func (c *streamFlowController) releaseFunc(size int) func() { released := false return func() { if released { return } released = true c.release(size) } } func (c *streamFlowController) release(size int) { if c == nil || size <= 0 { return } c.subBytesCAS(int64(size)) c.subChunksCAS(1) if c.waiters.Load() == 0 { return } c.mu.Lock() c.drainLocked() c.mu.Unlock() } func (c *streamFlowController) removeLocked(req *streamFlowRequest) { if c == nil || req == nil { return } for i, item := range c.queue { if item != req { continue } copy(c.queue[i:], c.queue[i+1:]) c.queue[len(c.queue)-1] = nil c.queue = c.queue[:len(c.queue)-1] c.waiters.Add(-1) return } } func (c *streamFlowController) drainLocked() { if c == nil { return } for len(c.queue) > 0 { req := c.queue[0] if req == nil { c.queue = c.queue[1:] continue } if !c.canAdmitLocked(req.size) { return } if !c.tryAcquireCAS(req.size) { return } copy(c.queue[0:], c.queue[1:]) c.queue[len(c.queue)-1] = nil c.queue = c.queue[:len(c.queue)-1] c.waiters.Add(-1) req.admitted = true close(req.ready) } } func (c *streamFlowController) canAdmitLocked(size int) bool { if c == nil { return true } if size <= 0 { return true } window := c.windowBytes.Load() chunks := c.inFlightChunks.Load() bytes := c.inFlightBytes.Load() maxChunks := c.maxChunks.Load() if maxChunks > 0 && chunks >= maxChunks { return false } if window <= 0 { return true } if bytes+int64(size) <= window { return true } return bytes == 0 && chunks == 0 }