Skip to content

Commit 036c3f9

Browse files
colesburysoumith
authored andcommitted
Check for released variables in SavedVariable::unpack() (pytorch#1648)
Fixes pytorch#1288
1 parent 98581b9 commit 036c3f9

File tree

5 files changed

+26
-10
lines changed

5 files changed

+26
-10
lines changed

test/test_nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,14 @@ def test_Conv2d_missing_argument(self):
13211321
c = nn.Conv2d(3, 3, 3)
13221322
self.assertRaises(RuntimeError, lambda: c(None))
13231323

1324+
def test_Conv2d_backward_twice(self):
1325+
input = Variable(torch.randn(2, 3, 5, 5))
1326+
c = nn.Conv2d(3, 3, 3)
1327+
o1 = c(input)
1328+
o1.sum().backward()
1329+
self.assertRaisesRegex(RuntimeError, 'Specify retain_variables=True',
1330+
lambda: o1.sum().backward())
1331+
13241332
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
13251333
def test_Conv2d_large_workspace(self):
13261334
# These sizes require huge cuDNN workspaces. Make sure we choose a

torch/csrc/autograd/functions/convolution.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ auto ConvBackward::apply(const variable_list& grad_outputs) -> variable_list {
204204
if (is_padding_neg()) throw std::runtime_error("negative padding is not supported");
205205
if (is_output_padding_neg()) throw std::runtime_error("negative output_padding is not supported");
206206

207-
AutoGPU guard(input_.data->getDevice());
208-
209-
auto input = input_.unpack_data()->contiguous();
207+
auto input = input_.unpack_data();
208+
AutoGPU guard(input->getDevice());
209+
input = input->contiguous();
210210
std::unique_ptr<Tensor> weight(weight_.unpack_data()->clone_shallow());
211211
auto bias = bias_.unpack_data();
212212
auto grad_output = grad_outputs[0]->data->contiguous();

torch/csrc/autograd/python_function.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -849,9 +849,7 @@ PyObject* THPFunction_register_hook(THPFunction *self, PyObject *hook)
849849

850850
PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused)
851851
{
852-
THPUtils_assert(!self->has_freed_buffers, "Trying to backward through the "
853-
"graph second time, but the buffers have already been freed. Please "
854-
"specify retain_variables=True when calling backward for the first time.");
852+
THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
855853
if (!self->saved_variables)
856854
return PyTuple_New(0);
857855

@@ -876,9 +874,7 @@ PyObject *THPFunction_saved_tensors(THPFunction *self, void *_unused)
876874

877875
PyObject *THPFunction_saved_variables(THPFunction *self, void *_unused)
878876
{
879-
THPUtils_assert(!self->has_freed_buffers, "Trying to backward through the "
880-
"graph second time, but the buffers have already been freed. Please "
881-
"specify retain_variables=True when calling backward for the first time.");
877+
THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
882878
if (!self->saved_variables)
883879
return PyTuple_New(0);
884880

torch/csrc/autograd/variable.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ auto Variable::get_grad_accumulator() -> std::shared_ptr<Function> {
6060
}
6161

6262
auto SavedVariable::unpack() -> std::shared_ptr<Variable> {
63-
if (!data) return nullptr;
63+
if (!data) {
64+
if (version) {
65+
throw std::runtime_error(ERR_BACKWARD_TWICE);
66+
}
67+
return nullptr;
68+
}
6469

6570
int current_version = **version;
6671
if (expected_version != current_version) {
@@ -91,4 +96,9 @@ auto SavedVariable::unpack() -> std::shared_ptr<Variable> {
9196
return new_var;
9297
}
9398

99+
const char* ERR_BACKWARD_TWICE =
100+
"Trying to backward through the graph a second time, but the buffers have "
101+
"already been freed. Specify retain_variables=True when calling backward "
102+
"the first time.";
103+
94104
}} // namespace torch::autograd

torch/csrc/autograd/variable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
namespace torch { namespace autograd {
1313

14+
extern const char* ERR_BACKWARD_TWICE;
15+
1416
struct Variable : std::enable_shared_from_this<Variable> {
1517

1618
struct SavedVariable {

0 commit comments

Comments
 (0)