Skip to content

Commit df276c5

Browse files
committed
fix requests tracker concurrency
1 parent b5fc001 commit df276c5

File tree

1 file changed

+55
-56
lines changed

1 file changed

+55
-56
lines changed

pkg/controller/statefulset/stateful_set_control_test.go

Lines changed: 55 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -2423,51 +2423,53 @@ type requestTracker struct {
24232423
err error
24242424
after int
24252425

2426-
parallelLock sync.Mutex
2427-
parallel int
2428-
maxParallel int
2429-
2430-
delay time.Duration
2426+
// this block should be updated consistently
2427+
parallelLock sync.Mutex
2428+
shouldTrackParallelRequests bool
2429+
parallelRequests int
2430+
maxParallelRequests int
2431+
parallelRequestDelay time.Duration
24312432
}
24322433

2433-
func (rt *requestTracker) errorReady() bool {
2434-
rt.Lock()
2435-
defer rt.Unlock()
2436-
return rt.err != nil && rt.requests >= rt.after
2437-
}
2438-
2439-
func (rt *requestTracker) inc() {
2440-
rt.parallelLock.Lock()
2441-
rt.parallel++
2442-
if rt.maxParallel < rt.parallel {
2443-
rt.maxParallel = rt.parallel
2434+
func (rt *requestTracker) trackParallelRequests() {
2435+
if !rt.shouldTrackParallelRequests {
2436+
// do not track parallel requests unless specifically enabled
2437+
return
24442438
}
2445-
rt.parallelLock.Unlock()
2446-
2447-
rt.Lock()
2448-
defer rt.Unlock()
2449-
rt.requests++
2450-
if rt.delay != 0 {
2451-
time.Sleep(rt.delay)
2439+
if rt.parallelLock.TryLock() {
2440+
// lock acquired: we are the only or the first concurrent request
2441+
// initialize the next set of parallel requests
2442+
rt.parallelRequests = 1
2443+
} else {
2444+
// lock is held by other requests
2445+
// now wait for the lock to increase the parallelRequests
2446+
rt.parallelLock.Lock()
2447+
rt.parallelRequests++
2448+
}
2449+
defer rt.parallelLock.Unlock()
2450+
// update the local maximum of parallel collisions
2451+
if rt.maxParallelRequests < rt.parallelRequests {
2452+
rt.maxParallelRequests = rt.parallelRequests
2453+
}
2454+
// increase the chance of collisions
2455+
if rt.parallelRequestDelay > 0 {
2456+
time.Sleep(rt.parallelRequestDelay)
24522457
}
24532458
}
24542459

2455-
func (rt *requestTracker) reset() {
2456-
rt.parallelLock.Lock()
2457-
rt.parallel = 0
2458-
rt.parallelLock.Unlock()
2459-
2460-
rt.Lock()
2461-
defer rt.Unlock()
2462-
rt.err = nil
2463-
rt.after = 0
2464-
rt.delay = 0
2465-
}
2466-
2467-
func (rt *requestTracker) getErr() error {
2460+
func (rt *requestTracker) incWithOptionalError() error {
24682461
rt.Lock()
24692462
defer rt.Unlock()
2470-
return rt.err
2463+
rt.requests++
2464+
if rt.err != nil && rt.requests >= rt.after {
2465+
// reset and pass the error
2466+
defer func() {
2467+
rt.err = nil
2468+
rt.after = 0
2469+
}()
2470+
return rt.err
2471+
}
2472+
return nil
24712473
}
24722474

24732475
func newRequestTracker(requests int, err error, after int) requestTracker {
@@ -2512,10 +2514,9 @@ func newFakeObjectManager(informerFactory informers.SharedInformerFactory) *fake
25122514
}
25132515

25142516
func (om *fakeObjectManager) CreatePod(ctx context.Context, pod *v1.Pod) error {
2515-
defer om.createPodTracker.inc()
2516-
if om.createPodTracker.errorReady() {
2517-
defer om.createPodTracker.reset()
2518-
return om.createPodTracker.getErr()
2517+
defer om.createPodTracker.trackParallelRequests()
2518+
if err := om.createPodTracker.incWithOptionalError(); err != nil {
2519+
return err
25192520
}
25202521
pod.SetUID(types.UID(pod.Name + "-uid"))
25212522
return om.podsIndexer.Update(pod)
@@ -2526,19 +2527,17 @@ func (om *fakeObjectManager) GetPod(namespace, podName string) (*v1.Pod, error)
25262527
}
25272528

25282529
func (om *fakeObjectManager) UpdatePod(pod *v1.Pod) error {
2529-
defer om.updatePodTracker.inc()
2530-
if om.updatePodTracker.errorReady() {
2531-
defer om.updatePodTracker.reset()
2532-
return om.updatePodTracker.getErr()
2530+
defer om.updatePodTracker.trackParallelRequests()
2531+
if err := om.updatePodTracker.incWithOptionalError(); err != nil {
2532+
return err
25332533
}
25342534
return om.podsIndexer.Update(pod)
25352535
}
25362536

25372537
func (om *fakeObjectManager) DeletePod(pod *v1.Pod) error {
2538-
defer om.deletePodTracker.inc()
2539-
if om.deletePodTracker.errorReady() {
2540-
defer om.deletePodTracker.reset()
2541-
return om.deletePodTracker.getErr()
2538+
defer om.deletePodTracker.trackParallelRequests()
2539+
if err := om.deletePodTracker.incWithOptionalError(); err != nil {
2540+
return err
25422541
}
25432542
if key, err := controller.KeyFunc(pod); err != nil {
25442543
return err
@@ -2733,10 +2732,9 @@ func newFakeStatefulSetStatusUpdater(setInformer appsinformers.StatefulSetInform
27332732
}
27342733

27352734
func (ssu *fakeStatefulSetStatusUpdater) UpdateStatefulSetStatus(ctx context.Context, set *apps.StatefulSet, status *apps.StatefulSetStatus) error {
2736-
defer ssu.updateStatusTracker.inc()
2737-
if ssu.updateStatusTracker.errorReady() {
2738-
defer ssu.updateStatusTracker.reset()
2739-
return ssu.updateStatusTracker.err
2735+
defer ssu.updateStatusTracker.trackParallelRequests()
2736+
if err := ssu.updateStatusTracker.incWithOptionalError(); err != nil {
2737+
return err
27402738
}
27412739
set.Status = *status
27422740
ssu.setsIndexer.Update(set)
@@ -2985,7 +2983,8 @@ func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplica
29852983
diff := desiredReplicas - replicas
29862984
client := fake.NewSimpleClientset(set)
29872985
om, _, ssc := setupController(client)
2988-
om.createPodTracker.delay = time.Millisecond
2986+
om.createPodTracker.shouldTrackParallelRequests = true
2987+
om.createPodTracker.parallelRequestDelay = time.Millisecond
29892988

29902989
*set.Spec.Replicas = replicas
29912990
if err := parallelScaleUpStatefulSetControl(set, ssc, om, invariants); err != nil {
@@ -3017,8 +3016,8 @@ func parallelScale(t *testing.T, set *apps.StatefulSet, replicas, desiredReplica
30173016
t.Errorf("Failed to scale statefulset to %v replicas, got %v replicas", desiredReplicas, set.Status.Replicas)
30183017
}
30193018

3020-
if (diff < -1 || diff > 1) && om.createPodTracker.maxParallel <= 1 {
3021-
t.Errorf("want max parallel requests > 1, got %v", om.createPodTracker.maxParallel)
3019+
if (diff < -1 || diff > 1) && om.createPodTracker.maxParallelRequests <= 1 {
3020+
t.Errorf("want max parallel requests > 1, got %v", om.createPodTracker.maxParallelRequests)
30223021
}
30233022
}
30243023

0 commit comments

Comments
 (0)