From f8b17af01ad4b38af3416affcda685731cf2cf70 Mon Sep 17 00:00:00 2001 From: s-zx <2575376715@qq.com> Date: Tue, 10 Mar 2026 23:08:24 +0100 Subject: [PATCH] fix: handle LoKr format keys in Z-image LoRA conversion The _convert_non_diffusers_z_image_lora_to_diffusers function did not consume LoKr-format keys (.lokr_w1, .lokr_w2, .alpha) from external Z-image LoRA checkpoints (e.g. Kohya/LyCORIS), causing a ValueError when the state_dict was not empty after conversion. Add handling to convert LoKr decomposition to standard lora_A/lora_B format: for linear layers, lokr_w1 @ lokr_w2 maps to lora_B @ lora_A with the same alpha scaling used by other formats. Fixes #13221 Signed-off-by: s-zx --- .../loaders/lora_conversion_utils.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 0895d5223e13..52c5a3bae26e 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2624,6 +2624,38 @@ def get_alpha_scales(down_weight, alpha_key): converted_state_dict[diffusers_down] = down_weight * scale_down converted_state_dict[diffusers_up] = up_weight * scale_up + # Handle LoKr format: .alpha, .lokr_w1, .lokr_w2 (e.g. from Kohya/LyCORIS Z-image trainers). + # LoKr decomposition: delta = alpha * (lokr_w1 @ lokr_w2). Map to LoRA: lora_B @ lora_A. + lokr_w1_key = ".lokr_w1" + lokr_w2_key = ".lokr_w2" + has_lokr_format = any(lokr_w1_key in k for k in state_dict) + + if has_lokr_format: + lokr_keys = [k for k in list(state_dict.keys()) if lokr_w1_key in k] + for k in lokr_keys: + if k not in state_dict: + continue + if not k.endswith(lokr_w1_key): + continue + + base = k[: -len(lokr_w1_key)] + lokr_w2_key_full = base + lokr_w2_key + alpha_key = base + ".alpha" + + if lokr_w2_key_full not in state_dict or alpha_key not in state_dict: + continue + + lokr_w1 = state_dict.pop(k) + lokr_w2 = state_dict.pop(lokr_w2_key_full) + scale_down, scale_up = get_alpha_scales(lokr_w2, alpha_key) + + # lora_A = lokr_w2 (r, in), lora_B = lokr_w1 (out, r) + diffusers_a_key = base + ".lora_A.weight" + diffusers_b_key = base + ".lora_B.weight" + converted_state_dict[diffusers_a_key] = lokr_w2 * scale_down + converted_state_dict[diffusers_b_key] = lokr_w1 * scale_up + state_dict.pop(alpha_key, None) + if len(state_dict) > 0: raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")