Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
360df9f
Add shape testing for vector_norm()
asmeurer Feb 17, 2022
e34f492
Test the dtype and stacks in the vector_norm() test
asmeurer Feb 18, 2022
111c237
Remove an ununsed variable
asmeurer Feb 18, 2022
4f3aa54
Use a simpler strategy for ord in test_vector_norm
asmeurer Feb 19, 2022
979b81b
Skip the test_vector_norm test on the NumPy CI
asmeurer Feb 25, 2022
8df237a
Fix syntax error
asmeurer Feb 25, 2022
d11a685
Fix the input strategies for test_tensordot()
asmeurer Feb 26, 2022
a776cd4
Add a test for the tensordot result shape
asmeurer Apr 9, 2022
45b36d6
Test stacking for tensordot
asmeurer Apr 12, 2022
414b322
Add allclose() and assert_allclose() helper functions
asmeurer Apr 25, 2022
9bb8c7a
Use assert_allclose() in the linalg tests for float inputs
asmeurer Apr 25, 2022
b3fb4ec
Remove skip from test_eigh
asmeurer Apr 26, 2022
241220e
Disable eigenvectors stack test
asmeurer May 6, 2022
ca70fbe
Reduce the relative tolerance in assert_allclose
asmeurer May 6, 2022
720b309
Sort the eigenvalues when testing stacks
asmeurer May 6, 2022
17d93bf
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer May 9, 2022
f439259
Sort the results in eigvalsh before comparing
asmeurer Jun 3, 2022
75ca73a
Remove the allclose testing in linalg
asmeurer Jun 13, 2022
d86a0a1
Add (commented out) stacking tests for solve()
asmeurer Jun 16, 2022
9bccfa5
Remove unused none standin in the linalg tests
asmeurer Jun 16, 2022
f494b45
Don't compare float elements in test_tensordot
asmeurer Jun 16, 2022
74add08
Fix test_vecdot
asmeurer Jun 24, 2022
f12be47
Fix typo in test_vecdot
asmeurer Jul 5, 2022
d41d0bd
Expand vecdot tests
asmeurer Jul 5, 2022
1220d6e
Merge branch 'master' into more-linalg2
asmeurer Sep 27, 2022
a96a5df
Merge branch 'master' into more-linalg2
asmeurer Oct 20, 2022
48a8442
Check specially that the result of linalg functions is not a unnamed …
asmeurer Nov 29, 2022
fd6367f
Use a more robust fallback helper for matrix_transpose
asmeurer Mar 17, 2023
7017797
Be more constrained about constructing symmetric matrices
asmeurer Mar 20, 2023
335574e
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer Mar 21, 2023
246e38a
Don't require the arguments to assert_keepdimable_shape to be positio…
asmeurer Mar 23, 2023
02542ff
Show the arrays in the error message for assert_exactly_equal
asmeurer Mar 29, 2023
72974e0
Allow passing an extra assertion message to assert_equal in linalg an…
asmeurer Mar 29, 2023
1daba5d
Fix the true_value check for test_vecdot
asmeurer Mar 29, 2023
bbfe50f
Fix the test_diagonal true value check
asmeurer Mar 29, 2023
64b0342
Use a function instead of operation
asmeurer Mar 29, 2023
9cb58a1
Add a comment
asmeurer Apr 18, 2023
0b3e170
Merge branch 'master' into more-linalg2
asmeurer Feb 3, 2024
c51216b
Remove flaky skips from linalg tests
asmeurer Feb 3, 2024
cffd076
Fix some issues in linalg tests from recent merge
asmeurer Feb 3, 2024
3501116
Fix vector_norm to not use our custom arrays strategy
asmeurer Feb 3, 2024
5c1aa45
Update _test_stacks to use updated ndindex behavior
asmeurer Feb 3, 2024
7a46e6b
Further limit the size of n in test_matrix_power
asmeurer Feb 3, 2024
6d154f2
Fix test_trace
asmeurer Feb 3, 2024
257aa13
Fix test_vecdot to only generate axis in [-min(x1.ndim, x2.ndim), -1]
asmeurer Feb 3, 2024
afc8a25
Update test_cross to test broadcastable shapes
asmeurer Feb 3, 2024
3cb9912
Fix test_cross to use assert_dtype and assert_shape helpers
asmeurer Feb 3, 2024
012ca19
Remove some completed TODO comments
asmeurer Feb 3, 2024
5ceb81d
Update linalg tests to test complex dtypes
asmeurer Feb 3, 2024
a4d419f
Update linalg tests to use assert_dtype and assert_shape helpers
asmeurer Feb 3, 2024
6f9db94
Factor out dtype logic from test_sum() and test_prod() and apply it t…
asmeurer Feb 3, 2024
5aa9083
Remove unused allclose and assert_allclose helpers
asmeurer Feb 7, 2024
938f086
Update ndindex version requirement
asmeurer Feb 16, 2024
3856b8f
Fix linting issue
asmeurer Feb 16, 2024
ccc6ca3
Skip `test_cross` in CI
honno Feb 20, 2024
3092422
Test matmul, matrix_transpose, tensordot, and vecdot for the main and…
asmeurer Feb 23, 2024
2d918e4
Merge branch 'more-linalg2' of github.com:asmeurer/array-api-tests in…
asmeurer Feb 23, 2024
3fefd20
Remove need for filtering in `invertible_matrices()`
honno Feb 26, 2024
a76e051
Merge branch 'master' into more-linalg2
honno Feb 26, 2024
268682d
Skip flaky `test_reshape`
honno Feb 26, 2024
0ddb0cd
Less filtering in `positive_definitive_matrices`
honno Feb 26, 2024
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Factor out dtype logic from test_sum() and test_prod() and apply it t…
…o test_trace()
  • Loading branch information
