Skip to content

Commit 5d71ba6

Browse files
pearupytorchmergebot
authored andcommitted
Add meta device support to sparse compressed tensors (pytorch#120498)
As in the title. Unblocks pytorch#117907 (comment) Pull Request resolved: pytorch#120498 Approved by: https://github.com/ezyang
1 parent 834c7a1 commit 5d71ba6

File tree

10 files changed

+130
-40
lines changed

10 files changed

+130
-40
lines changed

aten/src/ATen/SparseCsrTensorImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
5555
"to https://github.com/pytorch/pytorch/issues.");
5656

5757
TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
58-
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)),
58+
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)
59+
|| (key_set.has(DispatchKey::SparseCsrMeta) && device().type() == kMeta)),
5960
"Inconsistent key_set (=", key_set, ") and device (=", device(), ")");
6061

6162
set_storage_access_should_throw();

aten/src/ATen/native/native_functions.yaml

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2370,7 +2370,7 @@
23702370
Meta: empty_meta_symint
23712371
MkldnnCPU: empty_mkldnn
23722372
SparseCPU, SparseCUDA, SparseMeta: empty_sparse
2373-
SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
2373+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_sparse_compressed
23742374
QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
23752375
tags: core
23762376

@@ -2476,7 +2476,7 @@
24762476
CompositeExplicitAutograd: empty_like
24772477
QuantizedCPU, QuantizedCUDA: empty_like_quantized
24782478
SparseCPU, SparseCUDA, SparseMeta: empty_like_sparse_coo
2479-
SparseCsrCPU, SparseCsrCUDA: empty_like_sparse_csr
2479+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: empty_like_sparse_csr
24802480
NestedTensorCPU, NestedTensorCUDA: empty_like_nested
24812481
autogen: empty_like.out
24822482

@@ -6986,7 +6986,7 @@
69866986

69876987
- func: sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
69886988
dispatch:
6989-
CompositeExplicitAutograd: sparse_compressed_tensor
6989+
CompositeExplicitAutograd, SparseCsrMeta: sparse_compressed_tensor
69906990

69916991
- func: sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
69926992
- func: sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
@@ -7003,7 +7003,7 @@
70037003

70047004
- func: _sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
70057005
dispatch:
7006-
CompositeImplicitAutograd: _sparse_compressed_tensor_unsafe_symint
7006+
CompositeImplicitAutograd, SparseCsrMeta: _sparse_compressed_tensor_unsafe_symint
70077007

70087008
- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
70097009
- func: _sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@@ -7090,7 +7090,7 @@
70907090
dispatch:
70917091
CPU, CUDA: sparse_dim_strided
70927092
SparseCPU, SparseCUDA, SparseMeta: sparse_dim_sparse
7093-
SparseCsrCPU, SparseCsrCUDA: sparse_dim_sparse_csr
7093+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sparse_dim_sparse_csr
70947094
device_check: NoCheck
70957095
device_guard: False
70967096

@@ -7107,7 +7107,7 @@
71077107
dispatch:
71087108
CPU, CUDA: dense_dim_strided
71097109
SparseCPU, SparseCUDA, SparseMeta: dense_dim_sparse
7110-
SparseCsrCPU, SparseCsrCUDA: dense_dim_sparse_csr
7110+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: dense_dim_sparse_csr
71117111
device_check: NoCheck
71127112
device_guard: False
71137113

@@ -7123,7 +7123,7 @@
71237123
variants: method
71247124
dispatch:
71257125
SparseCPU, SparseCUDA, SparseMeta: _nnz_sparse
7126-
SparseCsrCPU, SparseCsrCUDA: _nnz_sparse_csr
7126+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: _nnz_sparse_csr
71277127
device_check: NoCheck
71287128
device_guard: False
71297129

