Skip to content

Commit 114bf00

Browse files
fix exampler and enable static test
1 parent 1dfb708 commit 114bf00

File tree

2 files changed

+28
-27
lines changed

2 files changed

+28
-27
lines changed

python/paddle/nn/layer/pooling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ class LPPool2D(Layer):
384384
385385
>>> # lp pool2d
386386
>>> input = paddle.uniform([1, 3, 32, 32], dtype="float32", min=-1, max=1)
387-
>>> AvgPool2D = nn.LPPool2D(norm_type=2, kernel_size=2, stride=2, padding=0)
387+
>>> LPPool2D = nn.LPPool2D(norm_type=2, kernel_size=2, stride=2, padding=0)
388388
>>> output = LPPool2D(input)
389389
>>> print(output.shape)
390390
[1, 3, 16, 16]

test/legacy_test/test_pool2d_api.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -357,32 +357,6 @@ def check_avg_divisor(self, place):
357357
result = avg_pool2d_dg(input)
358358
np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05)
359359

360-
def test_pool2d(self):
361-
for place in self.places:
362-
self.check_max_dygraph_results(place)
363-
self.check_avg_dygraph_results(place)
364-
self.check_max_dygraph_stride_is_none(place)
365-
self.check_avg_dygraph_stride_is_none(place)
366-
self.check_max_dygraph_padding(place)
367-
self.check_avg_divisor(place)
368-
self.check_max_dygraph_padding_results(place)
369-
self.check_max_dygraph_ceilmode_results(place)
370-
self.check_max_dygraph_nhwc_results(place)
371-
self.check_lp_dygraph_results(place)
372-
self.check_lp_dygraph_stride_is_none(place)
373-
self.check_lp_dygraph_ceilmode_results(place)
374-
self.check_lp_dygraph_nhwc_results(place)
375-
self.check_lp_dygraph_results_norm_type_is_inf(place)
376-
self.check_lp_dygraph_results_norm_type_is_negative_inf(place)
377-
378-
@test_with_pir_api
379-
def test_pool2d_static(self):
380-
paddle.enable_static()
381-
for place in self.places:
382-
self.check_max_static_results(place)
383-
self.check_avg_static_results(place)
384-
paddle.disable_static()
385-
386360
def check_lp_static_results(self, place):
387361
with paddle.static.program_guard(
388362
paddle.static.Program(), paddle.static.Program()
@@ -624,6 +598,33 @@ def check_lp_dygraph_stride_is_none(self, place):
624598
result = lp_pool2d_dg(input)
625599
np.testing.assert_allclose(result.numpy(), result_np, rtol=1e-05)
626600

601+
@test_with_pir_api
602+
def test_pool2d_static(self):
603+
paddle.enable_static()
604+
for place in self.places:
605+
self.check_max_static_results(place)
606+
self.check_avg_static_results(place)
607+
self.check_lp_static_results(place)
608+
paddle.disable_static()
609+
610+
def test_pool2d(self):
611+
for place in self.places:
612+
self.check_max_dygraph_results(place)
613+
self.check_avg_dygraph_results(place)
614+
self.check_max_dygraph_stride_is_none(place)
615+
self.check_avg_dygraph_stride_is_none(place)
616+
self.check_max_dygraph_padding(place)
617+
self.check_avg_divisor(place)
618+
self.check_max_dygraph_padding_results(place)
619+
self.check_max_dygraph_ceilmode_results(place)
620+
self.check_max_dygraph_nhwc_results(place)
621+
self.check_lp_dygraph_results(place)
622+
self.check_lp_dygraph_stride_is_none(place)
623+
self.check_lp_dygraph_ceilmode_results(place)
624+
self.check_lp_dygraph_nhwc_results(place)
625+
self.check_lp_dygraph_results_norm_type_is_inf(place)
626+
self.check_lp_dygraph_results_norm_type_is_negative_inf(place)
627+
627628

628629
class TestPool2DError_API(unittest.TestCase):
629630
def test_error_api(self):

0 commit comments

Comments
 (0)