- Notifications
You must be signed in to change notification settings - Fork 2.2k
Open
Labels
Description
Describe the issue:
Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.
It seems to be caused by input being a numpy array rather than a pytensor one.
Fix seems simple. Posting a PR for it next
Reproduceable code example:
import pymc as pm, numpy as np with pm.Model() as model: pm.ZeroSumNormal('zsn',shape=(10,)) pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) mp = pm.find_MAP() pm.sample(initvals=mp)Error message:
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8 5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) 6 mp = pm.find_MAP() ----> 8 pm.sample(initvals=mp) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs) 830 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] 831 with joined_blas_limiter(): --> 832 initial_points, step = init_nuts( 833 init=init, 834 chains=chains, 835 n_init=n_init, 836 model=model, 837 random_seed=random_seed_list, 838 progressbar=progress_bool, 839 jitter_max_retries=jitter_max_retries, 840 tune=tune, 841 initvals=initvals, 842 compile_kwargs=compile_kwargs, 843 **kwargs, 844 ) 845 else: 846 # Get initial points 847 ipfns = make_initial_point_fns_per_chain( 848 model=model, 849 overrides=initvals, 850 jitter_rvs=set(), 851 chains=chains, 852 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs) 1602 q, _ = DictToArrayBijection.map(ip) 1603 return logp_dlogp_func([q], extra_vars={})[0] -> 1605 initial_points = _init_jitter( 1606 model, 1607 initvals, 1608 seeds=random_seed_list, 1609 jitter="jitter" in init, 1610 jitter_max_retries=jitter_max_retries, 1611 logp_fn=model_logp_fn, 1612 ) 1614 apoints = [DictToArrayBijection.map(point) for point in initial_points] 1615 apoints_data = [apoint.data for apoint in apoints] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn) 1432 def _init_jitter( 1433 model: Model, 1434 initvals: StartDict | Sequence[StartDict | None] | None, (...) 1438 logp_fn: Callable[[PointType], np.ndarray] | None = None, 1439 ) -> list[PointType]: 1440 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain. 1441 1442 ``model.check_start_vals`` is used to test whether the jittered starting (...) 1460 List of starting points for the sampler 1461 """ -> 1462 ipfns = make_initial_point_fns_per_chain( 1463 model=model, 1464 overrides=initvals, 1465 jitter_rvs=set(model.free_RVs) if jitter else set(), 1466 chains=len(seeds), 1467 ) 1469 if not jitter: 1470 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains) 72 """Create an initial point function for each chain, as defined by initvals. 73 74 If a single initval dictionary is passed, the function is replicated for each (...) 95 96 """ 97 if isinstance(overrides, dict) or overrides is None: 98 # One strategy for all chains 99 # Only one function compilation is needed. 100 ipfns = [ --> 101 make_initial_point_fn( 102 model=model, 103 overrides=overrides, 104 jitter_rvs=jitter_rvs, 105 return_transformed=True, 106 ) 107 ] * chains 108 elif len(overrides) == chains: 109 ipfns = [ 110 make_initial_point_fn( 111 model=model, (...) 116 for chain_overrides in overrides 117 ] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed) 126 def make_initial_point_fn( 127 *, 128 model, (...) 132 return_transformed: bool = True, 133 ) -> Callable[[SeedSequenceSeed], PointType]: 134 """Create seeded function that computes initial values for all free model variables. 135 136 Parameters (...) 150 initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]] 151 """ --> 152 sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) 153 initval_strats = { 154 **model.rvs_to_initial_values, 155 **sdict_overrides, 156 } 158 initial_values = make_initial_point_expression( 159 free_rvs=model.free_RVs, 160 rvs_to_transforms=model.rvs_to_transforms, (...) 164 return_transformed=return_transformed, 165 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start) 55 if is_transformed_name(key): 56 rv = model[get_untransformed_name(key)] ---> 57 initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs) 58 else: 59 initvals[model[key]] = initval File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs) 307 def backward(self, value, *rv_inputs): 308 for axis in self.zerosum_axes: --> 309 value = self.extend_axis(value, axis=axis) 310 return value File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis) 279 @staticmethod 280 def extend_axis(array, axis): --> 281 n = (array.shape[axis] + 1).astype("floatX") 282 sum_vals = array.sum(axis, keepdims=True) 283 norm = sum_vals / (pt.sqrt(n) + n) AttributeError: 'int' object has no attribute 'astype'PyMC version information:
pymc 5.22.0
Context for the issue:
I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.