Skip to content

Commit 42cf251

Browse files
committed
fix-setlimit
1 parent 0de741c commit 42cf251

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

errgroup/errgroup.go

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,40 @@ package errgroup
88

99
import (
1010
"context"
11-
"fmt"
1211
"sync"
1312
)
1413

1514
type token struct{}
1615

16+
type taskLimiter struct {
17+
sem chan token
18+
}
19+
20+
func (t taskLimiter) run() {
21+
if t.sem == nil {
22+
return
23+
}
24+
t.sem <- token{}
25+
}
26+
27+
func (t taskLimiter) try() bool {
28+
if t.sem == nil {
29+
return true
30+
}
31+
select {
32+
case t.sem <- token{}:
33+
return true
34+
default:
35+
return false
36+
}
37+
}
38+
39+
func (t taskLimiter) done() {
40+
if t.sem != nil {
41+
<-t.sem
42+
}
43+
}
44+
1745
// A Group is a collection of goroutines working on subtasks that are part of
1846
// the same overall task.
1947
//
@@ -30,10 +58,8 @@ type Group struct {
3058
err error
3159
}
3260

33-
func (g *Group) done() {
34-
if g.sem != nil {
35-
<-g.sem
36-
}
61+
func (g *Group) done(t taskLimiter) {
62+
t.done()
3763
g.wg.Done()
3864
}
3965

@@ -64,13 +90,14 @@ func (g *Group) Wait() error {
6490
// The first call to return a non-nil error cancels the group; its error will be
6591
// returned by Wait.
6692
func (g *Group) Go(f func() error) {
67-
if g.sem != nil {
68-
g.sem <- token{}
93+
limiter := taskLimiter{
94+
sem: g.sem,
6995
}
96+
limiter.run()
7097

7198
g.wg.Add(1)
7299
go func() {
73-
defer g.done()
100+
defer g.done(limiter)
74101

75102
if err := f(); err != nil {
76103
g.errOnce.Do(func() {
@@ -88,18 +115,13 @@ func (g *Group) Go(f func() error) {
88115
//
89116
// The return value reports whether the goroutine was started.
90117
func (g *Group) TryGo(f func() error) bool {
91-
if g.sem != nil {
92-
select {
93-
case g.sem <- token{}:
94-
// Note: this allows barging iff channels in general allow barging.
95-
default:
96-
return false
97-
}
118+
limiter := taskLimiter{sem: g.sem}
119+
if !limiter.try() {
120+
return false
98121
}
99-
100122
g.wg.Add(1)
101123
go func() {
102-
defer g.done()
124+
defer g.done(limiter)
103125

104126
if err := f(); err != nil {
105127
g.errOnce.Do(func() {
@@ -118,15 +140,6 @@ func (g *Group) TryGo(f func() error) bool {
118140
//
119141
// Any subsequent call to the Go method will block until it can add an active
120142
// goroutine without exceeding the configured limit.
121-
//
122-
// The limit must not be modified while any goroutines in the group are active.
123143
func (g *Group) SetLimit(n int) {
124-
if n < 0 {
125-
g.sem = nil
126-
return
127-
}
128-
if len(g.sem) != 0 {
129-
panic(fmt.Errorf("errgroup: modify limit while %v goroutines in the group are still active", len(g.sem)))
130-
}
131144
g.sem = make(chan token, n)
132145
}

errgroup/errgroup_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func TestGoLimit(t *testing.T) {
234234
g := &errgroup.Group{}
235235
g.SetLimit(limit)
236236
var active int32
237-
for i := 0; i <= 1<<10; i++ {
237+
for i := 0; i <= 100; i++ {
238238
g.Go(func() error {
239239
n := atomic.AddInt32(&active, 1)
240240
if n > limit {
@@ -244,6 +244,9 @@ func TestGoLimit(t *testing.T) {
244244
atomic.AddInt32(&active, -1)
245245
return nil
246246
})
247+
if i%10 == 0 {
248+
g.SetLimit(2)
249+
}
247250
}
248251
if err := g.Wait(); err != nil {
249252
t.Fatal(err)

0 commit comments

Comments
 (0)