@@ -7186,7 +7186,7 @@
71867186
variants: method
71877187
dispatch:
71887188
SparseCPU, SparseCUDA, SparseMeta: values_sparse
7189-
SparseCsrCPU, SparseCsrCUDA: values_sparse_csr
7189+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: values_sparse_csr
71907190
NestedTensorCPU, NestedTensorCUDA: values_nested
71917191
CompositeExplicitAutograd: values_default
71927192
device_check: NoCheck
@@ -7195,31 +7195,31 @@
71957195
- func: crow_indices(Tensor(a) self) -> Tensor(a)
71967196
variants: method
71977197
dispatch:
7198-
SparseCsrCPU, SparseCsrCUDA: crow_indices_sparse_csr
7198+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: crow_indices_sparse_csr
71997199
CompositeExplicitAutograd: crow_indices_default
72007200
device_check: NoCheck
72017201
device_guard: False
72027202

72037203
- func: col_indices(Tensor(a) self) -> Tensor(a)
72047204
variants: method
72057205
dispatch:
7206-
SparseCsrCPU, SparseCsrCUDA: col_indices_sparse_csr
7206+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: col_indices_sparse_csr
72077207
CompositeExplicitAutograd: col_indices_default
72087208
device_check: NoCheck
72097209
device_guard: False
72107210

72117211
- func: ccol_indices(Tensor(a) self) -> Tensor(a)
72127212
variants: method
72137213
dispatch:
7214-
SparseCsrCPU, SparseCsrCUDA: ccol_indices_sparse_csr
7214+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: ccol_indices_sparse_csr
72157215
CompositeExplicitAutograd: ccol_indices_default
72167216
device_check: NoCheck
72177217
device_guard: False
72187218

72197219
- func: row_indices(Tensor(a) self) -> Tensor(a)
72207220
variants: method
72217221
dispatch:
7222-
SparseCsrCPU, SparseCsrCUDA: row_indices_sparse_csr
7222+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: row_indices_sparse_csr
72237223
CompositeExplicitAutograd: row_indices_default
72247224
device_check: NoCheck
72257225
device_guard: False

aten/src/ATen/native/sparse/SparseCsrTensor.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -258,22 +258,24 @@ static void _validate_sparse_compressed_tensor_args_worker(const Tensor& compres
258258
compressed_indices_name, " and ", plain_indices_name, " dtype must be Int or Long, but got ",
259259
compressed_indices_type);
260260

