1- from collections . abc import Sequence
1+ from typing import Sequence
22
33import numpy as np
44import pandas as pd
77
88import keras .src .callbacks
99
10- from ...utils .plot_utils import make_figure , add_titles_and_labels , gradient_line
10+ from ...utils .plot_utils import make_figure , add_titles_and_labels , add_gradient_plot
1111
1212
1313def loss (
1414 history : keras .callbacks .History ,
1515 train_key : str = "loss" ,
1616 val_key : str = "val_loss" ,
1717 moving_average : bool = True ,
18- moving_average_span : int = 10 ,
18+ moving_average_alpha : float = 0.8 ,
1919 figsize : Sequence [float ] = None ,
2020 train_color : str = "#132a70" ,
2121 val_color : str = None ,
2222 val_colormap : str = "viridis" ,
2323 lw_train : float = 2.0 ,
2424 lw_val : float = 3.0 ,
25- val_marker_type : str = "o" ,
25+ marker : bool = True ,
26+ val_marker_type : str = "." ,
2627 val_marker_size : int = 34 ,
2728 grid_alpha : float = 0.2 ,
2829 legend_fontsize : int = 14 ,
@@ -43,9 +44,8 @@ def loss(
4344 The validation loss key to look for in the history
4445 moving_average : bool, optional, default: False
4546 A flag for adding an exponential moving average line of the train_losses.
46- moving_average_span : int, optional, default: 0.01
47- Window size for the moving average as a fraction of total
48- training steps.
47+ moving_average_alpha : int, optional, default: 0.8
48+ Smoothing factor for the moving average.
4949 figsize : tuple or None, optional, default: None
5050 The figure size passed to the ``matplotlib`` constructor.
5151 Inferred if ``None``
@@ -54,11 +54,13 @@ def loss(
5454 val_color : str, optional, default: None
5555 The color for the optional validation loss trajectory
5656 val_colormap : str, optional, default: "viridis"
57- The color for the optional validation loss trajectory
57+ The colormap for the optional validation loss trajectory
5858 lw_train : int, optional, default: 2
5959 The linewidth for the training loss curve
6060 lw_val : int, optional, default: 3
6161 The linewidth for the validation loss curve
62+ marker : bool, optional, default: False
63+ A flag for whether marker should be added in the validation loss trajectory
6264 val_marker_type : str, optional, default: o
6365 The marker type for the validation loss curve
6466 val_marker_size : int, optional, default: 34
@@ -108,10 +110,10 @@ def loss(
108110 # Loop through loss entries and populate plot
109111 for i , ax in enumerate (axes .flat ):
110112 # Plot train curve
111- ax .plot (train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.2 , label = "Training" )
113+ ax .plot (train_step_index , train_losses .iloc [:, 0 ], color = train_color , lw = lw_train , alpha = 0.05 , label = "Training" )
112114 if moving_average :
113- smoothed_loss = train_losses .iloc [:, 0 ].ewm (span = moving_average_span , adjust = True ).mean ()
114- ax .plot (train_step_index , smoothed_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
115+ smoothed_train_loss = train_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
116+ ax .plot (train_step_index , smoothed_train_loss , color = "grey" , lw = lw_train , label = "Training (Moving Average)" )
115117
116118 # Plot optional val curve
117119 if val_losses is not None :
@@ -120,27 +122,49 @@ def loss(
120122 val_step_index ,
121123 val_losses .iloc [:, 0 ],
122124 linestyle = "--" ,
123- marker = val_marker_type ,
125+ marker = val_marker_type if marker else None ,
124126 color = val_color ,
125127 lw = lw_val ,
128+ alpha = 0.2 ,
126129 label = "Validation" ,
127130 )
131+ if moving_average :
132+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
133+ ax .plot (
134+ val_step_index ,
135+ smoothed_val_loss ,
136+ color = val_color ,
137+ lw = lw_val ,
138+ label = "Validation (Moving Average)" ,
139+ )
128140 else :
129141 # Make gradient lines
130- gradient_line (
131- val_step_index , val_losses .iloc [:, 0 ], c = val_step_index , cmap = val_colormap , lw = lw_val , ax = ax
132- )
133- ax .scatter (
142+ add_gradient_plot (
134143 val_step_index ,
135144 val_losses .iloc [:, 0 ],
136- c = val_step_index ,
137- cmap = val_colormap ,
138- marker = val_marker_type ,
139- s = val_marker_size ,
140- zorder = 10 ,
141- edgecolors = "none" ,
145+ ax ,
146+ val_colormap ,
147+ lw_val ,
148+ marker ,
149+ val_marker_type ,
150+ val_marker_size ,
151+ alpha = 0.05 ,
142152 label = "Validation" ,
143153 )
154+ if moving_average :
155+ smoothed_val_loss = val_losses .iloc [:, 0 ].ewm (alpha = moving_average_alpha , adjust = True ).mean ()
156+ add_gradient_plot (
157+ val_step_index ,
158+ smoothed_val_loss ,
159+ ax ,
160+ val_colormap ,
161+ lw_val ,
162+ marker ,
163+ val_marker_type ,
164+ val_marker_size ,
165+ alpha = 1 ,
166+ label = "Validation (Moving Average)" ,
167+ )
144168
145169 sns .despine (ax = ax )
146170 ax .grid (alpha = grid_alpha )
0 commit comments