|  | 
| 10 | 10 |  is_dask_array, is_jax_array, is_pydata_sparse_array, | 
| 11 | 11 |  is_numpy_namespace, is_cupy_namespace, is_torch_namespace, | 
| 12 | 12 |  is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace, | 
|  | 13 | + is_array_api_strict_namespace, | 
| 13 | 14 | ) | 
| 14 | 15 | 
 | 
| 15 | 16 | from array_api_compat import ( | 
|  | 
| 33 | 34 |  'dask.array': 'is_dask_namespace', | 
| 34 | 35 |  'jax.numpy': 'is_jax_namespace', | 
| 35 | 36 |  'sparse': 'is_pydata_sparse_namespace', | 
|  | 37 | + 'array_api_strict': 'is_array_api_strict_namespace', | 
| 36 | 38 | } | 
| 37 | 39 | 
 | 
| 38 | 40 | 
 | 
| @@ -74,7 +76,12 @@ def test_xp_is_array_generics(library): | 
| 74 | 76 |  is_func = globals()[func] | 
| 75 | 77 |  if is_func(x0): | 
| 76 | 78 |  matches.append(library2) | 
| 77 |  | - assert matches in ([library], ["numpy"]) | 
|  | 79 | + | 
|  | 80 | + if library == "array_api_strict": | 
|  | 81 | + # There is no is_array_api_strict_array() function | 
|  | 82 | + assert matches == [] | 
|  | 83 | + else: | 
|  | 84 | + assert matches in ([library], ["numpy"]) | 
| 78 | 85 | 
 | 
| 79 | 86 | 
 | 
| 80 | 87 | @pytest.mark.parametrize("library", all_libraries) | 
| @@ -213,26 +220,33 @@ def test_to_device_host(library): | 
| 213 | 220 | @pytest.mark.parametrize("target_library", is_array_functions.keys()) | 
| 214 | 221 | @pytest.mark.parametrize("source_library", is_array_functions.keys()) | 
| 215 | 222 | def test_asarray_cross_library(source_library, target_library, request): | 
| 216 |  | - if source_library == "dask.array" and target_library == "torch": | 
|  | 223 | + def _xfail(reason: str) -> None: | 
| 217 | 224 |  # Allow rest of test to execute instead of immediately xfailing | 
| 218 | 225 |  # xref https://github.com/pandas-dev/pandas/issues/38902 | 
|  | 226 | + request.node.add_marker(pytest.mark.xfail(reason=reason)) | 
| 219 | 227 | 
 | 
|  | 228 | + if source_library == "dask.array" and target_library == "torch": | 
| 220 | 229 |  # TODO: remove xfail once | 
| 221 | 230 |  # https://github.com/dask/dask/issues/8260 is resolved | 
| 222 |  | - request.node.add_marker(pytest.mark.xfail(reason="Bug in dask raising error on conversion")) | 
| 223 |  | - if source_library == "cupy" and target_library != "cupy": | 
|  | 231 | + _xfail(reason="Bug in dask raising error on conversion") | 
|  | 232 | + elif source_library == "jax.numpy" and target_library == "torch": | 
|  | 233 | + _xfail(reason="casts int to float") | 
|  | 234 | + elif source_library == "cupy" and target_library != "cupy": | 
| 224 | 235 |  # cupy explicitly disallows implicit conversions to CPU | 
| 225 | 236 |  pytest.skip(reason="cupy does not support implicit conversion to CPU") | 
| 226 | 237 |  elif source_library == "sparse" and target_library != "sparse": | 
| 227 | 238 |  pytest.skip(reason="`sparse` does not allow implicit densification") | 
|  | 239 | + | 
| 228 | 240 |  src_lib = import_(source_library, wrapper=True) | 
| 229 | 241 |  tgt_lib = import_(target_library, wrapper=True) | 
| 230 | 242 |  is_tgt_type = globals()[is_array_functions[target_library]] | 
| 231 | 243 | 
 | 
| 232 |  | - a = src_lib.asarray([1, 2, 3]) | 
|  | 244 | + a = src_lib.asarray([1, 2, 3], dtype=src_lib.int32) | 
| 233 | 245 |  b = tgt_lib.asarray(a) | 
| 234 | 246 | 
 | 
| 235 | 247 |  assert is_tgt_type(b), f"Expected {b} to be a {tgt_lib.ndarray}, but was {type(b)}" | 
|  | 248 | + assert b.dtype == tgt_lib.int32 | 
|  | 249 | + | 
| 236 | 250 | 
 | 
| 237 | 251 | 
 | 
| 238 | 252 | @pytest.mark.parametrize("library", wrapped_libraries) | 
|  | 
0 commit comments