Skip to content

Commit 9dfce88

Browse files
committed
[XPU] support python streams api for xpu
1 parent cf30d11 commit 9dfce88

File tree

17 files changed

+578
-37
lines changed

17 files changed

+578
-37
lines changed

paddle/fluid/pybind/tensor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -902,7 +902,7 @@ void BindTensor(pybind11::module &m) { // NOLINT
902902
const auto &device_id =
903903
paddle::platform::GetXPUCurrentDeviceId();
904904
auto stream = paddle::platform::get_current_stream(device_id);
905-
xpu_wait(stream);
905+
xpu_wait(stream->raw_stream());
906906
int type_idx = static_cast<int>(self.type());
907907
size_t data_size = self.numel() *
908908
framework::SizeOfType(

paddle/fluid/pybind/xpu_streams_py.cc

Lines changed: 151 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,24 @@ namespace py = pybind11;
3333
namespace paddle {
3434
namespace platform {
3535
#ifdef PADDLE_WITH_XPU
36-
XPUStream get_current_stream(int device_id) {
37-
if (device_id == -1) {
38-
device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
39-
}
36+
phi::XPUStreamHandle *get_current_stream(int device_id) {
37+
auto handle = new phi::XPUStreamHandle();
38+
return handle;
39+
}
40+
41+
phi::XPUStreamHandle *set_current_stream(int idx) {
42+
int device_id = phi::backends::xpu::GetXPUCurrentDeviceId();
43+
auto original_stream = get_current_stream(device_id);
4044
auto place = phi::XPUPlace(device_id);
4145
auto *dev_ctx = static_cast<phi::XPUContext *>(
4246
phi::DeviceContextPool::Instance().Get(place));
43-
dev_ctx->Wait();
44-
return dev_ctx->stream();
47+
dev_ctx->SetCurrentStream(idx);
48+
return original_stream;
4549
}
4650

4751
#endif
4852
} // namespace platform
53+
4954
namespace pybind {
5055
void BindXpuStream(py::module *m_ptr) {
5156
auto &m = *m_ptr;
@@ -69,7 +74,7 @@ void BindXpuStream(py::module *m_ptr) {
6974
#endif
7075
});
7176
m.def(
72-
"_get_current_stream",
77+
"_xpu_get_current_stream",
7378
[](int device_id) {
7479
#ifdef PADDLE_WITH_XPU
7580
if (device_id == -1) {
@@ -79,7 +84,19 @@ void BindXpuStream(py::module *m_ptr) {
7984
return platform::get_current_stream(device_id);
8085
#else
8186
PADDLE_THROW(
82-
common::errors::Unavailable("Paddle is not compiled with CUDA. "
87+
common::errors::Unavailable("Paddle is not compiled with XPU. "
88+
"Cannot visit device synchronize."));
89+
#endif
90+
},
91+
py::return_value_policy::reference);
92+
m.def(
93+
"_xpu_set_current_stream",
94+
[](int stream_id) {
95+
#ifdef PADDLE_WITH_XPU
96+
return platform::set_current_stream(stream_id);
97+
#else
98+
PADDLE_THROW(
99+
common::errors::Unavailable("Paddle is not compiled with XPU. "
83100
"Cannot visit device synchronize."));
84101
#endif
85102
},
@@ -101,11 +118,11 @@ void BindXpuStream(py::module *m_ptr) {
101118
});
102119

103120
#ifdef PADDLE_WITH_XPU
104-
py::class_<XPUStream>(m, "XPUStream", R"DOC(
121+
py::class_<phi::XPUStreamHandle>(m, "XPUStream", R"DOC(
105122
The handle of the CUDA stream.
106123
107124
Parameters:
108-
device(paddle.CUDAPlace()|int|None, optional): The device which wanted to allocate the stream.
125+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
109126
If device is None or negative integer, device will be the current device.
110127
If device is positive integer, it must less than the device count. Default: None.
111128
priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -116,14 +133,115 @@ void BindXpuStream(py::module *m_ptr) {
116133
117134
>>> # doctest: +REQUIRES(env:GPU)
118135
>>> import paddle
119-
>>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
120-
>>> s2 = paddle.device.cuda.Stream(0, 1)
121-
>>> s3 = paddle.device.cuda.Stream()
136+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0))
137+
>>> s2 = paddle.device.xpu.Stream(0)
138+
>>> s3 = paddle.device.xpu.Stream()
139+
140+
)DOC")
141+
.def("__init__",
142+
[](phi::XPUStreamHandle &self) {
143+
new (&self) phi::XPUStreamHandle();
144+
})
145+
.def_property_readonly(
146+
"xpu_stream",
147+
[](phi::XPUStreamHandle &self) {
148+
return reinterpret_cast<std::uintptr_t>(self.raw_stream());
149+
})
150+
.def("wait_stream",
151+
[](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other) {
152+
auto *dev_ctx = phi::get_xpu_context();
153+
dev_ctx->StreamWaitStreamInPool(self.id(), other.id());
154+
})
155+
.def("wait_event",
156+
[](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
157+
self.wait_event(other.get_event());
158+
})
159+
.def("__init__",
160+
[](phi::XPUStreamHandle &self, phi::XPUPlace *place) {
161+
if (place == nullptr) {
162+
int curr_device_id = platform::GetXPUCurrentDeviceId();
163+
auto place_tmp = phi::XPUPlace(curr_device_id);
164+
new (&self) phi::XPUStreamHandle(place_tmp);
165+
} else {
166+
new (&self) phi::XPUStreamHandle(*place);
167+
}
168+
})
169+
.def(
170+
"__init__",
171+
[](phi::XPUStreamHandle &self, int device) {
172+
if (device < 0) {
173+
device = platform::GetXPUCurrentDeviceId();
174+
}
175+
auto place_tmp = phi::XPUPlace(device);
176+
new (&self) phi::XPUStreamHandle(place_tmp);
177+
},
178+
py::arg("device") = -1)
179+
.def_property_readonly(
180+
"place",
181+
[](phi::XPUStreamHandle &self) {
182+
return phi::XPUPlace(platform::GetXPUCurrentDeviceId());
183+
})
184+
.def_property_readonly(
185+
"idx", [](phi::XPUStreamHandle &self) { return self.id(); });
186+
py::class_<phi::XPUEventHandle>(m, "XPUEvent", R"DOC(
187+
The handle of the XPU event.
188+
189+
Parameters:
190+
enable_timing(bool, optional): Whether the event will measure time. Default: False.
191+
blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
192+
interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
193+
194+
Examples:
195+
.. code-block:: python
196+
197+
>>> # doctest: +REQUIRES(env:XPU)
198+
>>> import paddle
199+
>>> event = paddle.device.xpu.Event()
200+
201+
)DOC")
202+
.def("__init__",
203+
[](phi::XPUEventHandle &self) { new (&self) phi::XPUEventHandle(); })
204+
.def(
205+
"record",
206+
[](phi::XPUEventHandle &self, phi::XPUStreamHandle *stream) {
207+
if (stream == nullptr) {
208+
auto stream_handle = phi::get_current_stream_handle();
209+
self.record(stream_handle.raw_stream());
210+
} else {
211+
self.record(stream->raw_stream());
212+
}
213+
},
214+
py::arg("stream") = nullptr)
215+
.def("query", [](phi::XPUEventHandle &self) { return self.query(); })
216+
.def("elapsed_time",
217+
[](phi::XPUEventHandle &self) {
218+
PADDLE_THROW(common::errors::Unavailable(
219+
"XPUEvent elapsed_time is not supported now"));
220+
})
221+
.def("synchronize",
222+
[](phi::XPUEventHandle &self) { self.synchronize(); });
223+
224+
py::class_<phi::XPUCUDAStream>(m, "XPUCUDAStream", R"DOC(
225+
The handle of the XPU stream.
226+
227+
Parameters:
228+
device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
229+
If device is None or negative integer, device will be the current device.
230+
If device is positive integer, it must less than the device count. Default: None.
231+
232+
Examples:
233+
.. code-block:: python
234+
235+
>>> # doctest: +REQUIRES(env:GPU)
236+
>>> import paddle
237+
>>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
238+
>>> s2 = paddle.device.xpu.Stream(0, 1)
239+
>>> s3 = paddle.device.xpu.Stream()
122240
123241
)DOC")
124242
.def(
125243
"synchronize",
126-
[](XPUStream &self) { xpu_wait(self); },
244+
[](phi::XPUCUDAStream &self) { self.Synchronize(); },
127245
R"DOC(
128246
Waits for stream tasks to complete.
129247
@@ -135,7 +253,25 @@ void BindXpuStream(py::module *m_ptr) {
135253
>>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136254
>>> s.synchronize()
137255
138-
)DOC");
256+
)DOC")
257+
.def("__init__",
258+
[](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
259+
if (priority != 1 && priority != 2) {
260+
PADDLE_THROW(common::errors::InvalidArgument(
261+
"Priority should be 1(high) or 2(normal) "));
262+
}
263+
auto stream_flag =
264+
phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking;
265+
if (place == nullptr) {
266+
int curr_device_id = platform::GetXPUCurrentDeviceId();
267+
auto place_tmp = phi::XPUPlace(curr_device_id);
268+
new (&self)
269+
phi::XPUCUDAStream(place_tmp, priority - 2, stream_flag);
270+
} else {
271+
new (&self)
272+
phi::XPUCUDAStream(*place, priority - 2, stream_flag);
273+
}
274+
});
139275
#endif
140276
}
141277
} // namespace pybind

paddle/fluid/pybind/xpu_streams_py.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
#include "pybind11/stl.h"
1919

2020
#ifdef PADDLE_WITH_XPU
21+
#include "paddle/phi/backends/xpu/xpu_context.h"
2122
#include "paddle/phi/core/xpu_cuda_stream.h"
2223
#include "xpu/runtime.h"
2324
#include "xpu/runtime_ex.h"
25+
2426
#else
2527
namespace phi {
2628
class XPUCUDAStream {};
@@ -32,7 +34,8 @@ namespace py = pybind11;
3234
namespace paddle {
3335
namespace platform {
3436
#ifdef PADDLE_WITH_XPU
35-
XPUStream get_current_stream(int device_id = -1);
37+
phi::XPUStreamHandle* get_current_stream(int device_id = -1);
38+
phi::XPUStreamHandle* set_current_stream(int idx);
3639
#endif
3740
} // namespace platform
3841
namespace pybind {

paddle/phi/api/include/tensor.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ using gpuStream_t = cudaStream_t;
2929
using gpuStream_t = hipStream_t;
3030
#endif
3131

32+
#ifdef PADDLE_WITH_XPU
33+
#include "xpu/runtime.h"
34+
#include "xpu/runtime_ex.h"
35+
#endif
36+
3237
#ifdef PADDLE_WITH_CUSTOM_DEVICE
3338
#include "paddle/phi/backends/stream.h"
3439
#endif
@@ -434,6 +439,10 @@ class PADDLE_API Tensor final {
434439
* @return gpuStream_t
435440
*/
436441
gpuStream_t stream() const;
442+
#elif defined(PADDLE_WITH_XPU)
443+
444+
void record_stream(XPUStream stream) const;
445+
437446
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
438447
/**
439448
* @brief Get the stream where the tensor is currently located

paddle/phi/api/lib/tensor.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ limitations under the License. */
4040
#include "paddle/phi/core/tensor_meta.h"
4141
#include "paddle/phi/core/tensor_utils.h"
4242

43+
#include "paddle/phi/core/memory/malloc.h"
44+
4345
namespace paddle {
4446

4547
using DeviceContextPool = experimental::DeviceContextPool;
@@ -394,6 +396,14 @@ Tensor Tensor::slice(int64_t begin_idx, int64_t end_idx) const {
394396

395397
const std::shared_ptr<phi::TensorBase> &Tensor::impl() const { return impl_; }
396398

399+
#ifdef PADDLE_WITH_XPU
400+
401+
void Tensor::record_stream(XPUStream stream) const {
402+
paddle::memory::RecordStream(
403+
std::dynamic_pointer_cast<phi::DenseTensor>(impl_)->Holder(), stream);
404+
}
405+
406+
#endif
397407
void Tensor::set_impl(const std::shared_ptr<phi::TensorBase> &impl) {
398408
impl_ = impl;
399409
}

0 commit comments

Comments
 (0)