Skip to content

Commit 75a8c8a

Browse files
msaroufimpytorchmergebot
authored andcommitted
softshrink lowering (pytorch#105603)
Fixes pytorch#105563 Pull Request resolved: pytorch#105603 Approved by: https://github.com/Chillee
1 parent 6560750 commit 75a8c8a

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

test/inductor/test_torchinductor.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2035,6 +2035,19 @@ def fn(a):
20352035
with self.assertRaisesRegex(RuntimeError, ""):
20362036
fn(torch.randn(1, 5))
20372037

2038+
def test_softshrink_backward(self):
2039+
grad_output = torch.randn(1)
2040+
lambd = 0.5
2041+
2042+
def fn(a, grad_output, lambd):
2043+
a = a.cos()
2044+
return torch.ops.aten.softshrink_backward(grad_output, a, lambd)
2045+
2046+
self.common(
2047+
fn,
2048+
(torch.randn(10), grad_output, lambd),
2049+
)
2050+
20382051
def test_inductor_assert(self):
20392052
@torch._dynamo.optimize("inductor", dynamic=True)
20402053
def fn(a):

torch/_inductor/lowering.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,6 @@ def apply_constraint(arg, fx_arg):
18011801
make_fallback(aten.reflection_pad1d_backward)
18021802
make_fallback(aten.replication_pad1d_backward)
18031803
make_fallback(aten.soft_margin_loss_backward, warn=False)
1804-
make_fallback(aten.softshrink_backward, warn=False)
18051804
make_fallback(aten.linalg_pinv.atol_rtol_tensor)
18061805
make_fallback(aten.segment_reduce.default)
18071806
make_fallback(aten._segment_reduce_backward.default)

0 commit comments

Comments
 (0)