@@ -140,6 +140,7 @@ def __init__(
140140 self .num_steps = training_params ["numb_steps" ]
141141 self .disp_file = training_params .get ("disp_file" , "lcurve.out" )
142142 self .disp_freq = training_params .get ("disp_freq" , 1000 )
143+ self .disp_avg = training_params .get ("disp_avg" , False )
143144 self .save_ckpt = training_params .get ("save_ckpt" , "model.ckpt" )
144145 self .save_freq = training_params .get ("save_freq" , 1000 )
145146 self .max_ckpt_keep = training_params .get ("max_ckpt_keep" , 5 )
@@ -808,23 +809,75 @@ def fake_model():
808809 else :
809810 raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
810811
812+ if self .disp_avg :
813+ # Accumulate loss for averaging over display interval
814+ self .step_count_in_interval += 1
815+ if not self .multi_task :
816+ # Accumulate loss for single task
817+ if not self .train_loss_accu :
818+ # Initialize accumulator with current loss structure
819+ for item in more_loss :
820+ if "l2_" not in item :
821+ self .train_loss_accu [item ] = 0.0
822+ for item in more_loss :
823+ if "l2_" not in item :
824+ self .train_loss_accu [item ] += more_loss [item ]
825+ else :
826+ # Accumulate loss for multi-task
827+ if task_key not in self .train_loss_accu :
828+ self .train_loss_accu [task_key ] = {}
829+ if task_key not in self .step_count_per_task :
830+ self .step_count_per_task [task_key ] = 0
831+ self .step_count_per_task [task_key ] += 1
832+
833+ for item in more_loss :
834+ if "l2_" not in item :
835+ if item not in self .train_loss_accu [task_key ]:
836+ self .train_loss_accu [task_key ][item ] = 0.0
837+ self .train_loss_accu [task_key ][item ] += more_loss [item ]
838+
811839 # Log and persist
812840 display_step_id = _step_id + 1
813841 if self .display_in_training and (
814842 display_step_id % self .disp_freq == 0 or display_step_id == 1
815843 ):
816844 self .wrapper .eval () # Will set to train mode before fininshing validation
817845
818- def log_loss_train (_loss , _more_loss , _task_key = "Default" ):
819- results = {}
820- rmse_val = {
821- item : _more_loss [item ]
822- for item in _more_loss
823- if "l2_" not in item
824- }
825- for item in sorted (rmse_val .keys ()):
826- results [item ] = rmse_val [item ]
827- return results
846+ if self .disp_avg :
847+
848+ def log_loss_train (_loss , _more_loss , _task_key = "Default" ):
849+ results = {}
850+ if not self .multi_task :
851+ # Use accumulated average loss for single task
852+ for item in self .train_loss_accu :
853+ results [item ] = (
854+ self .train_loss_accu [item ]
855+ / self .step_count_in_interval
856+ )
857+ else :
858+ # Use accumulated average loss for multi-task
859+ if (
860+ _task_key in self .train_loss_accu
861+ and _task_key in self .step_count_per_task
862+ ):
863+ for item in self .train_loss_accu [_task_key ]:
864+ results [item ] = (
865+ self .train_loss_accu [_task_key ][item ]
866+ / self .step_count_per_task [_task_key ]
867+ )
868+ return results
869+ else :
870+
871+ def log_loss_train (_loss , _more_loss , _task_key = "Default" ):
872+ results = {}
873+ rmse_val = {
874+ item : _more_loss [item ]
875+ for item in _more_loss
876+ if "l2_" not in item
877+ }
878+ for item in sorted (rmse_val .keys ()):
879+ results [item ] = rmse_val [item ]
880+ return results
828881
829882 def log_loss_valid (_task_key = "Default" ):
830883 single_results = {}
@@ -882,24 +935,31 @@ def log_loss_valid(_task_key="Default"):
882935 else :
883936 train_results = {_key : {} for _key in self .model_keys }
884937 valid_results = {_key : {} for _key in self .model_keys }
885- train_results [task_key ] = log_loss_train (
886- loss , more_loss , _task_key = task_key
887- )
888- for _key in self .model_keys :
889- if _key != task_key :
890- self .optimizer .zero_grad ()
891- input_dict , label_dict , _ = self .get_data (
892- is_train = True , task_key = _key
893- )
894- _ , loss , more_loss = self .wrapper (
895- ** input_dict ,
896- cur_lr = pref_lr ,
897- label = label_dict ,
898- task_key = _key ,
899- )
938+ if self .disp_avg :
939+ # For multi-task, use accumulated average loss for all tasks
940+ for _key in self .model_keys :
900941 train_results [_key ] = log_loss_train (
901942 loss , more_loss , _task_key = _key
902943 )
944+ else :
945+ train_results [task_key ] = log_loss_train (
946+ loss , more_loss , _task_key = task_key
947+ )
948+ for _key in self .model_keys :
949+ if _key != task_key :
950+ self .optimizer .zero_grad ()
951+ input_dict , label_dict , _ = self .get_data (
952+ is_train = True , task_key = _key
953+ )
954+ _ , loss , more_loss = self .wrapper (
955+ ** input_dict ,
956+ cur_lr = pref_lr ,
957+ label = label_dict ,
958+ task_key = _key ,
959+ )
960+ train_results [_key ] = log_loss_train (
961+ loss , more_loss , _task_key = _key
962+ )
903963 valid_results [_key ] = log_loss_valid (_task_key = _key )
904964 if self .rank == 0 :
905965 log .info (
@@ -921,6 +981,21 @@ def log_loss_valid(_task_key="Default"):
921981 )
922982 self .wrapper .train ()
923983
984+ if self .disp_avg :
985+ # Reset loss accumulators after display
986+ if not self .multi_task :
987+ for item in self .train_loss_accu :
988+ self .train_loss_accu [item ] = 0.0
989+ else :
990+ for task_key in self .model_keys :
991+ if task_key in self .train_loss_accu :
992+ for item in self .train_loss_accu [task_key ]:
993+ self .train_loss_accu [task_key ][item ] = 0.0
994+ if task_key in self .step_count_per_task :
995+ self .step_count_per_task [task_key ] = 0
996+ self .step_count_in_interval = 0
997+ self .last_display_step = display_step_id
998+
924999 current_time = time .time ()
9251000 train_time = current_time - self .t0
9261001 self .t0 = current_time
@@ -993,6 +1068,17 @@ def log_loss_valid(_task_key="Default"):
9931068 self .t0 = time .time ()
9941069 self .total_train_time = 0.0
9951070 self .timed_steps = 0
1071+
1072+ if self .disp_avg :
1073+ # Initialize loss accumulators
1074+ if not self .multi_task :
1075+ self .train_loss_accu = {}
1076+ else :
1077+ self .train_loss_accu = {key : {} for key in self .model_keys }
1078+ self .step_count_per_task = dict .fromkeys (self .model_keys , 0 )
1079+ self .step_count_in_interval = 0
1080+ self .last_display_step = 0
1081+
9961082 for step_id in range (self .start_step , self .num_steps ):
9971083 step (step_id )
9981084 if JIT :
0 commit comments