Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 73 additions & 67 deletions paddle/phi/kernels/gpu/rms_norm_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,31 +516,31 @@ __global__ void cuApplyRMSNorm(T* __restrict__ output_vals,
}

template <typename T, typename U>
__device__ void cuLoadWriteStridedInputs(const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
__device__ void cuLoadWriteStridedInputs(const int64_t i1_block,
const int64_t thr_load_row_off,
const int64_t thr_load_col_off,
const int64_t i2_off,
const int64_t row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const int64_t i1_end,
const int64_t n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
bool rms_only) {
int i1 = i1_block + thr_load_row_off;
int64_t i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
for (int64_t k = 0; k < blockDim.y; ++k) {
int64_t i2 = i2_off + k;
int64_t load_idx = i1 * n2 + i2;
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
Expand All @@ -559,8 +559,8 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
for (int64_t k = 0; k < blockDim.y; ++k) {
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (!rms_only) {
warp_buf1[write_idx] = U(0);
}
Expand All @@ -570,31 +570,31 @@ __device__ void cuLoadWriteStridedInputs(const int i1_block,
}

template <typename T, typename U>
__device__ void cuLoadAddStridedInputs(const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
__device__ void cuLoadAddStridedInputs(const int64_t i1_block,
const int64_t thr_load_row_off,
const int64_t thr_load_col_off,
const int64_t i2_off,
const int64_t row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const T* dout,
const int i1_end,
const int n2,
const int64_t i1_end,
const int64_t n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
bool rms_only) {
int i1 = i1_block + thr_load_row_off;
int64_t i1 = i1_block + thr_load_row_off;
if (i1 < i1_end) {
U curr_mean;
if (!rms_only) {
curr_mean = mean[i1];
}
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1 * n2 + i2;
int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
for (int64_t k = 0; k < blockDim.y; ++k) {
int64_t i2 = i2_off + k;
int64_t load_idx = i1 * n2 + i2;
int64_t write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
Expand All @@ -613,26 +613,29 @@ __device__ void cuLoadAddStridedInputs(const int i1_block,
template <typename T, typename U>
__global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const int64_t n1,
const int64_t n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta,
bool rms_only) {
const int numsegs_n1 =
const int64_t numsegs_n1 =
(n1 + blockDim.y * blockDim.y - 1) / (blockDim.y * blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int i1_beg_plus_one =
const int64_t segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int64_t i1_beg = blockIdx.y * segs_per_block * blockDim.y * blockDim.y;
const int64_t i1_beg_plus_one =
(blockIdx.y + 1) * segs_per_block * blockDim.y * blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x + 1;
const int thr_load_col_off = (threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int thr_load_row_off =
const int64_t i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;

const int64_t row_stride = blockDim.x + 1;
const int64_t thr_load_col_off =
(threadIdx.x * blockDim.y) & (blockDim.x - 1);
const int64_t thr_load_row_off =
(threadIdx.x * blockDim.y) / blockDim.x + threadIdx.y * blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
const int64_t i2_off =
static_cast<int64_t>(blockIdx.x) * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y *
// blockDim.y + (blockDim.y -
Expand All @@ -655,7 +658,7 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
mean,
invvar,
rms_only);
for (int i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
for (int64_t i1_block = i1_beg + blockDim.y * blockDim.y; i1_block < i1_end;
i1_block += blockDim.y * blockDim.y) {
cuLoadAddStridedInputs(i1_block,
thr_load_row_off,
Expand All @@ -677,9 +680,9 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k * blockDim.y;
int idx1 = row1 * row_stride + threadIdx.x;
for (int64_t k = 0; k < blockDim.y; ++k) {
int64_t row1 = threadIdx.y + k * blockDim.y;
int64_t idx1 = row1 * row_stride + threadIdx.x;
if (!rms_only) {
acc1 += warp_buf1[idx1];
}
Expand All @@ -692,25 +695,25 @@ __global__ void cuComputePartGradGammaBeta(const T* __restrict__ dout,
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y / 2; offset > 1; offset /= 2) {
for (int64_t offset = blockDim.y / 2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
int64_t row1 = threadIdx.y;
int64_t row2 = threadIdx.y + offset;
int64_t idx1 = row1 * row_stride + threadIdx.x;
int64_t idx2 = row2 * row_stride + threadIdx.x;
if (!rms_only) {
warp_buf1[idx1] += warp_buf1[idx2];
}
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i2 = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
int64_t row1 = threadIdx.y;
int64_t row2 = threadIdx.y + 1;
int64_t idx1 = row1 * row_stride + threadIdx.x;
int64_t idx2 = row2 * row_stride + threadIdx.x;
if (!rms_only) {
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
}
Expand All @@ -722,15 +725,15 @@ template <typename U, typename V>
__global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
const int64_t n1,
const int64_t n2,
V* grad_gamma,
V* grad_beta,
bool rms_only) {
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
int64_t i2 = static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is
// num_warps
Expand All @@ -749,11 +752,13 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
}
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) {
const int64_t nbsize3 = blockDim.x * blockDim.y / 2;
for (int64_t offset = blockDim.y / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
const int64_t write_idx =
static_cast<int64_t>(threadIdx.y - offset) * blockDim.x +
threadIdx.x;
buf[write_idx] = sum_gamma;
if (!rms_only) {
buf[write_idx + nbsize3] = sum_beta;
Expand All @@ -762,7 +767,8 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
const int64_t read_idx =
static_cast<int64_t>(threadIdx.y) * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
if (!rms_only) {
sum_beta += buf[read_idx + nbsize3];
Expand All @@ -783,15 +789,15 @@ __global__ void cuComputeGradGammaBeta(const U* part_grad_gamma,
template <typename T, typename U, typename V>
__global__ void cuComputeGradInput(const T* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const int64_t n1,
const int64_t n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input,
bool rms_only) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
for (int64_t i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
U c_mean;
Expand All @@ -804,9 +810,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4 * thrx;
int64_t l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
for (int64_t k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
const U gamma_tmp = static_cast<U>(gamma[l + k]);
Expand All @@ -830,9 +836,9 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
}
}
} else {
int l = 4 * thrx;
int64_t l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
for (int64_t k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
if (!rms_only) {
Expand Down Expand Up @@ -904,7 +910,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
for (int64_t l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * static_cast<U>(gamma[l]);
Expand All @@ -918,7 +924,7 @@ __global__ void cuComputeGradInput(const T* __restrict__ dout,
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
for (int64_t l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
Expand Down
17 changes: 11 additions & 6 deletions paddle/phi/kernels/gpu/rms_norm_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void HostRMSNormGradient(const Context& dev_ctx,
const T* dout,
const U* invvar,
const DenseTensor& input,
int n1,
int n2,
int64_t n1,
int64_t n2,
const V* gamma,
double epsilon,
T* grad_input,
Expand Down Expand Up @@ -125,10 +125,15 @@ void cuda_rms_norm_gradient(const Context& dev_ctx,
DenseTensor* grad_x,
DenseTensor* grad_scale,
const int begin_norm_axis) {
const auto x_dims = x.dims();
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int rows = static_cast<int>(matrix_dim[0]);
int cols = static_cast<int>(matrix_dim[1]);
int64_t rows = 1;
int64_t cols = 1;
for (int i = 0; i < begin_norm_axis; i++) {
rows *= x.dims()[i];
}

for (int i = begin_norm_axis; i < x.dims().size(); i++) {
cols *= x.dims()[i];
}
dev_ctx.template Alloc<T>(grad_x);

DISPATCH_SCALE_TYPE(T,
Expand Down
Loading
Loading