@@ -357,32 +357,6 @@ def check_avg_divisor(self, place):
357357 result = avg_pool2d_dg (input )
358358 np .testing .assert_allclose (result .numpy (), result_np , rtol = 1e-05 )
359359
360- def test_pool2d (self ):
361- for place in self .places :
362- self .check_max_dygraph_results (place )
363- self .check_avg_dygraph_results (place )
364- self .check_max_dygraph_stride_is_none (place )
365- self .check_avg_dygraph_stride_is_none (place )
366- self .check_max_dygraph_padding (place )
367- self .check_avg_divisor (place )
368- self .check_max_dygraph_padding_results (place )
369- self .check_max_dygraph_ceilmode_results (place )
370- self .check_max_dygraph_nhwc_results (place )
371- self .check_lp_dygraph_results (place )
372- self .check_lp_dygraph_stride_is_none (place )
373- self .check_lp_dygraph_ceilmode_results (place )
374- self .check_lp_dygraph_nhwc_results (place )
375- self .check_lp_dygraph_results_norm_type_is_inf (place )
376- self .check_lp_dygraph_results_norm_type_is_negative_inf (place )
377-
378- @test_with_pir_api
379- def test_pool2d_static (self ):
380- paddle .enable_static ()
381- for place in self .places :
382- self .check_max_static_results (place )
383- self .check_avg_static_results (place )
384- paddle .disable_static ()
385-
386360 def check_lp_static_results (self , place ):
387361 with paddle .static .program_guard (
388362 paddle .static .Program (), paddle .static .Program ()
@@ -624,6 +598,33 @@ def check_lp_dygraph_stride_is_none(self, place):
624598 result = lp_pool2d_dg (input )
625599 np .testing .assert_allclose (result .numpy (), result_np , rtol = 1e-05 )
626600
601+ @test_with_pir_api
602+ def test_pool2d_static (self ):
603+ paddle .enable_static ()
604+ for place in self .places :
605+ self .check_max_static_results (place )
606+ self .check_avg_static_results (place )
607+ self .check_lp_static_results (place )
608+ paddle .disable_static ()
609+
610+ def test_pool2d (self ):
611+ for place in self .places :
612+ self .check_max_dygraph_results (place )
613+ self .check_avg_dygraph_results (place )
614+ self .check_max_dygraph_stride_is_none (place )
615+ self .check_avg_dygraph_stride_is_none (place )
616+ self .check_max_dygraph_padding (place )
617+ self .check_avg_divisor (place )
618+ self .check_max_dygraph_padding_results (place )
619+ self .check_max_dygraph_ceilmode_results (place )
620+ self .check_max_dygraph_nhwc_results (place )
621+ self .check_lp_dygraph_results (place )
622+ self .check_lp_dygraph_stride_is_none (place )
623+ self .check_lp_dygraph_ceilmode_results (place )
624+ self .check_lp_dygraph_nhwc_results (place )
625+ self .check_lp_dygraph_results_norm_type_is_inf (place )
626+ self .check_lp_dygraph_results_norm_type_is_negative_inf (place )
627+
627628
628629class TestPool2DError_API (unittest .TestCase ):
629630 def test_error_api (self ):
0 commit comments