package stario import ( "fmt" "sync" ) type waitGroupAddMode uint8 const ( waitGroupAddModeStrict waitGroupAddMode = iota waitGroupAddModeLoose ) type WaitGroup struct { wg sync.WaitGroup mu sync.Mutex cond *sync.Cond initOnce sync.Once maxCount int running int addMode waitGroupAddMode } func NewWaitGroup(maxCount int) WaitGroup { if maxCount < 0 { panic("stario: negative max wait count") } return WaitGroup{ maxCount: maxCount, addMode: waitGroupAddModeStrict, } } func (w *WaitGroup) init() { w.initOnce.Do(func() { w.cond = sync.NewCond(&w.mu) }) } func (w *WaitGroup) Add(delta int) { w.init() if delta == 0 { return } if delta < 0 { w.release(-delta) return } w.acquire(delta) } func (w *WaitGroup) acquire(delta int) { w.mu.Lock() defer w.mu.Unlock() if w.maxCount <= 0 { w.running += delta w.wg.Add(delta) return } if delta == 1 { w.wg.Add(1) for w.maxCount > 0 && w.running >= w.maxCount { w.cond.Wait() } w.running++ return } if w.running+delta > w.maxCount { if w.addMode == waitGroupAddModeStrict { panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running)) } w.maxCount = w.running + delta } w.running += delta w.wg.Add(delta) } func (w *WaitGroup) release(delta int) { w.mu.Lock() defer w.mu.Unlock() if delta > w.running { panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running)) } w.wg.Add(-delta) w.running -= delta w.cond.Broadcast() } func (w *WaitGroup) Done() { w.Add(-1) } func (w *WaitGroup) Go(fn func()) { w.Add(1) go func() { defer w.Done() fn() }() } func (w *WaitGroup) Wait() { w.init() w.wg.Wait() } func (w *WaitGroup) GetMaxWaitNum() int { w.init() w.mu.Lock() defer w.mu.Unlock() return w.maxCount } func (w *WaitGroup) SetMaxWaitNum(num int) { if num < 0 { panic("stario: negative max wait count") } w.init() w.mu.Lock() w.maxCount = num w.mu.Unlock() w.cond.Broadcast() } func (w *WaitGroup) SetStrictAddMode(strict bool) { w.init() w.mu.Lock() if strict { w.addMode = waitGroupAddModeStrict } else { w.addMode = waitGroupAddModeLoose } w.mu.Unlock() w.cond.Broadcast() } func (w *WaitGroup) StrictAddMode() bool { w.init() w.mu.Lock() defer w.mu.Unlock() return w.addMode == waitGroupAddModeStrict }