Skip to content

Commit c1c5c1f

Browse files
authored
adaptive pool2d pass fix (#39600)
* first commit * teller fix * bug fix * enable for pool2d only * fix global_pooling issue * pooling_type * fix test
1 parent db43b54 commit c1c5c1f

File tree

4 files changed

+27
-26
lines changed

4 files changed

+27
-26
lines changed

paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,18 @@ void AdaptivePool2dConvertGlobalPass::ApplyImpl(ir::Graph* graph) const {
7272
for (const Node* n : graph->Nodes()) {
7373
if (n->IsOp()) {
7474
auto* op = n->Op();
75-
if (op->HasAttr("adaptive") && op->HasAttr("ksize")) {
75+
if (op->Type() == "pool2d" && op->HasAttr("adaptive") &&
76+
op->HasAttr("ksize")) {
77+
if (op->HasAttr("global_pooling")) {
78+
bool global_pooling =
79+
BOOST_GET_CONST(bool, op->GetAttr("global_pooling"));
80+
if (global_pooling) return;
81+
}
82+
if (!op->HasAttr("pooling_type")) return;
83+
std::string type =
84+
BOOST_GET_CONST(std::string, op->GetAttr("pooling_type"));
85+
// adaptive has no effect on max pooling
86+
if (type == "max") return;
7687
bool adaptive = BOOST_GET_CONST(bool, op->GetAttr("adaptive"));
7788
std::vector<int> ksize =
7889
BOOST_GET_CONST(std::vector<int>, op->GetAttr("ksize"));

paddle/fluid/framework/ir/adaptive_pool2d_convert_global_pass_tester.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ TEST(AdaptivePool2dConvertGlobalPass, basic) {
2929
AttributeMap attrs;
3030
attrs["adaptive"] = true;
3131
attrs["ksize"] = std::vector<int>{1, 1};
32+
attrs["pooling_type"] =
33+
std::string("avg"); // adaptive has no effect on max pooling
3234
layers.pool2d(x, false, &attrs);
3335

3436
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));

paddle/fluid/inference/tensorrt/op_teller.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
225225
<< desc.Output("Out").size();
226226
return false;
227227
}
228+
if (desc.HasAttr("data_format")) {
229+
std::string data_format =
230+
BOOST_GET_CONST(std::string, desc.GetAttr("data_format"));
231+
if (data_format == "NHWC" || data_format == "NDHWC") {
232+
return false;
233+
}
234+
}
228235
if (!desc.HasAttr("pooling_type")) {
229236
return false;
230237
} else {

python/paddle/fluid/tests/unittests/ir/inference/test_adaptive_pool2d_convert_global_pass_autoscan.py

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,14 @@ def sample_program_config(self, draw):
4242
st.integers(
4343
min_value=1, max_value=4), min_size=2, max_size=2))
4444

45-
paddings = [0, 0] # only 0 0 is right
45+
paddings = draw(
46+
st.lists(
47+
st.integers(
48+
min_value=1, max_value=4), min_size=2, max_size=2))
49+
4650
ceil_mode = draw(st.booleans())
4751
exclusive = draw(st.booleans())
48-
global_pooling = False #only false is right
52+
global_pooling = draw(st.booleans())
4953
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VAILD"]))
5054

5155
pool_op = OpConfig(
@@ -83,29 +87,6 @@ def sample_predictor_configs(self, program_config):
8387
use_calib_mode=False)
8488
yield config, ['pool2d'], (1e-5, 1e-5)
8589

86-
def add_ignore_pass_case(self):
87-
# Here we put some skip rules to avoid known bugs
88-
def teller1(program_config, predictor_config):
89-
if program_config.ops[0].attrs["pooling_type"] == "max":
90-
x_shape = list(program_config.inputs["input_data"].shape)
91-
if x_shape[-1] != 1 or x_shape[-2] != 1:
92-
return True
93-
return False
94-
95-
def teller2(program_config, predictor_config):
96-
if program_config.ops[0].attrs["padding_algorithm"] == "SAME":
97-
return True
98-
return False
99-
100-
self.add_ignore_check_case(
101-
teller1,
102-
IgnoreReasons.PASS_ACCURACY_ERROR,
103-
"max pooling has diff if H or W is not equals to 1", )
104-
self.add_ignore_check_case(
105-
teller2,
106-
IgnoreReasons.PASS_ACCURACY_ERROR,
107-
"output has wrong result if padding_algorithm equals to SAME", )
108-
10990
def test(self):
11091
self.run_and_statis(
11192
quant=False,

0 commit comments

Comments
 (0)