Skip to content

Commit 876c359

Browse files
suopytorchmergebot
authored andcommitted
Generalize sizes and strides policy on _make_wrapper_subclass
Previously, there was a `dispatch_strides` boolean arg. Change this to a string argument that directly maps onto `SizesStridesPolicy`. Pull Request resolved: pytorch#78646 Approved by: https://github.com/ezyang
1 parent 64a01f1 commit 876c359

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

test/test_python_dispatch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,7 +1347,7 @@ def test_is_contiguous_slow_path(self):
13471347
class ExampleTensor1(torch.Tensor):
13481348
@staticmethod
13491349
def __new__(cls, data, wrapper):
1350-
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_strides=True)
1350+
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
13511351

13521352
@classmethod
13531353
def __torch_dispatch__(cls, func, types, args, kwargs):
@@ -1356,7 +1356,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
13561356
class ExampleTensor2(torch.Tensor):
13571357
@staticmethod
13581358
def __new__(cls, data, wrapper):
1359-
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_strides=True)
1359+
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
13601360

13611361
@classmethod
13621362
def __torch_dispatch__(cls, func, types, args, kwargs):
@@ -1367,7 +1367,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
13671367
class ExampleTensor3(torch.Tensor):
13681368
@staticmethod
13691369
def __new__(cls, data, wrapper):
1370-
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_strides=True)
1370+
return TestPythonDispatch.subclass_helper(cls, data, wrapper, dispatch_sizes_strides_policy="strides")
13711371

13721372
@classmethod
13731373
def __torch_dispatch__(cls, func, types, args, kwargs):

torch/csrc/autograd/python_variable.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,22 @@ class PyInterpreterHolder {
240240
};
241241
PyInterpreterHolder self_interpreter;
242242

243+
c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument(
244+
c10::string_view arg) {
245+
if (arg == "strides") {
246+
return c10::TensorImpl::SizesStridesPolicy::CustomStrides;
247+
}
248+
249+
if (arg == "sizes") {
250+
return c10::TensorImpl::SizesStridesPolicy::CustomSizes;
251+
}
252+
253+
TORCH_CHECK_VALUE(
254+
false,
255+
"Unknown sizes_strides_policy: ",
256+
arg,
257+
"; expected 'strides' or 'sizes'");
258+
}
243259
} // anonymous namespace
244260

245261
c10::impl::PyInterpreter* getPyInterpreter() {
@@ -538,7 +554,7 @@ static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObje
538554
static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, PyObject* kwargs) {
539555
HANDLE_TH_ERRORS
540556
static PythonArgParser parser({
541-
"_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, bool dispatch_strides=False, bool dispatch_device=False)",
557+
"_make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
542558
});
543559
ParsedArgs<5> parsed_args{};
544560
auto r = parser.parse(args, kwargs, parsed_args);
@@ -559,8 +575,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
559575
// ```
560576
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
561577
data.set_requires_grad(r.toBool(2));
562-
if (r.toBool(3)) {
563-
data.unsafeGetTensorImpl()->set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomStrides);
578+
const auto sizes_strides_policy = r.stringViewOptional(3);
579+
if (sizes_strides_policy.has_value()) {
580+
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
581+
parseSizesStridesPolicyArgument(*sizes_strides_policy));
564582
}
565583
if (r.toBool(4)) {
566584
data.unsafeGetTensorImpl()->set_custom_device(true);
@@ -577,7 +595,10 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py
577595
// NB: pin_memory doesn't actually do anything
578596
// TODO: strides variant?
579597
static PythonArgParser parser({
580-
"_make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, bool dispatch_strides=False, bool dispatch_device=False)",
598+
"_make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, "
599+
"int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
600+
"Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
601+
"c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)",
581602
});
582603
ParsedArgs<12> parsed_args{};
583604
auto r = parser.parse(args, kwargs, parsed_args);
@@ -619,8 +640,10 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py
619640
.make_tensor();
620641
data.set_requires_grad(r.toBool(9));
621642

622-
if (r.toBool(10)) {
623-
data.unsafeGetTensorImpl()->set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomStrides);
643+
const auto sizes_strides_policy = r.stringViewOptional(10);
644+
if (sizes_strides_policy.has_value()) {
645+
data.unsafeGetTensorImpl()->set_sizes_strides_policy(
646+
parseSizesStridesPolicyArgument(*sizes_strides_policy));
624647
}
625648
if (r.toBool(11)) {
626649
data.unsafeGetTensorImpl()->set_custom_device(true);

0 commit comments

Comments
 (0)