File tree Expand file tree Collapse file tree 1 file changed +24
-0
lines changed
Expand file tree Collapse file tree 1 file changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -309,6 +309,30 @@ def test_from_cuda_array_interface(self):
309309 torch_ary += 42
310310 self .assertEqual (torch_ary .cpu ().data .numpy (), numpy .asarray (numba_ary ) + 42 )
311311
312+ @unittest .skipIf (not TEST_NUMPY , "No numpy" )
313+ @unittest .skipIf (not TEST_CUDA , "No cuda" )
314+ @unittest .skipIf (not TEST_NUMBA_CUDA , "No numba.cuda" )
315+ def test_from_cuda_array_interface_inferred_strides (self ):
316+ """torch.as_tensor(numba_ary) should have correct inferred (contiguous) strides"""
317+ # This could, in theory, be combined with test_from_cuda_array_interface but that test
318+ # is overly strict: it checks that the exported protocols are exactly the same, which
319+ # cannot handle differing exported protocol versions.
320+ dtypes = [
321+ numpy .float64 ,
322+ numpy .float32 ,
323+ numpy .int64 ,
324+ numpy .int32 ,
325+ numpy .int16 ,
326+ numpy .int8 ,
327+ numpy .uint8 ,
328+ ]
329+ for dtype in dtypes :
330+ numpy_ary = numpy .arange (6 ).reshape (2 , 3 ).astype (dtype ),
331+ numba_ary = numba .cuda .to_device (numpy_ary )
332+ self .assertTrue (numba_ary .is_c_contiguous ())
333+ torch_ary = torch .as_tensor (numba_ary , device = "cuda" )
334+ self .assertTrue (torch_ary .is_contiguous ())
335+
312336 @unittest .skipIf (not TEST_NUMPY , "No numpy" )
313337 @unittest .skipIf (not TEST_CUDA , "No cuda" )
314338 @unittest .skipIf (not TEST_NUMBA_CUDA , "No numba.cuda" )
You can’t perform that action at this time.
0 commit comments