@@ -52,11 +52,11 @@ typedef struct {
5252// The traversal order also affect the lifecycles, so different sort_kind is
5353// used.
5454void MemoryOptimizePass::CollectLifeCycle (
55- std::unordered_map<std::string, lifecycle_t >* lifecycles,
55+ Graph* graph, std::unordered_map<std::string, lifecycle_t >* lifecycles,
5656 int sort_kind) const {
57- max_lifecycle_ = 0 ;
57+ int max_lifecycle = 0 ;
5858 for (auto * op_node : framework::ir::TopologyVarientSort (
59- *graph_ , static_cast <framework::ir::SortKind>(sort_kind))) {
59+ *graph , static_cast <framework::ir::SortKind>(sort_kind))) {
6060 if (!op_node->IsOp ()) continue ;
6161 auto reads = op_node->inputs ;
6262 auto writes = op_node->outputs ;
@@ -77,20 +77,20 @@ void MemoryOptimizePass::CollectLifeCycle(
7777 if (node->Var ()->Persistable ()) continue ;
7878 std::string var = node->Name ();
7979 if (!lifecycles->count (var)) {
80- (*lifecycles)[var] = std::make_pair (max_lifecycle_, max_lifecycle_ );
80+ (*lifecycles)[var] = std::make_pair (max_lifecycle, max_lifecycle );
8181 } else {
8282 (*lifecycles)[var].second =
83- std::max (max_lifecycle_ , lifecycles->at (var).second ); // max()
83+ std::max (max_lifecycle , lifecycles->at (var).second ); // max()
8484 }
8585 }
8686 }
8787
88- ++max_lifecycle_ ;
88+ ++max_lifecycle ;
8989 }
9090}
9191
9292void MemoryOptimizePass::CollectVarMemorySize (
93- space_table_t * space_table) const {
93+ Graph* graph, space_table_t * space_table) const {
9494 const int fake_batch_size = 1 ;
9595
9696 auto valid_var = [&](framework::ir::Node* node) -> bool {
@@ -130,7 +130,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
130130 // although it's not always the case. so black list is the best compromise
131131 // between performance and underlying principle.
132132 std::unordered_set<std::string> black_list;
133- for (auto * node : graph_ ->Nodes ()) {
133+ for (auto * node : graph ->Nodes ()) {
134134 if (node->IsVar () &&
135135 node->Var ()->GetType () ==
136136 framework::proto::VarType::Type::VarType_Type_LOD_TENSOR) {
@@ -141,7 +141,7 @@ void MemoryOptimizePass::CollectVarMemorySize(
141141 }
142142
143143 // Collect tensors from graph.
144- for (auto * node : graph_ ->Nodes ()) {
144+ for (auto * node : graph ->Nodes ()) {
145145 if (node->IsVar () &&
146146 node->Var ()->GetType () ==
147147 framework::proto::VarType::Type::VarType_Type_LOD_TENSOR &&
@@ -304,18 +304,21 @@ void MemoryOptimizePass::RunImpl(Argument* argument) {
304304 // 3. Perform reuse plan: Replace all var's name in the model according to the
305305 // mapping table.
306306 if (!argument->enable_memory_optim ()) return ;
307- graph_ = argument->main_graph_ptr ();
307+ // Because of pass is a singleton, graph can not be member
308+ // variables,otherwise,errors will be caused under multithreading
309+ // conditions.
310+ auto graph = argument->main_graph_ptr ();
308311
309312 int sort_kind = 0 ;
310313 std::unordered_map<std::string, lifecycle_t > lifecycles;
311314 space_table_t space_table;
312315 std::unordered_map<std::string, std::string> node2cluster;
313316 std::unordered_map<std::string, int > cluster_size;
314317
315- CollectLifeCycle (&lifecycles, sort_kind);
316- CollectVarMemorySize (&space_table);
318+ CollectLifeCycle (graph, &lifecycles, sort_kind);
319+ CollectVarMemorySize (graph, &space_table);
317320 MakeSimpleReusePlan (lifecycles, space_table, &node2cluster, &cluster_size);
318- UpdateOpDescsByReuse (graph_ , node2cluster, sort_kind);
321+ UpdateOpDescsByReuse (graph , node2cluster, sort_kind);
319322 return ;
320323}
321324
0 commit comments