diff --git a/src/diffusers/quantizers/gguf/gguf_quantizer.py b/src/diffusers/quantizers/gguf/gguf_quantizer.py index 15f39dd9605e..b71092ff77cf 100644 --- a/src/diffusers/quantizers/gguf/gguf_quantizer.py +++ b/src/diffusers/quantizers/gguf/gguf_quantizer.py @@ -29,6 +29,7 @@ _dequantize_gguf_and_restore_linear, _quant_shape_from_byte_shape, _replace_with_gguf_linear, + dequantize_gguf_tensor, ) @@ -116,6 +117,17 @@ def create_quantized_param( if tensor_name not in module._parameters and tensor_name not in module._buffers: raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + # If the GGUFParameter should not be quantized (for example, it is a submodule of any excluded module), + # dequantize it and set the (dequantized) parameter to the proper dtype. + if isinstance(param_value, GGUFParameter) and any( + m in param_name.split(".") for m in self.modules_to_not_convert + ): + keep_in_fp32 = getattr(self, "keep_in_fp32_modules", []) + target_dtype = ( + torch.float32 if any(m in param_name.split(".") for m in keep_in_fp32) else self.compute_dtype + ) + param_value = dequantize_gguf_tensor(param_value).to(target_dtype) + if tensor_name in module._parameters: module._parameters[tensor_name] = param_value.to(target_device) if tensor_name in module._buffers: @@ -130,7 +142,8 @@ def _process_model_before_weight_loading( ): state_dict = kwargs.get("state_dict", None) - self.modules_to_not_convert.extend(keep_in_fp32_modules) + self.keep_in_fp32_modules = [module for module in keep_in_fp32_modules if module is not None] + self.modules_to_not_convert.extend(self.keep_in_fp32_modules) self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] _replace_with_gguf_linear( diff --git a/src/diffusers/quantizers/gguf/utils.py b/src/diffusers/quantizers/gguf/utils.py index e0ad0e1cce42..c7d9ec89bee6 100644 --- a/src/diffusers/quantizers/gguf/utils.py +++ b/src/diffusers/quantizers/gguf/utils.py @@ -80,7 +80,7 @@ def _fused_mul_mat_gguf(x: torch.Tensor, qweight: torch.Tensor, qweight_type: in # there is no need to call any kernel for fp16/bf16 if qweight_type in UNQUANTIZED_TYPES: weight = dequantize_gguf_tensor(qweight) - return x @ weight.T + return x @ weight.to(x.dtype).T # TODO(Isotr0py): GGUF's MMQ and MMVQ implementation are designed for # contiguous batching and inefficient with diffusers' batching, @@ -134,6 +134,8 @@ def _should_convert_to_gguf(state_dict, prefix): return for name, module in model.named_children(): + if name in modules_to_not_convert: + continue module_prefix = prefix + name + "." _replace_with_gguf_linear(module, compute_dtype, state_dict, module_prefix, modules_to_not_convert)