Skip to content

Commit 20840d9

Browse files
committed
add test
Signed-off-by: Dylan Chen <191843203+DylanChen-NV@users.noreply.github.com>
1 parent 7757de4 commit 20840d9

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import pytest
2+
import torch
3+
from utils.util import skip_pre_hopper
4+
5+
from tensorrt_llm._torch.modules.linear import Linear
6+
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
7+
8+
9+
@skip_pre_hopper
10+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
11+
def test_fp8_rowwise_linear(dtype):
12+
SEQ_LEN = 10
13+
HIDDEN_SIZE = 128
14+
torch.manual_seed(0)
15+
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
16+
x_fp8, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_activation(x)
17+
x_fp8 = x_fp8.view(torch.float8_e4m3fn)
18+
x_scale = x_scale.float().squeeze()
19+
w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda()
20+
w_fp8, w_scale = torch.ops.tensorrt_llm.quantize_e4m3_activation(w)
21+
w_fp8 = w_fp8.view(torch.float8_e4m3fn)
22+
w_scale = w_scale.float().squeeze()
23+
24+
qc = QuantConfig(quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN)
25+
l0 = Linear(in_features=HIDDEN_SIZE,
26+
out_features=HIDDEN_SIZE,
27+
bias=False,
28+
dtype=dtype,
29+
quant_config=qc)
30+
assert l0.weight.dtype == torch.float8_e4m3fn
31+
l0.load_weights([{
32+
'weight': w_fp8,
33+
'weight_scale': w_scale,
34+
}])
35+
l0.cuda()
36+
torch.testing.assert_close(l0.weight, w_fp8)
37+
torch.testing.assert_close(l0.weight_scale, w_scale)
38+
39+
with torch.inference_mode():
40+
output = l0.forward(x)
41+
42+
with torch.inference_mode():
43+
x_dq = x_fp8.to(x_scale.dtype) * x_scale.view(-1, 1)
44+
w_dq = w_fp8.to(w_scale.dtype).t() * w_scale.view(1, -1)
45+
ref_output = x_dq.to(dtype) @ w_dq.to(dtype)
46+
47+
# compare
48+
torch.cuda.synchronize()
49+
torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2)
50+
51+
52+
if __name__ == '__main__':
53+
test_fp8_rowwise_linear(torch.float16)

0 commit comments

Comments
 (0)