@@ -516,31 +516,31 @@ __global__ void cuApplyRMSNorm(T* __restrict__ output_vals,
516516}
517517
518518template <typename T, typename U>
519- __device__ void cuLoadWriteStridedInputs (const int i1_block,
520- const int thr_load_row_off,
521- const int thr_load_col_off,
522- const int i2_off,
523- const int row_stride,
519+ __device__ void cuLoadWriteStridedInputs (const int64_t i1_block,
520+ const int64_t thr_load_row_off,
521+ const int64_t thr_load_col_off,
522+ const int64_t i2_off,
523+ const int64_t row_stride,
524524 U* warp_buf1,
525525 U* warp_buf2,
526526 const T* input,
527527 const T* dout,
528- const int i1_end,
529- const int n2,
528+ const int64_t i1_end,
529+ const int64_t n2,
530530 const U* __restrict__ mean,
531531 const U* __restrict__ invvar,
532532 bool rms_only) {
533- int i1 = i1_block + thr_load_row_off;
533+ int64_t i1 = i1_block + thr_load_row_off;
534534 if (i1 < i1_end) {
535535 U curr_mean;
536536 if (!rms_only) {
537537 curr_mean = mean[i1];
538538 }
539539 U curr_invvar = invvar[i1];
540- for (int k = 0 ; k < blockDim.y ; ++k) {
541- int i2 = i2_off + k;
542- int load_idx = i1 * n2 + i2;
543- int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
540+ for (int64_t k = 0 ; k < blockDim.y ; ++k) {
541+ int64_t i2 = i2_off + k;
542+ int64_t load_idx = i1 * n2 + i2;
543+ int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
544544 if (i2 < n2) {
545545 U curr_input = static_cast <U>(input[load_idx]);
546546 U curr_dout = static_cast <U>(dout[load_idx]);
@@ -559,8 +559,8 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
559559 }
560560 }
561561 } else {
562- for (int k = 0 ; k < blockDim.y ; ++k) {
563- int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
562+ for (int64_t k = 0 ; k < blockDim.y ; ++k) {
563+ int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
564564 if (!rms_only) {
565565 warp_buf1[write_idx] = U (0 );
566566 }
@@ -570,31 +570,31 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
570570}
571571
572572template <typename T, typename U>
573- __device__ void cuLoadAddStridedInputs (const int i1_block,
574- const int thr_load_row_off,
575- const int thr_load_col_off,
576- const int i2_off,
577- const int row_stride,
573+ __device__ void cuLoadAddStridedInputs (const int64_t i1_block,
574+ const int64_t thr_load_row_off,
575+ const int64_t thr_load_col_off,
576+ const int64_t i2_off,
577+ const int64_t row_stride,
578578 U* warp_buf1,
579579 U* warp_buf2,
580580 const T* input,
581581 const T* dout,
582- const int i1_end,
583- const int n2,
582+ const int64_t i1_end,
583+ const int64_t n2,
584584 const U* __restrict__ mean,
585585 const U* __restrict__ invvar,
586586 bool rms_only) {
587- int i1 = i1_block + thr_load_row_off;
587+ int64_t i1 = i1_block + thr_load_row_off;
588588 if (i1 < i1_end) {
589589 U curr_mean;
590590 if (!rms_only) {
591591 curr_mean = mean[i1];
592592 }
593593 U curr_invvar = invvar[i1];
594- for (int k = 0 ; k < blockDim.y ; ++k) {
595- int i2 = i2_off + k;
596- int load_idx = i1 * n2 + i2;
597- int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
594+ for (int64_t k = 0 ; k < blockDim.y ; ++k) {
595+ int64_t i2 = i2_off + k;
596+ int64_t load_idx = i1 * n2 + i2;
597+ int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
598598 if (i2 < n2) {
599599 U curr_input = static_cast <U>(input[load_idx]);
600600 U curr_dout = static_cast <U>(dout[load_idx]);
@@ -613,26 +613,29 @@ __device__ void cuLoadAddStridedInputs(const int i1_block,
613613template <typename T, typename U>
614614__global__ void cuComputePartGradGammaBeta (const T* __restrict__ dout,
615615 const T* __restrict__ input,
616- const int n1,
617- const int n2,
616+ const int64_t n1,
617+ const int64_t n2,
618618 const U* __restrict__ mean,
619619 const U* __restrict__ invvar,
620620 U epsilon,
621621 U* part_grad_gamma,
622622 U* part_grad_beta,
623623 bool rms_only) {
624- const int numsegs_n1 =
624+ const int64_t numsegs_n1 =
625625 (n1 + blockDim.y * blockDim.y - 1 ) / (blockDim.y * blockDim.y );
626- const int segs_per_block = (numsegs_n1 + gridDim.y - 1 ) / gridDim.y ;
627- const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y ;
628- const int i1_beg_plus_one =
626+ const int64_t segs_per_block = (numsegs_n1 + gridDim.y - 1 ) / gridDim.y ;
627+ const int64_t i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y ;
628+ const int64_t i1_beg_plus_one =
629629 (blockIdx.y + 1 ) * segs_per_block * blockDim.y * blockDim.y ;
630- const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
631- const int row_stride = blockDim.x + 1 ;
632- const int thr_load_col_off = (threadIdx.x * blockDim.y ) & (blockDim.x - 1 );
633- const int thr_load_row_off =
630+ const int64_t i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
631+
632+ const int64_t row_stride = blockDim.x + 1 ;
633+ const int64_t thr_load_col_off =
634+ (threadIdx.x * blockDim.y ) & (blockDim.x - 1 );
635+ const int64_t thr_load_row_off =
634636 (threadIdx.x * blockDim.y ) / blockDim.x + threadIdx.y * blockDim.y ;
635- const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
637+ const int64_t i2_off =
638+ static_cast <int64_t >(blockIdx.x ) * blockDim.x + thr_load_col_off;
636639 SharedMemory<U> shared;
637640 U* buf = shared.getPointer (); // buf has at least blockDim.x * blockDim.y *
638641 // blockDim.y + (blockDim.y -
@@ -655,7 +658,7 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
655658 mean,
656659 invvar,
657660 rms_only);
658- for (int i1_block = i1_beg + blockDim.y * blockDim.y ; i1_block < i1_end;
661+ for (int64_t i1_block = i1_beg + blockDim.y * blockDim.y ; i1_block < i1_end;
659662 i1_block += blockDim.y * blockDim.y ) {
660663 cuLoadAddStridedInputs (i1_block,
661664 thr_load_row_off,
@@ -677,9 +680,9 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
677680 // sum within each warp
678681 U acc1 = U (0 );
679682 U acc2 = U (0 );
680- for (int k = 0 ; k < blockDim.y ; ++k) {
681- int row1 = threadIdx.y + k * blockDim.y ;
682- int idx1 = row1 * row_stride + threadIdx.x ;
683+ for (int64_t k = 0 ; k < blockDim.y ; ++k) {
684+ int64_t row1 = threadIdx.y + k * blockDim.y ;
685+ int64_t idx1 = row1 * row_stride + threadIdx.x ;
683686 if (!rms_only) {
684687 acc1 += warp_buf1[idx1];
685688 }
@@ -692,25 +695,25 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
692695 warp_buf2[threadIdx.y * row_stride + threadIdx.x ] = acc2;
693696 __syncthreads ();
694697 // sum all warps
695- for (int offset = blockDim.y / 2 ; offset > 1 ; offset /= 2 ) {
698+ for (int64_t offset = blockDim.y / 2 ; offset > 1 ; offset /= 2 ) {
696699 if (threadIdx.y < offset) {
697- int row1 = threadIdx.y ;
698- int row2 = threadIdx.y + offset;
699- int idx1 = row1 * row_stride + threadIdx.x ;
700- int idx2 = row2 * row_stride + threadIdx.x ;
700+ int64_t row1 = threadIdx.y ;
701+ int64_t row2 = threadIdx.y + offset;
702+ int64_t idx1 = row1 * row_stride + threadIdx.x ;
703+ int64_t idx2 = row2 * row_stride + threadIdx.x ;
701704 if (!rms_only) {
702705 warp_buf1[idx1] += warp_buf1[idx2];
703706 }
704707 warp_buf2[idx1] += warp_buf2[idx2];
705708 }
706709 __syncthreads ();
707710 }
708- int i2 = blockIdx.x * blockDim.x + threadIdx.x ;
711+ int64_t i2 = static_cast < int64_t >( blockIdx.x ) * blockDim.x + threadIdx.x ;
709712 if (threadIdx.y == 0 && i2 < n2) {
710- int row1 = threadIdx.y ;
711- int row2 = threadIdx.y + 1 ;
712- int idx1 = row1 * row_stride + threadIdx.x ;
713- int idx2 = row2 * row_stride + threadIdx.x ;
713+ int64_t row1 = threadIdx.y ;
714+ int64_t row2 = threadIdx.y + 1 ;
715+ int64_t idx1 = row1 * row_stride + threadIdx.x ;
716+ int64_t idx2 = row2 * row_stride + threadIdx.x ;
714717 if (!rms_only) {
715718 part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
716719 }
@@ -722,15 +725,15 @@ template <typename U, typename V>
722725__global__ void cuComputeGradGammaBeta (const U* part_grad_gamma,
723726 const U* part_grad_beta,
724727 const int part_size,
725- const int n1,
726- const int n2,
728+ const int64_t n1,
729+ const int64_t n2,
727730 V* grad_gamma,
728731 V* grad_beta,
729732 bool rms_only) {
730733 // sum partial gradients for gamma and beta
731734 SharedMemory<U> shared;
732735 U* buf = shared.getPointer ();
733- int i2 = blockIdx.x * blockDim.x + threadIdx.x ;
736+ int64_t i2 = static_cast < int64_t >( blockIdx.x ) * blockDim.x + threadIdx.x ;
734737 if (i2 < n2) {
735738 // each warp does sequential reductions until reduced part_size is
736739 // num_warps
@@ -749,11 +752,13 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
749752 }
750753 }
751754 // inter-warp reductions
752- const int nbsize3 = blockDim.x * blockDim.y / 2 ;
753- for (int offset = blockDim.y / 2 ; offset >= 1 ; offset /= 2 ) {
755+ const int64_t nbsize3 = blockDim.x * blockDim.y / 2 ;
756+ for (int64_t offset = blockDim.y / 2 ; offset >= 1 ; offset /= 2 ) {
754757 // top half write to shared memory
755758 if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
756- const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x ;
759+ const int64_t write_idx =
760+ static_cast <int64_t >(threadIdx.y - offset) * blockDim.x +
761+ threadIdx.x ;
757762 buf[write_idx] = sum_gamma;
758763 if (!rms_only) {
759764 buf[write_idx + nbsize3] = sum_beta;
@@ -762,7 +767,8 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
762767 __syncthreads ();
763768 // bottom half sums
764769 if (threadIdx.y < offset) {
765- const int read_idx = threadIdx.y * blockDim.x + threadIdx.x ;
770+ const int64_t read_idx =
771+ static_cast <int64_t >(threadIdx.y ) * blockDim.x + threadIdx.x ;
766772 sum_gamma += buf[read_idx];
767773 if (!rms_only) {
768774 sum_beta += buf[read_idx + nbsize3];
@@ -783,15 +789,15 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
783789template <typename T, typename U, typename V>
784790__global__ void cuComputeGradInput (const T* __restrict__ dout,
785791 const T* __restrict__ input,
786- const int n1,
787- const int n2,
792+ const int64_t n1,
793+ const int64_t n2,
788794 const U* __restrict__ mean,
789795 const U* __restrict__ invvar,
790796 U epsilon,
791797 const V* gamma,
792798 T* grad_input,
793799 bool rms_only) {
794- for (auto i1 = blockIdx.y ; i1 < n1; i1 += gridDim.y ) {
800+ for (int64_t i1 = blockIdx.y ; i1 < n1; i1 += gridDim.y ) {
795801 U sum_loss1 = U (0 );
796802 U sum_loss2 = U (0 );
797803 U c_mean;
@@ -804,9 +810,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
804810 const int numx = blockDim.x * blockDim.y ;
805811 const int thrx = threadIdx.x + threadIdx.y * blockDim.x ;
806812 if (gamma != NULL ) {
807- int l = 4 * thrx;
813+ int64_t l = 4 * thrx;
808814 for (; l + 3 < n2; l += 4 * numx) {
809- for (int k = 0 ; k < 4 ; ++k) {
815+ for (int64_t k = 0 ; k < 4 ; ++k) {
810816 const U c_h = static_cast <U>(k_input[l + k]);
811817 const U c_loss = static_cast <U>(k_dout[l + k]);
812818 const U gamma_tmp = static_cast <U>(gamma[l + k]);
@@ -830,9 +836,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
830836 }
831837 }
832838 } else {
833- int l = 4 * thrx;
839+ int64_t l = 4 * thrx;
834840 for (; l + 3 < n2; l += 4 * numx) {
835- for (int k = 0 ; k < 4 ; ++k) {
841+ for (int64_t k = 0 ; k < 4 ; ++k) {
836842 const U c_h = static_cast <U>(k_input[l + k]);
837843 const U c_loss = static_cast <U>(k_dout[l + k]);
838844 if (!rms_only) {
@@ -904,7 +910,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
904910 U term1 = (U (1 ) / fH ) * c_invvar;
905911 T* k_grad_input = grad_input + i1 * n2;
906912 if (gamma != NULL ) {
907- for (int l = thrx; l < n2; l += numx) {
913+ for (int64_t l = thrx; l < n2; l += numx) {
908914 const U c_h = static_cast <U>(k_input[l]);
909915 const U c_loss = static_cast <U>(k_dout[l]);
910916 U f_grad_input = fH * c_loss * static_cast <U>(gamma[l]);
@@ -918,7 +924,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
918924 k_grad_input[l] = static_cast <T>(f_grad_input);
919925 }
920926 } else {
921- for (int l = thrx; l < n2; l += numx) {
927+ for (int64_t l = thrx; l < n2; l += numx) {
922928 const U c_h = static_cast <U>(k_input[l]);
923929 const U c_loss = static_cast <U>(k_dout[l]);
924930 U f_grad_input = fH * c_loss;
0 commit comments