77
88import keras .src .callbacks
99
10- from ...utils .plot_utils import make_figure , add_titles_and_labels
10+ from matplotlib .colors import Normalize
11+ from ...utils .plot_utils import make_figure , add_titles_and_labels , gradient_line , gradient_legend
1112
1213
1314def loss (
1415 history : keras .callbacks .History ,
1516 train_key : str = "loss" ,
1617 val_key : str = "val_loss" ,
17- moving_average : bool = False ,
18+ moving_average : bool = True ,
1819 per_training_step : bool = False ,
19- ma_window_fraction : float = 0.01 ,
20+ moving_average_span : int = 10 ,
2021 figsize : Sequence [float ] = None ,
2122 train_color : str = "#132a70" ,
22- val_color : str = "black" ,
23+ val_color : str = None ,
24+ val_colormap : str = 'viridis' ,
2325 lw_train : float = 2.0 ,
2426 lw_val : float = 3.0 ,
27+ val_marker_type : str = "o" ,
28+ val_marker_size : int = 34 ,
29+ grid_alpha : float = 0.2 ,
2530 legend_fontsize : int = 14 ,
2631 label_fontsize : int = 14 ,
2732 title_fontsize : int = 16 ,
@@ -39,7 +44,7 @@ def loss(
3944 val_key : str, optional, default: "val_loss"
4045 The validation loss key to look for in the history
4146 moving_average : bool, optional, default: False
42- A flag for adding a moving average line of the train_losses.
47+ A flag for adding an exponential moving average line of the train_losses.
4348 per_training_step : bool, optional, default: False
4449 A flag for making loss trajectory detailed (to training steps) rather than per epoch.
4550 ma_window_fraction : int, optional, default: 0.01
@@ -99,27 +104,51 @@ def loss(
99104 # Loop through loss entries and populate plot
100105 for i , ax in enumerate (axes .flat ):
101106 # Plot train curve
102- ax .plot (train_step_index , train_losses .iloc [:, i ], color = train_color , lw = lw_train , alpha = 0.9 , label = "Training" )
103- if moving_average and train_losses .columns [i ] == "Loss" :
104- moving_average_window = int (train_losses .shape [0 ] * ma_window_fraction )
105- smoothed_loss = train_losses .iloc [:, i ].rolling (window = moving_average_window ).mean ()
107+ ax .plot (train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.2 , label = "Training" )
108+ if moving_average :
109+ smoothed_loss = train_losses .iloc [:, 0 ].ewm (span = moving_average_span , adjust = True ).mean ()
106110 ax .plot (train_step_index , smoothed_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
107111
108112 # Plot optional val curve
109113 if val_losses is not None :
110- if i < val_losses .shape [1 ]:
111- ax .plot (
112- val_step_index ,
113- val_losses .iloc [:, i ],
114- linestyle = "--" ,
115- marker = "o" ,
116- color = val_color ,
117- lw = lw_val ,
118- label = "Validation" ,
119- )
114+ if val_color is not None :
115+ ax .plot (
116+ val_step_index ,
117+ val_losses .iloc [:, 0 ],
118+ linestyle = "--" ,
119+ marker = val_marker_type ,
120+ color = val_color ,
121+ lw = lw_val ,
122+ label = "Validation" ,
123+ )
124+ else :
125+ # Create line segments between each epoch
126+ points = np .array ([val_step_index , val_losses .iloc [:,0 ]]).T .reshape (- 1 , 1 , 2 )
127+ segments = np .concatenate ([points [:- 1 ], points [1 :]], axis = 1 )
128+
129+ # Normalize color based on loss values
130+ lc = gradient_line (
131+ val_step_index ,
132+ val_losses .iloc [:,0 ],
133+ c = val_step_index ,
134+ cmap = val_colormap ,
135+ lw = lw_val ,
136+ ax = ax
137+ )
138+ scatter = ax .scatter (
139+ val_step_index ,
140+ val_losses .iloc [:,0 ],
141+ c = val_step_index ,
142+ cmap = val_colormap ,
143+ marker = val_marker_type ,
144+ s = val_marker_size ,
145+ zorder = 10 ,
146+ edgecolors = 'none' ,
147+ label = 'Validation'
148+ )
120149
121150 sns .despine (ax = ax )
122- ax .grid (alpha = 0.5 )
151+ ax .grid (alpha = grid_alpha )
123152
124153 # Only add legend if there is a validation curve
125154 if val_losses is not None or moving_average :
0 commit comments