package stario import ( "testing" "time" ) func TestWaitGroupAddBlocksAtLimit(t *testing.T) { wg := NewWaitGroup(1) wg.Add(1) unblocked := make(chan struct{}) go func() { wg.Add(1) close(unblocked) wg.Done() }() select { case <-unblocked: t.Fatal("Add(1) should block while running count is at the limit") case <-time.After(30 * time.Millisecond): } wg.Done() select { case <-unblocked: case <-time.After(200 * time.Millisecond): t.Fatal("Add(1) did not resume after capacity was released") } wg.Wait() } func TestWaitGroupStrictAddPanicsWhenBatchExceedsLimit(t *testing.T) { wg := NewWaitGroup(1) wg.Add(1) defer wg.Done() defer func() { if recover() == nil { t.Fatal("expected Add(2) to panic in strict mode when it exceeds the limit") } }() wg.Add(2) } func TestWaitGroupLooseAddExpandsLimit(t *testing.T) { wg := NewWaitGroup(1) wg.SetStrictAddMode(false) wg.Add(1) wg.Add(2) if got := wg.GetMaxWaitNum(); got != 3 { t.Fatalf("unexpected max count: got %d want 3", got) } for i := 0; i < 3; i++ { wg.Done() } wg.Wait() } func TestWaitGroupSetMaxWaitNumOverridesValue(t *testing.T) { wg := NewWaitGroup(1) wg.SetMaxWaitNum(4) if got := wg.GetMaxWaitNum(); got != 4 { t.Fatalf("unexpected max count: got %d want 4", got) } } func TestWaitGroupSetMaxWaitNumZeroUnblocksWaitingAdd(t *testing.T) { wg := NewWaitGroup(1) wg.Add(1) unblocked := make(chan struct{}) go func() { wg.Add(1) close(unblocked) wg.Done() }() select { case <-unblocked: t.Fatal("Add(1) should still be blocked before removing the limit") case <-time.After(30 * time.Millisecond): } wg.SetMaxWaitNum(0) select { case <-unblocked: case <-time.After(200 * time.Millisecond): t.Fatal("Add(1) did not resume after the limit was removed") } wg.Done() wg.Wait() } func TestWaitGroupGoWaitsForCompletion(t *testing.T) { wg := NewWaitGroup(1) done := make(chan struct{}) wg.Go(func() { close(done) }) select { case <-done: case <-time.After(200 * time.Millisecond): t.Fatal("Go(fn) did not run the function") } finished := make(chan struct{}) go func() { wg.Wait() close(finished) }() select { case <-finished: case <-time.After(200 * time.Millisecond): t.Fatal("Wait did not observe Go(fn) completion") } } func TestWaitGroupGoRespectsLimit(t *testing.T) { wg := NewWaitGroup(1) releaseFirst := make(chan struct{}) firstStarted := make(chan struct{}) secondStarted := make(chan struct{}) secondReturned := make(chan struct{}) wg.Go(func() { close(firstStarted) <-releaseFirst }) select { case <-firstStarted: case <-time.After(200 * time.Millisecond): t.Fatal("first Go(fn) did not start") } go func() { wg.Go(func() { close(secondStarted) }) close(secondReturned) }() select { case <-secondReturned: t.Fatal("second Go(fn) should block while the limit is full") case <-time.After(30 * time.Millisecond): } close(releaseFirst) select { case <-secondStarted: case <-time.After(200 * time.Millisecond): t.Fatal("second Go(fn) did not start after capacity was released") } wg.Wait() }