Skip to content

Commit 278a0e1

Browse files
davidberard98pytorchmergebot
authored andcommitted
[NestedTensor] Support binary pointwise ops with >2 inputs (if inputs are non-tensors) (pytorch#119419)
It should usually be safe to run pointwise binary ops with >2 inputs. e.g. threshold_backward(tensor, tensor, scalar): we just operate on the values of the nested tensors, and pass in the other args as-is. Pull Request resolved: pytorch#119419 Approved by: https://github.com/soulitzer
1 parent cd9a193 commit 278a0e1

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

test/test_nestedtensor.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3349,6 +3349,21 @@ def grad_test_func(t, *ts):
33493349
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
33503350
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
33513351

3352+
def test_threshold_backward(self, device):
3353+
ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
3354+
ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False)
3355+
3356+
nt1, offsets = jagged_from_list(ts1, None)
3357+
nt2, offsets = jagged_from_list(ts2, offsets)
3358+
buf1 = buffer_from_jagged(nt1).detach().clone()
3359+
buf2 = buffer_from_jagged(nt2).detach().clone()
3360+
3361+
res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
3362+
res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
3363+
3364+
self.assertEqual(res_dense, buffer_from_jagged(res_nt))
3365+
3366+
33523367
@parametrize("keepdim", [False, True])
33533368
def test_sum_int_DimList(self, device, keepdim):
33543369
# (B, j0, 3, 4)

torch/nested/_internal/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
194194
check_schema("self: jt_all, ...", func, *args, **kwargs)
195195
return functools.partial(jagged_unary_pointwise, func)
196196
elif num_tensor_args == 2:
197-
check_schema("lhs: any, rhs: any", func, *args, **kwargs)
197+
check_schema("lhs: any, rhs: any, ...", func, *args, **kwargs)
198198
return functools.partial(jagged_binary_pointwise, func)
199199

200200
return None

0 commit comments

Comments
 (0)