Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions paddle/phi/kernels/gpu/nll_loss.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaximumNumBlocks = 4096;
static const int NTHREADS = 32;
static inline int NumBlocks(const int N) {
static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaximumNumBlocks);
static_cast<int64_t>(kNumMaximumNumBlocks));
}

template <typename T>
Expand All @@ -54,7 +54,7 @@ __global__ void GPUNLLLossForward1D_no_reduce(T* out_data,
}
}

template <typename T>
template <typename T, typename AccT>
__global__ void GPUNLLLossForward1D_with_reduce(T* out_data,
T* total_weight_data,
const T* x_data,
Expand All @@ -67,7 +67,7 @@ __global__ void GPUNLLLossForward1D_with_reduce(T* out_data,
__shared__ T sharedInputs[NTHREADS], sharedWeights[NTHREADS];
sharedInputs[threadIdx.x] = 0;
sharedWeights[threadIdx.x] = 0;
int i;
int64_t i;
for (i = threadIdx.x; i < batch_size; i += NTHREADS) {
const auto cur_label = label_data[i];
if (cur_label != ignore_index) {
Expand All @@ -83,8 +83,8 @@ __global__ void GPUNLLLossForward1D_with_reduce(T* out_data,

if (threadIdx.x == 0) {
*out_data = *total_weight_data = 0;
T output_val = 0;
T total_weight_val = 0;
AccT output_val = 0;
AccT total_weight_val = 0;
for (i = 0; i < NTHREADS; ++i) {
output_val += sharedInputs[i];
total_weight_val += sharedWeights[i];
Expand All @@ -106,15 +106,15 @@ __global__ void GPUNLLLossForward1D_with_reduce(T* out_data,
// call. However, if smem will be used, e.g., this function is called in a loop,
// then __syncthreads is needed either before or afterwards to prevent non-0
// threads overriding smem in the next loop before num-0 thread reads from it.
template <typename T, typename ReduceOp, int N>
template <typename T, typename ReduceOp, int64_t N>
__device__ void reduceNValuesInBlock(T* smem,
T threadVals[N],
const unsigned int numVals,
ReduceOp reduceOp,
T init) {
if (numVals == 0) {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
threadVals[i] = init;
}
return;
Expand All @@ -125,7 +125,7 @@ __device__ void reduceNValuesInBlock(T* smem,
// all of the values for the second threadVal for each thread in the block
if (threadIdx.x < numVals) {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
smem[i * numVals + threadIdx.x] = threadVals[i];
}
}
Expand All @@ -139,19 +139,19 @@ __device__ void reduceNValuesInBlock(T* smem,

if (numVals > warpSize && ((threadIdx.x / warpSize) == 0)) {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
threadVals[i] = threadIdx.x < numVals ? threadVals[i] : init;
}

for (int i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
for (int64_t i = warpSize + threadIdx.x; i < numVals; i += warpSize) {
#pragma unroll
for (int j = 0; j < N; ++j) {
for (int64_t j = 0; j < N; ++j) {
threadVals[j] = reduceOp(threadVals[j], smem[j * numVals + i]);
}
}

#pragma unroll
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
smem[i * numLanesParticipating + threadIdx.x] = threadVals[i];
}
}
Expand All @@ -160,16 +160,16 @@ __device__ void reduceNValuesInBlock(T* smem,
if (threadIdx.x == 0) {
if (numLanesParticipating == 32) {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int64_t i = 0; i < N; ++i) {
#pragma unroll
for (int j = 1; j < 32; ++j) {
for (int64_t j = 1; j < 32; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * 32 + j]);
}
}
} else {
#pragma unroll
for (int i = 0; i < N; ++i) {
for (int j = 1; j < numLanesParticipating; ++j) {
for (int64_t i = 0; i < N; ++i) {
for (int64_t j = 1; j < numLanesParticipating; ++j) {
threadVals[i] = reduceOp(threadVals[i], smem[i * numVals + j]);
}
}
Expand Down Expand Up @@ -228,7 +228,7 @@ __global__ void GPUNLLLossForward2D_no_reduce(T* out_data,
}
}

template <typename T>
template <typename T, typename AccT>
__global__ void GPUNLLLossForward2D_with_reduce(T* out_data,
T* total_weight_data,
const T* x_data,
Expand All @@ -239,10 +239,10 @@ __global__ void GPUNLLLossForward2D_with_reduce(T* out_data,
const int64_t map_nelem,
const int64_t blocks_per_sample,
const int64_t ignore_index) {
__shared__ T partial_sums[kNumCUDAThreads];
__shared__ AccT partial_sums[kNumCUDAThreads];
int64_t i;
T input_sum = 0;
T acc_weight = 0;
AccT input_sum = 0;
AccT acc_weight = 0;
*out_data = 0;
*total_weight_data = 0;

Expand All @@ -257,17 +257,17 @@ __global__ void GPUNLLLossForward2D_with_reduce(T* out_data,
if (cur_label != ignore_index) {
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
"label should not be out of bounds.");
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
const AccT cur_weight = weight_data ? weight_data[cur_label] : (T)1;
input_sum -= x_data[ioffset + i + map_nelem * cur_label] * cur_weight;
acc_weight += cur_weight;
}
}

input_sum =
reduceBlock(partial_sums, blockDim.x, input_sum, thrust::plus<T>(), (T)0);
input_sum = reduceBlock(
partial_sums, blockDim.x, input_sum, thrust::plus<AccT>(), (AccT)0);
__syncthreads();
acc_weight = reduceBlock(
partial_sums, blockDim.x, acc_weight, thrust::plus<T>(), (T)0);
partial_sums, blockDim.x, acc_weight, thrust::plus<AccT>(), (AccT)0);

if (threadIdx.x == 0) {
phi::CudaAtomicAdd(total_weight_data, acc_weight);
Expand Down Expand Up @@ -313,7 +313,7 @@ __global__ void GPUNLLLossBackward1D_with_reduce(T* dx_data,
if (*total_weight_data <= 0) {
return;
}
int i;
int64_t i;
const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1;
for (i = threadIdx.x; i < batch_size; i += NTHREADS) {
const int64_t cur_label = label_data[i];
Expand Down Expand Up @@ -370,10 +370,10 @@ __global__ void GPUNLLLossBackward2D_with_reduce(
}
int64_t i;
const T norm = size_average ? (T)(1 / *total_weight_data) : (T)1;
int sample = blockIdx.x / blocks_per_sample;
int step = blockDim.x * blocks_per_sample;
int toffset = sample * map_nelem;
int ioffset = sample * map_nelem * n_classes;
int64_t sample = blockIdx.x / blocks_per_sample;
int64_t step = blockDim.x * blocks_per_sample;
int64_t toffset = sample * map_nelem;
int64_t ioffset = sample * map_nelem * n_classes;
for (i = (blockIdx.x % blocks_per_sample) * blockDim.x + threadIdx.x;
i < map_nelem;
i += step) {
Expand Down
15 changes: 8 additions & 7 deletions paddle/phi/kernels/gpu/nll_loss_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/phi/kernels/nll_loss_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/nll_loss.h"

Expand Down Expand Up @@ -43,10 +44,10 @@ void NllLossRawKernel(const Context& dev_ctx,
auto x_dims = x->dims();
auto batch_size = x_dims[0];
auto n_classes = x_dims[1];
int64_t size_average = (int64_t)(reduction == "mean");

int size_average = static_cast<int>(reduction == "mean");
using AccT = typename phi::dtype::MPTypeTrait<T>::Type;
if (x_dims.size() == 2) {
int blocks = NumBlocks(batch_size);
int64_t blocks = NumBlocks(batch_size);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossForward1D_no_reduce<T>
Expand All @@ -58,7 +59,7 @@ void NllLossRawKernel(const Context& dev_ctx,
n_classes,
ignore_index);
} else {
GPUNLLLossForward1D_with_reduce<T>
GPUNLLLossForward1D_with_reduce<T, AccT>
<<<1, NTHREADS, 0, dev_ctx.stream()>>>(out_data,
total_weight_data,
x_data,
Expand All @@ -74,7 +75,7 @@ void NllLossRawKernel(const Context& dev_ctx,
const auto in_dim3 = x_dims[3];
const auto map_size = in_dim2 * in_dim3;
const auto out_numel = batch_size * in_dim2 * in_dim3;
int blocks = NumBlocks(out_numel);
int64_t blocks = NumBlocks(out_numel);
int threads = kNumCUDAThreads;
if (reduction == "none") {
GPUNLLLossForward2D_no_reduce<T>
Expand All @@ -90,8 +91,8 @@ void NllLossRawKernel(const Context& dev_ctx,
} else {
int blocks_per_sample = NumBlocks(map_size) / 128;
blocks_per_sample = (blocks_per_sample == 0) ? 1 : blocks_per_sample;
int total_blocks = blocks_per_sample * batch_size;
GPUNLLLossForward2D_with_reduce<T>
int64_t total_blocks = blocks_per_sample * batch_size;
GPUNLLLossForward2D_with_reduce<T, AccT>
<<<total_blocks, threads, 0, dev_ctx.stream()>>>(out_data,
total_weight_data,
x_data,
Expand Down
Loading