-
Couldn't load subscription status.
- Fork 560
Description
TL;DR
- In most cases, replace
mark_stepwith a context manager/decoratortorch_xla.stepto explicitly mark code that should be traced and then executed. - Both ensures that
mark_stepis always executed upon exit and provides an explicit context to raise errors.
Introduction
Users should be able to pick up and use PyTorch/XLA without having to understand the execution model. As much as possible, we should bury the implementation details such that you can wrap existing PyTorch code and have it "just work".
xm.mark_step is the most obvious example where we directly expose implementation details of how PyTorch/XLA works to the user: we actually require users to decide where to place synchronization barriers. These manual barriers are not required in JAX and TensorFlow, even though both of them implement a similar lazy execution model.
In practice, our current solution is to hide xm.mark_step in our preloading data loader implementation that calls it 1) after each batch is loaded and 2) at the end of iteration. If a user is sufficiently careful, they don't have to see this implementation detail at all. Take the following training loop for example:
for batch in loader: # Run model xm.optimizer_step() # Note: this may be optimizer.step() when using DDPTraining loops like the one above will run without issue. However, even slight deviations from this pattern can cause inscrutable problems such as the one below.
Weird xm.mark_step behaviors
Checkpointing before xm.mark_step
Take the following common pattern, where the master saves a checkpoint every n steps:
for batch in loader: # Run model xm.optimizer_step() if is_master and step % 100 == 0: xm.save(model)Can you spot the error? The above code will hang forever without printing an error at all, because the execution diverges between the master and non-master replicas. Compare the order of operations:
| Replica 0 | Others |
load data mark step run step checkpoint, hang forever waiting for other replicas | load data mark step run step load data mark step, hang forever waiting for replica 0 |
Concretely, a version of this same bug took hours to debug when it appeared in both Accelerate and Lightning. The solution in each case was to add an additional xm.mark_step on replica 0 before the checkpoint.
Logging before xm.mark_step
Let's take a similar example where we log the loss of a model every n steps (either to the terminal or to TensorBoard):
for batch in loader: # Run model loss = ... xm.optimizer_step() if is_master and step % 100 == 0: print(step, loss)Although the above code looks innocent, it will (at the very least) lead to a substantial performance degradation. The forward pass will actually run twice on the master. When you move a tensor to the CPU, before a mark step, we immediately run any pending operations on that tensor, but we don't cache the result 1. The forward pass (along with the backward pass) will run again when loader calls mark_step.
Returning early before xm.mark_step
Because we rely on the MpDeviceLoader iterator for inserting mark_steps, breaking early or raising an exception becomes a dangerous operation. Take the following example with early stopping:
for batch in loader: # Run model loss = ... if loss < target: breakWhen the loop exits, the graph will include the entire forward pass of the model, which will get run the next time the user calls mark_step. This becomes problematic if the user adds another large executable to the graph. If the user runs mark_step after This was a latent bug with the HuggingFace transformers.Trainer and was only fixed recently. In this case, it led to OOMs during checkpointing.
xm.mark_step inside profile scope
The above examples deal with cases where we effectively require a mark_step. So let's say the user does add the mark_steps they need manually, but they are also profiling their code. In this case, they are going to be susceptible to the mistake below:
with xp.Trace('loss'): loss = … xm.mark_step() if is_master and step % 100 == 0: print(step, loss)Running code like this will result in an error like this one: RuntimeError: Expecting scope to be empty but it is train_loop.
We're essentially taking an indentation mistake (putting the log and the mark_step inside of a Trace and makes it a flat-out error with an unclear error message.
Proposal
The "correct" way to avoid most of the above errors is to use a step closure to defer an operation until after the xm.mark_step runs. This again exposes implementation details directly to the user. We should not be framing our API to the user such that they have to carefully think through the execution order of their code, particularly because our execution model diverges from what most people are used to in torch.
Fundamentally, the problem here is that what you can and cannot call at a given line depends on an invisible global state that is not readily visible. Because of our API's imperative structure, we cannot tell what code is between mark_steps and select the correct behavior or raise a clear error message. Python gives us two tools to mark a limited context explicitly: context managers and decorators2.
My proposal is simple: create a decorator/context manager3 named something like torch_xla.step that explicitly marks out a segment of code that should be traced. If a user accidentally does something that they shouldn't (like move a tensor to CPU) in this context, we can loudly raise an error.
Example code
The following example code is modified from our ResNet example code.
Before:
for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) xm.add_step_closure( train_update, args=(device, step, loss, tracker, epoch, writer))After:
for step, (data, target) in enumerate(loader): with torch_xla.step(): optimizer.zero_grad() output = model(data) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) train_update(device, step, loss, tracker, epoch, writer)If the user puts train_update inside of torch_xla.step, we can raise an error early when moving the xla tensor for loss to the CPU. The same would be true if they tried checkpointing there. Likewise, if they return early from the loop, the context manager's __exit__() will still be called.
Prior work: xp.StepTrace
We already have a similar context manager in our API: StepTrace. StepTrace is already a context manager that calls mark_step when it exits. Since it's focused on profiling, StepTrace is rarely used in our documentation and examples. I want to take this idea further. Most importantly, torch_xla.step should be used as a context to proactively raise errors for incorrect or dangerous operations, and we should start using it in documentation that uses custom training loops.
We still need xm.mark_step
We can't eliminate mark_step entirely. torch_xla.step will be useful for training loops that are written with PyTorch/XLA in mind. It may be less useful for use with higher-level frameworks that control the training loop. It's generally harder to insert a new context manager without modifying code than to insert a single call to xm.mark_step. Take our HuggingFace Diffusers example:
image = pipeline(prompt, callback=lambda *args: xm.mark_step(), generator=generator)mark_step is a useful callback to pass in this case.
Interaction with MpDeviceLoader
We use MpDeviceLoader extensively throughout our documentation, so torch_xla.step needs to be compatible with it4. Since mark_step doesn't affect the host-to-device transfers that MpDeviceLoader starts, the effect of adding an extra mark_step will be minimal. If there was an unexpected interaction, there is already a mechanism to prevent MpDeviceLoader from dispatching computations, although I'd like to avoid that. We should reduce the coupling between our features such that they can be added gradually as needed.
Closing thoughts
This is not a perfect solution. The problem as I see it is that the user must keep in mind some context about the XLA state, but we make that state invisible. My proposal here is to instead make that state visible. Explicit is better than implicit.
I'd rather the user not have to think about XLA at all, but I don't see a good way to do that entirely. I'm absolutely open to other ideas and ways of thinking about the problems above.
- Can this be combined with our
torch.compilebackend somehow? - Can we come up with a better name? I'll happily update this proposal. Naming is hard.
- What other
mark_stephorror stories did I miss?
Footnotes
-
There are good reasons for this: caching intermediate results will make it harder to effectively cache the whole executable. ↩
-
TensorFlow and JAX have a similar challenge of translating an eager-looking API based on Numpy into an effectively lazy one, and decorators are the path that both of them take. See
jax.jitandtensorflow.function. ↩ -
Python makes it easy to combine these two into one implementation: https://docs.python.org/3/library/contextlib.html#contextlib.ContextDecorator ↩
-
In my experimentation,
MpDeviceLoaderdoesn't make much (if any difference) to MP workloads because transfers are now async anyway. That's a can of worms for another issue/proposal. ↩