Skip to content

Commit 75e0df2

Browse files
fmassasoumith
authored andcommitted
Add Inverse to autograd (pytorch#1670)
* Add Inverse to autograd * Add SkipTest to autograd tests
1 parent 565bf71 commit 75e0df2

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

test/test_autograd.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch.autograd import gradcheck
1111
from torch.autograd.function import once_differentiable
1212

13-
from common import TestCase, run_tests
13+
from common import TestCase, run_tests, skipIfNoLapack
1414
from torch.autograd._functions import *
1515
from torch.autograd import Variable, Function
1616

@@ -1415,6 +1415,7 @@ class dont_convert(tuple):
14151415
(Trace, (), ((S, S),)),
14161416
(Cross, (), ((S, 3), (S, 3))),
14171417
(Cross, (1,), ((S, 3, S), (S, 3, S)), 'dim'),
1418+
(Inverse, (), ((S, S),), '', (), [skipIfNoLapack]),
14181419
(Clone, (), ((S, M, S),)),
14191420
(Squeeze, (), ((S, 1, M, 1),)),
14201421
# TODO: enable neg dim checks
@@ -1542,6 +1543,7 @@ class dont_convert(tuple):
15421543
('trace', (M, M), ()),
15431544
('cross', (S, 3), ((S, 3),)),
15441545
('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
1546+
('inverse', (S, S), (), '', (), [skipIfNoLapack]),
15451547
('clone', (S, M, S), ()),
15461548
('eq', (S, S, S), ((S, S, S),)),
15471549
('ne', (S, S, S), ((S, S, S),)),
@@ -1617,6 +1619,8 @@ def unpack_variables(args):
16171619

16181620
dim_args_idx = test[4] if len(test) == 5 else []
16191621

1622+
skipTestIf = test[5] if len(test) == 6 else []
1623+
16201624
for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
16211625
test_name = basic_test_name
16221626
new_constructor_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg
@@ -1671,6 +1675,10 @@ def apply_inplace_fn(*input):
16711675
self.assertEqual(inp_i.grad, i.grad)
16721676

16731677
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
1678+
1679+
for skip in skipTestIf:
1680+
do_test = skip(do_test)
1681+
16741682
setattr(TestAutograd, test_name, do_test)
16751683

16761684

@@ -1687,6 +1695,8 @@ def apply_inplace_fn(*input):
16871695

16881696
dim_args_idx = test[4] if len(test) == 5 else []
16891697

1698+
skipTestIf = test[5] if len(test) == 6 else []
1699+
16901700
for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
16911701
test_name = basic_test_name
16921702
new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
@@ -1726,6 +1736,10 @@ def check(name):
17261736
raise
17271737

17281738
assert not hasattr(TestAutograd, test_name), 'Two tests have the same name: ' + test_name
1739+
1740+
for skip in skipTestIf:
1741+
do_test = skip(do_test)
1742+
17291743
setattr(TestAutograd, test_name, do_test)
17301744

17311745

torch/autograd/_functions/linalg.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,17 @@ def backward(self, grad_output):
6969
grad_input = torch.cross(other, grad_output, self.dim)
7070
grad_other = torch.cross(grad_output, input, self.dim)
7171
return grad_input, grad_other
72+
73+
74+
class Inverse(Function):
75+
76+
@staticmethod
77+
def forward(ctx, input):
78+
inverse = torch.inverse(input)
79+
ctx.save_for_backward(inverse)
80+
return inverse
81+
82+
@staticmethod
83+
def backward(ctx, grad_output):
84+
inverse, = ctx.saved_variables
85+
return -torch.mm(inverse.t(), torch.mm(grad_output, inverse.t()))

torch/autograd/variable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,9 @@ def trace(self):
716716
def cross(self, other, dim=-1):
717717
return Cross(dim)(self, other)
718718

719+
def inverse(self):
720+
return Inverse.apply(self)
721+
719722
def multinomial(self, num_samples=1, with_replacement=False):
720723
return Multinomial(num_samples, with_replacement)(self)
721724

0 commit comments

Comments
 (0)