stario/sync.go

160 lines
3.2 KiB
Go
Raw Permalink Normal View History

2021-11-12 16:01:35 +08:00
package stario
import (
2026-03-21 19:15:43 +08:00
"fmt"
2021-11-12 16:01:35 +08:00
"sync"
2026-03-21 19:15:43 +08:00
)
type waitGroupAddMode uint8
const (
waitGroupAddModeStrict waitGroupAddMode = iota
waitGroupAddModeLoose
2021-11-12 16:01:35 +08:00
)
// WaitGroup is a concurrency-limited sync.WaitGroup variant.
//
// A zero or negative limit means unlimited concurrency. WaitGroup must not be
// copied after first use.
2021-11-12 16:01:35 +08:00
type WaitGroup struct {
2026-03-21 19:15:43 +08:00
wg sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
initOnce sync.Once
maxCount int
running int
addMode waitGroupAddMode
2021-11-12 16:01:35 +08:00
}
// NewWaitGroup creates a WaitGroup with the provided concurrency limit.
2021-11-12 16:01:35 +08:00
func NewWaitGroup(maxCount int) WaitGroup {
2026-03-21 19:15:43 +08:00
if maxCount < 0 {
panic("stario: negative max wait count")
}
return WaitGroup{
maxCount: maxCount,
addMode: waitGroupAddModeStrict,
}
}
func (w *WaitGroup) init() {
w.initOnce.Do(func() {
w.cond = sync.NewCond(&w.mu)
})
2021-11-12 16:01:35 +08:00
}
// Add adjusts the running task count.
//
// Positive deltas may block when the concurrency limit is reached. Negative
// deltas release running slots.
func (w *WaitGroup) Add(delta int) {
2026-03-21 19:15:43 +08:00
w.init()
if delta == 0 {
return
}
2021-11-12 16:01:35 +08:00
if delta < 0 {
2026-03-21 19:15:43 +08:00
w.release(-delta)
return
}
w.acquire(delta)
}
func (w *WaitGroup) acquire(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.maxCount <= 0 {
w.running += delta
w.wg.Add(delta)
return
}
if delta == 1 {
w.wg.Add(1)
for w.maxCount > 0 && w.running >= w.maxCount {
w.cond.Wait()
2021-11-12 16:01:35 +08:00
}
2026-03-21 19:15:43 +08:00
w.running++
return
2021-11-12 16:01:35 +08:00
}
2026-03-21 19:15:43 +08:00
if w.running+delta > w.maxCount {
if w.addMode == waitGroupAddModeStrict {
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
}
w.maxCount = w.running + delta
}
w.running += delta
w.wg.Add(delta)
2021-11-12 16:01:35 +08:00
}
2026-03-21 19:15:43 +08:00
func (w *WaitGroup) release(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if delta > w.running {
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
}
w.wg.Add(-delta)
w.running -= delta
w.cond.Broadcast()
}
// Done releases one running task slot.
func (w *WaitGroup) Done() {
w.Add(-1)
2021-11-12 16:01:35 +08:00
}
// Go runs fn in a goroutine while accounting for the concurrency limit.
2026-03-21 19:15:43 +08:00
func (w *WaitGroup) Go(fn func()) {
w.Add(1)
go func() {
defer w.Done()
fn()
}()
}
// Wait blocks until all added work has completed.
func (w *WaitGroup) Wait() {
2026-03-21 19:15:43 +08:00
w.init()
w.wg.Wait()
2021-11-12 16:01:35 +08:00
}
// GetMaxWaitNum returns the current concurrency limit.
func (w *WaitGroup) GetMaxWaitNum() int {
2026-03-21 19:15:43 +08:00
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.maxCount
2021-11-12 16:01:35 +08:00
}
// SetMaxWaitNum updates the concurrency limit.
func (w *WaitGroup) SetMaxWaitNum(num int) {
2026-03-21 19:15:43 +08:00
if num < 0 {
panic("stario: negative max wait count")
}
w.init()
w.mu.Lock()
w.maxCount = num
w.mu.Unlock()
w.cond.Broadcast()
}
// SetStrictAddMode controls whether Add(n>1) panics or auto-expands the limit
// when the requested batch exceeds the current capacity.
2026-03-21 19:15:43 +08:00
func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.init()
w.mu.Lock()
if strict {
w.addMode = waitGroupAddModeStrict
} else {
w.addMode = waitGroupAddModeLoose
}
w.mu.Unlock()
w.cond.Broadcast()
}
// StrictAddMode reports whether strict batch-add behavior is enabled.
2026-03-21 19:15:43 +08:00
func (w *WaitGroup) StrictAddMode() bool {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.addMode == waitGroupAddModeStrict
2021-11-12 16:01:35 +08:00
}