Skip to content

Commit 35cf380

Browse files
committed
Improve output wrapping logic in autograd
1 parent 3a7e068 commit 35cf380

File tree

5 files changed

+93
-60
lines changed

5 files changed

+93
-60
lines changed

test/test_autograd.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,13 @@ def test_leaf_assignment(self):
902902
self.assertEqual(y.grad.data, torch.ones(5))
903903
self.assertEqual(z.grad.data, torch.ones(5) * 2)
904904

905+
def test_volatile_assignment(self):
906+
x = Variable(torch.randn(5, 5))
907+
y = Variable(torch.randn(5), volatile=True)
908+
909+
x[0] = y
910+
self.assertTrue(x.volatile)
911+
905912
def test_backward_copy(self):
906913
# This tests checks backward engine for a very subtle bug that appreared
907914
# in one of the initial versions of autograd. Gradients tensors were

torch/autograd/_functions/pointwise.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ def __init__(self, scale=1, inplace=False):
367367
def forward(self, add_tensor, mul_tensor1, mul_tensor2):
368368
self.save_for_backward(mul_tensor1, mul_tensor2)
369369
if self.inplace:
370+
self.mark_dirty(add_tensor)
370371
return add_tensor.addcmul_(self.scale, mul_tensor1, mul_tensor2)
371372
else:
372373
return add_tensor.addcmul(self.scale, mul_tensor1, mul_tensor2)
@@ -396,6 +397,7 @@ def __init__(self, scale=1, inplace=False):
396397
def forward(self, add_tensor, div_tensor1, div_tensor2):
397398
self.save_for_backward(div_tensor1, div_tensor2)
398399
if self.inplace:
400+
self.mark_dirty(add_tensor)
399401
return add_tensor.addcdiv_(self.scale, div_tensor1, div_tensor2)
400402
else:
401403
return add_tensor.addcdiv(self.scale, div_tensor1, div_tensor2)

torch/csrc/autograd/python_function.cpp

Lines changed: 79 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <THPP/THPP.h>
99

1010
#include "THP.h"
11+
#include "torch/csrc/autograd/functions/accumulate_grad.h"
1112
#include "torch/csrc/autograd/functions/basic_ops.h"
1213
#include "torch/csrc/autograd/functions/utils.h"
1314
#include "torch/csrc/autograd/python_cpp_function.h"
@@ -330,33 +331,64 @@ static void _mark_dirty(THPFunction *self, t2var_type &t2var,
330331
self->dirty_tensors = NULL;
331332
}
332333

