143 lines
2.3 KiB
Go
143 lines
2.3 KiB
Go
package stario
|
|
|
|
import (
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
type waitGroupAddMode uint8
|
|
|
|
const (
|
|
waitGroupAddModeStrict waitGroupAddMode = iota
|
|
waitGroupAddModeLoose
|
|
)
|
|
|
|
type WaitGroup struct {
|
|
wg sync.WaitGroup
|
|
mu sync.Mutex
|
|
cond *sync.Cond
|
|
initOnce sync.Once
|
|
maxCount int
|
|
running int
|
|
addMode waitGroupAddMode
|
|
}
|
|
|
|
func NewWaitGroup(maxCount int) WaitGroup {
|
|
if maxCount < 0 {
|
|
panic("stario: negative max wait count")
|
|
}
|
|
return WaitGroup{
|
|
maxCount: maxCount,
|
|
addMode: waitGroupAddModeStrict,
|
|
}
|
|
}
|
|
|
|
func (w *WaitGroup) init() {
|
|
w.initOnce.Do(func() {
|
|
w.cond = sync.NewCond(&w.mu)
|
|
})
|
|
}
|
|
|
|
func (w *WaitGroup) Add(delta int) {
|
|
w.init()
|
|
if delta == 0 {
|
|
return
|
|
}
|
|
if delta < 0 {
|
|
w.release(-delta)
|
|
return
|
|
}
|
|
w.acquire(delta)
|
|
}
|
|
|
|
func (w *WaitGroup) acquire(delta int) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
if w.maxCount <= 0 {
|
|
w.running += delta
|
|
w.wg.Add(delta)
|
|
return
|
|
}
|
|
if delta == 1 {
|
|
w.wg.Add(1)
|
|
for w.maxCount > 0 && w.running >= w.maxCount {
|
|
w.cond.Wait()
|
|
}
|
|
w.running++
|
|
return
|
|
}
|
|
if w.running+delta > w.maxCount {
|
|
if w.addMode == waitGroupAddModeStrict {
|
|
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
|
|
}
|
|
w.maxCount = w.running + delta
|
|
}
|
|
w.running += delta
|
|
w.wg.Add(delta)
|
|
}
|
|
|
|
func (w *WaitGroup) release(delta int) {
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
if delta > w.running {
|
|
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
|
|
}
|
|
w.wg.Add(-delta)
|
|
w.running -= delta
|
|
w.cond.Broadcast()
|
|
}
|
|
|
|
func (w *WaitGroup) Done() {
|
|
w.Add(-1)
|
|
}
|
|
|
|
func (w *WaitGroup) Go(fn func()) {
|
|
w.Add(1)
|
|
go func() {
|
|
defer w.Done()
|
|
fn()
|
|
}()
|
|
}
|
|
|
|
func (w *WaitGroup) Wait() {
|
|
w.init()
|
|
w.wg.Wait()
|
|
}
|
|
|
|
func (w *WaitGroup) GetMaxWaitNum() int {
|
|
w.init()
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
return w.maxCount
|
|
}
|
|
|
|
func (w *WaitGroup) SetMaxWaitNum(num int) {
|
|
if num < 0 {
|
|
panic("stario: negative max wait count")
|
|
}
|
|
w.init()
|
|
w.mu.Lock()
|
|
w.maxCount = num
|
|
w.mu.Unlock()
|
|
w.cond.Broadcast()
|
|
}
|
|
|
|
func (w *WaitGroup) SetStrictAddMode(strict bool) {
|
|
w.init()
|
|
w.mu.Lock()
|
|
if strict {
|
|
w.addMode = waitGroupAddModeStrict
|
|
} else {
|
|
w.addMode = waitGroupAddModeLoose
|
|
}
|
|
w.mu.Unlock()
|
|
w.cond.Broadcast()
|
|
}
|
|
|
|
func (w *WaitGroup) StrictAddMode() bool {
|
|
w.init()
|
|
w.mu.Lock()
|
|
defer w.mu.Unlock()
|
|
return w.addMode == waitGroupAddModeStrict
|
|
}
|