Skip to content

Commit 89d38f5

Browse files
authored
Support multi-outputs feature for broadcast ops (#38329)
* No harm to KP * Pass the compile stage * change the WriteData function * fix template bugs and pass ctest of current elementwise * for passing partial template specialization of tempalte function in CI-ROCm * To make 'WriteData' funtion flexible. * a less harmful way to support multi-output * a less harmful way to support multi-output
1 parent f1d56b7 commit 89d38f5

File tree

3 files changed

+142
-42
lines changed

3 files changed

+142
-42
lines changed

paddle/fluid/operators/kernel_primitives/datamover_primitives.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ __device__ __forceinline__ void ReadData(T* dst, const T* __restrict__ src,
254254
}
255255
}
256256
} else { // blockDim,x * NX < num
257-
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
258-
const int kVectorsPerThread = NX / kVectorSize;
257+
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
258+
constexpr int kVectorsPerThread = NX / kVectorSize;
259259
int thread_offset = threadIdx.x * kVectorsPerThread;
260260

261261
using VecType = details::VectorType<T, kVectorSize>;
@@ -441,8 +441,8 @@ __device__ __forceinline__ void WriteData(T* dst, T* __restrict__ src,
441441
}
442442
} else {
443443
// Vector type
444-
const int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
445-
const int kVectorsPerThread = NX / kVectorSize;
444+
constexpr int kVectorSize = (NX % 4 == 0) ? 4 : (NX % 2 == 0) ? 2 : 1;
445+
constexpr int kVectorsPerThread = NX / kVectorSize;
446446

447447
int thread_offset = threadIdx.x * kVectorsPerThread;
448448
using VecType = details::VectorType<T, kVectorSize>;

paddle/pten/kernels/hybird/cuda/elementwise/elementwise_broadcast.cu.h

Lines changed: 97 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,13 @@ template <typename InT,
193193
typename OutT,
194194
typename Functor,
195195
int Arity,
196+
int NumOuts,
196197
int VecSize,
197198
int Rank,
198199
bool IsBoundary = false>
199200
__device__ void ElementwiseBroadcastKernelImpl(
200201
const paddle::framework::Array<const InT *__restrict__, Arity> &ins,
201-
OutT *out,
202+
paddle::framework::Array<OutT *, NumOuts> outs,
202203
const paddle::framework::Array<bool, Arity> &use_broadcast,
203204
uint32_t numel,
204205
const paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
@@ -207,7 +208,7 @@ __device__ void ElementwiseBroadcastKernelImpl(
207208
int block_offset,
208209
Functor func) {
209210
InT args[Arity][VecSize];
210-
OutT result[VecSize];
211+
OutType<OutT, NumOuts> result[VecSize];
211212

212213
#pragma unroll
213214
for (int i = 0; i < Arity; i++) {
@@ -220,28 +221,29 @@ __device__ void ElementwiseBroadcastKernelImpl(
220221
num,
221222
use_broadcast[i]);
222223
}
223-
224-
const bool kCallElementwiseAny =
224+
constexpr bool kCallElementwiseAny =
225225
paddle::platform::FunctionTraits<Functor>::has_pointer_args;
226226
ElementwisePrimitiveCaller<InT,
227-
OutT,
227+
OutType<OutT, NumOuts>,
228228
VecSize,
229229
Functor,
230230
Arity,
231231
kCallElementwiseAny>()(func, args, result);
232-
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
233-
out + block_offset, result, num);
232+
233+
ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, NumOuts>()(
234+
outs, result, block_offset, num);
234235
}
235236

236237
template <typename InT,
237238
typename OutT,
238239
typename Functor,
239240
int Arity,
241+
int NumOuts,
240242
int VecSize,
241243
int Rank>
242244
__global__ void ElementwiseBroadcastKernel(
243245
paddle::framework::Array<const InT *__restrict__, Arity> ins,
244-
OutT *out,
246+
paddle::framework::Array<OutT *, NumOuts> outs,
245247
paddle::framework::Array<bool, Arity> use_broadcast,
246248
uint32_t numel,
247249
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity>
@@ -251,16 +253,18 @@ __global__ void ElementwiseBroadcastKernel(
251253
Functor func) {
252254
int block_offset = BLOCK_ID_X * BLOCK_NUM_X * VecSize;
253255
int stride = BLOCK_NUM_X * GRID_NUM_X * VecSize;
256+
254257
#ifdef PADDLE_WITH_XPU2
255258
for (; block_offset < main_offset; block_offset += stride) {
256259
ElementwiseBroadcastKernelImpl<InT,
257260
OutT,
258261
Functor,
259262
Arity,
263+
NumOuts,
260264
VecSize,
261265
Rank,
262266
false>(ins,
263-
out,
267+
outs,
264268
use_broadcast,
265269
numel,
266270
configs,
@@ -273,22 +277,23 @@ __global__ void ElementwiseBroadcastKernel(
273277
OutT,
274278
Functor,
275279
Arity,
280+
NumOuts,
276281
VecSize,
277282
Rank,
278283
true>(
279-
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
284+
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
280285
}
281-
282286
#else
283287
if (block_offset < main_offset) {
284288
ElementwiseBroadcastKernelImpl<InT,
285289
OutT,
286290
Functor,
287291
Arity,
292+
NumOuts,
288293
VecSize,
289294
Rank,
290295
false>(ins,
291-
out,
296+
outs,
292297
use_broadcast,
293298
numel,
294299
configs,
@@ -300,10 +305,11 @@ __global__ void ElementwiseBroadcastKernel(
300305
OutT,
301306
Functor,
302307
Arity,
308+
NumOuts,
303309
VecSize,
304310
Rank,
305311
true>(
306-
ins, out, use_broadcast, numel, configs, tail_tid, block_offset, func);
312+
ins, outs, use_broadcast, numel, configs, tail_tid, block_offset, func);
307313
}
308314
#endif
309315
}
@@ -312,25 +318,30 @@ template <typename InT,
312318
typename OutT,
313319
typename Functor,
314320
int Arity,
321+
int NumOuts,
315322
int VecSize,
316323
int Rank>
317324
void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
318325
const std::vector<const DenseTensor *> &ins,
319-
DenseTensor *out,
326+
std::vector<DenseTensor *> *outs,
320327
Functor func,
321328
DimensionsTransform merge_dims) {
322-
int numel = out->numel();
329+
int numel = (*outs)[0]->numel();
323330
const int threads = 256;
324331
int blocks = ((numel + VecSize - 1) / VecSize + threads - 1) / threads;
325332

326333
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
327334
int tail_tid = numel % (VecSize * threads);
328335
auto stream = ctx.stream();
329-
OutT *out_data = out->mutable_data<OutT>();
330336

331337
paddle::framework::Array<kps::details::BroadcastConfig<Rank>, Arity> configs;
332338
paddle::framework::Array<bool, Arity> use_broadcast;
333339
paddle::framework::Array<const InT *__restrict__, Arity> ins_data;
340+
paddle::framework::Array<OutT *, NumOuts> outs_data;
341+
342+
for (int i = 0; i < NumOuts; ++i) {
343+
outs_data[i] = (*outs)[i]->mutable_data<OutT>();
344+
}
334345

335346
for (int i = 0; i < Arity; i++) {
336347
use_broadcast[i] = (ins[i]->numel() != numel);
@@ -343,6 +354,7 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
343354
merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
344355
}
345356
}
357+
346358
#ifdef PADDLE_WITH_XPU2
347359
threads = 128;
348360
blocks = 8;
@@ -352,9 +364,10 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
352364
OutT,
353365
Functor,
354366
Arity,
367+
NumOuts,
355368
VecSize,
356369
Rank><<<blocks, threads, stream>>>(ins_data,
357-
out_data,
370+
outs_data,
358371
use_broadcast,
359372
numel,
360373
configs,
@@ -366,10 +379,11 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
366379
OutT,
367380
Functor,
368381
Arity,
382+
NumOuts,
369383
VecSize,
370384
Rank><<<blocks, threads, 0, stream>>>(
371385
ins_data,
372-
out_data,
386+
outs_data,
373387
use_broadcast,
374388
numel,
375389
configs,
@@ -379,19 +393,24 @@ void LaunchKernel(const paddle::platform::CUDADeviceContext &ctx,
379393
#endif
380394
}
381395

382-
template <typename InT, typename OutT, typename Functor, int Arity, int VecSize>
396+
template <typename InT,
397+
typename OutT,
398+
typename Functor,
399+
int Arity,
400+
int NumOuts,
401+
int VecSize>
383402
void LaunchBroadcastKernelForDifferentVecSize(
384403
const paddle::platform::CUDADeviceContext &ctx,
385404
const std::vector<const DenseTensor *> &ins,
386-
DenseTensor *out,
405+
std::vector<DenseTensor *> *outs,
387406
int axis,
388407
Functor func) {
389-
const auto merge_dims = DimensionsTransform(ins, out->dims(), axis);
408+
const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
390409

391-
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
392-
case rank: { \
393-
LaunchKernel<InT, OutT, Functor, Arity, VecSize, rank>( \
394-
ctx, ins, out, func, merge_dims); \
410+
#define CALL_BROADCAST_FOR_DIM_SIZE(rank) \
411+
case rank: { \
412+
LaunchKernel<InT, OutT, Functor, Arity, NumOuts, VecSize, rank>( \
413+
ctx, ins, outs, func, merge_dims); \
395414
} break;
396415

397416
switch (merge_dims.dim_size) {
@@ -414,7 +433,11 @@ void LaunchBroadcastKernelForDifferentVecSize(
414433
#undef CALL_BROADCAST_FOR_DIM_SIZE
415434
}
416435

417-
template <ElementwiseType ET, typename InT, typename OutT, typename Functor>
436+
template <ElementwiseType ET,
437+
typename InT,
438+
typename OutT,
439+
typename Functor,
440+
int NumOuts = 1>
418441
void LaunchBroadcastElementwiseCudaKernel(
419442
const paddle::platform::CUDADeviceContext &ctx,
420443
const std::vector<const DenseTensor *> &ins,
@@ -438,32 +461,68 @@ void LaunchBroadcastElementwiseCudaKernel(
438461
"Currently only broadcast of binary is supported and "
439462
"verified, but received %d.",
440463
kArity));
441-
464+
PADDLE_ENFORCE_EQ(
465+
outs->size(),
466+
NumOuts,
467+
paddle::platform::errors::InvalidArgument(
468+
"Number of outputs shall equal to number of functions, "
469+
"but number of outputs is %d, number of functions is %d.",
470+
outs->size(),
471+
NumOuts));
442472
int in_vec_size = 4;
443-
DenseTensor *out = (*outs)[0];
473+
int out_vec_size = 4;
474+
if (NumOuts > 1) {
475+
for (int i = 0; i < NumOuts; ++i) {
476+
PADDLE_ENFORCE_EQ(
477+
(*outs)[i]->dims(),
478+
(*outs)[0]->dims(),
479+
paddle::platform::errors::InvalidArgument(
480+
"The shape of each output tensor shall be identical yet, but "
481+
"%dth output tensor`s shape is not.",
482+
i));
483+
out_vec_size = std::min(
484+
paddle::platform::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()),
485+
out_vec_size);
486+
}
487+
} else {
488+
out_vec_size =
489+
paddle::platform::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
490+
}
491+
444492
for (auto *in : ins) {
445493
auto temp_size = paddle::platform::GetVectorizedSize<InT>(in->data<InT>());
446-
in_vec_size = in->dims() == out->dims() ? std::min(temp_size, in_vec_size)
447-
: in_vec_size;
494+
in_vec_size = in->dims() == (*outs)[0]->dims()
495+
? std::min(temp_size, in_vec_size)
496+
: in_vec_size;
448497
}
449-
int out_vec_size =
450-
paddle::platform::GetVectorizedSize<OutT>(out->data<OutT>());
451498
int vec_size = std::min(out_vec_size, in_vec_size);
452499

453500
switch (vec_size) {
454501
case 4: {
455-
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 4>(
456-
ctx, ins, out, axis, func);
502+
LaunchBroadcastKernelForDifferentVecSize<InT,
503+
OutT,
504+
Functor,
505+
kArity,
506+
NumOuts,
507+
4>(ctx, ins, outs, axis, func);
457508
break;
458509
}
459510
case 2: {
460-
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 2>(
461-
ctx, ins, out, axis, func);
511+
LaunchBroadcastKernelForDifferentVecSize<InT,
512+
OutT,
513+
Functor,
514+
kArity,
515+
NumOuts,
516+
2>(ctx, ins, outs, axis, func);
462517
break;
463518
}
464519
case 1: {
465-
LaunchBroadcastKernelForDifferentVecSize<InT, OutT, Functor, kArity, 1>(
466-
ctx, ins, out, axis, func);
520+
LaunchBroadcastKernelForDifferentVecSize<InT,
521+
OutT,
522+
Functor,
523+
kArity,
524+
NumOuts,
525+
1>(ctx, ins, outs, axis, func);
467526
break;
468527
}
469528
default: {

paddle/pten/kernels/hybird/cuda/elementwise/elementwise_common.cu.h

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@ namespace pten {
2424
namespace kps = paddle::operators::kernel_primitives;
2525
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
2626

27+
/* Packing scalar type T(float, int etc.) into Array<T, NumOuts> type
28+
for supporting multiple-output feature in elementwise system.*/
29+
template <class T, int Num>
30+
using OutType =
31+
typename std::conditional_t<Num == 1, T, paddle::framework::Array<T, Num>>;
32+
2733
template <typename InT,
2834
typename OutT,
2935
int VecSize,
@@ -76,4 +82,39 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 3, false> {
7682
}
7783
};
7884

85+
template <typename OutT, int VecSize, bool IsBoundary, int NumOuts>
86+
struct ElementwiseWriteDataCaller {
87+
__device__ __forceinline__ void operator()(
88+
paddle::framework::Array<OutT *, NumOuts> outs,
89+
OutType<OutT, NumOuts> src[VecSize],
90+
int block_offset,
91+
int num) {
92+
OutT dst[NumOuts][VecSize];
93+
#pragma unroll
94+
for (int i = 0; i < VecSize; ++i) {
95+
#pragma unroll
96+
for (int j = 0; j < NumOuts; ++j) {
97+
dst[j][i] = (src[i])[j];
98+
}
99+
}
100+
#pragma unroll
101+
for (int i = 0; i < NumOuts; ++i) {
102+
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
103+
outs[i] + block_offset, dst[i], num);
104+
}
105+
}
106+
};
107+
108+
template <typename OutT, int VecSize, bool IsBoundary>
109+
struct ElementwiseWriteDataCaller<OutT, VecSize, IsBoundary, 1> {
110+
__device__ __forceinline__ void operator()(
111+
paddle::framework::Array<OutT *, 1> outs,
112+
OutT src[VecSize],
113+
int block_offset,
114+
int num) {
115+
kps::WriteData<OutT, VecSize, 1, 1, IsBoundary>(
116+
outs[0] + block_offset, src, num);
117+
}
118+
};
119+
79120
} // namespace pten

0 commit comments

Comments
 (0)