Skip to content

Add the pytorch implementation of the OpenAI GeLU approximation #21344

@jlamypoirier

Description

@jlamypoirier

Feature request

Add support for the pytorch implementation of OpenAI's approximation of the GeLU function, added in pytorch 1.12. This implementation is equivalent to gelu_new or gelu_fast but much faster. It can come as a separate activation function, for example gelu_new_python, to avoid distrupting existing models.

Motivation

Many transformer models use OpenAI's approximation (tanh) for the GeLU, through the activation function gelu_new or gelu_fast. These implementations are extremely slow (despite their name) because they consist of multiple operations/kernels (8 and 9 respectively).

Since version 1.12, pytorch supports a single-kernel, C/cuda implementation through the argument approximate='tanh' ( https://pytorch.org/docs/stable/generated/torch.nn.GELU.html). This implementation is 6-10x faster than what currently exists in transformers, and is numerically equal up to rounding errors.

When benchmarking the inference speed of the SantaCoder models, I found that using the pytorch implementation allowed for an end-to-end speedup of ~15-20%.

I also benchmarked the speed and accuracy using the following code (on a A100-80GB):

 import time import torch from transformers.activations import NewGELUActivation, FastGELUActivation dtype=torch.float32 eps=torch.finfo(dtype).eps x=torch.empty([2**30], device="cuda", dtype=dtype).normal_() torch.cuda.synchronize() t0=time.perf_counter() y0=torch.nn.functional.gelu(x, approximate="tanh") torch.cuda.synchronize() t1=time.perf_counter() y1=NewGELUActivation()(x) torch.cuda.synchronize() t2=time.perf_counter() y2=FastGELUActivation()(x) torch.cuda.synchronize() t3=time.perf_counter() y3=torch.nn.functional.gelu(x) torch.cuda.synchronize() t4=time.perf_counter() print(f"Torch tanh: {1000*(t1-t0):.3f} ms") print(f"New: {1000*(t2-t1):.3f} ms") print(f"Fast: {1000*(t3-t2):.3f} ms") print(f"Torch orig: {1000*(t4-t3):.3f} ms") print(f"Torch tanh vs new: {(y1-y0).float().std().cpu().item()/eps:.3f}") print(f"Torch tanh vs fast: {(y2-y0).float().std().cpu().item()/eps:.3f}") print(f"New vs fast: {(y2-y1).float().std().cpu().item()/eps:.3f}") print(f"Torch tanh vs torch orig: {(y3-y0).float().std().cpu().item()/eps:.3f}") 

With output

Torch tanh: 4.921 ms New: 43.253 ms Fast: 50.269 ms Torch orig: 4.989 ms Torch tanh vs new: 0.042 Torch tanh vs fast: 0.147 New vs fast: 0.147 Torch tanh vs torch orig: 971.960 

I.e., the tanh version of torch matches the fast and new gelu within epsilon while being 8.8x/10.2x faster, but is different from the original version

With dtype=torch.float16:

Torch tanh: 3.342 ms New: 22.667 ms Fast: 26.104 ms Torch orig: 3.395 ms Torch tanh vs new: 0.244 Torch tanh vs fast: 0.243 New vs fast: 0.143 Torch tanh vs torch orig: 0.216 

I.e., it's 6.8x/7.8x faster, and the implementation doesn't matters because rounding errors dominate.

On cpu (float32), size 2**28 (268M):

Torch tanh: 182.575 ms New: 1683.934 ms Fast: 1925.547 ms Torch orig: 141.410 ms Torch tanh vs new: 0.043 Torch tanh vs fast: 0.144 New vs fast: 0.144 Torch tanh vs torch orig: 971.852 

I.e., same accuracy and speedup (9.2x/10.5x faster)

Your contribution

Opened a draft PR (#21345)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions