Skip to content
1 change: 1 addition & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
set_grad_enabled,
)
from .device import ( # noqa: F401
device_guard,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是想暴露到 paddle.device_guard?具体公开 API 是哪个?paddle.device.device_guard 还是 paddle.device_guard

如果确定是公开 API 需要将其加到对应模块的 __all__

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

get_cudnn_version,
get_device,
is_compiled_with_cinn,
Expand Down
68 changes: 68 additions & 0 deletions python/paddle/device/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from paddle import IPUPlace as _IPUPlace, XPUPlace as _XPUPlace
from paddle._typing.device_like import PlaceLike
from paddle.base.core import Place

_InitStreamBase = Union[core.CUDAStream, core.CustomDeviceStream]
_InitEventBase = Union[core.CUDAEvent, core.CustomDeviceEvent]
Expand Down Expand Up @@ -77,6 +78,7 @@
'current_stream',
'set_stream',
'stream_guard',
'device_guard',
'synchronize',
]

Expand Down Expand Up @@ -1109,6 +1111,72 @@ def __exit__(
set_stream(self.src_prev_stream)


class device_guard:
'''

Notes:
This API only supports dynamic graph mode currently.

A context manager that specifies the current device context by the given device.

Args:
device(PlaceLike): The specified device.

Examples:
.. code-block:: python

>>> # doctest: +REQUIRES(env:GPU)
>>> import paddle

>>> # Set the global default device to CPU
>>> paddle.set_device("cpu")
>>> # Temporarily switch to GPU:0 using device_guard with string input
>>> with paddle.device.device_guard("gpu:0"):
... x = paddle.randn([4, 4]) # Create a Tensor on GPU:0
... x = x.tanh() * 2 # Perform computation on GPU:0
... print(x.place) # Check the device of the Tensor
Place(gpu:0)

>>> # Set the global default device to GPU:0
>>> paddle.set_device("gpu:0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加到公开 API 后,预期这一行会挂,#73982 已确定现在 CI 环境有问题,#73992 在修

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

>>> # Temporarily switch to CPU using device_guard with Place object (CPUPlace)
>>> cpu_place = paddle.CPUPlace()
>>> with paddle.device.device_guard(cpu_place):
... x = paddle.randn([4, 4]) # Create a Tensor on CPU
... x = x.tanh() * 2 # Perform computation on CPU
... print(x.place)
Place(cpu)
'''

_target_place: Place
_original_place: Place

def __init__(self, device: PlaceLike) -> None:
if isinstance(device, str):
self._target_place = paddle.device._convert_to_place(device)
elif isinstance(device, paddle.base.libpaddle.Place):
self._target_place = device
else:
raise ValueError(
"'device' must be a string or an instance of a subclass of "
f"paddle.base.libpaddle.Place, but got {type(device)}"
)

def __enter__(self) -> None:
self._original_place = paddle.framework._current_expected_place_()
if self._original_place != self._target_place:
paddle.framework._set_expected_place(self._target_place)

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._original_place != self._target_place:
paddle.framework._set_expected_place(self._original_place)


def synchronize(device: PlaceLike | None = None) -> None:
"""

Expand Down
136 changes: 136 additions & 0 deletions test/legacy_test/test_place_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from utils import dygraph_guard

import paddle


class TestPlaceGuard(unittest.TestCase):
def test_str_place_obj_consistency(self):
places = [
["cpu", paddle.CPUPlace()],
]
if paddle.device.is_compiled_with_cuda():
places.append(["gpu", paddle.CUDAPlace(0)])
places.append(["gpu:0", paddle.CUDAPlace(0)])
elif paddle.device.is_compiled_with_ipu():
places.append(["ipu", paddle.IPUPlace()])
elif paddle.device.is_compiled_with_xpu():
places.append(["xpu:0", paddle.XPUPlace(0)])

with dygraph_guard():
for place_str, place_obj in places:
with paddle.device.device_guard(place_str):
x = paddle.randn([2, 2])
x = x.tanh() ** 2
self.assertEqual(x.place, place_obj)

def test_str_place_obj_scope_in_device(self):
places = []
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
places.append(paddle.CUDAPlace(0))
elif paddle.device.is_compiled_with_ipu():
places.append(paddle.IPUPlace())
elif paddle.device.is_compiled_with_xpu():
places.append(paddle.XPUPlace(0))
places.append(paddle.XPUPlace(0))

with dygraph_guard():
for place_obj in places:
x = paddle.randn([2, 2]) # create on default place
with paddle.device.device_guard("cpu"):
x = (
x.tanh() ** 2
) # should be still in place rather than cpu
self.assertNotEqual(x.place, paddle.CPUPlace())
self.assertEqual(x.place, place_obj)

def test_wrong_device_name(self):
with (
dygraph_guard(),
self.assertRaisesRegex(
ValueError,
"The device must be a string which is like 'cpu', 'gpu', 'gpu:x',",
),
paddle.device.device_guard("xxx"),
):
pass

def test_wrong_device_type(self):
with (
dygraph_guard(),
self.assertRaisesRegex(
ValueError,
"'device' must be a string or an instance of a subclass of",
),
paddle.device.device_guard(paddle.randn([2])),
):
pass

def test_str_place_obj_nested(self):
places = [paddle.CPUPlace()]
if paddle.device.is_compiled_with_cuda():
places.append(paddle.CUDAPlace(0))
places.append(paddle.CUDAPlace(0))
elif paddle.device.is_compiled_with_ipu():
places.append(paddle.IPUPlace())
elif paddle.device.is_compiled_with_xpu():
places.append(paddle.XPUPlace(0))
places.append(paddle.XPUPlace(0))

if len(places) >= 2:
place_obj1, place_obj2 = places[:2]
else:
self.skipTest("Not compiled with HPC hardware.")

with dygraph_guard():
with paddle.device.device_guard(place_obj1):
x = paddle.randn([2, 2]) # create on place1
self.assertEqual(x.place, place_obj1)
self.assertNotEqual(x.place, place_obj2)

with paddle.device.device_guard(place_obj2):
xx = paddle.randn([2, 2]) # create on place1
self.assertEqual(xx.place, place_obj2)
self.assertNotEqual(xx.place, place_obj1)

with paddle.device.device_guard(place_obj1):
xxx = paddle.randn([2, 2]) # create on place1
self.assertEqual(xxx.place, place_obj1)
self.assertNotEqual(xxx.place, place_obj2)

with paddle.device.device_guard(place_obj2):
xxxx = paddle.randn([2, 2]) # create on place1
self.assertEqual(xxxx.place, place_obj2)
self.assertNotEqual(xxxx.place, place_obj1)

self.assertEqual(xxxx.place, place_obj2)
self.assertNotEqual(xxxx.place, place_obj1)

self.assertEqual(xxx.place, place_obj1)
self.assertNotEqual(xxx.place, place_obj2)

self.assertEqual(xx.place, place_obj2)
self.assertNotEqual(xx.place, place_obj1)

self.assertEqual(x.place, place_obj1)
self.assertNotEqual(x.place, place_obj2)


if __name__ == "__main__":
unittest.main()