|
| 1 | +import numpy as np |
| 2 | +import matplotlib.pyplot as plt |
| 3 | + |
| 4 | +from typing import Sequence |
| 5 | +from ..utils.comp_utils import expected_calibration_error |
| 6 | +from ..utils.plot_utils import preprocess, add_titles_and_labels, add_metric, prettify_subplots |
| 7 | + |
| 8 | + |
| 9 | +def plot_calibration_curves( |
| 10 | + post_model_samples: dict[str, np.ndarray] | np.ndarray, |
| 11 | + true_model_samples: dict[str, np.ndarray] | np.ndarray, |
| 12 | + names: Sequence[str] = None, |
| 13 | + num_bins: int = 10, |
| 14 | + label_fontsize: int = 16, |
| 15 | + title_fontsize: int = 18, |
| 16 | + metric_fontsize: int = 14, |
| 17 | + tick_fontsize: int = 12, |
| 18 | + epsilon: float = 0.02, |
| 19 | + figsize: Sequence[int] = None, |
| 20 | + color: str = "#132a70", |
| 21 | + num_col: int = None, |
| 22 | + num_row: int = None, |
| 23 | +) -> plt.Figure: |
| 24 | + """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities |
| 25 | + for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. |
| 26 | + Depends on the ``expected_calibration_error`` function for computing the ECE. |
| 27 | +
|
| 28 | + Parameters |
| 29 | + ---------- |
| 30 | + true_model_samples : np.ndarray of shape (num_data_sets, num_models) |
| 31 | + The one-hot-encoded true model indices per data set. |
| 32 | + post_model_samples : np.ndarray of shape (num_data_sets, num_models) |
| 33 | + The predicted posterior model probabilities (PMPs) per data set. |
| 34 | + names : list or None, optional, default: None |
| 35 | + The model names for nice plot titles. Inferred if None. |
| 36 | + num_bins : int, optional, default: 10 |
| 37 | + The number of bins to use for the calibration curves (and marginal histograms). |
| 38 | + label_fontsize : int, optional, default: 16 |
| 39 | + The font size of the y-label and y-label texts |
| 40 | + legend_fontsize : int, optional, default: 14 |
| 41 | + The font size of the legend text (ECE value) |
| 42 | + title_fontsize : int, optional, default: 18 |
| 43 | + The font size of the title text. Only relevant if `stacked=False` |
| 44 | + tick_fontsize : int, optional, default: 12 |
| 45 | + The font size of the axis ticklabels |
| 46 | + epsilon : float, optional, default: 0.02 |
| 47 | + A small amount to pad the [0, 1]-bounded axes from both side. |
| 48 | + figsize : tuple or None, optional, default: None |
| 49 | + The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` |
| 50 | + color : str, optional, default: '#8f2727' |
| 51 | + The color of the calibration curves |
| 52 | + num_row : int, optional, default: None |
| 53 | + The number of rows for the subplots. Dynamically determined if None. |
| 54 | + num_col : int, optional, default: None |
| 55 | + The number of columns for the subplots. Dynamically determined if None. |
| 56 | +
|
| 57 | + Returns |
| 58 | + ------- |
| 59 | + fig : plt.Figure - the figure instance for optional saving |
| 60 | + """ |
| 61 | + |
| 62 | + plot_data = preprocess(post_model_samples, true_model_samples, names, num_col, num_row, figsize, context="M") |
| 63 | + |
| 64 | + # Compute calibration |
| 65 | + cal_errors, true_probs, pred_probs = expected_calibration_error( |
| 66 | + plot_data["prior_samples"], plot_data["post_samples"], num_bins |
| 67 | + ) |
| 68 | + |
| 69 | + for j, ax in enumerate(plot_data["axes"].flat): |
| 70 | + # Plot calibration curve |
| 71 | + ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color) |
| 72 | + |
| 73 | + # Plot PMP distribution over bins |
| 74 | + uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) |
| 75 | + norm_weights = np.ones_like(plot_data["post_samples"]) / len(plot_data["post_samples"]) |
| 76 | + ax[j].hist( |
| 77 | + plot_data["post_samples"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3 |
| 78 | + ) |
| 79 | + |
| 80 | + # Plot AB line |
| 81 | + ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9) |
| 82 | + |
| 83 | + # Tweak plot |
| 84 | + ax[j].set_xlim([0 - epsilon, 1 + epsilon]) |
| 85 | + ax[j].set_ylim([0 - epsilon, 1 + epsilon]) |
| 86 | + ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) |
| 87 | + ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) |
| 88 | + |
| 89 | + # Add ECE label |
| 90 | + add_metric( |
| 91 | + ax[j], |
| 92 | + metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}", |
| 93 | + metric_value=cal_errors[j], |
| 94 | + metric_fontsize=metric_fontsize, |
| 95 | + ) |
| 96 | + |
| 97 | + # Prettify |
| 98 | + prettify_subplots(axes=plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) |
| 99 | + |
| 100 | + # Only add x-labels to the bottom row |
| 101 | + add_titles_and_labels( |
| 102 | + axes=plot_data["axes"], |
| 103 | + num_row=plot_data["num_row"], |
| 104 | + num_col=plot_data["num_col"], |
| 105 | + title=plot_data["names"], |
| 106 | + xlabel="Predicted Probability", |
| 107 | + ylabel="True Probability", |
| 108 | + title_fontsize=title_fontsize, |
| 109 | + label_fontsize=label_fontsize, |
| 110 | + ) |
| 111 | + |
| 112 | + plot_data["fig"].tight_layout() |
| 113 | + return plot_data["fig"] |
0 commit comments