1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from functools import lru_cache
1516from typing import Optional
1617
1718from ..activations import ACT2FN
1819from ..core_model_loading import ConversionOps
1920from ..quantizers .quantizers_utils import get_module_from_name , should_convert_module
20- from ..utils import is_accelerate_available , is_fbgemm_gpu_available , is_torch_available , logging
21+ from ..utils import (
22+ is_accelerate_available ,
23+ is_fbgemm_gpu_available ,
24+ is_torch_available ,
25+ is_torch_xpu_available ,
26+ logging ,
27+ )
2128
2229
2330if is_torch_available ():
2734if is_accelerate_available ():
2835 from accelerate import init_empty_weights
2936
30- if is_fbgemm_gpu_available ():
37+ _is_torch_xpu_available = is_torch_xpu_available ()
38+
39+ if is_fbgemm_gpu_available () and not _is_torch_xpu_available :
3140 import fbgemm_gpu .experimental .gen_ai # noqa: F401
3241
3342logger = logging .get_logger (__name__ )
@@ -61,7 +70,7 @@ def convert(
6170 flattened_param = transposed_param .reshape (- 1 , original_shape [- 1 ])
6271
6372 # Quantize using per row instead of per column
64- new_value_flat , weight_scale_flat = torch . ops . fbgemm . quantize_fp8_per_row (flattened_param )
73+ new_value_flat , weight_scale_flat = quantize_fp8_per_row (flattened_param )
6574
6675 # Reshape back to original dimensions
6776 new_value = new_value_flat .reshape (original_shape )
@@ -77,14 +86,14 @@ def convert(
7786 flattened_param = transposed_param .reshape (- 1 , original_shape [- 1 ])
7887
7988 # Quantize using per column
80- new_value_flat , weight_scale_flat = torch . ops . fbgemm . quantize_fp8_per_row (flattened_param )
89+ new_value_flat , weight_scale_flat = quantize_fp8_per_row (flattened_param )
8190
8291 # Reshape back to original dimensions
8392 new_value = new_value_flat .reshape (original_shape )
8493 new_value = new_value .transpose (1 , 2 )
8594 weight_scale = weight_scale_flat .reshape (original_shape [0 ], original_shape [1 ], 1 )
8695 else :
87- new_value , weight_scale = torch . ops . fbgemm . quantize_fp8_per_row (value )
96+ new_value , weight_scale = quantize_fp8_per_row (value )
8897 weight_scale = torch .nn .Parameter (weight_scale .view (weight_scale .shape [0 ], 1 ))
8998
9099 return {target_key : torch .nn .Parameter (new_value ), f"{ target_key } _scale" : weight_scale }
@@ -110,18 +119,26 @@ def forward(self, x):
110119 output_shape = (* x .shape [:- 1 ], - 1 )
111120 # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
112121 # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
113- x_quantized , x_scale = torch .ops .fbgemm .quantize_fp8_per_row (
114- x .view (- 1 , x .shape [- 1 ]).contiguous (), scale_ub = self .input_scale_ub
115- )
122+ x_quantized , x_scale = quantize_fp8_per_row (x .view (- 1 , x .shape [- 1 ]).contiguous (), scale_ub = self .input_scale_ub )
116123 # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
117124 # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
118125
119126 # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
120127 weight_scale_float32 = self .weight_scale .to (torch .float32 )
121- output = torch .ops .fbgemm .f8f8bf16_rowwise (
122- x_quantized , self .weight , x_scale , weight_scale_float32 , use_fast_accum = True
123- )
124- output = output + self .bias if self .bias is not None else output
128+ if _is_torch_xpu_available :
129+ output = torch ._scaled_mm (
130+ x_quantized ,
131+ self .weight .t (),
132+ scale_a = x_scale .unsqueeze (- 1 ),
133+ scale_b = weight_scale_float32 .t (),
134+ out_dtype = x .dtype ,
135+ bias = self .bias ,
136+ )
137+ else :
138+ output = torch .ops .fbgemm .f8f8bf16_rowwise (
139+ x_quantized , self .weight , x_scale , weight_scale_float32 , use_fast_accum = True
140+ )
141+ output = output + self .bias if self .bias is not None else output
125142 # Hacky for now, we have the output to the device of x
126143 output = output .to (x .device )
127144 output = output .reshape (output_shape )
@@ -173,48 +190,79 @@ def forward(self, hidden_states):
173190 expert_hidden = hidden_states [i ]
174191 expert_hidden_reshaped = expert_hidden .reshape (- 1 , self .hidden_size )
175192 # Quantize for this expert
176- expert_quantized , expert_scale = torch . ops . fbgemm . quantize_fp8_per_row (
193+ expert_quantized , expert_scale = quantize_fp8_per_row (
177194 expert_hidden_reshaped , num_tokens , self .input_scale_ub
178195 )
179196 sharded_expert_dim = self .gate_up_proj .shape [- 1 ] // 2
180197 gate_up_proj_scale_float32 = self .gate_up_proj_scale .to (torch .float32 )
198+ if _is_torch_xpu_available :
199+ gate = torch ._scaled_mm (
200+ expert_quantized ,
201+ self .gate_up_proj [i ].transpose (0 , 1 )[:sharded_expert_dim ].contiguous ().t (),
202+ scale_a = expert_scale .unsqueeze (- 1 ),
203+ scale_b = gate_up_proj_scale_float32 [i ][0 ][:sharded_expert_dim ].view (- 1 , 1 ).contiguous ().t (),
204+ out_dtype = hidden_states .dtype ,
205+ )
206+ up = torch ._scaled_mm (
207+ expert_quantized ,
208+ self .gate_up_proj [i ].transpose (0 , 1 )[sharded_expert_dim :].contiguous ().t (),
209+ scale_a = expert_scale .unsqueeze (- 1 ),
210+ scale_b = gate_up_proj_scale_float32 [i ][0 ][sharded_expert_dim :].view (- 1 , 1 ).contiguous ().t (),
211+ out_dtype = hidden_states .dtype ,
212+ )
213+ else :
214+ gate = torch .ops .fbgemm .f8f8bf16_rowwise (
215+ expert_quantized ,
216+ self .gate_up_proj [i ].transpose (0 , 1 )[:sharded_expert_dim ].contiguous (),
217+ expert_scale ,
218+ gate_up_proj_scale_float32 [i ][0 ][:sharded_expert_dim ].view (- 1 , 1 ).contiguous (),
219+ use_fast_accum = True ,
220+ )
181221
182- gate = torch .ops .fbgemm .f8f8bf16_rowwise (
183- expert_quantized ,
184- self .gate_up_proj [i ].transpose (0 , 1 )[:sharded_expert_dim ].contiguous (),
185- expert_scale ,
186- gate_up_proj_scale_float32 [i ][0 ][:sharded_expert_dim ].view (- 1 , 1 ).contiguous (),
187- use_fast_accum = True ,
188- )
189-
190- up = torch .ops .fbgemm .f8f8bf16_rowwise (
191- expert_quantized ,
192- self .gate_up_proj [i ].transpose (0 , 1 )[sharded_expert_dim :].contiguous (),
193- expert_scale ,
194- gate_up_proj_scale_float32 [i ][0 ][sharded_expert_dim :].view (- 1 , 1 ).contiguous (),
195- use_fast_accum = True ,
196- )
222+ up = torch .ops .fbgemm .f8f8bf16_rowwise (
223+ expert_quantized ,
224+ self .gate_up_proj [i ].transpose (0 , 1 )[sharded_expert_dim :].contiguous (),
225+ expert_scale ,
226+ gate_up_proj_scale_float32 [i ][0 ][sharded_expert_dim :].view (- 1 , 1 ).contiguous (),
227+ use_fast_accum = True ,
228+ )
197229
198230 activated = up * self .act_fn (gate )
199231
200- activated_quantized , activated_scale = torch .ops .fbgemm .quantize_fp8_per_row (
201- activated , num_tokens , self .input_scale_ub
202- )
232+ activated_quantized , activated_scale = quantize_fp8_per_row (activated , num_tokens , self .input_scale_ub )
203233
204234 down_proj_scale_float32 = self .down_proj_scale .to (torch .float32 )
205- expert_output = torch .ops .fbgemm .f8f8bf16_rowwise (
206- activated_quantized ,
207- self .down_proj [i ].transpose (0 , 1 ).contiguous (),
208- activated_scale ,
209- down_proj_scale_float32 [i ].view (- 1 , 1 ).contiguous (),
210- use_fast_accum = True ,
211- )
235+ if _is_torch_xpu_available :
236+ expert_output = torch ._scaled_mm (
237+ activated_quantized ,
238+ self .down_proj [i ].transpose (0 , 1 ).contiguous (),
239+ scale_a = activated_scale .unsqueeze (- 1 ),
240+ scale_b = down_proj_scale_float32 [i ].view (- 1 , 1 ).contiguous ().t (),
241+ out_dtype = hidden_states .dtype ,
242+ )
243+ else :
244+ expert_output = torch .ops .fbgemm .f8f8bf16_rowwise (
245+ activated_quantized ,
246+ self .down_proj [i ].transpose (0 , 1 ).contiguous (),
247+ activated_scale ,
248+ down_proj_scale_float32 [i ].view (- 1 , 1 ).contiguous (),
249+ use_fast_accum = True ,
250+ )
212251
213252 next_states [i ] = expert_output
214253 next_states = next_states .to (hidden_states .device )
215254 return next_states .view (- 1 , self .hidden_size )
216255
217256
257+ @lru_cache (maxsize = 1 )
258+ def get_quantize_fp8_per_row ():
259+ if _is_torch_xpu_available :
260+ from kernels import get_kernel
261+
262+ return get_kernel ("kernels-community/fp8-fbgemm" ).quantize_fp8_per_row
263+ return torch .ops .fbgemm .quantize_fp8_per_row
264+
265+
218266def replace_with_fbgemm_fp8_linear (
219267 model , modules_to_not_convert : list [str ] | None = None , quantization_config = None , pre_quantized = False , tp_plan = None
220268):
@@ -232,6 +280,8 @@ def replace_with_fbgemm_fp8_linear(
232280 pre_quantized (`book`, defaults to `False`):
233281 Whether the model is pre-quantized or not
234282 """
283+ global quantize_fp8_per_row
284+ quantize_fp8_per_row = get_quantize_fp8_per_row ()
235285
236286 has_been_replaced = False
237287 module_kwargs = {} if pre_quantized else {"dtype" : None }
0 commit comments