Skip to content

Commit 815ba53

Browse files
committed
add warning message.
1 parent cadccba commit 815ba53

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

torch_xla/core/xla_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
import threading
77
import time
8+
import warnings
89
from typing import List, Optional
910
import torch
1011
import torch.distributed._functional_collectives
@@ -88,6 +89,11 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
8889
Returns:
8990
The list of device strings.
9091
"""
92+
# TODO(xiowei): Remove the below if statement after r2.2 release.
93+
if devkind.casefold() == 'gpu':
94+
warnings.warn("GPU as a device name is being deprecate. Please replace it with CUDA such as get_xla_supported_devices(devkind='CUDA'). Similarly, please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.")
95+
devkind = 'CUDA'
96+
9197
xla_devices = _DEVICES.value
9298
devkind = [devkind] if devkind else [
9399
'TPU', 'GPU', 'XPU', 'NEURON', 'CPU', 'CUDA', 'ROCM'

torch_xla/runtime.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ def xla_device(n: Optional[int] = None,
107107
Returns:
108108
A `torch.device` representing an XLA device.
109109
"""
110+
# TODO(xiowei): Remove the warning message at r2.2 release.
111+
pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
112+
if pjrt_device.casefold() == 'gpu':
113+
warnings.warn('PJRT_DEVICE=GPU is being deprecate. Please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.')
114+
110115
if n is None:
111116
return torch.device(torch_xla._XLAC._xla_get_default_device())
112117

0 commit comments

Comments
 (0)