Skip to content

Commit e1d257b

Browse files
colesburysoumith
authored andcommitted
Fix segfault in autograd: (pytorch#1644)
* Fix segfault in autograd: 1) Every "output" variable must have a grad_fn or grad_accumulator 2) compute_partial_exec_callbacks uses Python errors * assertRaisesRegexp was renamed assertRaisesRegex in 3.2 * Use HANDLE_TH_ERRORS macro
1 parent 3d38e4f commit e1d257b

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

test/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ def assertObjectIn(self, obj, iterable):
248248
return
249249
raise AssertionError("object not found in iterable")
250250

251+
if sys.version_info < (3, 2):
252+
# assertRaisesRegexp renamed assertRaisesRegex in 3.2
253+
assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
254+
251255

252256
def download_file(url, path, binary=True):
253257
if sys.version_info < (3,):

test/test_autograd.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,12 @@ def backward(ctx, grad_output):
117117
grad_output * ctx.scalar + grad_output * t1)
118118

119119
x, y = self._function_test(MyFunction)
120-
x_grad_desc = graph_desc(x.grad.grad_fn)
121-
y_grad_desc = graph_desc(y.grad.grad_fn)
122120
self.assertEqual(graph_desc(x.grad.grad_fn),
123121
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
124122
self.assertEqual(graph_desc(y.grad.grad_fn),
125123
'Identity(Error(AccumulateGrad(), None, AccumulateGrad()))')
126124

127125
def test_accumulate_grad(self):
128-
import sys
129-
130126
grad_output = Variable(torch.ones(5, 5))
131127
for start_volatile, end_volatile in product((True, False), repeat=2):
132128
go1 = grad_output.data if start_volatile else grad_output
@@ -248,6 +244,20 @@ def hook(*grads):
248244
self.assertFalse(hook_called[0])
249245
self.assertIsNone(x.grad)
250246

247+
def test_grad_badcalls(self):
248+
x = Variable(torch.ones(1))
249+
y = x ** 2
250+
with self.assertRaisesRegex(RuntimeError, 'does not require grad'):
251+
torch.autograd.grad(x, y)
252+
with self.assertRaisesRegex(RuntimeError, 'not have been used in the graph'):
253+
torch.autograd.grad(y, x)
254+
255+
x = Variable(torch.ones(1), requires_grad=True)
256+
y = x ** 2
257+
torch.autograd.grad(y, x) # this should succeed now
258+
with self.assertRaisesRegex(RuntimeError, 'unreachable'):
259+
torch.autograd.grad(x, y)
260+
251261
def test_hooks(self):
252262
x = Variable(torch.ones(5, 5), requires_grad=True)
253263
y = Variable(torch.ones(5, 5) * 4, requires_grad=True)

torch/csrc/autograd/python_engine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ void compute_partial_exec_callbacks(const function_list& roots,
106106

107107
PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwargs)
108108
{
109+
HANDLE_TH_ERRORS
109110
PyObject *variables = NULL;
110111
PyObject *grad_variables = NULL;
111112
unsigned char keep_graph = 0;
@@ -137,6 +138,7 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
137138
THPUtils_assert(!variable->is_volatile,
138139
"element %d of variables tuple is volatile", i);
139140
auto grad_fn = variable->grad_fn ? variable->grad_fn : variable->get_grad_accumulator();
141+
THPUtils_assert(grad_fn, "element %d of variables tuple does not require grad", i);
140142
int output_nr = variable->grad_fn ? variable->output_nr : 0;
141143
roots[i] = std::make_pair<>(std::move(grad_fn), output_nr);
142144

@@ -201,16 +203,14 @@ PyObject *THPEngine_run_backward(THPEngine *self, PyObject *args, PyObject *kwar
201203
} catch (python_error &e) {
202204
e.restore();
203205
return nullptr;
204-
} catch (const std::exception &e) {
205-
PyErr_SetString(PyExc_RuntimeError, e.what());
206-
return nullptr;
207206
}
208207

209208
if (ctx.outputs) {
210209
return ctx.outputs.release();
211210
} else {
212211
Py_RETURN_NONE;
213212
}
213+
END_HANDLE_TH_ERRORS
214214
}
215215

216216
PyObject *THPEngine_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)

0 commit comments

Comments
 (0)