PyTorch/XLA documentation¶
torch_xla is a Python package that implements XLA as a backend for PyTorch.
Familiar APIs Create and train PyTorch models on TPUs, with only minimal changes required. | High Performance Scale training jobs across thousands of TPU cores while maintaining high MFU. | Cost Efficient TPU hardware and the XLA compiler are optimized for cost-efficient training and inference. |
Getting Started¶
Install with pip.
pip install torch torch_xla[tpu] Verify the installation:
python -c "import torch_xla; print(torch_xla.__version__)" python -c "import torch; import torch_xla; print(torch.tensor(1.0, device='xla').device)" Tutorials¶
Learn the Basics
Distributed Training on TPU
Advanced Techniques
Troubleshooting