@@ -466,6 +466,9 @@ def fn(layer):
466466
467467 # very last
468468 self ._memory_tracker .stop_and_update_metrics ()
469+ if self .args .count_trained_tokens :
470+ self .trained_effective_tokens = 0
471+ self .trained_tokens = 0
469472
470473 def _wrap_amp_model (self , args , model ):
471474 logger .info ("Using half precision" )
@@ -1122,6 +1125,9 @@ def _inner_training_loop(
11221125 is_no_sync = True
11231126
11241127 sync_context = model .no_sync () if is_no_sync else contextlib .nullcontext ()
1128+ if self .args .count_trained_tokens :
1129+ self .trained_effective_tokens += (inputs ["input_ids" ] != self .args .pad_token_id ).sum ()
1130+ self .trained_tokens += inputs ["input_ids" ].numel ()
11251131 with sync_context :
11261132 if "step_control" in inspect .signature (self .training_step ).parameters :
11271133 tr_loss_step = self .training_step (model , inputs , step_control = step_control )
@@ -1570,6 +1576,27 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
15701576 self ._save_checkpoint (model , metrics = metrics )
15711577 logger .info (f"{ self .runtime_timer .log ()} " )
15721578 self .control = self .callback_handler .on_save (self .args , self .state , self .control )
1579+ self .log_trained_tokens ()
1580+
1581+ def log_trained_tokens (self ):
1582+ if self .args .count_trained_tokens :
1583+ token_list = []
1584+ for token_num in [self .trained_effective_tokens , self .trained_tokens ]:
1585+ tensors = token_num .reshape ([1 ])
1586+ if self .hcg ._sharding_degree > 1 :
1587+ output_tensors = []
1588+ paddle .distributed .all_gather (output_tensors , tensors , group = self .hcg ._sharding_comm_group )
1589+ tensors = paddle .concat (output_tensors ).sum ().reshape ([1 ])
1590+ if self .hcg ._dp_degree > 1 :
1591+ output_tensors = []
1592+ paddle .distributed .all_gather (output_tensors , tensors , group = self .hcg ._dp_comm_group )
1593+ tensors = paddle .concat (output_tensors ).sum ().reshape ([1 ])
1594+ token_list .append (tensors .item ())
1595+ if self .is_local_process_zero ():
1596+
1597+ logger .info (
1598+ f"Update to now, trained_effective_tokens: { token_list [0 ]} , trained_tokens: { token_list [1 ]} ."
1599+ )
15731600
15741601 def _get_learning_rate (self ):
15751602 return self .optimizer .get_lr ()
0 commit comments