@@ -19,6 +19,8 @@ def loss(
1919 figsize : Sequence [float ] = None ,
2020 train_color : str = "#132a70" ,
2121 val_color : str = "black" ,
22+ val_marker : str = "o" ,
23+ val_marker_size : float = 5 ,
2224 lw_train : float = 2.0 ,
2325 lw_val : float = 2.0 ,
2426 grid_alpha : float = 0.2 ,
@@ -49,6 +51,10 @@ def loss(
4951 The color for the train loss trajectory
5052 val_color : str, optional, default: None
5153 The color for the optional validation loss trajectory
54+ val_marker: str
55+ Marker style for the validation loss curve. Default is "o".
56+ val_marker_size: float
57+ Marker size for the validation loss curve. Default is 5.
5258 lw_train : int, optional, default: 2
5359 The linewidth for the training loss curve
5460 lw_val : int, optional, default: 2
@@ -130,6 +136,9 @@ def loss(
130136 color = val_color ,
131137 lw = lw_val ,
132138 alpha = alpha_unsmoothed ,
139+ linestyle = "--" ,
140+ marker = val_marker ,
141+ markersize = val_marker_size ,
133142 label = "Validation" ,
134143 )
135144
@@ -140,6 +149,7 @@ def loss(
140149 val_step_index ,
141150 smoothed_val_loss ,
142151 color = val_color ,
152+ linestyle = "--" ,
143153 lw = lw_val ,
144154 alpha = 0.8 ,
145155 label = "Validation (Moving Average)" ,
0 commit comments