Skip to content

Commit 30f6a8c

Browse files
committed
fix comment and test failure.
1 parent 815ba53 commit 30f6a8c

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

.circleci/common.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ function run_torch_xla_python_tests() {
131131
if [ -x "$(command -v nvidia-smi)" ]; then
132132
# These tests fail on CUDA with 03/30 TF-pin update (https://github.com/pytorch/xla/pull/4840)
133133
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
134+
# TODO(xiowei replace gpu with cuda): remove the test below with PJRT_DEVICE=GPU because PJRT_DEVICE=GPU is being deprecated.
135+
PJRT_DEVICE=GPU python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
134136
PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --auto_wrap_policy type_based --use_small_fake_sample --num_epochs=1
135137
XLA_DISABLE_FUNCTIONALIZATION=1 PJRT_DEVICE=CUDA python test/test_train_mp_imagenet_fsdp.py --fake_data --use_nested_fsdp --use_small_fake_sample --num_epochs=1
136138
# Syncfree SGD optimizer tests

torch_xla/core/xla_model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ def get_xla_supported_devices(devkind=None, max_devices=None):
8989
Returns:
9090
The list of device strings.
9191
"""
92-
# TODO(xiowei): Remove the below if statement after r2.2 release.
93-
if devkind.casefold() == 'gpu':
94-
warnings.warn("GPU as a device name is being deprecate. Please replace it with CUDA such as get_xla_supported_devices(devkind='CUDA'). Similarly, please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.")
92+
# TODO(xiowei replace gpu with cuda): Remove the below if statement after r2.2 release.
93+
if devkind and devkind.casefold() == 'gpu':
94+
warnings.warn(
95+
"GPU as a device name is being deprecate. Please replace it with CUDA such as get_xla_supported_devices(devkind='CUDA'). Similarly, please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA."
96+
)
9597
devkind = 'CUDA'
9698

9799
xla_devices = _DEVICES.value

torch_xla/runtime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,12 @@ def xla_device(n: Optional[int] = None,
107107
Returns:
108108
A `torch.device` representing an XLA device.
109109
"""
110-
# TODO(xiowei): Remove the warning message at r2.2 release.
110+
# TODO(xiowei replace gpu with cuda): Remove the warning message at r2.2 release.
111111
pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
112112
if pjrt_device.casefold() == 'gpu':
113-
warnings.warn('PJRT_DEVICE=GPU is being deprecate. Please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.')
113+
warnings.warn(
114+
'PJRT_DEVICE=GPU is being deprecate. Please replace PJRT_DEVICE=GPU with PJRT_DEVICE=CUDA.'
115+
)
114116

115117
if n is None:
116118
return torch.device(torch_xla._XLAC._xla_get_default_device())

0 commit comments

Comments
 (0)