diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 98b49fa5..03986806 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -646,15 +646,24 @@ def recover_safetensors_from_dcp( # create switch based on state dict for future use new_state_dict = {} lora = False + lora_keys = {} for name, param in state_dict.items(): # if lora weight, set lora switch to true if "lora_A" in name or "lora_B" in name: lora = True # if lora naming convention, convert to traditional if "base_model.model." in name: + v = name name = name.replace("base_model.model.", "", 1) + if "default." in name: + name = name.replace("default.", "", 1) + k = name + lora_keys[k] = v if "default." in name: + v = name name = name.replace("default.", "", 1) + k = name + lora_keys[k] = v new_state_dict[name] = param # recover the original state dict @@ -662,9 +671,16 @@ def recover_safetensors_from_dcp( new_state_dict, _name_or_path ) + new_state_dict = {} + # modify the state dict back to HF PEFT format + for name, param in state_dict.items(): + if lora_keys.get(name, None): + name = lora_keys[name] + new_state_dict[name] = param + # save it as a safetensors file save_sharded_safetensors( - {k: v.contiguous() for k, v in state_dict.items()}, + {k: v.contiguous() for k, v in new_state_dict.items()}, output_dir, metadata={"format": "pt"}, lora=lora,