Skip to content

Commit 85fe498

Browse files
SigureMoCopilot
andauthored
[SOT] Non-break support for paddle.get_device (#72004)
--------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent e8e5427 commit 85fe498

File tree

6 files changed

+183
-4
lines changed

6 files changed

+183
-4
lines changed

python/paddle/device/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,6 @@ def get_device() -> str:
342342
elif isinstance(place, core.IPUPlace):
343343
num_devices = core.get_ipu_device_count()
344344
device = f"ipus:{{0-{num_devices - 1}}}"
345-
device = f"ipus:{{0-{num_devices - 1}}}"
346345
elif isinstance(place, core.CustomPlace):
347346
device_id = place.get_device_id()
348347
device_type = place.get_device_type()

python/paddle/jit/sot/opcode_translator/executor/dispatch_functions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,11 @@ def tensor_dim(x):
5959

6060
def generator_send(x):
6161
pass
62+
63+
64+
def place_get_device_id():
65+
pass
66+
67+
68+
def place_get_device_type():
69+
pass

python/paddle/jit/sot/opcode_translator/executor/variable_dispatch.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
operator_is_none,
4848
operator_is_not_none,
4949
operator_not_in,
50+
place_get_device_id,
51+
place_get_device_type,
5052
tensor_dim,
5153
)
5254
from .dispatcher import Dispatcher, optional
@@ -1586,3 +1588,15 @@ def dispatch_all(var: ContainerVariable | IterVariable):
15861588
ufunc,
15871589
),
15881590
)
1591+
1592+
# place
1593+
Dispatcher.register(
1594+
place_get_device_id,
1595+
("PlaceVariable",),
1596+
lambda var: var.get_device_id(),
1597+
)
1598+
Dispatcher.register(
1599+
place_get_device_type,
1600+
("PlaceVariable",),
1601+
lambda var: var.get_device_type(),
1602+
)

python/paddle/jit/sot/opcode_translator/executor/variables/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
NumpyVariable,
3333
ObjectVariable,
3434
ParameterVariable,
35+
PlaceVariable,
3536
SliceVariable,
3637
SuperVariable,
3738
SymbolicVariable,

python/paddle/jit/sot/opcode_translator/executor/variables/basic.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424

2525
import paddle
26+
from paddle._typing import unreached
2627
from paddle.framework import core
2728

2829
from ....infer_meta import (
@@ -54,7 +55,11 @@
5455
InnerError,
5556
UnsupportedPaddleAPIBreak,
5657
)
57-
from ..dispatch_functions import tensor_dim
58+
from ..dispatch_functions import (
59+
place_get_device_id,
60+
place_get_device_type,
61+
tensor_dim,
62+
)
5863
from ..guard import (
5964
FasterStringifiedExpression,
6065
StringifiedExpression,
@@ -1428,7 +1433,7 @@ def make_stringified_guard(self) -> None:
14281433

14291434
@VariableFactory.register_from_value()
14301435
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1431-
if isinstance(value, (np.ndarray)):
1436+
if isinstance(value, np.ndarray):
14321437
return NumpyArrayVariable(value, graph, tracker)
14331438
return None
14341439

@@ -1483,7 +1488,7 @@ def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
14831488
class NumpyBoolVariable(NumpyNumberVariable):
14841489
@VariableFactory.register_from_value()
14851490
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1486-
if isinstance(value, (np.bool_)):
1491+
if isinstance(value, np.bool_):
14871492
return NumpyBoolVariable(value, graph, tracker)
14881493
return None
14891494

@@ -1519,6 +1524,54 @@ def make_stringified_guard(self) -> list[StringifiedExpression]:
15191524
]
15201525

15211526

