Skip to content

Commit 971eac1

Browse files
authored
Fix paddle.incubate.nn.functional.fused_rms_norm big Tensor (#74055)
* fix fused_rms_norm big shape * fix codestyle * fix codestyle * fix * fix blockIdx.x to int64
1 parent 4953c94 commit 971eac1

File tree

3 files changed

+124
-104
lines changed

3 files changed

+124
-104
lines changed

paddle/phi/kernels/gpu/rms_norm_funcs.h

Lines changed: 73 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -516,31 +516,31 @@ __global__ void cuApplyRMSNorm(T* __restrict__ output_vals,
516516
}
517517

518518
template <typename T, typename U>
519-
__device__ void cuLoadWriteStridedInputs(const int i1_block,
520-
const int thr_load_row_off,
521-
const int thr_load_col_off,
522-
const int i2_off,
523-
const int row_stride,
519+
__device__ void cuLoadWriteStridedInputs(const int64_t i1_block,
520+
const int64_t thr_load_row_off,
521+
const int64_t thr_load_col_off,
522+
const int64_t i2_off,
523+
const int64_t row_stride,
524524
U* warp_buf1,
525525
U* warp_buf2,
526526
const T* input,
527527
const T* dout,
528-
const int i1_end,
529-
const int n2,
528+
const int64_t i1_end,
529+
const int64_t n2,
530530
const U* __restrict__ mean,
531531
const U* __restrict__ invvar,
532532
bool rms_only) {
533-
int i1 = i1_block + thr_load_row_off;
533+
int64_t i1 = i1_block + thr_load_row_off;
534534
if (i1 < i1_end) {
535535
U curr_mean;
536536
if (!rms_only) {
537537
curr_mean = mean[i1];
538538
}
539539
U curr_invvar = invvar[i1];
540-
for (int k = 0; k < blockDim.y; ++k) {
541-
int i2 = i2_off + k;
542-
int load_idx = i1 * n2 + i2;
543-
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
540+
for (int64_t k = 0; k < blockDim.y; ++k) {
541+
int64_t i2 = i2_off + k;
542+
int64_t load_idx = i1 * n2 + i2;
543+
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
544544
if (i2 < n2) {
545545
U curr_input = static_cast<U>(input[load_idx]);
546546
U curr_dout = static_cast<U>(dout[load_idx]);
@@ -559,8 +559,8 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
559559
}
560560
}
561561
} else {
562-
for (int k = 0; k < blockDim.y; ++k) {
563-
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
562+
for (int64_t k = 0; k < blockDim.y; ++k) {
563+
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
564564
if (!rms_only) {
565565
warp_buf1[write_idx] = U(0);
566566
}
@@ -570,31 +570,31 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
570570
}
571571

572572
template <typename T, typename U>
573-
__device__ void cuLoadAddStridedInputs(const int i1_block,
574-
const int thr_load_row_off,
575-
const int thr_load_col_off,
576-
const int i2_off,
577-
const int row_stride,
573+
__device__ void cuLoadAddStridedInputs(const int64_t i1_block,
574+
const int64_t thr_load_row_off,
575+
const int64_t thr_load_col_off,
576+
const int64_t i2_off,
577+
const int64_t row_stride,
578578
U* warp_buf1,
579579
U* warp_buf2,
580580
const T* input,
581581
const T* dout,
582-
const int i1_end,
583-
const int n2,
582+
const int64_t i1_end,
583+
const int64_t n2,
584584
const U* __restrict__ mean,
585585
const U* __restrict__ invvar,
586586
bool rms_only) {
587-
int i1 = i1_block + thr_load_row_off;
587+
int64_t i1 = i1_block + thr_load_row_off;
588588
if (i1 < i1_end) {
589589
U curr_mean;
590590
if (!rms_only) {
591591
curr_mean = mean[i1];
592592
}
593593
U curr_invvar = invvar[i1];
594-
for (int k = 0; k < blockDim.y; ++k) {
595-
int i2 = i2_off + k;
596-
int load_idx = i1 * n2 + i2;
597-
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
594+
for (int64_t k = 0; k < blockDim.y; ++k) {
595+
int64_t i2 = i2_off + k;
596+
int64_t load_idx = i1 * n2 + i2;
597+
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
598598
if (i2 < n2) {
599599
U curr_input = static_cast<U>(input[load_idx]);
600600
U curr_dout = static_cast<U>(dout[load_idx]);
@@ -613,26 +613,29 @@ __device__ void cuLoadAddStridedInputs(const int i1_block,
613613
template <typename T, typename U>
614614
__global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
615615
const T* __restrict__ input,
616-
const int n1,
617-
const int n2,
616+
const int64_t n1,
617+
const int64_t n2,
618618
const U* __restrict__ mean,
619619
const U* __restrict__ invvar,
620620
U epsilon,
621621
U* part_grad_gamma,
622622
U* part_grad_beta,
623623
bool rms_only) {
624-
const int numsegs_n1 =
624+
const int64_t numsegs_n1 =
625625
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
626-
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
627-
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
628-
const int i1_beg_plus_one =
626+
const int64_t segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
627+
const int64_t i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
628+
const int64_t i1_beg_plus_one =
629629
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
630-
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
631-
const int row_stride = blockDim.x + 1;
632-
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
633-
const int thr_load_row_off =
630+
const int64_t i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
631+
632+
const int64_t row_stride = blockDim.x + 1;
633+
const int64_t thr_load_col_off =
634+
(threadIdx.x * blockDim.y) & (blockDim.x - 1);
635+
const int64_t thr_load_row_off =
634636
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
635-
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
637+
const int64_t i2_off =
638+
static_cast<int64_t>(blockIdx.x) * blockDim.x + thr_load_col_off;
636639
SharedMemory<U> shared;
637640
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
638641
// blockDim.y + (blockDim.y -
@@ -655,7 +658,7 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
655658
mean,
656659
invvar,
657660
rms_only);
658-
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
661+
for (int64_t i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
659662
i1_block += blockDim.y * blockDim.y) {
660663
cuLoadAddStridedInputs(i1_block,
661664
thr_load_row_off,
@@ -677,9 +680,9 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
677680
// sum within each warp
678681
U acc1 = U(0);
679682
U acc2 = U(0);
680-
for (int k = 0; k < blockDim.y; ++k) {
681-
int row1 = threadIdx.y + k * blockDim.y;
682-
int idx1 = row1 * row_stride + threadIdx.x;
683+
for (int64_t k = 0; k < blockDim.y; ++k) {
684+
int64_t row1 = threadIdx.y + k * blockDim.y;
685+
int64_t idx1 = row1 * row_stride + threadIdx.x;
683686
if (!rms_only) {
684687
acc1 += warp_buf1[idx1];
685688
}
@@ -692,25 +695,25 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
692695
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
693696
__syncthreads();
694697
// sum all warps
695-
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
698+
for (int64_t offset = blockDim.y / 2; offset > 1; offset /= 2) {
696699
if (threadIdx.y < offset) {
697-
int row1 = threadIdx.y;
698-
int row2 = threadIdx.y + offset;
699-
int idx1 = row1 * row_stride + threadIdx.x;
700-
int idx2 = row2 * row_stride + threadIdx.x;
700+
int64_t row1 = threadIdx.y;
701+
int64_t row2 = threadIdx.y + offset;
702+
int64_t idx1 = row1 * row_stride + threadIdx.x;
703+
int64_t idx2 = row2 * row_stride + threadIdx.x;
701704
if (!rms_only) {
702705
warp_buf1[idx1] += warp_buf1[idx2];
703706
}
704707
warp_buf2[idx1] += warp_buf2[idx2];
705708
}
706709
__syncthreads();
707710
}
708-
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
711+
int64_t i2 = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
709712
if (threadIdx.y == 0 && i2 < n2) {
710-
int row1 = threadIdx.y;
711-
int row2 = threadIdx.y + 1;
712-
int idx1 = row1 * row_stride + threadIdx.x;
713-
int idx2 = row2 * row_stride + threadIdx.x;
713+
int64_t row1 = threadIdx.y;
714+
int64_t row2 = threadIdx.y + 1;
715+
int64_t idx1 = row1 * row_stride + threadIdx.x;
716+
int64_t idx2 = row2 * row_stride + threadIdx.x;
714717
if (!rms_only) {
715718
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
716719
}
@@ -722,15 +725,15 @@ template <typename U, typename V>
722725
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
723726
const U* part_grad_beta,
724727
const int part_size,
725-
const int n1,
726-
const int n2,
728+
const int64_t n1,
729+
const int64_t n2,
727730
V* grad_gamma,
728731
V* grad_beta,
729732
bool rms_only) {
730733
// sum partial gradients for gamma and beta
731734
SharedMemory<U> shared;
732735
U* buf = shared.getPointer();
733-
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
736+
int64_t i2 = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
734737
if (i2 < n2) {
735738
// each warp does sequential reductions until reduced part_size is
736739
// num_warps
@@ -749,11 +752,13 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
749752
}
750753
}
751754
// inter-warp reductions
752-
const int nbsize3 = blockDim.x * blockDim.y / 2;
753-
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
755+
const int64_t nbsize3 = blockDim.x * blockDim.y / 2;
756+
for (int64_t offset = blockDim.y / 2; offset >= 1; offset /= 2) {
754757
// top half write to shared memory
755758
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
756-
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
759+
const int64_t write_idx =
760+
static_cast<int64_t>(threadIdx.y - offset) * blockDim.x +
761+
threadIdx.x;
757762
buf[write_idx] = sum_gamma;
758763
if (!rms_only) {
759764
buf[write_idx + nbsize3] = sum_beta;
@@ -762,7 +767,8 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
762767
__syncthreads();
763768
// bottom half sums
764769
if (threadIdx.y < offset) {
765-
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
770+
const int64_t read_idx =
771+
static_cast<int64_t>(threadIdx.y) * blockDim.x + threadIdx.x;
766772
sum_gamma += buf[read_idx];
767773
if (!rms_only) {
768774
sum_beta += buf[read_idx + nbsize3];
@@ -783,15 +789,15 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
783789
template <typename T, typename U, typename V>
784790
__global__ void cuComputeGradInput(const T* __restrict__ dout,
785791
const T* __restrict__ input,
786-
const int n1,
787-
const int n2,
792+
const int64_t n1,
793+
const int64_t n2,
788794
const U* __restrict__ mean,
789795
const U* __restrict__ invvar,
790796
U epsilon,
791797
const V* gamma,
792798
T* grad_input,
793799
bool rms_only) {
794-
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
800+
for (int64_t i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
795801
U sum_loss1 = U(0);
796802
U sum_loss2 = U(0);
797803
U c_mean;
@@ -804,9 +810,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
804810
const int numx = blockDim.x * blockDim.y;
805811
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
806812
if (gamma != NULL) {
807-
int l = 4 * thrx;
813+
int64_t l = 4 * thrx;
808814
for (; l + 3 < n2; l += 4 * numx) {
809-
for (int k = 0; k < 4; ++k) {
815+
for (int64_t k = 0; k < 4; ++k) {
810816
const U c_h = static_cast<U>(k_input[l + k]);
811817
const U c_loss = static_cast<U>(k_dout[l + k]);
812818
const U gamma_tmp = static_cast<U>(gamma[l + k]);
@@ -830,9 +836,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
830836
}
831837
}
832838
} else {
833-
int l = 4 * thrx;
839+
int64_t l = 4 * thrx;
834840
for (; l + 3 < n2; l += 4 * numx) {
835-
for (int k = 0; k < 4; ++k) {
841+
for (int64_t k = 0; k < 4; ++k) {
836842
const U c_h = static_cast<U>(k_input[l + k]);
837843
const U c_loss = static_cast<U>(k_dout[l + k]);
838844
if (!rms_only) {
@@ -904,7 +910,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
904910
U term1 = (U(1) / fH) * c_invvar;
905911
T* k_grad_input = grad_input + i1 * n2;
906912
if (gamma != NULL) {
907-
for (int l = thrx; l < n2; l += numx) {
913+
for (int64_t l = thrx; l < n2; l += numx) {
908914
const U c_h = static_cast<U>(k_input[l]);
909915
const U c_loss = static_cast<U>(k_dout[l]);
910916
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
@@ -918,7 +924,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
918924
k_grad_input[l] = static_cast<T>(f_grad_input);
919925
}
920926
} else {
921-
for (int l = thrx; l < n2; l += numx) {
927+
for (int64_t l = thrx; l < n2; l += numx) {
922928
const U c_h = static_cast<U>(k_input[l]);
923929
const U c_loss = static_cast<U>(k_dout[l]);
924930
U f_grad_input = fH * c_loss;

paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ void HostRMSNormGradient(const Context& dev_ctx,
4848
const T* dout,
4949
const U* invvar,
5050
const DenseTensor& input,
51-
int n1,
52-
int n2,
51+
int64_t n1,
52+
int64_t n2,
5353
const V* gamma,
5454
double epsilon,
5555
T* grad_input,
@@ -126,10 +126,15 @@ void cuda_rms_norm_gradient(const Context& dev_ctx,
126126
DenseTensor* grad_x,
127127
DenseTensor* grad_scale,
128128
const int begin_norm_axis) {
129-
const auto x_dims = x.dims();
130-
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
131-
int rows = static_cast<int>(matrix_dim[0]);
132-
int cols = static_cast<int>(matrix_dim[1]);
129+
int64_t rows = 1;
130+
int64_t cols = 1;
131+
for (int i = 0; i < begin_norm_axis; i++) {
132+
rows *= x.dims()[i];
133+
}
134+
135+
for (int i = begin_norm_axis; i < x.dims().size(); i++) {
136+
cols *= x.dims()[i];
137+
}
133138
dev_ctx.template Alloc<T>(grad_x);
134139
if (x.numel() == 0) {
135140
if (grad_scale) {

0 commit comments

Comments
 (0)