Skip to content

Commit 753798e

Browse files
committed
Support initializing specific grad tensors to zero for selected operators
1 parent 9f0bf2b commit 753798e

File tree

12 files changed

+62
-25
lines changed

12 files changed

+62
-25
lines changed

paddle/fluid/eager/accumulation/accumulation_node.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
3939
}
4040

4141
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
42-
operator()(
43-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
42+
operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
4443
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
4544
PADDLE_ENFORCE(grads.size() == 1,
4645
paddle::platform::errors::Fatal(

paddle/fluid/eager/accumulation/accumulation_node.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ class GradNodeAccumulation : public GradNodeBase {
3232

3333
// Functor: perform backward computations
3434
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
35-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
36-
override;
35+
std::vector<std::vector<paddle::experimental::Tensor>>& grads) override;
3736

3837
std::string name() { return "GradNodeAccumulation"; }
3938

paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ void GradNodeScale::SetTensorWrappers_X(
145145
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
146146

147147
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
148-
operator()(
149-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
148+
operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads) {
150149
// 1. Check Output Size
151150
PADDLE_ENFORCE(
152151
((grads.size() == 1) && (grads[0].size() == 1)),

paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ class GradNodeScale : public GradNodeBase {
3939

4040
// Functor: perform backward computations
4141
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
42-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
43-
override;
42+
std::vector<std::vector<paddle::experimental::Tensor>>& grads) override;
4443

4544
void SetTensorWrappers_X(
4645
const std::vector<paddle::experimental::Tensor>& tensors);

paddle/fluid/eager/auto_code_generator/eager_generator.cc

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ std::unordered_map<std::string, std::vector<std::string>>
4747
static std::unordered_map<std::string, paddle::framework::AttributeMap>
4848
operators_with_attrs = {};
4949

50+
static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
51+
"split"};
52+
5053
static std::string LegalizeVariableName(const std::string& var_name) {
5154
std::string ret = var_name;
5255
std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_'
@@ -2053,10 +2056,18 @@ static std::string GenerateGradNodeCCContents(
20532056
// [Generation] Get Full Grad Function
20542057
const char* GRAD_FUNCTION_TEMPLATE =
20552058
"std::vector<std::vector<paddle::experimental::Tensor>> "
2056-
"GradNode%s::operator()(const "
2057-
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) {\n%s\n}";
2058-
std::string grad_function_str = paddle::string::Sprintf(
2059-
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);
2059+
"GradNode%s::operator()("
2060+
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) {\n"
2061+
"%s"
2062+
"%s"
2063+
"\n}";
2064+
std::string fill_zero_str = "";
2065+
if (ops_to_fill_zero_for_empty_grads.count(fwd_op_type)) {
2066+
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads);\n";
2067+
}
2068+
std::string grad_function_str =
2069+
paddle::string::Sprintf(GRAD_FUNCTION_TEMPLATE, fwd_op_type,
2070+
fill_zero_str, generated_grad_function_body);
20602071

20612072
VLOG(6) << "Generated returns";
20622073

@@ -2086,7 +2097,7 @@ static std::string GenerateGradNodeHeaderContents(
20862097
" ~GradNode%s() override = default;\n"
20872098
"\n"
20882099
" virtual std::vector<std::vector<paddle::experimental::Tensor>> "
2089-
"operator()(const "
2100+
"operator()("
20902101
"std::vector<std::vector<paddle::experimental::Tensor>>& grads) "
20912102
"override;\n"
20922103
"\n"

paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import argparse
1818
import os
1919

20+
ops_to_fill_zero_for_empty_grads = set()
21+
2022
# For API dispatch used at python-level
2123
# { op_name : [arg_name, ...] }
2224
core_ops_returns_info = {}
@@ -513,7 +515,7 @@ class {} : public egr::GradNodeBase {{
513515
~{}() override = default;
514516
515517
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
516-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) override;
518+
std::vector<std::vector<paddle::experimental::Tensor>>& grads) override;
517519
518520
// SetTensorWrapperX, SetTensorWrapperY, ...
519521
{}
@@ -558,10 +560,11 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
558560
for _, (ttype, fwd_position,
559561
grad_api_position) in backward_grad_input_map.items():
560562
if IsPlainTensorType(ttype):
561-
grad_api_args[grad_api_position] = f"grads[{fwd_position}][0]"
563+
grad_api_args[
564+
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
562565
else:
563566
assert IsVectorTensorType(ttype)
564-
grad_api_args[grad_api_position] = f"grads[{fwd_position}]"
567+
grad_api_args[grad_api_position] = f"hooked_grads[{fwd_position}]"
565568

566569
for name, _, _, grad_api_position in backward_attrs_list:
567570
saved_attribute_name = GetSavedName(name)
@@ -588,16 +591,24 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
588591
returns_str += f"return returns;\n"
589592

590593
grad_node_name = GetGradNodeName(fwd_api_name)
594+
fill_zero_str = ""
595+
if fwd_api_name in ops_to_fill_zero_for_empty_grads:
596+
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads);\n"
597+
591598
FUNCTION_TEMPLATE = """
592-
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
599+
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads) {{
600+
{}
601+
auto hooked_grads = ApplyGradientHooks(grads);
602+
593603
// Call grad_api function
594604
auto grad_api_returns = paddle::experimental::{}({});
595605
{}
596606
}}
597607
"""
598608

599609
node_definition_str = FUNCTION_TEMPLATE.format(
600-
grad_node_name, bwd_api_name, grad_api_args_str, returns_str)
610+
grad_node_name, fill_zero_str, bwd_api_name, grad_api_args_str,
611+
returns_str)
601612

602613
return node_definition_str
603614

paddle/fluid/eager/grad_node_info.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class GradNodeBase {
103103
* is better choice to fit this format.
104104
* **/
105105
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
106-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads) = 0;
106+
std::vector<std::vector<paddle::experimental::Tensor>>& grads) = 0;
107107

108108
/**
109109
* AddEdges is designed to set input tensors' backward Node as current

paddle/fluid/eager/grad_tensor_holder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class GradTensorHolder {
5252
return buffer_[pos];
5353
}
5454

55-
const std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() {
55+
std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() {
5656
return buffer_;
5757
}
5858

paddle/fluid/eager/tests/data_structure_tests/accumulation_node_test.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,15 @@ TEST(AccumulationNode, Tensor) {
8080
grad_meta->SetStopGradient(false);
8181

8282
// operator()
83-
paddle::experimental::Tensor ret_et0 = node->operator()({{et0}})[0][0];
83+
std::vector<std::vector<paddle::experimental::Tensor>> et0_vec = {{et0}};
84+
paddle::experimental::Tensor ret_et0 = node->operator()(et0_vec)[0][0];
8485
auto* ret_et0_ptr =
8586
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl())
8687
->data<paddle::platform::float16>();
8788
CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f));
8889

89-
paddle::experimental::Tensor ret_et1 = node->operator()({{et1}})[0][0];
90+
std::vector<std::vector<paddle::experimental::Tensor>> et1_vec = {{et1}};
91+
paddle::experimental::Tensor ret_et1 = node->operator()(et1_vec)[0][0];
9092

9193
auto* ret_et1_ptr =
9294
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl())
@@ -121,7 +123,7 @@ TEST(AccumulationNode, Tensor) {
121123
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
122124

123125
// operator()
124-
paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0];
126+
paddle::experimental::Tensor _ret = node->operator()(et0_vec)[0][0];
125127

126128
// Check operator() result, should be 36.0
127129
auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl())

paddle/fluid/eager/tests/data_structure_tests/grad_node_test.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ class GradTestNode : public egr::GradNodeBase {
3131
: GradNodeBase(in_num, out_num), val_(val) {}
3232
GradTestNode() : GradNodeBase() { val_ = 1.0; }
3333
std::vector<std::vector<paddle::experimental::Tensor>> operator()(
34-
const std::vector<std::vector<paddle::experimental::Tensor>>& grads)
35-
override {
34+
std::vector<std::vector<paddle::experimental::Tensor>>& grads) override {
3635
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
3736
->data<float>()[0];
3837
phi::DenseTensorMeta meta =

0 commit comments

Comments
 (0)