Skip to content

Commit 86644be

Browse files
MekkCybersywangyi
andauthored
[Quantization] FBgemm FP8 for XPU (#42773)
* enable xpu in fp8_gemm Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * refine the code Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * updated Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * fix * style * small fix --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com> Co-authored-by: Wang, Yi <yi.a.wang@intel.com>
1 parent a8f32a0 commit 86644be

File tree

3 files changed

+135
-63
lines changed

3 files changed

+135
-63
lines changed

src/transformers/integrations/fbgemm_fp8.py

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from functools import lru_cache
1516
from typing import Optional
1617

1718
from ..activations import ACT2FN
1819
from ..core_model_loading import ConversionOps
1920
from ..quantizers.quantizers_utils import get_module_from_name, should_convert_module
20-
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
21+
from ..utils import (
22+
is_accelerate_available,
23+
is_fbgemm_gpu_available,
24+
is_torch_available,
25+
is_torch_xpu_available,
26+
logging,
27+
)
2128

2229

2330
if is_torch_available():
@@ -27,7 +34,9 @@
2734
if is_accelerate_available():
2835
from accelerate import init_empty_weights
2936

30-
if is_fbgemm_gpu_available():
37+
_is_torch_xpu_available = is_torch_xpu_available()
38+
39+
if is_fbgemm_gpu_available() and not _is_torch_xpu_available:
3140
import fbgemm_gpu.experimental.gen_ai # noqa: F401
3241

3342
logger = logging.get_logger(__name__)
@@ -61,7 +70,7 @@ def convert(
6170
flattened_param = transposed_param.reshape(-1, original_shape[-1])
6271

6372
# Quantize using per row instead of per column
64-
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
73+
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
6574

6675
# Reshape back to original dimensions
6776
new_value = new_value_flat.reshape(original_shape)
@@ -77,14 +86,14 @@ def convert(
7786
flattened_param = transposed_param.reshape(-1, original_shape[-1])
7887

7988
# Quantize using per column
80-
new_value_flat, weight_scale_flat = torch.ops.fbgemm.quantize_fp8_per_row(flattened_param)
89+
new_value_flat, weight_scale_flat = quantize_fp8_per_row(flattened_param)
8190

8291
# Reshape back to original dimensions
8392
new_value = new_value_flat.reshape(original_shape)
8493
new_value = new_value.transpose(1, 2)
8594
weight_scale = weight_scale_flat.reshape(original_shape[0], original_shape[1], 1)
8695
else:
87-
new_value, weight_scale = torch.ops.fbgemm.quantize_fp8_per_row(value)
96+
new_value, weight_scale = quantize_fp8_per_row(value)
8897
weight_scale = torch.nn.Parameter(weight_scale.view(weight_scale.shape[0], 1))
8998

9099
return {target_key: torch.nn.Parameter(new_value), f"{target_key}_scale": weight_scale}
@@ -110,18 +119,26 @@ def forward(self, x):
110119
output_shape = (*x.shape[:-1], -1)
111120
# x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
112121
# https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
113-
x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
114-
x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub
115-
)
122+
x_quantized, x_scale = quantize_fp8_per_row(x.view(-1, x.shape[-1]).contiguous(), scale_ub=self.input_scale_ub)
116123
# moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
117124
# x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
118125

119126
# The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
120127
weight_scale_float32 = self.weight_scale.to(torch.float32)
121-
output = torch.ops.fbgemm.f8f8bf16_rowwise(
122-
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
123-
)
124-
output = output + self.bias if self.bias is not None else output
128+
if _is_torch_xpu_available:
129+
output = torch._scaled_mm(
130+
x_quantized,
131+
self.weight.t(),
132+
scale_a=x_scale.unsqueeze(-1),
133+
scale_b=weight_scale_float32.t(),
134+
out_dtype=x.dtype,
135+
bias=self.bias,
136+
)
137+
else:
138+
output = torch.ops.fbgemm.f8f8bf16_rowwise(
139+
x_quantized, self.weight, x_scale, weight_scale_float32, use_fast_accum=True
140+
)
141+
output = output + self.bias if self.bias is not None else output
125142
# Hacky for now, we have the output to the device of x
126143
output = output.to(x.device)
127144
output = output.reshape(output_shape)
@@ -173,48 +190,79 @@ def forward(self, hidden_states):
173190
expert_hidden = hidden_states[i]
174191
expert_hidden_reshaped = expert_hidden.reshape(-1, self.hidden_size)
175192
# Quantize for this expert
176-
expert_quantized, expert_scale = torch.ops.fbgemm.quantize_fp8_per_row(
193+
expert_quantized, expert_scale = quantize_fp8_per_row(
177194
expert_hidden_reshaped, num_tokens, self.input_scale_ub
178195
)
179196
sharded_expert_dim = self.gate_up_proj.shape[-1] // 2
180197
gate_up_proj_scale_float32 = self.gate_up_proj_scale.to(torch.float32)
198+
if _is_torch_xpu_available:
199+
gate = torch._scaled_mm(
200+
expert_quantized,
201+
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous().t(),
202+
scale_a=expert_scale.unsqueeze(-1),
203+
scale_b=gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous().t(),
204+
out_dtype=hidden_states.dtype,
205+
)
206+
up = torch._scaled_mm(
207+
expert_quantized,
208+
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous().t(),
209+
scale_a=expert_scale.unsqueeze(-1),
210+
scale_b=gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous().t(),
211+
out_dtype=hidden_states.dtype,
212+
)
213+
else:
214+
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
215+
expert_quantized,
216+
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
217+
expert_scale,
218+
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
219+
use_fast_accum=True,
220+
)
181221

182-
gate = torch.ops.fbgemm.f8f8bf16_rowwise(
183-
expert_quantized,
184-
self.gate_up_proj[i].transpose(0, 1)[:sharded_expert_dim].contiguous(),
185-
expert_scale,
186-
gate_up_proj_scale_float32[i][0][:sharded_expert_dim].view(-1, 1).contiguous(),
187-
use_fast_accum=True,
188-
)
189-
190-
up = torch.ops.fbgemm.f8f8bf16_rowwise(
191-
expert_quantized,
192-
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
193-
expert_scale,
194-
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
195-
use_fast_accum=True,
196-
)
222+
up = torch.ops.fbgemm.f8f8bf16_rowwise(
223+
expert_quantized,
224+
self.gate_up_proj[i].transpose(0, 1)[sharded_expert_dim:].contiguous(),
225+
expert_scale,
226+
gate_up_proj_scale_float32[i][0][sharded_expert_dim:].view(-1, 1).contiguous(),
227+
use_fast_accum=True,
228+
)
197229

198230
activated = up * self.act_fn(gate)
199231

200-
activated_quantized, activated_scale = torch.ops.fbgemm.quantize_fp8_per_row(
201-
activated, num_tokens, self.input_scale_ub
202-
)
232+
activated_quantized, activated_scale = quantize_fp8_per_row(activated, num_tokens, self.input_scale_ub)
203233

204234
down_proj_scale_float32 = self.down_proj_scale.to(torch.float32)
205-
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
206-
activated_quantized,
207-
self.down_proj[i].transpose(0, 1).contiguous(),
208-
activated_scale,
209-
down_proj_scale_float32[i].view(-1, 1).contiguous(),
210-
use_fast_accum=True,
211-
)
235+
if _is_torch_xpu_available:
236+
expert_output = torch._scaled_mm(
237+
activated_quantized,
238+
self.down_proj[i].transpose(0, 1).contiguous(),
239+
scale_a=activated_scale.unsqueeze(-1),
240+
scale_b=down_proj_scale_float32[i].view(-1, 1).contiguous().t(),
241+
out_dtype=hidden_states.dtype,
242+
)
243+
else:
244+
expert_output = torch.ops.fbgemm.f8f8bf16_rowwise(
245+
activated_quantized,
246+
self.down_proj[i].transpose(0, 1).contiguous(),
247+
activated_scale,
248+
down_proj_scale_float32[i].view(-1, 1).contiguous(),
249+
use_fast_accum=True,
250+
)
212251

213252
next_states[i] = expert_output
214253
next_states = next_states.to(hidden_states.device)
215254
return next_states.view(-1, self.hidden_size)
216255

217256

257+
@lru_cache(maxsize=1)
258+
def get_quantize_fp8_per_row():
259+
if _is_torch_xpu_available:
260+
from kernels import get_kernel
261+
262+
return get_kernel("kernels-community/fp8-fbgemm").quantize_fp8_per_row
263+
return torch.ops.fbgemm.quantize_fp8_per_row
264+
265+
218266
def replace_with_fbgemm_fp8_linear(
219267
model, modules_to_not_convert: list[str] | None = None, quantization_config=None, pre_quantized=False, tp_plan=None
220268
):
@@ -232,6 +280,8 @@ def replace_with_fbgemm_fp8_linear(
232280
pre_quantized (`book`, defaults to `False`):
233281
Whether the model is pre-quantized or not
234282
"""
283+
global quantize_fp8_per_row
284+
quantize_fp8_per_row = get_quantize_fp8_per_row()
235285

236286
has_been_replaced = False
237287
module_kwargs = {} if pre_quantized else {"dtype": None}

src/transformers/quantizers/quantizer_fbgemm_fp8.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,21 @@
1919
if TYPE_CHECKING:
2020
from ..modeling_utils import PreTrainedModel
2121

22-
from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
22+
from ..utils import (
23+
is_accelerate_available,
24+
is_fbgemm_gpu_available,
25+
is_kernels_available,
26+
is_torch_available,
27+
is_torch_cuda_available,
28+
is_torch_xpu_available,
29+
logging,
30+
)
2331
from .quantizers_utils import get_module_from_name
2432

2533

2634
if is_torch_available():
2735
import torch
2836

29-
3037
logger = logging.get_logger(__name__)
3138

3239

@@ -41,27 +48,32 @@ def __init__(self, quantization_config, **kwargs):
4148
super().__init__(quantization_config, **kwargs)
4249

4350
def validate_environment(self, *args, **kwargs):
44-
if not is_fbgemm_gpu_available():
51+
if not is_torch_cuda_available() and not is_torch_xpu_available():
52+
raise ImportError("Using fbgemm fp8 quantization requires a GPU or XPU")
53+
if is_torch_xpu_available() and not is_kernels_available():
54+
raise ImportError("Using FP8 fbgemm on XPU requires kernels (`pip install kernels`)")
55+
if is_torch_cuda_available() and not is_fbgemm_gpu_available():
4556
raise ImportError(
46-
"Using fbgemm fp8 quantization requires fbgemm-gpu library"
57+
"Loading an FP8 fbgemm quantized model on CUDA requires fbgemm-gpu library"
4758
"Please install the latest version of fbgemm-gpu library by following : https://pytorch.org/FBGEMM/fbgemm_gpu-development/InstallationInstructions.html#fbgemm-gpu-install-libraries"
4859
)
4960
if not is_accelerate_available():
5061
raise ImportError(
5162
"Loading an FP8 quantized model requires accelerate (`pip install --upgrade accelerate`)"
5263
)
53-
compute_capability = torch.cuda.get_device_capability()
54-
major, _ = compute_capability
55-
if major < 9:
56-
raise ValueError(
57-
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
58-
)
64+
if is_torch_cuda_available():
65+
compute_capability = torch.cuda.get_device_capability()
66+
major, _ = compute_capability
67+
if major < 9:
68+
raise ValueError(
69+
"FP8 quantized models is only supported on GPUs with compute capability >= 9.0 (e.g H100)"
70+
)
5971

6072
device_map = kwargs.get("device_map")
6173
if device_map is None:
6274
logger.warning_once(
63-
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set "
64-
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. "
75+
"You have loaded an FP8 model on CPU and have a CUDA/XPU device available, make sure to set "
76+
"your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or 'xpu' or 'auto'. "
6577
)
6678
elif isinstance(device_map, dict):
6779
if not self.pre_quantized and ("cpu" in device_map.values() or "disk" in device_map.values()):
@@ -121,7 +133,6 @@ def _process_model_before_weight_loading(
121133
modules_to_not_convert=self.modules_to_not_convert,
122134
quantization_config=self.quantization_config,
123135
pre_quantized=self.pre_quantized,
124-
config=model.config,
125136
tp_plan=model._tp_plan,
126137
)
127138

tests/quantization/fbgemm_fp8/test_fbgemm_fp8.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,19 @@
2121
from transformers.testing_utils import (
2222
backend_empty_cache,
2323
require_accelerate,
24-
require_fbgemm_gpu,
24+
require_deterministic_for_xpu,
2525
require_read_token,
26-
require_torch_gpu,
27-
require_torch_multi_gpu,
26+
require_torch_accelerator,
27+
require_torch_multi_accelerator,
2828
slow,
2929
torch_device,
3030
)
31-
from transformers.utils import is_accelerate_available, is_torch_available
31+
from transformers.utils import (
32+
is_accelerate_available,
33+
is_fbgemm_gpu_available,
34+
is_torch_available,
35+
is_torch_xpu_available,
36+
)
3237

3338

3439
if is_torch_available():
@@ -38,7 +43,7 @@
3843
from accelerate import init_empty_weights
3944

4045

41-
@require_torch_gpu
46+
@require_torch_accelerator
4247
class FbgemmFp8ConfigTest(unittest.TestCase):
4348
def test_to_dict(self):
4449
"""
@@ -62,8 +67,8 @@ def test_from_dict(self):
6267

6368

6469
@slow
65-
@require_torch_gpu
66-
@require_fbgemm_gpu
70+
@require_torch_accelerator
71+
@unittest.skipIf(not is_torch_xpu_available() and not is_fbgemm_gpu_available(), "test requires fbgemm-gpu or xpu")
6772
@require_accelerate
6873
@require_read_token
6974
class FbgemmFp8Test(unittest.TestCase):
@@ -76,10 +81,11 @@ class FbgemmFp8Test(unittest.TestCase):
7681
[
7782
"What are we having for dinner?\nI'm having a steak and a salad",
7883
"What are we having for dinner? I don’t know. What are we having",
84+
"What are we having for dinner? I don’t know, what are you having",
7985
]
8086
)
8187

82-
device_map = "cuda"
88+
device_map = "xpu" if is_torch_xpu_available() else "cuda"
8389

8490
offload_device_map = {
8591
"model.embed_tokens": 0,
@@ -176,6 +182,7 @@ def test_quantized_model_conversion(self):
176182

177183
self.assertEqual(nb_linears - 24, nb_fbgemm_linear)
178184

185+
@require_deterministic_for_xpu
179186
def test_quantized_model(self):
180187
"""
181188
Simple test that checks if the quantized model is working properly
@@ -185,6 +192,7 @@ def test_quantized_model(self):
185192
output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
186193
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
187194

195+
@require_deterministic_for_xpu
188196
def test_save_pretrained(self):
189197
"""
190198
Simple test that checks if the quantized model is working properly after being saved and loaded
@@ -219,7 +227,8 @@ def test_change_loading_attributes(self):
219227
output = model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
220228
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
221229

222-
@require_torch_multi_gpu
230+
@require_torch_multi_accelerator
231+
@require_deterministic_for_xpu
223232
def test_quantized_model_multi_gpu(self):
224233
"""
225234
Simple test that checks if the quantized model is working properly with multiple GPUs
@@ -248,6 +257,7 @@ def test_quantized_model_offload(self):
248257
self.model_name, device_map=self.offload_device_map, quantization_config=quantization_config
249258
)
250259

260+
@require_deterministic_for_xpu
251261
def test_save_pretrained_offload(self):
252262
"""
253263
Simple test that checks if the saved quantized model is working properly cpu/disk offload
@@ -261,7 +271,8 @@ def test_save_pretrained_offload(self):
261271
output = quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens, do_sample=False)
262272
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
263273

264-
@require_torch_multi_gpu
274+
@require_torch_multi_accelerator
275+
@require_deterministic_for_xpu
265276
def test_save_pretrained_multi_gpu(self):
266277
"""
267278
Simple test that checks if the quantized model is working properly after being saved and loaded
@@ -278,9 +289,9 @@ def test_save_pretrained_multi_gpu(self):
278289
self.assertTrue(self.tokenizer.decode(output[0], skip_special_tokens=True) in self.EXPECTED_OUTPUT)
279290

280291

281-
@require_torch_gpu
292+
@require_torch_accelerator
282293
@require_accelerate
283-
@require_fbgemm_gpu
294+
@unittest.skipIf(not is_torch_xpu_available() and not is_fbgemm_gpu_available(), "test requires fbgemm-gpu or xpu")
284295
class FbgemmFp8LinearTest(unittest.TestCase):
285296
def test_linear_preserves_shape(self):
286297
"""

0 commit comments

Comments
 (0)