Skip to content

Commit 4e130cd

Browse files
committed
Enhance loss plot, including viridisifying the val loss
1 parent 2bedf47 commit 4e130cd

File tree

2 files changed

+120
-20
lines changed

2 files changed

+120
-20
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 49 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,26 @@
77

88
import keras.src.callbacks
99

10-
from ...utils.plot_utils import make_figure, add_titles_and_labels
10+
from matplotlib.colors import Normalize
11+
from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line, gradient_legend
1112

1213

1314
def loss(
1415
history: keras.callbacks.History,
1516
train_key: str = "loss",
1617
val_key: str = "val_loss",
17-
moving_average: bool = False,
18+
moving_average: bool = True,
1819
per_training_step: bool = False,
19-
ma_window_fraction: float = 0.01,
20+
moving_average_span: int = 10,
2021
figsize: Sequence[float] = None,
2122
train_color: str = "#132a70",
22-
val_color: str = "black",
23+
val_color: str = None,
24+
val_colormap: str = 'viridis',
2325
lw_train: float = 2.0,
2426
lw_val: float = 3.0,
27+
val_marker_type: str = "o",
28+
val_marker_size: int = 34,
29+
grid_alpha: float = 0.2,
2530
legend_fontsize: int = 14,
2631
label_fontsize: int = 14,
2732
title_fontsize: int = 16,
@@ -39,7 +44,7 @@ def loss(
3944
val_key : str, optional, default: "val_loss"
4045
The validation loss key to look for in the history
4146
moving_average : bool, optional, default: False
42-
A flag for adding a moving average line of the train_losses.
47+
A flag for adding an exponential moving average line of the train_losses.
4348
per_training_step : bool, optional, default: False
4449
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
4550
ma_window_fraction : int, optional, default: 0.01
@@ -99,27 +104,51 @@ def loss(
99104
# Loop through loss entries and populate plot
100105
for i, ax in enumerate(axes.flat):
101106
# Plot train curve
102-
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
103-
if moving_average and train_losses.columns[i] == "Loss":
104-
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
105-
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
107+
ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.2, label="Training")
108+
if moving_average:
109+
smoothed_loss = train_losses.iloc[:, 0].ewm(span=moving_average_span, adjust=True).mean()
106110
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
107111

108112
# Plot optional val curve
109113
if val_losses is not None:
110-
if i < val_losses.shape[1]:
111-
ax.plot(
112-
val_step_index,
113-
val_losses.iloc[:, i],
114-
linestyle="--",
115-
marker="o",
116-
color=val_color,
117-
lw=lw_val,
118-
label="Validation",
119-
)
114+
if val_color is not None:
115+
ax.plot(
116+
val_step_index,
117+
val_losses.iloc[:, 0],
118+
linestyle="--",
119+
marker=val_marker_type,
120+
color=val_color,
121+
lw=lw_val,
122+
label="Validation",
123+
)
124+
else:
125+
# Create line segments between each epoch
126+
points = np.array([val_step_index, val_losses.iloc[:,0]]).T.reshape(-1, 1, 2)
127+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
128+
129+
# Normalize color based on loss values
130+
lc = gradient_line(
131+
val_step_index,
132+
val_losses.iloc[:,0],
133+
c=val_step_index,
134+
cmap=val_colormap,
135+
lw=lw_val,
136+
ax=ax
137+
)
138+
scatter = ax.scatter(
139+
val_step_index,
140+
val_losses.iloc[:,0],
141+
c=val_step_index,
142+
cmap=val_colormap,
143+
marker=val_marker_type,
144+
s=val_marker_size,
145+
zorder=10,
146+
edgecolors='none',
147+
label='Validation'
148+
)
120149

121150
sns.despine(ax=ax)
122-
ax.grid(alpha=0.5)
151+
ax.grid(alpha=grid_alpha)
123152

124153
# Only add legend if there is a validation curve
125154
if val_losses is not None or moving_average:

bayesflow/utils/plot_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import matplotlib.pyplot as plt
55
import seaborn as sns
66

7+
from matplotlib.collections import LineCollection
8+
from matplotlib.colors import Normalize
9+
from matplotlib.patches import Rectangle
10+
from matplotlib.legend_handler import HandlerPatch
11+
712
from .validators import check_estimates_prior_shapes
813
from .dict_utils import dicts_to_arrays
914

@@ -260,3 +265,69 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
260265
alpha=0.9,
261266
linestyle="dashed",
262267
)
268+
269+
270+
def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None):
271+
"""
272+
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
273+
"""
274+
if ax is None:
275+
ax = plt.gca()
276+
277+
# Default color value = y
278+
if c is None:
279+
c = y
280+
281+
# Create segments for LineCollection
282+
points = np.array([x, y]).T.reshape(-1, 1, 2)
283+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
284+
285+
norm = Normalize(np.min(c), np.max(c))
286+
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw)
287+
288+
ax.add_collection(lc)
289+
ax.set_xlim(np.min(x), np.max(x))
290+
ax.set_ylim(np.min(y), np.max(y))
291+
return lc
292+
293+
294+
def gradient_legend(ax, label, cmap, norm, loc='upper right'):
295+
"""
296+
Adds a single gradient swatch to the legend of the given Axes.
297+
298+
Parameters:
299+
- ax: matplotlib Axes
300+
- label: str, label to display in the legend
301+
- cmap: matplotlib colormap
302+
- norm: matplotlib Normalize object
303+
- loc: legend location (default 'upper right')
304+
"""
305+
306+
# Custom dummy handle to represent the gradient
307+
class _GradientSwatch(Rectangle): pass
308+
309+
# Custom legend handler that draws a horizontal gradient
310+
class _HandlerGradient(HandlerPatch):
311+
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
312+
gradient = np.linspace(0, 1, 256).reshape(1, -1)
313+
im = ax.imshow(
314+
gradient,
315+
aspect='auto',
316+
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
317+
transform=trans,
318+
cmap=cmap,
319+
norm=norm
320+
)
321+
return [im]
322+
323+
# Add to existing legend entries
324+
handles, labels = ax.get_legend_handles_labels()
325+
handles.append(_GradientSwatch((0, 0), 1, 1))
326+
labels.append(label)
327+
328+
ax.legend(
329+
handles=handles,
330+
labels=labels,
331+
loc=loc,
332+
handler_map={_GradientSwatch: _HandlerGradient()}
333+
)

0 commit comments

Comments
 (0)