@@ -240,6 +240,22 @@ class PyInterpreterHolder {
240240};
241241PyInterpreterHolder self_interpreter;
242242
243+ c10::TensorImpl::SizesStridesPolicy parseSizesStridesPolicyArgument (
244+ c10::string_view arg) {
245+ if (arg == " strides" ) {
246+ return c10::TensorImpl::SizesStridesPolicy::CustomStrides;
247+ }
248+
249+ if (arg == " sizes" ) {
250+ return c10::TensorImpl::SizesStridesPolicy::CustomSizes;
251+ }
252+
253+ TORCH_CHECK_VALUE (
254+ false ,
255+ " Unknown sizes_strides_policy: " ,
256+ arg,
257+ " ; expected 'strides' or 'sizes'" );
258+ }
243259} // anonymous namespace
244260
245261c10::impl::PyInterpreter* getPyInterpreter () {
@@ -538,7 +554,7 @@ static PyObject* THPVariable_as_subclass(PyObject* _self, PyObject* args, PyObje
538554static PyObject* THPVariable_make_subclass (PyObject* _ignored, PyObject* args, PyObject* kwargs) {
539555 HANDLE_TH_ERRORS
540556 static PythonArgParser parser ({
541- " _make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, bool dispatch_strides=False , bool dispatch_device=False)" ,
557+ " _make_subclass(PyObject* cls, Tensor data, bool require_grad=False, *, c10::string_view? dispatch_sizes_strides_policy=None , bool dispatch_device=False)" ,
542558 });
543559 ParsedArgs<5 > parsed_args{};
544560 auto r = parser.parse (args, kwargs, parsed_args);
@@ -559,8 +575,10 @@ static PyObject* THPVariable_make_subclass(PyObject* _ignored, PyObject* args, P
559575 // ```
560576 data.unsafeGetTensorImpl ()->set_allow_tensor_metadata_change (true );
561577 data.set_requires_grad (r.toBool (2 ));
562- if (r.toBool (3 )) {
563- data.unsafeGetTensorImpl ()->set_sizes_strides_policy (c10::TensorImpl::SizesStridesPolicy::CustomStrides);
578+ const auto sizes_strides_policy = r.stringViewOptional (3 );
579+ if (sizes_strides_policy.has_value ()) {
580+ data.unsafeGetTensorImpl ()->set_sizes_strides_policy (
581+ parseSizesStridesPolicyArgument (*sizes_strides_policy));
564582 }
565583 if (r.toBool (4 )) {
566584 data.unsafeGetTensorImpl ()->set_custom_device (true );
@@ -577,7 +595,10 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py
577595 // NB: pin_memory doesn't actually do anything
578596 // TODO: strides variant?
579597 static PythonArgParser parser ({
580- " _make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, bool dispatch_strides=False, bool dispatch_device=False)" ,
598+ " _make_wrapper_subclass(PyObject* cls, IntArrayRef size, *, IntArrayRef? strides=None, "
599+ " int64_t? storage_offset=None, MemoryFormat? memory_format=None, ScalarType dtype=None, "
600+ " Layout layout=torch.strided, Device device=None, bool pin_memory=False, bool requires_grad=False, "
601+ " c10::string_view? dispatch_sizes_strides_policy=None, bool dispatch_device=False)" ,
581602 });
582603 ParsedArgs<12 > parsed_args{};
583604 auto r = parser.parse (args, kwargs, parsed_args);
@@ -619,8 +640,10 @@ static PyObject* THPVariable_make_wrapper_subclass(PyObject*, PyObject* args, Py
619640 .make_tensor ();
620641 data.set_requires_grad (r.toBool (9 ));
621642
622- if (r.toBool (10 )) {
623- data.unsafeGetTensorImpl ()->set_sizes_strides_policy (c10::TensorImpl::SizesStridesPolicy::CustomStrides);
643+ const auto sizes_strides_policy = r.stringViewOptional (10 );
644+ if (sizes_strides_policy.has_value ()) {
645+ data.unsafeGetTensorImpl ()->set_sizes_strides_policy (
646+ parseSizesStridesPolicyArgument (*sizes_strides_policy));
624647 }
625648 if (r.toBool (11 )) {
626649 data.unsafeGetTensorImpl ()->set_custom_device (true );
0 commit comments