Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ void CublasLtMatmulFP8(const phi::GPUContext& dev_ctx,

template <typename Context>
void cublaslt_fp8_fp8_fp16_gemm(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const paddle::optional<DenseTensor>& bias,
Expand Down Expand Up @@ -339,18 +339,18 @@ void cublaslt_fp8_fp8_fp16_gemm(
common::errors::InvalidArgument(
"FP8 gemm need k % 16 = 0, but k = %d", k));

ctx.template Alloc<phi::dtype::float16>(out);
dev_ctx.template Alloc<phi::dtype::float16>(out);
int batch_count = 1;
for (size_t i = 0; i < rank - 2; ++i) {
batch_count *= x.dims()[i];
}
CublasLtMatmulFP8<phi::dtype::float16>(
ctx, batch_count, m, n, k, x, y, scale, bias, activation_type, out);
dev_ctx, batch_count, m, n, k, x, y, scale, bias, activation_type, out);
}

template <typename Context>
void cublaslt_fp8_fp8_bf16_gemm(
const Context& ctx,
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const paddle::optional<DenseTensor>& bias,
Expand Down Expand Up @@ -396,13 +396,13 @@ void cublaslt_fp8_fp8_bf16_gemm(
common::errors::InvalidArgument(
"FP8 gemm need k % 16 = 0, but k = %d", k));

ctx.template Alloc<phi::dtype::bfloat16>(out);
dev_ctx.template Alloc<phi::dtype::bfloat16>(out);
int batch_count = 1;
for (size_t i = 0; i < rank - 2; ++i) {
batch_count *= x.dims()[i];
}
CublasLtMatmulFP8<phi::dtype::bfloat16>(
ctx, batch_count, m, n, k, x, y, scale, bias, activation_type, out);
dev_ctx, batch_count, m, n, k, x, y, scale, bias, activation_type, out);
}

} // namespace cutlass_internal
Expand Down
Loading