-
Couldn't load subscription status.
- Fork 560
Closed
Description
🐛 Bug
torch_xla.runtime.global_runtime_device_count() always return one and the value of GPU_NUM_DEVICES is set to four.
To Reproduce
run PJRT_DEVICE=CUDA GPU_NUM_DEVICES=4 python test.py
the content of test.py as follows
import torch_xla from torch_xla import runtime as xr print('global_runtime_device_count: ', xr.global_runtime_device_count()) output is
global_runtime_device_count: 1 Expected behavior
I want the value of xr.global_runtime_device_count() to be equal to the value of GPU_NUM_DEVICES.
or
There is a mistake in my understanding of xr.global_runtime_device_count()
Environment
- Reproducible on XLA backend [CPU/TPU]: CUDA
- torch_xla version: 2.2.0+git4038f8e
Additional context
None
Metadata
Metadata
Assignees
Labels
No labels