Skip to content

Commit 486fe0d

Browse files
authored
[sparse] Migrate Float8SemiSparseTensor off of AQT (#3361)
This PR migrates `Float8DynamicActivationFloat8SemiSparseWeighConfig` off of using the AQT CutlassSemiSparseLayout subclass. The old AQT flow can still be used by passing `version=1` into the config Testing: ``` pytest test/quantization/quantize_/workflows/float8/test_float8_semi_sparse_tensor.py ```
1 parent 7035fb7 commit 486fe0d

File tree

5 files changed

+438
-15
lines changed

5 files changed

+438
-15
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import copy
7+
import logging
8+
import unittest
9+
10+
import torch
11+
from torch import nn
12+
from torch.testing._internal import common_utils
13+
14+
from torchao.quantization import (
15+
Float8DynamicActivationFloat8WeightConfig,
16+
)
17+
from torchao.quantization.granularity import PerRow
18+
from torchao.quantization.quant_api import (
19+
quantize_,
20+
)
21+
from torchao.quantization.quantize_.workflows import (
22+
Float8PackingFormat,
23+
)
24+
from torchao.quantization.utils import compute_error
25+
from torchao.sparsity import apply_fake_sparsity
26+
from torchao.utils import is_sm_at_least_90
27+
28+
logging.basicConfig(
29+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
30+
)
31+
32+
33+
class TestSparse2x4Float8Tensor(common_utils.TestCase):
34+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
35+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
36+
@common_utils.parametrize("compile", [True, False])
37+
def test_fp8_cutlass_sparse(self, compile):
38+
with torch.inference_mode():
39+
input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda")
40+
model = (
41+
nn.Sequential(
42+
nn.Linear(256, 1024),
43+
nn.Linear(1024, 256),
44+
)
45+
.bfloat16()
46+
.cuda()
47+
.eval()
48+
)
49+
50+
apply_fake_sparsity(model)
51+
baseline_result = model(input)
52+
model_copy = copy.deepcopy(model)
53+
54+
# Quantized
55+
quantize_(model_copy, Float8DynamicActivationFloat8WeightConfig())
56+
dense_result = model_copy(input)
57+
dense_sqnr = compute_error(baseline_result, dense_result)
58+
59+
# Sparse + quantized
60+
quantize_(
61+
model,
62+
Float8DynamicActivationFloat8WeightConfig(
63+
version=2,
64+
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
65+
granularity=PerRow(),
66+
),
67+
)
68+
if compile:
69+
model = torch.compile(model)
70+
sparse_result = model(input)
71+
sparse_sqnr = compute_error(baseline_result, sparse_result)
72+
73+
self.assertEqual(dense_sqnr, sparse_sqnr)
74+
75+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
76+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
77+
def test_fp8_cutlass_sparse_lowering_op_clone(self):
78+
with torch.inference_mode():
79+
model = nn.Linear(256, 1024).half().cuda().eval()
80+
apply_fake_sparsity(model)
81+
quantize_(
82+
model,
83+
Float8DynamicActivationFloat8WeightConfig(
84+
version=2,
85+
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
86+
granularity=PerRow(),
87+
),
88+
)
89+
90+
original = model.weight.dequantize()
91+
cloned = model.weight.clone().dequantize()
92+
93+
for o, c in zip(original, cloned):
94+
self.assertEqual(o, c)
95+
96+
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
97+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
98+
def test_fp8_cutlass_sparse_lowering_op_to(self):
99+
# Need to run with inference mode to avoid dispatching to `aten.to_copy`
100+
with torch.inference_mode():
101+
model = nn.Linear(256, 1024).half().cuda().eval()
102+
apply_fake_sparsity(model)
103+
model_copy = copy.deepcopy(model)
104+
expected = model_copy.weight.to(dtype=torch.float)
105+
106+
quantize_(
107+
model,
108+
Float8DynamicActivationFloat8WeightConfig(
109+
version=2,
110+
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
111+
granularity=PerRow(),
112+
),
113+
)
114+
115+
original = torch.ops.aten.to.dtype_layout(
116+
model.weight,
117+
dtype=torch.float,
118+
layout=torch.strided,
119+
)
120+
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)
121+
122+
123+
common_utils.instantiate_parametrized_tests(TestSparse2x4Float8Tensor)
124+
125+
if __name__ == "__main__":
126+
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@
8181
KernelPreference,
8282
)
8383
from torchao.quantization.quantize_.workflows import (
84+
Float8PackingFormat,
8485
Float8Tensor,
8586
Int4ChooseQParamsAlgorithm,
8687
Int4MarlinSparseTensor,
@@ -96,6 +97,7 @@
9697
IntxUnpackedToInt8Tensor,
9798
QuantizeTensorToFloat8Kwargs,
9899
QuantizeTensorToInt8Kwargs,
100+
Sparse2x4CUTLASSFloat8Tensor,
99101
)
100102
from torchao.quantization.transform_module import (
101103
_QUANTIZE_CONFIG_HANDLER,
@@ -1588,6 +1590,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig):
15881590
activation_dtype: torch.dtype = e4m3_dtype
15891591
weight_dtype: torch.dtype = e4m3_dtype
15901592
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None
1593+
packing_format: Optional[Float8PackingFormat] = Float8PackingFormat.PLAIN
15911594
mm_config: Optional[Float8MMConfig] = None
15921595
activation_value_lb: Optional[float] = None
15931596
activation_value_ub: Optional[float] = None
@@ -1625,6 +1628,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16251628
activation_value_lb = config.activation_value_lb
16261629
activation_value_ub = config.activation_value_ub
16271630
kernel_preference = config.kernel_preference
1631+
packing_format = config.packing_format
16281632