asmeurer committed Feb 3, 2024
commit 6f9db94db029a851b567f7a53fc92cb9efc04868
41 changes: 41 additions & 0 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,47 @@ def as_real_dtype(dtype):
else:
raise ValueError("as_real_dtype requires a floating-point dtype")

def accumulation_result_dtype(x_dtype, dtype_kwarg):
"""
Result dtype logic for sum(), prod(), and trace()

Note: may return None if a default uint cannot exist (e.g., for pytorch
which doesn't support uint32 or uint64). See https://github.com/data-apis/array-api-tests/issues/106

"""
if dtype_kwarg is None:
if is_int_dtype(x_dtype):
if x_dtype in uint_dtypes:
default_dtype = default_uint
else:
default_dtype = default_int
if default_dtype is None:
_dtype = None
else:
m, M = dtype_ranges[x_dtype]
d_m, d_M = dtype_ranges[default_dtype]
if m < d_m or M > d_M:
_dtype = x_dtype
else:
_dtype = default_dtype
elif is_float_dtype(x_dtype, include_complex=False):
if dtype_nbits[x_dtype] > dtype_nbits[default_float]:
_dtype = x_dtype
else:
_dtype = default_float
elif api_version > "2021.12":
# Complex dtype
if dtype_nbits[x_dtype] > dtype_nbits[default_complex]:
_dtype = x_dtype
else:
_dtype = default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype_kwarg

return _dtype

if not hasattr(xp, "asarray"):
default_int = xp.int32
default_float = xp.float32
Expand Down
15 changes: 10 additions & 5 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,11 +795,16 @@ def test_tensordot(x1, x2, kw):
def test_trace(x, kw):
res = linalg.trace(x, **kw)

# TODO: trace() should promote in some cases. See
# https://github.com/data-apis/array-api/issues/202. See also the dtype
# argument to sum() below.

# assert res.dtype == x.dtype, "trace() returned the wrong dtype"
dtype = kw.get("dtype", None)
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
if expected_dtype is None:
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
# uint32 or uint64), we skip testing the output dtype.
# See https://github.com/data-apis/array-api-tests/issues/160
if x.dtype in dh.uint_dtypes:
assert dh.is_int_dtype(res.dtype) # sanity check
else:
ph.assert_dtype("trace", in_dtype=x.dtype, out_dtype=res.dtype, expected=expected_dtype)

n, m = x.shape[-2:]
ph.assert_result_shape('trace', x.shape, res.shape, expected=x.shape[:-2])
Expand Down
70 changes: 6 additions & 64 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,44 +130,15 @@ def test_prod(x, data):
out = xp.prod(x, **kw)

dtype = kw.get("dtype", None)
if dtype is None:
if dh.is_int_dtype(x.dtype):
if x.dtype in dh.uint_dtypes:
default_dtype = dh.default_uint
else:
default_dtype = dh.default_int
if default_dtype is None:
_dtype = None
else:
m, M = dh.dtype_ranges[x.dtype]
d_m, d_M = dh.dtype_ranges[default_dtype]
if m < d_m or M > d_M:
_dtype = x.dtype
else:
_dtype = default_dtype
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
if expected_dtype is None:
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
# uint32 or uint64), we skip testing the output dtype.
# See https://github.com/data-apis/array-api-tests/issues/106
if x.dtype in dh.uint_dtypes:
assert dh.is_int_dtype(out.dtype) # sanity check
else:
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)
ph.assert_dtype("prod", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
ph.assert_keepdimable_shape(
"prod", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
Expand Down Expand Up @@ -246,44 +217,15 @@ def test_sum(x, data):
out = xp.sum(x, **kw)

dtype = kw.get("dtype", None)
if dtype is None:
if dh.is_int_dtype(x.dtype):
if x.dtype in dh.uint_dtypes:
default_dtype = dh.default_uint
else:
default_dtype = dh.default_int
if default_dtype is None:
_dtype = None
else:
m, M = dh.dtype_ranges[x.dtype]
d_m, d_M = dh.dtype_ranges[default_dtype]
if m < d_m or M > d_M:
_dtype = x.dtype
else:
_dtype = default_dtype
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
expected_dtype = dh.accumulation_result_dtype(x.dtype, dtype)
if expected_dtype is None:
# If a default uint cannot exist (i.e. in PyTorch which doesn't support
# uint32 or uint64), we skip testing the output dtype.
# See https://github.com/data-apis/array-api-tests/issues/160
if x.dtype in dh.uint_dtypes:
assert dh.is_int_dtype(out.dtype) # sanity check
else:
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=_dtype)
ph.assert_dtype("sum", in_dtype=x.dtype, out_dtype=out.dtype, expected=expected_dtype)
_axes = sh.normalise_axis(kw.get("axis", None), x.ndim)
ph.assert_keepdimable_shape(
"sum", in_shape=x.shape, out_shape=out.shape, axes=_axes, keepdims=keepdims, kw=kw
Expand Down