Accelerate documentation
Checkpointing
Checkpointing
When training a PyTorch model with Accelerate, you may often want to save and continue a state of training. Doing so requires saving and loading the model, optimizer, RNG generators, and the GradScaler. Inside Accelerate are two convenience functions to achieve this quickly:
- Use save_state() for saving everything mentioned above to a folder location
- Use load_state() for loading everything stored from an earlier
save_state
To further customize where and how states are saved through save_state() the ProjectConfiguration class can be used. For example if automatic_checkpoint_naming
is enabled each saved checkpoint will be located then at Accelerator.project_dir/checkpoints/checkpoint_{checkpoint_number}
.
It should be noted that the expectation is that those states come from the same training script, they should not be from two separate scripts.
- By using register_for_checkpointing(), you can register custom objects to be automatically stored or loaded from the two prior functions, so long as the object has a
state_dict
and aload_state_dict
functionality. This could include objects such as a learning rate scheduler.
Below is a brief example using checkpointing to save and reload a state during training:
from accelerate import Accelerator import torch accelerator = Accelerator(project_dir="my/save/path") my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99) my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader) # Register the LR scheduler accelerator.register_for_checkpointing(my_scheduler) # Save the starting state accelerator.save_state() device = accelerator.device my_model.to(device) # Perform training for epoch in range(num_epochs): for batch in my_training_dataloader: my_optimizer.zero_grad() inputs, targets = batch inputs = inputs.to(device) targets = targets.to(device) outputs = my_model(inputs) loss = my_loss_function(outputs, targets) accelerator.backward(loss) my_optimizer.step() my_scheduler.step() # Restore the previous state accelerator.load_state("my/save/path/checkpointing/checkpoint_0")
Restoring the state of the DataLoader
After resuming from a checkpoint, it may also be desirable to resume from a particular point in the active DataLoader
if the state was saved during the middle of an epoch. You can use skip_first_batches() to do so.
from accelerate import Accelerator accelerator = Accelerator(project_dir="my/save/path") train_dataloader = accelerator.prepare(train_dataloader) accelerator.load_state("my_state") # Assume the checkpoint was saved 100 steps into the epoch skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100) # After the first iteration, go back to `train_dataloader` # First epoch for batch in skipped_dataloader: # Do something pass # Second epoch for batch in train_dataloader: # Do something pass