- Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
Feature request
Add support for the pytorch implementation of OpenAI's approximation of the GeLU function, added in pytorch 1.12. This implementation is equivalent to gelu_new or gelu_fast but much faster. It can come as a separate activation function, for example gelu_new_python, to avoid distrupting existing models.
Motivation
Many transformer models use OpenAI's approximation (tanh) for the GeLU, through the activation function gelu_new or gelu_fast. These implementations are extremely slow (despite their name) because they consist of multiple operations/kernels (8 and 9 respectively).
Since version 1.12, pytorch supports a single-kernel, C/cuda implementation through the argument approximate='tanh' ( https://pytorch.org/docs/stable/generated/torch.nn.GELU.html). This implementation is 6-10x faster than what currently exists in transformers, and is numerically equal up to rounding errors.
When benchmarking the inference speed of the SantaCoder models, I found that using the pytorch implementation allowed for an end-to-end speedup of ~15-20%.
I also benchmarked the speed and accuracy using the following code (on a A100-80GB):
import time import torch from transformers.activations import NewGELUActivation, FastGELUActivation dtype=torch.float32 eps=torch.finfo(dtype).eps x=torch.empty([2**30], device="cuda", dtype=dtype).normal_() torch.cuda.synchronize() t0=time.perf_counter() y0=torch.nn.functional.gelu(x, approximate="tanh") torch.cuda.synchronize() t1=time.perf_counter() y1=NewGELUActivation()(x) torch.cuda.synchronize() t2=time.perf_counter() y2=FastGELUActivation()(x) torch.cuda.synchronize() t3=time.perf_counter() y3=torch.nn.functional.gelu(x) torch.cuda.synchronize() t4=time.perf_counter() print(f"Torch tanh: {1000*(t1-t0):.3f} ms") print(f"New: {1000*(t2-t1):.3f} ms") print(f"Fast: {1000*(t3-t2):.3f} ms") print(f"Torch orig: {1000*(t4-t3):.3f} ms") print(f"Torch tanh vs new: {(y1-y0).float().std().cpu().item()/eps:.3f}") print(f"Torch tanh vs fast: {(y2-y0).float().std().cpu().item()/eps:.3f}") print(f"New vs fast: {(y2-y1).float().std().cpu().item()/eps:.3f}") print(f"Torch tanh vs torch orig: {(y3-y0).float().std().cpu().item()/eps:.3f}") With output
Torch tanh: 4.921 ms New: 43.253 ms Fast: 50.269 ms Torch orig: 4.989 ms Torch tanh vs new: 0.042 Torch tanh vs fast: 0.147 New vs fast: 0.147 Torch tanh vs torch orig: 971.960 I.e., the tanh version of torch matches the fast and new gelu within epsilon while being 8.8x/10.2x faster, but is different from the original version
With dtype=torch.float16:
Torch tanh: 3.342 ms New: 22.667 ms Fast: 26.104 ms Torch orig: 3.395 ms Torch tanh vs new: 0.244 Torch tanh vs fast: 0.243 New vs fast: 0.143 Torch tanh vs torch orig: 0.216 I.e., it's 6.8x/7.8x faster, and the implementation doesn't matters because rounding errors dominate.
On cpu (float32), size 2**28 (268M):
Torch tanh: 182.575 ms New: 1683.934 ms Fast: 1925.547 ms Torch orig: 141.410 ms Torch tanh vs new: 0.043 Torch tanh vs fast: 0.144 New vs fast: 0.144 Torch tanh vs torch orig: 971.852 I.e., same accuracy and speedup (9.2x/10.5x faster)
Your contribution
Opened a draft PR (#21345)