Skip to content

Commit 5ee9fcb

Browse files
faaanyamyeroberts
authored andcommitted
Fix wrong xpu device in DistributedType.MULTI_XPU mode (#28386)
* remove elif xpu * remove redudant code
1 parent e156abd commit 5ee9fcb

File tree

1 file changed

+0
-11
lines changed

1 file changed

+0
-11
lines changed

src/transformers/training_args.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,11 +1844,6 @@ def _setup_devices(self) -> "torch.device":
18441844
device = torch.device("cuda", local_rank)
18451845
self._n_gpu = 1
18461846
torch.cuda.set_device(device)
1847-
elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ:
1848-
os.environ["ACCELERATE_USE_XPU"] = "true"
1849-
self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout))
1850-
device = torch.device("xpu:0")
1851-
self._n_gpu = 1
18521847
elif is_sagemaker_dp_enabled():
18531848
self.distributed_state = PartialState(_use_sagemaker_dp=True)
18541849
self._n_gpu = 1
@@ -1877,12 +1872,6 @@ def _setup_devices(self) -> "torch.device":
18771872
elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled():
18781873
# Already set _n_gpu
18791874
pass
1880-
elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU:
1881-
if "ACCELERATE_USE_XPU" not in os.environ:
1882-
os.environ["ACCELERATE_USE_XPU"] = "true"
1883-
self._n_gpu = 1
1884-
device = torch.device("xpu:0")
1885-
torch.xpu.set_device(device)
18861875
elif self.distributed_state.distributed_type == DistributedType.NO:
18871876
if self.use_mps_device:
18881877
warnings.warn(

0 commit comments

Comments
 (0)