334+
static void _transplant_var(Variable& var, const std::shared_ptr<Function>& fn, int output_nr, bool is_volatile)
335+
{
336+
if (is_volatile) {
337+
var.grad_fn = nullptr;
338+
var.requires_grad = false;
339+
var.is_volatile = true;
340+
var.output_nr = 0;
341+
} else {
342+
var.grad_fn = fn;
343+
var.requires_grad = fn->is_executable;
344+
var.is_volatile = is_volatile;
345+
var.output_nr = output_nr;
346+
}
347+
var.grad = nullptr;
348+
var.hooks.clear();
349+
if (auto grad_acc_fn = var.grad_accumulator.lock()) {
350+
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
351+
grad_acc->variable.reset();
352+
grad_acc->variable_grad.reset();
353+
}
354+
}
355+
333356
static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
334357
std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
335-
PyObject *outputs)
358+
PyObject *outputs, bool is_volatile)
336359
{
337360
// Wrap outputs in Variables
361+
auto cdata = is_volatile ? nullptr : THPFunction_asFunction(self);
338362
Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
339-
self->output_info = new std::vector<output_info_type>(num_outputs);
340-
auto &output_info = *self->output_info;
363+
if (self->cdata.is_executable) {
364+
self->output_info = new std::vector<output_info_type>();
365+
self->output_info->reserve(num_outputs);
366+
}
341367
for (int i = 0; i < num_outputs; i++) {
342368
PyObject *output = PyTuple_GET_ITEM(raw_output, i);
343369
THPVariable *output_var;
344370
auto it = t2var.find(output);
345371
if (it == t2var.end()) {
346372
// A completely new tensor - just wrap it and continue
347-
output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self);
373+
if (is_volatile) {
374+
output_var = (THPVariable*)THPVariable_NewVolatile(output);
375+
} else {
376+
output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata);
377+
}
348378
} else {
349379
// If one of the outputs was also an input tensor it's a bit more complicated.
350380
THPVariable *input_var = it->second;
351381
auto& input_var_ = *input_var->cdata;
352382
if (input_var_.grad_fn) {
353-
// If it's not a leaf we want to move it in the graph so backprop
354-
// will be computed correctly:
355-
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
356383
Py_INCREF(input_var);
357384
output_var = input_var;
358-
input_var_.grad_fn = THPFunction_asFunction(self);
359-
input_var_.requires_grad = self->cdata.is_executable;
385+
// If it's not a leaf we want to move it in the graph so backprop
386+
// will be computed correctly, but only if it was modified. Otherwise
387+
// it's better to minimize the number of operations that mutate the graph.
388+
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
389+
if (dirty_inputs.count(output) > 0) {
390+
_transplant_var(input_var_, cdata, i, is_volatile);
391+
}
360392
} else {
361393
// If the leaf Variable has been returned, we have to move it after the
362394
// current function to ensure the gradient is computed correctly.
@@ -370,8 +402,7 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
370402
if (!input_var_.requires_grad) {
371403
Py_INCREF(input_var);
372404
output_var = input_var;
373-
input_var_.grad_fn = THPFunction_asFunction(self);
374-
input_var_.requires_grad = self->cdata.is_executable;
405+
_transplant_var(input_var_, cdata, i, is_volatile);
375406
} else { // input_var_.requires_grad
376407
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
377408
}
@@ -386,20 +417,26 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
386417
// a side-effect of making in-place ops on any of these Variables an
387418
// immediate error, but it would be raised anyway once someone
388419
// calls backward.
389-
output_var = (THPVariable*)THPVariable_New(output, (PyObject*)self);
420+
if (is_volatile) {
421+
output_var = (THPVariable*)THPVariable_NewVolatile(output);
422+
} else {
423+
output_var = (THPVariable*)THPVariable_NewWithFunction(output, cdata);
424+
}
390425
if (!output_var) throw python_error();
391426
output_var->cdata->version_counter->join_with(*input_var->cdata->version_counter);
392427
}
393428
}
394429
}
395430
if (!output_var) throw python_error();
396431

397-
auto& output_tensor = *output_var->cdata->data;
398-
output_info[i] = std::make_tuple(
399-
(PyObject *)getPyTypeObject(output_tensor),
400-
output_tensor.getDevice(),
401-
output_tensor.sizes()
402-
);
432+
if (self->output_info) {
433+
auto& output_tensor = *output_var->cdata->data;
434+
self->output_info->emplace_back(
435+
(PyObject *)getPyTypeObject(output_tensor),
436+
output_tensor.getDevice(),
437+
output_tensor.sizes()
438+
);
439+
}
403440
t2var[output] = output_var;
404441
output_var->cdata->output_nr = i;
405442
PyTuple_SET_ITEM(outputs, i, (PyObject*)output_var);
@@ -562,47 +599,36 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
562599
return std::make_pair(std::move(unpacked), std::move(flags));
563600
}
564601

