@@ -2812,6 +2812,113 @@ TEST(MeanAll, Ctor) {
28122812 check_dim_mapping (backward_info.first [1 ], {});
28132813 check_dim_mapping (backward_info.second [0 ], {-1 , -1 });
28142814}
2815+ TEST (BatchNorm, Ctor) {
2816+ std::vector<int64_t > mesh_shape = {2 , 2 };
2817+ std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 };
2818+ std::vector<std::string> dim_names = {" x" , " y" };
2819+ ProcessMesh process_mesh (mesh_shape, process_ids, dim_names);
2820+
2821+ // test forward
2822+ // data_format = NCHW
2823+ // [0, 1, -1, -1],[-1],[-1],[-1],[-1] ->[-1 , 1, -1, -1],[1],[1],[1],[1],[-1]
2824+ auto x_dist_attr = TensorDistAttr ();
2825+ x_dist_attr.set_process_mesh (process_mesh);
2826+ x_dist_attr.set_dims_mapping ({0 , 1 , -1 , -1 });
2827+ x_dist_attr.set_dynamic_dims ({false , false , false , false });
2828+ auto one_dim_dist_attr = TensorDistAttr ();
2829+ one_dim_dist_attr.set_process_mesh (process_mesh);
2830+ one_dim_dist_attr.set_dims_mapping ({-1 });
2831+ one_dim_dist_attr.set_dynamic_dims ({false });
2832+
2833+ phi::distributed::DistMetaTensor x = phi::distributed::DistMetaTensor (
2834+ common::make_ddim ({16 , 16 , 16 , 16 }), x_dist_attr);
2835+ phi::distributed::DistMetaTensor mean = phi::distributed::DistMetaTensor (
2836+ common::make_ddim ({16 }), one_dim_dist_attr);
2837+ phi::distributed::DistMetaTensor variance = phi::distributed::DistMetaTensor (
2838+ common::make_ddim ({16 }), one_dim_dist_attr);
2839+ phi::distributed::DistMetaTensor scale = phi::distributed::DistMetaTensor (
2840+ common::make_ddim ({16 }), one_dim_dist_attr);
2841+ phi::distributed::DistMetaTensor bias = phi::distributed::DistMetaTensor (
2842+ common::make_ddim ({16 }), one_dim_dist_attr);
2843+ phi::distributed::SpmdInfo forward_info =
2844+ phi::distributed::BatchNormInferSpmdStatic (
2845+ x, mean, variance, scale, bias);
2846+
2847+ EXPECT_EQ (forward_info.first .size (), 5UL );
2848+ EXPECT_EQ (forward_info.second .size (), 6UL );
2849+ check_dim_mapping (forward_info.first [0 ], {-1 , 1 , -1 , -1 });
2850+ check_dim_mapping (forward_info.first [1 ], {1 });
2851+ check_dim_mapping (forward_info.first [2 ], {1 });
2852+ check_dim_mapping (forward_info.first [3 ], {-1 });
2853+ check_dim_mapping (forward_info.first [4 ], {-1 });
2854+ check_dim_mapping (forward_info.second [0 ], {-1 , 1 , -1 , -1 });
2855+ check_dim_mapping (forward_info.second [1 ], {1 });
2856+ check_dim_mapping (forward_info.second [2 ], {1 });
2857+ check_dim_mapping (forward_info.second [3 ], {1 });
2858+ check_dim_mapping (forward_info.second [4 ], {1 });
2859+ check_dim_mapping (forward_info.second [5 ], {-1 });
2860+
2861+ // test backward
2862+ // data_format = NCHW
2863+ // [0, 1, -1, -1],[-1],[-1],[-1],[-1],[-1],[-1],[-1],[0, 1, -1, -1]
2864+ // ->[-1,1,-1,-1],[-1],[-1]
2865+ // dst_input: [-1, 1, -1, -1],[-1],[-1],[1],[1],[1],[1],[-1],[-1, 1, -1, -1]
2866+
2867+ x = phi::distributed::DistMetaTensor (common::make_ddim ({16 , 16 , 16 , 16 }),
2868+ x_dist_attr);
2869+ phi::distributed::DistMetaTensor out_grad = phi::distributed::DistMetaTensor (
2870+ common::make_ddim ({16 , 16 , 16 , 16 }), x_dist_attr);
2871+ phi::distributed::DistMetaTensor mean_out = phi::distributed::DistMetaTensor (
2872+ common::make_ddim ({16 }), one_dim_dist_attr);
2873+ phi::distributed::DistMetaTensor variance_out =
2874+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2875+ one_dim_dist_attr);
2876+ scale = phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2877+ one_dim_dist_attr);
2878+ bias = phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2879+ one_dim_dist_attr);
2880+ phi::distributed::DistMetaTensor saved_mean =
2881+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2882+ one_dim_dist_attr);
2883+ phi::distributed::DistMetaTensor saved_variance =
2884+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2885+ one_dim_dist_attr);
2886+ phi::distributed::DistMetaTensor reserve_space =
2887+ phi::distributed::DistMetaTensor (common::make_ddim ({16 }),
2888+ one_dim_dist_attr);
2889+ phi::distributed::SpmdInfo backward_info =
2890+ phi::distributed::BatchNormGradInferSpmd (x,
2891+ scale,
2892+ bias,
2893+ mean_out,
2894+ variance_out,
2895+ saved_mean,
2896+ saved_variance,
2897+ reserve_space,
2898+ out_grad,
2899+ 0.9 ,
2900+ 0.1 ,
2901+ " NCHW" ,
2902+ false ,
2903+ false ,
2904+ false );
2905+
2906+ EXPECT_EQ (backward_info.first .size (), 9UL );
2907+ EXPECT_EQ (backward_info.second .size (), 3UL );
2908+ check_dim_mapping (backward_info.first [0 ], {-1 , 1 , -1 , -1 });
2909+ check_dim_mapping (backward_info.first [1 ], {-1 });
2910+ check_dim_mapping (backward_info.first [2 ], {-1 });
2911+ check_dim_mapping (backward_info.first [3 ], {1 });
2912+ check_dim_mapping (backward_info.first [4 ], {1 });
2913+ check_dim_mapping (backward_info.first [5 ], {1 });
2914+ check_dim_mapping (backward_info.first [6 ], {1 });
2915+ check_dim_mapping (backward_info.first [7 ], {-1 });
2916+ check_dim_mapping (backward_info.first [8 ], {-1 , 1 , -1 , -1 });
2917+
2918+ check_dim_mapping (backward_info.second [0 ], {-1 , 1 , -1 , -1 });
2919+ check_dim_mapping (backward_info.second [1 ], {-1 });
2920+ check_dim_mapping (backward_info.second [2 ], {-1 });
2921+ }
28152922TEST (Topk, Ctor) {
28162923 std::vector<int64_t > mesh_shape = {2 , 2 };
28172924 std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 };
0 commit comments