Skip to content

Commit 6adadba

Browse files
jbschlosserpytorchmergebot
authored andcommitted
Fix jagged NT softmax semantics (pytorch#119459)
Before: `softmax` definition uses `jagged_unary_pointwise()` (wrong) After: `softmax` impl adjusts the `dim` arg to account for the difference in dimensionality between the outer NT and the NT's `_values` Pull Request resolved: pytorch#119459 Approved by: https://github.com/soulitzer
1 parent 278a0e1 commit 6adadba

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

test/test_nestedtensor.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3182,6 +3182,22 @@ def test_split_with_sizes(self, device):
31823182
):
31833183
torch.split(nt, [1, 2], 1)
31843184

3185+
def test_softmax(self, device):
3186+
nt = random_nt_from_dims(
3187+
[3, None, 5], device=device, dtype=torch.float32, layout=torch.jagged)
3188+
3189+
# operate on dim=2
3190+
output = nt.softmax(dim=2)
3191+
for in_component, out_component in zip(nt.unbind(), output.unbind()):
3192+
# dim=2 -> dim=1 after unbind
3193+
self.assertEqual(in_component.softmax(dim=1), out_component)
3194+
3195+
# operate on dim=-1
3196+
output2 = nt.softmax(dim=-1)
3197+
self.assertEqual(output, output2)
3198+
for in_component, out_component in zip(nt.unbind(), output2.unbind()):
3199+
self.assertEqual(in_component.softmax(dim=-1), out_component)
3200+
31853201
def test_views_inherit_ragged_dim(self, device):
31863202
# view
31873203
nt = random_nt_from_dims(

torch/nested/_internal/ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,19 @@ def to_copy_default(func, *args, **kwargs):
447447
)(jagged_unary_pointwise)
448448

449449

450-
register_jagged_func(
450+
@register_jagged_func(
451451
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
452-
)(jagged_unary_pointwise)
452+
)
453+
def _softmax_default(func, *args, **kwargs):
454+
_, new_kwargs = normalize_function(
455+
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
456+
)
457+
458+
inp = new_kwargs.pop("input")
459+
dim = new_kwargs["dim"]
460+
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), dim, "softmax")
461+
462+
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
453463

454464

455465
@register_jagged_func(

0 commit comments

Comments
 (0)