Skip to content

Commit 358b7ec

Browse files
committed
Start calculating metrics after profile step assuming profiling happens after compile
1 parent c905572 commit 358b7ec

File tree

2 files changed

+33
-54
lines changed

2 files changed

+33
-54
lines changed

examples/pytorch/language-modeling/run_clm.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -687,11 +687,6 @@ def compute_metrics(eval_preds):
687687

688688
metrics = train_result.metrics
689689

690-
max_train_samples = (
691-
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
692-
)
693-
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
694-
695690
trainer.log_metrics("train", metrics)
696691
trainer.save_metrics("train", metrics)
697692
trainer.save_state()

src/transformers/trainer.py

Lines changed: 33 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,46 +1910,32 @@ def _inner_training_loop(
19101910
# number of training epochs: num_train_epochs
19111911
# number of training steps per epoch: num_update_steps_per_epoch
19121912
# total number of training steps to execute: max_steps
1913-
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size
1914-
1915-
len_dataloader = None
1916-
num_train_tokens = None
1917-
if has_length(train_dataloader):
1918-
len_dataloader = len(train_dataloader)
1919-
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1920-
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1921-
num_examples = self.num_examples(train_dataloader)
1922-
if args.max_steps > 0:
1923-
max_steps = args.max_steps
1924-
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
1925-
args.max_steps % num_update_steps_per_epoch > 0
1926-
)
1927-
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1928-
# the best we can do.
1929-
num_train_samples = args.max_steps * total_train_batch_size
1930-
if args.include_tokens_per_second:
1931-
num_train_tokens = (
1932-
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
1933-
)
1934-
else:
1935-
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
1936-
num_train_epochs = math.ceil(args.num_train_epochs)
1937-
num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
1938-
if args.include_tokens_per_second:
1939-
num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
1940-
elif args.max_steps > 0: # Rely on max_steps when dataloader does not have a working size
1941-
max_steps = args.max_steps
1942-
# Setting a very large number of epochs so we go as many times as necessary over the iterator.
1943-
num_train_epochs = sys.maxsize
1944-
num_update_steps_per_epoch = max_steps
1945-
num_examples = total_train_batch_size * args.max_steps
1946-
num_train_samples = args.max_steps * total_train_batch_size
1947-
if args.include_tokens_per_second:
1948-
num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
1949-
else:
1950-
raise ValueError(
1951-
"args.max_steps must be set to a positive value if dataloader does not have a length, was"
1952-
f" {args.max_steps}"
1913+
profile_step = int(os.environ.get('PROFILE_STEP', -1))
1914+
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
1915+
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
1916+
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
1917+
total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps
1918+
assert args.max_steps > 0
1919+
max_steps = args.max_steps
1920+
len_dataloader = len(train_dataloader)
1921+
num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
1922+
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
1923+
steps_for_counting_metrics = max_steps - num_update_steps_per_epoch*profile_epoch - profile_step
1924+
num_examples = self.num_examples(train_dataloader)
1925+
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
1926+
args.max_steps % num_update_steps_per_epoch > 0
1927+
)
1928+
# May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
1929+
# the best we can do.
1930+
num_train_samples = args.max_steps * total_train_batch_size
1931+
metrics_num_train_samples = steps_for_counting_metrics * total_train_batch_size
1932+
metrics_num_train_tokens=None
1933+
if args.include_tokens_per_second:
1934+
num_train_tokens = (
1935+
self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
1936+
)
1937+
metrics_num_train_tokens = (
1938+
self.num_tokens(train_dataloader, steps_for_counting_metrics) * args.gradient_accumulation_steps
19531939
)
19541940

19551941
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
@@ -2153,10 +2139,6 @@ def _inner_training_loop(
21532139
total_batched_samples = 0
21542140
server = xp.start_server(9012)
21552141
logger.info(f'Profiling server started: {str(server)}')
2156-
profile_step = int(os.environ.get('PROFILE_STEP', -1))
2157-
profile_epoch = int(os.environ.get('PROFILE_EPOCH', -1))
2158-
profile_duration = int(os.environ.get('PROFILE_DURATION_MS', 20000))
2159-
profile_logdir = os.environ.get('PROFILE_LOGDIR', None)
21602142
for epoch in range(epochs_trained, num_train_epochs):
21612143
epoch_iterator = train_dataloader
21622144
if hasattr(epoch_iterator, "set_epoch"):
@@ -2309,6 +2291,8 @@ def _inner_training_loop(
23092291
xm.wait_device_ops()
23102292
import tempfile
23112293
xp.trace_detached('127.0.0.1:9012', profile_logdir or tempfile.mkdtemp(), profile_duration or 20000)
2294+
# Assuming that the profiles start after model compilation is done.
2295+
after_compile_start_time = time.time()
23122296

23132297
if self.control.should_epoch_stop or self.control.should_training_stop:
23142298
# PyTorch/XLA relies on the data loader to insert the mark_step for
@@ -2360,13 +2344,13 @@ def _inner_training_loop(
23602344
self._total_loss_scalar += tr_loss.item()
23612345
effective_global_step = max(self.state.global_step, 0.001) # Avoid ZeroDivisionError
23622346
train_loss = self._total_loss_scalar / effective_global_step
2363-
2347+
xm.wait_device_ops()
23642348
metrics = speed_metrics(
23652349
"train",
2366-
start_time,
2367-
num_samples=num_train_samples,
2368-
num_steps=self.state.max_steps,
2369-
num_tokens=num_train_tokens,
2350+
after_compile_start_time,
2351+
num_samples=metrics_num_train_samples,
2352+
num_steps=steps_for_counting_metrics,
2353+
num_tokens=metrics_num_train_tokens,
23702354
)
23712355
self.store_flos()
23722356
metrics["total_flos"] = self.state.total_flos

0 commit comments

Comments
 (0)