- Notifications
You must be signed in to change notification settings - Fork 31.2k
Add the GeLU activation from pytorch with the tanh approximation #21345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| The documentation is not available anymore as the PR was closed or merged. |
| Thanks for working on this! Does the new implementation in Pytorch produce the exact same results as |
The results are similar but there are still rounding errors, see my analysis in the related issue #21344. I would also be in favor of replacing the existing implementation / using it as default, but I would introduce small numerical differences in some models, is that a problem? |
| Ah yes, the difference is quite significant sadly, so this will probably introduce a difference that is too big :-/ |
Wouldn't it cause confusion with the default pytorch implementation? That one is currently named "gelu". (And the one named "gelu_python"). Also should I add an explicit pytorch version check? |
| Ok for the name then. For the version check, you will need to create a function that returns the instance of GELU and issues an import error if the PyTorch version is too low, then put that function in the mappinh. |
Made a class to match the other activations, and raising a |
sgugger left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating! I just have one last comment on the error raised.
src/transformers/activations.py Outdated
| def __init__(self): | ||
| super().__init__() | ||
| if version.parse(torch.__version__) < version.parse("1.12.0"): | ||
| raise NotImplementedError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| raise NotImplementedError( | |
| raise ImportError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All fixed, this should be ready to merge once the tests pass.
| Failure is unrelated so merging. Thanks again for your contribution! |
Fixes #21344. See that issue for more details.