- Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
The dtype promotion check in _fix_promotion does not correctly identify scalar inputs, and unconditionally accesses .dtype.
This breaks binary operators with float scalar inputs.
The can be fixed by accessing dtype via getattr with a None default or validating that the input is not a scalar.
Happy to provide a PR.
Minimal repo, in version 1.4, via:
import torch import numpy import array_api_compat as aac aac.__version__ () t = torch.arange(10) n = numpy.arange(10) numpy.add(n, 1.0) torch.add(t, 1.0) aac.get_namespace(n).add(n, 1.0) aac.get_namespace(t).add(t, 1.0)Raises:
9 torch.add(t, 1.0) 11 aac.get_namespace(n).add(n, 1.0) ---> 12 aac.get_namespace(t).add(t, 1.0) File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:91, in _two_arg.<locals>._f(x1, x2, **kwargs) 89 @wraps(f) 90 def _f(x1, x2, /, **kwargs): ---> 91 x1, x2 = _fix_promotion(x1, x2) 92 return f(x1, x2, **kwargs) File ~/ab/main/.conda/lib/python3.10/site-packages/array_api_compat/torch/_aliases.py:104, in _fix_promotion(x1, x2, only_scalar) 103 def _fix_promotion(x1, x2, only_scalar=True): --> 104 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: 105 return x1, x2 106 # If an argument is 0-D pytorch downcasts the other argument AttributeError: 'float' object has no attribute 'dtype' Would expect equivalent behavior to torch.add.
See:
https://gist.github.com/asford/ee688d59f0747a6507b9670a83fa7c47
Metadata
Metadata
Assignees
Labels
No labels