Skip to content

[RFC] Flatten basic APIs #6399

@will-cromar

Description

@will-cromar

TL;DR

Move some basic functionality to the top-level torch_xla module, including:

  • Getting and counting addressable XLA devices
  • Launching distributed training in various environments
  • Synchronizing the graph

Introduction

The ergonomics of PyTorch/XLA could use some work.

Reviewing our Getting Started diffs, converting even the most basic PyTorch code to use XLA requires 3 imports (non-torch.distributed) or 4 imports (torch.distributed), because the fully qualified names for even basic functions are excessively verbose. For example, to even get the XLA device, the name mentions XLA 3 times: torch_xla.core.xla_model.xla_device.

This is most noticeable to me when I'm prototyping a new script or debugging in an interactive REPL or notebook. Although we have limited concrete feedback from new users, I expect the verbosity is a deterrent to someone seeing PyTorch/XLA for the first time.

Objectives

My proposal is simple: pull fundamental functionality to the top-level torch_xla package. I'll define fundamental functionality as the following:

  • The basic functions that you could not reasonably use PyTorch/XLA without
  • The first functions you run when you open a new REPL or notebook

We can leave more advanced functions organized into lower-level packages. The discussion of usability within the PyTorch/XLA team more or less comes down to how we expose complexity gradually to users, rather than dumping all of the complexity at once. Users can then dig deeper as needed to improve performance (e.g. data preloading, SPMD, etc) once they're up and running. There are more proposals coming in that vein. The goal of this proposal specifically is to make the basics clear and accessible.

I'm intentionally taking a conservative approach here. It's much easier to add a visible function to our public API than remove it. For that reason, we should start with the APIs we're most confident will remain consistent in the long-term.

Proposal

The easy one: torch_xla.device()

There is no way to construct a PyTorch/XLA program without actually moving data and modules to the XLA device. So, there's no reason to bury this function in our API. torch_xla.device would effectively just be an alias to the existing xm.xla_device(), minus the deprecated devkind argument.

Listing and counting devices

Although these functions are not always required to write PyTorch/XLA programs like xm.xla_device(), they are a first-line debugging tool users can use to inspect their environment. I propose the following new functions:

  • torch_xla.devices(): returns a list of local torch.devices, e.g. ['xla:0', ..., 'xla:3']. Again, this is effectively a more convenient alias to xm.get_xla_supported_devices().
  • torch_xla.real_devices(): returns a list of strings representing the "real" underlying devices available, e.g. ['TPU:0', ..., 'TPU:3']. This would be similar to xm.xla_real_devices(xm.get_xla_supported_devices()), except we will return the real-real devices even in the SPMD case. That is, we would return a list of TPU:i instead of just SPMD:0. In practice, I view this as a debugging tool to check which devices are addressable in the current process.
  • torch_xla.device_count(): Essentially len(torch_xla.real_devices()). There are a ton of different counts and indices available in torch_xla.runtime; this one corresponds to the "addressable runtime device" count. I believe this is the most intuitive one to call the torch_xla.device_count(). If most users actually have to learn the difference between "local devices", "addressable devices", etc, I believe we have made a mistake.

torch_xla.launch(): a better xmp.spawn()

Two years ago, there was only one supported way to launch PyTorch/XLA: xmp.spawn. Now, there are three ways:

  • xmp.spawn for automatic multiprocess execution from a Python process.
  • torchrun for multiprocess execution from a shell (required for multi-host GPU). Requires slightly more manual configuration than the above, such as setting the master address and number of devices manually.
  • Single-process for SPMD

Each of these ways requires a slightly different script, making it difficult to to switch between them. You can't, for example, use xmp.spawn in conjunction with torchrun. This makes it harder than it should be to experiment with different modes of execution.

I propose to wrap all three cases under an optional utility function, torch_xla.launch. Specifically:

  • Default to emulating xmp.spawn, including automatic device discovery and environment configuration, plus some improvements:
    • If in a notebook environment like Kaggle, detect that and use spawn_method='fork' automatically.
    • Add a flag to launch single-process execution with portable XLA executables for debugging. This would be a better version of the nprocs=1 case in xmp.spawn, where collectives would not hang indefinitely on TPU v2 and v3. Footnote: we hack around that for all_reduce by inserting a no-op
  • When using SPMD and/or torchrun, don't launch any child processes, and assume environment configuration is already done.

