Skip to content

Commit a86adf4

Browse files
committed
Fix comparison functions
1 parent 1c304a9 commit a86adf4

File tree

4 files changed

+39
-28
lines changed

4 files changed

+39
-28
lines changed

test/test_autograd.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,24 @@ def test_inplace(self):
682682
x.add_(2)
683683
self.assertRaises(RuntimeError, lambda: z.backward(torch.ones(5, 5)))
684684

685+
def test_mark_non_differentiable(self):
686+
class MyFunction(Function):
687+
@staticmethod
688+
def forward(ctx, input):
689+
output = input > 0
690+
ctx.mark_non_differentiable(output)
691+
return output
692+
693+
@staticmethod
694+
def backward(ctx, grad_output):
695+
return (grad_output * 0).type(torch.DoubleTensor)
696+
697+
x = Variable(torch.randn(5, 5), requires_grad=True)
698+
mask = MyFunction.apply(x)
699+
self.assertFalse(mask.requires_grad)
700+
y = x.masked_fill(mask, 0)
701+
y.sum().backward()
702+
685703
def test_shared_storage(self):
686704
x = Variable(torch.ones(5, 5))
687705
y = x.t()

torch/autograd/_functions/compare.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
from ..function import Function
44

55

6+
# TODO: once Cpp-style functions are implemented we can detach a and b
7+
# before calling forward.
68
class _CompareOp(Function):
79

8-
def __init__(self, scalar=None):
9-
super(_CompareOp, self).__init__()
10-
self.scalar = scalar
11-
12-
def forward(self, tensor1, tensor2=None):
13-
other = tensor2 if tensor2 is not None else self.scalar
14-
mask = getattr(tensor1, self.fn_name)(other)
15-
self.mark_non_differentiable(mask)
10+
@classmethod
11+
def forward(cls, ctx, a, b):
12+
ctx.b_tensor = torch.is_tensor(b)
13+
mask = getattr(a, cls.fn_name)(b)
14+
ctx.mark_non_differentiable(mask)
1615
return mask
1716

17+
@staticmethod
18+
def backward(ctx, grad_output):
19+
grad_input = grad_output * 0
20+
return grad_input, (grad_input if ctx.b_tensor else None)
21+
1822

1923
class Eq(_CompareOp):
2024
fn_name = 'eq'

torch/autograd/function.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ class FunctionMeta(type):
9494

9595
def __init__(cls, name, bases, attrs):
9696
for super_cls in cls.mro():
97-
if 'forward' in super_cls.__dict__:
98-
has_static_forward = isinstance(super_cls.__dict__['forward'], staticmethod)
97+
forward = super_cls.__dict__.get('forward')
98+
if forward is not None:
99+
has_static_forward = isinstance(forward, staticmethod) or isinstance(forward, classmethod)
99100
break
100101

101102
setattr(cls, '_is_legacy', not has_static_forward)

torch/autograd/variable.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -706,40 +706,28 @@ def bernoulli(self):
706706
return Bernoulli()(self)
707707

708708
def eq(self, other):
709-
if isinstance(other, Variable):
710-
return Eq()(self, other)
711709
assert not torch.is_tensor(other), "can't compare Variable and tensor"
712-
return Eq(other)(self)
710+
return Eq.apply(self, other)
713711

714712
def ne(self, other):
715-
if isinstance(other, Variable):
716-
return Ne()(self, other)
717713
assert not torch.is_tensor(other), "can't compare Variable and tensor"
718-
return Ne(other)(self)
714+
return Ne.apply(self, other)
719715

720716
def gt(self, other):
721-
if isinstance(other, Variable):
722-
return Gt()(self, other)
723717
assert not torch.is_tensor(other), "can't compare Variable and tensor"
724-
return Gt(other)(self)
718+
return Gt.apply(self, other)
725719

726720
def ge(self, other):
727-
if isinstance(other, Variable):
728-
return Ge()(self, other)
729721
assert not torch.is_tensor(other), "can't compare Variable and tensor"
730-
return Ge(other)(self)
722+
return Ge.apply(self, other)
731723

732724
def lt(self, other):
733-
if isinstance(other, Variable):
734-
return Lt()(self, other)
735725
assert not torch.is_tensor(other), "can't compare Variable and tensor"
736-
return Lt(other)(self)
726+
return Lt.apply(self, other)
737727

738728
def le(self, other):
739-
if isinstance(other, Variable):
740-
return Le()(self, other)
741729
assert not torch.is_tensor(other), "can't compare Variable and tensor"
742-
return Le(other)(self)
730+
return Le.apply(self, other)
743731

744732
def __add__(self, other):
745733
return self.add(other)

0 commit comments

Comments
 (0)