@@ -2140,6 +2140,12 @@ def _inner_training_loop(
21402140 _  =  list (sampler )
21412141
21422142 total_batched_samples  =  0 
2143+  server  =  xp .start_server (9012 )
2144+  logger .info (f'Profiling server started: { str (server )}  )
2145+  profile_step  =  int (os .environ .get ('PROFILE_STEP' , - 1 ))
2146+  profile_epoch  =  int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
2147+  profile_duration  =  int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
2148+  profile_logdir  =  os .environ .get ('PROFILE_LOGDIR' , None )
21432149 for  epoch  in  range (epochs_trained , num_train_epochs ):
21442150 epoch_iterator  =  train_dataloader 
21452151 if  hasattr (epoch_iterator , "set_epoch" ):
@@ -2168,12 +2174,6 @@ def _inner_training_loop(
21682174 rng_to_sync  =  True 
21692175
21702176 step  =  - 1 
2171-  server  =  xp .start_server (9012 )
2172-  logger .info (f'Profiling server started: { str (server )}  )
2173-  profile_step  =  int (os .environ .get ('PROFILE_STEP' , - 1 ))
2174-  profile_epoch  =  int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
2175-  profile_duration  =  int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
2176-  profile_logdir  =  os .environ .get ('PROFILE_LOGDIR' , None )
21772177 for  step , inputs  in  enumerate (epoch_iterator ):
21782178 total_batched_samples  +=  1 
21792179
0 commit comments