Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions paddle/phi/kernels/fusion/cutlass/fused_conv2d_add_act_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace cutlass_internal {
typedef bool (*func)(phi::fusion::cutlass_internal::ConvAllParams);

template <typename T, typename Context>
void FusedConv2dAddActKernel(const Context& ctx,
void FusedConv2dAddActKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& filter,
const DenseTensor& bias,
Expand All @@ -47,7 +47,7 @@ void FusedConv2dAddActKernel(const Context& ctx,
float fuse_alpha,
DenseTensor* output,
std::vector<DenseTensor*> outputs) {
ctx.template Alloc<T>(output);
dev_ctx.template Alloc<T>(output);
auto in_dims = x.dims();
auto filter_dims = filter.dims();
auto out_dims = output->dims();
Expand Down Expand Up @@ -136,7 +136,7 @@ void FusedConv2dAddActKernel(const Context& ctx,
const int oh = out_dims[1];
const int ow = out_dims[2];

int64_t device_id = ctx.GetPlace().GetDeviceId();
int64_t device_id = dev_ctx.GetPlace().GetDeviceId();
int sm_version = backends::gpu::GetGPUComputeCapability(device_id);

auto get_conv2d_dtype = [&](decltype(x.dtype()) x_type)
Expand Down Expand Up @@ -190,7 +190,7 @@ void FusedConv2dAddActKernel(const Context& ctx,
oh,
ow,
groups,
ctx.stream(),
dev_ctx.stream(),
0, // alpha
cutlass_dispatch_sm_version(sm_version),
get_conv2d_dtype(x.dtype()),
Expand All @@ -207,9 +207,9 @@ void FusedConv2dAddActKernel(const Context& ctx,
if (groups == ic && ic == oc) {
// conv2d_depthwise need a tmp workspace.
phi::Allocator::AllocationPtr tmp_ptr = phi::memory_utils::Alloc(
ctx.GetPlace(),
dev_ctx.GetPlace(),
oc * kh * kw * sizeof(T),
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
params.workspace = tmp_ptr->ptr();
// cutlass conv2d_depthwise not support residual
if (residual) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ using gemm_kernel_utils::getMaximumSharedMemoryPerBlockKb;

template <typename T, typename Context>
void MemoryEfficientAttentionGradKernel(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
Expand Down Expand Up @@ -292,7 +292,7 @@ void MemoryEfficientAttentionGradKernel(

auto use_dropout = dropout_p != 0.0;
const auto maxK = std::max(q_dims[3], v_dims[3]);
int compute_capacity = ctx.GetComputeCapability();
int compute_capacity = dev_ctx.GetComputeCapability();
const auto max_shmem =
getMaximumSharedMemoryPerBlockKb(compute_capacity) * 1024;
using KernelType = decltype(k_);
Expand Down Expand Up @@ -327,36 +327,37 @@ void MemoryEfficientAttentionGradKernel(
DenseTensor delta;
if (KernelType::kKernelComputesDelta) {
phi::EmptyKernel<float, Context>(
ctx,
dev_ctx,
{output.dims()[0], output.dims()[2], output.dims()[1]},
output.dtype(),
&delta);
} else {
DenseTensor output_grad_tmp =
output_grad.dtype() == DataType::FLOAT32
? output_grad
: phi::Cast<T, Context>(ctx, output_grad, DataType::FLOAT32);
: phi::Cast<T, Context>(dev_ctx, output_grad, DataType::FLOAT32);
DenseTensor output_tmp =
output.dtype() == DataType::FLOAT32
? output
: phi::Cast<T, Context>(ctx, output, DataType::FLOAT32);
: phi::Cast<T, Context>(dev_ctx, output, DataType::FLOAT32);
DenseTensor delta_mul =
phi::Multiply<float, Context>(ctx, output_grad_tmp, output_tmp);
phi::Multiply<float, Context>(dev_ctx, output_grad_tmp, output_tmp);

DenseTensor delta_sum;
phi::EmptyKernel<float, Context>(
ctx,
dev_ctx,
{delta_mul.dims()[0], delta_mul.dims()[1], delta_mul.dims()[2]},
DataType::FLOAT32,
&delta_sum);
phi::SumKernel<float, Context>(
ctx, delta_mul, {-1}, delta_mul.dtype(), false, &delta_sum);
dev_ctx, delta_mul, {-1}, delta_mul.dtype(), false, &delta_sum);
phi::EmptyKernel<float, Context>(
ctx,
dev_ctx,
{delta_mul.dims()[0], delta_mul.dims()[2], delta_mul.dims()[1]},
DataType::FLOAT32,
&delta);
phi::TransposeKernel<float, Context>(ctx, delta_sum, {0, 2, 1}, &delta);
phi::TransposeKernel<float, Context>(
dev_ctx, delta_sum, {0, 2, 1}, &delta);
}
VLOG(3) << "p.output" << output.dtype();
VLOG(3) << "p.output_grad" << output_grad.dtype();
Expand Down Expand Up @@ -399,7 +400,7 @@ void MemoryEfficientAttentionGradKernel(
bool force_pad_inf = (compute_capacity == 75);
const std::string data_format = "NCHW";
DenseTensor padded_lse =
phi::funcs::get_pad_lse<float>(ctx,
phi::funcs::get_pad_lse<float>(dev_ctx,
const_cast<DenseTensor*>(&logsumexp),
static_cast<int>(output.dims()[1]),
32,
Expand All @@ -412,24 +413,26 @@ void MemoryEfficientAttentionGradKernel(

if (!has_query_grad) {
dq_tmp.clear();
dq_tmp = EmptyLike<T, Context>(ctx, query);
dq_tmp = EmptyLike<T, Context>(dev_ctx, query);
query_grad = &dq_tmp;
}
p.grad_query_ptr = phi::SafeAllocTensor<scalar_t, Context>(ctx, query_grad);
p.grad_query_ptr =
phi::SafeAllocTensor<scalar_t, Context>(dev_ctx, query_grad);

if (!has_key_grad) {
dk_tmp.clear();
dk_tmp = EmptyLike<T, Context>(ctx, key);
dk_tmp = EmptyLike<T, Context>(dev_ctx, key);
key_grad = &dk_tmp;
}
p.grad_key_ptr = phi::SafeAllocTensor<scalar_t, Context>(ctx, key_grad);
p.grad_key_ptr = phi::SafeAllocTensor<scalar_t, Context>(dev_ctx, key_grad);

if (!has_value_grad) {
dv_tmp.clear();
dv_tmp = EmptyLike<T, Context>(ctx, value);
dv_tmp = EmptyLike<T, Context>(dev_ctx, value);
value_grad = &dv_tmp;
}
p.grad_value_ptr = phi::SafeAllocTensor<scalar_t, Context>(ctx, value_grad);
p.grad_value_ptr =
phi::SafeAllocTensor<scalar_t, Context>(dev_ctx, value_grad);

p.delta_ptr = phi::SafeGetTensorPtr<float>(delta);
PD_MEA_CHECK_OVERFLOW(p.head_dim, q_dims[3]);
Expand Down Expand Up @@ -522,7 +525,7 @@ void MemoryEfficientAttentionGradKernel(
VLOG(3) << "p.bias_ptr" << p.bias_ptr;
if (bias_grad) {
p.grad_bias_ptr =
phi::SafeAllocTensor<scalar_t, Context>(ctx, bias_grad);
phi::SafeAllocTensor<scalar_t, Context>(dev_ctx, bias_grad);
PD_MEA_CHECK_OVERFLOW(p.gB_strideB, q_dims[2] * q_dims[1] * k_dims[1]);
PD_MEA_CHECK_OVERFLOW(p.gB_strideH, q_dims[1] * k_dims[1]);
PD_MEA_CHECK_OVERFLOW(p.gB_strideM, k_dims[1]);
Expand All @@ -549,9 +552,9 @@ void MemoryEfficientAttentionGradKernel(
phi::Allocator::AllocationPtr temp_workspace{nullptr};
VLOG(3) << "size_bytes " << size_bytes;
temp_workspace = phi::memory_utils::Alloc(
ctx.GetPlace(),
dev_ctx.GetPlace(),
size_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream())));
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
if (size_bytes) {
p.workspace = reinterpret_cast<typename KernelType::output_accum_t*>(
temp_workspace->ptr());
Expand All @@ -574,9 +577,9 @@ void MemoryEfficientAttentionGradKernel(
kernel_fn<<<p.getBlocksGrid(),
p.getThreadsGrid(),
smem_bytes,
ctx.stream()>>>(p);
dev_ctx.stream()>>>(p);
};
dispatch_cutlass_backward<T>(ctx, launchKernel);
dispatch_cutlass_backward<T>(dev_ctx, launchKernel);
PADDLE_ENFORCE_EQ(
kernel_launched,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using gemm_kernel_utils::getMaximumSharedMemoryPerBlockKb;

template <typename T, typename Context>
void MemoryEfficientAttentionForwardKernel(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
Expand All @@ -47,7 +47,7 @@ void MemoryEfficientAttentionForwardKernel(
DenseTensor* output,
DenseTensor* logsumexp,
DenseTensor* seed_and_offset) {
int compute_capacity = ctx.GetComputeCapability();
int compute_capacity = dev_ctx.GetComputeCapability();
const auto max_shmem =
getMaximumSharedMemoryPerBlockKb(compute_capacity) * 1024;
bool kernel_launched = false;
Expand Down Expand Up @@ -122,7 +122,7 @@ void MemoryEfficientAttentionForwardKernel(
is_test ? 0 : (max_seqlen_q_tmp + kAlignLSE - 1) / kAlignLSE;
logsumexp_dims[2] *= kAlignLSE;
logsumexp->Resize(logsumexp_dims);
ctx.template Alloc<float>(logsumexp);
dev_ctx.template Alloc<float>(logsumexp);
VLOG(3) << "logsumexp dims" << logsumexp_dims;
VLOG(3) << "logsumexp" << logsumexp;
VLOG(3) << "kAlignLSE" << kAlignLSE;
Expand All @@ -139,13 +139,13 @@ void MemoryEfficientAttentionForwardKernel(
out_accum.Resize(output->dims());
p.output_accum_ptr =
phi::SafeAllocTensor<typename KernelType::output_accum_t, Context>(
ctx, &out_accum);
dev_ctx, &out_accum);
VLOG(3) << "output_accum_ptr " << p.output_accum_ptr;
} else {
p.output_accum_ptr = nullptr;
}
p.output_ptr = phi::SafeAllocTensor<typename KernelType::output_t, Context>(
ctx, output);
dev_ctx, output);
VLOG(3) << "output_ptr " << p.output_ptr;

if (cu_seqlens_q) {
Expand Down Expand Up @@ -221,11 +221,11 @@ void MemoryEfficientAttentionForwardKernel(
phi::Dim<1> seed_dims;
seed_dims[0] = 2;
seed_and_offset->Resize(seed_dims);
ctx.template HostAlloc<int64_t>(seed_and_offset);
dev_ctx.template HostAlloc<int64_t>(seed_and_offset);
int64_t* seed_and_offset_ptr =
phi::SafeGetTensorPtr<int64_t>(seed_and_offset);

auto gen = ctx.GetGenerator();
auto gen = dev_ctx.GetGenerator();
uint64_t inc = query.dims()[0] * query.dims()[2] * 32;
auto seed_offset_pair = gen->IncrementOffset(inc);
auto seed = (seed_offset_pair.first);
Expand Down Expand Up @@ -259,9 +259,9 @@ void MemoryEfficientAttentionForwardKernel(
kernel_fn<<<p.getBlocksGrid(),
p.getThreadsGrid(),
smem_bytes,
ctx.stream()>>>(p);
dev_ctx.stream()>>>(p);
};
dispatch_cutlass_forward<T>(ctx, launchKernel);
dispatch_cutlass_forward<T>(dev_ctx, launchKernel);
PADDLE_ENFORCE_EQ(
kernel_launched,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace fusion {

template <typename T, typename Context>
void MultiHeadAttentionVariableForwardKernel(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace fusion {

template <typename T, typename Context>
void MultiHeadAttentionVariableForwardKernel(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& query,
const DenseTensor& key,
const DenseTensor& value,
Expand All @@ -30,7 +30,7 @@ void MultiHeadAttentionVariableForwardKernel(
const bool causal,
const int pre_cache_length,
DenseTensor* output) {
ctx.template Alloc<T>(output);
dev_ctx.template Alloc<T>(output);
Params params{};
// [B, N, S, H]
params.seq_lens = seq_lens.data<int>();
Expand Down Expand Up @@ -109,9 +109,9 @@ void MultiHeadAttentionVariableForwardKernel(
return;
}
kernel_launched = true;
kernel_fn(k_, params, ctx);
kernel_fn(k_, params, dev_ctx);
};
dispatch_cutlass_forward<T, decltype(launchKernel)>(ctx, launchKernel);
dispatch_cutlass_forward<T, decltype(launchKernel)>(dev_ctx, launchKernel);
PADDLE_ENFORCE_EQ(
kernel_launched,
true,
Expand Down
Loading