1+ import  pytest 
2+ 
13import  array_api_compat 
24from  array_api_compat  import  array_namespace 
35
4- from  ._helpers  import  import_ 
5- 
6- import  pytest 
76
87@pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" , "dask.array" ]) 
9- @pytest .mark .parametrize ("api_version" , [None , ' 2021.12'  ]) 
8+ @pytest .mark .parametrize ("api_version" , [None , " 2021.12"  ]) 
109def  test_array_namespace (library , api_version ):
11-  lib  =  import_ (library )
10+  lib  =  pytest . importorskip (library )
1211
1312 array  =  lib .asarray ([1.0 , 2.0 , 3.0 ])
1413 namespace  =  array_api_compat .array_namespace (array , api_version = api_version )
1514
16-  if  ' array_api'   in  library :
15+  if  " array_api"   in  library :
1716 assert  namespace  ==  lib 
1817 else :
1918 if  library  ==  "dask.array" :
@@ -23,21 +22,22 @@ def test_array_namespace(library, api_version):
2322
2423
2524def  test_array_namespace_errors ():
25+  np  =  pytest .importorskip ("numpy" )
26+ 
2627 pytest .raises (TypeError , lambda : array_namespace ([1 ]))
2728 pytest .raises (TypeError , lambda : array_namespace ())
2829
29-  import  numpy  as  np 
3030 x  =  np .asarray ([1 , 2 ])
31- 
3231 pytest .raises (TypeError , lambda : array_namespace ((x , x )))
3332 pytest .raises (TypeError , lambda : array_namespace (x , (x , x )))
3433
35-  import  torch 
36-  y  =  torch .asarray ([1 , 2 ])
3734
38-  pytest .raises (TypeError , lambda : array_namespace (x , y ))
35+ def  test_array_namespace_errors_torch ():
36+  torch  =  pytest .importorskip ("torch" )
3937
40-  pytest .raises (ValueError , lambda : array_namespace (x , api_version = '2022.12' ))
38+  y  =  torch .asarray ([1 , 2 ])
39+  pytest .raises (TypeError , lambda : array_namespace (x , y ))
40+  pytest .raises (ValueError , lambda : array_namespace (x , api_version = "2022.12" ))
4141
4242
4343def  test_get_namespace ():
0 commit comments