565-
PyObject* process_outputs(THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr raw_output) {
602+
PyObject* process_outputs(THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr raw_output, bool is_volatile) {
566603
bool unpack_output = _ensure_tuple(raw_output);
567604

568605
auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
569606

570607
THPObjectPtr outputs = PyTuple_New(num_outputs);
571608
if (!outputs) throw python_error();
572-
if (!grad_fn) { // if volatile
573-
// If one of the inputs is volatile let's take a fast path - we want
574-
// minimize the overhead of inference
575-
for (int i = 0; i < num_outputs; i++) {
576-
PyObject *output = PyTuple_GET_ITEM(raw_output.get(), i);
577-
THPVariable *output_var = (THPVariable*)THPVariable_NewVolatile(output);
578-
if (!output_var) throw python_error();
579-
output_var->cdata->output_nr = i;
580-
PyTuple_SET_ITEM(outputs.get(), i, (PyObject*)output_var);
581-
}
582-
} else {
583-
grad_fn->cdata.num_inputs = num_outputs;
584609

585-
// Initialize t2var map
586-
t2var_type t2var;
587-
for (auto& c_var : unpacked.input_vars) {
588-
THPVariable* py_var = (THPVariable*)c_var->pyobj;
589-
t2var.emplace(py_var->data, py_var);
590-
}
610+
grad_fn->cdata.num_inputs = num_outputs;
591611

592-
std::unordered_set<PyObject *> dirty_inputs;
593-
_mark_dirty(grad_fn, t2var, dirty_inputs);
594-
_wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs);
595-
_join_version_counters(grad_fn, t2var);
596-
if (grad_fn->cdata.is_executable) {
597-
_mark_non_differentiable(grad_fn, t2var);
598-
_save_variables(grad_fn, t2var);
599-
} else {
600-
// Remove unnecessary attributes
601-
Py_XDECREF(grad_fn->to_save);
602-
grad_fn->to_save = NULL;
603-
Py_XDECREF(grad_fn->non_differentiable);
604-
grad_fn->non_differentiable = NULL;
605-
}
612+
// Initialize t2var map
613+
t2var_type t2var;
614+
for (auto& c_var : unpacked.input_vars) {
615+
THPVariable* py_var = (THPVariable*)c_var->pyobj;
616+
t2var.emplace(py_var->data, py_var);
617+
}
618+
619+
std::unordered_set<PyObject *> dirty_inputs;
620+
_mark_dirty(grad_fn, t2var, dirty_inputs);
621+
_wrap_outputs(grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile);
622+
_join_version_counters(grad_fn, t2var);
623+
if (grad_fn->cdata.is_executable) {
624+
_mark_non_differentiable(grad_fn, t2var);
625+
_save_variables(grad_fn, t2var);
626+
} else {
627+
// Remove unnecessary attributes
628+
Py_XDECREF(grad_fn->to_save);
629+
grad_fn->to_save = NULL;
630+
Py_XDECREF(grad_fn->non_differentiable);
631+
grad_fn->non_differentiable = NULL;
606632
}
607633

608634
// Unpack the output, unless .forward() returned a tuple
@@ -631,7 +657,7 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
631657
THPObjectPtr raw_output = PyObject_CallObject(forward_fn, unpacked_input.tensor_input);
632658
if (!raw_output) return NULL;
633659

634-
return process_outputs(is_volatile ? NULL : self, unpacked_input, std::move(raw_output));
660+
return process_outputs(self, unpacked_input, std::move(raw_output), is_volatile);
635661
END_HANDLE_TH_ERRORS
636662
}
637663

@@ -670,7 +696,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs)
670696
THPObjectPtr tensor_outputs = PyObject_CallObject(forward_fn, ctx_tensor_input);
671697
if (!tensor_outputs) return NULL;
672698

673-
return process_outputs(is_volatile ? NULL : ctx, unpacked_input, std::move(tensor_outputs));
699+
return process_outputs(ctx, unpacked_input, std::move(tensor_outputs), is_volatile);
674700
END_HANDLE_TH_ERRORS
675701
}
676702

torch/csrc/autograd/python_variable.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,15 @@ PyObject * THPVariable_Wrap(const std::shared_ptr<Variable>& var)
4141
return var->pyobj;
4242
}
4343

44-
// This function DOES NOT steal a reference to data and grad_fn
45-
PyObject * THPVariable_New(PyObject *data, PyObject *_grad_fn)
44+
// This function DOES NOT steal a reference to data
45+
PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& grad_fn)
4646
{
4747
THPUtils_assert(THPModule_isTensor(data), "data must be a Tensor");
48-
THPUtils_assert(THPFunction_Check(_grad_fn), "grad_fn must be a Function");
49-
THPFunction *grad_fn = (THPFunction*)_grad_fn;
50-
auto v = std::make_shared<Variable>(torch::createTensor(data), grad_fn->cdata.is_executable, false);
48+
auto v = std::make_shared<Variable>(torch::createTensor(data), grad_fn->is_executable, false);
5149
PyObject* obj = THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, v);
5250
if (obj) {
5351
v->pyobj = obj;
54-
v->grad_fn = THPFunction_asFunction((THPFunction*)grad_fn);
52+
v->grad_fn = grad_fn;
5553
((THPVariable*)obj)->data = data;
5654
Py_INCREF(data);
5755
}

torch/csrc/autograd/python_variable.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ extern PyObject *THPVariableClass;
1717
bool THPVariable_initModule(PyObject *module);
1818
PyObject * THPVariable_NewVolatile(PyObject *data);
1919
PyObject * THPVariable_NewLeaf(PyObject *data);
20-
PyObject * THPVariable_New(PyObject *data, PyObject *grad_fn);
20+
PyObject * THPVariable_NewWithFunction(PyObject *data, const std::shared_ptr<torch::autograd::Function>& var);
2121
PyObject * THPVariable_Wrap(const std::shared_ptr<torch::autograd::Variable>& var);
2222
PyObject * THPVariable_get_data(THPVariable *self);
2323

0 commit comments

Comments
 (0)