Fix GGUF to Work Better with modules_to_not_convert / keep_in_fp32_modules#13697
Fix GGUF to Work Better with modules_to_not_convert / keep_in_fp32_modules#13697
modules_to_not_convert / keep_in_fp32_modules#13697Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| # If the GGUFParameter should not be quantized (for example, it is a submodule of any excluded module), | ||
| # dequantize it and set the (dequantized) parameter to the proper dtype. | ||
| if isinstance(param_value, GGUFParameter) and any( | ||
| m in param_name.split(".") for m in self.modules_to_not_convert | ||
| ): | ||
| keep_in_fp32 = getattr(self, "keep_in_fp32_modules", []) | ||
| target_dtype = ( | ||
| torch.float32 if any(m in param_name.split(".") for m in keep_in_fp32) else self.compute_dtype | ||
| ) | ||
| param_value = dequantize_gguf_tensor(param_value).to(target_dtype) |
There was a problem hiding this comment.
I am a bit confused. If a param is already GGUFParameter type, then I'd assume that it's already quantized. In that case, how come dequantize -> type upcasting is the right sequence of ops?
What am I missing?
There was a problem hiding this comment.
The idea is that the GGUF checkpoint might specify a quantization for a parameter that we do not want to be quantized, as expressed through either _keep_in_fp32_modules on the model: ModelMixin or modules_to_not_convert on GGUFQuantizationConfig.
When we load the GGUF state dict, these parameters will be placed into a GGUFParameter, and this happens before we load the weights into the model (e.g. in FromOriginalModelMixin.from_single_file). To respect modules_to_not_convert, we need to convert these back into normal (unquantized) parameters, which we do here at load time via dequantize_gguf_tensor. We then need to cast the parameter to the right compute dtype, which is torch.float32 for keep_in_fp32_modules and compute_dtype otherwise.
Currently, GGUFQuantizationConfig doesn't expose a modules_to_not_convert argument, but keep_in_fp32_modules are included in modules_to_not_convert:
So this change would affect only any specified _keep_in_fp32_modules right now.
| if qweight_type in UNQUANTIZED_TYPES: | ||
| weight = dequantize_gguf_tensor(qweight) | ||
| return x @ weight.T | ||
| return x @ weight.to(x.dtype).T |
There was a problem hiding this comment.
Would it break torch.compile compatibility for models that don't define modules_to_not_convert / keep_in_fp32_modules?
There was a problem hiding this comment.
I'm not sure how it will interact with torch.compile, but this change mirrors the implementation used for quantized weight types (qweight_type in DEQUANT_TYPES):
diffusers/src/diffusers/quantizers/gguf/utils.py
Lines 98 to 99 in d773308
So I think it should be fine? (I think this change isn't specific to modules_to_not_convert, as the GGUF checkpoint could store weights in e.g. BF16 even if modules_to_not_convert is empty, which would then go through this code path.)
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks! I left some comments. I think there should be a test for this in
What does this PR do?
This PR contains several fixes so the GGUF loading and inference work better with
module_to_not_convertand_keep_in_fp32_modules.Changelist
src/diffusers/quantizers/gguf/utils.py_replace_with_gguf_linear: adds a check to see if any of the current module'snamed_childrenare inmodules_to_not_convert, and if so, skip it. This allows us skip containers, rather than just leaf-levelnn.Linearsubmodules as in the current code. For example,TimestepEmbeddingmodules are commonly added to_keep_in_fp32_modules(e.g.time_embedderinWanTransformer3DModel'sWanTimeTextImageEmbeddingcondition embedder), but since they themselves contain leafnn.Linearsubmodules such aslinear_1, the current code will only check against leaf modules such aslinear_1, and conclude incorrectly that they should be converted._fused_mul_mat_gguf: in theUNQUANTIZED_TYPEScase, also cast the dequantizedweightto the activationx'sdtypebefore performing the matrix multiplication, which should prevent dtype errors for BF16 weights.src/diffusers/quantizers/gguf/gguf_quantizer.pyGGUFQuantizer.create_quantized_param: handlesmodules_to_not_convertby dequantizing them, so that they end up in their original unquantized form. This is intended to handle the case where a module inself.modules_to_not_convert(or one of its children) is in the GGUF file. Since it is in the file, it will be converted to aGGUFParameter, but we don't want it to be quantized, so we convert it back here.Inspired by GGUF debugging in #13551, in particular #13551 (comment).
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@DN6
@sayakpaul