- Notifications
You must be signed in to change notification settings - Fork 52
Description
The requirement to upcast sum(x)
to the default floating-point dtype with the default dtype=None
currently says (from the sum spec):
If x
has a real-valued floating-point data type, the returned array must have the default real-valued floating-point data type.
The rationale given is "keyword argument is intended to help prevent data type overflows.". This came up again in the review of NEP 56 (numpy/numpy#25542), and is basically the only part of the standard that was flagged as problematic and explicitly rejected.
I agree that the standard's choice here is problematic, at least from a practical perspective: no array library does this, and none are planning to implement this. And the rationale is pretty weak, it just does not apply to floating-point dtypes to a similar extent as it does to integer dtypes (and for integers, array libraries do implement the upcasting). Examples:
>>> # NumPy: >>> np.sum(np.ones(3, dtype=np.float32)).dtype dtype('float32') >>> np.sum(np.ones(3, dtype=np.int32)).dtype dtype('int64') >>> # PyTorch: >>> torch.sum(torch.ones(2, dtype=torch.bfloat16)).dtype torch.bfloat16 >>> torch.sum(torch.ones(2, dtype=torch.int16)).dtype torch.int64 >>> # JAX: >>> jnp.sum(jnp.ones(4, dtype=jnp.float16)).dtype dtype('float16') >>> jnp.sum(jnp.ones(4, dtype=jnp.int16)).dtype dtype('int32') >>> # CuPy: >>> cp.sum(cp.ones(5, dtype=cp.float16)).dtype dtype('float16') >>> cp.sum(cp.ones(5, dtype=cp.int32)).dtype dtype('int64') >>> # Dask: >>> da.sum(da.ones(6, dtype=np.float32)).dtype dtype('float32') >>> da.sum(da.ones(6, dtype=np.int32)).dtype dtype('int64') >>>
The most relevant conversation is #238 (comment). There was some further minor tweaks (without much discussion) in gh-666.
Proposed resolution: align the standard with what all known array libraries implement today.