@@ -48,13 +48,13 @@ struct MergedMomentumKernelParam
4848 T *PADDLE_RESTRICT params[N];
4949 const T *PADDLE_RESTRICT grads[N];
5050 MT *PADDLE_RESTRICT velocitys[N];
51- const MT *PADDLE_RESTRICT lr;
51+ const MultiPrecisionType<MT> *PADDLE_RESTRICT lr;
5252 MT mu;
5353 MT rescale_grad;
5454 uint32_t param_num;
5555
5656 HOSTDEVICE void operator ()(size_t i) const {
57- const auto lr_val = *lr;
57+ const MT lr_val = static_cast <MT>( *lr) ;
5858 for (uint32_t idx = 0 ; idx < param_num; ++idx) {
5959 auto size = sizes[idx];
6060 if (i >= size) continue ;
@@ -81,8 +81,22 @@ struct MergedMomentumKernelParam
8181
8282template <typename DeviceContext, typename T>
8383class MergedMomentumOpKernel : public framework ::OpKernel<T> {
84+ using MPType = typename operators::details::MPTypeTrait<T>::Type;
85+
8486 public:
8587 void Compute (const framework::ExecutionContext &ctx) const override {
88+ const bool multi_precision = ctx.Attr <bool >(" multi_precision" );
89+ if (multi_precision) {
90+ InnerCompute<MPType>(ctx, multi_precision);
91+ } else {
92+ InnerCompute<T>(ctx, multi_precision);
93+ }
94+ }
95+
96+ private:
97+ template <typename MT>
98+ void InnerCompute (const framework::ExecutionContext &ctx,
99+ const bool multi_precision) const {
86100 auto params = ctx.MultiInput <framework::Tensor>(" Param" );
87101 auto params_out = ctx.MultiOutput <framework::Tensor>(" ParamOut" );
88102 size_t n = params.size ();
@@ -133,7 +147,6 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
133147 auto master_params = ctx.MultiInput <framework::Tensor>(" MasterParam" );
134148 auto master_params_out =
135149 ctx.MultiOutput <framework::Tensor>(" MasterParamOut" );
136- auto multi_precision = ctx.Attr <bool >(" multi_precision" );
137150 if (multi_precision) {
138151 PADDLE_ENFORCE_EQ (
139152 n, master_params.size (),
@@ -206,39 +219,37 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
206219 << " , regularization_coeffs.size(): "
207220 << regularization_coeffs.size ();
208221
209- using MPType = typename operators::details::MPTypeTrait<T>::Type;
210-
211222 auto &dev_ctx = ctx.template device_context <DeviceContext>();
212223
213224 if (lrs.size () == 1 && use_nesterov == false &&
214225 regularization_methods.size () == 0 ) {
215- #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL (kMultiPrecision ) \
216- MergedMomentumKernelParam<T, MPType , kMultiPrecision > kernel_params; \
217- constexpr auto kMaxMergedNum = decltype (kernel_params)::N; \
218- size_t kernel_num = (n + kMaxMergedNum - 1 ) / kMaxMergedNum ; \
219- kernel_params.mu = static_cast <MPType >(mu); \
220- kernel_params.rescale_grad = static_cast <MPType >(rescale_grad); \
221- kernel_params.lr = lrs[0 ]->data <MPType>(); \
222- for (size_t i = 0 ; i < kernel_num; ++i) { \
223- size_t start = i * kMaxMergedNum ; \
224- size_t end = std::min ((i + 1 ) * kMaxMergedNum , n); \
225- kernel_params.param_num = static_cast <uint32_t >(end - start); \
226- size_t max_size = 0 ; \
227- for (size_t j = 0 ; j < kernel_params.param_num ; ++j) { \
228- auto size = static_cast <size_t >(params_out[j + start]->numel ()); \
229- max_size = std::max (max_size, size); \
230- kernel_params.sizes [j] = size; \
231- kernel_params.params [j] = params_out[j + start]->data <T>(); \
232- kernel_params.grads [j] = grads[j + start]->data <T>(); \
233- kernel_params.velocitys [j] = velocitys_out[j + start]->data <MPType >(); \
234- kernel_params.SetMasterParam ( \
235- j, kMultiPrecision ? master_params_out[j + start]->data <MPType >() \
236- : nullptr ); \
237- } \
238- platform::ForRange<DeviceContext> for_range (dev_ctx, max_size); \
239- for_range (kernel_params); \
240- VLOG (10 ) << " Launch MergedMomentum kernel " << i << " " \
241- << kernel_params.param_num ; \
226+ #define PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL (kMultiPrecision ) \
227+ MergedMomentumKernelParam<T, MT , kMultiPrecision > kernel_params; \
228+ constexpr auto kMaxMergedNum = decltype (kernel_params)::N; \
229+ size_t kernel_num = (n + kMaxMergedNum - 1 ) / kMaxMergedNum ; \
230+ kernel_params.mu = static_cast <MT >(mu); \
231+ kernel_params.rescale_grad = static_cast <MT >(rescale_grad); \
232+ kernel_params.lr = lrs[0 ]->data <MPType>(); \
233+ for (size_t i = 0 ; i < kernel_num; ++i) { \
234+ size_t start = i * kMaxMergedNum ; \
235+ size_t end = std::min ((i + 1 ) * kMaxMergedNum , n); \
236+ kernel_params.param_num = static_cast <uint32_t >(end - start); \
237+ size_t max_size = 0 ; \
238+ for (size_t j = 0 ; j < kernel_params.param_num ; ++j) { \
239+ auto size = static_cast <size_t >(params_out[j + start]->numel ()); \
240+ max_size = std::max (max_size, size); \
241+ kernel_params.sizes [j] = size; \
242+ kernel_params.params [j] = params_out[j + start]->data <T>(); \
243+ kernel_params.grads [j] = grads[j + start]->data <T>(); \
244+ kernel_params.velocitys [j] = velocitys_out[j + start]->data <MT >(); \
245+ kernel_params.SetMasterParam ( \
246+ j, kMultiPrecision ? master_params_out[j + start]->data <MT >() \
247+ : nullptr ); \
248+ } \
249+ platform::ForRange<DeviceContext> for_range (dev_ctx, max_size); \
250+ for_range (kernel_params); \
251+ VLOG (10 ) << " Launch MergedMomentum kernel " << i << " " \
252+ << kernel_params.param_num ; \
242253 }
243254 if (multi_precision) {
244255 PADDLE_LAUNCH_MERGED_MOMENTUM_KERNEL (true );
@@ -254,34 +265,33 @@ class MergedMomentumOpKernel : public framework::OpKernel<T> {
254265 ? RegularizationType::kL2DECAY
255266 : RegularizationType::kNONE ;
256267
257- MPType regularization_coeff = static_cast <MPType >(0.0 );
268+ MT regularization_coeff = static_cast <MT >(0.0 );
258269 if (regularization_coeffs.size () != 0 ) {
259- regularization_coeff =
260- static_cast <MPType>(regularization_coeffs[idx]);
270+ regularization_coeff = static_cast <MT>(regularization_coeffs[idx]);
261271 }
262272 auto lr_temp = lrs.size () > 1 ? lrs[idx] : lrs[0 ];
263273
264- const MPType *master_in_data =
265- multi_precision ? master_params[idx]->data <MPType >() : nullptr ;
266- MPType *master_out_data =
267- multi_precision ? master_params_out[idx]->data <MPType >() : nullptr ;
274+ const MT *master_in_data =
275+ multi_precision ? master_params[idx]->data <MT >() : nullptr ;
276+ MT *master_out_data =
277+ multi_precision ? master_params_out[idx]->data <MT >() : nullptr ;
268278 if (platform::is_cpu_place (ctx.GetPlace ())) {
269- CPUDenseMomentumFunctor<MPType > functor;
270- functor (params[idx], grads[idx], velocitys[idx], lr_temp, mu,
271- use_nesterov, regularization_flag, regularization_coeff ,
272- params_out[idx], velocitys_out[idx]);
279+ CPUDenseMomentumFunctor<MT > functor;
280+ functor (params[idx], grads[idx], velocitys[idx], lr_temp,
281+ static_cast <MT>(mu), use_nesterov, regularization_flag ,
282+ regularization_coeff, params_out[idx], velocitys_out[idx]);
273283 VLOG (10 ) << " Launch MergedMomentum cpu kernel." ;
274284 } else if (platform::is_gpu_place (ctx.GetPlace ())) {
275285 platform::ForRange<DeviceContext> for_range (
276286 static_cast <const DeviceContext &>(ctx.device_context ()),
277287 params[idx]->numel ());
278- #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL (__nesterov, __reg_type ) \
279- DenseMomentumFunctor<T, MPType , __reg_type, __nesterov> functor ( \
280- params[idx]->data <T>(), grads[idx]->data <T>(), \
281- velocitys[idx]->data <MPType >(), lr_temp->data <MPType>(), master_in_data, \
282- mu, rescale_grad, params[idx]-> numel ( ), regularization_coeff, \
283- params_out [idx]->data <T> (), velocitys_out [idx]->data <MPType >(), \
284- master_out_data); \
288+ #define PADDLE_LAUNCH_DENSE_MTMOMENTUM_KERNEL (__nesterov, __reg_type ) \
289+ DenseMomentumFunctor<T, MT , __reg_type, __nesterov> functor ( \
290+ params[idx]->data <T>(), grads[idx]->data <T>(), \
291+ velocitys[idx]->data <MT >(), lr_temp->data <MPType>(), master_in_data, \
292+ static_cast <MT>(mu), static_cast <MT>(rescale_grad ), \
293+ params [idx]->numel (), regularization_coeff, params_out [idx]->data <T >(), \
294+ velocitys_out[idx]-> data <MT>(), master_out_data); \
285295 for_range (functor);
286296 if (use_nesterov) {
287297 if (regularization_flag == RegularizationType::kL2DECAY ) {
0 commit comments