@@ -193,12 +193,13 @@ template <typename InT,
193193 typename OutT,
194194 typename Functor,
195195 int Arity,
196+ int NumOuts,
196197 int VecSize,
197198 int Rank,
198199 bool IsBoundary = false >
199200__device__ void ElementwiseBroadcastKernelImpl (
200201 const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
201- OutT *out ,
202+ paddle::framework::Array< OutT *, NumOuts> outs ,
202203 const paddle::framework::Array<bool , Arity> &use_broadcast,
203204 uint32_t numel,
204205 const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
@@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
207208 int block_offset,
208209 Functor func) {
209210 InT args[Arity][VecSize];
210- OutT result[VecSize];
211+ OutType< OutT, NumOuts> result[VecSize];
211212
212213#pragma unroll
213214 for (int i = 0 ; i < Arity; i++) {
@@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl(
220221 num,
221222 use_broadcast[i]);
222223 }
223-
224- const bool kCallElementwiseAny =
224+ constexpr bool kCallElementwiseAny =
225225 paddle::platform::FunctionTraits<Functor>::has_pointer_args;
226226 ElementwisePrimitiveCaller<InT,
227- OutT,
227+ OutType< OutT, NumOuts> ,
228228 VecSize,
229229 Functor,
230230 Arity,
231231 kCallElementwiseAny >()(func, args, result);
232- kps::WriteData<OutT, VecSize, 1 , 1 , IsBoundary>(
233- out + block_offset, result, num);
232+
233+ ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
234+ outs, result, block_offset, num);
234235}
235236
236237template <typename InT,
237238 typename OutT,
238239 typename Functor,
239240 int Arity,
241+ int NumOuts,
240242 int VecSize,
241243 int Rank>
242244__global__ void ElementwiseBroadcastKernel (
243245 paddle::framework::Array<const InT *__restrict__, Arity> ins,
244- OutT *out ,
246+ paddle::framework::Array< OutT *, NumOuts> outs ,
245247 paddle::framework::Array<bool , Arity> use_broadcast,
246248 uint32_t numel,
247249 paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
@@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel(
251253 Functor func) {
252254 int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
253255 int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
256+
254257#ifdef PADDLE_WITH_XPU2
255258 for (; block_offset < main_offset; block_offset += stride) {
256259 ElementwiseBroadcastKernelImpl<InT,
257260 OutT,
258261 Functor,
259262 Arity,
263+ NumOuts,
260264 VecSize,
261265 Rank,
262266 false >(ins,
263- out ,
267+ outs ,
264268 use_broadcast,
265269 numel,
266270 configs,
@@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel(
273277 OutT,
274278 Functor,
275279 Arity,
280+ NumOuts,
276281 VecSize,
277282 Rank,
278283 true >(
279- ins, out , use_broadcast, numel, configs, tail_tid, block_offset, func);
284+ ins, outs , use_broadcast, numel, configs, tail_tid, block_offset, func);
280285 }
281-
282286#else
283287 if (block_offset < main_offset) {
284288 ElementwiseBroadcastKernelImpl<InT,
285289 OutT,
286290 Functor,
287291 Arity,
292+ NumOuts,
288293 VecSize,
289294 Rank,
290295 false >(ins,
291- out ,
296+ outs ,
292297 use_broadcast,
293298 numel,
294299 configs,
@@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel(
300305 OutT,
301306 Functor,
302307 Arity,
308+ NumOuts,
303309 VecSize,
304310 Rank,
305311 true >(
306- ins, out , use_broadcast, numel, configs, tail_tid, block_offset, func);
312+ ins, outs , use_broadcast, numel, configs, tail_tid, block_offset, func);
307313 }
308314#endif
309315}
@@ -312,25 +318,30 @@ template <typename InT,
312318 typename OutT,
313319 typename Functor,
314320 int Arity,
321+ int NumOuts,
315322 int VecSize,
316323 int Rank>
317324void LaunchKernel (const paddle::platform::CUDADeviceContext &ctx,
318325 const std::vector<const DenseTensor *> &ins,
319- DenseTensor *out ,
326+ std::vector< DenseTensor *> *outs ,
320327 Functor func,
321328 DimensionsTransform merge_dims) {
322- int numel = out ->numel ();
329+ int numel = (*outs)[ 0 ] ->numel ();
323330 const int threads = 256 ;
324331 int blocks = ((numel + VecSize - 1 ) / VecSize + threads - 1 ) / threads;
325332
326333 int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
327334 int tail_tid = numel % (VecSize * threads);
328335 auto stream = ctx.stream ();
329- OutT *out_data = out->mutable_data <OutT>();
330336
331337 paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
332338 paddle::framework::Array<bool , Arity> use_broadcast;
333339 paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
340+ paddle::framework::Array<OutT *, NumOuts> outs_data;
341+
342+ for (int i = 0 ; i < NumOuts; ++i) {
343+ outs_data[i] = (*outs)[i]->mutable_data <OutT>();
344+ }
334345
335346 for (int i = 0 ; i < Arity; i++) {
336347 use_broadcast[i] = (ins[i]->numel () != numel);
@@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
343354 merge_dims.out_dims , merge_dims.in_dims [i], merge_dims.dim_size );
344355 }
345356 }
357+
346358#ifdef PADDLE_WITH_XPU2
347359 threads = 128 ;
348360 blocks = 8 ;
@@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
352364 OutT,
353365 Functor,
354366 Arity,
367+ NumOuts,
355368 VecSize,
356369 Rank><<<blocks, threads, stream>>>(ins_data,
357- out_data ,
370+ outs_data ,
358371 use_broadcast,
359372 numel,
360373 configs,
@@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
366379 OutT,
367380 Functor,
368381 Arity,
382+ NumOuts,
369383 VecSize,
370384 Rank><<<blocks, threads, 0 , stream>>>(
371385 ins_data,
372- out_data ,
386+ outs_data ,
373387 use_broadcast,
374388 numel,
375389 configs,
@@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
379393#endif
380394}
381395
382- template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
396+ template <typename InT,
397+ typename OutT,
398+ typename Functor,
399+ int Arity,
400+ int NumOuts,
401+ int VecSize>
383402void LaunchBroadcastKernelForDifferentVecSize (
384403 const paddle::platform::CUDADeviceContext &ctx,
385404 const std::vector<const DenseTensor *> &ins,
386- DenseTensor *out ,
405+ std::vector< DenseTensor *> *outs ,
387406 int axis,
388407 Functor func) {
389- const auto merge_dims = DimensionsTransform (ins, out ->dims (), axis);
408+ const auto merge_dims = DimensionsTransform (ins, (*outs)[ 0 ] ->dims (), axis);
390409
391- #define CALL_BROADCAST_FOR_DIM_SIZE (rank ) \
392- case rank: { \
393- LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>( \
394- ctx, ins, out , func, merge_dims); \
410+ #define CALL_BROADCAST_FOR_DIM_SIZE (rank ) \
411+ case rank: { \
412+ LaunchKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
413+ ctx, ins, outs , func, merge_dims); \
395414 } break ;
396415
397416 switch (merge_dims.dim_size ) {
@@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize(
414433#undef CALL_BROADCAST_FOR_DIM_SIZE
415434}
416435
417- template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
436+ template <ElementwiseType ET,
437+ typename InT,
438+ typename OutT,
439+ typename Functor,
440+ int NumOuts = 1 >
418441void LaunchBroadcastElementwiseCudaKernel (
419442 const paddle::platform::CUDADeviceContext &ctx,
420443 const std::vector<const DenseTensor *> &ins,
@@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel(
438461 " Currently only broadcast of binary is supported and "
439462 " verified, but received %d." ,
440463 kArity ));
441-
464+ PADDLE_ENFORCE_EQ (
465+ outs->size (),
466+ NumOuts,
467+ paddle::platform::errors::InvalidArgument (
468+ " Number of outputs shall equal to number of functions, "
469+ " but number of outputs is %d, number of functions is %d." ,
470+ outs->size (),
471+ NumOuts));
442472 int in_vec_size = 4 ;
443- DenseTensor *out = (*outs)[0 ];
473+ int out_vec_size = 4 ;
474+ if (NumOuts > 1 ) {
475+ for (int i = 0 ; i < NumOuts; ++i) {
476+ PADDLE_ENFORCE_EQ (
477+ (*outs)[i]->dims (),
478+ (*outs)[0 ]->dims (),
479+ paddle::platform::errors::InvalidArgument (
480+ " The shape of each output tensor shall be identical yet, but "
481+ " %dth output tensor`s shape is not." ,
482+ i));
483+ out_vec_size = std::min (
484+ paddle::platform::GetVectorizedSize<OutT>((*outs)[i]->data <OutT>()),
485+ out_vec_size);
486+ }
487+ } else {
488+ out_vec_size =
489+ paddle::platform::GetVectorizedSize<OutT>((*outs)[0 ]->data <OutT>());
490+ }
491+
444492 for (auto *in : ins) {
445493 auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data <InT>());
446- in_vec_size = in->dims () == out->dims () ? std::min (temp_size, in_vec_size)
447- : in_vec_size;
494+ in_vec_size = in->dims () == (*outs)[0 ]->dims ()
495+ ? std::min (temp_size, in_vec_size)
496+ : in_vec_size;
448497 }
449- int out_vec_size =
450- paddle::platform::GetVectorizedSize<OutT>(out->data <OutT>());
451498 int vec_size = std::min (out_vec_size, in_vec_size);
452499
453500 switch (vec_size) {
454501 case 4 : {
455- LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity , 4 >(
456- ctx, ins, out, axis, func);
502+ LaunchBroadcastKernelForDifferentVecSize<InT,
503+ OutT,
504+ Functor,
505+ kArity ,
506+ NumOuts,
507+ 4 >(ctx, ins, outs, axis, func);
457508 break ;
458509 }
459510 case 2 : {
460- LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity , 2 >(
461- ctx, ins, out, axis, func);
511+ LaunchBroadcastKernelForDifferentVecSize<InT,
512+ OutT,
513+ Functor,
514+ kArity ,
515+ NumOuts,
516+ 2 >(ctx, ins, outs, axis, func);
462517 break ;
463518 }
464519 case 1 : {
465- LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity , 1 >(
466- ctx, ins, out, axis, func);
520+ LaunchBroadcastKernelForDifferentVecSize<InT,
521+ OutT,
522+ Functor,
523+ kArity ,
524+ NumOuts,
525+ 1 >(ctx, ins, outs, axis, func);
467526 break ;
468527 }
469528 default : {
0 commit comments