- Notifications
You must be signed in to change notification settings - Fork 31.5k
Description
Feature request
pytorch just merged pytorch/torchdistx#52, which adds AnyPrecisionAdamW (bf16-support, and future new dtypes)
we should add it to our HF Trainer arsenal
This is open to the community - it shouldn't be too difficult to add by just checking the existing optimizers. Here are some pointers to start unraveling:
transformers/src/transformers/training_args.py
Lines 393 to 394 in e88e9ff
| optim (`str` or [`training_args.OptimizerNames`], *optional*, defaults to `"adamw_hf"`): | |
| The optimizer to use: adamw_hf, adamw_torch, adamw_apex_fused, or adafactor. |
and
transformers/src/transformers/training_args.py
Lines 94 to 106 in e88e9ff
| class OptimizerNames(ExplicitEnum): | |
| """ | |
| Stores the acceptable string identifiers for optimizers. | |
| """ | |
| ADAMW_HF = "adamw_hf" | |
| ADAMW_TORCH = "adamw_torch" | |
| ADAMW_TORCH_XLA = "adamw_torch_xla" | |
| ADAMW_APEX_FUSED = "adamw_apex_fused" | |
| ADAFACTOR = "adafactor" | |
| ADAMW_BNB = "adamw_bnb_8bit" | |
| SGD = "sgd" | |
| ADAGRAD = "adagrad" |
the key of course is the documentation and tests. checking the existing tests and working from there is what's needed.
One would start looking at mimicking the integration of other optimizers,
So in this case it'd follow the path of adamw_torch , as it's the nearest similar optimizer.
it might help to look at the previous PRs that added new optimizers, e.g. find the PR that added adamw_bnb_8bit - that could be a good model to copy from. And you can see the scope of work that needs to be done. Except this one should be simpler than adamw_bnb_8bit as it just plugs in a core pytorch optimizer, that's why I said adamw_torch is another good model.
Please remember that this requires pytorch-nightly as this new feature hasn't made it yet into pytorch-1.13. So you will need to install it from https://pytorch.org/get-started/locally/ (Choose Preview (Nightly))
Thank you!