remove str option for quantization config in torchao#13291
remove str option for quantization config in torchao#13291howardzhang-cv wants to merge 4 commits intohuggingface:mainfrom
Conversation
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) | ||
|
|
||
| # Map the extracted digit to appropriate dtype | ||
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| elif quant_type == "uintx_weight_only": | ||
| return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) | ||
| elif quant_type.startswith("uint"): | ||
| return { | ||
| 1: torch.uint1, | ||
| 2: torch.uint2, | ||
| 3: torch.uint3, | ||
| 4: torch.uint4, | ||
| 5: torch.uint5, | ||
| 6: torch.uint6, | ||
| 7: torch.uint7, | ||
| }[int(quant_type[4])] | ||
| elif quant_type.startswith("float") or quant_type.startswith("fp"): | ||
| return torch.bfloat16 | ||
|
|
||
| elif is_torchao_version(">", "0.9.0"): | ||
| from torchao.core.config import AOBaseConfig | ||
|
|
||
| quant_type = self.quantization_config.quant_type | ||
| if isinstance(quant_type, AOBaseConfig): | ||
| # Extract size digit using fuzzy match on the class name | ||
| config_name = quant_type.__class__.__name__ | ||
| size_digit = fuzzy_match_size(config_name) | ||
|
|
||
| # Map the extracted digit to appropriate dtype | ||
| if size_digit == "4": | ||
| return CustomDtype.INT4 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 | ||
| else: | ||
| # Default to int8 | ||
| return torch.int8 |
There was a problem hiding this comment.
this seems a bit fragile, I think it's from transformers originally, not sure if this is still needed in transformers though, it might have been refactored after 5.0 update
There was a problem hiding this comment.
Sorry might be a dumb question, what part are you referring to that's from transformers originally? Is it the entire adjust_target_dtype function?
| for pattern, target_dtype in map_to_target_dtype.items(): | ||
| if fnmatch(quant_type, pattern): | ||
| return target_dtype | ||
| if isinstance(quant_type, AOBaseConfig): |
There was a problem hiding this comment.
same here, when do we use this? cc @sayakpaul
| f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " | ||
| f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
| ) | ||
| if is_torchao_version("<=", "0.9.0"): |
There was a problem hiding this comment.
separate PR: I feel we should just have a single assertion for torchao to be a relatively recent version (e.g. 0.15) and remove all these version checks
There was a problem hiding this comment.
Actually yeah I was going to ask you about that as well. There's a couple version checks scattered around right now. Would be cleaner to just remove all of them.
There was a problem hiding this comment.
Yeah we should mandate a minimum version requirement here.
There was a problem hiding this comment.
Do we want to wait to do this in a separate PR? I changed and set it to 0.9.0 because that's when AOBaseConfig was supported. Moving to 0.15.0 might be cleaner in a separate PR in case we need to revert for whatever reason?
ed295a7 to
d49eb2e
Compare
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for starting this work!
I think we can merge this fairly soon.
| return torch.int8 | ||
| elif quant_type.startswith("int4"): | ||
| quant_type = self.quantization_config.quant_type | ||
| if isinstance(quant_type, AOBaseConfig): |
There was a problem hiding this comment.
What's the fallback path? Would it be better to mandate that the users pass a valid AOBaseConfig instance?
| for pattern, target_dtype in map_to_target_dtype.items(): | ||
| if fnmatch(quant_type, pattern): | ||
| return target_dtype | ||
| if isinstance(quant_type, AOBaseConfig): |
| f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the " | ||
| f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." | ||
| ) | ||
| if is_torchao_version("<=", "0.9.0"): |
There was a problem hiding this comment.
Yeah we should mandate a minimum version requirement here.
| if not isinstance(self.quant_type, AOBaseConfig): | ||
| raise TypeError(f"quant_type must be an AOBaseConfig instance, got {type(self.quant_type).__name__}") |
| # For now we assume there is 1 config per Transformer, however in the future | ||
| # We may want to support a config per fqn. | ||
| d["quant_type"] = {"default": config_to_dict(self.quant_type)} |
There was a problem hiding this comment.
This would be really nice! I think we should also provide a reference link to the TorchAO docs to remind ourselves what that granularity means.
|
@bot /style |
|
Style bot fixed some files and pushed the changes. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
What does this PR do?
Remove the deprecated string-based quant_type path from TorchAoConfig, requiring AOBaseConfig instances instead.
Testing
python -m pytest tests/quantization/torchao/test_torchao.py -xvsWho can review?
@sayakpaul