Skip to content

the new xm.optimization_barrier API breaks the gradient flow #3486

@ronghanghu

Description

@ronghanghu

🐛 Bug

The new xm.optimization_barrier API introduced in #3482 provides a great feature to avoid XLA compiler fusion between different parts of the graph (e.g. forward pass and backward pass) -- very useful for gradient checkpointing application such as in #3455.

However, applying the xm.optimization_barrier API leads to incorrect results in many cases. So it seems that a further inspection is needed here.

For example, it breaks the MNIST example. In a correct training case, MNIST is supposed to get 98%+ accuracy in 2 epochs. However, when calling output, = xm.optimization_barrier([output]) on the model output with this API, the MNIST training does not converge. In fact, the training doesn't happen at all as all the model parameters' .grad is always None in this case.

To Reproduce

  1. Get a v3-8 TPU VM with tpu-vm-pt-1.10 runtime environment.
  2. Install the nightly PyTorch XLA build containing this API:
sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-nightly+20220408-cp38-cp38-linux_x86_64.whl sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torchvision-nightly+20220408-cp38-cp38-linux_x86_64.whl sudo pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-nightly+20220408-cp38-cp38-linux_x86_64.whl sudo pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20220408-py3-none-any.whl 
  1. Clone PyTorch XLA repo containing the official PyTorch XLA MNIST example test_train_mp_mnist.py and download the new API example test_train_mp_mnist_with_optimization_barrier.py:
git clone https://github.com/pytorch/xla.git cd xla/test # test_train_mp_mnist_with_optimization_barrier.py -- new API example wget https://gist.githubusercontent.com/ronghanghu/74f103f79df3c5c6df2807d12506d6c7/raw/265f4377fb051ba7f799184e9a70bcffac8a1cde/test_train_mp_mnist_with_optimization_barrier.py 

Note: their only difference is that test_train_mp_mnist_with_optimization_barrier.py has output, = xm.optimization_barrier([output]) on the model output.

ronghanghu@t1v-n-d5308e1f-w-0:~$ diff -p test_train_mp_mnist.py test_train_mp_mnist_with_optimization_barrier.py *** test_train_mp_mnist.py 2022-04-09 06:49:39.712446418 +0000 --- test_train_mp_mnist_with_optimization_barrier.py 2022-04-09 06:49:44.728872258 +0000 *************** def train_mnist(flags, **kwargs): *** 125,130 **** --- 125,131 ---- for step, (data, target) in enumerate(loader): optimizer.zero_grad() output = model(data) + output, = xm.optimization_barrier([output]) loss = loss_fn(output, target) loss.backward() xm.optimizer_step(optimizer) 
  1. Run the two examples above with --batch_size 16 --drop_last --num_epochs 2 and check their training accuracies.

The official PyTorch XLA MNIST example with 2 training epochs

python3 -u test_train_mp_mnist.py --batch_size 16 --drop_last --num_epochs 2 

gives

... Epoch 1 test end 07:03:06, Accuracy=98.68 ... Epoch 2 test end 06:59:32, Accuracy=98.94 Max Accuracy: 98.94% 

as expected.

The new API example with 2 training epochs

python3 -u test_train_mp_mnist_with_optimization_barrier.py --batch_size 16 --drop_last --num_epochs 2 

gives

... Epoch 1 test end 07:00:38, Accuracy=3.58 ... Epoch 2 test end 07:00:48, Accuracy=3.58 Max Accuracy: 3.58% Accuracy 3.58375 is below target 98.0 

which shows that the model doesn't converge.

It seems that this new xm.optimization_barrier API breaks the gradient flow -- the accuracy at epoch 1 and epoch 2 are both exactly 3.58. A further inspection shows that all the model parameters stayed the same as their initialized values and their .grad is always None.

Expected behavior

The training accuracy should be the same between the two cases since xm.optimization_barrier should not change the computational results.

Environment

  • Reproducible on XLA backend [CPU/TPU]: v3-8 TPU VM
  • torch_xla version: nightly+20220408 (see details above)

cc: @JackCaoG @ultrons

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions