notify/stream_flow.go

276 lines
5.2 KiB
Go
Raw Permalink Normal View History

package notify
import (
"context"
"sync"
"sync/atomic"
)
type streamFlowController struct {
mu sync.Mutex
queue []*streamFlowRequest
inFlightBytes atomic.Int64
inFlightChunks atomic.Int64
windowBytes atomic.Int64
maxChunks atomic.Int64
waiters atomic.Int32
}
type streamFlowRequest struct {
size int
ready chan struct{}
admitted bool
}
func newStreamFlowController(cfg streamConfig) *streamFlowController {
cfg = normalizeStreamConfig(cfg)
controller := &streamFlowController{}
controller.windowBytes.Store(int64(cfg.OutboundWindowBytes))
controller.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
return controller
}
func (c *streamFlowController) applyConfig(cfg streamConfig) {
if c == nil {
return
}
cfg = normalizeStreamConfig(cfg)
c.windowBytes.Store(int64(cfg.OutboundWindowBytes))
c.maxChunks.Store(int64(cfg.OutboundMaxInFlightChunks))
if c.waiters.Load() == 0 {
return
}
c.mu.Lock()
c.drainLocked()
c.mu.Unlock()
}
func (c *streamFlowController) acquire(ctx context.Context, size int) (func(), error) {
if c == nil || size <= 0 {
return func() {}, nil
}
if ctx == nil {
ctx = context.Background()
}
if c.tryAcquire(size) {
return c.releaseFunc(size), nil
}
req := &streamFlowRequest{
size: size,
ready: make(chan struct{}),
}
c.mu.Lock()
if c.tryAcquireLocked(size) {
c.mu.Unlock()
return c.releaseFunc(size), nil
}
c.queue = append(c.queue, req)
c.waiters.Add(1)
c.drainLocked()
c.mu.Unlock()
select {
case <-req.ready:
return c.releaseFunc(size), nil
case <-ctx.Done():
c.mu.Lock()
if req.admitted {
c.mu.Unlock()
return c.releaseFunc(size), nil
}
c.removeLocked(req)
c.drainLocked()
c.mu.Unlock()
return nil, ctx.Err()
}
}
func (c *streamFlowController) tryAcquire(size int) bool {
if c == nil || size <= 0 {
return true
}
if c.waiters.Load() != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireLocked(size int) bool {
if c == nil || size <= 0 {
return true
}
if len(c.queue) != 0 {
return false
}
return c.tryAcquireCAS(size)
}
func (c *streamFlowController) tryAcquireCAS(size int) bool {
if c == nil || size <= 0 {
return true
}
size64 := int64(size)
for {
window := c.windowBytes.Load()
maxChunks := c.maxChunks.Load()
inFlightBytes := c.inFlightBytes.Load()
inFlightChunks := c.inFlightChunks.Load()
if maxChunks > 0 && inFlightChunks >= maxChunks {
return false
}
if window > 0 && inFlightBytes+size64 > window {
if !(inFlightBytes == 0 && inFlightChunks == 0) {
return false
}
}
if !c.inFlightBytes.CompareAndSwap(inFlightBytes, inFlightBytes+size64) {
continue
}
if c.addChunksCAS(1, maxChunks) {
return true
}
c.subBytesCAS(size64)
return false
}
}
func (c *streamFlowController) addChunksCAS(delta int64, maxChunks int64) bool {
if c == nil || delta <= 0 {
return true
}
for {
current := c.inFlightChunks.Load()
if maxChunks > 0 && current+delta > maxChunks {
return false
}
if c.inFlightChunks.CompareAndSwap(current, current+delta) {
return true
}
}
}
func (c *streamFlowController) subBytesCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightBytes.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightBytes.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) subChunksCAS(delta int64) {
if c == nil || delta <= 0 {
return
}
for {
current := c.inFlightChunks.Load()
next := current - delta
if next < 0 {
next = 0
}
if c.inFlightChunks.CompareAndSwap(current, next) {
return
}
}
}
func (c *streamFlowController) releaseFunc(size int) func() {
released := false
return func() {
if released {
return
}
released = true
c.release(size)
}
}
func (c *streamFlowController) release(size int) {
if c == nil || size <= 0 {
return
}
c.subBytesCAS(int64(size))
c.subChunksCAS(1)
if c.waiters.Load() == 0 {
return
}
c.mu.Lock()
c.drainLocked()
c.mu.Unlock()
}
func (c *streamFlowController) removeLocked(req *streamFlowRequest) {
if c == nil || req == nil {
return
}
for i, item := range c.queue {
if item != req {
continue
}
copy(c.queue[i:], c.queue[i+1:])
c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
return
}
}
func (c *streamFlowController) drainLocked() {
if c == nil {
return
}
for len(c.queue) > 0 {
req := c.queue[0]
if req == nil {
c.queue = c.queue[1:]
continue
}
if !c.canAdmitLocked(req.size) {
return
}
if !c.tryAcquireCAS(req.size) {
return
}
copy(c.queue[0:], c.queue[1:])
c.queue[len(c.queue)-1] = nil
c.queue = c.queue[:len(c.queue)-1]
c.waiters.Add(-1)
req.admitted = true
close(req.ready)
}
}
func (c *streamFlowController) canAdmitLocked(size int) bool {
if c == nil {
return true
}
if size <= 0 {
return true
}
window := c.windowBytes.Load()
chunks := c.inFlightChunks.Load()
bytes := c.inFlightBytes.Load()
maxChunks := c.maxChunks.Load()
if maxChunks > 0 && chunks >= maxChunks {
return false
}
if window <= 0 {
return true
}
if bytes+int64(size) <= window {
return true
}
return bytes == 0 && chunks == 0
}