fix:修复一些边界性问题

This commit is contained in:
兔子 2026-03-21 19:15:43 +08:00
parent 3c83d1d39f
commit 3add9183b3
Signed by: b612
GPG Key ID: 99DD2222B612B612
7 changed files with 737 additions and 242 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
.sentrux/
agent_readme.md
target.md
.gocache
.gocache/
.tmp_*/
.idea

234
circle.go
View File

@ -2,151 +2,191 @@ package stario
import (
"errors"
"fmt"
"io"
"os"
"runtime"
"sync"
"sync/atomic"
"time"
)
var ErrStarBufferInvalidCapacity = errors.New("star buffer capacity must be greater than zero")
var ErrStarBufferClosed = errors.New("star buffer closed")
var ErrStarBufferWriteClosed = errors.New("star buffer write closed")
type StarBuffer struct {
io.Reader
io.Writer
io.Closer
datas []byte
pStart uint64
pEnd uint64
cap uint64
isClose atomic.Value
isEnd atomic.Value
rmu sync.Mutex
wmu sync.Mutex
datas []byte
pStart uint64
pEnd uint64
size uint64
cap uint64
isClose bool
isWriteEnd bool
mu sync.Mutex
notEmpty *sync.Cond
notFull *sync.Cond
}
func NewStarBuffer(cap uint64) *StarBuffer {
rtnBuffer := new(StarBuffer)
rtnBuffer.cap = cap
rtnBuffer.datas = make([]byte, cap)
rtnBuffer.isClose.Store(false)
rtnBuffer.isEnd.Store(false)
return rtnBuffer
func NewStarBuffer(cap uint64) (*StarBuffer, error) {
if cap == 0 {
return nil, ErrStarBufferInvalidCapacity
}
rtnBuffer := &StarBuffer{
cap: cap,
datas: make([]byte, cap),
}
rtnBuffer.notEmpty = sync.NewCond(&rtnBuffer.mu)
rtnBuffer.notFull = sync.NewCond(&rtnBuffer.mu)
return rtnBuffer, nil
}
func (star *StarBuffer) Free() uint64 {
return star.cap - star.Len()
star.mu.Lock()
defer star.mu.Unlock()
return star.cap - star.size
}
func (star *StarBuffer) Cap() uint64 {
star.mu.Lock()
defer star.mu.Unlock()
return star.cap
}
func (star *StarBuffer) Len() uint64 {
if star.pEnd >= star.pStart {
return star.pEnd - star.pStart
}
return star.pEnd - star.pStart + star.cap
star.mu.Lock()
defer star.mu.Unlock()
return star.size
}
func (star *StarBuffer) getByte() (byte, error) {
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) {
star.mu.Lock()
defer star.mu.Unlock()
for star.size == 0 && !star.isWriteEnd && !star.isClose {
star.notEmpty.Wait()
}
if star.size == 0 {
return 0, io.EOF
}
if star.Len() == 0 {
return 0, os.ErrNotExist
}
nowPtr := star.pStart
nextPtr := star.pStart + 1
if nextPtr >= star.cap {
nextPtr = 0
}
data := star.datas[nowPtr]
ok := atomic.CompareAndSwapUint64(&star.pStart, nowPtr, nextPtr)
if !ok {
return 0, os.ErrInvalid
}
data := star.datas[star.pStart]
star.pStart = (star.pStart + 1) % star.cap
star.size--
star.notFull.Broadcast()
return data, nil
}
func (star *StarBuffer) putByte(data byte) error {
if star.isClose.Load().(bool) {
return io.EOF
star.mu.Lock()
defer star.mu.Unlock()
for star.size == star.cap && !star.isClose && !star.isWriteEnd {
star.notFull.Wait()
}
nowPtr := star.pEnd
kariEnd := nowPtr + 1
if kariEnd == star.cap {
kariEnd = 0
if star.isClose {
return ErrStarBufferClosed
}
if kariEnd == atomic.LoadUint64(&star.pStart) {
for {
time.Sleep(time.Microsecond)
runtime.Gosched()
if kariEnd != atomic.LoadUint64(&star.pStart) {
break
}
}
if star.isWriteEnd {
return ErrStarBufferWriteClosed
}
star.datas[nowPtr] = data
if ok := atomic.CompareAndSwapUint64(&star.pEnd, nowPtr, kariEnd); !ok {
return os.ErrInvalid
star.datas[star.pEnd] = data
star.pEnd = (star.pEnd + 1) % star.cap
star.size++
star.notEmpty.Broadcast()
return nil
}
func (star *StarBuffer) EndWrite() error {
star.mu.Lock()
defer star.mu.Unlock()
if star.isClose {
return ErrStarBufferClosed
}
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil
}
func (star *StarBuffer) Close() error {
star.isClose.Store(true)
star.mu.Lock()
defer star.mu.Unlock()
star.isClose = true
star.isWriteEnd = true
star.notEmpty.Broadcast()
star.notFull.Broadcast()
return nil
}
func (star *StarBuffer) Read(buf []byte) (int, error) {
if star.isClose.Load().(bool) || (star.Len() == 0 && star.isEnd.Load().(bool)) {
return 0, io.EOF
}
if buf == nil {
return 0, errors.New("buffer is nil")
}
star.rmu.Lock()
defer star.rmu.Unlock()
var sum int = 0
for i := 0; i < len(buf); i++ {
data, err := star.getByte()
if err != nil {
if err == io.EOF {
if sum == 0 {
return sum, err
}
return sum, nil
}
if err == os.ErrNotExist {
i--
continue
}
return sum, nil
}
buf[i] = data
sum++
if len(buf) == 0 {
return 0, nil
}
star.mu.Lock()
defer star.mu.Unlock()
for star.size == 0 && !star.isWriteEnd && !star.isClose {
star.notEmpty.Wait()
}
if star.size == 0 {
return 0, io.EOF
}
sum := minInt(len(buf), int(star.size))
first := minInt(sum, int(star.cap-star.pStart))
copy(buf, star.datas[star.pStart:star.pStart+uint64(first)])
second := sum - first
if second > 0 {
copy(buf[first:], star.datas[:second])
}
star.pStart = (star.pStart + uint64(sum)) % star.cap
star.size -= uint64(sum)
star.notFull.Broadcast()
return sum, nil
}
func (star *StarBuffer) Write(bts []byte) (int, error) {
if bts == nil && !star.isEnd.Load().(bool) {
star.isEnd.Store(true)
if bts == nil {
return 0, star.EndWrite()
}
if len(bts) == 0 {
return 0, nil
}
if bts == nil || star.isClose.Load().(bool) {
return 0, io.EOF
}
star.wmu.Lock()
defer star.wmu.Unlock()
var sum = 0
for i := 0; i < len(bts); i++ {
err := star.putByte(bts[i])
if err != nil {
fmt.Println("Write bts err:", err)
return sum, err
star.mu.Lock()
defer star.mu.Unlock()
sum := 0
for sum < len(bts) {
for star.size == star.cap && !star.isClose && !star.isWriteEnd {
star.notFull.Wait()
}
sum++
if star.isClose {
if sum == 0 {
return 0, ErrStarBufferClosed
}
return sum, ErrStarBufferClosed
}
if star.isWriteEnd {
if sum == 0 {
return 0, ErrStarBufferWriteClosed
}
return sum, ErrStarBufferWriteClosed
}
space := int(star.cap - star.size)
if space == 0 {
continue
}
n := minInt(len(bts)-sum, space)
first := minInt(n, int(star.cap-star.pEnd))
copy(star.datas[star.pEnd:star.pEnd+uint64(first)], bts[sum:sum+first])
second := n - first
if second > 0 {
copy(star.datas[:second], bts[sum+first:sum+n])
}
star.pEnd = (star.pEnd + uint64(n)) % star.cap
star.size += uint64(n)
sum += n
star.notEmpty.Broadcast()
}
return sum, nil
}
func minInt(a int, b int) int {
if a < b {
return a
}
return b
}

View File

@ -1,14 +1,91 @@
package stario
import (
"bytes"
"fmt"
"io"
"sync/atomic"
"testing"
"time"
)
func TestNewStarBufferRejectsZeroCapacity(t *testing.T) {
buf, err := NewStarBuffer(0)
if err != ErrStarBufferInvalidCapacity {
t.Fatalf("unexpected error: %v", err)
}
if buf != nil {
t.Fatal("expected nil buffer when capacity is invalid")
}
}
func TestStarBufferEndWriteDrainsThenEOF(t *testing.T) {
buf, err := NewStarBuffer(4)
if err != nil {
t.Fatal(err)
}
if _, err := buf.Write([]byte("abcd")); err != nil {
t.Fatal(err)
}
if err := buf.EndWrite(); err != nil {
t.Fatal(err)
}
got := make([]byte, 4)
n, err := buf.Read(got)
if err != nil {
t.Fatal(err)
}
if n != 4 || !bytes.Equal(got[:n], []byte("abcd")) {
t.Fatalf("unexpected payload: n=%d data=%q", n, got[:n])
}
n, err = buf.Read(got)
if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after draining buffer, got n=%d err=%v", n, err)
}
if _, err := buf.Write([]byte("x")); err != ErrStarBufferWriteClosed {
t.Fatalf("unexpected write error after EndWrite: %v", err)
}
}
func TestStarBufferCloseAllowsDrain(t *testing.T) {
buf, err := NewStarBuffer(4)
if err != nil {
t.Fatal(err)
}
if _, err := buf.Write([]byte("ab")); err != nil {
t.Fatal(err)
}
if err := buf.Close(); err != nil {
t.Fatal(err)
}
got := make([]byte, 2)
n, err := buf.Read(got)
if err != nil {
t.Fatal(err)
}
if n != 2 || !bytes.Equal(got[:n], []byte("ab")) {
t.Fatalf("unexpected payload after close: n=%d data=%q", n, got[:n])
}
n, err = buf.Read(got)
if n != 0 || err != io.EOF {
t.Fatalf("expected EOF after draining closed buffer, got n=%d err=%v", n, err)
}
if _, err := buf.Write([]byte("x")); err != ErrStarBufferClosed {
t.Fatalf("unexpected write error after Close: %v", err)
}
}
func Test_Circle(t *testing.T) {
buf := NewStarBuffer(2048)
buf, err := NewStarBuffer(2048)
if err != nil {
t.Fatal(err)
}
go func() {
for {
//fmt.Println("write start")
@ -37,7 +114,10 @@ func Test_Circle(t *testing.T) {
}
func Test_Circle_Speed(t *testing.T) {
buf := NewStarBuffer(1048976)
buf, err := NewStarBuffer(1048976)
if err != nil {
t.Fatal(err)
}
count := uint64(0)
for i := 1; i <= 10; i++ {
go func() {
@ -61,7 +141,10 @@ func Test_Circle_Speed(t *testing.T) {
}
func Test_Circle_Speed2(t *testing.T) {
buf := NewStarBuffer(8192)
buf, err := NewStarBuffer(8192)
if err != nil {
t.Fatal(err)
}
count := uint64(0)
for i := 1; i <= 10; i++ {
go func() {

312
io.go
View File

@ -16,12 +16,20 @@ type InputMsg struct {
skipSliceSigErr bool
}
type rawTerminalSession struct {
fd int
state *terminal.State
reader *bufio.Reader
redrawHint string
printNewline bool
}
func Passwd(hint string, defaultVal string) InputMsg {
return passwd(hint, defaultVal, "", false)
return passwd(hint, defaultVal, "", true)
}
func PasswdWithMask(hint string, defaultVal string, mask string) InputMsg {
return passwd(hint, defaultVal, mask, false)
return passwd(hint, defaultVal, mask, true)
}
func PasswdResponseSignal(hint string, defaultVal string) InputMsg {
@ -36,54 +44,162 @@ func MessageBoxRaw(hint string, defaultVal string) InputMsg {
return messageBox(hint, defaultVal)
}
func messageBox(hint string, defaultVal string) InputMsg {
var ioBuf []rune
func newRawTerminalSession(hint string, printNewline bool) (*rawTerminalSession, error) {
if hint != "" {
fmt.Print(hint)
}
if strings.Index(hint, "\n") >= 0 {
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
}
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return InputMsg{msg: "", err: err}
return nil, err
}
defer fmt.Println()
defer terminal.Restore(fd, state)
inputReader := bufio.NewReader(os.Stdin)
for {
b, _, err := inputReader.ReadRune()
if err != nil {
return InputMsg{msg: "", err: err}
}
if b == 0x0d {
strValue := strings.TrimSpace(string(ioBuf))
if len(strValue) == 0 {
strValue = defaultVal
}
return InputMsg{msg: strValue, err: err}
}
if b == 0x08 || b == 0x7F {
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
fmt.Print("\r")
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
fmt.Print(" ")
}
} else {
ioBuf = append(ioBuf, b)
}
fmt.Print("\r")
if hint != "" {
fmt.Print(hint)
}
fmt.Print(string(ioBuf))
return &rawTerminalSession{
fd: fd,
state: state,
reader: bufio.NewReader(os.Stdin),
redrawHint: promptRedrawHint(hint),
printNewline: printNewline,
}, nil
}
func (session *rawTerminalSession) Close() {
if session == nil || session.state == nil {
return
}
_ = terminal.Restore(session.fd, session.state)
if session.printNewline {
fmt.Println()
}
}
func isSiganl(s rune) bool {
func (session *rawTerminalSession) Restore() error {
if session == nil || session.state == nil {
return nil
}
return terminal.Restore(session.fd, session.state)
}
func promptRedrawHint(hint string) string {
if strings.Index(hint, "\n") >= 0 {
hint = hint[strings.LastIndex(hint, "\n")+1:]
}
return strings.TrimSpace(hint)
}
func finalizeInputValue(raw string, defaultVal string) string {
raw = strings.TrimSpace(raw)
if len(raw) == 0 {
return defaultVal
}
return raw
}
func renderRawEcho(ioBuf []rune, mask string) string {
if mask == "" {
return string(ioBuf)
}
return strings.Repeat(mask, len(ioBuf))
}
func redrawPromptLine(hint string, echo string, lastWidth int) int {
nowWidth := stringDisplayWidth(hint) + stringDisplayWidth(echo)
clearWidth := lastWidth
if nowWidth > clearWidth {
clearWidth = nowWidth
}
fmt.Print("\r")
if clearWidth > 0 {
fmt.Print(strings.Repeat(" ", clearWidth))
fmt.Print("\r")
}
if hint != "" {
fmt.Print(hint)
}
if echo != "" {
fmt.Print(echo)
}
return nowWidth
}
func stringDisplayWidth(text string) int {
width := 0
for _, r := range text {
width += runeDisplayWidth(r)
}
return width
}
func runeDisplayWidth(r rune) int {
switch {
case r == '\t':
return 4
case r < 0x20 || (r >= 0x7f && r < 0xa0):
return 0
case isWideRune(r):
return 2
default:
return 1
}
}
func isWideRune(r rune) bool {
return r >= 0x1100 && (r <= 0x115f ||
r == 0x2329 || r == 0x232a ||
(r >= 0x2e80 && r <= 0xa4cf && r != 0x303f) ||
(r >= 0xac00 && r <= 0xd7a3) ||
(r >= 0xf900 && r <= 0xfaff) ||
(r >= 0xfe10 && r <= 0xfe19) ||
(r >= 0xfe30 && r <= 0xfe6f) ||
(r >= 0xff00 && r <= 0xff60) ||
(r >= 0xffe0 && r <= 0xffe6) ||
(r >= 0x1f300 && r <= 0x1f64f) ||
(r >= 0x1f900 && r <= 0x1f9ff) ||
(r >= 0x20000 && r <= 0x3fffd))
}
func rawLineInput(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
session, err := newRawTerminalSession(hint, true)
if err != nil {
return InputMsg{msg: "", err: err}
}
defer session.Close()
ioBuf := make([]rune, 0, 16)
lastWidth := 0
for {
b, _, err := session.reader.ReadRune()
if err != nil {
return InputMsg{msg: "", err: err}
}
if handleSignal && isSignal(b) {
if runtime.GOOS != "windows" {
if err := session.Restore(); err != nil {
return InputMsg{msg: "", err: err}
}
}
if err := signal(b); err != nil {
return InputMsg{msg: "", err: err}
}
continue
}
switch b {
case 0x0d, 0x0a:
return InputMsg{msg: finalizeInputValue(string(ioBuf), defaultVal), err: nil}
case 0x08, 0x7F:
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
default:
ioBuf = append(ioBuf, b)
}
lastWidth = redrawPromptLine(session.redrawHint, renderRawEcho(ioBuf, mask), lastWidth)
}
}
func messageBox(hint string, defaultVal string) InputMsg {
return rawLineInput(hint, defaultVal, "", false)
}
func isSignal(s rune) bool {
switch s {
case 0x03, 0x1a, 0x1c:
return true
@ -93,63 +209,7 @@ func isSiganl(s rune) bool {
}
func passwd(hint string, defaultVal string, mask string, handleSignal bool) InputMsg {
var ioBuf []rune
if hint != "" {
fmt.Print(hint)
}
if strings.Index(hint, "\n") >= 0 {
hint = strings.TrimSpace(hint[strings.LastIndex(hint, "\n"):])
}
fd := int(os.Stdin.Fd())
state, err := terminal.MakeRaw(fd)
if err != nil {
return InputMsg{msg: "", err: err}
}
defer fmt.Println()
defer terminal.Restore(fd, state)
inputReader := bufio.NewReader(os.Stdin)
for {
b, _, err := inputReader.ReadRune()
if err != nil {
return InputMsg{msg: "", err: err}
}
if handleSignal && isSiganl(b) {
if runtime.GOOS != "windows" {
terminal.Restore(fd, state)
}
if err := signal(b); err != nil {
return InputMsg{
msg: "",
err: err,
}
}
}
if b == 0x0d {
strValue := strings.TrimSpace(string(ioBuf))
if len(strValue) == 0 {
strValue = defaultVal
}
return InputMsg{msg: strValue, err: err}
}
if b == 0x08 || b == 0x7F {
if len(ioBuf) > 0 {
ioBuf = ioBuf[:len(ioBuf)-1]
}
fmt.Print("\r")
for i := 0; i < len(ioBuf)+2+len(hint); i++ {
fmt.Print(" ")
}
} else {
ioBuf = append(ioBuf, b)
}
fmt.Print("\r")
if hint != "" {
fmt.Print(hint)
}
for i := 0; i < len(ioBuf); i++ {
fmt.Print(mask)
}
}
return rawLineInput(hint, defaultVal, mask, handleSignal)
}
func MessageBox(hint string, defaultVal string) InputMsg {
@ -161,10 +221,7 @@ func MessageBox(hint string, defaultVal string) InputMsg {
if err != nil {
return InputMsg{msg: str, err: err}
}
str = strings.TrimSpace(str)
if len(str) == 0 {
str = defaultVal
}
str = finalizeInputValue(str, defaultVal)
return InputMsg{msg: str, err: err}
}
@ -386,46 +443,61 @@ func (im InputMsg) MustSliceFloat32(sep string) []float32 {
}
func YesNo(hint string, defaults bool) bool {
res, err := YesNoE(hint, defaults)
if err != nil {
return defaults
}
return res
}
func parseYesNoValue(raw string, defaults bool) (bool, bool) {
raw = strings.TrimSpace(strings.ToUpper(raw))
if raw == "" {
return defaults, true
}
switch []rune(raw)[0] {
case 'Y':
return true, true
case 'N':
return false, true
default:
return false, false
}
}
func YesNoE(hint string, defaults bool) (bool, error) {
for {
res := strings.ToUpper(MessageBox(hint, "").MustString())
if res == "" {
return defaults
res, err := MessageBox(hint, "").String()
if err != nil {
return false, err
}
res = res[0:1]
if res == "Y" {
return true
} else if res == "N" {
return false
if answer, ok := parseYesNoValue(res, defaults); ok {
return answer, nil
}
}
}
func StopUntil(hint string, trigger string, repeat bool) error {
pressLen := len([]rune(trigger))
triggerRunes := []rune(trigger)
pressLen := len(triggerRunes)
if trigger == "" {
pressLen = 1
}
fd := int(os.Stdin.Fd())
if hint != "" {
fmt.Print(hint)
}
state, err := terminal.MakeRaw(fd)
session, err := newRawTerminalSession(hint, false)
if err != nil {
return err
}
defer terminal.Restore(fd, state)
inputReader := bufio.NewReader(os.Stdin)
//ioBuf := make([]byte, pressLen)
defer session.Close()
i := 0
for {
b, _, err := inputReader.ReadRune()
b, _, err := session.reader.ReadRune()
if err != nil {
return err
}
if trigger == "" {
break
}
if b == []rune(trigger)[i] {
if b == triggerRunes[i] {
i++
if i == pressLen {
break

View File

@ -5,6 +5,41 @@ import (
"testing"
)
func TestPromptRedrawHint(t *testing.T) {
got := promptRedrawHint("头部提示\n 中文确认: ")
if got != "中文确认:" {
t.Fatalf("unexpected redraw hint: got %q", got)
}
}
func TestStringDisplayWidth(t *testing.T) {
got := stringDisplayWidth("中a[]")
if got != 5 {
t.Fatalf("unexpected display width: got %d want 5", got)
}
}
func TestParseYesNoValue(t *testing.T) {
cases := []struct {
name string
input string
defaults bool
want bool
ok bool
}{
{name: "default", input: " ", defaults: true, want: true, ok: true},
{name: "yes", input: "yes", defaults: false, want: true, ok: true},
{name: "no", input: "No", defaults: true, want: false, ok: true},
{name: "invalid", input: "maybe", defaults: false, want: false, ok: false},
}
for _, tc := range cases {
got, ok := parseYesNoValue(tc.input, tc.defaults)
if got != tc.want || ok != tc.ok {
t.Fatalf("%s: got (%v, %v) want (%v, %v)", tc.name, got, ok, tc.want, tc.ok)
}
}
}
func Test_Slice(t *testing.T) {
var data = InputMsg{
msg: "true,false,true,true,false,0,1,hello",

131
sync.go
View File

@ -1,55 +1,142 @@
package stario
import (
"fmt"
"sync"
"sync/atomic"
"time"
)
type waitGroupAddMode uint8
const (
waitGroupAddModeStrict waitGroupAddMode = iota
waitGroupAddModeLoose
)
type WaitGroup struct {
wg *sync.WaitGroup
maxCount uint32
allCount uint32
wg sync.WaitGroup
mu sync.Mutex
cond *sync.Cond
initOnce sync.Once
maxCount int
running int
addMode waitGroupAddMode
}
func NewWaitGroup(maxCount int) WaitGroup {
return WaitGroup{wg: &sync.WaitGroup{}, maxCount: uint32(maxCount)}
if maxCount < 0 {
panic("stario: negative max wait count")
}
return WaitGroup{
maxCount: maxCount,
addMode: waitGroupAddModeStrict,
}
}
func (w *WaitGroup) init() {
w.initOnce.Do(func() {
w.cond = sync.NewCond(&w.mu)
})
}
func (w *WaitGroup) Add(delta int) {
var Udelta uint32
w.init()
if delta == 0 {
return
}
if delta < 0 {
Udelta = uint32(-delta - 1)
} else {
Udelta = uint32(delta)
w.release(-delta)
return
}
for {
allC := atomic.LoadUint32(&w.allCount)
if atomic.LoadUint32(&w.maxCount) == 0 || atomic.LoadUint32(&w.maxCount) >= allC+uint32(delta) {
if delta < 0 {
atomic.AddUint32(&w.allCount, ^uint32(Udelta))
} else {
atomic.AddUint32(&w.allCount, uint32(Udelta))
}
break
w.acquire(delta)
}
func (w *WaitGroup) acquire(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if w.maxCount <= 0 {
w.running += delta
w.wg.Add(delta)
return
}
if delta == 1 {
w.wg.Add(1)
for w.maxCount > 0 && w.running >= w.maxCount {
w.cond.Wait()
}
time.Sleep(time.Microsecond)
w.running++
return
}
if w.running+delta > w.maxCount {
if w.addMode == waitGroupAddModeStrict {
panic(fmt.Sprintf("stario: WaitGroup.Add(%d) exceeds max limit %d with %d running", delta, w.maxCount, w.running))
}
w.maxCount = w.running + delta
}
w.running += delta
w.wg.Add(delta)
}
func (w *WaitGroup) release(delta int) {
w.mu.Lock()
defer w.mu.Unlock()
if delta > w.running {
panic(fmt.Sprintf("stario: WaitGroup.Done releases %d tasks but only %d running", delta, w.running))
}
w.wg.Add(-delta)
w.running -= delta
w.cond.Broadcast()
}
func (w *WaitGroup) Done() {
w.Add(-1)
}
func (w *WaitGroup) Go(fn func()) {
w.Add(1)
go func() {
defer w.Done()
fn()
}()
}
func (w *WaitGroup) Wait() {
w.init()
w.wg.Wait()
}
func (w *WaitGroup) GetMaxWaitNum() int {
return int(atomic.LoadUint32(&w.maxCount))
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.maxCount
}
func (w *WaitGroup) SetMaxWaitNum(num int) {
atomic.AddUint32(&w.maxCount, uint32(num))
if num < 0 {
panic("stario: negative max wait count")
}
w.init()
w.mu.Lock()
w.maxCount = num
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) SetStrictAddMode(strict bool) {
w.init()
w.mu.Lock()
if strict {
w.addMode = waitGroupAddModeStrict
} else {
w.addMode = waitGroupAddModeLoose
}
w.mu.Unlock()
w.cond.Broadcast()
}
func (w *WaitGroup) StrictAddMode() bool {
w.init()
w.mu.Lock()
defer w.mu.Unlock()
return w.addMode == waitGroupAddModeStrict
}

171
sync_test.go Normal file
View File

@ -0,0 +1,171 @@
package stario
import (
"testing"
"time"
)
func TestWaitGroupAddBlocksAtLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
unblocked := make(chan struct{})
go func() {
wg.Add(1)
close(unblocked)
wg.Done()
}()
select {
case <-unblocked:
t.Fatal("Add(1) should block while running count is at the limit")
case <-time.After(30 * time.Millisecond):
}
wg.Done()
select {
case <-unblocked:
case <-time.After(200 * time.Millisecond):
t.Fatal("Add(1) did not resume after capacity was released")
}
wg.Wait()
}
func TestWaitGroupStrictAddPanicsWhenBatchExceedsLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
defer wg.Done()
defer func() {
if recover() == nil {
t.Fatal("expected Add(2) to panic in strict mode when it exceeds the limit")
}
}()
wg.Add(2)
}
func TestWaitGroupLooseAddExpandsLimit(t *testing.T) {
wg := NewWaitGroup(1)
wg.SetStrictAddMode(false)
wg.Add(1)
wg.Add(2)
if got := wg.GetMaxWaitNum(); got != 3 {
t.Fatalf("unexpected max count: got %d want 3", got)
}
for i := 0; i < 3; i++ {
wg.Done()
}
wg.Wait()
}
func TestWaitGroupSetMaxWaitNumOverridesValue(t *testing.T) {
wg := NewWaitGroup(1)
wg.SetMaxWaitNum(4)
if got := wg.GetMaxWaitNum(); got != 4 {
t.Fatalf("unexpected max count: got %d want 4", got)
}
}
func TestWaitGroupSetMaxWaitNumZeroUnblocksWaitingAdd(t *testing.T) {
wg := NewWaitGroup(1)
wg.Add(1)
unblocked := make(chan struct{})
go func() {
wg.Add(1)
close(unblocked)
wg.Done()
}()
select {
case <-unblocked:
t.Fatal("Add(1) should still be blocked before removing the limit")
case <-time.After(30 * time.Millisecond):
}
wg.SetMaxWaitNum(0)
select {
case <-unblocked:
case <-time.After(200 * time.Millisecond):
t.Fatal("Add(1) did not resume after the limit was removed")
}
wg.Done()
wg.Wait()
}
func TestWaitGroupGoWaitsForCompletion(t *testing.T) {
wg := NewWaitGroup(1)
done := make(chan struct{})
wg.Go(func() {
close(done)
})
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("Go(fn) did not run the function")
}
finished := make(chan struct{})
go func() {
wg.Wait()
close(finished)
}()
select {
case <-finished:
case <-time.After(200 * time.Millisecond):
t.Fatal("Wait did not observe Go(fn) completion")
}
}
func TestWaitGroupGoRespectsLimit(t *testing.T) {
wg := NewWaitGroup(1)
releaseFirst := make(chan struct{})
firstStarted := make(chan struct{})
secondStarted := make(chan struct{})
secondReturned := make(chan struct{})
wg.Go(func() {
close(firstStarted)
<-releaseFirst
})
select {
case <-firstStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("first Go(fn) did not start")
}
go func() {
wg.Go(func() {
close(secondStarted)
})
close(secondReturned)
}()
select {
case <-secondReturned:
t.Fatal("second Go(fn) should block while the limit is full")
case <-time.After(30 * time.Millisecond):
}
close(releaseFirst)
select {
case <-secondStarted:
case <-time.After(200 * time.Millisecond):
t.Fatal("second Go(fn) did not start after capacity was released")
}
wg.Wait()
}