Skip to content

Commit 722c861

Browse files
committed
Add smoothing for validation loss too
1 parent 06c5565 commit 722c861

File tree

2 files changed

+80
-25
lines changed

2 files changed

+80
-25
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence
1+
from typing import Sequence
22

33
import numpy as np
44
import pandas as pd
@@ -7,22 +7,23 @@
77

88
import keras.src.callbacks
99

10-
from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line
10+
from ...utils.plot_utils import make_figure, add_titles_and_labels, add_gradient_plot
1111

1212

1313
def loss(
1414
history: keras.callbacks.History,
1515
train_key: str = "loss",
1616
val_key: str = "val_loss",
1717
moving_average: bool = True,
18-
moving_average_span: int = 10,
18+
moving_average_alpha: float = 0.8,
1919
figsize: Sequence[float] = None,
2020
train_color: str = "#132a70",
2121
val_color: str = None,
2222
val_colormap: str = "viridis",
2323
lw_train: float = 2.0,
2424
lw_val: float = 3.0,
25-
val_marker_type: str = "o",
25+
marker: bool = True,
26+
val_marker_type: str = ".",
2627
val_marker_size: int = 34,
2728
grid_alpha: float = 0.2,
2829
legend_fontsize: int = 14,
@@ -43,9 +44,8 @@ def loss(
4344
The validation loss key to look for in the history
4445
moving_average : bool, optional, default: False
4546
A flag for adding an exponential moving average line of the train_losses.
46-
moving_average_span : int, optional, default: 0.01
47-
Window size for the moving average as a fraction of total
48-
training steps.
47+
moving_average_alpha : int, optional, default: 0.8
48+
Smoothing factor for the moving average.
4949
figsize : tuple or None, optional, default: None
5050
The figure size passed to the ``matplotlib`` constructor.
5151
Inferred if ``None``
@@ -54,11 +54,13 @@ def loss(
5454
val_color : str, optional, default: None
5555
The color for the optional validation loss trajectory
5656
val_colormap : str, optional, default: "viridis"
57-
The color for the optional validation loss trajectory
57+
The colormap for the optional validation loss trajectory
5858
lw_train : int, optional, default: 2
5959
The linewidth for the training loss curve
6060
lw_val : int, optional, default: 3
6161
The linewidth for the validation loss curve
62+
marker : bool, optional, default: False
63+
A flag for whether marker should be added in the validation loss trajectory
6264
val_marker_type : str, optional, default: o
6365
The marker type for the validation loss curve
6466
val_marker_size : int, optional, default: 34
@@ -108,10 +110,10 @@ def loss(
108110
# Loop through loss entries and populate plot
109111
for i, ax in enumerate(axes.flat):
110112
# Plot train curve
111-
ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.2, label="Training")
113+
ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.05, label="Training")
112114
if moving_average:
113-
smoothed_loss = train_losses.iloc[:, 0].ewm(span=moving_average_span, adjust=True).mean()
114-
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
115+
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
116+
ax.plot(train_step_index, smoothed_train_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
115117

116118
# Plot optional val curve
117119
if val_losses is not None:
@@ -120,27 +122,49 @@ def loss(
120122
val_step_index,
121123
val_losses.iloc[:, 0],
122124
linestyle="--",
123-
marker=val_marker_type,
125+
marker=val_marker_type if marker else None,
124126
color=val_color,
125127
lw=lw_val,
128+
alpha=0.2,
126129
label="Validation",
127130
)
131+
if moving_average:
132+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
133+
ax.plot(
134+
val_step_index,
135+
smoothed_val_loss,
136+
color=val_color,
137+
lw=lw_val,
138+
label="Validation (Moving Average)",
139+
)
128140
else:
129141
# Make gradient lines
130-
gradient_line(
131-
val_step_index, val_losses.iloc[:, 0], c=val_step_index, cmap=val_colormap, lw=lw_val, ax=ax
132-
)
133-
ax.scatter(
142+
add_gradient_plot(
134143
val_step_index,
135144
val_losses.iloc[:, 0],
136-
c=val_step_index,
137-
cmap=val_colormap,
138-
marker=val_marker_type,
139-
s=val_marker_size,
140-
zorder=10,
141-
edgecolors="none",
145+
ax,
146+
val_colormap,
147+
lw_val,
148+
marker,
149+
val_marker_type,
150+
val_marker_size,
151+
alpha=0.05,
142152
label="Validation",
143153
)
154+
if moving_average:
155+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
156+
add_gradient_plot(
157+
val_step_index,
158+
smoothed_val_loss,
159+
ax,
160+
val_colormap,
161+
lw_val,
162+
marker,
163+
val_marker_type,
164+
val_marker_size,
165+
alpha=1,
166+
label="Validation (Moving Average)",
167+
)
144168

145169
sns.despine(ax=ax)
146170
ax.grid(alpha=grid_alpha)

bayesflow/utils/plot_utils.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
267267
)
268268

269269

270-
def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None):
270+
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
271271
"""
272272
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
273273
"""
@@ -283,7 +283,7 @@ def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None):
283283
segments = np.concatenate([points[:-1], points[1:]], axis=1)
284284

285285
norm = Normalize(np.min(c), np.max(c))
286-
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw)
286+
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha)
287287

288288
ax.add_collection(lc)
289289
ax.set_xlim(np.min(x), np.max(x))
@@ -295,7 +295,8 @@ def gradient_legend(ax, label, cmap, norm, loc="upper right"):
295295
"""
296296
Adds a single gradient swatch to the legend of the given Axes.
297297
298-
Parameters:
298+
Parameters
299+
----------
299300
- ax: matplotlib Axes
300301
- label: str, label to display in the legend
301302
- cmap: matplotlib colormap
@@ -327,3 +328,33 @@ def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height,
327328
labels.append(label)
328329

329330
ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})
331+
332+
333+
def add_gradient_plot(
334+
x,
335+
y,
336+
ax,
337+
cmap: str = "viridis",
338+
lw: float = 3.0,
339+
marker: bool = True,
340+
marker_type: str = "o",
341+
marker_size: int = 34,
342+
alpha: float = 1,
343+
label: str = "Validation",
344+
):
345+
gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax)
346+
347+
# Optionally add markers
348+
if marker:
349+
ax.scatter(
350+
x,
351+
y,
352+
c=x,
353+
cmap=cmap,
354+
marker=marker_type,
355+
s=marker_size,
356+
zorder=10,
357+
edgecolors="none",
358+
label=label,
359+
alpha=0.01,
360+
)

0 commit comments

Comments
 (0)