1414
1515#pragma once
1616
17+ #include < cublasXt.h>
1718#include < cublas_v2.h>
1819#include < cuda.h>
1920#include < dlfcn.h>
2021#include < mutex> // NOLINT
22+ #include < type_traits>
2123#include " paddle/fluid/platform/dynload/dynamic_loader.h"
2224
2325namespace paddle {
@@ -37,14 +39,14 @@ extern void *cublas_dso_handle;
3739#ifdef PADDLE_USE_DSO
3840#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP (__name ) \
3941 struct DynLoad__ ##__name { \
42+ using FUNC_TYPE = decltype (&::__name); \
4043 template <typename ... Args> \
4144 inline cublasStatus_t operator ()(Args... args) { \
42- typedef cublasStatus_t (*cublasFunc)(Args...); \
4345 std::call_once (cublas_dso_flag, []() { \
4446 cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle (); \
4547 }); \
4648 void *p_##__name = dlsym (cublas_dso_handle, #__name); \
47- return reinterpret_cast <cublasFunc >(p_##__name)(args...); \
49+ return reinterpret_cast <FUNC_TYPE >(p_##__name)(args...); \
4850 } \
4951 }; \
5052 extern DynLoad__##__name __name
@@ -71,8 +73,8 @@ extern void *cublas_dso_handle;
7173 __macro (cublasDgemm_v2); \
7274 __macro (cublasHgemm); \
7375 __macro (cublasSgemmEx); \
74- __macro (cublasSgeam_v2); \
75- __macro (cublasDgeam_v2); \
76+ __macro (cublasSgeam); \
77+ __macro (cublasDgeam); \
7678 __macro (cublasCreate_v2); \
7779 __macro (cublasDestroy_v2); \
7880 __macro (cublasSetStream_v2); \
0 commit comments