1010from torch .autograd import gradcheck
1111from torch .autograd .function import once_differentiable
1212
13- from common import TestCase , run_tests
13+ from common import TestCase , run_tests , skipIfNoLapack
1414from torch .autograd ._functions import *
1515from torch .autograd import Variable , Function
1616
@@ -1415,6 +1415,7 @@ class dont_convert(tuple):
14151415 (Trace , (), ((S , S ),)),
14161416 (Cross , (), ((S , 3 ), (S , 3 ))),
14171417 (Cross , (1 ,), ((S , 3 , S ), (S , 3 , S )), 'dim' ),
1418+ (Inverse , (), ((S , S ),), '' , (), [skipIfNoLapack ]),
14181419 (Clone , (), ((S , M , S ),)),
14191420 (Squeeze , (), ((S , 1 , M , 1 ),)),
14201421 # TODO: enable neg dim checks
@@ -1542,6 +1543,7 @@ class dont_convert(tuple):
15421543 ('trace' , (M , M ), ()),
15431544 ('cross' , (S , 3 ), ((S , 3 ),)),
15441545 ('cross' , (S , 3 , S ), ((S , 3 , S ), 1 ), 'dim' ),
1546+ ('inverse' , (S , S ), (), '' , (), [skipIfNoLapack ]),
15451547 ('clone' , (S , M , S ), ()),
15461548 ('eq' , (S , S , S ), ((S , S , S ),)),
15471549 ('ne' , (S , S , S ), ((S , S , S ),)),
@@ -1617,6 +1619,8 @@ def unpack_variables(args):
16171619
16181620 dim_args_idx = test [4 ] if len (test ) == 5 else []
16191621
1622+ skipTestIf = test [5 ] if len (test ) == 6 else []
1623+
16201624 for dim_perm in product ([- 1 , 1 ], repeat = len (dim_args_idx )):
16211625 test_name = basic_test_name
16221626 new_constructor_args = [arg * dim_perm [dim_args_idx .index (i )] if i in dim_args_idx else arg
@@ -1671,6 +1675,10 @@ def apply_inplace_fn(*input):
16711675 self .assertEqual (inp_i .grad , i .grad )
16721676
16731677 assert not hasattr (TestAutograd , test_name ), 'Two tests have the same name: ' + test_name
1678+
1679+ for skip in skipTestIf :
1680+ do_test = skip (do_test )
1681+
16741682 setattr (TestAutograd , test_name , do_test )
16751683
16761684
@@ -1687,6 +1695,8 @@ def apply_inplace_fn(*input):
16871695
16881696 dim_args_idx = test [4 ] if len (test ) == 5 else []
16891697
1698+ skipTestIf = test [5 ] if len (test ) == 6 else []
1699+
16901700 for dim_perm in product ([- 1 , 1 ], repeat = len (dim_args_idx )):
16911701 test_name = basic_test_name
16921702 new_args = [arg * dim_perm [dim_args_idx .index (i )] if i in dim_args_idx else arg for i , arg in enumerate (args )]
@@ -1726,6 +1736,10 @@ def check(name):
17261736 raise
17271737
17281738 assert not hasattr (TestAutograd , test_name ), 'Two tests have the same name: ' + test_name
1739+
1740+ for skip in skipTestIf :
1741+ do_test = skip (do_test )
1742+
17291743 setattr (TestAutograd , test_name , do_test )
17301744
17311745
0 commit comments