|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from utils.util import skip_pre_hopper |
| 4 | + |
| 5 | +from tensorrt_llm._torch.modules.linear import Linear |
| 6 | +from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig |
| 7 | + |
| 8 | + |
| 9 | +@skip_pre_hopper |
| 10 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 11 | +def test_fp8_rowwise_linear(dtype): |
| 12 | + SEQ_LEN = 10 |
| 13 | + HIDDEN_SIZE = 128 |
| 14 | + torch.manual_seed(0) |
| 15 | + x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda() |
| 16 | + x_fp8, x_scale = torch.ops.tensorrt_llm.quantize_e4m3_activation(x) |
| 17 | + x_fp8 = x_fp8.view(torch.float8_e4m3fn) |
| 18 | + x_scale = x_scale.float().squeeze() |
| 19 | + w = torch.randn((HIDDEN_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() |
| 20 | + w_fp8, w_scale = torch.ops.tensorrt_llm.quantize_e4m3_activation(w) |
| 21 | + w_fp8 = w_fp8.view(torch.float8_e4m3fn) |
| 22 | + w_scale = w_scale.float().squeeze() |
| 23 | + |
| 24 | + qc = QuantConfig(quant_algo=QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN) |
| 25 | + l0 = Linear(in_features=HIDDEN_SIZE, |
| 26 | + out_features=HIDDEN_SIZE, |
| 27 | + bias=False, |
| 28 | + dtype=dtype, |
| 29 | + quant_config=qc) |
| 30 | + assert l0.weight.dtype == torch.float8_e4m3fn |
| 31 | + l0.load_weights([{ |
| 32 | + 'weight': w_fp8, |
| 33 | + 'weight_scale': w_scale, |
| 34 | + }]) |
| 35 | + l0.cuda() |
| 36 | + torch.testing.assert_close(l0.weight, w_fp8) |
| 37 | + torch.testing.assert_close(l0.weight_scale, w_scale) |
| 38 | + |
| 39 | + with torch.inference_mode(): |
| 40 | + output = l0.forward(x) |
| 41 | + |
| 42 | + with torch.inference_mode(): |
| 43 | + x_dq = x_fp8.to(x_scale.dtype) * x_scale.view(-1, 1) |
| 44 | + w_dq = w_fp8.to(w_scale.dtype).t() * w_scale.view(1, -1) |
| 45 | + ref_output = x_dq.to(dtype) @ w_dq.to(dtype) |
| 46 | + |
| 47 | + # compare |
| 48 | + torch.cuda.synchronize() |
| 49 | + torch.testing.assert_close(output, ref_output, atol=5e-2, rtol=1e-2) |
| 50 | + |
| 51 | + |
| 52 | +if __name__ == '__main__': |
| 53 | + test_fp8_rowwise_linear(torch.float16) |
0 commit comments