261-
// Indices invariants
262-
at::_validate_compressed_sparse_indices(
261+
if (!values.is_meta()) {
262+
// Indices invariants
263+
at::_validate_compressed_sparse_indices(
263264
/*is_crow = */layout == kSparseCsr || layout == kSparseBsr,
264265
compressed_indices,
265266
plain_indices,
266267
compressed_dim_size,
267268
plain_dim_size,
268269
values_nnz);
270+
}
269271

270272
// Device Invariants
271273
// 4.1
272274
TORCH_CHECK(
273-
values.device().type() == kCPU || values.device().type() == kCUDA,
275+
values.device().type() == kCPU || values.device().type() == kCUDA || values.device().type() == kMeta,
274276
"device type of values (",
275277
values.device().type(),
276-
") must be CPU or CUDA");
278+
") must be CPU or CUDA or Meta");
277279
// 4.2, 4.3, 4.4
278280
TORCH_CHECK(
279281
compressed_indices.get_device() == values.get_device(),
@@ -333,14 +335,18 @@ static SparseCsrTensor new_compressed_tensor(const TensorOptions& options) {
333335
Layout layout = AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(options.layout(), "new_compressed_tensor", [&] { return the_layout; });
334336
DispatchKey dispatch_key;
335337

336-
TORCH_CHECK_NOT_IMPLEMENTED(
337-
options.device().type() == kCPU || options.device().type() == kCUDA,
338-
"Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
339-
340-
if (options.device().is_cuda()) {
341-
dispatch_key = DispatchKey::SparseCsrCUDA;
342-
} else {
338+
switch(options.device().type()) {
339+
case kCPU:
343340
dispatch_key = DispatchKey::SparseCsrCPU;
341+
break;
342+
case kCUDA:
343+
dispatch_key = DispatchKey::SparseCsrCUDA;
344+
break;
345+
case kMeta:
346+
dispatch_key = DispatchKey::SparseCsrMeta;
347+
break;
348+
default:
349+
TORCH_CHECK_NOT_IMPLEMENTED(false, "Could not run 'new_compressed_tensor' from the '", options.device(), "' device.)");
344350
}
345351

346352
return detail::make_tensor<SparseCsrTensorImpl>(DispatchKeySet(dispatch_key), options.device(), layout, options.dtype());

c10/core/DispatchKey.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ const char* toString(DispatchKey t) {
9595
return "SparseCsrCPU";
9696
case DispatchKey::SparseCsrCUDA:
9797
return "SparseCsrCUDA";
98+
case DispatchKey::SparseCsrMeta:
99+
return "SparseCsrMeta";
98100

99101
case DispatchKey::NestedTensor:
100102
return "NestedTensor";
@@ -276,6 +278,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
276278
{"Sparse", c10::DispatchKey::Sparse},
277279
{"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
278280
{"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
281+
{"SparseCsrMeta", c10::DispatchKey::SparseCsrMeta},
279282
{"BackendSelect", c10::DispatchKey::BackendSelect},
280283
{"Python", c10::DispatchKey::Python},
281284
{"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},

c10/core/DispatchKey.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ enum class DispatchKey : uint16_t {
220220
// TODO: Make SparseCsr a functionality key
221221
SparseCsrCPU,
222222
SparseCsrCUDA,
223+
SparseCsrMeta,
223224

224225
NestedTensor,
225226

c10/core/DispatchKeySet.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -685,8 +685,10 @@ constexpr DispatchKeySet python_ks = DispatchKeySet({
685685

686686
constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
687687

688-
constexpr DispatchKeySet sparse_csr_ks =
689-
DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});
688+
constexpr DispatchKeySet sparse_csr_ks = DispatchKeySet(
689+
{DispatchKey::SparseCsrCPU,
690+
DispatchKey::SparseCsrCUDA,
691+
DispatchKey::SparseCsrMeta});
690692

691693
constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
692694

c10/core/TensorOptions.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,8 @@ inline DispatchKey computeDispatchKey(
700700
return DispatchKey::SparseCsrCPU;
701701
case c10::DeviceType::CUDA:
702702
return DispatchKey::SparseCsrCUDA;
703+
case c10::DeviceType::Meta:
704+
return DispatchKey::SparseCsrMeta;
703705
default:
704706
AT_ERROR(
705707
"Unsupported device type for ",
@@ -720,6 +722,7 @@ inline Layout dispatchKeyToLayout(DispatchKey dispatch_key) {
720722
return Layout::Sparse;
721723
case DispatchKey::SparseCsrCPU:
722724
case DispatchKey::SparseCsrCUDA:
725+
case DispatchKey::SparseCsrMeta:
723726
TORCH_CHECK(
724727
false,
725728
"Cannot map DispatchKey ",

test/test_sparse.py

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4234,14 +4234,14 @@ def test_future_empty_dim(self, device, dtype, op):
42344234
class TestSparseMeta(TestCase):
42354235
exact_dtype = True
42364236

4237-
def test_basic(self):
4238-
r = torch.empty(4, 4, layout=torch.sparse_coo, device='meta')
4237+
def _test_basic_coo(self, dtype):
4238+
r = torch.empty(4, 4, dtype=dtype, layout=torch.sparse_coo, device='meta')
42394239
self.assertTrue(r.is_meta)
42404240
self.assertEqual(r.device.type, "meta")
42414241
r2 = torch.empty_like(r)
42424242
self.assertTrue(r2.is_meta)
42434243
self.assertEqual(r, r2)
4244-
r3 = torch.sparse_coo_tensor(size=(4, 4), device='meta')
4244+
r3 = torch.sparse_coo_tensor(size=(4, 4), dtype=dtype, device='meta')
42454245
self.assertTrue(r3.is_meta)
42464246
self.assertEqual(r, r3)
42474247
r.sparse_resize_((4, 4), 1, 1)
@@ -4260,9 +4260,67 @@ def test_basic(self):
42604260
# TODO: this sort of aliasing will need to be handled by
42614261
# functionalization
42624262
self.assertEqual(r._indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4263-
self.assertEqual(r._values(), torch.empty(0, 4, device='meta'))
4263+
self.assertEqual(r._values(), torch.empty(0, 4, dtype=dtype, device='meta'))
42644264
self.assertEqual(r.indices(), torch.empty(2, 0, device='meta', dtype=torch.int64))
4265-
self.assertEqual(r.values(), torch.empty(0, 4, device='meta'))
4265+
self.assertEqual(r.values(), torch.empty(0, 4, dtype=dtype, device='meta'))
4266+
4267+
def _test_basic_sparse_compressed(self, dtype, layout, batch_shape, dense_shape):
4268+
index_dtype = torch.int64
4269+
blocksize = (2, 3) if layout in {torch.sparse_bsr, torch.sparse_bsc} else ()
4270+
sparse_shape = (4, 6)
4271+
nnz = 0
4272+
4273+
shape = (*batch_shape, *sparse_shape, *dense_shape)
4274+
compressed_dim = 0 if layout in {torch.sparse_csr, torch.sparse_bsr} else 1
4275+
nof_compressed_indices = (sparse_shape[compressed_dim] // blocksize[compressed_dim] + 1 if blocksize
4276+
else sparse_shape[compressed_dim] + 1)
4277+
compressed_indices = torch.empty((*batch_shape, nof_compressed_indices), device='meta', dtype=index_dtype)
4278+
plain_indices = torch.empty((*batch_shape, nnz), device='meta', dtype=index_dtype)
4279+
4280+
values = torch.empty((*batch_shape, nnz, *blocksize, *dense_shape), device='meta', dtype=dtype)
4281+
r = torch.sparse_compressed_tensor(
4282+
compressed_indices,
4283+
plain_indices,
4284+
values,
4285+
shape,
4286+
layout=layout
4287+
)
4288+
self.assertTrue(r.is_meta)
4289+
self.assertEqual(r.device.type, "meta")
4290+
4291+
self.assertEqual(r.sparse_dim(), 2)
4292+
self.assertEqual(r.dense_dim(), len(dense_shape))
4293+
self.assertEqual(r._nnz(), nnz)
4294+
batch_dims = r.ndim - r.sparse_dim() - r.dense_dim()
4295+
r_blocksize = r.values().shape[batch_dims + 1: batch_dims + 1 + len(blocksize)]
4296+
self.assertEqual(r_blocksize, blocksize)
4297+
4298+
r_compressed_indices = r.crow_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.ccol_indices()
4299+
r_plain_indices = r.col_indices() if layout in {torch.sparse_csr, torch.sparse_bsr} else r.row_indices()
4300+
4301+
self.assertEqual(r_compressed_indices,
4302+
torch.empty((*batch_shape, nof_compressed_indices), device='meta', dtype=index_dtype))
4303+
self.assertEqual(r_plain_indices, torch.empty((*batch_shape, nnz), device='meta', dtype=index_dtype))
4304+
self.assertEqual(r.values(), torch.empty((*batch_shape, nnz, *blocksize, *dense_shape), device='meta', dtype=dtype))
4305+
4306+
r2 = torch.empty_like(r)
4307+
self.assertTrue(r2.is_meta)
4308+
self.assertEqual(r2, r)
4309+
4310+
if layout in {torch.sparse_csr, torch.sparse_csc}:
4311+
r3 = torch.empty((*batch_shape, *sparse_shape), dtype=dtype, layout=layout, device="meta")
4312+
self.assertTrue(r3.is_meta)
4313+
if not dense_shape:
4314+
self.assertEqual(r3, r)
4315+
4316+
@all_sparse_layouts('layout', include_strided=False)
4317+
@parametrize("dtype", [torch.float64])
4318+
def test_basic(self, dtype, layout):
4319+
if layout is torch.sparse_coo:
4320+
self._test_basic_coo(dtype)
4321+
else:
4322+
for batch_shape, dense_shape in itertools.product([(), (2,)], [(), (3,)]):
4323+
self._test_basic_sparse_compressed(dtype, layout, batch_shape, dense_shape)
42664324

42674325

42684326
class _SparseDataset(torch.utils.data.Dataset):
@@ -5125,6 +5183,8 @@ def test_invalid_blocksize(self):
51255183

51265184
instantiate_device_type_tests(TestSparseAny, globals(), except_for='meta')
51275185

5186+
instantiate_parametrized_tests(TestSparseMeta)
5187+
51285188
instantiate_parametrized_tests(TestSparseLegacyAndDeprecation)
51295189

51305190
if __name__ == '__main__':

torch/_tensor_str.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def __init__(self, tensor):
128128
with torch.no_grad():
129129
tensor_view = tensor.reshape(-1)
130130

131-
if not self.floating_dtype:
131+
if not self.floating_dtype or tensor.is_meta:
132132
for value in tensor_view:
133133
value_str = f"{value}"
134134
self.max_width = max(self.max_width, len(value_str))
@@ -476,7 +476,8 @@ def _str_intern(inp, *, tensor_contents=None):
476476
torch.sparse_bsc,
477477
}:
478478
suffixes.append("size=" + str(tuple(self.shape)))
479-
suffixes.append("nnz=" + str(self._nnz()))
479+
if not self.is_meta:
480+
suffixes.append("nnz=" + str(self._nnz()))
480481
if not has_default_dtype:
481482
suffixes.append("dtype=" + str(self.dtype))
482483
if not custom_contents_provided:
@@ -492,23 +493,34 @@ def _str_intern(inp, *, tensor_contents=None):
492493
cdimname, pdimname = "column", "row"
493494
compressed_indices_prefix = f"c{cdimname[:3]}_indices=tensor("
494495
compressed_indices = compressed_indices_method(self).detach()
495-
compressed_indices_str = _tensor_str(
496-
compressed_indices, indent + len(compressed_indices_prefix)
497-
)
496+
if compressed_indices.is_meta:
497+
compressed_indices_str = "..."
498+
else:
499+
compressed_indices_str = _tensor_str(
500+
compressed_indices, indent + len(compressed_indices_prefix)
501+
)
498502
if compressed_indices.numel() == 0:
499503
compressed_indices_str += ", size=" + str(
500504
tuple(compressed_indices.shape)
501505
)
506+
if compressed_indices.is_meta:
507+
compressed_indices_str += ", dtype=" + str(compressed_indices.dtype)
502508
plain_indices_prefix = f"{pdimname[:3]}_indices=tensor("
503509
plain_indices = plain_indices_method(self).detach()
504-
plain_indices_str = _tensor_str(
505-
plain_indices, indent + len(plain_indices_prefix)
506-
)
510+
if plain_indices.is_meta:
511+
plain_indices_str = "..."
512+
else:
513+
plain_indices_str = _tensor_str(
514+
plain_indices, indent + len(plain_indices_prefix)
515+
)
507516
if plain_indices.numel() == 0:
508517
plain_indices_str += ", size=" + str(tuple(plain_indices.shape))
509518
values_prefix = "values=tensor("
510519
values = self.values().detach()
511-
values_str = _tensor_str(values, indent + len(values_prefix))
520+
if values.is_meta:
521+
values_str = "..."
522+
else:
523+
values_str = _tensor_str(values, indent + len(values_prefix))
512524
if values.numel() == 0:
513525
values_str += ", size=" + str(tuple(values.shape))
514526
tensor_str = (

0 commit comments

Comments
 (0)