Skip to content

Commit fe76af9

Browse files
rohan-varmafacebook-github-bot
authored andcommitted
fix test_process_group_debug_info flaky test (pytorch#31533)
Summary: Pull Request resolved: pytorch#31533 Fixes this test that was flaky and has been disabled (see pytorch#31112) ghstack-source-id: 96038999 Test Plan: Run the test 1000 times and ensure that it passes. Differential Revision: D19203366 fbshipit-source-id: 7978cbb8ca0989a0a370a36349cdd4db3bb8345b
1 parent cc2d5ca commit fe76af9

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

test/rpc_test.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,11 +1203,17 @@ def test_rref_context_debug_info(self):
12031203
# barrier after check 3
12041204
dist.barrier()
12051205

1206-
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/31112")
12071206
@dist_init
12081207
@requires_process_group_agent("PROCESS_GROUP rpc backend specific test, skip")
12091208
def test_process_group_debug_info(self):
12101209
from torch.distributed.rpc.api import _agent
1210+
if not dist.is_initialized():
1211+
dist.init_process_group(
1212+
backend="gloo",
1213+
init_method=self.init_method,
1214+
rank=self.rank,
1215+
world_size=self.world_size,
1216+
)
12111217

12121218
NUM_THREAD = self.rpc_backend_options.num_send_recv_threads
12131219

@@ -1218,7 +1224,10 @@ def test_process_group_debug_info(self):
12181224
self.assertEqual(int(info["num_pending_requests"]), 0)
12191225
self.assertEqual(int(info["thread_pool_size"]), NUM_THREAD)
12201226
self.assertEqual(int(info["num_idle_threads"]), NUM_THREAD)
1221-
1227+
# for the above check, add a barrier to ensure that another worker
1228+
# cannot send a request before we check num_idle_threads, since we'd
1229+
# use up an idle thread if we start processing that request.
1230+
dist.barrier()
12221231
dst_rank = (self.rank + 1) % self.world_size
12231232
fut = rpc.rpc_async(
12241233
"worker{}".format(dst_rank),
@@ -1239,14 +1248,6 @@ def test_process_group_debug_info(self):
12391248
# might be either 1 or 2 busy threads
12401249
self.assertTrue(num_idle_threads in [NUM_THREAD - 1, NUM_THREAD - 2])
12411250

1242-
if not dist.is_initialized():
1243-
dist.init_process_group(
1244-
backend="gloo",
1245-
init_method=self.init_method,
1246-
rank=self.rank,
1247-
world_size=self.world_size,
1248-
)
1249-
12501251
# add a barrier to make sure the request is not finished before checking
12511252
# num_pending_requests
12521253
dist.barrier()

0 commit comments

Comments
 (0)