1527+
class PlaceVariable(ObjectVariable):
1528+
def __init__(self, obj, graph, tracker):
1529+
super().__init__(obj, graph, tracker)
1530+
1531+
def getattr(self, name: str, default=None):
1532+
if default is not None:
1533+
raise FallbackError(
1534+
"default argument for getattr is not implemented"
1535+
)
1536+
if name not in ["get_device_id", "get_device_type"]:
1537+
return super().getattr(name, default)
1538+
from .callable import BuiltinVariable
1539+
1540+
if name == "get_device_id":
1541+
return BuiltinVariable(
1542+
place_get_device_id, self.graph, DanglingTracker()
1543+
).bind_dangling_fn(self, name)
1544+
elif name == "get_device_type":
1545+
return BuiltinVariable(
1546+
place_get_device_type, self.graph, DanglingTracker()
1547+
).bind_dangling_fn(self, name)
1548+
unreached()
1549+
1550+
def get_device_id(self):
1551+
return VariableFactory.from_value(
1552+
self.value.get_device_id(), self.graph, DummyTracker([self])
1553+
)
1554+
1555+
def get_device_type(self):
1556+
return VariableFactory.from_value(
1557+
self.value.get_device_type(), self.graph, DummyTracker([self])
1558+
)
1559+
1560+
@VariableFactory.register_from_value()
1561+
def from_value(value: Any, graph: FunctionGraph, tracker: Tracker):
1562+
if paddle.is_compiled_with_cuda() and isinstance(
1563+
value, (paddle.CUDAPlace, paddle.CUDAPinnedPlace)
1564+
):
1565+
return PlaceVariable(value, graph, tracker)
1566+
if paddle.is_compiled_with_xpu() and isinstance(
1567+
value, (paddle.XPUPlace, paddle.XPUPinnedPlace)
1568+
):
1569+
return PlaceVariable(value, graph, tracker)
1570+
if isinstance(value, paddle.CustomPlace):
1571+
return PlaceVariable(value, graph, tracker)
1572+
return None
1573+
1574+
15221575
class NullVariable(VariableBase):
15231576
"""
15241577
NullVariable is a subclass of VariableBase used to represent a placeholder variable that has no value or reference associated with it.

test/sot/test_sot_place.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import unittest
18+
from contextlib import contextmanager
19+
20+
from test_case_base import (
21+
TestCaseBase,
22+
test_instruction_translator_cache_context,
23+
)
24+
25+
import paddle
26+
from paddle.jit.sot.psdb import check_no_breakgraph
27+
28+
29+
@contextmanager
30+
def device_guard(place: str):
31+
original_place = paddle.get_device()
32+
try:
33+
paddle.set_device(place)
34+
yield
35+
finally:
36+
paddle.set_device(original_place)
37+
38+
39+
@check_no_breakgraph
40+
def run_diff_logic_by_check_expected_place(x: paddle.Tensor):
41+
expected_place_str = paddle.get_device()
42+
if "cpu" in expected_place_str:
43+
return x + 1
44+
elif "gpu" in expected_place_str:
45+
return x + 2
46+
elif "xpu" in expected_place_str:
47+
return x + 3
48+
elif "npu" in expected_place_str:
49+
return x + 4
50+
return x
51+
52+
53+
class TestCheckExpectedPlace(TestCaseBase):
54+
def test_check_cpu(self):
55+
x = paddle.to_tensor(0.0)
56+
with device_guard("cpu"):
57+
self.assert_results(run_diff_logic_by_check_expected_place, x.cpu())
58+
59+
@unittest.skipUnless(
60+
paddle.is_compiled_with_cuda(),
61+
"This test case needs to be compiled with CUDA",
62+
)
63+
def test_check_gpu(self):
64+
x = paddle.to_tensor(0.0)
65+
with device_guard("gpu"):
66+
self.assert_results(
67+
run_diff_logic_by_check_expected_place, x.cuda()
68+
)
69+
70+
@unittest.skipUnless(
71+
paddle.is_compiled_with_xpu(),
72+
"This test case needs to be compiled with XPU",
73+
)
74+
def test_check_xpu(self):
75+
x = paddle.to_tensor(0.0)
76+
with device_guard("xpu"):
77+
self.assert_results(
78+
run_diff_logic_by_check_expected_place, x.to("xpu")
79+
)
80+
81+
82+
class TestExpectedPlaceGuard(TestCaseBase):
83+
@unittest.skipUnless(
84+
paddle.is_compiled_with_cuda(),
85+
"This test case needs to be compiled with cuda",
86+
)
87+
def test_expected_place_guard(self):
88+
x = paddle.to_tensor(0.0)
89+
with test_instruction_translator_cache_context() as ctx:
90+
self.assertEqual(ctx.translate_count, 0)
91+
with device_guard("cpu"):
92+
self.assert_results(
93+
run_diff_logic_by_check_expected_place, x.cpu()
94+
)
95+
self.assertEqual(ctx.translate_count, 1)
96+
with device_guard("gpu"):
97+
self.assert_results(
98+
run_diff_logic_by_check_expected_place, x.cuda()
99+
)
100+
self.assertEqual(ctx.translate_count, 2)
101+
102+
103+
if __name__ == "__main__":
104+
unittest.main()

0 commit comments

Comments
 (0)