Skip to content

Conversation

@jonb377
Copy link
Collaborator

@jonb377 jonb377 commented Feb 20, 2024

See also: #6546

The optimizer state must be primed before it can be restored. Optimizer state isn't materialized until the first optim.step call, so to restore optimizer state before resuming training, a dummy step is needed.

This PR introduces the prime_optimizer API, which will run a dummy optimizer step with zeroed gradients. The gradient sharding is copied from the parameters to ensure the resulting sharding is the same.

Copy link
Collaborator

@alanwaketan alanwaketan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks, Jon.

@dasoto
Copy link

dasoto commented May 28, 2024

Any idea when this will be merged? I try to implement the prime_optimizer function but looks like require some of the changes on the torch_xla/csrc/init_python_bindings.cpp side.

@jonb377
Copy link
Collaborator Author

jonb377 commented May 28, 2024

Hi @dasoto, I've found that this approach will not guarantee the same sharding in the optimizer compared to running a full training step (this is due to sharding propagation decisions in the compiler). I believe the adagrad unit test was broken after an openxla pin update, for example.

Since this is an experimental feature, I would be OK to merge after a rebase. cc @JackCaoG @alanwaketan

@jonb377 jonb377 merged commit 8fd051f into master May 30, 2024
@jonb377 jonb377 deleted the jonbolin/prime branch May 30, 2024 20:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants