@@ -52,49 +52,44 @@ void RegisterReduceHookForTensor(const paddle::experimental::Tensor& tensor,
5252 }
5353}
5454
55- static void RetainGradForRegularNode (
56- const paddle::experimental::Tensor& tensor) {
57- AutogradMeta* meta = EagerUtils::unsafe_autograd_meta ( tensor);
58- if (meta-> RetainGrads ()) {
55+ void RetainGradForTensor ( const paddle::experimental::Tensor& tensor) {
56+ if ( IsLeafTensor ( tensor) ) {
57+ // Leaf tensor's grad will always be retained
58+ // Refer to implementation of AccumulationNode for more details
5959 return ;
6060 } else {
61- meta->SetRetainGrads (true );
62- }
61+ AutogradMeta* meta = EagerUtils::unsafe_autograd_meta (tensor);
62+ if (meta->RetainGrads ()) {
63+ return ;
64+ } else {
65+ meta->SetRetainGrads (true );
66+ }
6367
64- std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor =
65- meta->WeakGrad ();
68+ std::weak_ptr<paddle::experimental::Tensor> weak_grad_tensor =
69+ meta->WeakGrad ();
6670
67- // Define Hook
68- auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) {
69- if (!weak_grad_tensor.expired ()) {
70- auto grad_tensor = weak_grad_tensor.lock ();
71- if (t.defined ()) {
72- VLOG (7 ) << " Set impl for RetainGrad Hook for tensor: " << t.name ();
73- // Simply Copy impl() to grad_tensor
74- grad_tensor->set_impl (t.impl ());
75- return *grad_tensor.get ();
71+ // Define Hook
72+ auto hook = [weak_grad_tensor](const paddle::experimental::Tensor& t) {
73+ if (!weak_grad_tensor.expired ()) {
74+ auto grad_tensor = weak_grad_tensor.lock ();
75+ if (t.defined ()) {
76+ VLOG (7 ) << " Set impl for RetainGrad Hook for tensor: " << t.name ();
77+ // Simply Copy impl() to grad_tensor
78+ grad_tensor->set_impl (t.impl ());
79+ return *grad_tensor.get ();
80+ } else {
81+ VLOG (7 ) << " Retain NULL paddle::experimental::Tensor in Grad Hook" ;
82+ return paddle::experimental::Tensor ();
83+ }
7684 } else {
7785 VLOG (7 ) << " Retain NULL paddle::experimental::Tensor in Grad Hook" ;
7886 return paddle::experimental::Tensor ();
7987 }
80- } else {
81- VLOG (7 ) << " Retain NULL paddle::experimental::Tensor in Grad Hook" ;
82- return paddle::experimental::Tensor ();
83- }
84- };
88+ };
8589
86- // Append to GradientHooks
87- RegisterGradientHookForTensor (tensor,
88- std::make_shared<egr::CppTensorHook>(hook));
89- }
90-
91- void RetainGradForTensor (const paddle::experimental::Tensor& tensor) {
92- if (IsLeafTensor (tensor)) {
93- // Leaf tensor's grad will always be retained
94- // Refer to implementation of AccumulationNode for more details
95- return ;
96- } else {
97- RetainGradForRegularNode (tensor);
90+ // Append to GradientHooks
91+ RegisterGradientHookForTensor (tensor,
92+ std::make_shared<egr::CppTensorHook>(hook));
9893 }
9994}
10095
0 commit comments