Skip to content

Commit 0c1c9bc

Browse files
Diagnostics: partitioning filtering and naming (#260)
* Migrate diagnostics module * Add confusion_matrix * Remove duplicate code for layout configuration * Simplify pre-processing for plot_recovery and plot_z_score_contraction * Simplify prettify process for z_score and recovery * Simplify preprocessing for plot_sbc_ecdf and plot_sbc_histograms * Simplify labeling * Reformat * Make plot_distribution_2d more compatible * Update quickstart notebook with working prior checks * Improve compatibility for plot_losses and plot_prior_2d * Update quickstart notebook with loss trajectory * Minor fix of plot_utils, start adding tests for diagnostics * Pre-final version WIP * Minor changes in 2d plots; update and test plot_z_score_contraction * Update and test plot_sbc_histograms * Update plot_calibration_curves (WIP) * Minor refactors: change global color schemes, complete type casts, further simplify plot_losses * Add detailed callback for loss trajectory * Generalize preprocessing utilities from samples to variables * Generalize add_metric * Remove redundant code segments related to prettify * Include add_titles and add_titles_and_labels; propagate variables as samples * Interim cleanup * Add typing and fix plot_sbc_ecdf * Minor fix of plot_samples_2d * Minor fix of plot_prior_2d * Remove redundant code for axes flattening * Ensure consistent color scheme; incorporate sequence of labels * Bug fix for plot_losses * Cleanup * Partition filtering and renaming in dicts_to_arrays * Propagate filter keys and variable names * Rename all 'names' to 'variable_names' * Getting rid of test_diagnostics (for now) to make sure that tests are passing * Minor bugfix to plot_samples_2d and plot_z_score_contraction based on partitioning --------- Co-authored-by: stefanradev93 <stefan.radev93@gmail.com>
1 parent fcc78a8 commit 0c1c9bc

File tree

9 files changed

+455
-420
lines changed

9 files changed

+455
-420
lines changed

bayesflow/diagnostics/plot_calibration_curves.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def plot_calibration_curves(
102102
axes=plot_data["axes"],
103103
num_row=plot_data["num_row"],
104104
num_col=plot_data["num_col"],
105-
title=plot_data["names"],
105+
title=plot_data["variable_names"],
106106
xlabel="Predicted Probability",
107107
ylabel="True Probability",
108108
title_fontsize=title_fontsize,

bayesflow/diagnostics/plot_recovery.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
def plot_recovery(
1212
post_samples: dict[str, np.ndarray] | np.ndarray,
1313
prior_samples: dict[str, np.ndarray] | np.ndarray,
14+
filter_keys: Sequence[str] = None,
1415
variable_names: Sequence[str] = None,
1516
point_agg=np.median,
1617
uncertainty_agg=median_abs_deviation,
@@ -59,7 +60,7 @@ def plot_recovery(
5960
"""
6061

6162
# Gather plot data and metadata into a dictionary
62-
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
63+
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
6364
plot_data["post_samples"] = plot_data.pop("post_variables")
6465
plot_data["prior_samples"] = plot_data.pop("prior_variables")
6566

@@ -93,7 +94,7 @@ def plot_recovery(
9394
corr = np.corrcoef(plot_data["prior_samples"][:, i], point_estimate[:, i])[0, 1]
9495
add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize)
9596

96-
ax.set_title(plot_data["names"][i], fontsize=title_fontsize)
97+
ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize)
9798

9899
# Add custom schmuck
99100
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

bayesflow/diagnostics/plot_samples_2d.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
import seaborn as sns
33
import pandas as pd
44

5+
from typing import Sequence
56
from bayesflow.utils import logging
7+
from bayesflow.utils.dict_utils import dicts_to_arrays
68

79

810
def plot_samples_2d(
9-
samples: np.ndarray = None,
11+
samples: dict[str, np.ndarray] | np.ndarray = None,
12+
filter_keys: Sequence[str] = None,
1013
context: str = None,
1114
variable_names: list = None,
1215
height: float = 2.5,
@@ -41,7 +44,11 @@ def plot_samples_2d(
4144
Additional keyword arguments passed to the sns.PairGrid constructor
4245
"""
4346

44-
dim = samples.shape[-1]
47+
plot_data = dicts_to_arrays(
48+
post_variables=samples, filter_keys=filter_keys, variable_names=variable_names, context=context
49+
)
50+
51+
dim = plot_data["post_variables"].shape[-1]
4552
if context is None:
4653
context = "Default"
4754

@@ -52,7 +59,10 @@ def plot_samples_2d(
5259
titles = [f"{context} {p}" for p in variable_names]
5360

5461
# Convert samples to pd.DataFrame
55-
data_to_plot = pd.DataFrame(samples, columns=titles)
62+
if context == "Posterior":
63+
data_to_plot = pd.DataFrame(plot_data["post_variables"][0], columns=titles)
64+
else:
65+
data_to_plot = pd.DataFrame(plot_data["post_variables"], columns=titles)
5666

5767
# Generate plots
5868
artist = sns.PairGrid(data_to_plot, height=height, **kwargs)

bayesflow/diagnostics/plot_sbc_ecdf.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def plot_sbc_ecdf(
1010
post_samples: dict[str, np.ndarray] | np.ndarray,
1111
prior_samples: dict[str, np.ndarray] | np.ndarray,
12+
filter_keys: Sequence[str] = None,
1213
variable_names: Sequence[str] = None,
1314
difference: bool = False,
1415
stacked: bool = False,
@@ -92,7 +93,9 @@ def plot_sbc_ecdf(
9293
"""
9394

9495
# Preprocessing
95-
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize, stacked=stacked)
96+
plot_data = preprocess(
97+
post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize, stacked=stacked
98+
)
9699
plot_data["post_samples"] = plot_data.pop("post_variables")
97100
plot_data["prior_samples"] = plot_data.pop("prior_variables")
98101

@@ -129,7 +132,7 @@ def plot_sbc_ecdf(
129132
ylab = "ECDF"
130133

131134
# Add simultaneous bounds
132-
titles = plot_data["names"] if not stacked else ["Stacked ECDFs"]
135+
titles = plot_data["variable_names"] if not stacked else ["Stacked ECDFs"]
133136
for ax, title in zip(plot_data["axes"].flat, titles):
134137
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
135138
ax.legend(fontsize=legend_fontsize)

bayesflow/diagnostics/plot_sbc_histograms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
def plot_sbc_histograms(
1313
post_samples: dict[str, np.ndarray] | np.ndarray,
1414
prior_samples: dict[str, np.ndarray] | np.ndarray,
15+
filter_keys: Sequence[str] = None,
1516
variable_names: Sequence[str] = None,
1617
figsize: Sequence[float] = None,
1718
num_bins: int = 10,
@@ -71,11 +72,11 @@ def plot_sbc_histograms(
7172
"""
7273

7374
# Preprocessing
74-
plot_data = preprocess(post_samples, prior_samples, num_col, num_row, variable_names, figsize)
75+
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize=figsize)
7576
plot_data["post_samples"] = plot_data.pop("post_variables")
7677
plot_data["prior_samples"] = plot_data.pop("prior_variables")
7778

78-
# Determine the ratio of simulations to prior draws
79+
# Determine the ratio of simulations to prior draw
7980
# num_params = plot_data['num_variables']
8081
num_sims = plot_data["post_samples"].shape[0]
8182
num_draws = plot_data["post_samples"].shape[1]
@@ -119,7 +120,7 @@ def plot_sbc_histograms(
119120
axes=plot_data["axes"],
120121
num_row=plot_data["num_row"],
121122
num_col=plot_data["num_col"],
122-
title=plot_data["names"],
123+
title=plot_data["variable_names"],
123124
xlabel="Rank statistic",
124125
ylabel="",
125126
title_fontsize=title_fontsize,

bayesflow/diagnostics/plot_z_score_contraction.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
def plot_z_score_contraction(
1010
post_samples: dict[str, np.ndarray] | np.ndarray,
1111
prior_samples: dict[str, np.ndarray] | np.ndarray,
12+
filter_keys: Sequence[str] = None,
1213
variable_names: Sequence[str] = None,
1314
figsize: Sequence[int] = None,
1415
label_fontsize: int = 16,
@@ -84,7 +85,7 @@ def plot_z_score_contraction(
8485
"""
8586

8687
# Preprocessing
87-
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
88+
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
8889
plot_data["post_samples"] = plot_data.pop("post_variables")
8990
plot_data["prior_samples"] = plot_data.pop("prior_variables")
9091

@@ -98,7 +99,7 @@ def plot_z_score_contraction(
9899

99100
# Compute contraction and z-score
100101
contraction = 1 - (post_vars / prior_vars)
101-
z_score = (post_means - prior_samples) / post_stds
102+
z_score = (post_means - plot_data["prior_samples"]) / post_stds
102103

103104
# Loop and plot
104105
for i, ax in enumerate(plot_data["axes"].flat):
@@ -115,7 +116,7 @@ def plot_z_score_contraction(
115116
axes=plot_data["axes"],
116117
num_row=plot_data["num_row"],
117118
num_col=plot_data["num_col"],
118-
title=plot_data["names"],
119+
title=plot_data["variable_names"],
119120
xlabel="Posterior contraction",
120121
ylabel="Posterior z-score",
121122
title_fontsize=title_fontsize,

bayesflow/utils/dict_utils.py

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -107,36 +107,56 @@ def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Te
107107

108108
def dicts_to_arrays(
109109
post_variables: dict[str, np.ndarray] | np.ndarray,
110-
prior_variables: dict[str, np.ndarray] | np.ndarray,
111-
names: Sequence[str] = None,
110+
prior_variables: dict[str, np.ndarray] | np.ndarray = None,
111+
filter_keys: Sequence[str] | None = None,
112+
variable_names: Sequence[str] = None,
112113
context: str = None,
113114
):
114-
"""Utility to optionally convert dicts as returned from approximators and adapters into arrays."""
115+
"""
116+
# TODO
117+
"""
115118

116-
if type(post_variables) is not type(prior_variables):
117-
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")
119+
# Ensure that posterior and prior variables have the same type
120+
if prior_variables is not None:
121+
if type(post_variables) is not type(prior_variables):
122+
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")
118123

124+
# Filtering
119125
if isinstance(post_variables, dict):
120-
if post_variables.keys() != prior_variables.keys():
121-
raise ValueError("Keys in your posterior / prior arrays should match.")
126+
# Ensure that the keys of selected posterior and prior variables match
127+
if prior_variables is not None:
128+
if not (set(post_variables) <= set(prior_variables)):
129+
raise ValueError("Keys in your posterior / prior arrays should match.")
122130

123-
# Use user-provided names instead of inferred ones
124-
names = list(post_variables.keys()) if names is None else names
131+
# If they match, users can further select the variables by using filter keys
132+
filter_keys = list(post_variables.keys()) if filter_keys is None else filter_keys
125133

126-
post_variables = np.concatenate([v for k, v in post_variables.items() if k in names], axis=-1)
127-
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in names], axis=-1)
134+
# The variables will then be overridden with the filtered keys
135+
post_variables = np.concatenate([v for k, v in post_variables.items() if k in filter_keys], axis=-1)
136+
if prior_variables is not None:
137+
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in filter_keys], axis=-1)
128138

139+
# Naming or Renaming
129140
elif isinstance(post_variables, np.ndarray):
130-
if names is not None:
131-
if post_variables.shape[-1] != len(names) or prior_variables.shape[-1] != len(names):
132-
raise ValueError("The length of the names list should match the number of target variables.")
133-
else:
134-
if context is not None:
135-
names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
141+
# If there are filter_keys, check if their number is the same as that of the variables.
142+
# If it does, check if there are sufficient variable names.
143+
# If there are, then the variable names are adopted.
144+
if variable_names is not None:
145+
if post_variables.shape[-1] != len(variable_names) or prior_variables.shape[-1] != len(variable_names):
146+
raise ValueError("The number of variable names should match the number of target variables.")
147+
148+
else: # Otherwise, we would assume that all variables are used for plotting.
149+
if context is None:
150+
if variable_names is None:
151+
variable_names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]
136152
else:
137-
names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]
138-
153+
variable_names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
139154
else:
140155
raise TypeError("Only dicts and tensors are supported as arguments.")
141156

142-
return dict(post_variables=post_variables, prior_variables=prior_variables, names=names, num_variables=len(names))
157+
return dict(
158+
post_variables=post_variables,
159+
prior_variables=prior_variables,
160+
variable_names=variable_names,
161+
num_variables=len(variable_names),
162+
)

bayesflow/utils/plot_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99

1010

1111
def preprocess(
12-
post_variables: dict[str, np.ndarray],
13-
prior_variables: dict[str, np.ndarray],
14-
names: Sequence[str] = None,
12+
post_variables: dict[str, np.ndarray] | np.ndarray,
13+
prior_variables: dict[str, np.ndarray] | np.ndarray,
14+
filter_keys: Sequence[str] = None,
15+
variable_names: Sequence[str] = None,
1516
context: str = None,
1617
num_col: int = None,
1718
num_row: int = None,
@@ -43,7 +44,13 @@ def preprocess(
4344
Whether the plots are stacked horizontally
4445
"""
4546

46-
plot_data = dicts_to_arrays(post_variables, prior_variables, names, context)
47+
plot_data = dicts_to_arrays(
48+
post_variables=post_variables,
49+
prior_variables=prior_variables,
50+
filter_keys=filter_keys,
51+
variable_names=variable_names,
52+
context=context,
53+
)
4754
check_posterior_prior_shapes(plot_data["post_variables"], plot_data["prior_variables"])
4855

4956
# Configure layout

0 commit comments

Comments
 (0)