Skip to content

Commit 96a4c31

Browse files
authored
[CINN] Add merge_all_horizontal_groups flag (#72775)
1 parent fc24bc7 commit 96a4c31

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

paddle/common/flags.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,16 @@ PHI_DEFINE_EXPORTED_bool(enable_fusion_result_check,
14021402
false,
14031403
"Whether enable fusion result check in cinn.");
14041404

1405+
/**
1406+
* CINN all horizontal groups merge FLAG
1407+
* Name: FLAGS_merge_all_horizontal_groups
1408+
* Since Version: 3.0
1409+
* Value Range: bool, default=false
1410+
*/
1411+
PHI_DEFINE_EXPORTED_bool(merge_all_horizontal_groups,
1412+
false,
1413+
"Whether enable merge all horizontal groups in cinn.");
1414+
14051415
/**
14061416
* Conv Search cache max number related FLAG
14071417
* Name: FLAGS_search_cache_max_number

paddle/fluid/pir/transforms/sub_graph_detector.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
5252
#endif
5353

54+
COMMON_DECLARE_bool(merge_all_horizontal_groups);
5455
REGISTER_FILE_SYMBOLS(sub_graph_detector);
5556
namespace pir {
5657
std::vector<pir::Operation*> InverselyTopologicalSort(pir::Block* block) {
@@ -782,7 +783,7 @@ void SubgraphDetector::SubgraphFusion() {
782783
auto subgraph_list = GetSubgraphList();
783784
VLOG(4) << "Merge non-related subgraphs (size=" << subgraph_list.size()
784785
<< ")";
785-
if (subgraph_list.size() > 2048) return;
786+
if (subgraph_list.size() > 2048 && !FLAGS_merge_all_horizontal_groups) return;
786787
for (size_t i = 0; i < subgraph_list.size(); ++i) {
787788
auto lhs = subgraph_list[i];
788789
if (!lhs->substitute) continue;

0 commit comments

Comments
 (0)