diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cd5e331 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.sentrux/ +agent_readme.md +target.md +.gocache +.gocache/ +.tmp_*/ +.idea \ No newline at end of file diff --git a/circle.go b/circle.go index c8932ad..6d96c4f 100644 --- a/circle.go +++ b/circle.go @@ -2,151 +2,191 @@ package stario import ( "errors" - "fmt" "io" - "os" - "runtime" "sync" - "sync/atomic" - "time" ) +var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be greater than zero") +var ErrStarBufferClosed = errors.New("star buffer closed") +var ErrStarBufferWriteClosed = errors.New("star buffer write closed") + type StarBuffer struct { - io.Reader - io.Writer - io.Closer - datas []byte - pStart uint64 - pEnd uint64 - cap uint64 - isClose atomic.Value - isEnd atomic.Value - rmu sync.Mutex - wmu sync.Mutex + datas []byte + pStart uint64 + pEnd uint64 + size uint64 + cap uint64 + isClose bool + isWriteEnd bool + mu sync.Mutex + notEmpty *sync.Cond + notFull *sync.Cond } -func NewStarBuffer(cap uint64) *StarBuffer { - rtnBuffer := new(StarBuffer) - rtnBuffer.cap = cap - rtnBuffer.datas = make([]byte, cap) - rtnBuffer.isClose.Store(false) - rtnBuffer.isEnd.Store(false) - return rtnBuffer +func NewStarBuffer(cap uint64) (*StarBuffer, error) { + if cap == 0 { + return nil, ErrStarBufferInvalidCapacity + } + rtnBuffer := &StarBuffer{ + cap: cap, + datas: make([]byte, cap), + } + rtnBuffer.notEmpty = sync.NewCond(&rtnBuffer.mu) + rtnBuffer.notFull = sync.NewCond(&rtnBuffer.mu) + return rtnBuffer, nil } func (star *StarBuffer) Free() uint64 { - return star.cap - star.Len() + star.mu.Lock() + defer star.mu.Unlock() + return star.cap - star.size } func (star *StarBuffer) Cap() uint64 { + star.mu.Lock() + defer star.mu.Unlock() return star.cap } func (star *StarBuffer) Len() uint64 { - if star.pEnd >= star.pStart { - return star.pEnd - star.pStart - } - return star.pEnd - star.pStart + star.cap + star.mu.Lock() + defer star.mu.Unlock() + return star.size } func (star *StarBuffer) getByte() (byte, error) { - if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) { + star.mu.Lock() + defer star.mu.Unlock() + for star.size == 0 && !star.isWriteEnd && !star.isClose { + star.notEmpty.Wait() + } + if star.size == 0 { return 0, io.EOF } - if star.Len() == 0 { - return 0, os.ErrNotExist - } - nowPtr := star.pStart - nextPtr := star.pStart + 1 - if nextPtr >= star.cap { - nextPtr = 0 - } - data := star.datas[nowPtr] - ok := atomic.CompareAndSwapUint64(&star.pStart, nowPtr, nextPtr) - if !ok { - return 0, os.ErrInvalid - } + data := star.datas[star.pStart] + star.pStart = (star.pStart + 1) % star.cap + star.size-- + star.notFull.Broadcast() return data, nil } func (star *StarBuffer) putByte(data byte) error { - if star.isClose.Load().(bool) { - return io.EOF + star.mu.Lock() + defer star.mu.Unlock() + for star.size == star.cap && !star.isClose && !star.isWriteEnd { + star.notFull.Wait() } - nowPtr := star.pEnd - kariEnd := nowPtr + 1 - if kariEnd == star.cap { - kariEnd = 0 + if star.isClose { + return ErrStarBufferClosed } - if kariEnd == atomic.LoadUint64(&star.pStart) { - for { - time.Sleep(time.Microsecond) - runtime.Gosched() - if kariEnd != atomic.LoadUint64(&star.pStart) { - break - } - } + if star.isWriteEnd { + return ErrStarBufferWriteClosed } - star.datas[nowPtr] = data - if ok := atomic.CompareAndSwapUint64(&star.pEnd, nowPtr, kariEnd); !ok { - return os.ErrInvalid + star.datas[star.pEnd] = data + star.pEnd = (star.pEnd + 1) % star.cap + star.size++ + star.notEmpty.Broadcast() + return nil +} + +func (star *StarBuffer) EndWrite() error { + star.mu.Lock() + defer star.mu.Unlock() + if star.isClose { + return ErrStarBufferClosed } + star.isWriteEnd = true + star.notEmpty.Broadcast() + star.notFull.Broadcast() return nil } func (star *StarBuffer) Close() error { - star.isClose.Store(true) + star.mu.Lock() + defer star.mu.Unlock() + star.isClose = true + star.isWriteEnd = true + star.notEmpty.Broadcast() + star.notFull.Broadcast() return nil } + func (star *StarBuffer) Read(buf []byte) (int, error) { - if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) { - return 0, io.EOF - } if buf == nil { return 0, errors.New("buffer is nil") } - star.rmu.Lock() - defer star.rmu.Unlock() - var sum int = 0 - for i := 0; i < len(buf); i++ { - data, err := star.getByte() - if err != nil { - if err == io.EOF { - if sum == 0 { - return sum, err - } - return sum, nil - } - if err == os.ErrNotExist { - i-- - continue - } - return sum, nil - } - buf[i] = data - sum++ + if len(buf) == 0 { + return 0, nil } + star.mu.Lock() + defer star.mu.Unlock() + for star.size == 0 && !star.isWriteEnd && !star.isClose { + star.notEmpty.Wait() + } + if star.size == 0 { + return 0, io.EOF + } + sum := minInt(len(buf), int(star.size)) + first := minInt(sum, int(star.cap-star.pStart)) + copy(buf, star.datas[star.pStart:star.pStart+uint64(first)]) + second := sum - first + if second > 0 { + copy(buf[first:], star.datas[:second]) + } + star.pStart = (star.pStart + uint64(sum)) % star.cap + star.size -= uint64(sum) + star.notFull.Broadcast() return sum, nil } func (star *StarBuffer) Write(bts []byte) (int, error) { - if bts == nil && !star.isEnd.Load().(bool) { - star.isEnd.Store(true) + if bts == nil { + return 0, star.EndWrite() + } + if len(bts) == 0 { return 0, nil } - if bts == nil || star.isClose.Load().(bool) { - return 0, io.EOF - } - star.wmu.Lock() - defer star.wmu.Unlock() - var sum = 0 - for i := 0; i < len(bts); i++ { - err := star.putByte(bts[i]) - if err != nil { - fmt.Println("Write bts err:", err) - return sum, err + star.mu.Lock() + defer star.mu.Unlock() + sum := 0 + for sum < len(bts) { + for star.size == star.cap && !star.isClose && !star.isWriteEnd { + star.notFull.Wait() } - sum++ + if star.isClose { + if sum == 0 { + return 0, ErrStarBufferClosed + } + return sum, ErrStarBufferClosed + } + if star.isWriteEnd { + if sum == 0 { + return 0, ErrStarBufferWriteClosed + } + return sum, ErrStarBufferWriteClosed + } + space := int(star.cap - star.size) + if space == 0 { + continue + } + n := minInt(len(bts)-sum, space) + first := minInt(n, int(star.cap-star.pEnd)) + copy(star.datas[star.pEnd:star.pEnd+uint64(first)], bts[sum:sum+first]) + second := n - first + if second > 0 { + copy(star.datas[:second], bts[sum+first:sum+n]) + } + star.pEnd = (star.pEnd + uint64(n)) % star.cap + star.size += uint64(n) + sum += n + star.notEmpty.Broadcast() } return sum, nil } + +func minInt(a int, b int) int { + if a < b { + return a + } + return b +} diff --git a/circle_test.go b/circle_test.go index 3b3bef9..cfec5bd 100644 --- a/circle_test.go +++ b/circle_test.go @@ -1,14 +1,91 @@ package stario import ( + "bytes" "fmt" + "io" "sync/atomic" "testing" "time" ) +func TestNewStarBufferRejectsZeroCapacity(t *testing.T) { + buf, err := NewStarBuffer(0) + if err != ErrStarBufferInvalidCapacity { + t.Fatalf("unexpected error: %v", err) + } + if buf != nil { + t.Fatal("expected nil buffer when capacity is invalid") + } +} + +func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) { + buf, err := NewStarBuffer(4) + if err != nil { + t.Fatal(err) + } + if _, err := buf.Write([]byte("abcd")); err != nil { + t.Fatal(err) + } + if err := buf.EndWrite(); err != nil { + t.Fatal(err) + } + + got := make([]byte, 4) + n, err := buf.Read(got) + if err != nil { + t.Fatal(err) + } + if n != 4 || !bytes.Equal(got[:n], []byte("abcd")) { + t.Fatalf("unexpected payload: n=%d data=%q", n, got[:n]) + } + + n, err = buf.Read(got) + if n != 0 || err != io.EOF { + t.Fatalf("expected EOF after draining buffer, got n=%d err=%v", n, err) + } + + if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed { + t.Fatalf("unexpected write error after EndWrite: %v", err) + } +} + +func TestStarBufferCloseAllowsDrain(t *testing.T) { + buf, err := NewStarBuffer(4) + if err != nil { + t.Fatal(err) + } + if _, err := buf.Write([]byte("ab")); err != nil { + t.Fatal(err) + } + if err := buf.Close(); err != nil { + t.Fatal(err) + } + + got := make([]byte, 2) + n, err := buf.Read(got) + if err != nil { + t.Fatal(err) + } + if n != 2 || !bytes.Equal(got[:n], []byte("ab")) { + t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n]) + } + + n, err = buf.Read(got) + if n != 0 || err != io.EOF { + t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err) + } + + if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed { + t.Fatalf("unexpected write error after Close: %v", err) + } +} + func Test_Circle(t *testing.T) { - buf := NewStarBuffer(2048) + buf, err := NewStarBuffer(2048) + if err != nil { + t.Fatal(err) + } go func() { for { //fmt.Println("write start") @@ -37,7 +114,10 @@ func Test_Circle(t *testing.T) { } func Test_Circle_Speed(t *testing.T) { - buf := NewStarBuffer(1048976) + buf, err := NewStarBuffer(1048976) + if err != nil { + t.Fatal(err) + } count := uint64(0) for i := 1; i <= 10; i++ { go func() { @@ -61,7 +141,10 @@ func Test_Circle_Speed(t *testing.T) { } func Test_Circle_Speed2(t *testing.T) { - buf := NewStarBuffer(8192) + buf, err := NewStarBuffer(8192) + if err != nil { + t.Fatal(err) + } count := uint64(0) for i := 1; i <= 10; i++ { go func() { diff --git a/io.go b/io.go index c397793..12ec3f4 100644 --- a/io.go +++ b/io.go @@ -16,12 +16,20 @@ type InputMsg struct { skipSliceSigErr bool } +type rawTerminalSession struct { + fd int + state *terminal.State + reader *bufio.Reader + redrawHint string + printNewline bool +} + func Passwd(hint string, defaultVal string) InputMsg { - return passwd(hint, defaultVal, "", false) + return passwd(hint, defaultVal, "", true) } func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg { - return passwd(hint, defaultVal, mask, false) + return passwd(hint, defaultVal, mask, true) } func PasswdResponseSignal(hint string, defaultVal string) InputMsg { @@ -36,54 +44,162 @@ func MessageBoxRaw(hint string, defaultVal string) InputMsg { return messageBox(hint, defaultVal) } -func messageBox(hint string, defaultVal string) InputMsg { - var ioBuf []rune +func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession, error) { if hint != "" { fmt.Print(hint) } - if strings.Index(hint, "\n") >= 0 { - hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):]) - } fd := int(os.Stdin.Fd()) state, err := terminal.MakeRaw(fd) if err != nil { - return InputMsg{msg: "", err: err} + return nil, err } - defer fmt.Println() - defer terminal.Restore(fd, state) - inputReader := bufio.NewReader(os.Stdin) - for { - b, _, err := inputReader.ReadRune() - if err != nil { - return InputMsg{msg: "", err: err} - } - if b == 0x0d { - strValue := strings.TrimSpace(string(ioBuf)) - if len(strValue) == 0 { - strValue = defaultVal - } - return InputMsg{msg: strValue, err: err} - } - if b == 0x08 || b == 0x7F { - if len(ioBuf) > 0 { - ioBuf = ioBuf[:len(ioBuf)-1] - } - fmt.Print("\r") - for i := 0; i < len(ioBuf)+2+len(hint); i++ { - fmt.Print(" ") - } - } else { - ioBuf = append(ioBuf, b) - } - fmt.Print("\r") - if hint != "" { - fmt.Print(hint) - } - fmt.Print(string(ioBuf)) + return &rawTerminalSession{ + fd: fd, + state: state, + reader: bufio.NewReader(os.Stdin), + redrawHint: promptRedrawHint(hint), + printNewline: printNewline, + }, nil +} + +func (session *rawTerminalSession) Close() { + if session == nil || session.state == nil { + return + } + _ = terminal.Restore(session.fd, session.state) + if session.printNewline { + fmt.Println() } } -func isSiganl(s rune) bool { +func (session *rawTerminalSession) Restore() error { + if session == nil || session.state == nil { + return nil + } + return terminal.Restore(session.fd, session.state) +} + +func promptRedrawHint(hint string) string { + if strings.Index(hint, "\n") >= 0 { + hint = hint[strings.LastIndex(hint, "\n")+1:] + } + return strings.TrimSpace(hint) +} + +func finalizeInputValue(raw string, defaultVal string) string { + raw = strings.TrimSpace(raw) + if len(raw) == 0 { + return defaultVal + } + return raw +} + +func renderRawEcho(ioBuf []rune, mask string) string { + if mask == "" { + return string(ioBuf) + } + return strings.Repeat(mask, len(ioBuf)) +} + +func redrawPromptLine(hint string, echo string, lastWidth int) int { + nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo) + clearWidth := lastWidth + if nowWidth > clearWidth { + clearWidth = nowWidth + } + fmt.Print("\r") + if clearWidth > 0 { + fmt.Print(strings.Repeat(" ", clearWidth)) + fmt.Print("\r") + } + if hint != "" { + fmt.Print(hint) + } + if echo != "" { + fmt.Print(echo) + } + return nowWidth +} + +func stringDisplayWidth(text string) int { + width := 0 + for _, r := range text { + width += runeDisplayWidth(r) + } + return width +} + +func runeDisplayWidth(r rune) int { + switch { + case r == '\t': + return 4 + case r < 0x20 || (r >= 0x7f && r < 0xa0): + return 0 + case isWideRune(r): + return 2 + default: + return 1 + } +} + +func isWideRune(r rune) bool { + return r >= 0x1100 && (r <= 0x115f || + r == 0x2329 || r == 0x232a || + (r >= 0x2e80 && r <= 0xa4cf && r != 0x303f) || + (r >= 0xac00 && r <= 0xd7a3) || + (r >= 0xf900 && r <= 0xfaff) || + (r >= 0xfe10 && r <= 0xfe19) || + (r >= 0xfe30 && r <= 0xfe6f) || + (r >= 0xff00 && r <= 0xff60) || + (r >= 0xffe0 && r <= 0xffe6) || + (r >= 0x1f300 && r <= 0x1f64f) || + (r >= 0x1f900 && r <= 0x1f9ff) || + (r >= 0x20000 && r <= 0x3fffd)) +} + +func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { + session, err := newRawTerminalSession(hint, true) + if err != nil { + return InputMsg{msg: "", err: err} + } + defer session.Close() + ioBuf := make([]rune, 0, 16) + lastWidth := 0 + for { + b, _, err := session.reader.ReadRune() + if err != nil { + return InputMsg{msg: "", err: err} + } + if handleSignal && isSignal(b) { + if runtime.GOOS != "windows" { + if err := session.Restore(); err != nil { + return InputMsg{msg: "", err: err} + } + } + if err := signal(b); err != nil { + return InputMsg{msg: "", err: err} + } + continue + } + switch b { + case 0x0d, 0x0a: + return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil} + case 0x08, 0x7F: + if len(ioBuf) > 0 { + ioBuf = ioBuf[:len(ioBuf)-1] + } + default: + ioBuf = append(ioBuf, b) + } + lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth) + } +} + +func messageBox(hint string, defaultVal string) InputMsg { + return rawLineInput(hint, defaultVal, "", false) +} + +func isSignal(s rune) bool { switch s { case 0x03, 0x1a, 0x1c: return true @@ -93,63 +209,7 @@ func isSiganl(s rune) bool { } func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg { - var ioBuf []rune - if hint != "" { - fmt.Print(hint) - } - if strings.Index(hint, "\n") >= 0 { - hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):]) - } - fd := int(os.Stdin.Fd()) - state, err := terminal.MakeRaw(fd) - if err != nil { - return InputMsg{msg: "", err: err} - } - defer fmt.Println() - defer terminal.Restore(fd, state) - inputReader := bufio.NewReader(os.Stdin) - for { - b, _, err := inputReader.ReadRune() - if err != nil { - return InputMsg{msg: "", err: err} - } - if handleSignal && isSiganl(b) { - if runtime.GOOS != "windows" { - terminal.Restore(fd, state) - } - if err := signal(b); err != nil { - return InputMsg{ - msg: "", - err: err, - } - } - } - if b == 0x0d { - strValue := strings.TrimSpace(string(ioBuf)) - if len(strValue) == 0 { - strValue = defaultVal - } - return InputMsg{msg: strValue, err: err} - } - if b == 0x08 || b == 0x7F { - if len(ioBuf) > 0 { - ioBuf = ioBuf[:len(ioBuf)-1] - } - fmt.Print("\r") - for i := 0; i < len(ioBuf)+2+len(hint); i++ { - fmt.Print(" ") - } - } else { - ioBuf = append(ioBuf, b) - } - fmt.Print("\r") - if hint != "" { - fmt.Print(hint) - } - for i := 0; i < len(ioBuf); i++ { - fmt.Print(mask) - } - } + return rawLineInput(hint, defaultVal, mask, handleSignal) } func MessageBox(hint string, defaultVal string) InputMsg { @@ -161,10 +221,7 @@ func MessageBox(hint string, defaultVal string) InputMsg { if err != nil { return InputMsg{msg: str, err: err} } - str = strings.TrimSpace(str) - if len(str) == 0 { - str = defaultVal - } + str = finalizeInputValue(str, defaultVal) return InputMsg{msg: str, err: err} } @@ -386,46 +443,61 @@ func (im InputMsg) MustSliceFloat32(sep string) []float32 { } func YesNo(hint string, defaults bool) bool { + res, err := YesNoE(hint, defaults) + if err != nil { + return defaults + } + return res +} + +func parseYesNoValue(raw string, defaults bool) (bool, bool) { + raw = strings.TrimSpace(strings.ToUpper(raw)) + if raw == "" { + return defaults, true + } + switch []rune(raw)[0] { + case 'Y': + return true, true + case 'N': + return false, true + default: + return false, false + } +} + +func YesNoE(hint string, defaults bool) (bool, error) { for { - res := strings.ToUpper(MessageBox(hint, "").MustString()) - if res == "" { - return defaults + res, err := MessageBox(hint, "").String() + if err != nil { + return false, err } - res = res[0:1] - if res == "Y" { - return true - } else if res == "N" { - return false + if answer, ok := parseYesNoValue(res, defaults); ok { + return answer, nil } } } func StopUntil(hint string, trigger string, repeat bool) error { - pressLen := len([]rune(trigger)) + triggerRunes := []rune(trigger) + pressLen := len(triggerRunes) if trigger == "" { pressLen = 1 } - fd := int(os.Stdin.Fd()) - if hint != "" { - fmt.Print(hint) - } - state, err := terminal.MakeRaw(fd) + session, err := newRawTerminalSession(hint, false) if err != nil { return err } - defer terminal.Restore(fd, state) - inputReader := bufio.NewReader(os.Stdin) - //ioBuf := make([]byte, pressLen) + defer session.Close() i := 0 for { - b, _, err := inputReader.ReadRune() + b, _, err := session.reader.ReadRune() if err != nil { return err } if trigger == "" { break } - if b == []rune(trigger)[i] { + if b == triggerRunes[i] { i++ if i == pressLen { break diff --git a/io_test.go b/io_test.go index 5a5df1c..f42af27 100644 --- a/io_test.go +++ b/io_test.go @@ -5,6 +5,41 @@ import ( "testing" ) +func TestPromptRedrawHint(t *testing.T) { + got := promptRedrawHint("头部提示\n 中文确认: ") + if got != "中文确认:" { + t.Fatalf("unexpected redraw hint: got %q", got) + } +} + +func TestStringDisplayWidth(t *testing.T) { + got := stringDisplayWidth("中a[]") + if got != 5 { + t.Fatalf("unexpected display width: got %d want 5", got) + } +} + +func TestParseYesNoValue(t *testing.T) { + cases := []struct { + name string + input string + defaults bool + want bool + ok bool + }{ + {name: "default", input: " ", defaults: true, want: true, ok: true}, + {name: "yes", input: "yes", defaults: false, want: true, ok: true}, + {name: "no", input: "No", defaults: true, want: false, ok: true}, + {name: "invalid", input: "maybe", defaults: false, want: false, ok: false}, + } + for _, tc := range cases { + got, ok := parseYesNoValue(tc.input, tc.defaults) + if got != tc.want || ok != tc.ok { + t.Fatalf("%s: got (%v, %v) want (%v, %v)", tc.name, got, ok, tc.want, tc.ok) + } + } +} + func Test_Slice(t *testing.T) { var data = InputMsg{ msg: "true,false,true,true,false,0,1,hello", diff --git a/sync.go b/sync.go index 4c3594a..4ca92a0 100644 --- a/sync.go +++ b/sync.go @@ -1,55 +1,142 @@ package stario import ( + "fmt" "sync" - "sync/atomic" - "time" +) + +type waitGroupAddMode uint8 + +const ( + waitGroupAddModeStrict waitGroupAddMode = iota + waitGroupAddModeLoose ) type WaitGroup struct { - wg *sync.WaitGroup - maxCount uint32 - allCount uint32 + wg sync.WaitGroup + mu sync.Mutex + cond *sync.Cond + initOnce sync.Once + maxCount int + running int + addMode waitGroupAddMode } func NewWaitGroup(maxCount int) WaitGroup { - return WaitGroup{wg: &sync.WaitGroup{}, maxCount: uint32(maxCount)} + if maxCount < 0 { + panic("stario: negative max wait count") + } + return WaitGroup{ + maxCount: maxCount, + addMode: waitGroupAddModeStrict, + } +} + +func (w *WaitGroup) init() { + w.initOnce.Do(func() { + w.cond = sync.NewCond(&w.mu) + }) } func (w *WaitGroup) Add(delta int) { - var Udelta uint32 + w.init() + if delta == 0 { + return + } if delta < 0 { - Udelta = uint32(-delta - 1) - } else { - Udelta = uint32(delta) + w.release(-delta) + return } - for { - allC := atomic.LoadUint32(&w.allCount) - if atomic.LoadUint32(&w.maxCount) == 0 || atomic.LoadUint32(&w.maxCount) >= allC+uint32(delta) { - if delta < 0 { - atomic.AddUint32(&w.allCount, ^uint32(Udelta)) - } else { - atomic.AddUint32(&w.allCount, uint32(Udelta)) - } - break + w.acquire(delta) +} + +func (w *WaitGroup) acquire(delta int) { + w.mu.Lock() + defer w.mu.Unlock() + if w.maxCount <= 0 { + w.running += delta + w.wg.Add(delta) + return + } + if delta == 1 { + w.wg.Add(1) + for w.maxCount > 0 && w.running >= w.maxCount { + w.cond.Wait() } - time.Sleep(time.Microsecond) + w.running++ + return } + if w.running+delta > w.maxCount { + if w.addMode == waitGroupAddModeStrict { + panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running)) + } + w.maxCount = w.running + delta + } + w.running += delta w.wg.Add(delta) } +func (w *WaitGroup) release(delta int) { + w.mu.Lock() + defer w.mu.Unlock() + if delta > w.running { + panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running)) + } + w.wg.Add(-delta) + w.running -= delta + w.cond.Broadcast() +} + func (w *WaitGroup) Done() { w.Add(-1) } +func (w *WaitGroup) Go(fn func()) { + w.Add(1) + go func() { + defer w.Done() + fn() + }() +} + func (w *WaitGroup) Wait() { + w.init() w.wg.Wait() } func (w *WaitGroup) GetMaxWaitNum() int { - return int(atomic.LoadUint32(&w.maxCount)) + w.init() + w.mu.Lock() + defer w.mu.Unlock() + return w.maxCount } func (w *WaitGroup) SetMaxWaitNum(num int) { - atomic.AddUint32(&w.maxCount, uint32(num)) + if num < 0 { + panic("stario: negative max wait count") + } + w.init() + w.mu.Lock() + w.maxCount = num + w.mu.Unlock() + w.cond.Broadcast() +} + +func (w *WaitGroup) SetStrictAddMode(strict bool) { + w.init() + w.mu.Lock() + if strict { + w.addMode = waitGroupAddModeStrict + } else { + w.addMode = waitGroupAddModeLoose + } + w.mu.Unlock() + w.cond.Broadcast() +} + +func (w *WaitGroup) StrictAddMode() bool { + w.init() + w.mu.Lock() + defer w.mu.Unlock() + return w.addMode == waitGroupAddModeStrict } diff --git a/sync_test.go b/sync_test.go new file mode 100644 index 0000000..180808c --- /dev/null +++ b/sync_test.go @@ -0,0 +1,171 @@ +package stario + +import ( + "testing" + "time" +) + +func TestWaitGroupAddBlocksAtLimit(t *testing.T) { + wg := NewWaitGroup(1) + wg.Add(1) + + unblocked := make(chan struct{}) + go func() { + wg.Add(1) + close(unblocked) + wg.Done() + }() + + select { + case <-unblocked: + t.Fatal("Add(1) should block while running count is at the limit") + case <-time.After(30 * time.Millisecond): + } + + wg.Done() + + select { + case <-unblocked: + case <-time.After(200 * time.Millisecond): + t.Fatal("Add(1) did not resume after capacity was released") + } + + wg.Wait() +} + +func TestWaitGroupStrictAddPanicsWhenBatchExceedsLimit(t *testing.T) { + wg := NewWaitGroup(1) + wg.Add(1) + defer wg.Done() + + defer func() { + if recover() == nil { + t.Fatal("expected Add(2) to panic in strict mode when it exceeds the limit") + } + }() + + wg.Add(2) +} + +func TestWaitGroupLooseAddExpandsLimit(t *testing.T) { + wg := NewWaitGroup(1) + wg.SetStrictAddMode(false) + + wg.Add(1) + wg.Add(2) + + if got := wg.GetMaxWaitNum(); got != 3 { + t.Fatalf("unexpected max count: got %d want 3", got) + } + + for i := 0; i < 3; i++ { + wg.Done() + } + wg.Wait() +} + +func TestWaitGroupSetMaxWaitNumOverridesValue(t *testing.T) { + wg := NewWaitGroup(1) + wg.SetMaxWaitNum(4) + if got := wg.GetMaxWaitNum(); got != 4 { + t.Fatalf("unexpected max count: got %d want 4", got) + } +} + +func TestWaitGroupSetMaxWaitNumZeroUnblocksWaitingAdd(t *testing.T) { + wg := NewWaitGroup(1) + wg.Add(1) + + unblocked := make(chan struct{}) + go func() { + wg.Add(1) + close(unblocked) + wg.Done() + }() + + select { + case <-unblocked: + t.Fatal("Add(1) should still be blocked before removing the limit") + case <-time.After(30 * time.Millisecond): + } + + wg.SetMaxWaitNum(0) + + select { + case <-unblocked: + case <-time.After(200 * time.Millisecond): + t.Fatal("Add(1) did not resume after the limit was removed") + } + + wg.Done() + wg.Wait() +} + +func TestWaitGroupGoWaitsForCompletion(t *testing.T) { + wg := NewWaitGroup(1) + done := make(chan struct{}) + + wg.Go(func() { + close(done) + }) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatal("Go(fn) did not run the function") + } + + finished := make(chan struct{}) + go func() { + wg.Wait() + close(finished) + }() + + select { + case <-finished: + case <-time.After(200 * time.Millisecond): + t.Fatal("Wait did not observe Go(fn) completion") + } +} + +func TestWaitGroupGoRespectsLimit(t *testing.T) { + wg := NewWaitGroup(1) + releaseFirst := make(chan struct{}) + firstStarted := make(chan struct{}) + secondStarted := make(chan struct{}) + secondReturned := make(chan struct{}) + + wg.Go(func() { + close(firstStarted) + <-releaseFirst + }) + + select { + case <-firstStarted: + case <-time.After(200 * time.Millisecond): + t.Fatal("first Go(fn) did not start") + } + + go func() { + wg.Go(func() { + close(secondStarted) + }) + close(secondReturned) + }() + + select { + case <-secondReturned: + t.Fatal("second Go(fn) should block while the limit is full") + case <-time.After(30 * time.Millisecond): + } + + close(releaseFirst) + + select { + case <-secondStarted: + case <-time.After(200 * time.Millisecond): + t.Fatal("second Go(fn) did not start after capacity was released") + } + + wg.Wait() +}