Skip to content

BUG: ZeroSumTransform fails with initvalues #7772

@velochy

Description

@velochy

Describe the issue:

Trying to set initvals on ZeroSumTransform'ed variables leads to a type casting error.

It seems to be caused by input being a numpy array rather than a pytensor one.

Fix seems simple. Posting a PR for it next

Reproduceable code example:

import pymc as pm, numpy as np with pm.Model() as model: pm.ZeroSumNormal('zsn',shape=(10,)) pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) mp = pm.find_MAP() pm.sample(initvals=mp)

Error message:

--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) /home/velochy/salk/sandbox/sandy.ipynb Cell 1 line 8 5 pm.Normal('n', shape=(10,), transform=pm.distributions.transforms.ZeroSumTransform(zerosum_axes=[0])) 6 mp = pm.find_MAP() ----> 8 pm.sample(initvals=mp) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:832, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs) 830 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()] 831 with joined_blas_limiter(): --> 832 initial_points, step = init_nuts( 833 init=init, 834 chains=chains, 835 n_init=n_init, 836 model=model, 837 random_seed=random_seed_list, 838 progressbar=progress_bool, 839 jitter_max_retries=jitter_max_retries, 840 tune=tune, 841 initvals=initvals, 842 compile_kwargs=compile_kwargs, 843 **kwargs, 844 ) 845 else: 846 # Get initial points 847 ipfns = make_initial_point_fns_per_chain( 848 model=model, 849 overrides=initvals, 850 jitter_rvs=set(), 851 chains=chains, 852 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1605, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, compile_kwargs, **kwargs) 1602 q, _ = DictToArrayBijection.map(ip) 1603 return logp_dlogp_func([q], extra_vars={})[0] -> 1605 initial_points = _init_jitter( 1606 model, 1607 initvals, 1608 seeds=random_seed_list, 1609 jitter="jitter" in init, 1610 jitter_max_retries=jitter_max_retries, 1611 logp_fn=model_logp_fn, 1612 ) 1614 apoints = [DictToArrayBijection.map(point) for point in initial_points] 1615 apoints_data = [apoint.data for apoint in apoints] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/sampling/mcmc.py:1462, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries, logp_fn) 1432 def _init_jitter( 1433 model: Model, 1434 initvals: StartDict | Sequence[StartDict | None] | None, (...) 1438 logp_fn: Callable[[PointType], np.ndarray] | None = None, 1439 ) -> list[PointType]: 1440 """Apply a uniform jitter in [-1, 1] to the test value as starting point in each chain.  1441   1442 ``model.check_start_vals`` is used to test whether the jittered starting  (...)  1460 List of starting points for the sampler  1461 """ -> 1462 ipfns = make_initial_point_fns_per_chain( 1463 model=model, 1464 overrides=initvals, 1465 jitter_rvs=set(model.free_RVs) if jitter else set(), 1466 chains=len(seeds), 1467 ) 1469 if not jitter: 1470 return [ipfn(seed) for ipfn, seed in zip(ipfns, seeds)] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:101, in make_initial_point_fns_per_chain(model, overrides, jitter_rvs, chains) 72 """Create an initial point function for each chain, as defined by initvals.  73   74 If a single initval dictionary is passed, the function is replicated for each  (...)  95   96 """ 97 if isinstance(overrides, dict) or overrides is None: 98 # One strategy for all chains 99 # Only one function compilation is needed. 100 ipfns = [ --> 101 make_initial_point_fn( 102 model=model, 103 overrides=overrides, 104 jitter_rvs=jitter_rvs, 105 return_transformed=True, 106 ) 107 ] * chains 108 elif len(overrides) == chains: 109 ipfns = [ 110 make_initial_point_fn( 111 model=model, (...) 116 for chain_overrides in overrides 117 ] File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed) 126 def make_initial_point_fn( 127 *, 128 model, (...) 132 return_transformed: bool = True, 133 ) -> Callable[[SeedSequenceSeed], PointType]: 134 """Create seeded function that computes initial values for all free model variables.  135   136 Parameters  (...)  150 initial_point_fn : Callable[[SeedSequenceSeed], dict[str, np.ndarray]]  151 """ --> 152 sdict_overrides = convert_str_to_rv_dict(model, overrides or {}) 153 initval_strats = { 154 **model.rvs_to_initial_values, 155 **sdict_overrides, 156 } 158 initial_values = make_initial_point_expression( 159 free_rvs=model.free_RVs, 160 rvs_to_transforms=model.rvs_to_transforms, (...) 164 return_transformed=return_transformed, 165 ) File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/initial_point.py:57, in convert_str_to_rv_dict(model, start) 55 if is_transformed_name(key): 56 rv = model[get_untransformed_name(key)] ---> 57 initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs) 58 else: 59 initvals[model[key]] = initval File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:309, in ZeroSumTransform.backward(self, value, *rv_inputs) 307 def backward(self, value, *rv_inputs): 308 for axis in self.zerosum_axes: --> 309 value = self.extend_axis(value, axis=axis) 310 return value File ~/miniconda3/envs/salk/lib/python3.12/site-packages/pymc/distributions/transforms.py:281, in ZeroSumTransform.extend_axis(array, axis) 279 @staticmethod 280 def extend_axis(array, axis): --> 281 n = (array.shape[axis] + 1).astype("floatX") 282 sum_vals = array.sum(axis, keepdims=True) 283 norm = sum_vals / (pt.sqrt(n) + n) AttributeError: 'int' object has no attribute 'astype'

PyMC version information:

pymc 5.22.0

Context for the issue:

I wanted to experiment with setting initvals from MAP and pathfinder, and ran into this issue.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions