@@ -4234,14 +4234,14 @@ def test_future_empty_dim(self, device, dtype, op):
42344234class TestSparseMeta (TestCase ):
42354235 exact_dtype = True
42364236
4237- def test_basic (self ):
4238- r = torch .empty (4 , 4 , layout = torch .sparse_coo , device = 'meta' )
4237+ def _test_basic_coo (self , dtype ):
4238+ r = torch .empty (4 , 4 , dtype = dtype , layout = torch .sparse_coo , device = 'meta' )
42394239 self .assertTrue (r .is_meta )
42404240 self .assertEqual (r .device .type , "meta" )
42414241 r2 = torch .empty_like (r )
42424242 self .assertTrue (r2 .is_meta )
42434243 self .assertEqual (r , r2 )
4244- r3 = torch .sparse_coo_tensor (size = (4 , 4 ), device = 'meta' )
4244+ r3 = torch .sparse_coo_tensor (size = (4 , 4 ), dtype = dtype , device = 'meta' )
42454245 self .assertTrue (r3 .is_meta )
42464246 self .assertEqual (r , r3 )
42474247 r .sparse_resize_ ((4 , 4 ), 1 , 1 )
@@ -4260,9 +4260,67 @@ def test_basic(self):
42604260 # TODO: this sort of aliasing will need to be handled by
42614261 # functionalization
42624262 self .assertEqual (r ._indices (), torch .empty (2 , 0 , device = 'meta' , dtype = torch .int64 ))
4263- self .assertEqual (r ._values (), torch .empty (0 , 4 , device = 'meta' ))
4263+ self .assertEqual (r ._values (), torch .empty (0 , 4 , dtype = dtype , device = 'meta' ))
42644264 self .assertEqual (r .indices (), torch .empty (2 , 0 , device = 'meta' , dtype = torch .int64 ))
4265- self .assertEqual (r .values (), torch .empty (0 , 4 , device = 'meta' ))
4265+ self .assertEqual (r .values (), torch .empty (0 , 4 , dtype = dtype , device = 'meta' ))
4266+
4267+ def _test_basic_sparse_compressed (self , dtype , layout , batch_shape , dense_shape ):
4268+ index_dtype = torch .int64
4269+ blocksize = (2 , 3 ) if layout in {torch .sparse_bsr , torch .sparse_bsc } else ()
4270+ sparse_shape = (4 , 6 )
4271+ nnz = 0
4272+
4273+ shape = (* batch_shape , * sparse_shape , * dense_shape )
4274+ compressed_dim = 0 if layout in {torch .sparse_csr , torch .sparse_bsr } else 1
4275+ nof_compressed_indices = (sparse_shape [compressed_dim ] // blocksize [compressed_dim ] + 1 if blocksize
4276+ else sparse_shape [compressed_dim ] + 1 )
4277+ compressed_indices = torch .empty ((* batch_shape , nof_compressed_indices ), device = 'meta' , dtype = index_dtype )
4278+ plain_indices = torch .empty ((* batch_shape , nnz ), device = 'meta' , dtype = index_dtype )
4279+
4280+ values = torch .empty ((* batch_shape , nnz , * blocksize , * dense_shape ), device = 'meta' , dtype = dtype )
4281+ r = torch .sparse_compressed_tensor (
4282+ compressed_indices ,
4283+ plain_indices ,
4284+ values ,
4285+ shape ,
4286+ layout = layout
4287+ )
4288+ self .assertTrue (r .is_meta )
4289+ self .assertEqual (r .device .type , "meta" )
4290+
4291+ self .assertEqual (r .sparse_dim (), 2 )
4292+ self .assertEqual (r .dense_dim (), len (dense_shape ))
4293+ self .assertEqual (r ._nnz (), nnz )
4294+ batch_dims = r .ndim - r .sparse_dim () - r .dense_dim ()
4295+ r_blocksize = r .values ().shape [batch_dims + 1 : batch_dims + 1 + len (blocksize )]
4296+ self .assertEqual (r_blocksize , blocksize )
4297+
4298+ r_compressed_indices = r .crow_indices () if layout in {torch .sparse_csr , torch .sparse_bsr } else r .ccol_indices ()
4299+ r_plain_indices = r .col_indices () if layout in {torch .sparse_csr , torch .sparse_bsr } else r .row_indices ()
4300+
4301+ self .assertEqual (r_compressed_indices ,
4302+ torch .empty ((* batch_shape , nof_compressed_indices ), device = 'meta' , dtype = index_dtype ))
4303+ self .assertEqual (r_plain_indices , torch .empty ((* batch_shape , nnz ), device = 'meta' , dtype = index_dtype ))
4304+ self .assertEqual (r .values (), torch .empty ((* batch_shape , nnz , * blocksize , * dense_shape ), device = 'meta' , dtype = dtype ))
4305+
4306+ r2 = torch .empty_like (r )
4307+ self .assertTrue (r2 .is_meta )
4308+ self .assertEqual (r2 , r )
4309+
4310+ if layout in {torch .sparse_csr , torch .sparse_csc }:
4311+ r3 = torch .empty ((* batch_shape , * sparse_shape ), dtype = dtype , layout = layout , device = "meta" )
4312+ self .assertTrue (r3 .is_meta )
4313+ if not dense_shape :
4314+ self .assertEqual (r3 , r )
4315+
4316+ @all_sparse_layouts ('layout' , include_strided = False )
4317+ @parametrize ("dtype" , [torch .float64 ])
4318+ def test_basic (self , dtype , layout ):
4319+ if layout is torch .sparse_coo :
4320+ self ._test_basic_coo (dtype )
4321+ else :
4322+ for batch_shape , dense_shape in itertools .product ([(), (2 ,)], [(), (3 ,)]):
4323+ self ._test_basic_sparse_compressed (dtype , layout , batch_shape , dense_shape )
42664324
42674325
42684326class _SparseDataset (torch .utils .data .Dataset ):
@@ -5125,6 +5183,8 @@ def test_invalid_blocksize(self):
51255183
51265184instantiate_device_type_tests (TestSparseAny , globals (), except_for = 'meta' )
51275185
5186+ instantiate_parametrized_tests (TestSparseMeta )
5187+
51285188instantiate_parametrized_tests (TestSparseLegacyAndDeprecation )
51295189
51305190if __name__ == '__main__' :
0 commit comments