Skip to content

Commit c68bf05

Browse files
authored
feat: lora for accelerated MoE - limited (#141)
* limit changes Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * lora filtering Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * naming Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * fix: hardcodes Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * fix: target modules Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * remove lora config from expert weights Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * fix: pass in lora config Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * fix: requires grad Signed-off-by: Will Johnson <mwjohnson728@gmail.com> * fix: don't turn off requires grad Signed-off-by: Will Johnson <mwjohnson728@gmail.com> --------- Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent 2a3a10f commit c68bf05

File tree

5 files changed

+95
-42
lines changed

5 files changed

+95
-42
lines changed

plugins/accelerated-moe/.pylintrc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ ignored-parents=
281281
max-args=5
282282

283283
# Maximum number of attributes for a class (see R0902).
284-
max-attributes=7
284+
max-attributes=8
285285

286286
# Maximum number of boolean expressions in an if statement (see R0916).
287287
max-bool-expr=5

plugins/accelerated-moe/src/fms_acceleration_moe/framework_plugin_scattermoe.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def augmentation(
7777
modifiable_args: Tuple[LoraConfig],
7878
):
7979
rank, world_size = 0, 1
80+
(peft_config,) = modifiable_args
8081
if torch.distributed.is_initialized():
8182
world_size = torch.distributed.get_world_size()
8283
# we do not need to use the fallback as this is wrapped in an `is_initialized` block
@@ -97,6 +98,7 @@ def augmentation(
9798
ep_degree=self._ep_degree,
9899
disable_distributed=self._disable_distributed,
99100
mixed_precision=False, # Currently this is hardcoded to OFF
101+
lora_config=peft_config,
100102
)
101103
return model, modifiable_args
102104

plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py

Lines changed: 90 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@
5353
KEY_MODEL = "model"
5454
KEY_OPTIMIZER = "optimizer"
5555

56+
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
57+
5658
# Below are rewrite of HF FSDP model saving functions to be able to handle
5759
# that the parameters are now a mixture of regular and Dtensors.
5860
# - these functions are found in accelerate.utils.fsdp_utils.py
@@ -110,16 +112,30 @@ def save_fsdp_optimizer(
110112
# get the state dicts for model and optimize
111113
(model_state_dict, optimizer_state_dict) = get_state_dict(model, optimizer)
112114

115+
# filter out lora state dict
116+
lora_state_dict = {
117+
k: v for k, v in model_state_dict.items() if "lora_A" in k or "lora_B" in k
118+
}
119+
113120
# - save model
114-
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
115-
os.makedirs(ckpt_model, exist_ok=True)
116-
logger.info(f"Saving model to {ckpt_model}")
117-
dcp.save(
118-
state_dict={KEY_MODEL: model_state_dict},
119-
storage_writer=dcp.FileSystemWriter(ckpt_model),
120-
planner=DefaultSavePlanner(),
121-
)
122-
logger.info(f"Model saved to {ckpt_model}")
121+
if lora_state_dict:
122+
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
123+
os.makedirs(ckpt_model, exist_ok=True)
124+
logger.info(f"Saving lora model to {ckpt_model}")
125+
dcp.save(
126+
state_dict={KEY_MODEL: lora_state_dict},
127+
storage_writer=dcp.FileSystemWriter(ckpt_model),
128+
planner=DefaultSavePlanner(),
129+
)
130+
else:
131+
ckpt_model = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{MODEL_INDEX}")
132+
os.makedirs(ckpt_model, exist_ok=True)
133+
logger.info(f"Saving ft model to {ckpt_model}")
134+
dcp.save(
135+
state_dict={KEY_MODEL: model_state_dict},
136+
storage_writer=dcp.FileSystemWriter(ckpt_model),
137+
planner=DefaultSavePlanner(),
138+
)
123139

124140
# - save optimizer
125141
ckpt_opt = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}")
@@ -467,30 +483,54 @@ def save_sharded_safetensors(
467483
save_directory: str,
468484
metadata: Dict,
469485
max_shard_size: Union[int, str] = "5GB",
486+
lora: bool = False,
470487
):
471-
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
472-
".safetensors", "{suffix}.safetensors"
473-
)
474-
state_dict_split = split_torch_state_dict_into_shards(
475-
input_state_dict,
476-
filename_pattern=filename_pattern,
477-
max_shard_size=max_shard_size,
478-
)
479-
index = {
480-
"metadata": state_dict_split.metadata,
481-
"weight_map": state_dict_split.tensor_to_filename,
482-
}
483-
# Save the index
484-
with open(
485-
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
486-
) as f:
487-
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
488-
f.write(content)
488+
if not lora:
489+
filename_pattern = SAFE_WEIGHTS_NAME.replace(".bin", "{suffix}.bin").replace(
490+
".safetensors", "{suffix}.safetensors"
491+
)
492+
state_dict_split = split_torch_state_dict_into_shards(
493+
input_state_dict,
494+
filename_pattern=filename_pattern,
495+
max_shard_size=max_shard_size,
496+
)
489497

