File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
vllm/model_executor/layers/quantization/utils Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -366,7 +366,7 @@ def per_token_group_quant_fp8(
366366 dtype : Optional [torch .dtype ] = None ,
367367 column_major_scales : bool = False ,
368368 out_q : Optional [torch .Tensor ] = None ,
369- use_ue8m0 : bool = is_blackwell_deep_gemm_used () ,
369+ use_ue8m0 : Optional [ bool ] = None ,
370370) -> tuple [torch .Tensor , torch .Tensor ]:
371371 """Function to perform per-token-group quantization on an input tensor `x`.
372372 It converts the tensor values into signed float8 values and returns the
@@ -383,6 +383,10 @@ def per_token_group_quant_fp8(
383383 tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
384384 scaling factor.
385385 """
386+ # TODO(wentao): refactor this
387+ # use_ue8m0 should be a global flag that could be set by user
388+ if use_ue8m0 is None :
389+ use_ue8m0 = is_blackwell_deep_gemm_used ()
386390 dtype = current_platform .fp8_dtype () if dtype is None else dtype
387391 assert (x .shape [- 1 ] % group_size == 0 ), (
388392 f"the last dimension of `x` { x .shape [- 1 ]} must be divisible "
You can’t perform that action at this time.
0 commit comments