16291633
# Ensure works on device
16301634
_check_hardware_support(granularity)
@@ -1651,31 +1655,41 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
16511655
# TODO(future PR): this should really throw an exception instead of silently
16521656
# not doing what the user asked
16531657
return weight
1654-
1655-
if isinstance(weight_granularity, PerRow):
1658+
assert config.version == 2, f"Unexpected version: {config.version}"
1659+
if packing_format == Float8PackingFormat.PLAIN and isinstance(
1660+
weight_granularity, PerRow
1661+
):
16561662
assert weight.dtype == torch.bfloat16, (
16571663
"PerRow quantization only works for bfloat16 precision input weight"
16581664
)
1659-
1660-
assert config.version == 2, f"Unexpected version: {config.version}"
16611665
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
16621666
activation_dtype,
16631667
activation_granularity,
16641668
hp_value_lb=activation_value_lb,
16651669
hp_value_ub=activation_value_ub,
16661670
kernel_preference=kernel_preference,
16671671
)
1668-
1669-
quantized_weight = Float8Tensor.from_hp(
1670-
weight,
1671-
float8_dtype=weight_dtype,
1672-
granularity=weight_granularity,
1673-
mm_config=mm_config,
1674-
kernel_preference=kernel_preference,
1675-
act_quant_kwargs=act_quant_kwargs,
1676-
)
1677-
1678-
return quantized_weight
1672+
if packing_format == Float8PackingFormat.PLAIN:
1673+
quantized_weight = Float8Tensor.from_hp(
1674+
weight,
1675+
float8_dtype=weight_dtype,
1676+
granularity=weight_granularity,
1677+
mm_config=mm_config,
1678+
kernel_preference=kernel_preference,
1679+
act_quant_kwargs=act_quant_kwargs,
1680+
)
1681+
return quantized_weight
1682+
elif packing_format == Float8PackingFormat.SPARSE_CUTLASS:
1683+
assert isinstance(weight_granularity, PerRow), (
1684+
"Sparse packing format only supports per-row quantization"
1685+
)
1686+
quantized_weight = Sparse2x4CUTLASSFloat8Tensor.from_hp(
1687+
weight,
1688+
float8_dtype=weight_dtype,
1689+
granularity=weight_granularity,
1690+
act_quant_kwargs=act_quant_kwargs,
1691+
)
1692+
return quantized_weight
16791693

16801694

16811695
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from .float8.float8_packing_format import (
2+
Float8PackingFormat,
3+
)
14
from .float8.float8_tensor import (
25
Float8Tensor,
36
QuantizeTensorToFloat8Kwargs,
47
)
8+
from .float8.sparse_2x4_cutlass_float8_tensor import (
9+
Sparse2x4CUTLASSFloat8Tensor,
10+
)
511
from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
612
from .int4.int4_marlin_sparse_tensor import (
713
Int4MarlinSparseTensor,
@@ -41,6 +47,8 @@
4147
"Int8Tensor",
4248
"QuantizeTensorToInt8Kwargs",
4349
"Float8Tensor",
50+
"Sparse2x4CUTLASSFloat8Tensor",
51+
"Float8PackingFormat",
4452
"QuantizeTensorToFloat8Kwargs",
4553
"Int8Tensor",
4654
"QuantizeTensorToInt8Kwargs",
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from enum import Enum
9+
10+
import torch
11+
12+
__all__ = [
13+
"Float8PackingFormat",
14+
]
15+
16+
17+
# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum)
18+
# after python 3.10 is end of life (https://devguide.python.org/versions/)
19+
class Float8PackingFormat(str, Enum):
20+
"""
21+
plain packing format for Float8Tensor will lay out elements in Tensor sequentially,
22+
for example: for a Tensor of shape (4, 6):
23+
a_0_0, a_0_1, ..., a_0_5,
24+
...
25+
a_3_0, a_3_1, ..., a_3_5
26+
"""
27+
28+
PLAIN = "plain"
29+
"""
30+
Sparse packing format for 2:4 sparsity + FP8 quantization
31+
32+
SPARSE_CUTLASS will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
33+
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
34+
"""
35+
SPARSE_CUTLASS = "sparse_cutlass"
36+
37+
38+
torch.serialization.add_safe_globals([Float8PackingFormat])

0 commit comments

Comments
 (0)