Skip to content

Commit b9b6b9a

Browse files
Diagnostics module (#235)
* Migrate diagnostics module * Add confusion_matrix * Remove duplicate code for layout configuration * Simplify pre-processing for plot_recovery and plot_z_score_contraction * Simplify prettify process for z_score and recovery * Simplify preprocessing for plot_sbc_ecdf and plot_sbc_histograms * Simplify labeling * Reformat * Make plot_distribution_2d more compatible * Update quickstart notebook with working prior checks * Improve compatibility for plot_losses and plot_prior_2d * Update quickstart notebook with loss trajectory * Minor fix of plot_utils, start adding tests for diagnostics * Pre-final version WIP * Minor changes in 2d plots; update and test plot_z_score_contraction * Update and test plot_sbc_histograms * Update plot_calibration_curves (WIP) * Minor refactors: change global color schemes, complete type casts, further simplify plot_losses * Add detailed callback for loss trajectory * Generalize preprocessing utilities from samples to variables * Generalize add_metric * Remove redundant code segments related to prettify * Include add_titles and add_titles_and_labels; propagate variables as samples * Interim cleanup * Add typing and fix plot_sbc_ecdf * Minor fix of plot_samples_2d * Minor fix of plot_prior_2d * Remove redundant code for axes flattening * Ensure consistent color scheme; incorporate sequence of labels * Bug fix for plot_losses * Cleanup --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
1 parent 06678e8 commit b9b6b9a

29 files changed

+3222
-312
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ First, install one machine learning backend of choice. Note that BayesFlow **wil
4242
- [Install TensorFlow](https://www.tensorflow.org/install)
4343

4444
If you don't know which backend to use, we recommend JAX to get started.
45-
It is the fastest backend and already works pretty reliably with the current
45+
It is the fastest backend and already works pretty reliably with the current
4646
dev version of bayesflow.
4747

4848
Once installed, [set the backend environment variable as required by keras](https://keras.io/getting_started/#configuring-your-backend). For example, inside your Python script write:

bayesflow/diagnostics/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from .plot_losses import plot_losses
2+
from .plot_recovery import plot_recovery
3+
from .plot_sbc_ecdf import plot_sbc_ecdf
4+
from .plot_sbc_histograms import plot_sbc_histograms
5+
from .plot_samples_2d import plot_samples_2d
6+
from .plot_z_score_contraction import plot_z_score_contraction
7+
from .plot_prior_2d import plot_prior_2d
8+
from .plot_posterior_2d import plot_posterior_2d
9+
from .plot_calibration_curves import plot_calibration_curves
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
from typing import Sequence
5+
from ..utils.comp_utils import expected_calibration_error
6+
from ..utils.plot_utils import preprocess, add_titles_and_labels, add_metric, prettify_subplots
7+
8+
9+
def plot_calibration_curves(
10+
post_model_samples: dict[str, np.ndarray] | np.ndarray,
11+
true_model_samples: dict[str, np.ndarray] | np.ndarray,
12+
names: Sequence[str] = None,
13+
num_bins: int = 10,
14+
label_fontsize: int = 16,
15+
title_fontsize: int = 18,
16+
metric_fontsize: int = 14,
17+
tick_fontsize: int = 12,
18+
epsilon: float = 0.02,
19+
figsize: Sequence[int] = None,
20+
color: str = "#132a70",
21+
num_col: int = None,
22+
num_row: int = None,
23+
) -> plt.Figure:
24+
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
25+
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
26+
Depends on the ``expected_calibration_error`` function for computing the ECE.
27+
28+
Parameters
29+
----------
30+
true_model_samples : np.ndarray of shape (num_data_sets, num_models)
31+
The one-hot-encoded true model indices per data set.
32+
post_model_samples : np.ndarray of shape (num_data_sets, num_models)
33+
The predicted posterior model probabilities (PMPs) per data set.
34+
names : list or None, optional, default: None
35+
The model names for nice plot titles. Inferred if None.
36+
num_bins : int, optional, default: 10
37+
The number of bins to use for the calibration curves (and marginal histograms).
38+
label_fontsize : int, optional, default: 16
39+
The font size of the y-label and y-label texts
40+
legend_fontsize : int, optional, default: 14
41+
The font size of the legend text (ECE value)
42+
title_fontsize : int, optional, default: 18
43+
The font size of the title text. Only relevant if `stacked=False`
44+
tick_fontsize : int, optional, default: 12
45+
The font size of the axis ticklabels
46+
epsilon : float, optional, default: 0.02
47+
A small amount to pad the [0, 1]-bounded axes from both side.
48+
figsize : tuple or None, optional, default: None
49+
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
50+
color : str, optional, default: '#8f2727'
51+
The color of the calibration curves
52+
num_row : int, optional, default: None
53+
The number of rows for the subplots. Dynamically determined if None.
54+
num_col : int, optional, default: None
55+
The number of columns for the subplots. Dynamically determined if None.
56+
57+
Returns
58+
-------
59+
fig : plt.Figure - the figure instance for optional saving
60+
"""
61+
62+
plot_data = preprocess(post_model_samples, true_model_samples, names, num_col, num_row, figsize, context="M")
63+
64+
# Compute calibration
65+
cal_errors, true_probs, pred_probs = expected_calibration_error(
66+
plot_data["prior_samples"], plot_data["post_samples"], num_bins
67+
)
68+
69+
for j, ax in enumerate(plot_data["axes"].flat):
70+
# Plot calibration curve
71+
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)
72+
73+
# Plot PMP distribution over bins
74+
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
75+
norm_weights = np.ones_like(plot_data["post_samples"]) / len(plot_data["post_samples"])
76+
ax[j].hist(
77+
plot_data["post_samples"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3
78+
)
79+
80+
# Plot AB line
81+
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)
82+
83+
# Tweak plot
84+
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
85+
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
86+
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
87+
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
88+
89+
# Add ECE label
90+
add_metric(
91+
ax[j],
92+
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
93+
metric_value=cal_errors[j],
94+
metric_fontsize=metric_fontsize,
95+
)
96+
97+
# Prettify
98+
prettify_subplots(axes=plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
99+
100+
# Only add x-labels to the bottom row
101+
add_titles_and_labels(
102+
axes=plot_data["axes"],
103+
num_row=plot_data["num_row"],
104+
num_col=plot_data["num_col"],
105+
title=plot_data["names"],
106+
xlabel="Predicted Probability",
107+
ylabel="True Probability",
108+
title_fontsize=title_fontsize,
109+
label_fontsize=label_fontsize,
110+
)
111+
112+
plot_data["fig"].tight_layout()
113+
return plot_data["fig"]
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import matplotlib.colors
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
from typing import Sequence
6+
7+
from keras import ops
8+
from sklearn.metrics import confusion_matrix
9+
from matplotlib.colors import LinearSegmentedColormap
10+
11+
from bayesflow.utils.plot_utils import make_figure
12+
13+
14+
def plot_confusion_matrix(
15+
true_models: dict[str, np.ndarray] | np.ndarray,
16+
pred_models: dict[str, np.ndarray] | np.ndarray,
17+
model_names: Sequence[str] = None,
18+
fig_size: tuple = (5, 5),
19+
label_fontsize: int = 16,
20+
title_fontsize: int = 18,
21+
value_fontsize: int = 10,
22+
tick_fontsize: int = 12,
23+
xtick_rotation: int = None,
24+
ytick_rotation: int = None,
25+
normalize: bool = True,
26+
cmap: matplotlib.colors.Colormap | str = None,
27+
title: bool = True,
28+
) -> plt.Figure:
29+
"""Plots a confusion matrix for validating a neural network trained for Bayesian model comparison.
30+
31+
Parameters
32+
----------
33+
true_models : np.ndarray of shape (num_data_sets, num_models)
34+
The one-hot-encoded true model indices per data set.
35+
pred_models : np.ndarray of shape (num_data_sets, num_models)
36+
The predicted posterior model probabilities (PMPs) per data set.
37+
model_names : list or None, optional, default: None
38+
The model names for nice plot titles. Inferred if None.
39+
fig_size : tuple or None, optional, default: (5, 5)
40+
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
41+
label_fontsize : int, optional, default: 16
42+
The font size of the y-label and y-label texts
43+
title_fontsize : int, optional, default: 18
44+
The font size of the title text.
45+
value_fontsize : int, optional, default: 10
46+
The font size of the text annotations and the colorbar tick labels.
47+
tick_fontsize : int, optional, default: 12
48+
The font size of the axis label and model name texts.
49+
xtick_rotation: int, optional, default: None
50+
Rotation of x-axis tick labels (helps with long model names).
51+
ytick_rotation: int, optional, default: None
52+
Rotation of y-axis tick labels (helps with long model names).
53+
normalize : bool, optional, default: True
54+
A flag for normalization of the confusion matrix.
55+
If True, each row of the confusion matrix is normalized to sum to 1.
56+
cmap : matplotlib.colors.Colormap or str, optional, default: None
57+
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
58+
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
59+
title : bool, optional, default True
60+
A flag for adding 'Confusion Matrix' above the matrix.
61+
62+
Returns
63+
-------
64+
fig : plt.Figure - the figure instance for optional saving
65+
"""
66+
67+
if model_names is None:
68+
num_models = true_models.shape[-1]
69+
model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)]
70+
71+
if cmap is None:
72+
cmap = LinearSegmentedColormap.from_list("", ["white", "#132a70"])
73+
74+
# Flatten input
75+
true_models = ops.argmax(true_models, axis=1)
76+
pred_models = ops.argmax(pred_models, axis=1)
77+
78+
# Compute confusion matrix
79+
cm = confusion_matrix(true_models, pred_models)
80+
81+
# if normalize:
82+
# # Sum along rows and keep dimensions for broadcasting
83+
# cm_sum = ops.sum(cm, axis=1, keepdims=True)
84+
#
85+
# # Broadcast division for normalization
86+
# cm_normalized = cm / cm_sum
87+
88+
# Initialize figure
89+
fig, ax = make_figure(1, 1, figsize=fig_size)
90+
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
91+
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)
92+
93+
cbar.ax.tick_params(labelsize=value_fontsize)
94+
95+
ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0]))
96+
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
97+
if xtick_rotation:
98+
plt.xticks(rotation=xtick_rotation, ha="right")
99+
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
100+
if ytick_rotation:
101+
plt.yticks(rotation=ytick_rotation)
102+
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
103+
ax.set_ylabel("True model", fontsize=label_fontsize)
104+
105+
# Loop over data dimensions and create text annotations
106+
fmt = ".2f" if normalize else "d"
107+
thresh = cm.max() / 2.0
108+
for i in range(cm.shape[0]):
109+
for j in range(cm.shape[1]):
110+
ax.text(
111+
j,
112+
i,
113+
format(cm[i, j], fmt),
114+
fontsize=value_fontsize,
115+
ha="center",
116+
va="center",
117+
color="white" if cm[i, j] > thresh else "black",
118+
)
119+
if title:
120+
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
121+
return fig
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import numpy as np
2+
import pandas as pd
3+
import seaborn as sns
4+
import matplotlib.pyplot as plt
5+
6+
from typing import Sequence
7+
from ..utils.plot_utils import make_figure, add_titles_and_labels
8+
9+
10+
def plot_losses(
11+
train_losses: pd.DataFrame | np.ndarray,
12+
val_losses: pd.DataFrame | np.ndarray = None,
13+
moving_average: bool = False,
14+
per_training_step: bool = False,
15+
ma_window_fraction: float = 0.01,
16+
figsize: Sequence[float] = None,
17+
train_color: str = "#132a70",
18+
val_color: str = "black",
19+
lw_train: float = 2.0,
20+
lw_val: float = 3.0,
21+
legend_fontsize: int = 14,
22+
label_fontsize: int = 14,
23+
title_fontsize: int = 16,
24+
) -> plt.Figure:
25+
"""
26+
A generic helper function to plot the losses of a series of training epochs
27+
and runs.
28+
29+
Parameters
30+
----------
31+
32+
train_losses : pd.DataFrame
33+
The (plottable) history as returned by a train_[...] method of a
34+
``Trainer`` instance.
35+
Alternatively, you can just pass a data frame of validation losses
36+
instead of train losses, if you only want to plot the validation loss.
37+
val_losses : pd.DataFrame or None, optional, default: None
38+
The (plottable) validation history as returned by a train_[...] method
39+
of a ``Trainer`` instance.
40+
If left ``None``, only train losses are plotted. Should have the same
41+
number of columns as ``train_losses``.
42+
moving_average : bool, optional, default: False
43+
A flag for adding a moving average line of the train_losses.
44+
per_training_step : bool, optional, default: False
45+
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
46+
ma_window_fraction : int, optional, default: 0.01
47+
Window size for the moving average as a fraction of total
48+
training steps.
49+
figsize : tuple or None, optional, default: None
50+
The figure size passed to the ``matplotlib`` constructor.
51+
Inferred if ``None``
52+
train_color : str, optional, default: '#8f2727'
53+
The color for the train loss trajectory
54+
val_color : str, optional, default: black
55+
The color for the optional validation loss trajectory
56+
lw_train : int, optional, default: 2
57+
The linewidth for the training loss curve
58+
lw_val : int, optional, default: 3
59+
The linewidth for the validation loss curve
60+
legend_fontsize : int, optional, default: 14
61+
The font size of the legend text
62+
label_fontsize : int, optional, default: 14
63+
The font size of the y-label text
64+
title_fontsize : int, optional, default: 16
65+
The font size of the title text
66+
67+
Returns
68+
-------
69+
f : plt.Figure - the figure instance for optional saving
70+
71+
Raises
72+
------
73+
AssertionError
74+
If the number of columns in ``train_losses`` does not match the
75+
number of columns in ``val_losses``.
76+
"""
77+
if isinstance(train_losses, np.ndarray):
78+
train_losses = pd.DataFrame(train_losses)
79+
80+
if isinstance(val_losses, np.ndarray):
81+
val_losses = pd.DataFrame(val_losses)
82+
83+
# Determine the number of rows for plot
84+
num_row = len(train_losses.columns)
85+
86+
# Initialize figure
87+
fig, axes = make_figure(num_row=num_row, num_col=1, figsize=(16, int(4 * num_row) if figsize is None else figsize))
88+
89+
# Get the number of steps as an array
90+
train_step_index = np.arange(1, len(train_losses) + 1)
91+
if val_losses is not None:
92+
val_step = int(np.floor(len(train_losses) / len(val_losses)))
93+
val_step_index = train_step_index[(val_step - 1) :: val_step]
94+
95+
# If unequal length due to some reason, attempt a fix
96+
if val_step_index.shape[0] > val_losses.shape[0]:
97+
val_step_index = val_step_index[: val_losses.shape[0]]
98+
99+
# Loop through loss entries and populate plot
100+
looper = [axes] if num_row == 1 else axes.flat
101+
for i, ax in enumerate(looper):
102+
# Plot train curve
103+
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
104+
if moving_average and train_losses.columns[i] == "Loss":
105+
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
106+
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
107+
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
108+
109+
# Plot optional val curve
110+
if val_losses is not None:
111+
if i < val_losses.shape[1]:
112+
ax.plot(
113+
val_step_index,
114+
val_losses.iloc[:, i],
115+
linestyle="--",
116+
marker="o",
117+
color=val_color,
118+
lw=lw_val,
119+
label="Validation",
120+
)
121+
122+
sns.despine(ax=ax)
123+
ax.grid(alpha=0.5)
124+
125+
# Only add legend if there is a validation curve
126+
if val_losses is not None or moving_average:
127+
ax.legend(fontsize=legend_fontsize)
128+
129+
# Schmuck
130+
add_titles_and_labels(
131+
axes=np.atleast_1d(axes),
132+
num_row=num_row,
133+
num_col=1,
134+
title=train_losses.columns if num_row > 1 else ["Training Loss"],
135+
xlabel="Training step #" if per_training_step else "Training epoch #",
136+
ylabel="Value",
137+
title_fontsize=title_fontsize,
138+
label_fontsize=label_fontsize,
139+
)
140+
141+
fig.tight_layout()
142+
return fig

0 commit comments

Comments
 (0)