@@ -33,19 +33,24 @@ namespace py = pybind11;
3333namespace paddle {
3434namespace platform {
3535#ifdef PADDLE_WITH_XPU
36- XPUStream get_current_stream (int device_id) {
37- if (device_id == -1 ) {
38- device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
39- }
36+ phi::XPUStreamHandle *get_current_stream (int device_id) {
37+ auto handle = new phi::XPUStreamHandle ();
38+ return handle;
39+ }
40+
41+ phi::XPUStreamHandle *set_current_stream (int idx) {
42+ int device_id = phi::backends::xpu::GetXPUCurrentDeviceId ();
43+ auto original_stream = get_current_stream (device_id);
4044 auto place = phi::XPUPlace (device_id);
4145 auto *dev_ctx = static_cast <phi::XPUContext *>(
4246 phi::DeviceContextPool::Instance ().Get (place));
43- dev_ctx->Wait ( );
44- return dev_ctx-> stream () ;
47+ dev_ctx->SetCurrentStream (idx );
48+ return original_stream ;
4549}
4650
4751#endif
4852} // namespace platform
53+
4954namespace pybind {
5055void BindXpuStream (py::module *m_ptr) {
5156 auto &m = *m_ptr;
@@ -69,7 +74,7 @@ void BindXpuStream(py::module *m_ptr) {
6974#endif
7075 });
7176 m.def (
72- " _get_current_stream " ,
77+ " _xpu_get_current_stream " ,
7378 [](int device_id) {
7479#ifdef PADDLE_WITH_XPU
7580 if (device_id == -1 ) {
@@ -79,7 +84,19 @@ void BindXpuStream(py::module *m_ptr) {
7984 return platform::get_current_stream (device_id);
8085#else
8186 PADDLE_THROW (
82- common::errors::Unavailable (" Paddle is not compiled with CUDA. "
87+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
88+ " Cannot visit device synchronize." ));
89+ #endif
90+ },
91+ py::return_value_policy::reference);
92+ m.def (
93+ " _xpu_set_current_stream" ,
94+ [](int stream_id) {
95+ #ifdef PADDLE_WITH_XPU
96+ return platform::set_current_stream (stream_id);
97+ #else
98+ PADDLE_THROW (
99+ common::errors::Unavailable (" Paddle is not compiled with XPU. "
83100 " Cannot visit device synchronize." ));
84101#endif
85102 },
@@ -101,11 +118,11 @@ void BindXpuStream(py::module *m_ptr) {
101118 });
102119
103120#ifdef PADDLE_WITH_XPU
104- py::class_<XPUStream >(m, " XPUStream" , R"DOC(
121+ py::class_<phi::XPUStreamHandle >(m, " XPUStream" , R"DOC(
105122 The handle of the CUDA stream.
106123
107124 Parameters:
108- device(paddle.CUDAPlace ()|int|None, optional): The device which wanted to allocate the stream.
125+ device(paddle.XPUPlace ()|int|None, optional): The device which wanted to allocate the stream.
109126 If device is None or negative integer, device will be the current device.
110127 If device is positive integer, it must less than the device count. Default: None.
111128 priority(int|None, optional): The priority of stream. The priority can be 1(high) or 2(normal).
@@ -116,14 +133,115 @@ void BindXpuStream(py::module *m_ptr) {
116133
117134 >>> # doctest: +REQUIRES(env:GPU)
118135 >>> import paddle
119- >>> s1 = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
120- >>> s2 = paddle.device.cuda.Stream(0, 1)
121- >>> s3 = paddle.device.cuda.Stream()
136+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0))
137+ >>> s2 = paddle.device.xpu.Stream(0)
138+ >>> s3 = paddle.device.xpu.Stream()
139+
140+ )DOC" )
141+ .def (" __init__" ,
142+ [](phi::XPUStreamHandle &self) {
143+ new (&self) phi::XPUStreamHandle ();
144+ })
145+ .def_property_readonly (
146+ " xpu_stream" ,
147+ [](phi::XPUStreamHandle &self) {
148+ return reinterpret_cast <std::uintptr_t >(self.raw_stream ());
149+ })
150+ .def (" wait_stream" ,
151+ [](phi::XPUStreamHandle &self, phi::XPUStreamHandle &other) {
152+ auto *dev_ctx = phi::get_xpu_context ();
153+ dev_ctx->StreamWaitStreamInPool (self.id (), other.id ());
154+ })
155+ .def (" wait_event" ,
156+ [](phi::XPUStreamHandle &self, phi::XPUEventHandle &other) {
157+ self.wait_event (other.get_event ());
158+ })
159+ .def (" __init__" ,
160+ [](phi::XPUStreamHandle &self, phi::XPUPlace *place) {
161+ if (place == nullptr ) {
162+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
163+ auto place_tmp = phi::XPUPlace (curr_device_id);
164+ new (&self) phi::XPUStreamHandle (place_tmp);
165+ } else {
166+ new (&self) phi::XPUStreamHandle (*place);
167+ }
168+ })
169+ .def (
170+ " __init__" ,
171+ [](phi::XPUStreamHandle &self, int device) {
172+ if (device < 0 ) {
173+ device = platform::GetXPUCurrentDeviceId ();
174+ }
175+ auto place_tmp = phi::XPUPlace (device);
176+ new (&self) phi::XPUStreamHandle (place_tmp);
177+ },
178+ py::arg (" device" ) = -1 )
179+ .def_property_readonly (
180+ " place" ,
181+ [](phi::XPUStreamHandle &self) {
182+ return phi::XPUPlace (platform::GetXPUCurrentDeviceId ());
183+ })
184+ .def_property_readonly (
185+ " idx" , [](phi::XPUStreamHandle &self) { return self.id (); });
186+ py::class_<phi::XPUEventHandle>(m, " XPUEvent" , R"DOC(
187+ The handle of the XPU event.
188+
189+ Parameters:
190+ enable_timing(bool, optional): Whether the event will measure time. Default: False.
191+ blocking(bool, optional): Whether the wait() func will be blocking. Default: False;
192+ interprocess(bool, optional): Whether the event can be shared between processes. Default: False.
193+
194+ Examples:
195+ .. code-block:: python
196+
197+ >>> # doctest: +REQUIRES(env:XPU)
198+ >>> import paddle
199+ >>> event = paddle.device.xpu.Event()
200+
201+ )DOC" )
202+ .def (" __init__" ,
203+ [](phi::XPUEventHandle &self) { new (&self) phi::XPUEventHandle (); })
204+ .def (
205+ " record" ,
206+ [](phi::XPUEventHandle &self, phi::XPUStreamHandle *stream) {
207+ if (stream == nullptr ) {
208+ auto stream_handle = phi::get_current_stream_handle ();
209+ self.record (stream_handle.raw_stream ());
210+ } else {
211+ self.record (stream->raw_stream ());
212+ }
213+ },
214+ py::arg (" stream" ) = nullptr )
215+ .def (" query" , [](phi::XPUEventHandle &self) { return self.query (); })
216+ .def (" elapsed_time" ,
217+ [](phi::XPUEventHandle &self) {
218+ PADDLE_THROW (common::errors::Unavailable (
219+ " XPUEvent elapsed_time is not supported now" ));
220+ })
221+ .def (" synchronize" ,
222+ [](phi::XPUEventHandle &self) { self.synchronize (); });
223+
224+ py::class_<phi::XPUCUDAStream>(m, " XPUCUDAStream" , R"DOC(
225+ The handle of the XPU stream.
226+
227+ Parameters:
228+ device(paddle.XPUPlace()|int|None, optional): The device which wanted to allocate the stream.
229+ If device is None or negative integer, device will be the current device.
230+ If device is positive integer, it must less than the device count. Default: None.
231+
232+ Examples:
233+ .. code-block:: python
234+
235+ >>> # doctest: +REQUIRES(env:GPU)
236+ >>> import paddle
237+ >>> s1 = paddle.device.xpu.Stream(paddle.XPUPlace(0), 1)
238+ >>> s2 = paddle.device.xpu.Stream(0, 1)
239+ >>> s3 = paddle.device.xpu.Stream()
122240
123241 )DOC" )
124242 .def (
125243 " synchronize" ,
126- [](XPUStream &self) { xpu_wait ( self); },
244+ [](phi::XPUCUDAStream &self) { self. Synchronize ( ); },
127245 R"DOC(
128246 Waits for stream tasks to complete.
129247
@@ -135,7 +253,25 @@ void BindXpuStream(py::module *m_ptr) {
135253 >>> s = paddle.device.cuda.Stream(paddle.CUDAPlace(0), 1)
136254 >>> s.synchronize()
137255
138- )DOC" );
256+ )DOC" )
257+ .def (" __init__" ,
258+ [](phi::XPUCUDAStream &self, phi::XPUPlace *place, int priority) {
259+ if (priority != 1 && priority != 2 ) {
260+ PADDLE_THROW (common::errors::InvalidArgument (
261+ " Priority should be 1(high) or 2(normal) " ));
262+ }
263+ auto stream_flag =
264+ phi::XPUCUDAStream::StreamFlag::kStreamNonBlocking ;
265+ if (place == nullptr ) {
266+ int curr_device_id = platform::GetXPUCurrentDeviceId ();
267+ auto place_tmp = phi::XPUPlace (curr_device_id);
268+ new (&self)
269+ phi::XPUCUDAStream (place_tmp, priority - 2 , stream_flag);
270+ } else {
271+ new (&self)
272+ phi::XPUCUDAStream (*place, priority - 2 , stream_flag);
273+ }
274+ });
139275#endif
140276}
141277} // namespace pybind
0 commit comments