Skip to content

Commit 7f88656

Browse files
committed
Fix profiler
1 parent 850ac3c commit 7f88656

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

src/transformers/trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,12 @@ def _inner_training_loop(
21402140
_ = list(sampler)
21412141

21422142
total_batched_samples = 0
2143+
server = xp.start_server(9012)
2144+
logger.info(f'Profiling server started: {str(server)}')
2145+
profile_step = int(os.environ.get('PROFILE_STEP', -1))
2146+
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
2147+
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
2148+
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
21432149
for epoch in range(epochs_trained, num_train_epochs):
21442150
epoch_iterator = train_dataloader
21452151
if hasattr(epoch_iterator, "set_epoch"):
@@ -2168,12 +2174,6 @@ def _inner_training_loop(
21682174
rng_to_sync = True
21692175

21702176
step = -1
2171-
server = xp.start_server(9012)
2172-
logger.info(f'Profiling server started: {str(server)}')
2173-
profile_step = int(os.environ.get('PROFILE_STEP', -1))
2174-
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
2175-
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
2176-
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
21772177
for step, inputs in enumerate(epoch_iterator):
21782178
total_batched_samples += 1
21792179

0 commit comments

Comments
 (0)