@@ -11,28 +11,28 @@ import (
1111// time, Wait can be used to block until all goroutines have finished or the
1212// given context is done.
1313type WaitGroupContext struct {
14- ctx context.Context
15- done chan struct {}
16- counter atomic.Int32
17- state atomic.Int32
14+ ctx context.Context
15+ sem chan struct {}
16+ state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
1817}
1918
2019// NewWaitGroupContext returns a new WaitGroupContext with Context ctx.
2120func NewWaitGroupContext (ctx context.Context ) * WaitGroupContext {
2221return & WaitGroupContext {
23- ctx : ctx ,
24- done : make (chan struct {}),
22+ ctx : ctx ,
23+ sem : make (chan struct {}),
2524}
2625}
2726
2827// Add adds delta, which may be negative, to the WaitGroupContext counter.
2928// If the counter becomes zero, all goroutines blocked on Wait are released.
3029// If the counter goes negative, Add panics.
3130func (wgc * WaitGroupContext ) Add (delta int ) {
32- counter := wgc .counter .Add (int32 (delta ))
33- if counter == 0 && wgc .state .CompareAndSwap (0 , 1 ) {
34- wgc .release ()
35- } else if counter < 0 && wgc .state .Load () == 0 {
31+ state := wgc .state .Add (uint64 (delta ) << 32 )
32+ counter := int32 (state >> 32 )
33+ if counter == 0 {
34+ wgc .notifyAll ()
35+ } else if counter < 0 {
3636panic ("async: negative WaitGroupContext counter" )
3737}
3838}
@@ -44,12 +44,36 @@ func (wgc *WaitGroupContext) Done() {
4444
4545// Wait blocks until the wait group counter is zero or ctx is done.
4646func (wgc * WaitGroupContext ) Wait () {
47- select {
48- case <- wgc .ctx .Done ():
49- case <- wgc .done :
47+ for {
48+ state := wgc .state .Load ()
49+ counter := int32 (state >> 32 )
50+ if counter == 0 {
51+ return
52+ }
53+ if wgc .state .CompareAndSwap (state , state + 1 ) {
54+ select {
55+ case <- wgc .sem :
56+ if wgc .state .Load () != 0 {
57+ panic ("async: WaitGroupContext is reused before " +
58+ "previous Wait has returned" )
59+ }
60+ case <- wgc .ctx .Done ():
61+ }
62+ return
63+ }
5064}
5165}
5266
53- func (wgc * WaitGroupContext ) release () {
54- close (wgc .done )
67+ // notifyAll releases all goroutines blocked in Wait and resets
68+ // the wait group state.
69+ func (wgc * WaitGroupContext ) notifyAll () {
70+ state := wgc .state .Load ()
71+ waiting := uint32 (state )
72+ wgc .state .Store (0 )
73+ for ; waiting != 0 ; waiting -- {
74+ select {
75+ case wgc .sem <- struct {}{}:
76+ case <- wgc .ctx .Done ():
77+ }
78+ }
5579}
0 commit comments