|
14 | 14 |
|
15 | 15 | #pragma once |
16 | 16 |
|
17 | | -#include "paddle/fluid/platform/enforce.h" |
18 | | -#include "paddle/fluid/platform/place.h" |
19 | | -#ifdef PADDLE_WITH_CUDA |
20 | | -#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h" |
21 | | -#endif |
| 17 | +#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h" |
| 18 | +#include "paddle/phi/common/place.h" |
| 19 | +#include "paddle/phi/core/enforce.h" |
| 20 | +#include "paddle/phi/core/macros.h" |
22 | 21 |
|
23 | 22 | namespace paddle { |
24 | 23 | namespace platform { |
25 | 24 |
|
26 | | -#ifdef PADDLE_WITH_CUDA |
27 | | -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ |
28 | | - __kernel_func, \ |
29 | | - __grid, \ |
30 | | - __block, \ |
31 | | - __sm_size, \ |
32 | | - __stream, \ |
33 | | - __seed_inc, \ |
34 | | - __seed_expr, \ |
35 | | - __offset_expr, \ |
36 | | - ...) \ |
37 | | - do { \ |
38 | | - if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \ |
39 | | - using __Helper = \ |
40 | | - ::paddle::platform::IsSameKernelHelper<decltype(&__kernel_func), \ |
41 | | - &__kernel_func>; \ |
42 | | - auto *dev_ctx = \ |
43 | | - ::paddle::platform::DeviceContextPool::Instance().GetByPlace( \ |
44 | | - ::paddle::platform::CUDAGraph::CapturingPlace()); \ |
45 | | - auto __set_seed_func = \ |
46 | | - [=](::paddle::platform::CUDAKernelParams *__params, \ |
47 | | - bool __check_only) -> bool { \ |
48 | | - if (__check_only) { \ |
49 | | - return __params->func() == &__kernel_func && \ |
50 | | - __Helper::Compare(*__params, __VA_ARGS__); \ |
51 | | - } \ |
52 | | - auto &KERNEL_PARAMS = *__params; \ |
53 | | - uint64_t __seed, __offset; \ |
54 | | - ::paddle::operators::GetSeedDataAndIncrement( \ |
55 | | - *dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \ |
56 | | - __seed_expr = static_cast<decltype(__seed_expr)>(__seed); \ |
57 | | - __offset_expr = static_cast<decltype(__offset_expr)>(__offset); \ |
58 | | - return true; \ |
59 | | - }; \ |
60 | | - ::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \ |
61 | | - } \ |
62 | | - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ |
63 | | - } while (0) |
64 | | -#else |
65 | | -#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \ |
66 | | - __kernel_func, \ |
67 | | - __grid, \ |
68 | | - __block, \ |
69 | | - __sm_size, \ |
70 | | - __stream, \ |
71 | | - __seed_inc, \ |
72 | | - __seed_expr, \ |
73 | | - __offset_expr, \ |
74 | | - ...) \ |
75 | | - do { \ |
76 | | - __kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \ |
77 | | - } while (0) |
78 | | -#endif |
79 | | - |
80 | 25 | // NOTE: These APIs are not thread-safe. |
81 | 26 | #ifdef PADDLE_WITH_CUDA |
82 | | -void BeginCUDAGraphCapture(platform::CUDAPlace place, |
| 27 | +using CUDAGraph = phi::backends::gpu::CUDAGraph; |
| 28 | + |
| 29 | +void BeginCUDAGraphCapture(phi::GPUPlace place, |
83 | 30 | cudaStreamCaptureMode mode, |
84 | 31 | int64_t pool_id = CUDAGraph::kInvalidPoolID); |
85 | 32 | std::unique_ptr<CUDAGraph> EndCUDAGraphCapture(); |
86 | 33 | #endif |
87 | 34 |
|
88 | | -inline bool IsCUDAGraphCapturing() { |
89 | | -#ifdef PADDLE_WITH_CUDA |
90 | | - return CUDAGraph::IsCapturing(); |
91 | | -#else |
92 | | - return false; |
93 | | -#endif |
94 | | -} |
95 | | - |
96 | | -inline platform::CUDAPlace CUDAGraphCapturingPlace() { |
| 35 | +inline phi::GPUPlace CUDAGraphCapturingPlace() { |
97 | 36 | #ifdef PADDLE_WITH_CUDA |
98 | 37 | return CUDAGraph::CapturingPlace(); |
99 | 38 | #else |
100 | | - PADDLE_THROW(platform::errors::Unimplemented( |
| 39 | + PADDLE_THROW(phi::errors::Unimplemented( |
101 | 40 | "CUDA Graph is only supported on NVIDIA GPU device.")); |
102 | 41 | #endif |
103 | 42 | } |
104 | 43 |
|
105 | | -// Add reset callback if CUDA Graph is capturing. |
106 | | -// Otherwise, invoke callback directly. |
107 | | -template <typename Callback> |
108 | | -inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) { |
109 | | -#ifdef PADDLE_WITH_CUDA |
110 | | - if (UNLIKELY(IsCUDAGraphCapturing())) { |
111 | | - return CUDAGraph::AddResetCallbackDuringCapturing( |
112 | | - std::forward<Callback>(callback)); |
113 | | - } |
114 | | -#endif |
115 | | - callback(); |
116 | | -} |
| 44 | +using phi::backends::gpu::IsCUDAGraphCapturing; |
117 | 45 |
|
118 | | -template <typename T> |
119 | | -inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) { |
120 | | - static_assert(std::is_trivial<T>::value, "T must be trivial type"); |
121 | | - static_assert(!std::is_same<T, void>::value, "T cannot be void"); |
122 | | -#ifdef PADDLE_WITH_CUDA |
123 | | - if (UNLIKELY(IsCUDAGraphCapturing())) { |
124 | | - size_t nbytes = size * sizeof(T); |
125 | | - void *new_host_mem = new uint8_t[nbytes]; |
126 | | - std::memcpy(new_host_mem, host_mem, nbytes); |
127 | | - AddResetCallbackIfCapturingCUDAGraph( |
128 | | - [new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); }); |
129 | | - return reinterpret_cast<T *>(new_host_mem); |
130 | | - } |
131 | | -#endif |
132 | | - return host_mem; |
133 | | -} |
| 46 | +using phi::backends::gpu::AddResetCallbackIfCapturingCUDAGraph; |
| 47 | + |
| 48 | +using phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph; |
134 | 49 |
|
135 | 50 | class SkipCUDAGraphCaptureGuard { |
136 | 51 | DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard); |
|
0 commit comments