There was an error while loading. Please reload this page.
1 parent 8d71b11 commit 6c22d40Copy full SHA for 6c22d40
source/tests/pd/model/test_deeppot.py
@@ -42,7 +42,7 @@ def setUp(self) -> None:
42
trainer = get_trainer(deepcopy(self.config))
43
trainer.run()
44
45
- with paddle.device("cpu"):
+ with paddle.device.device_guard("cpu"):
46
input_dict, label_dict, _ = trainer.get_data(is_train=False)
47
trainer.wrapper(**input_dict, label=label_dict, cur_lr=1.0)
48
self.model = "model.pd"
0 commit comments