Skip to content

Conversation

jcaip
Copy link
Contributor

@jcaip jcaip commented Sep 26, 2025

This PR adds in support for quantizing nn.Parameter to quantize_by adding a new config, ModuleOrParamFqnToConfig. This new config is very similar to ModuleFqnToConfig except it also accepts nn.Parameter FQNs.

It also enables ModuleOrParamFqnToConfig for Float8DynamicActivationFloat8WeightConfig. Other configs will throw an NotImplementedError.

I've decided to remove the top-level quantize_(param, config) -> new_param functionality, will instead expose this as quantize_tensor in a subsequent PR.

API examples

For example, a toy nn.Linear model,

model = nn.Sequential( nn.Linear(128, 128), nn.Linear(128, 128), )

We can quantize the weight of the first linear as follows

from torchao.quantization import ModuleOrParamFqnToConfig quant_config = ModuleOrParamFqnToConfig({ "0.weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), })

We can quantize all parameters that match "weight" with regexs by prepending re: to the string

quant_config = ModuleOrParamFqnToConfig({ "re:.*weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), })

When both a module and parameters match a regex, the module configs take precedence. We ignore swapping parameters for modules that already have an instance of TorchAOBaseTensor. Below model[0].weight will be a PerRow quantized float8 tensor, and we will not try to replace the bias as 0 has already been transformed, even though it matches the regex.

quant_config = ModuleOrParamFqnToConfig({ # replace using module "0": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), # matches 0 (nn.Linear), 0.weight (nn.Param), and 0.bias (nn.Param) "re:0": Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ), })

Below would quantize_ both MyBlock.weight and MyBlock.bias, as MyBlock is not an instance of nn.Linear and therefore will not be modified by the module config flow first.

class MyBlock(nn.Module): weight = nn.Parameter bias = nn.Parameter model = nn.Sequential( MyBlock(), MyBlock(), ) quant_config = ModuleOrParamFqnToConfig({ # matches 0.weight (nn.Param), and 0.bias (nn.Param) "re:0": Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ), })

Note that filter_fn is ignored for parameter quantization, it is possible to enable support for this, but I chose not to since it's kind of a footgun IMO, the default filter_fn is _is_linear and all MoE model definitions would fail this check.

Test Plan

  1. unit tests for new config:
pytest test/quantization/test_quant_api.py::TestModuleOrParamFqnToConfig 
  1. regression test for ModuleFqnToConfig
pytest test/quantization/test_quant_api.py -k test_module_fqn_to_config 
  1. Make sure that we can load old HF checkpoints to maintain BC, run this

How do our configs translate for MoEs?

Currently, we define a bunch of configs that are for dense nn.Linear modules, how do these configs translate in the case of MoE inference?

Some background on MoE inference

There are two ways that forwards is implemented for MoE

  • For loop of nn.Linear - In this case, we break down the 3d weight x activation matmul into a for loop of 2d weight x activation matmuls. This can be seen here.

In this case, I argue that the semantics of the configs do not change at all from the normal nn.Linear case, as we are just doing a bunch of normal 2d linear matmuls.

  • bmm/grouped mm on the 3d weights / activations directly.

For this case, we'd need to add additional op support (bmm) for forwards. Depending on whether the subclass is an AQT subclass or non AQT subclass this will be added differently.

I plan to only support parameter quantization for non-AQT subclasses, my reasoning being that those are the most popular / important configs anyway (Float8Dynamic, Int4WeightOnly).

Below is a breakdown of what Configs map to AQT / non-AQT subclasses:

not using AQT AffineQuantizedTensor
Float8DynamicActivationFloat8WeightConfig FPXWeightOnlyConfig
Float8DynamicActivationInt4WeightConfig Float8WeightOnlyConfig
Float8StaticActivationFloat8WeightConfig Float8DynamicActivationFloat8SemiSparseWeightConfig
Int4WeightOnlyConfig (v2) GemliteUIntXWeightOnlyConfig
Int4DynamicActivationInt4WeightConfig
Int8DynamicActivationInt4WeightConfig
Int8DynamicActivationInt8WeightConfig
Int8WeightOnlyConfig
IntxWeightOnlyConfig
UIntXWeightOnlyConfig

For these the majority of the semantics remain the same, the only semantics that really changes is PerRow granularity. and there's a very natural extension of PerRow to the 3d case (apply on the last dimension).

I took a look at the keys of the non-AQT configs below and what they would mean for MoEs.

Float8DynamicActivationFloat8WeightConfig

[('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.List[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('activation_value_lb', typing.Optional[float]), ('activation_value_ub', typing.Optional[float]), ('kernel_preference', <enum 'KernelPreference'>), ('set_inductor_config', <class 'bool'>), ('version', <class 'int'>)] 

activation_dtype, weight_dtype, activation_value_lb, activation_value_ub all do not change meaning semantically.
granularity=PerTensor() does not change semantic meaning - we still use a single tensor to scale the entire weight tensor.
granularity=PerRow() does change meaning - we now calculate a scale for each row for the last dimension [-1] i.e for a weight of (E, N, K) we would expect PerRow to create scales of block size (1, 1, K).
mm_config kernel_preference and set_inductor_config stay the same as well.

Float8StaticActivationFloat8WeightConfig

[('scale', <class 'torch.Tensor'>), ('activation_dtype', <class 'torch.dtype'>), ('weight_dtype', <class 'torch.dtype'>), ('granularity', typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow'), typing.Tuple[typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')], typing.Union[ForwardRef('PerTensor'), ForwardRef('PerRow')]], NoneType]), ('mm_config', typing.Optional[torchao.float8.inference.Float8MMConfig]), ('set_inductor_config', <class 'bool'>)] 

scale should be passed in as a 3d tensor instead of a 2d tensor in the case of PerRow granularity

Float8DynamicActivationInt4WeightConfig

[('int4_packing_format', <enum 'Int4PackingFormat'>)] 

int4_packing_format - Only "preshuffled" is supported and Int4PreshuffledTensor supports 3d weights.

Int4WeightOnlyConfig

[('group_size', <class 'int'>), ('layout', typing.Optional[torchao.dtypes.uintx.tensor_core_tiled_layout.TensorCoreTiledLayout]), ('use_hqq', <class 'bool'>), ('zero_point_domain', typing.Optional[torchao.quantization.quant_primitives.ZeroPointDomain]), ('set_inductor_config', <class 'bool'>), ('preserve_zero', typing.Optional[bool]), ('int4_packing_format', <enum 'Int4PackingFormat'>), ('int4_choose_qparams_algorithm', <enum 'Int4ChooseQParamsAlgorithm'>), ('version', <class 'int'>)] 

group_size, int4_packing_format, int4_choose_qparams_algorithm, set_inductor_config are the only things that are set for v2 config,

I don't think these semantics of these change, although there are some packing formats that do not support 3d weights. It looks like (Int4PackingFormat.PLAIN_INT32, Int4PackingFormat.MARLIN_SPARSE).

Summary: This PR adds in a simple 2d and 3d moe implementation and tests `quantize_` on them to see if we get the same results. Test Plan: ``` pytest test/prototype/test_parameter.py -k test_quantize_parameter ``` Reviewers: Subscribers: Tasks: Tags:
Copy link

pytorch-bot bot commented Sep 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3083

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 829d31f with merge base 5346f0e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2025
@jcaip jcaip requested review from jerryzh168 and vkuzo September 26, 2025 21:00
@jerryzh168
Copy link
Contributor

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Add in ParamFqnToConfig config
This new config is very similar to ModuleFqnToConfig except it takes in nn.Parameter FQNs and also supports regexs.

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object? Pseudocode of what it could do:

def handle_module(model, fqn, config): if has_parameter(model, fqn): ... new behavior for parameters, apply parameter swap config ... elif has_parameter(model, fqn + '.weight'): ... old behavior, apply parameter swap config ... elif has_module(model, fqn): ... old behavior, apply module swap ...
@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object?

Yeah, we can do this. Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

@jcaip
Copy link
Contributor Author

jcaip commented Sep 29, 2025

current AOBaseConfig is more for linear weights, can it be extended to param config cleanly?

Yes I believe so, especially in the case of the Config object itself. We attach everything to the weight parameter for nn.Linear, so this allows us to specify the parameter name instead of assuming it's "weight".

The only thing that does not map cleanly IMO is the module_registration:

 # non user facing code @register_quantize_module_handler(WorkflowFooConfig) def _transform( mod: torch.nn.Module, config: WorkflowFooConfig, ) -> torch.nn.Module: # the transform is implemented here, usually a tensor sublass # weight swap or a module swap 

I think we should define the transform for parameters as the base case (aka @register_quantize_handler) , and use that for the module flow (assuming the parameter is module.weight), since it's the more general case.

@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

IMO we should change the current name and keep the old name for BC:

ParamOrModuleFqnToConfig = ... # for bc ModuleFqnToConfig = ParamOrModuleFqnToConfig
@vkuzo
Copy link
Contributor

vkuzo commented Sep 29, 2025

I think we should define the transform for parameters as the base case

To me it seems that the transform has to be for modules, because it is inplace. User can target a parameter if they want to, but the transform function always runs on a module that owns the parameter.

# skip if not direct child
if "." not in name:
for pattern in config.param_fqn_to_config:
if re.match(pattern, f"{fqn}.{name}"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it applies to all params, regardless of what it is? e.g. bias? should we be more specific in what people are configuring?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should consider the regex syntax separately, I can remove from this PR.

One thing I would like would be for quantize_ log the modules/params it's swapping so it's easy to see what the difference is.

@andrewor14
Copy link
Contributor

Does this mean we need to refactor all supported configs to use this structure?

@register_quantized_param_handler(config) def _float8_dynamic_activation_float8_weight_quantize_tensor(...): # returns quantized tensor def _float8_dynamic_activation_float8_weight_transform(...): module.weight = _float8_dynamic_activation_float8_weight_quantize_tensor(...) return module 


@dataclass
class ModuleOrParamFqnToConfig(AOBaseConfig):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just adding the logic to ModuleFqnToConfig, as I suggested in one of my previous comments

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have some merge conflicts with #3084, do you want to land yours first @jerryzh168 and then I can rebase?

Copy link
Contributor

@jerryzh168 jerryzh168 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure, I'll update the PR now

Copy link
Contributor

@vkuzo vkuzo Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ping on my original comment

Copy link
Contributor Author

@jcaip jcaip Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should keep the ModuleFqnToConfig name? It's a little confusing I feel to pass in parameter fqn but it's also being used by huggingface and vllm so I think it would be better to keep it as is.

IMO we should change the current name and keep the old name for BC:

ParamOrModuleFqnToConfig = ... # for bc ModuleFqnToConfig = ParamOrModuleFqnToConfig

I thought we agreed on renaming ModueFqnToConfig but unifying on a single object? Is there another comment you're referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you are right, I misread the code, sorry

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds in support for quantizing nn.Parameter to quantize_by adding a new config, ModuleOrParamFqnToConfig. This new config is very similar to ModuleFqnToConfig except it also accepts nn.Parameter FQNs.

I missed it because the PR summary still said we are adding a new object, can we update the summary to reflect the current state of the PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about FqnToConfig, and assert inline that the thing pointed to by FQN is a module or a parameter, throw an exception on other attributes? IMO simpler name that will cover the known use cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me.

for pattern, param_config in config.module_or_param_fqn_to_config.items():
full_param_fqn = f"{fqn}.{name}"
if (pattern == full_param_fqn) or (
pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we use re.fullmatch? since that's the behavior we want right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm onboard, will update this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also updated https://github.com/pytorch/ao/pull/3084/files we can align the implementation to check regex as well

@jcaip jcaip requested review from jerryzh168 and vkuzo October 6, 2025 19:04
class ModuleOrParamFqnToConfig(AOBaseConfig):
"""Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs).
This extends the functionality of ModuleFqnToConfig to support parameter-level quantization configurations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment seems stale

@vkuzo
Copy link
Contributor

vkuzo commented Oct 7, 2025

Add in ParamFqnToConfig config
This new config is very similar to ModuleFqnToConfig except it takes in nn.Parameter FQNs and also supports regexs.

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object? Pseudocode of what it could do:

def handle_module(model, fqn, config): if has_parameter(model, fqn): ... new behavior for parameters, apply parameter swap config ... elif has_parameter(model, fqn + '.weight'): ... old behavior, apply parameter swap config ... elif has_module(model, fqn): ... old behavior, apply module swap ...

@jcaip would this be simpler than having two transform registration systems?

@jcaip
Copy link
Contributor Author

jcaip commented Oct 7, 2025

Would it work to stick with ModuleFqnToConfig and update its meaning, to avoid introducing a new object with a lot of similarities with the old object? Pseudocode of what it could do:

def handle_module(model, fqn, config): if has_parameter(model, fqn): ... new behavior for parameters, apply parameter swap config ... elif has_parameter(model, fqn + '.weight'): ... old behavior, apply parameter swap config ... elif has_module(model, fqn): ... old behavior, apply module swap ...

@jcaip would this be simpler than having two transform registration systems?

cc @vkuzo

Hmm, I think the pseudocode mentioned here vs the logic in the PR and having two transform registration systems are a bit orthogonal. It's possible to have one registration system with the logic in the PR as well. I'm assuming your main concern is with having two registration systems? Let me know if that's not the case.

IMO it's about the same complexity to have one registration system vs two. My main preference for having two registration systems is that it reduces the amount of work we have to do to enable other Configs for parameter quantization - we just need to add the decorator to our from_hp or from_float class function. In the case of having a shared registration system, we'd need to modify each existing transform function manually to add non-weight param support.

@vkuzo
Copy link
Contributor

vkuzo commented Oct 7, 2025

I'm assuming your main concern is with having two registration systems?

yes, and even further IMO we should have a single "modify module inplace" paradigm instead of having one paradigm for modules and one for parameters

My main preference for having two registration systems is that it reduces the amount of work we have to do to enable other Configs for parameter quantization

IMO we should go for the solution where the resulting code is the simplest, if that involves manual work that seems OK to me, and we can parallelize the conversions if you don't want to do them alone. Reducing the work to convert but ending up with two systems seems like trading dev time now for increased system complexity later.

@jcaip
Copy link
Contributor Author

jcaip commented Oct 7, 2025

OK I'll update the PR to use a single registration system.

yes, and even further IMO we should have a single "modify module inplace" paradigm instead of having one paradigm for modules and one for parameters

One thing I want to point out is that it's difficult to supports stuff like our vLLM integration, where we pass in a parameter that's not tied to any module, with a single "modify module inplace" paradigm.

@vkuzo
Copy link
Contributor

vkuzo commented Oct 7, 2025

One thing I want to point out is that it's difficult to supports stuff like our vLLM integration, where we pass in a parameter that's not tied to any module, with a single "modify module inplace" paradigm.

I think "everything is parameters" is also a valid solution, I just don't think we should have both - let's pick one?

`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an
ordered dictionary from
(1). fully qualified name (fqn) of module or
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use typing.OrderedDict since it's different from collections.OrderedDict

Raises:
NotImplementedError: If a configuration type doesn't have a registered parameter handler.
"""
top_level_named_parameters_list = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is this the same as list(dict(mod_containing_param.named_parameters()).items())

for name, param in top_level_named_parameters_list:
for pattern, param_config in config.module_or_param_fqn_to_config.items():
full_param_fqn = f"{fqn}.{name}"
if (pattern == full_param_fqn) or (
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, if we want exact match (==) to take precedence, I think it has to be a separate check,

if pattern == full_param_fqn: ... elif pattern.startswith("re:") and ...: ... 

A test of

model: with linear1 module config: {"re:linear.*": config1, "linear1": config2} 

and linear1 should be quantized with config2 instead of config1 should catch it

"0": Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(),
),
"re:.*weight": Float8DynamicActivationFloat8WeightConfig(
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should test the reverse order I think, to make sure 0 takes precedence

quantize_(
model,
quant_config,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checks?

`module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an
ordered dictionary from
(1). fully qualified name (fqn) of module or
module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also to correct the naming, we can add a module_or_param_fqn_to_config field and use that for version 2, and go through the normal version update path like other configs as well I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about just fqn_to_config

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sounds good

Copy link
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcaip can you add ModuleOrParamFqnToConfig to torchao docs as well? I would like to link to it in transformer docs

Comment on lines 519 to +526
_replace_with_custom_fn_if_matches_filter_with_name(
model,
_module_fqn_to_config_handler,
filter_fn,
device=device,
extra_args=(config,),
)
_replace_with_custom_fn_if_matches_filter_with_name(
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can we just do a single replacement? that might be simpler

the older functionality is a special case of the new one with param_name="weight" so seems like we can use the same code path for everything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, before I was doing this in two passes because we were doing re.search, so something like re:linear would match the fqn of the module, the weight, and the bias. But now that we're using re.fullmatch I think we can unify this.

Note:
- The order of patterns in the OrderedDict may matter as only the first matching pattern is applied
- Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization
- "_default" is ignored for parameter replacement.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK if we merge the replacement code, I guess this option has to be valid for params as well, or we need to rename this to something else? or just remove this option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really default linear, since we assume filter_fn is _is_linear if not specified, I think it makes the most sense to remove, it seems like users could just use the regex support to apply to all linear explicitly by their FQN.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah makes sense, can you keep BC in this PR, and we can change all the published checkpoints and deprecate this separately

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
4 participants