@@ -1910,46 +1910,32 @@ def _inner_training_loop(
19101910 # number of training epochs: num_train_epochs
19111911 # number of training steps per epoch: num_update_steps_per_epoch
19121912 # total number of training steps to execute: max_steps
1913- total_train_batch_size = self ._train_batch_size * args .gradient_accumulation_steps * args .world_size
1914-
1915- len_dataloader = None
1916- num_train_tokens = None
1917- if has_length (train_dataloader ):
1918- len_dataloader = len (train_dataloader )
1919- num_update_steps_per_epoch = len_dataloader // args .gradient_accumulation_steps
1920- num_update_steps_per_epoch = max (num_update_steps_per_epoch , 1 )
1921- num_examples = self .num_examples (train_dataloader )
1922- if args .max_steps > 0 :
1923- max_steps = args .max_steps
1924- num_train_epochs = args .max_steps // num_update_steps_per_epoch + int (
1925- args .max_steps % num_update_steps_per_epoch > 0
1926- )
1927- # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1928- # the best we can do.
1929- num_train_samples = args .max_steps * total_train_batch_size
1930- if args .include_tokens_per_second :
1931- num_train_tokens = (
1932- self .num_tokens (train_dataloader , args .max_steps ) * args .gradient_accumulation_steps
1933- )
1934- else :
1935- max_steps = math .ceil (args .num_train_epochs * num_update_steps_per_epoch )
1936- num_train_epochs = math .ceil (args .num_train_epochs )
1937- num_train_samples = self .num_examples (train_dataloader ) * args .num_train_epochs
1938- if args .include_tokens_per_second :
1939- num_train_tokens = self .num_tokens (train_dataloader ) * args .num_train_epochs
1940- elif args .max_steps > 0 : # Rely on max_steps when dataloader does not have a working size
1941- max_steps = args .max_steps
1942- # Setting a very large number of epochs so we go as many times as necessary over the iterator.
1943- num_train_epochs = sys .maxsize
1944- num_update_steps_per_epoch = max_steps
1945- num_examples = total_train_batch_size * args .max_steps
1946- num_train_samples = args .max_steps * total_train_batch_size
1947- if args .include_tokens_per_second :
1948- num_train_tokens = self .num_tokens (train_dataloader , args .max_steps ) * args .gradient_accumulation_steps
1949- else :
1950- raise ValueError (
1951- "args.max_steps must be set to a positive value if dataloader does not have a length, was"
1952- f" { args .max_steps } "
1913+ profile_step = int (os .environ .get ('PROFILE_STEP' , - 1 ))
1914+ profile_epoch = int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
1915+ profile_duration = int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
1916+ profile_logdir = os .environ .get ('PROFILE_LOGDIR' , None )
1917+ total_train_batch_size = self ._train_batch_size * args .gradient_accumulation_steps
1918+ assert args .max_steps > 0
1919+ max_steps = args .max_steps
1920+ len_dataloader = len (train_dataloader )
1921+ num_update_steps_per_epoch = len_dataloader // args .gradient_accumulation_steps
1922+ num_update_steps_per_epoch = max (num_update_steps_per_epoch , 1 )
1923+ steps_for_counting_metrics = max_steps - num_update_steps_per_epoch * profile_epoch - profile_step
1924+ num_examples = self .num_examples (train_dataloader )
1925+ num_train_epochs = args .max_steps // num_update_steps_per_epoch + int (
1926+ args .max_steps % num_update_steps_per_epoch > 0
1927+ )
1928+ # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1929+ # the best we can do.
1930+ num_train_samples = args .max_steps * total_train_batch_size
1931+ metrics_num_train_samples = steps_for_counting_metrics * total_train_batch_size
1932+ metrics_num_train_tokens = None
1933+ if args .include_tokens_per_second :
1934+ num_train_tokens = (
1935+ self .num_tokens (train_dataloader , args .max_steps ) * args .gradient_accumulation_steps
1936+ )
1937+ metrics_num_train_tokens = (
1938+ self .num_tokens (train_dataloader , steps_for_counting_metrics ) * args .gradient_accumulation_steps
19531939 )
19541940
19551941 if DebugOption .UNDERFLOW_OVERFLOW in self .args .debug :
@@ -2153,10 +2139,6 @@ def _inner_training_loop(
21532139 total_batched_samples = 0
21542140 server = xp .start_server (9012 )
21552141 logger .info (f'Profiling server started: { str (server )} ' )
2156- profile_step = int (os .environ .get ('PROFILE_STEP' , - 1 ))
2157- profile_epoch = int (os .environ .get ('PROFILE_EPOCH' , - 1 ))
2158- profile_duration = int (os .environ .get ('PROFILE_DURATION_MS' , 20000 ))
2159- profile_logdir = os .environ .get ('PROFILE_LOGDIR' , None )
21602142 for epoch in range (epochs_trained , num_train_epochs ):
21612143 epoch_iterator = train_dataloader
21622144 if hasattr (epoch_iterator , "set_epoch" ):
@@ -2309,6 +2291,8 @@ def _inner_training_loop(
23092291 xm .wait_device_ops ()
23102292 import tempfile
23112293 xp .trace_detached ('127.0.0.1:9012' , profile_logdir or tempfile .mkdtemp (), profile_duration or 20000 )
2294+ # Assuming that the profiles start after model compilation is done.
2295+ after_compile_start_time = time .time ()
23122296
23132297 if self .control .should_epoch_stop or self .control .should_training_stop :
23142298 # PyTorch/XLA relies on the data loader to insert the mark_step for
@@ -2360,13 +2344,13 @@ def _inner_training_loop(
23602344 self ._total_loss_scalar += tr_loss .item ()
23612345 effective_global_step = max (self .state .global_step , 0.001 ) # Avoid ZeroDivisionError
23622346 train_loss = self ._total_loss_scalar / effective_global_step
2363-
2347+ xm . wait_device_ops ()
23642348 metrics = speed_metrics (
23652349 "train" ,
2366- start_time ,
2367- num_samples = num_train_samples ,
2368- num_steps = self . state . max_steps ,
2369- num_tokens = num_train_tokens ,
2350+ after_compile_start_time ,
2351+ num_samples = metrics_num_train_samples ,
2352+ num_steps = steps_for_counting_metrics ,
2353+ num_tokens = metrics_num_train_tokens ,
23702354 )
23712355 self .store_flos ()
23722356 metrics ["total_flos" ] = self .state .total_flos
0 commit comments