Skip to content

Conversation

@will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Jun 13, 2024

GOOGLE_CUDA is defined in bazel by OpenXLA. This will be unset by default. Therefore, we'll use dynamic plugins by default, enabling automatic discovery of the CUDA plugin if you're using the normal/tpu build.

Tested:

# python >>> import os >>> os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' >>> os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' >>> import torch_xla 2024-06-13 17:39:34.217260: I torch_xla/csrc/runtime/pjrt_registry.cc:70] Registering PjRt plugin NEURON 2024-06-13 17:39:34.217363: I torch_xla/csrc/runtime/pjrt_registry.cc:70] Registering PjRt plugin TPU 2024-06-13 17:39:34.217400: I torch_xla/csrc/runtime/pjrt_registry.cc:70] Registering PjRt plugin XPU >>> torch_xla.device() WARNING:root:PJRT is now the default runtime. For more information, see https://github.com/pytorch/xla/blob/master/docs/pjrt.md WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU. 2024-06-13 17:39:39.318121: I torch_xla/csrc/runtime/pjrt_registry.cc:83] Initializing client for PjRt plugin TPU WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1718300379.343529 1209222 pjrt_api.cc:99] GetPjrtApi was found for tpu at /usr/local/lib/python3.10/site-packages/libtpu/libtpu.so I0000 00:00:1718300379.343613 1209222 pjrt_api.cc:78] PJRT_Api is set for device type tpu I0000 00:00:1718300379.343620 1209222 pjrt_api.cc:145] The PJRT plugin has PJRT API version 0.54. The framework PJRT API version is 0.54. 2024-06-13 17:39:42.233641: I external/xla/xla/pjrt/pjrt_c_api_client.cc:127] PjRtCApiClient created. device(type='xla', index=0) 

Enables #7249

Nothing changes with the old CUDA build.

# PJRT_DEVICE=CUDA python Python 3.10.14 (main, May 14 2024, 08:51:34) [GCC 10.2.1 20210110] on linux Type "help", "copyright", "credits" or "license" for more information. >>> os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' Traceback (most recent call last): File "<stdin>", line 1, in <module> NameError: name 'os' is not defined >>> import os >>> os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' >>> os.environ['TF_CPP_VMODULE'] = 'pjrt_registry=5' >>> import torch_xla >>> torch_xla.devices() 2024-06-13 19:55:16.605772: I torch_xla/csrc/runtime/pjrt_registry.cc:147] Initializing PjRt GPU client... 

See #6242

@will-cromar will-cromar added runtime usability Bugs/features related to improving the usability of PyTorch/XLA tpuci labels Jun 13, 2024
@will-cromar will-cromar requested a review from JackCaoG June 13, 2024 19:56
@will-cromar will-cromar marked this pull request as ready for review June 13, 2024 19:56
@will-cromar will-cromar merged commit 1cad403 into master Jun 13, 2024
will-cromar added a commit that referenced this pull request Jun 13, 2024
will-cromar added a commit that referenced this pull request Jun 13, 2024
bhavya01 pushed a commit that referenced this pull request Jun 14, 2024
Co-authored-by: Aman Gupta <4409685+aman2930@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

runtime usability Bugs/features related to improving the usability of PyTorch/XLA

3 participants