Skip to content

Commit 25f5c64

Browse files
committed
Update dispatching distributions
1 parent 206d706 commit 25f5c64

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

bayesflow/networks/inference_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import keras
22

3-
from bayesflow.distributions import find_distribution
43
from bayesflow.types import Shape, Tensor
5-
from bayesflow.utils import layer_kwargs
4+
from bayesflow.utils import layer_kwargs, find_distribution
65
from bayesflow.utils.decorators import allow_batch_size
76

87

bayesflow/networks/summary_network.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import keras
22

3-
from bayesflow.distributions import find_distribution
43
from bayesflow.metrics.functional import maximum_mean_discrepancy
54
from bayesflow.types import Tensor
6-
from bayesflow.utils import layer_kwargs
5+
from bayesflow.utils import layer_kwargs, find_distribution
76
from bayesflow.utils.decorators import sanitize_input_shape
87
from bayesflow.utils.serialization import deserialize
98

bayesflow/utils/__init__.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@
77
logging,
88
numpy_utils,
99
)
10+
1011
from .callbacks import detailed_loss_callback
12+
1113
from .devices import devices
14+
1215
from .dict_utils import (
1316
convert_args,
1417
convert_kwargs,
@@ -20,30 +23,48 @@
2023
split_arrays,
2124
squeeze_inner_estimates_dict,
2225
)
23-
from .dispatch import find_network, find_permutation, find_pooling, find_recurrent_net
26+
27+
from .dispatch import (
28+
find_network,
29+
find_permutation,
30+
find_pooling,
31+
find_recurrent_net,
32+
find_summary_network,
33+
find_inference_network,
34+
find_distribution,
35+
)
36+
2437
from .ecdf import simultaneous_ecdf_bands, ranks
38+
2539
from .functional import batched_call
40+
2641
from .git import (
2742
issue_url,
2843
pull_url,
2944
repo_url,
3045
)
46+
3147
from .hparam_utils import find_batch_size, find_memory_budget
48+
3249
from .integrate import (
3350
integrate,
3451
)
52+
3553
from .io import (
3654
pickle_load,
3755
format_bytes,
3856
parse_bytes,
3957
)
58+
4059
from .jacobian import (
4160
jacobian,
4261
jacobian_trace,
4362
jvp,
4463
vjp,
4564
)
65+
4666
from .optimal_transport import optimal_transport
67+
4768
from .plot_utils import (
4869
check_estimates_prior_shapes,
4970
prepare_plot_data,
@@ -53,6 +74,7 @@
5374
add_metric,
5475
)
5576
from .serialization import serialize_value_or_type, deserialize_value_or_type
77+
5678
from .tensor_utils import (
5779
concatenate_valid,
5880
expand,
@@ -75,9 +97,10 @@
7597
fill_triangular_matrix,
7698
weighted_mean,
7799
)
100+
78101
from .classification import calibration_curve, confusion_matrix
102+
79103
from .validators import check_lengths_same
80-
from .workflow_utils import find_inference_network, find_summary_network
81104

82105
from ._docs import _add_imports_to_all
83106

0 commit comments

Comments
 (0)