Skip to content

remove str option for quantization config in torchao#13291

Open
howardzhang-cv wants to merge 4 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test
Open

remove str option for quantization config in torchao#13291
howardzhang-cv wants to merge 4 commits intohuggingface:mainfrom
howardzhang-cv:update_torchao_test

Conversation

@howardzhang-cv
Copy link
Contributor

What does this PR do?

Remove the deprecated string-based quant_type path from TorchAoConfig, requiring AOBaseConfig instances instead.

  • TorchAoConfig.init now only accepts AOBaseConfig subclass instances (e.g. Int8WeightOnlyConfig()) and raises TypeError for strings
  • Deleted ~200 lines of dead code: _get_torchao_quant_type_to_method, _is_xpu_or_cuda_capability_atleast_8_9, TorchAoJSONEncoder, and all string-parsing branches in post_init, to_dict, from_dict, get_apply_tensor_subclass
  • Simplified torchao_quantizer.py: removed string-based branches in update_torch_dtype, adjust_target_dtype, get_cuda_warm_up_factor; fixed is_trainable which would crash on AOBaseConfig objects
  • Converted all test cases from string quant types to their AOBaseConfig equivalents; removed test_floatx_quantization (no replacement for fpx_weight_only)
  • Updated docs to show only AOBaseConfig-based usage

Testing

python -m pytest tests/quantization/torchao/test_torchao.py -xvs

Who can review?

@sayakpaul

Comment on lines +217 to +227
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
elif quant_type == "uintx_weight_only":
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8)
elif quant_type.startswith("uint"):
return {
1: torch.uint1,
2: torch.uint2,
3: torch.uint3,
4: torch.uint4,
5: torch.uint5,
6: torch.uint6,
7: torch.uint7,
}[int(quant_type[4])]
elif quant_type.startswith("float") or quant_type.startswith("fp"):
return torch.bfloat16

elif is_torchao_version(">", "0.9.0"):
from torchao.core.config import AOBaseConfig

quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)

# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
else:
# Default to int8
return torch.int8

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit fragile, I think it's from transformers originally, not sure if this is still needed in transformers though, it might have been refactored after 5.0 update

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry might be a dumb question, what part are you referring to that's from transformers originally? Is it the entire adjust_target_dtype function?

for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
if isinstance(quant_type, AOBaseConfig):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, when do we use this? cc @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do what?

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link

@jerryzh168 jerryzh168 Mar 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

separate PR: I feel we should just have a single assertion for torchao to be a relatively recent version (e.g. 0.15) and remove all these version checks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually yeah I was going to ask you about that as well. There's a couple version checks scattered around right now. Would be cleaner to just remove all of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to wait to do this in a separate PR? I changed and set it to 0.9.0 because that's when AOBaseConfig was supported. Moving to 0.15.0 might be cleaner in a separate PR in case we need to revert for whatever reason?

@howardzhang-cv howardzhang-cv marked this pull request as ready for review March 19, 2026 21:43
Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for starting this work!

I think we can merge this fairly soon.

return torch.int8
elif quant_type.startswith("int4"):
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the fallback path? Would it be better to mandate that the users pass a valid AOBaseConfig instance?

for pattern, target_dtype in map_to_target_dtype.items():
if fnmatch(quant_type, pattern):
return target_dtype
if isinstance(quant_type, AOBaseConfig):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do what?

f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
)
if is_torchao_version("<=", "0.9.0"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we should mandate a minimum version requirement here.

Comment on lines +478 to +479
if not isinstance(self.quant_type, AOBaseConfig):
raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes cool!

Comment on lines +488 to +490
# For now we assume there is 1 config per Transformer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be really nice! I think we should also provide a reference link to the TorchAO docs to remind ourselves what that granularity means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

@sayakpaul
Copy link
Member

@bot /style

@github-actions
Copy link
Contributor

github-actions bot commented Mar 20, 2026

Style bot fixed some files and pushed the changes.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants