Skip to content

Commit bbe5228

Browse files
authored
Optimize perf of softmax_with_cross_entropy (PaddlePaddle#39553)
* Optimize perf of softmax_with_cross_entropy * fix * fix * fix accuracy error
1 parent 2fedd39 commit bbe5228

File tree

1 file changed

+289
-7
lines changed

1 file changed

+289
-7
lines changed

paddle/fluid/operators/softmax_with_cross_entropy_op.cu

Lines changed: 289 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace cub = hipcub;
2727
namespace paddle {
2828
namespace operators {
2929

30+
#define ALIGN_BYTES 16
31+
3032
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
3133
using DataLayout = platform::DataLayout;
3234
using Tensor = framework::Tensor;
@@ -47,6 +49,18 @@ static __device__ __forceinline__ T Exp(T x) {
4749
return math::TolerableValue<T>()(static_cast<T>(expx));
4850
}
4951

52+
template <typename Tx, typename Ty = Tx>
53+
struct ExpAddFunctor {
54+
HOSTDEVICE inline ExpAddFunctor(Tx max) : max(max) {}
55+
56+
HOSTDEVICE inline Ty operator()(const Tx& sum, const Tx& x) const {
57+
return static_cast<Ty>(sum + std::exp(x - max));
58+
}
59+
60+
private:
61+
Tx max;
62+
};
63+
5064
// log2(value)
5165
static inline int Log2Ceil(int value) {
5266
int log2_value = 0;
@@ -419,10 +433,272 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
419433
}
420434
}
421435

436+
template <typename T, bool IgnoreIndex>
437+
__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value,
438+
const int label_id,
439+
const int64_t label_value,
440+
const int tid, const int vec_size,
441+
const int offset,
442+
const int ignore_index) {
443+
int loss_id = vec_size * tid + offset;
444+
if (IgnoreIndex) {
445+
if (label_value == loss_id) {
446+
if (label_value == ignore_index) {
447+
loss[label_id] = static_cast<T>(0.0f);
448+
} else {
449+
loss[label_id] = loss_value;
450+
}
451+
}
452+
} else {
453+
if (label_value == loss_id) {
454+
loss[label_id] = loss_value;
455+
}
456+
}
457+
}
458+
459+
template <typename T, typename AccT, int VecSize, class ReduceFunctor>
460+
__device__ __forceinline__ AccT ThreadReduce(const T* input, int size,
461+
const int offset, AccT init,
462+
ReduceFunctor reducer) {
463+
using VecT = kps::details::VectorType<T, VecSize>;
464+
int tid = threadIdx.x;
465+
AccT val = init;
466+
467+
if (offset > 0) {
468+
input -= offset;
469+
size += offset;
470+
if (tid >= offset) {
471+
val = reducer(val, input[tid]);
472+
}
473+
size -= blockDim.x;
474+
input += blockDim.x;
475+
}
476+
int remain = size % (VecSize * blockDim.x);
477+
478+
T ins[VecSize];
479+
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
480+
481+
// vector part
482+
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
483+
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
484+
485+
#pragma unroll
486+
for (int i = 0; i < VecSize; ++i) {
487+
val = reducer(val, ins[i]);
488+
}
489+
}
490+
491+
// scalar part
492+
tid = size - remain + threadIdx.x;
493+
for (; tid < size; tid += blockDim.x) {
494+
val = reducer(val, input[tid]);
495+
}
496+
return val;
497+
}
498+
499+
template <typename T, typename AccT, typename LabelT, int VecSize,
500+
bool IgnoreIndex>
501+
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
502+
T* loss, T* softmax, const T* logits, const LabelT* label, int size,
503+
const int offset, const LogSoftmaxForwardFunctor<AccT>& func,
504+
const int ignore_index) {
505+
using VecT = kps::details::VectorType<T, VecSize>;
506+
int tid = threadIdx.x;
507+
int label_id = blockIdx.x;
508+
auto label_value = static_cast<int64_t>(label[label_id]);
509+
const bool label_valid = label_value >= 0 && label_value < size;
510+
int loss_id_offset = 0;
511+
512+
if (offset > 0) {
513+
logits -= offset;
514+
softmax -= offset;
515+
size += offset;
516+
loss_id_offset -= offset;
517+
if (tid >= offset) {
518+
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
519+
softmax[tid] = static_cast<T>(std::exp(log_softmax));
520+
// loss
521+
if (label_valid) {
522+
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
523+
label_id, label_value, tid, 1,
524+
loss_id_offset, ignore_index);
525+
}
526+
}
527+
size -= blockDim.x;
528+
logits += blockDim.x;
529+
softmax += blockDim.x;
530+
loss_id_offset += blockDim.x;
531+
}
532+
int remain = size % (VecSize * blockDim.x);
533+
534+
T ins[VecSize];
535+
T outs[VecSize];
536+
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
537+
VecT* outs_vec = reinterpret_cast<VecT*>(&outs);
538+
539+
// vector part
540+
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
541+
// read
542+
*ins_vec = reinterpret_cast<const VecT*>(logits)[tid];
543+
544+
#pragma unroll
545+
// compute
546+
for (int i = 0; i < VecSize; ++i) {
547+
AccT log_softmax = func(static_cast<AccT>(ins[i]));
548+
outs[i] = static_cast<T>(std::exp(log_softmax));
549+
550+
// loss
551+
if (label_valid) {
552+
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
553+
label_id, label_value, tid, VecSize,
554+
loss_id_offset + i, ignore_index);
555+
}
556+
}
557+
558+
// write
559+
reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec;
560+
}
561+
562+
// scalar part
563+
tid = size - remain + threadIdx.x;
564+
for (; tid < size; tid += blockDim.x) {
565+
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
566+
softmax[tid] = static_cast<T>(std::exp(log_softmax));
567+
568+
// loss
569+
if (label_valid) {
570+
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
571+
label_value, tid, 1, loss_id_offset,
572+
ignore_index);
573+
}
574+
}
575+
576+
// invalid label, write once
577+
if (!label_valid && threadIdx.x == 0) {
578+
loss[label_id] = static_cast<T>(0.0f);
579+
}
580+
}
581+
582+
template <typename T, typename AccT, typename LabelT, int VecSize,
583+
bool IgnoreIndex>
584+
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
585+
T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
586+
const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
587+
int tid = threadIdx.x;
588+
int remain = size % (VecSize * blockDim.x);
589+
int label_id = blockIdx.x;
590+
auto label_value = static_cast<int64_t>(label[label_id]);
591+
const bool label_valid = label_value >= 0 && label_value < size;
592+
593+
// main part
594+
for (; tid < (size - remain); tid += VecSize * blockDim.x) {
595+
T ins[VecSize];
596+
597+
#pragma unroll
598+
for (int i = 0; i < VecSize; ++i) {
599+
ins[i] = logits[tid + i * blockDim.x];
600+
}
601+
#pragma unroll
602+
for (int i = 0; i < VecSize; ++i) {
603+
AccT log_softmax = func(static_cast<AccT>(ins[i]));
604+
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
605+
// loss
606+
if (label_valid) {
607+
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
608+
label_id, label_value, tid, VecSize, i,
609+
ignore_index);
610+
}
611+
}
612+
}
613+
614+
// tail part
615+
for (; tid < size; tid += blockDim.x) {
616+
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
617+
softmax[tid] = static_cast<T>(std::exp(log_softmax));
618+
// loss
619+
if (label_valid) {
620+
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
621+
label_value, tid, 1, 0, ignore_index);
622+
}
623+
}
624+
625+
// invalid label, write once
626+
if (!label_valid && threadIdx.x == 0) {
627+
loss[label_id] = static_cast<T>(0.0f);
628+
}
629+
}
630+
631+
template <typename T, typename AccT, typename LabelT, int VecSize,
632+
bool IgnoreIndex>
633+
__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
634+
const LabelT* label,
635+
const int high_dim, const int mid_dim,
636+
const int ignore_index) {
637+
using VecT = kps::details::VectorType<T, VecSize>;
638+
639+
// each block deal with one batch
640+
logits += blockIdx.x * mid_dim;
641+
softmax += blockIdx.x * mid_dim;
642+
643+
const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
644+
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
645+
646+
// 1. reduce max
647+
AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
648+
logits, mid_dim, input_offset, -std::numeric_limits<AccT>::infinity(),
649+
kps::MaxFunctor<AccT>());
650+
max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
651+
max, kps::MaxFunctor<AccT>());
652+
653+
// 2. reduce sum
654+
AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
655+
logits, mid_dim, input_offset, static_cast<AccT>(0),
656+
ExpAddFunctor<AccT>(max));
657+
sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
658+
sum, kps::AddFunctor<AccT>());
659+
660+
// 3. softmax
661+
LogSoftmaxForwardFunctor<AccT> func(max, sum);
662+
if (input_offset == output_offset) {
663+
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
664+
loss, softmax, logits, label, mid_dim, input_offset, func,
665+
ignore_index);
666+
} else {
667+
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
668+
loss, softmax, logits, label, mid_dim, func, ignore_index);
669+
}
670+
}
671+
672+
template <typename T, typename LabelT, bool IgnoreIndex>
673+
void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
674+
const LabelT* label, const int high_dim,
675+
const int mid_dim, const int ignore_index,
676+
gpuStream_t stream) {
677+
using AccT = typename details::MPTypeTrait<T>::Type;
678+
constexpr int vec_size = sizeof(float4) / sizeof(T);
679+
const int max_num_threads = 1024;
680+
int max_block_size = std::min(mid_dim / vec_size, max_num_threads);
681+
if (vec_size > 1) {
682+
max_block_size /= 2;
683+
}
684+
685+
int block_size = 1;
686+
while (block_size < max_block_size) {
687+
block_size *= 2;
688+
}
689+
block_size = std::max(block_size, kps::details::kWarpSize);
690+
dim3 grids(high_dim);
691+
dim3 blocks(block_size);
692+
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size,
693+
IgnoreIndex><<<grids, blocks, 0, stream>>>(
694+
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
695+
}
696+
422697
/*
423698
Wrapper of softmax with cross entropy hard label.
424-
- SwitchWarpSoftmaxForward for small size
425-
- cudnn function for large size
699+
- SwitchWarpSoftmaxForward for small size when axis == -1
700+
- LaunchVectorizedSoftmaxForward for large size when axis == -1
701+
- cudnn function for axis != -1
426702
*/
427703
template <typename T, typename LabelT, bool IgnoreIndex>
428704
static void SoftmaxWithCrossEntropyHardLabel(
@@ -431,11 +707,17 @@ static void SoftmaxWithCrossEntropyHardLabel(
431707
T* softmax_data, int N, int dim, int D, const int ignore_index) {
432708
auto stream = ctx.stream();
433709
constexpr int max_dim = 320;
434-
if (D == 1 && dim <= max_dim) { // small size
435-
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
436-
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
437-
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
438-
ignore_index, stream);
710+
if (D == 1) {
711+
if (dim <= max_dim) { // small size
712+
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
713+
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
714+
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
715+
ignore_index, stream);
716+
} else { // large size
717+
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(
718+
loss_data, softmax_data, logits_data, labels_data, N, dim,
719+
ignore_index, stream);
720+
}
439721
} else {
440722
ScopedTensorDescriptor desc;
441723
std::vector<int> tensor_dims = {N, dim, D, 1};

0 commit comments

Comments
 (0)