Skip to content

Commit b8ebe3e

Browse files
authored
add count trained tokens (#9800)
* add count * fix * fix
1 parent 347d77c commit b8ebe3e

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def fn(layer):
466466

467467
# very last
468468
self._memory_tracker.stop_and_update_metrics()
469+
if self.args.count_trained_tokens:
470+
self.trained_effective_tokens = 0
471+
self.trained_tokens = 0
469472

470473
def _wrap_amp_model(self, args, model):
471474
logger.info("Using half precision")
@@ -1122,6 +1125,9 @@ def _inner_training_loop(
11221125
is_no_sync = True
11231126

11241127
sync_context = model.no_sync() if is_no_sync else contextlib.nullcontext()
1128+
if self.args.count_trained_tokens:
1129+
self.trained_effective_tokens += (inputs["input_ids"] != self.args.pad_token_id).sum()
1130+
self.trained_tokens += inputs["input_ids"].numel()
11251131
with sync_context:
11261132
if "step_control" in inspect.signature(self.training_step).parameters:
11271133
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
@@ -1570,6 +1576,27 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
15701576
self._save_checkpoint(model, metrics=metrics)
15711577
logger.info(f"{self.runtime_timer.log()}")
15721578
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
1579+
self.log_trained_tokens()
1580+
1581+
def log_trained_tokens(self):
1582+
if self.args.count_trained_tokens:
1583+
token_list = []
1584+
for token_num in [self.trained_effective_tokens, self.trained_tokens]:
1585+
tensors = token_num.reshape([1])
1586+
if self.hcg._sharding_degree > 1:
1587+
output_tensors = []
1588+
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._sharding_comm_group)
1589+
tensors = paddle.concat(output_tensors).sum().reshape([1])
1590+
if self.hcg._dp_degree > 1:
1591+
output_tensors = []
1592+
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._dp_comm_group)
1593+
tensors = paddle.concat(output_tensors).sum().reshape([1])
1594+
token_list.append(tensors.item())
1595+
if self.is_local_process_zero():
1596+
1597+
logger.info(
1598+
f"Update to now, trained_effective_tokens: {token_list[0]}, trained_tokens: {token_list[1]}."
1599+
)
15731600

15741601
def _get_learning_rate(self):
15751602
return self.optimizer.get_lr()

paddlenlp/trainer/training_args.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -978,6 +978,14 @@ class TrainingArguments:
978978
default=300,
979979
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
980980
)
981+
count_trained_tokens: bool = field(
982+
default=False,
983+
metadata={"help": "Whether to count trained tokens."},
984+
)
985+
pad_token_id: int = field(
986+
default=0,
987+
metadata={"help": "The id of the padding token."},
988+
)
981989

982990
def __post_init__(self):
983991
if in_auto_parallel_align_mode():

0 commit comments

Comments
 (0)