@@ -27,6 +27,8 @@ namespace cub = hipcub;
2727namespace paddle {
2828namespace operators {
2929
30+ #define ALIGN_BYTES 16
31+
3032using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
3133using DataLayout = platform::DataLayout;
3234using Tensor = framework::Tensor;
@@ -47,6 +49,18 @@ static __device__ __forceinline__ T Exp(T x) {
4749 return math::TolerableValue<T>()(static_cast <T>(expx));
4850}
4951
52+ template <typename Tx, typename Ty = Tx>
53+ struct ExpAddFunctor {
54+ HOSTDEVICE inline ExpAddFunctor (Tx max) : max(max) {}
55+
56+ HOSTDEVICE inline Ty operator ()(const Tx& sum, const Tx& x) const {
57+ return static_cast <Ty>(sum + std::exp (x - max));
58+ }
59+
60+ private:
61+ Tx max;
62+ };
63+
5064// log2(value)
5165static inline int Log2Ceil (int value) {
5266 int log2_value = 0 ;
@@ -419,10 +433,272 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
419433 }
420434}
421435
436+ template <typename T, bool IgnoreIndex>
437+ __device__ __forceinline__ void ComputeLoss (T* loss, const T loss_value,
438+ const int label_id,
439+ const int64_t label_value,
440+ const int tid, const int vec_size,
441+ const int offset,
442+ const int ignore_index) {
443+ int loss_id = vec_size * tid + offset;
444+ if (IgnoreIndex) {
445+ if (label_value == loss_id) {
446+ if (label_value == ignore_index) {
447+ loss[label_id] = static_cast <T>(0 .0f );
448+ } else {
449+ loss[label_id] = loss_value;
450+ }
451+ }
452+ } else {
453+ if (label_value == loss_id) {
454+ loss[label_id] = loss_value;
455+ }
456+ }
457+ }
458+
459+ template <typename T, typename AccT, int VecSize, class ReduceFunctor >
460+ __device__ __forceinline__ AccT ThreadReduce (const T* input, int size,
461+ const int offset, AccT init,
462+ ReduceFunctor reducer) {
463+ using VecT = kps::details::VectorType<T, VecSize>;
464+ int tid = threadIdx .x ;
465+ AccT val = init;
466+
467+ if (offset > 0 ) {
468+ input -= offset;
469+ size += offset;
470+ if (tid >= offset) {
471+ val = reducer (val, input[tid]);
472+ }
473+ size -= blockDim .x ;
474+ input += blockDim .x ;
475+ }
476+ int remain = size % (VecSize * blockDim .x );
477+
478+ T ins[VecSize];
479+ VecT* ins_vec = reinterpret_cast <VecT*>(&ins);
480+
481+ // vector part
482+ for (; VecSize * tid < (size - remain); tid += blockDim .x ) {
483+ *ins_vec = reinterpret_cast <const VecT*>(input)[tid];
484+
485+ #pragma unroll
486+ for (int i = 0 ; i < VecSize; ++i) {
487+ val = reducer (val, ins[i]);
488+ }
489+ }
490+
491+ // scalar part
492+ tid = size - remain + threadIdx .x ;
493+ for (; tid < size; tid += blockDim .x ) {
494+ val = reducer (val, input[tid]);
495+ }
496+ return val;
497+ }
498+
499+ template <typename T, typename AccT, typename LabelT, int VecSize,
500+ bool IgnoreIndex>
501+ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl (
502+ T* loss, T* softmax, const T* logits, const LabelT* label, int size,
503+ const int offset, const LogSoftmaxForwardFunctor<AccT>& func,
504+ const int ignore_index) {
505+ using VecT = kps::details::VectorType<T, VecSize>;
506+ int tid = threadIdx .x ;
507+ int label_id = blockIdx .x ;
508+ auto label_value = static_cast <int64_t >(label[label_id]);
509+ const bool label_valid = label_value >= 0 && label_value < size;
510+ int loss_id_offset = 0 ;
511+
512+ if (offset > 0 ) {
513+ logits -= offset;
514+ softmax -= offset;
515+ size += offset;
516+ loss_id_offset -= offset;
517+ if (tid >= offset) {
518+ AccT log_softmax = func (static_cast <AccT>(logits[tid]));
519+ softmax[tid] = static_cast <T>(std::exp (log_softmax));
520+ // loss
521+ if (label_valid) {
522+ ComputeLoss<T, IgnoreIndex>(loss, static_cast <T>(-log_softmax),
523+ label_id, label_value, tid, 1 ,
524+ loss_id_offset, ignore_index);
525+ }
526+ }
527+ size -= blockDim .x ;
528+ logits += blockDim .x ;
529+ softmax += blockDim .x ;
530+ loss_id_offset += blockDim .x ;
531+ }
532+ int remain = size % (VecSize * blockDim .x );
533+
534+ T ins[VecSize];
535+ T outs[VecSize];
536+ VecT* ins_vec = reinterpret_cast <VecT*>(&ins);
537+ VecT* outs_vec = reinterpret_cast <VecT*>(&outs);
538+
539+ // vector part
540+ for (; VecSize * tid < (size - remain); tid += blockDim .x ) {
541+ // read
542+ *ins_vec = reinterpret_cast <const VecT*>(logits)[tid];
543+
544+ #pragma unroll
545+ // compute
546+ for (int i = 0 ; i < VecSize; ++i) {
547+ AccT log_softmax = func (static_cast <AccT>(ins[i]));
548+ outs[i] = static_cast <T>(std::exp (log_softmax));
549+
550+ // loss
551+ if (label_valid) {
552+ ComputeLoss<T, IgnoreIndex>(loss, static_cast <T>(-log_softmax),
553+ label_id, label_value, tid, VecSize,
554+ loss_id_offset + i, ignore_index);
555+ }
556+ }
557+
558+ // write
559+ reinterpret_cast <VecT*>(softmax)[tid] = *outs_vec;
560+ }
561+
562+ // scalar part
563+ tid = size - remain + threadIdx .x ;
564+ for (; tid < size; tid += blockDim .x ) {
565+ AccT log_softmax = func (static_cast <AccT>(logits[tid]));
566+ softmax[tid] = static_cast <T>(std::exp (log_softmax));
567+
568+ // loss
569+ if (label_valid) {
570+ ComputeLoss<T, IgnoreIndex>(loss, static_cast <T>(-log_softmax), label_id,
571+ label_value, tid, 1 , loss_id_offset,
572+ ignore_index);
573+ }
574+ }
575+
576+ // invalid label, write once
577+ if (!label_valid && threadIdx .x == 0 ) {
578+ loss[label_id] = static_cast <T>(0 .0f );
579+ }
580+ }
581+
582+ template <typename T, typename AccT, typename LabelT, int VecSize,
583+ bool IgnoreIndex>
584+ __device__ __forceinline__ void ScalarSoftmaxForwardImpl (
585+ T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
586+ const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
587+ int tid = threadIdx .x ;
588+ int remain = size % (VecSize * blockDim .x );
589+ int label_id = blockIdx .x ;
590+ auto label_value = static_cast <int64_t >(label[label_id]);
591+ const bool label_valid = label_value >= 0 && label_value < size;
592+
593+ // main part
594+ for (; tid < (size - remain); tid += VecSize * blockDim .x ) {
595+ T ins[VecSize];
596+
597+ #pragma unroll
598+ for (int i = 0 ; i < VecSize; ++i) {
599+ ins[i] = logits[tid + i * blockDim .x ];
600+ }
601+ #pragma unroll
602+ for (int i = 0 ; i < VecSize; ++i) {
603+ AccT log_softmax = func (static_cast <AccT>(ins[i]));
604+ softmax[tid + i * blockDim .x ] = static_cast <T>(std::exp (log_softmax));
605+ // loss
606+ if (label_valid) {
607+ ComputeLoss<T, IgnoreIndex>(loss, static_cast <T>(-log_softmax),
608+ label_id, label_value, tid, VecSize, i,
609+ ignore_index);
610+ }
611+ }
612+ }
613+
614+ // tail part
615+ for (; tid < size; tid += blockDim .x ) {
616+ AccT log_softmax = func (static_cast <AccT>(logits[tid]));
617+ softmax[tid] = static_cast <T>(std::exp (log_softmax));
618+ // loss
619+ if (label_valid) {
620+ ComputeLoss<T, IgnoreIndex>(loss, static_cast <T>(-log_softmax), label_id,
621+ label_value, tid, 1 , 0 , ignore_index);
622+ }
623+ }
624+
625+ // invalid label, write once
626+ if (!label_valid && threadIdx .x == 0 ) {
627+ loss[label_id] = static_cast <T>(0 .0f );
628+ }
629+ }
630+
631+ template <typename T, typename AccT, typename LabelT, int VecSize,
632+ bool IgnoreIndex>
633+ __global__ void VectorizedSoftmaxForward (T* loss, T* softmax, const T* logits,
634+ const LabelT* label,
635+ const int high_dim, const int mid_dim,
636+ const int ignore_index) {
637+ using VecT = kps::details::VectorType<T, VecSize>;
638+
639+ // each block deal with one batch
640+ logits += blockIdx .x * mid_dim;
641+ softmax += blockIdx .x * mid_dim;
642+
643+ const int input_offset = ((uint64_t )logits) % ALIGN_BYTES / sizeof (T);
644+ const int output_offset = ((uint64_t )softmax) % ALIGN_BYTES / sizeof (T);
645+
646+ // 1. reduce max
647+ AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
648+ logits, mid_dim, input_offset, -std::numeric_limits<AccT>::infinity (),
649+ kps::MaxFunctor<AccT>());
650+ max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
651+ max, kps::MaxFunctor<AccT>());
652+
653+ // 2. reduce sum
654+ AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
655+ logits, mid_dim, input_offset, static_cast <AccT>(0 ),
656+ ExpAddFunctor<AccT>(max));
657+ sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
658+ sum, kps::AddFunctor<AccT>());
659+
660+ // 3. softmax
661+ LogSoftmaxForwardFunctor<AccT> func (max, sum);
662+ if (input_offset == output_offset) {
663+ VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
664+ loss, softmax, logits, label, mid_dim, input_offset, func,
665+ ignore_index);
666+ } else {
667+ ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
668+ loss, softmax, logits, label, mid_dim, func, ignore_index);
669+ }
670+ }
671+
672+ template <typename T, typename LabelT, bool IgnoreIndex>
673+ void LaunchVectorizedSoftmaxForward (T* loss, T* softmax, const T* logits,
674+ const LabelT* label, const int high_dim,
675+ const int mid_dim, const int ignore_index,
676+ gpuStream_t stream) {
677+ using AccT = typename details::MPTypeTrait<T>::Type;
678+ constexpr int vec_size = sizeof (float4 ) / sizeof (T);
679+ const int max_num_threads = 1024 ;
680+ int max_block_size = std::min (mid_dim / vec_size, max_num_threads);
681+ if (vec_size > 1 ) {
682+ max_block_size /= 2 ;
683+ }
684+
685+ int block_size = 1 ;
686+ while (block_size < max_block_size) {
687+ block_size *= 2 ;
688+ }
689+ block_size = std::max (block_size, kps::details::kWarpSize );
690+ dim3 grids (high_dim);
691+ dim3 blocks (block_size);
692+ VectorizedSoftmaxForward<T, AccT, LabelT, vec_size,
693+ IgnoreIndex><<<grids, blocks, 0 , stream>>> (
694+ loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
695+ }
696+
422697/*
423698 Wrapper of softmax with cross entropy hard label.
424- - SwitchWarpSoftmaxForward for small size
425- - cudnn function for large size
699+ - SwitchWarpSoftmaxForward for small size when axis == -1
700+ - LaunchVectorizedSoftmaxForward for large size when axis == -1
701+ - cudnn function for axis != -1
426702*/
427703template <typename T, typename LabelT, bool IgnoreIndex>
428704static void SoftmaxWithCrossEntropyHardLabel (
@@ -431,11 +707,17 @@ static void SoftmaxWithCrossEntropyHardLabel(
431707 T* softmax_data, int N, int dim, int D, const int ignore_index) {
432708 auto stream = ctx.stream ();
433709 constexpr int max_dim = 320 ;
434- if (D == 1 && dim <= max_dim) { // small size
435- const SoftmaxMode mode = SoftmaxMode::kCrossEntropy ;
436- SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
437- loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
438- ignore_index, stream);
710+ if (D == 1 ) {
711+ if (dim <= max_dim) { // small size
712+ const SoftmaxMode mode = SoftmaxMode::kCrossEntropy ;
713+ SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
714+ loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
715+ ignore_index, stream);
716+ } else { // large size
717+ LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(
718+ loss_data, softmax_data, logits_data, labels_data, N, dim,
719+ ignore_index, stream);
720+ }
439721 } else {
440722 ScopedTensorDescriptor desc;
441723 std::vector<int > tensor_dims = {N, dim, D, 1 };
0 commit comments