Skip to content

Commit eff1e48

Browse files
pearupytorchmergebot
authored andcommitted
Add sparse COO/CSR/CSC/BSR/BSC meta tensor input support to torch.sum (pytorch#121673)
As in the title. Fixes an issue reported in pytorch#117907 (comment) Pull Request resolved: pytorch#121673 Approved by: https://github.com/cpuhrsch
1 parent 7ce42eb commit eff1e48

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

aten/src/ATen/native/ReduceOpsUtils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,13 @@ static void resize_reduction(
368368
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim());
369369
maybe_wrap_dims(dims_, self.dim());
370370
auto shape = get_reduction_shape(self, dims_, keepdim, allow_empty_dims);
371-
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
371+
if (self.layout() == kStrided) {
372+
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
373+
} else if (shape.size() == 0) {
374+
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype).layout(kStrided));
375+
} else {
376+
TORCH_CHECK(false, "resize_reduction: support for output with ", self.layout(), " layout is not implemented yet");
377+
}
372378
namedinference::propagate_names_for_reduction(
373379
meta.maybe_get_output(), self, dims_, keepdim);
374380
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5695,8 +5695,8 @@
56955695
variants: function, method
56965696
dispatch:
56975697
CompositeExplicitAutograd: sum
5698-
SparseCPU, SparseCUDA: sum_coo
5699-
SparseCsrCPU, SparseCsrCUDA: sum_csr
5698+
SparseCPU, SparseCUDA, SparseMeta: sum_coo
5699+
SparseCsrCPU, SparseCsrCUDA, SparseCsrMeta: sum_csr
57005700
autogen: sum.out
57015701

57025702
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor

test/test_sparse.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4446,6 +4446,18 @@ def test_zeros_like_fake(self, dtype, layout):
44464446
result = torch.zeros_like(f, device=f.fake_device)
44474447
self.assertEqual(result, expected)
44484448

4449+
@all_sparse_layouts('layout', include_strided=False)
4450+
@parametrize("dtype", [torch.float64])
4451+
def test_sum_meta(self, dtype, layout):
4452+
device = 'cpu'
4453+
index_dtype = torch.int64
4454+
for t in self.generate_simple_inputs(layout, device=device, dtype=dtype, index_dtype=index_dtype):
4455+
m = t.to(device='meta')
4456+
r = torch.sum(m)
4457+
self.assertEqual(r.layout, torch.strided)
4458+
self.assertTrue(r.is_meta)
4459+
self.assertEqual(r.shape, ())
4460+
44494461

44504462
class _SparseDataset(torch.utils.data.Dataset):
44514463
# An utility class used in TestSparseAny.test_dataloader method.

0 commit comments

Comments
 (0)