Skip to content

Conversation

@jlamypoirier
Copy link
Contributor

@jlamypoirier jlamypoirier commented Jan 27, 2023

Fixes #21344. See that issue for more details.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 27, 2023

The documentation is not available anymore as the PR was closed or merged.

@sgugger
Copy link
Collaborator

sgugger commented Jan 30, 2023

Thanks for working on this! Does the new implementation in Pytorch produce the exact same results as gelu_fast? If that is the case, I would prefer we just replace the current gelu_fast with this when PyTorch is 1.12 or above.

@jlamypoirier
Copy link
Contributor Author

Thanks for working on this! Does the new implementation in Pytorch produce the exact same results as gelu_fast? If that is the case, I would prefer we just replace the current gelu_fast with this when PyTorch is 1.12 or above.

The results are similar but there are still rounding errors, see my analysis in the related issue #21344. I would also be in favor of replacing the existing implementation / using it as default, but I would introduce small numerical differences in some models, is that a problem?

@sgugger
Copy link
Collaborator

sgugger commented Jan 30, 2023

Ah yes, the difference is quite significant sadly, so this will probably introduce a difference that is too big :-/
So let's go with a new activation. Maybe gelu_pytorch is a better name?

@jlamypoirier
Copy link
Contributor Author

Ah yes, the difference is quite significant sadly, so this will probably introduce a difference that is too big :-/ So let's go with a new activation. Maybe gelu_pytorch is a better name?

Wouldn't it cause confusion with the default pytorch implementation? That one is currently named "gelu". (And the one named "gelu_python").

Also should I add an explicit pytorch version check?

@sgugger
Copy link
Collaborator

sgugger commented Jan 31, 2023

Ok for the name then. For the version check, you will need to create a function that returns the instance of GELU and issues an import error if the PyTorch version is too low, then put that function in the mappinh.

@jlamypoirier
Copy link
Contributor Author

Ok for the name then. For the version check, you will need to create a function that returns the instance of GELU and issues an import error if the PyTorch version is too low, then put that function in the mappinh.

Made a class to match the other activations, and raising a NotImplementedError (I don't think an ImportError is the best here since the function exists in earlier versions.) Also added to test_get_activation.

@jlamypoirier jlamypoirier marked this pull request as ready for review January 31, 2023 21:20
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating! I just have one last comment on the error raised.

def __init__(self):
super().__init__()
if version.parse(torch.__version__) < version.parse("1.12.0"):
raise NotImplementedError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise NotImplementedError(
raise ImportError(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All fixed, this should be ready to merge once the tests pass.

@sgugger
Copy link
Collaborator

sgugger commented Feb 2, 2023

Failure is unrelated so merging. Thanks again for your contribution!

@sgugger sgugger merged commit e006ab5 into huggingface:main Feb 2, 2023
@jlamypoirier jlamypoirier deleted the gelu_new_python branch February 2, 2023 23:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

3 participants