490-
filename_to_tensors = state_dict_split.filename_to_tensors.items()
491-
for shard_file, tensors in filename_to_tensors:
492-
shard = {tensor: input_state_dict[tensor].contiguous() for tensor in tensors}
493-
save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
498+
index = {
499+
"metadata": state_dict_split.metadata,
500+
"weight_map": state_dict_split.tensor_to_filename,
501+
}
502+
# Save the index
503+
with open(
504+
os.path.join(save_directory, SAFE_WEIGHTS_INDEX_NAME), "w", encoding="utf-8"
505+
) as f:
506+
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
507+
f.write(content)
508+
509+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
510+
for shard_file, tensors in filename_to_tensors:
511+
shard = {
512+
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
513+
}
514+
save_file(
515+
shard, os.path.join(save_directory, shard_file), metadata=metadata
516+
)
517+
else:
518+
filename_pattern = ADAPTER_SAFE_WEIGHTS_NAME.replace(
519+
".bin", "{suffix}.bin"
520+
).replace(".safetensors", "{suffix}.safetensors")
521+
state_dict_split = split_torch_state_dict_into_shards(
522+
input_state_dict,
523+
filename_pattern=filename_pattern,
524+
max_shard_size=max_shard_size,
525+
)
526+
filename_to_tensors = state_dict_split.filename_to_tensors.items()
527+
for shard_file, tensors in filename_to_tensors:
528+
shard = {
529+
tensor: input_state_dict[tensor].contiguous() for tensor in tensors
530+
}
531+
save_file(
532+
shard, os.path.join(save_directory, shard_file), metadata=metadata
533+
)
494534

495535

496536
# --------------------------- SCRIPT -------------------------
@@ -540,14 +580,32 @@ def recover_safetensors_from_dcp(
540580
# get the state_dict
541581
state_dict = loader(checkpoint_dir)
542582

583+
# filter out additional names created by lora tuning
584+
# create switch based on state dict for future use
585+
new_state_dict = {}
586+
lora = False
587+
for name, param in state_dict.items():
588+
# if lora weight, set lora switch to true
589+
if "lora_A" in name or "lora_B" in name:
590+
lora = True
591+
# if lora naming convention, convert to traditional
592+
if "base_model.model." in name:
593+
name = name.replace("base_model.model.", "", 1)
594+
if "default." in name:
595+
name = name.replace("default.", "", 1)
596+
new_state_dict[name] = param
597+
543598
# recover the original state dict
544-
state_dict = recover_original_state_dict_from_checkpoint(state_dict, _name_or_path)
599+
state_dict = recover_original_state_dict_from_checkpoint(
600+
new_state_dict, _name_or_path
601+
)
545602

546603
# save it as a safetensors file
547604
save_sharded_safetensors(
548605
{k: v.contiguous() for k, v in state_dict.items()},
549606
output_dir,
550607
metadata={"format": "pt"},
608+
lora=lora,
551609
)
552610

553611

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
# Third Party
1919
from peft import LoraConfig
20-
from peft.utils import INCLUDE_LINEAR_LAYERS_SHORTHAND
2120
from torch.distributed._tensor import DTensor
2221

2322
# pylint: disable=import-error
@@ -237,10 +236,6 @@ def __init__(
237236
assert (
238237
lora_config.bias == "none"
239238
), "ScatterMoE currently unable to handle bias in the lora adapters"
240-
assert (
241-
lora_config.target_modules == INCLUDE_LINEAR_LAYERS_SHORTHAND
242-
or INCLUDE_LINEAR_LAYERS_SHORTHAND in lora_config.target_modules
243-
), "ScatterMoe currently only handles lora adapters on all linears."
244239

245240
assert lora_config.init_lora_weights in {
246241
True,
@@ -286,7 +281,6 @@ def __init__(
286281
grouped_out=True,
287282
dtype=dtype,
288283
device=device,
289-
lora_config=lora_config,
290284
)
291285
self.w2 = ScatteredExperts(
292286
in_features=self.intermediate_size,
@@ -296,7 +290,6 @@ def __init__(
296290
grouped_in=True,
297291
dtype=dtype,
298292
device=device,
299-
lora_config=lora_config,
300293
)
301294
if mlp_arch == SCATTERMOE_SPEC_HAS_GATE:
302295
self.w3 = ScatteredExperts(
@@ -307,7 +300,6 @@ def __init__(
307300
grouped_out=True,
308301
dtype=dtype,
309302
device=device,
310-
lora_config=lora_config,
311303
)
312304

313305
# referenced from dolomite-engine

plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_prepare.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def _hook(grad):
9292

9393
# install gradient scaling hook
9494
if KEY_SCATTERMOE_ROUTER not in weight_name:
95-
param.register_hook(_hook)
95+
if param.requires_grad:
96+
param.register_hook(_hook)
9697

9798
# register the sharded parameter onto the megablocks.dmoe
9899
mod.register_parameter(name, param)

0 commit comments

Comments
 (0)