- Notifications
You must be signed in to change notification settings - Fork 52
Open
Description
import array_api_strict as xp a = xp.arange(10, device=xp.Device("device1")) xp.searchsorted(a, 42)
raises:
Traceback (most recent call last): Cell In[5], line 5 xp.searchsorted(a, 42) File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_flags.py:395 in wrapper return func(*args, **kwargs) File ~/miniforge3/envs/dev/lib/python3.13/site-packages/array_api_strict/_searching_functions.py:78 in searchsorted if x1.dtype not in _real_numeric_dtypes or x2.dtype not in _real_numeric_dtypes: AttributeError: 'int' object has no attribute 'dtype'
This is a bit annoying as it requires to instead write: xp.searchsorted(a, xp.asarray(42, device=a.device))
which feels unnecessarily verbose.
Even PyTorch accepts the following without complaining, for instance:
import array_api_compat.torch as xp a = xp.arange(10, device="mps") xp.searchsorted(a, 42)
However, the SPEC does not mention Python scalar support explicitly, so maybe it would need to be updated first?
Metadata
Metadata
Assignees
Labels
No labels