File tree Expand file tree Collapse file tree 5 files changed +45
-1
lines changed Expand file tree Collapse file tree 5 files changed +45
-1
lines changed Original file line number Diff line number Diff line change @@ -35,6 +35,7 @@ PD_DECLARE_string(cinn_nvcc_cmd_path);
3535PD_DECLARE_string (nvidia_package_dir);
3636PD_DECLARE_bool (nvrtc_compile_to_cubin);
3737PD_DECLARE_bool (cinn_nvrtc_cubin_with_fmad);
38+ PD_DECLARE_string (cuda_cccl_dir);
3839
3940namespace cinn {
4041namespace backends {
@@ -50,7 +51,8 @@ static std::vector<std::string> GetNvidiaAllIncludePath(
5051 std::vector<std::string> include_paths;
5152 const std::string delimiter = " /" ;
5253 // Expand this list if necessary.
53- const std::vector<std::string> sub_modules = {" cublas" ,
54+ const std::vector<std::string> sub_modules = {" cuda_cccl" ,
55+ " cublas" ,
5456 " cudnn" ,
5557 " cufft" ,
5658 " cusparse" ,
Original file line number Diff line number Diff line change @@ -1798,6 +1798,11 @@ PHI_DEFINE_EXPORTED_string(
17981798 " Specify root dir path for nvidia site-package, such as "
17991799 " python3.9/site-packages/nvidia" );
18001800
1801+ PHI_DEFINE_EXPORTED_string (cuda_cccl_dir, // NOLINT
1802+ " " ,
1803+ " Specify root dir path for nv/target, such as "
1804+ " python3.9/site-packages/nvidia/cuda_cccl/include/" );
1805+
18011806PHI_DEFINE_EXPORTED_string (
18021807 cudnn_dir, // NOLINT
18031808 " " ,
Original file line number Diff line number Diff line change 648648 cupti_dir_lib_path = package_dir + "/.." + "/nvidia/cuda_cupti/lib"
649649 set_flags ({"FLAGS_cupti_dir" : cupti_dir_lib_path })
650650
651+ if is_compiled_with_cinn ():
652+ cuda_cccl_path = package_dir + "/.." + "/nvidia/cuda_cccl/include/"
653+ set_flags ({"FLAGS_cuda_cccl_dir" : cuda_cccl_path })
654+
651655 elif (
652656 platform .system () == 'Windows'
653657 and platform .machine () in ('x86_64' , 'AMD64' )
Original file line number Diff line number Diff line change @@ -656,6 +656,22 @@ def get_paddle_extra_install_requirements():
656656 "nvidia-cufile-cu12==1.14.0.30; platform_system == 'Linux' and platform_machine == 'x86_64'"
657657 ),
658658 }
659+ if '@WITH_CINN@' == 'ON':
660+ PADDLE_CUDA_INSTALL_REQUIREMENTS["12.3"] += (
661+ " | nvidia-cuda-cccl-cu12==12.3.52;platform_system == 'Linux' and platform_machine == 'x86_64' "
662+ )
663+ PADDLE_CUDA_INSTALL_REQUIREMENTS["12.4"] += (
664+ " | nvidia-cuda-cccl-cu12==12.4.99;platform_system == 'Linux' and platform_machine == 'x86_64' "
665+ )
666+ PADDLE_CUDA_INSTALL_REQUIREMENTS["12.6"] += (
667+ " | nvidia-cuda-cccl-cu12==12.6.77;platform_system == 'Linux' and platform_machine == 'x86_64' "
668+ )
669+ PADDLE_CUDA_INSTALL_REQUIREMENTS["12.8"] += (
670+ " | nvidia-cuda-cccl-cu12==12.8.90;platform_system == 'Linux' and platform_machine == 'x86_64' "
671+ )
672+ PADDLE_CUDA_INSTALL_REQUIREMENTS["12.9"] += (
673+ " | nvidia-cuda-cccl-cu12==12.9.27;platform_system == 'Linux' and platform_machine == 'x86_64' "
674+ )
659675 elif platform.system() == 'Windows':
660676 PADDLE_CUDA_INSTALL_REQUIREMENTS = {
661677 "11.8": (
Original file line number Diff line number Diff line change @@ -1164,6 +1164,23 @@ def get_paddle_extra_install_requirements():
11641164 "nvidia-cufile-cu12==1.14.0.30; platform_system == 'Linux' and platform_machine == 'x86_64'"
11651165 ),
11661166 }
1167+ if env_dict .get ("WITH_CINN" ) == "ON" :
1168+ PADDLE_CUDA_INSTALL_REQUIREMENTS [
1169+ "12.3"
1170+ ] += " | nvidia-cuda-cccl-cu12==12.3.52;platform_system == 'Linux' and platform_machine == 'x86_64' "
1171+ PADDLE_CUDA_INSTALL_REQUIREMENTS [
1172+ "12.4"
1173+ ] += " | nvidia-cuda-cccl-cu12==12.4.99;platform_system == 'Linux' and platform_machine == 'x86_64' "
1174+ PADDLE_CUDA_INSTALL_REQUIREMENTS [
1175+ "12.6"
1176+ ] += " | nvidia-cuda-cccl-cu12==12.6.77;platform_system == 'Linux' and platform_machine == 'x86_64' "
1177+ PADDLE_CUDA_INSTALL_REQUIREMENTS [
1178+ "12.8"
1179+ ] += " | nvidia-cuda-cccl-cu12==12.8.90;platform_system == 'Linux' and platform_machine == 'x86_64' "
1180+ PADDLE_CUDA_INSTALL_REQUIREMENTS [
1181+ "12.9"
1182+ ] += " | nvidia-cuda-cccl-cu12==12.9.27;platform_system == 'Linux' and platform_machine == 'x86_64' "
1183+
11671184 elif platform .system () == 'Windows' :
11681185 PADDLE_CUDA_INSTALL_REQUIREMENTS = {
11691186 "11.8" : (
You can’t perform that action at this time.
0 commit comments