88#include < THPP/THPP.h>
99
1010#include " THP.h"
11+ #include " torch/csrc/autograd/functions/accumulate_grad.h"
1112#include " torch/csrc/autograd/functions/basic_ops.h"
1213#include " torch/csrc/autograd/functions/utils.h"
1314#include " torch/csrc/autograd/python_cpp_function.h"
@@ -330,33 +331,64 @@ static void _mark_dirty(THPFunction *self, t2var_type &t2var,
330331 self->dirty_tensors = NULL ;
331332}
332333
334+ static void _transplant_var (Variable& var, const std::shared_ptr<Function>& fn, int output_nr, bool is_volatile)
335+ {
336+ if (is_volatile) {
337+ var.grad_fn = nullptr ;
338+ var.requires_grad = false ;
339+ var.is_volatile = true ;
340+ var.output_nr = 0 ;
341+ } else {
342+ var.grad_fn = fn;
343+ var.requires_grad = fn->is_executable ;
344+ var.is_volatile = is_volatile;
345+ var.output_nr = output_nr;
346+ }
347+ var.grad = nullptr ;
348+ var.hooks .clear ();
349+ if (auto grad_acc_fn = var.grad_accumulator .lock ()) {
350+ auto grad_acc = dynamic_cast <AccumulateGrad*>(grad_acc_fn.get ());
351+ grad_acc->variable .reset ();
352+ grad_acc->variable_grad .reset ();
353+ }
354+ }
355+
333356static void _wrap_outputs (THPFunction *self, t2var_type &t2var,
334357 std::unordered_set<PyObject *> &dirty_inputs, PyObject *raw_output,
335- PyObject *outputs)
358+ PyObject *outputs, bool is_volatile )
336359{
337360 // Wrap outputs in Variables
361+ auto cdata = is_volatile ? nullptr : THPFunction_asFunction (self);
338362 Py_ssize_t num_outputs = PyTuple_GET_SIZE (raw_output);
339- self->output_info = new std::vector<output_info_type>(num_outputs);
340- auto &output_info = *self->output_info ;
363+ if (self->cdata .is_executable ) {
364+ self->output_info = new std::vector<output_info_type>();
365+ self->output_info ->reserve (num_outputs);
366+ }
341367 for (int i = 0 ; i < num_outputs; i++) {
342368 PyObject *output = PyTuple_GET_ITEM (raw_output, i);
343369 THPVariable *output_var;
344370 auto it = t2var.find (output);
345371 if (it == t2var.end ()) {
346372 // A completely new tensor - just wrap it and continue
347- output_var = (THPVariable*)THPVariable_New (output, (PyObject*)self);
373+ if (is_volatile) {
374+ output_var = (THPVariable*)THPVariable_NewVolatile (output);
375+ } else {
376+ output_var = (THPVariable*)THPVariable_NewWithFunction (output, cdata);
377+ }
348378 } else {
349379 // If one of the outputs was also an input tensor it's a bit more complicated.
350380 THPVariable *input_var = it->second ;
351381 auto & input_var_ = *input_var->cdata ;
352382 if (input_var_.grad_fn ) {
353- // If it's not a leaf we want to move it in the graph so backprop
354- // will be computed correctly:
355- // grad_fn <- variable <- self ==> grad_fn <- self <- variable
356383 Py_INCREF (input_var);
357384 output_var = input_var;
358- input_var_.grad_fn = THPFunction_asFunction (self);
359- input_var_.requires_grad = self->cdata .is_executable ;
385+ // If it's not a leaf we want to move it in the graph so backprop
386+ // will be computed correctly, but only if it was modified. Otherwise
387+ // it's better to minimize the number of operations that mutate the graph.
388+ // grad_fn <- variable <- self ==> grad_fn <- self <- variable
389+ if (dirty_inputs.count (output) > 0 ) {
390+ _transplant_var (input_var_, cdata, i, is_volatile);
391+ }
360392 } else {
361393 // If the leaf Variable has been returned, we have to move it after the
362394 // current function to ensure the gradient is computed correctly.
@@ -370,8 +402,7 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
370402 if (!input_var_.requires_grad ) {
371403 Py_INCREF (input_var);
372404 output_var = input_var;
373- input_var_.grad_fn = THPFunction_asFunction (self);
374- input_var_.requires_grad = self->cdata .is_executable ;
405+ _transplant_var (input_var_, cdata, i, is_volatile);
375406 } else { // input_var_.requires_grad
376407 throw std::runtime_error (" a leaf Variable that requires grad has been used in an in-place operation." );
377408 }
@@ -386,20 +417,26 @@ static void _wrap_outputs(THPFunction *self, t2var_type &t2var,
386417 // a side-effect of making in-place ops on any of these Variables an
387418 // immediate error, but it would be raised anyway once someone
388419 // calls backward.
389- output_var = (THPVariable*)THPVariable_New (output, (PyObject*)self);
420+ if (is_volatile) {
421+ output_var = (THPVariable*)THPVariable_NewVolatile (output);
422+ } else {
423+ output_var = (THPVariable*)THPVariable_NewWithFunction (output, cdata);
424+ }
390425 if (!output_var) throw python_error ();
391426 output_var->cdata ->version_counter ->join_with (*input_var->cdata ->version_counter );
392427 }
393428 }
394429 }
395430 if (!output_var) throw python_error ();
396431
397- auto & output_tensor = *output_var->cdata ->data ;
398- output_info[i] = std::make_tuple (
399- (PyObject *)getPyTypeObject (output_tensor),
400- output_tensor.getDevice (),
401- output_tensor.sizes ()
402- );
432+ if (self->output_info ) {
433+ auto & output_tensor = *output_var->cdata ->data ;
434+ self->output_info ->emplace_back (
435+ (PyObject *)getPyTypeObject (output_tensor),
436+ output_tensor.getDevice (),
437+ output_tensor.sizes ()
438+ );
439+ }
403440 t2var[output] = output_var;
404441 output_var->cdata ->output_nr = i;
405442 PyTuple_SET_ITEM (outputs, i, (PyObject*)output_var);
@@ -562,47 +599,36 @@ std::pair<UnpackedInput, InputFlags> unpack_input(PyObject *args) {
562599 return std::make_pair (std::move (unpacked), std::move (flags));
563600}
564601
565- PyObject* process_outputs (THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr raw_output) {
602+ PyObject* process_outputs (THPFunction* grad_fn, const UnpackedInput& unpacked, THPObjectPtr raw_output, bool is_volatile ) {
566603 bool unpack_output = _ensure_tuple (raw_output);
567604
568605 auto num_outputs = PyTuple_GET_SIZE (raw_output.get ());
569606
570607 THPObjectPtr outputs = PyTuple_New (num_outputs);
571608 if (!outputs) throw python_error ();
572- if (!grad_fn) { // if volatile
573- // If one of the inputs is volatile let's take a fast path - we want
574- // minimize the overhead of inference
575- for (int i = 0 ; i < num_outputs; i++) {
576- PyObject *output = PyTuple_GET_ITEM (raw_output.get (), i);
577- THPVariable *output_var = (THPVariable*)THPVariable_NewVolatile (output);
578- if (!output_var) throw python_error ();
579- output_var->cdata ->output_nr = i;
580- PyTuple_SET_ITEM (outputs.get (), i, (PyObject*)output_var);
581- }
582- } else {
583- grad_fn->cdata .num_inputs = num_outputs;
584609
585- // Initialize t2var map
586- t2var_type t2var;
587- for (auto & c_var : unpacked.input_vars ) {
588- THPVariable* py_var = (THPVariable*)c_var->pyobj ;
589- t2var.emplace (py_var->data , py_var);
590- }
610+ grad_fn->cdata .num_inputs = num_outputs;
591611
592- std::unordered_set<PyObject *> dirty_inputs;
593- _mark_dirty (grad_fn, t2var, dirty_inputs);
594- _wrap_outputs (grad_fn, t2var, dirty_inputs, raw_output, outputs);
595- _join_version_counters (grad_fn, t2var);
596- if (grad_fn->cdata .is_executable ) {
597- _mark_non_differentiable (grad_fn, t2var);
598- _save_variables (grad_fn, t2var);
599- } else {
600- // Remove unnecessary attributes
601- Py_XDECREF (grad_fn->to_save );
602- grad_fn->to_save = NULL ;
603- Py_XDECREF (grad_fn->non_differentiable );
604- grad_fn->non_differentiable = NULL ;
605- }
612+ // Initialize t2var map
613+ t2var_type t2var;
614+ for (auto & c_var : unpacked.input_vars ) {
615+ THPVariable* py_var = (THPVariable*)c_var->pyobj ;
616+ t2var.emplace (py_var->data , py_var);
617+ }
618+
619+ std::unordered_set<PyObject *> dirty_inputs;
620+ _mark_dirty (grad_fn, t2var, dirty_inputs);
621+ _wrap_outputs (grad_fn, t2var, dirty_inputs, raw_output, outputs, is_volatile);
622+ _join_version_counters (grad_fn, t2var);
623+ if (grad_fn->cdata .is_executable ) {
624+ _mark_non_differentiable (grad_fn, t2var);
625+ _save_variables (grad_fn, t2var);
626+ } else {
627+ // Remove unnecessary attributes
628+ Py_XDECREF (grad_fn->to_save );
629+ grad_fn->to_save = NULL ;
630+ Py_XDECREF (grad_fn->non_differentiable );
631+ grad_fn->non_differentiable = NULL ;
606632 }
607633
608634 // Unpack the output, unless .forward() returned a tuple
@@ -631,7 +657,7 @@ PyObject *THPFunction_do_forward(THPFunction *self, PyObject *_inputs)
631657 THPObjectPtr raw_output = PyObject_CallObject (forward_fn, unpacked_input.tensor_input );
632658 if (!raw_output) return NULL ;
633659
634- return process_outputs (is_volatile ? NULL : self, unpacked_input, std::move (raw_output));
660+ return process_outputs (self, unpacked_input, std::move (raw_output), is_volatile );
635661 END_HANDLE_TH_ERRORS
636662}
637663
@@ -670,7 +696,7 @@ PyObject *THPFunction_apply(PyObject *cls, PyObject *_inputs)
670696 THPObjectPtr tensor_outputs = PyObject_CallObject (forward_fn, ctx_tensor_input);
671697 if (!tensor_outputs) return NULL ;
672698
673- return process_outputs (is_volatile ? NULL : ctx, unpacked_input, std::move (tensor_outputs));
699+ return process_outputs (ctx, unpacked_input, std::move (tensor_outputs), is_volatile );
674700 END_HANDLE_TH_ERRORS
675701}
676702
0 commit comments