@@ -13,13 +13,16 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include " paddle/fluid/operators/batch_norm_op.h"
16+ #include " paddle/fluid/operators/amp/fp16_type_traits.h"
1617#include " paddle/fluid/operators/mlu/mlu_baseop.h"
1718
1819namespace paddle {
1920namespace operators {
2021
2122template <typename T>
2223class MLUBatchNormOpKernel : public framework ::OpKernel<T> {
24+ using MPDType = typename details::MPTypeTrait<T>::Type;
25+
2326 public:
2427 void Compute (const framework::ExecutionContext &ctx) const override {
2528 const auto &place = ctx.GetPlace ();
@@ -68,10 +71,10 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
6871
6972 // alloc memory
7073 y->mutable_data <T>(place);
71- mean_out->mutable_data <T >(place);
72- variance_out->mutable_data <T >(place);
73- saved_mean->mutable_data <T >(place);
74- saved_variance->mutable_data <T >(place);
74+ mean_out->mutable_data <MPDType >(place);
75+ variance_out->mutable_data <MPDType >(place);
76+ saved_mean->mutable_data <MPDType >(place);
77+ saved_variance->mutable_data <MPDType >(place);
7578
7679 Tensor transformed_x;
7780 Tensor transformed_y;
@@ -132,6 +135,8 @@ class MLUBatchNormOpKernel : public framework::OpKernel<T> {
132135
133136template <typename T>
134137class MLUBatchNormGradOpKernel : public framework ::OpKernel<T> {
138+ using MPDType = typename details::MPTypeTrait<T>::Type;
139+
135140 public:
136141 void Compute (const framework::ExecutionContext &ctx) const override {
137142 const auto *x = ctx.Input <Tensor>(" X" );
@@ -154,10 +159,10 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
154159 auto &dev_ctx = ctx.template device_context <MLUDeviceContext>();
155160 auto d_x_tmp =
156161 ctx.AllocateTmpTensor <T, MLUDeviceContext>(x->dims (), dev_ctx);
157- auto scale_grad_tmp =
158- ctx. AllocateTmpTensor <T, MLUDeviceContext>( scale->dims (), dev_ctx);
162+ auto scale_grad_tmp = ctx. AllocateTmpTensor <MPDType, MLUDeviceContext>(
163+ scale->dims (), dev_ctx);
159164 auto bias_grad_tmp =
160- ctx.AllocateTmpTensor <T , MLUDeviceContext>(bias->dims (), dev_ctx);
165+ ctx.AllocateTmpTensor <MPDType , MLUDeviceContext>(bias->dims (), dev_ctx);
161166
162167 if (d_x == nullptr ) {
163168 d_x = &d_x_tmp;
@@ -171,8 +176,8 @@ class MLUBatchNormGradOpKernel : public framework::OpKernel<T> {
171176
172177 const auto &place = ctx.GetPlace ();
173178 d_x->mutable_data <T>(place);
174- d_scale->mutable_data <T >(place);
175- d_bias->mutable_data <T >(place);
179+ d_scale->mutable_data <MPDType >(place);
180+ d_bias->mutable_data <MPDType >(place);
176181
177182 use_global_stats = is_test || use_global_stats;
178183
0 commit comments