Skip to content

Commit 35b03e1

Browse files
authored
share MemOptVarInfos of external variables into cinn_launch subgraph (#39209)
* add a graph pass to share MemOptVarInfos of external variables into subgraph * update pass name * fix compile failed * add share_mem_opt_info_to_subgraph_pass test * share_mem_opt_info_to_subgraph_pass_test pass * modify some codes for better style and more robust * update cmake
1 parent 29d3160 commit 35b03e1

File tree

9 files changed

+360
-14
lines changed

9 files changed

+360
-14
lines changed

paddle/fluid/framework/details/eager_deletion_op_handle.cc

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,21 @@ void EagerDeletionOpHandle::CallOnce() {
107107

108108
std::string EagerDeletionOpHandle::Name() const { return "eager_deletion"; }
109109

110+
static bool CanBeErased(ir::MemOptVarInfo *var_info) {
111+
if (var_info->IsSkippedAllMemoryOptimization() ||
112+
!var_info->DecreaseRefCnt()) {
113+
return false;
114+
}
115+
#ifdef PADDLE_WITH_CINN
116+
// if parent_holder exists, it should meet deletion condition too.
117+
std::shared_ptr<ir::MemOptVarInfo> parent_holder = var_info->ParentHolder();
118+
if (parent_holder && !CanBeErased(parent_holder.get())) {
119+
return false;
120+
}
121+
#endif
122+
return true;
123+
}
124+
110125
void EagerDeletionOpHandle::RunImpl() {
111126
if (vars_.size() != var_infos_.size() || is_variant_scope_) {
112127
vars_.clear();
@@ -117,8 +132,7 @@ void EagerDeletionOpHandle::RunImpl() {
117132
std::deque<std::shared_ptr<memory::Allocation>> garbages;
118133
for (size_t i = 0; i < var_infos_.size(); ++i) {
119134
auto *var_info = var_infos_[i];
120-
if (var_info->IsSkippedAllMemoryOptimization() ||
121-
!var_info->DecreaseRefCnt()) {
135+
if (!CanBeErased(var_info)) {
122136
VLOG(4) << "skip memory optimization with var: " << var_info->Name();
123137
continue;
124138
}

paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@ cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pas
55
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
66
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
77

8-
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
9-
eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
8+
SET(EAGER_DELETETION_PASS_DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass conditional_block_op_eager_deletion_pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
9+
if (WITH_CINN)
10+
cc_library(share_varinfo_into_cinn_pass SRCS share_varinfo_into_cinn_pass.cc DEPS pass enforce graph_helper computation_op_handle eager_deletion_op_handle cinn_compiler)
11+
cc_test(share_varinfo_into_cinn_pass_test SRCS share_varinfo_into_cinn_pass_test.cc DEPS share_varinfo_into_cinn_pass parallel_executor cinn_compiler elementwise_add_op mul_op cinn_launch_op)
12+
list(APPEND EAGER_DELETETION_PASS_DEPS share_varinfo_into_cinn_pass)
13+
endif()
1014

11-
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
15+
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS ${EAGER_DELETETION_PASS_DEPS})
16+
17+
cc_library(memory_reuse_pass SRCS memory_reuse_pass.cc DEPS computation_op_handle reference_count_pass_helper share_tensor_buffer_op_handle graph pass multi_devices_helper)
1218

1319
cc_library(buffer_shared_inplace_op_pass SRCS buffer_shared_inplace_op_pass.cc DEPS memory_reuse_pass executor_gc_helper)
14-
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
20+
cc_library(buffer_shared_cross_op_memory_reuse_pass SRCS buffer_shared_cross_op_memory_reuse_pass.cc DEPS memory_reuse_pass)
1521

1622
cc_library(inplace_addto_op_pass SRCS inplace_addto_op_pass.cc DEPS memory_reuse_pass)
1723

paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,13 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
285285
auto recurrent_op_eager_deletion_pass =
286286
ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass");
287287
recurrent_op_eager_deletion_pass->Apply(graph);
288+
289+
#ifdef PADDLE_WITH_CINN
290+
auto share_varinfo_into_cinn_pass =
291+
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
292+
share_varinfo_into_cinn_pass->SetNotOwned(kMemOptVarInfoMapList, &var_infos);
293+
share_varinfo_into_cinn_pass->Apply(graph);
294+
#endif
288295
}
289296

290297
} // namespace ir
@@ -300,3 +307,6 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
300307
USE_PASS(conditional_block_op_eager_deletion_pass);
301308
USE_PASS(while_op_eager_deletion_pass);
302309
USE_PASS(recurrent_op_eager_deletion_pass);
310+
#ifdef PADDLE_WITH_CINN
311+
USE_PASS(share_varinfo_into_cinn_pass);
312+
#endif

paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ class MemOptVarInfo {
6666
return skip_memory_reuse_ || skip_all_memory_optimization_;
6767
}
6868

69+
void SetParentHolder(std::shared_ptr<MemOptVarInfo> parent) {
70+
parent_holder_ = parent;
71+
}
72+
73+
std::shared_ptr<MemOptVarInfo> ParentHolder() const { return parent_holder_; }
74+
6975
const std::string &Name() const { return name_; }
7076

7177
private:
@@ -88,6 +94,9 @@ class MemOptVarInfo {
8894
std::atomic<size_t> runtime_ref_cnt_;
8995
bool skip_memory_reuse_{false};
9096
bool skip_all_memory_optimization_{false};
97+
// point to var info of the same variable in the main graph,
98+
// used in external(input/output) variables of a subgraph
99+
std::shared_ptr<MemOptVarInfo> parent_holder_{nullptr};
91100
};
92101

93102
using MemOptVarInfoMapList = std::vector<
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <algorithm>
16+
#include "paddle/fluid/framework/details/computation_op_handle.h"
17+
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
18+
#include "paddle/fluid/framework/ir/graph_helper.h"
19+
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
20+
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
21+
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
22+
#include "paddle/fluid/operators/cinn/cinn_launch_op.h"
23+
#include "paddle/fluid/platform/enforce.h"
24+
#include "paddle/fluid/string/string_helper.h"
25+
26+
namespace paddle::framework::ir {
27+
28+
using Name2VarInfoMap =
29+
std::unordered_map<std::string, std::shared_ptr<MemOptVarInfo>>;
30+
31+
static details::EagerDeletionOpHandle* FindFollowedEagerDeletionOp(
32+
details::ComputationOpHandle* compute_op) {
33+
for (details::VarHandleBase* var : compute_op->Outputs()) {
34+
if (!var->Node()->IsCtrlVar()) {
35+
continue;
36+
}
37+
for (details::OpHandleBase* op : var->PendingOps()) {
38+
auto* eager_deletion_op =
39+
dynamic_cast<details::EagerDeletionOpHandle*>(op);
40+
if (eager_deletion_op) {
41+
return eager_deletion_op;
42+
}
43+
}
44+
}
45+
return nullptr;
46+
}
47+
48+
static void ShareVarInfoToCinnLaunch(
49+
const MemOptVarInfoMapList& varinfo_maps,
50+
details::ComputationOpHandle* cinn_launch_op) {
51+
details::EagerDeletionOpHandle* followed_eager_deletion_op =
52+
FindFollowedEagerDeletionOp(cinn_launch_op);
53+
if (!followed_eager_deletion_op) {
54+
VLOG(4) << "No eager_deletion op found after this cinn_launch op";
55+
return;
56+
}
57+
58+
std::vector<std::string> vars_to_delete =
59+
followed_eager_deletion_op->VarsToDelete();
60+
if (vars_to_delete.empty()) {
61+
VLOG(4) << "No var to be deleted after this cinn_launch op";
62+
return;
63+
}
64+
VLOG(4) << "Variables would be deleted by the eager_deletion_op"
65+
<< " following the cinn_launch:"
66+
<< paddle::string::join_strings(vars_to_delete, ',');
67+
68+
const Graph& subgraph = paddle2cinn::CinnCompiler::GetInstance()->FindGraph(
69+
cinn_launch_op->GetOp()->Attr<std::string>(operators::kCompilationKey));
70+
auto& dst_varinfo_map =
71+
subgraph.Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
72+
const Name2VarInfoMap& src_varinfo_map =
73+
varinfo_maps.at(cinn_launch_op->GetScopeIdx());
74+
75+
// collect all MemOptVarInfos of external variables
76+
// that would be eager deleted after the cinn_launch subgraph executed,
77+
// and store them as attribute of the subgraph
78+
for (const auto& var_name : vars_to_delete) {
79+
auto it = src_varinfo_map.find(var_name);
80+
PADDLE_ENFORCE_NE(it, src_varinfo_map.end(),
81+
platform::errors::NotFound(
82+
"MemOptVarInfo of var[%s] not found", var_name));
83+
dst_varinfo_map.emplace(var_name, it->second);
84+
}
85+
}
86+
87+
static void TakeVarInfoFromMainGraph(
88+
const Name2VarInfoMap& src_varinfo_map,
89+
const MemOptVarInfoMapList& varinfo_maps,
90+
details::EagerDeletionOpHandle* eager_deletion_op) {
91+
const Name2VarInfoMap& dst_varinfo_map =
92+
varinfo_maps.at(eager_deletion_op->GetScopeIdx());
93+
for (auto&& var_name : eager_deletion_op->VarsToDelete()) {
94+
auto dst_it = dst_varinfo_map.find(var_name);
95+
PADDLE_ENFORCE_NE(dst_it, dst_varinfo_map.end(),
96+
platform::errors::NotFound(
97+
"MemOptVarInfo of var[%s] not found", var_name));
98+
auto src_it = src_varinfo_map.find(var_name);
99+
if (src_it != src_varinfo_map.end()) {
100+
VLOG(4) << "MemOptVarInfo of var[" << var_name << "] set parent holder";
101+
dst_it->second->SetParentHolder(src_it->second);
102+
}
103+
}
104+
}
105+
106+
// This pass will be applied on both the main graph and all cinn subgraphs,
107+
// and it distinguishs them according to whether the graph has the
108+
// kMemOptVarInfoFromMainGraph attribute or not.
109+
// On the main graph, it finds all cinn_launch ops and shares MemOptVarInfos
110+
// to their subgraphs.
111+
// On a cinn subgraph, it iterates each variable that will be deleted by a
112+
// eager_deletion op, and take the MemOptVarInfo from the main graph
113+
// if such one found.
114+
class ShareMemOptInfoToSubGraphPass : public ir::Pass {
115+
protected:
116+
void ApplyImpl(ir::Graph* graph) const override {
117+
auto all_ops = ir::FilterByNodeWrapper<details::OpHandleBase>(*graph);
118+
const auto& varinfo_maps = Get<MemOptVarInfoMapList>(kMemOptVarInfoMapList);
119+
120+
// the main graph
121+
if (!graph->Has(paddle2cinn::kMemOptVarInfoFromMainGraph)) {
122+
for (details::OpHandleBase* op : all_ops) {
123+
auto compute_op = dynamic_cast<details::ComputationOpHandle*>(op);
124+
if (compute_op && compute_op->Name() == "cinn_launch") {
125+
ShareVarInfoToCinnLaunch(varinfo_maps, compute_op);
126+
}
127+
}
128+
} else { // a cinn subgraph
129+
const auto& parent_varinfo_map =
130+
graph->Get<Name2VarInfoMap>(paddle2cinn::kMemOptVarInfoFromMainGraph);
131+
for (details::OpHandleBase* op : all_ops) {
132+
auto eager_deletion_op =
133+
dynamic_cast<details::EagerDeletionOpHandle*>(op);
134+
if (eager_deletion_op) {
135+
TakeVarInfoFromMainGraph(parent_varinfo_map, varinfo_maps,
136+
eager_deletion_op);
137+
}
138+
}
139+
}
140+
}
141+
};
142+
143+
} // namespace paddle::framework::ir
144+
145+
REGISTER_PASS(share_varinfo_into_cinn_pass,
146+
paddle::framework::ir::ShareMemOptInfoToSubGraphPass)
147+
.RequirePassAttr(paddle::framework::ir::kMemOptVarInfoMapList);
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <memory>
16+
#include "gtest/gtest.h"
17+
#include "paddle/fluid/framework/details/computation_op_handle.h"
18+
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h"
21+
#include "paddle/fluid/framework/ir/pass.h"
22+
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
23+
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
24+
#include "paddle/fluid/framework/parallel_executor.h"
25+
#include "paddle/fluid/framework/program_desc.h"
26+
27+
USE_OP(mul);
28+
USE_OP(cinn_launch);
29+
USE_OP(elementwise_add);
30+
namespace paddle::framework {
31+
32+
using Name2VarInfoMap =
33+
std::unordered_map<std::string, std::shared_ptr<ir::MemOptVarInfo>>;
34+
35+
static ProgramDesc BuildProgramInsideCinnLaunchOp() {
36+
ProgramDesc program;
37+
auto* block = program.MutableBlock(0);
38+
block->Var("var1");
39+
block->Var("var2");
40+
block->Var("var3");
41+
block->Var("var4");
42+
block->Var("var5");
43+
44+
auto add_op = std::unique_ptr<OpDesc>(
45+
new OpDesc("elementwise_add", {{"X", {"var1"}}, {"Y", {"var2"}}},
46+
{{"Out", {"var3"}}}, {}));
47+
block->AppendAllocatedOp(std::move(add_op));
48+
auto mul_op = std::unique_ptr<OpDesc>(new OpDesc(
49+
"mul", {{"X", {"var3"}}, {"Y", {"var4"}}}, {{"Out", {"var5"}}}, {}));
50+
block->AppendAllocatedOp(std::move(mul_op));
51+
return program;
52+
}
53+
54+
static ProgramDesc BuildProgramWithCinnLaunchOp(
55+
const std::string& compilation_key) {
56+
// create a cinn_launch op
57+
ProgramDesc program;
58+
auto* block = program.MutableBlock(0);
59+
block->Var("var1");
60+
block->Var("var2");
61+
block->Var("var4");
62+
block->Var("var5");
63+
64+
auto cinn_launch_op = std::unique_ptr<OpDesc>(
65+
new OpDesc("cinn_launch", {{"X", {"var1", "var2", "var4"}}},
66+
{{"Out", {"var5"}}}, {{"compilation_key", compilation_key}}));
67+
block->AppendAllocatedOp(std::move(cinn_launch_op));
68+
return program;
69+
}
70+
71+
struct TestPassContext {
72+
explicit TestPassContext(const ProgramDesc& program) {
73+
graph = std::make_unique<ir::Graph>(program);
74+
details::BuildStrategy build_strategy;
75+
details::ExecutionStrategy exec_strategy;
76+
exec_strategy.use_device_ = paddle::platform::kCUDA;
77+
executor.reset(new ParallelExecutor(platform::CUDAPlace(0), &scope,
78+
exec_strategy, build_strategy,
79+
graph.get()));
80+
}
81+
82+
Scope scope;
83+
std::unique_ptr<ir::Graph> graph;
84+
std::unique_ptr<ParallelExecutor> executor;
85+
};
86+
87+
TEST(ShareMemInfoToSubGraphPassTest, test_main_graph_share_varinfo) {
88+
// add a subgraph to CinnCompiler
89+
auto subgraph = std::make_unique<ir::Graph>(BuildProgramInsideCinnLaunchOp());
90+
subgraph->GetOrInit<Name2VarInfoMap>(
91+
paddle2cinn::kMemOptVarInfoFromMainGraph);
92+
std::string compilation_key =
93+
paddle2cinn::CinnCompiler::GetInstance()->AddGraph(std::move(subgraph));
94+
95+
// build test data and apply pass
96+
auto context = std::make_unique<TestPassContext>(
97+
BuildProgramWithCinnLaunchOp(compilation_key));
98+
99+
// check result
100+
const ir::Graph& result_subgraph =
101+
paddle2cinn::CinnCompiler::GetInstance()->FindGraph(compilation_key);
102+
const auto& dst_varinfo_map = result_subgraph.Get<Name2VarInfoMap>(
103+
paddle2cinn::kMemOptVarInfoFromMainGraph);
104+
ASSERT_EQ(dst_varinfo_map.size(), 4);
105+
EXPECT_EQ(dst_varinfo_map.count("var1"), 1);
106+
EXPECT_EQ(dst_varinfo_map.count("var5"), 1);
107+
EXPECT_EQ(dst_varinfo_map.at("var1").use_count(), 2);
108+
EXPECT_EQ(dst_varinfo_map.at("var5").use_count(), 2);
109+
}
110+
111+
TEST(ShareMemInfoToSubGraphPassTest, test_subgraph_take_varinfo) {
112+
// build test data and apply pass
113+
auto context =
114+
std::make_unique<TestPassContext>(BuildProgramInsideCinnLaunchOp());
115+
auto& varinfo_map_shared = context->graph->GetOrInit<Name2VarInfoMap>(
116+
paddle2cinn::kMemOptVarInfoFromMainGraph);
117+
varinfo_map_shared = {
118+
{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
119+
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 2)},
120+
};
121+
122+
ir::MemOptVarInfoMapList varinfo_maps(1);
123+
auto& dst_varinfo_map = varinfo_maps.front();
124+
dst_varinfo_map = {{"var1", std::make_shared<ir::MemOptVarInfo>("var1", 1)},
125+
{"var2", std::make_shared<ir::MemOptVarInfo>("var2", 1)},
126+
{"var3", std::make_shared<ir::MemOptVarInfo>("var3", 1)},
127+
{"var4", std::make_shared<ir::MemOptVarInfo>("var4", 1)},
128+
{"var5", std::make_shared<ir::MemOptVarInfo>("var5", 1)}};
129+
auto share_pass =
130+
ir::PassRegistry::Instance().Get("share_varinfo_into_cinn_pass");
131+
share_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &varinfo_maps);
132+
share_pass->Apply(context->graph.get());
133+
134+
// check result
135+
ASSERT_NE(dst_varinfo_map.at("var1")->ParentHolder(), nullptr);
136+
ASSERT_NE(dst_varinfo_map.at("var2")->ParentHolder(), nullptr);
137+
ASSERT_EQ(dst_varinfo_map.at("var3")->ParentHolder(), nullptr);
138+
ASSERT_EQ(dst_varinfo_map.at("var4")->ParentHolder(), nullptr);
139+
ASSERT_EQ(dst_varinfo_map.at("var5")->ParentHolder(), nullptr);
140+
}
141+
142+
} // namespace paddle::framework

0 commit comments

Comments
 (0)