Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 24 additions & 25 deletions paddle/phi/kernels/legacy/gpu/fp8_gemm_blockwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ cudaDataType_t ScalarTypeToCudaDataType(phi::DataType dtype) {
} \
} while (0)

void cublas_gemm_blockwise_impl(const DenseTensor& A,
template <typename Context>
void cublas_gemm_blockwise_impl(const Context& dev_ctx,
const DenseTensor& A,
const DenseTensor& A_decode_scale,
const DenseTensor& B,
const DenseTensor& B_decode_scale,
Expand Down Expand Up @@ -148,15 +150,13 @@ void cublas_gemm_blockwise_impl(const DenseTensor& A,
int lda = k, ldb = k, ldc = m, ldd = m;
float alpha = 1.0, beta = accumulate ? 1.0 : 0.0;

cublasLtHandle_t ltHandle;
PADDLE_CUDABLAS_CHECK(phi::dynload::cublasLtCreate(&ltHandle));

cublasLtHandle_t ltHandle = dev_ctx.cublaslt_handle();
// Create operation descriptor
cublasLtMatmulDesc_t operationDesc = nullptr;
PADDLE_CUDABLAS_CHECK(phi::dynload::cublasLtMatmulDescCreate(
&operationDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));

#if CUBLAS_VERSION >= 120804 && CUDA_VERSION >= 12060
#if CUBLAS_VERSION >= 120805 && CUDA_VERSION >= 12080
// Setup scaling for A and B
cublasLtMatmulMatrixScale_t A_scale_mode, B_scale_mode;
// Note: in cuBLAS term, tensor name A and B are swapped.
Expand Down Expand Up @@ -187,7 +187,7 @@ void cublas_gemm_blockwise_impl(const DenseTensor& A,
sizeof(B_scale_mode)));
#else
PADDLE_THROW(phi::errors::InvalidArgument(
"Sub-channel FP8 GEMM requires CUDA 12.8 and cuBLAS 12.8.4 or later."));
"Sub-channel FP8 GEMM requires CUDA 12.8 and cuBLAS 12.8.5 or later."));
#endif

// setup transa and transb
Expand Down Expand Up @@ -285,7 +285,6 @@ void cublas_gemm_blockwise_impl(const DenseTensor& A,
workspace->data(),
workspace_size,
stream));

// Cleanup
if (preference)
PADDLE_CUDABLAS_CHECK(
Expand All @@ -301,7 +300,6 @@ void cublas_gemm_blockwise_impl(const DenseTensor& A,
if (operationDesc)
PADDLE_CUDABLAS_CHECK(
phi::dynload::cublasLtMatmulDescDestroy(operationDesc));
if (ltHandle) PADDLE_CUDABLAS_CHECK(phi::dynload::cublasLtDestroy(ltHandle));
}

} // anonymous namespace
Expand All @@ -327,23 +325,24 @@ void Fp8GemmBlockwiseKernel(const Context& dev_ctx,
DenseTensor* output,
DenseTensor* pre_gelu_out,
DenseTensor* workspace_out) {
cublas_gemm_blockwise_impl(A,
A_scale,
B,
B_scale,
output,
bias,
pre_gelu_out,
transa,
transb,
grad,
workspace_out,
accumulate,
use_split_accumulator,
math_sm_count,
is_A_1d_scaled,
is_B_1d_scaled,
dev_ctx.stream());
cublas_gemm_blockwise_impl<Context>(dev_ctx,
A,
A_scale,
B,
B_scale,
output,
bias,
pre_gelu_out,
transa,
transb,
grad,
workspace_out,
accumulate,
use_split_accumulator,
math_sm_count,
is_A_1d_scaled,
is_B_1d_scaled,
dev_ctx.stream());
}

} // namespace phi
Expand Down
Loading