Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dependencies = [
"simpleeval>=0.9.13,<2.0",
"pillow>=11.0.0,<12.0",
"kernels<=0.9.0",
"recommender @ git+https://github.com/foundation-model-stack/tuning-config-recommender.git",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not have a package for the recommender?

without pypi this could not be included in the release

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we do not have a release for the recommender yet!
we can skip this PR for the current minor release if not urgent and required.

]

[project.optional-dependencies]
Expand Down
156 changes: 156 additions & 0 deletions tuning/fms-recommender.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
#!/usr/bin/env python3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The placement of this wrapper inside tuning is a bit unclear.

Do we maintain separate packages with main inside the same folder?
Can we not place this outside?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ashokponkumar any suggestions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reasoning was to keep it at the same level as sft_trainer.py so as the user can simply switch between recommender wrapper or directly use the sft_trainer.py itself..
The goal was to eventually abstract out the usage of sft_trainer.py inside the wrapper so with time the user would only use the wrapper; the wrapper would eventually toggle between recommendation mode vs no-recommendation mode in future updates

# Standard
from pathlib import Path
import argparse
import json
import os
import shlex
import subprocess

# Third Party
from recommender.adapters import FMSAdapter
import yaml

ACCEL_NESTED_PREFIXES = {
"fsdp_": "fsdp_config",
}

DATA_KEYS = {
"training_data_path",
"validation_data_path",
"dataset",
}


def grab_flags(tokens, start, end):
cfg, i = {}, start
while i < end:
t = tokens[i]
if t.startswith("--"):
k, v = t[2:], True
if "=" in t:
k, v = k.split("=", 1)
v = v.strip('"')
elif i + 1 < end and not tokens[i + 1].startswith("--"):
v = tokens[i + 1].strip('"')
i += 1
cfg[k] = v
i += 1
return cfg


def load_yaml(path):
if path and os.path.exists(path):
try:
with open(path, "r") as f:
y = yaml.safe_load(f)
return y if isinstance(y, dict) else {}
except (OSError, yaml.YAMLError):
return {}
return {}


def nest_accelerate_flags(flat_dist):
nested = {section: {} for section in ACCEL_NESTED_PREFIXES.values()}
remaining = {}

for k, v in flat_dist.items():
matched = False
for prefix, section in ACCEL_NESTED_PREFIXES.items():
if k.startswith(prefix):
nested[section][k] = v
matched = True
break
if not matched:
remaining[k] = v

for sec in list(nested.keys()):
if not nested[sec]:
nested.pop(sec)

return {**remaining, **nested}


def parse(cmd: str):
tokens = shlex.split(cmd)
has_m = "-m" in tokens
is_accel = "accelerate" in tokens and "launch" in tokens
if is_accel and has_m:
m = tokens.index("-m")
dist_flat = grab_flags(tokens, 0, m)
train = grab_flags(tokens, m + 2, len(tokens))

elif has_m:
m = tokens.index("-m")
dist_flat = {}
train = grab_flags(tokens, m + 2, len(tokens))
else:
dist_flat = {}
train = grab_flags(tokens, 0, len(tokens))

yaml_path = train.pop("data_config", None)
if yaml_path:
data = load_yaml(yaml_path)
else:
data = {}
accel_yaml_path = dist_flat.pop("config_file", None)
accel_yaml = load_yaml(accel_yaml_path) if accel_yaml_path else {}
dist_nested = nest_accelerate_flags(dist_flat)
dist = {**accel_yaml, **dist_nested}
train.pop("config_file", None)

return train, dist, data


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--debug",
action="store_true",
help="Print parsed configs and exit (no adapter, no execution).",
)
parser.add_argument(
"--preview",
action="store_true",
help="Run adapter and show launch command but DO NOT execute it.",
)
parser.add_argument("command", nargs=argparse.REMAINDER)
args = parser.parse_args()
if not args.command:
print("Error: No command provided.")
return

cmd = " ".join(args.command)
train_cfg, dist_cfg, data_cfg = parse(cmd)
train_cfg.pop("config_file", None)
dist_cfg.pop("config_file", None)

if args.debug:
print("\n[dist_config]\n", json.dumps(dist_cfg, indent=2))
print("\n[train_config]\n", json.dumps(train_cfg, indent=2))
print("\n[data_config]\n", json.dumps(data_cfg, indent=2))
return

adapter = FMSAdapter(base_dir=Path("fms_recommender_ouput/final"))
ir, patches = adapter.execute(
train_config=train_cfg,
dist_config=dist_cfg,
compute_config={},
data_config=data_cfg,
unique_tag="fms-recommender",
)
out = adapter._to_target(ir, patches, tag="fms-recommender")
launch_cmd = out["launch_command"]

if args.preview:
print("\n[LAUNCH COMMAND — PREVIEW ONLY]\n")
print(launch_cmd)
return

print("\n[EXECUTING launch command]\n")
print(launch_cmd)
subprocess.run(launch_cmd, shell=True, check=True)


if __name__ == "__main__":
main()