Skip to content

Commit d48d3a3

Browse files
authored
Fix cant find nv/targe bug (#73360)
* refine cpp extension * fix can't find nv/target bug * fix can't find nv/target bug
1 parent 68d0bb6 commit d48d3a3

File tree

5 files changed

+45
-1
lines changed

5 files changed

+45
-1
lines changed

paddle/cinn/backends/nvrtc/nvrtc_util.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ PD_DECLARE_string(cinn_nvcc_cmd_path);
3535
PD_DECLARE_string(nvidia_package_dir);
3636
PD_DECLARE_bool(nvrtc_compile_to_cubin);
3737
PD_DECLARE_bool(cinn_nvrtc_cubin_with_fmad);
38+
PD_DECLARE_string(cuda_cccl_dir);
3839

3940
namespace cinn {
4041
namespace backends {
@@ -50,7 +51,8 @@ static std::vector<std::string> GetNvidiaAllIncludePath(
5051
std::vector<std::string> include_paths;
5152
const std::string delimiter = "/";
5253
// Expand this list if necessary.
53-
const std::vector<std::string> sub_modules = {"cublas",
54+
const std::vector<std::string> sub_modules = {"cuda_cccl",
55+
"cublas",
5456
"cudnn",
5557
"cufft",
5658
"cusparse",

paddle/common/flags.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,11 @@ PHI_DEFINE_EXPORTED_string(
17981798
"Specify root dir path for nvidia site-package, such as "
17991799
"python3.9/site-packages/nvidia");
18001800

1801+
PHI_DEFINE_EXPORTED_string(cuda_cccl_dir, // NOLINT
1802+
"",
1803+
"Specify root dir path for nv/target, such as "
1804+
"python3.9/site-packages/nvidia/cuda_cccl/include/");
1805+
18011806
PHI_DEFINE_EXPORTED_string(
18021807
cudnn_dir, // NOLINT
18031808
"",

python/paddle/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,10 @@
648648
cupti_dir_lib_path = package_dir + "/.." + "/nvidia/cuda_cupti/lib"
649649
set_flags({"FLAGS_cupti_dir": cupti_dir_lib_path})
650650

651+
if is_compiled_with_cinn():
652+
cuda_cccl_path = package_dir + "/.." + "/nvidia/cuda_cccl/include/"
653+
set_flags({"FLAGS_cuda_cccl_dir": cuda_cccl_path})
654+
651655
elif (
652656
platform.system() == 'Windows'
653657
and platform.machine() in ('x86_64', 'AMD64')

python/setup.py.in

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,22 @@ def get_paddle_extra_install_requirements():
656656
"nvidia-cufile-cu12==1.14.0.30; platform_system == 'Linux' and platform_machine == 'x86_64'"
657657
),
658658
}
659+
if '@WITH_CINN@' == 'ON':
660+
PADDLE_CUDA_INSTALL_REQUIREMENTS["12.3"] += (
661+
" | nvidia-cuda-cccl-cu12==12.3.52;platform_system == 'Linux' and platform_machine == 'x86_64' "
662+
)
663+
PADDLE_CUDA_INSTALL_REQUIREMENTS["12.4"] += (
664+
" | nvidia-cuda-cccl-cu12==12.4.99;platform_system == 'Linux' and platform_machine == 'x86_64' "
665+
)
666+
PADDLE_CUDA_INSTALL_REQUIREMENTS["12.6"] += (
667+
" | nvidia-cuda-cccl-cu12==12.6.77;platform_system == 'Linux' and platform_machine == 'x86_64' "
668+
)
669+
PADDLE_CUDA_INSTALL_REQUIREMENTS["12.8"] += (
670+
" | nvidia-cuda-cccl-cu12==12.8.90;platform_system == 'Linux' and platform_machine == 'x86_64' "
671+
)
672+
PADDLE_CUDA_INSTALL_REQUIREMENTS["12.9"] += (
673+
" | nvidia-cuda-cccl-cu12==12.9.27;platform_system == 'Linux' and platform_machine == 'x86_64' "
674+
)
659675
elif platform.system() == 'Windows':
660676
PADDLE_CUDA_INSTALL_REQUIREMENTS = {
661677
"11.8": (

setup.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,23 @@ def get_paddle_extra_install_requirements():
11641164
"nvidia-cufile-cu12==1.14.0.30; platform_system == 'Linux' and platform_machine == 'x86_64'"
11651165
),
11661166
}
1167+
if env_dict.get("WITH_CINN") == "ON":
1168+
PADDLE_CUDA_INSTALL_REQUIREMENTS[
1169+
"12.3"
1170+
] += " | nvidia-cuda-cccl-cu12==12.3.52;platform_system == 'Linux' and platform_machine == 'x86_64' "
1171+
PADDLE_CUDA_INSTALL_REQUIREMENTS[
1172+
"12.4"
1173+
] += " | nvidia-cuda-cccl-cu12==12.4.99;platform_system == 'Linux' and platform_machine == 'x86_64' "
1174+
PADDLE_CUDA_INSTALL_REQUIREMENTS[
1175+
"12.6"
1176+
] += " | nvidia-cuda-cccl-cu12==12.6.77;platform_system == 'Linux' and platform_machine == 'x86_64' "
1177+
PADDLE_CUDA_INSTALL_REQUIREMENTS[
1178+
"12.8"
1179+
] += " | nvidia-cuda-cccl-cu12==12.8.90;platform_system == 'Linux' and platform_machine == 'x86_64' "
1180+
PADDLE_CUDA_INSTALL_REQUIREMENTS[
1181+
"12.9"
1182+
] += " | nvidia-cuda-cccl-cu12==12.9.27;platform_system == 'Linux' and platform_machine == 'x86_64' "
1183+
11671184
elif platform.system() == 'Windows':
11681185
PADDLE_CUDA_INSTALL_REQUIREMENTS = {
11691186
"11.8": (

0 commit comments

Comments
 (0)