From 3fc7ca8b764a1f5bd4522339666d811c25156fdf Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 10 Dec 2025 17:08:00 +0530 Subject: [PATCH 1/5] fix: lora checkpoint Signed-off-by: Mehant Kammakomati --- .../utils/checkpoint_utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 98b49fa5..7b5fecba 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -646,15 +646,22 @@ def recover_safetensors_from_dcp( # create switch based on state dict for future use new_state_dict = {} lora = False + lora_keys = {} for name, param in state_dict.items(): # if lora weight, set lora switch to true if "lora_A" in name or "lora_B" in name: lora = True # if lora naming convention, convert to traditional if "base_model.model." in name: + v = name name = name.replace("base_model.model.", "", 1) + k = name + lora_keys[k] = v if "default." in name: + v = name name = name.replace("default.", "", 1) + k = name + lora_keys[k] = v new_state_dict[name] = param # recover the original state dict @@ -662,9 +669,16 @@ def recover_safetensors_from_dcp( new_state_dict, _name_or_path ) + new_state_dict = {} + # modify the state dict back to HF PEFT format + for name, param in state_dict.items(): + if name in lora_keys.keys(): + name = lora_keys[name] + new_state_dict[name] = param + # save it as a safetensors file save_sharded_safetensors( - {k: v.contiguous() for k, v in state_dict.items()}, + {k: v.contiguous() for k, v in new_state_dict.items()}, output_dir, metadata={"format": "pt"}, lora=lora, From adfc3c249b26b9b0f9941c2674d4d05caef5143c Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 10 Dec 2025 17:18:33 +0530 Subject: [PATCH 2/5] fix: lora checkpoint Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 7b5fecba..6c8ecbc1 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -654,7 +654,9 @@ def recover_safetensors_from_dcp( # if lora naming convention, convert to traditional if "base_model.model." in name: v = name - name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("base_model.model.", "", 1) + name = name.replace("default.", "", 1) k = name lora_keys[k] = v if "default." in name: From 3bf6a672cf2eaf39b7673ca21901b03b9eb544e2 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 10 Dec 2025 17:31:14 +0530 Subject: [PATCH 3/5] fix: lora checkpoint Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 6c8ecbc1..1c3a1b55 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -654,8 +654,8 @@ def recover_safetensors_from_dcp( # if lora naming convention, convert to traditional if "base_model.model." in name: v = name + name = name.replace("base_model.model.", "", 1) if "default." in name: - name = name.replace("base_model.model.", "", 1) name = name.replace("default.", "", 1) k = name lora_keys[k] = v From 85c1e42c29acd2bff220a17d5b2469178924c72a Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 10 Dec 2025 17:32:13 +0530 Subject: [PATCH 4/5] fix: lora checkpoint Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 1c3a1b55..008bcad3 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -677,7 +677,7 @@ def recover_safetensors_from_dcp( if name in lora_keys.keys(): name = lora_keys[name] new_state_dict[name] = param - + # save it as a safetensors file save_sharded_safetensors( {k: v.contiguous() for k, v in new_state_dict.items()}, From af233acf798117afd0b404ab781c52f29d8a2a4e Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Wed, 10 Dec 2025 17:37:42 +0530 Subject: [PATCH 5/5] fix: lora checkpoint Signed-off-by: Mehant Kammakomati --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 008bcad3..03986806 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -674,7 +674,7 @@ def recover_safetensors_from_dcp( new_state_dict = {} # modify the state dict back to HF PEFT format for name, param in state_dict.items(): - if name in lora_keys.keys(): + if lora_keys.get(name, None): name = lora_keys[name] new_state_dict[name] = param