Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,9 @@ __global__ void WarpSoftmaxBackward(T* dst,
constexpr IndexType kBatchSize = (kDimCeil <= 128) ? 2 : 1;
constexpr IndexType kLoopsV = (kLoops >= kVSize) ? (kLoops / kVSize) : 1;
IndexType element_count_v = element_count / kVSize;
IndexType first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
IndexType first_batch =
(static_cast<int64_t>(blockDim.y) * blockIdx.x + threadIdx.y) *
kBatchSize;
IndexType local_batches = min(batch_size - first_batch, kBatchSize);

// max index to read
Expand Down Expand Up @@ -844,8 +846,8 @@ static void GetGridDim(int64_t high_dim,
int max_mp = phi::backends::gpu::GetGPUMultiProcessors(device_id);
int max_threads_per_mp =
phi::backends::gpu::GetGPUMaxThreadsPerMultiProcessor(device_id);
int max_threads = max_threads_per_mp * max_mp;
int num_threads = block.x * block.y;
int64_t max_threads = max_threads_per_mp * max_mp;
int64_t num_threads = static_cast<int64_t>(block.x) * block.y;
int64_t max_num_blocks = max_threads / num_threads;

int64_t grid_x = (low_dim + block.x - 1) / block.x;
Expand Down Expand Up @@ -889,9 +891,10 @@ __global__ void NormalSoftmaxForward(T* output,
const IndexType mid_stride = low_dim;
for (IndexType high_id = blockIdx.y; high_id < high_dim;
high_id += gridDim.y) {
for (IndexType low_id = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexType low_id =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
low_id += static_cast<int64_t>(blockDim.x) * gridDim.x) {
const IndexType input_offset = high_id * high_stride + low_id;

// 1. reduce max
Expand Down Expand Up @@ -948,9 +951,10 @@ __global__ void NormalSoftmaxBackward(T* input_grad,
const IndexType mid_stride = low_dim;
for (IndexType high_id = blockIdx.y; high_id < high_dim;
high_id += gridDim.y) {
for (IndexType low_id = blockIdx.x * blockDim.x + threadIdx.x;
for (IndexType low_id =
static_cast<int64_t>(blockIdx.x) * blockDim.x + threadIdx.x;
low_id < low_dim;
low_id += blockDim.x * gridDim.x) {
low_id += static_cast<int64_t>(blockDim.x) * gridDim.x) {
const IndexType grad_offset = high_id * high_stride + low_id;

// 1. reduce sum
Expand Down Expand Up @@ -1111,7 +1115,7 @@ void LaunchSoftmaxForwardCudnnKernel(const GPUContext& dev_ctx,
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
int64_t offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxForwardCudnnKernel<T>(
Expand Down Expand Up @@ -1189,7 +1193,7 @@ void LaunchSoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
int64_t remaining = tensor_dims[0];
int dim = tensor_dims[1];
int64_t batch_size = std::numeric_limits<int32_t>::max() / dim;
int offset = batch_size * dim;
int64_t offset = batch_size * dim;
while (remaining > 0) {
tensor_dims[0] = std::min<int64_t>(remaining, batch_size);
SoftmaxBackwardCudnnKernel<T>(dev_ctx,
Expand Down
Loading