Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions paddle/phi/kernels/funcs/gather.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,11 @@ __global__ void GatherNdCUDAKernel(const T* input,
}

template <typename T, typename IndexT = int>
void GPUGatherNd(const phi::GPUContext& ctx,
void GPUGatherNd(const phi::GPUContext& dev_ctx,
const DenseTensor& input,
const DenseTensor& index,
DenseTensor* output) {
const auto gplace = ctx.GetPlace();
const auto gplace = dev_ctx.GetPlace();
auto cplace = phi::CPUPlace();

auto index_dims = index.dims();
Expand Down Expand Up @@ -118,9 +118,9 @@ void GPUGatherNd(const phi::GPUContext& ctx,

constexpr int loop_count = 4;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, remain_numel * slice_size, vec_size * loop_count);
dev_ctx, remain_numel * slice_size, vec_size * loop_count);

auto stream = ctx.stream();
auto stream = dev_ctx.stream();

switch (vec_size) {
#define CASE_VEC_SIZE(__Sz) \
Expand Down Expand Up @@ -217,7 +217,7 @@ void GatherV2CUDAFunction(const DenseTensor* input,
const DenseTensor* index,
const int axis,
DenseTensor* out,
const phi::GPUContext& ctx) {
const phi::GPUContext& dev_ctx) {
int64_t index_size = index->numel();
int64_t input_size = input->numel();
auto input_dim = input->dims();
Expand Down Expand Up @@ -245,7 +245,7 @@ void GatherV2CUDAFunction(const DenseTensor* input,
auto out_dim = common::make_ddim(out_dim_vec);

out->Resize(out_dim);
auto* out_data = ctx.Alloc<T>(out);
auto* out_data = dev_ctx.Alloc<T>(out);
int64_t out_size = out->numel();
if (out_size == 0) return;

Expand All @@ -258,8 +258,8 @@ void GatherV2CUDAFunction(const DenseTensor* input,

constexpr int loop_count = 4;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, out_size, vec_size * loop_count);
auto stream = ctx.stream();
dev_ctx, out_size, vec_size * loop_count);
auto stream = dev_ctx.stream();
switch (vec_size) {
#define CASE_VEC_SIZE(__Sz) \
case __Sz: \
Expand Down Expand Up @@ -292,19 +292,19 @@ void GatherV2CUDAFunction(const DenseTensor* input,
* return: output tensor
*/
template <typename T, typename IndexT = int>
void GPUGather(const phi::GPUContext& ctx,
void GPUGather(const phi::GPUContext& dev_ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output) {
GatherV2CUDAFunction<T, IndexT>(&src, &index, /* axis= */ 0, output, ctx);
GatherV2CUDAFunction<T, IndexT>(&src, &index, /* axis= */ 0, output, dev_ctx);
}

template <typename T, typename U>
void GatherV2GradCUDAFunction(const DenseTensor* input,
const DenseTensor* index,
const int axis,
DenseTensor* out,
const phi::GPUContext& ctx) {
const phi::GPUContext& dev_ctx) {
auto* index_data = index->data<U>();
int64_t index_size = index->numel();
int64_t input_size = input->numel();
Expand All @@ -326,13 +326,13 @@ void GatherV2GradCUDAFunction(const DenseTensor* input,
outer_dim_size *= input_dim[i];
}

auto* out_data = ctx.Alloc<T>(out);
auto* out_data = dev_ctx.Alloc<T>(out);
auto out_dim = out->dims();
int64_t out_index_dim_size = out_dim[axis_index];
phi::funcs::set_constant(ctx, out, static_cast<float>(0.0));
phi::funcs::set_constant(dev_ctx, out, static_cast<float>(0.0));

auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, input_size);
auto stream = ctx.stream();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, input_size);
auto stream = dev_ctx.stream();
GatherGradGPUKernel<T, U>
<<<config.block_per_grid, config.thread_per_block, 0, stream>>>(
input_data,
Expand Down
39 changes: 21 additions & 18 deletions paddle/phi/kernels/funcs/scatter.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ __global__ void ScatterNdCUDAKernel(const T* update,
* return: output tensor
*/
template <typename T, typename IndexT = int>
void GPUScatterAssign(const phi::GPUContext& ctx,
void GPUScatterAssign(const phi::GPUContext& dev_ctx,
const DenseTensor& src,
const DenseTensor& index,
DenseTensor* output,
Expand Down Expand Up @@ -204,15 +204,16 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
int block = 512;
int64_t n = slice_size * index_size;
dim3 grid = dim3((n + block - 1) / block);
phi::backends::gpu::LimitGridDim(ctx, &grid);
phi::backends::gpu::LimitGridDim(dev_ctx, &grid);

// if not overwrite mode, init data
if (!overwrite) {
ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, dev_ctx.stream()>>>(
p_index, p_output, output_dims[0], index_size, slice_size);

ScatterCUDAKernel<T, IndexT, false, 1><<<grid, block, 0, ctx.stream()>>>(
p_src, p_index, p_output, output_dims[0], index_size, slice_size);
ScatterCUDAKernel<T, IndexT, false, 1>
<<<grid, block, 0, dev_ctx.stream()>>>(
p_src, p_index, p_output, output_dims[0], index_size, slice_size);
return;
}

Expand All @@ -225,14 +226,16 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
}

constexpr int loop_count = 4;
auto config =
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, n, vec_size * loop_count);
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
dev_ctx, n, vec_size * loop_count);
switch (vec_size) {
#define CASE_VEC_SIZE(__Sz) \
case __Sz: \
ScatterCUDAKernel<T, IndexT, true, __Sz> \
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( \
p_src, p_index, p_output, output_dims[0], index_size, slice_size); \
#define CASE_VEC_SIZE(__Sz) \
case __Sz: \
ScatterCUDAKernel<T, IndexT, true, __Sz><<<config.block_per_grid, \
config.thread_per_block, \
0, \
dev_ctx.stream()>>>( \
p_src, p_index, p_output, output_dims[0], index_size, slice_size); \
break
CASE_VEC_SIZE(8);
CASE_VEC_SIZE(4);
Expand All @@ -248,7 +251,7 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
// The function is only for scatter grad x,
// however update grad use gather
template <typename T, typename IndexT = int>
void GPUScatterGradForX(const phi::GPUContext& ctx,
void GPUScatterGradForX(const phi::GPUContext& dev_ctx,
const DenseTensor& index,
DenseTensor* output) {
int64_t index_size = index.dims().size() == 0 ? 1 : index.dims()[0];
Expand All @@ -265,14 +268,14 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block;
dim3 grid = dim3((n + block - 1) / block);
phi::backends::gpu::LimitGridDim(ctx, &grid);
phi::backends::gpu::LimitGridDim(dev_ctx, &grid);

ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, dev_ctx.stream()>>>(
p_index, p_output, dst_dims[0], index_size, slice_size);
}

template <typename T, typename IndexT = int>
void GPUScatterNdAdd(const phi::GPUContext& ctx,
void GPUScatterNdAdd(const phi::GPUContext& dev_ctx,
const DenseTensor& update,
const DenseTensor& index,
DenseTensor* output) {
Expand Down Expand Up @@ -312,9 +315,9 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,

constexpr int loop_count = 4;
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(
ctx, remain_numel * slice_size, vec_size * loop_count);
dev_ctx, remain_numel * slice_size, vec_size * loop_count);

auto stream = ctx.stream();
auto stream = dev_ctx.stream();
switch (vec_size) {
#define CASE_VEC_SIZE(__Sz) \
case __Sz: \
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/funcs/search_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void call_gemm(const Context& dev_ctx,
T* C) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
// auto& dev_ctx = ctx.template device_context<phi::CPUContext>();
// auto& dev_ctx = dev_ctx.template device_context<phi::CPUContext>();
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx);
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
Expand All @@ -83,7 +83,7 @@ void call_gemm_with_lda(const phi::funcs::BlasT<DeviceContext, T>& blas,
}

template <typename T, typename Context>
void call_gemm_batched(const Context& ctx,
void call_gemm_batched(const Context& dev_ctx,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M,
Expand All @@ -96,7 +96,7 @@ void call_gemm_batched(const Context& ctx,
T** C,
const int batch) {
for (int i = 0; i < batch; ++i) {
call_gemm(ctx, TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i]);
call_gemm(dev_ctx, TransA, TransB, M, N, K, alpha, A[i], B[i], beta, C[i]);
}
}

Expand Down
Loading