Skip to content

No JAX dispatch for mul_without_zeros #526

@jessegrabowski

Description

@jessegrabowski

Describe the issue:

The ProdWithoutZeros Op arises in the gradients of pt.prod. This currently cannot be compiled to gradient mode unless we specifically pass no_zeros_in_input=True. I guess we would just need a JAX dispatch for this function? Or maybe a mapping to the correct jax.lax function?

Reproducable code example:

import pytensor import pytensor.tensor as pt x = pt.dvector('x') z = pt.prod(x, no_zeros_in_input=False) gz = pytensor.grad(z, x) f_gz = pytensor.function([x], gz, mode='JAX') f_gz([1, 2, 3, 4])

Error message:

Details
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f() 195 for thunk, node, old_storage in zip( 196 thunks, order, post_thunk_old_storage 197 ): --> 198 thunk() 199 for old_s in old_storage: File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs) 654 def thunk( 655 fgraph=self.fgraph, 656 fgraph_jit=fgraph_jit, 657 thunk_inputs=thunk_inputs, 658 thunk_outputs=thunk_outputs, 659 ): --> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) 662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): [... skipping hidden 12 frame] File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x) 12 # ProdWithoutZeros{axes=None}(Mul.0) ---> 13 tensor_variable_5 = careduce_1(tensor_variable_4) 14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x) 54 if to_reduce: 55 # In this case, we need to use the `jax.lax` function (if there 56 # is one), and not the `jnp` version. ---> 57 jax_op = getattr(jax.lax, scalar_fn_name) 58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros' During handling of the above exception, another exception occurred: AttributeError Traceback (most recent call last) Cell In[61], line 1 ----> 1 f_z([1, 2, 3, 4]) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/compile/function/types.py:970, in Function.__call__(self, *args, **kwargs) 967 t0_fn = time.perf_counter() 968 try: 969 outputs = ( --> 970 self.vm() 971 if output_subset is None 972 else self.vm(output_subset=output_subset) 973 ) 974 except Exception: 975 restore_defaults() File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:202, in streamline.<locals>.streamline_default_f() 200 old_s[0] = None 201 except Exception: --> 202 raise_with_op(fgraph, node, thunk) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:531, in raise_with_op(fgraph, node, thunk, exc_info, storage_map) 526 warnings.warn( 527 f"{exc_type} error does not allow us to add an extra error message" 528 ) 529 # Some exception need extra parameter in inputs. So forget the 530 # extra long error message in that case. --> 531 raise exc_value.with_traceback(exc_trace) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/utils.py:198, in streamline.<locals>.streamline_default_f() 194 try: 195 for thunk, node, old_storage in zip( 196 thunks, order, post_thunk_old_storage 197 ): --> 198 thunk() 199 for old_s in old_storage: 200 old_s[0] = None File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/basic.py:660, in JITLinker.create_jitable_thunk.<locals>.thunk(fgraph, fgraph_jit, thunk_inputs, thunk_outputs) 654 def thunk( 655 fgraph=self.fgraph, 656 fgraph_jit=fgraph_jit, 657 thunk_inputs=thunk_inputs, 658 thunk_outputs=thunk_outputs, 659 ): --> 660 outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) 662 for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): 663 compute_map[o_var][0] = True [... skipping hidden 12 frame] File /var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/tmpdtqmle3i:13, in jax_funcified_fgraph(x) 11 tensor_variable_4 = elemwise_fn_2(tensor_variable_3, x) 12 # ProdWithoutZeros{axes=None}(Mul.0) ---> 13 tensor_variable_5 = careduce_1(tensor_variable_4) 14 # ExpandDims{axis=0}(ProdWithoutZeros{axes=None}.0) 15 tensor_variable_6 = dimshuffle_1(tensor_variable_5) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/link/jax/dispatch/elemwise.py:57, in jax_funcify_CAReduce.<locals>.careduce(x) 52 to_reduce = sorted(axis, reverse=True) 54 if to_reduce: 55 # In this case, we need to use the `jax.lax` function (if there 56 # is one), and not the `jnp` version. ---> 57 jax_op = getattr(jax.lax, scalar_fn_name) 58 init_value = jnp.array(scalar_op_identity, dtype=acc_dtype) 59 return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype) File ~/mambaforge/envs/cge-dev/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name) 51 warnings.warn(message, DeprecationWarning, stacklevel=2) 52 return fn ---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}") AttributeError: module 'jax.lax' has no attribute 'mul_without_zeros' Apply node that caused the error: Switch(Eq.0, True_div.0, Switch.0) Toposort index: 13 Inputs types: [TensorType(bool, shape=(1,)), TensorType(float64, shape=(None,)), TensorType(float64, shape=(None,))] Inputs shapes: [(4,)] Inputs strides: [(8,)] Inputs values: [array([1., 2., 3., 4.])] Outputs clients: [['output']] Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer): File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3488, in run_ast_nodes if await self.run_code(code, result, async_=asy): File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3548, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "/var/folders/7b/rzxy96cj0w751_6td3g2yss00000gn/T/ipykernel_27218/3109327815.py", line 5, in <module> gz = pytensor.grad(z, x) File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 607, in grad _rval: Sequence[Variable] = _populate_grad_dict( File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in _populate_grad_dict rval = [access_grad_cache(elem) for elem in wrt] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1407, in <listcomp> rval = [access_grad_cache(elem) for elem in wrt] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1362, in access_grad_cache term = access_term_cache(node)[idx] File "/Users/jessegrabowski/mambaforge/envs/cge-dev/lib/python3.11/site-packages/pytensor/gradient.py", line 1192, in access_term_cache input_grads = node.op.L_op(inputs, node.outputs, new_output_grads) HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

PyTensor version information:

Pytensor 2.17.4

Context for the issue:

I want the gradient of a product in JAX mode

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions