stario/sync.go

143 lines
2.3 KiB
Go

package stario
import (
"fmt"
"sync"
)
type waitGroupAddMode uint8
const (
waitGroupAddModeStrict waitGroupAddMode = iota
waitGroupAddModeLoose
)
type WaitGroup struct {
wg sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
initOnce sync.Once
maxCount int
running int
addMode waitGroupAddMode
}
func NewWaitGroup(maxCount int) WaitGroup {
if maxCount < 0 {
panic("stario: negative max wait count")
}
return WaitGroup{
maxCount: maxCount,
addMode: waitGroupAddModeStrict,
}
}
func (w *WaitGroup) init() {
w.initOnce.Do(func() {
w.cond = sync.NewCond(&w.mu)
})
}
func (w *WaitGroup) Add(delta int) {
w.init()
if delta == 0 {
return
}
if delta < 0 {
w.release(-delta)
return
}
w.acquire(delta)
}
func (w *WaitGroup) acquire(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.maxCount <= 0 {
w.running += delta
w.wg.Add(delta)
return
}
if delta == 1 {
w.wg.Add(1)
for w.maxCount > 0 && w.running >= w.maxCount {
w.cond.Wait()
}
w.running++
return
}
if w.running+delta > w.maxCount {
if w.addMode == waitGroupAddModeStrict {
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
}
w.maxCount = w.running + delta
}
w.running += delta
w.wg.Add(delta)
}
func (w *WaitGroup) release(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if delta > w.running {
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
}
w.wg.Add(-delta)
w.running -= delta
w.cond.Broadcast()
}
func (w *WaitGroup) Done() {
w.Add(-1)
}
func (w *WaitGroup) Go(fn func()) {
w.Add(1)
go func() {
defer w.Done()
fn()
}()
}
func (w *WaitGroup) Wait() {
w.init()
w.wg.Wait()
}
func (w *WaitGroup) GetMaxWaitNum() int {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.maxCount
}
func (w *WaitGroup) SetMaxWaitNum(num int) {
if num < 0 {
panic("stario: negative max wait count")
}
w.init()
w.mu.Lock()
w.maxCount = num
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.init()
w.mu.Lock()
if strict {
w.addMode = waitGroupAddModeStrict
} else {
w.addMode = waitGroupAddModeLoose
}
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) StrictAddMode() bool {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.addMode == waitGroupAddModeStrict
}