Register the torch.distributed backend

Right now, registering the xla backend and init_method for torch.distributed requires an extra import. Since registration is not destructive, we can safely do that registration in torch_xla instead to save users an import.

Graph synchronization

This is the proposal I'm least confident about. We've had many discussions within the team about how we can bury the details of lazy execution and synchronization from users in the most common cases. That is still one of my top goals in terms of usability. But, I don't think we can avoid users having to occasionally trigger a sync manually, and adding syncs will continue to be a primary way we debug misbehaving code. Even if I am wrong about the last sentence (and I really hope I am), mark_step will continue to be a fundamental function for framework authors, including us as well as wrappers like PyTorch Lightning and HuggingFace Accelerate.

Although I begrudgingly accept that we must add mark_step to the top-level API, we can do a much better job with the name. It was not at all clear to me at least what mark_step was supposed to mean when I started with PyTorch/XLA. I propose we name the top-level function torch_xla.sync(). It's the term we use for the operation internally within our code, and I think it's a term that is much more familiar to software developers in general.

I also propose one additional behavior change between torch_xla.sync() and xm.mark_step(): set wait to true by default, plus implement @yeounoh's suggestion to call wait_device_ops in that case. Actually blocking the current thread until all operations are complete is a much more intuitive default behavior in my opinion.

Updated 2024/1/30: wait should not be True by default. See the discussion below.

New convention: import torch_xla as xla

We can update our docs and examples to start using xla as an alias to our top-level module. In context, the torch part is implicit anyway, and this way functions read more naturally, e.g. xla.device(), xla.launch(), xla.sync()

Example

Using the changes proposed above, let's update the torch.distributed-based "Getting started" example from our README:

import torch.distributed as dist -import torch.multiprocessing as mp +import torch_xla as xla def _mp_fn(rank): ... - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - dist.init_process_group("gloo", rank=rank, world_size=world_size) + dist.init_process_group("xla", init_method='xla://') - model = model.to(rank) + model.to(xla.device()) - ddp_model = DDP(model, device_ids=[rank]) + # `gradient_as_bucket_view=True` required for XLA + ddp_model = DDP(model, gradient_as_bucket_view=True) for inputs, labels in train_loader: optimizer.zero_grad() - outputs = ddp_model(inputs) + outputs = ddp_model(inputs.to(xla.device()) - loss = loss_fn(outputs, labels) + loss = loss_fn(outputs, labels.to(xla.device())) loss.backward() optimizer.step() + xla.sync() if __name__ == '__main__': - mp.spawn(_mp_fn, args=(), nprocs=world_size) + xla.launch(_mp_fn)

Note that with the proposed changes to xla.launch, this would work with SPMD, torchrun, interactive notebooks, and "normal" multiprocessing without changes.

Final notes

Although some of the usability problems in PyTorch/XLA run deep, cleaning up our basic APIs is an easy way we can provide a better impression to new and potential users.

I intentionally omitted two APIs that are present in our other "Getting started" example: pl.MpDeviceLoader and xm.optimizer_step(). I did this for two reasons:

  • I believe we should be leaning into torch.distributed, which can wrap the low-level details of collective operations, including gradient collection. Users should be able to pick up PyTorch/XLA without learning the details of XLA.
  • The PyTorch/XLA team has been having active discussions about how we can re-architect our basic training loop to be more intuitive and less error-prone. MpDeviceLoader is tightly coupled to this topic, because it is the current synchronization point in our loop. Since this may change in the near future, I don't want to solidify the API for that yet. But, I do believe that data pre-loading should be part of the torch_xla module since it's absolutely necessary for good performance on TPUs.

cc @JackCaoG @miladm @alanwaketan

Metadata

Metadata

Assignees

No one assigned

    Labels

    RFCusabilityBugs/features related to improving the usability of PyTorch/XLA

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions