Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sagemaker-train/src/sagemaker/train/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,7 @@
"mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"],
"amazon.nova-pro-v1:0": ["us-east-1"]
}

SM_RECIPE = "recipe"
SM_RECIPE_YAML = "recipe.yaml"
SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}"
57 changes: 52 additions & 5 deletions sagemaker-train/src/sagemaker/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@
SM_CODE_CONTAINER_PATH,
SM_DRIVERS,
SM_DRIVERS_LOCAL_PATH,
SM_RECIPE,
SM_RECIPE_YAML,
SM_RECIPE_CONTAINER_PATH,
TRAIN_SCRIPT,
DEFAULT_CONTAINER_ENTRYPOINT,
DEFAULT_CONTAINER_ARGUMENTS,
Expand All @@ -100,7 +103,12 @@
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train import logger
from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type
from sagemaker.train.sm_recipes.utils import (
_get_args_from_recipe,
_determine_device_type,
_is_nova_recipe,
_load_base_recipe,
)

from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.jumpstart.document import get_hub_content_and_document
Expand Down Expand Up @@ -249,6 +257,7 @@ class ModelTrainer(BaseModel):
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)

_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
# Private Attributes for Recipes
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)

Expand Down Expand Up @@ -573,6 +582,22 @@ def _create_training_job_args(

final_input_data_config = list(existing_channels.values()) + new_channels

if self._is_nova_recipe:
for input_data in final_input_data_config:
if input_data.channel_name == SM_RECIPE:
raise ValueError(
"Cannot use reserved channel name 'recipe' as an input channel name "
" for Nova Recipe"
)
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
recipe_channel = self.create_input_data_channel(
channel_name=SM_RECIPE,
data_source=recipe_file_path,
key_prefix=input_data_key_prefix,
)
final_input_data_config.append(recipe_channel)
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})

if final_input_data_config:
final_input_data_config = self._get_input_data_config(
final_input_data_config, input_data_key_prefix
Expand Down Expand Up @@ -1039,6 +1064,7 @@ def from_recipe(
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
training_input_mode: Optional[str] = "File",
environment: Optional[Dict[str, str]] = None,
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
tags: Optional[List[Tag]] = None,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
Expand Down Expand Up @@ -1136,12 +1162,19 @@ def from_recipe(
if compute.instance_type is None:
raise ValueError("Must set ``instance_type`` in Compute when using training recipes.")
device_type = _determine_device_type(compute.instance_type)
if device_type == "cpu":
recipe = _load_base_recipe(
training_recipe=training_recipe, recipe_overrides=recipe_overrides
)
is_nova = _is_nova_recipe(recipe=recipe)
if device_type == "cpu" and not is_nova:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : should we have only if device_type == "cpu" condition here, not sure if second condition is required ?

raise ValueError(
"Training recipes are not supported for CPU instances. "
"Please provide a GPU or Tranium instance type."
)

if training_image is None and is_nova:
raise ValueError("training_image must be provided when using recipe for Nova.")

if training_image_config and training_image is None:
raise ValueError("training_image must be provided when using training_image_config.")

Expand All @@ -1154,16 +1187,29 @@ def from_recipe(
# - distributed
# - compute
# - hyperparameters
model_trainer_args, recipe_train_dir = _get_args_from_recipe(
training_recipe=training_recipe,
model_trainer_args, tmp_dir = _get_args_from_recipe(
training_recipe=recipe,
recipe_overrides=recipe_overrides,
requirements=requirements,
compute=compute,
region_name=sagemaker_session.boto_region_name,
role=role,
)
if training_image is not None:
model_trainer_args["training_image"] = training_image

if hyperparameters and not is_nova:
logger.warning(
"Hyperparameters are not supported for general training recipes. "
+ "Ignoring hyperparameters input."
)
if is_nova:
if hyperparameters and isinstance(hyperparameters, str):
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
model_trainer_args["hyperparameters"].update(hyperparameters)
elif hyperparameters and isinstance(hyperparameters, dict):
model_trainer_args["hyperparameters"].update(hyperparameters)

model_trainer = cls(
sagemaker_session=sagemaker_session,
role=role,
Expand All @@ -1180,7 +1226,8 @@ def from_recipe(
**model_trainer_args,
)

model_trainer._temp_recipe_train_dir = recipe_train_dir
model_trainer._is_nova_recipe = is_nova
model_trainer._temp_recipe_train_dir = tmp_dir
return model_trainer

@classmethod
Expand Down
136 changes: 127 additions & 9 deletions sagemaker-train/src/sagemaker/train/sm_recipes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import shutil
import tempfile
from urllib.request import urlretrieve
from typing import Dict, Any, Optional, Tuple
from typing import Dict, Any, Optional, Tuple, Union

import omegaconf
from omegaconf import OmegaConf, dictconfig
Expand All @@ -30,6 +30,7 @@
from sagemaker.train.utils import _run_clone_command_silent
from sagemaker.train.configs import Compute, SourceCode
from sagemaker.train.distributed import Torchrun, SMP
from sagemaker.train.constants import SM_RECIPE_YAML


def _try_resolve_recipe(recipe, key=None):
Expand Down Expand Up @@ -86,6 +87,8 @@ def _load_base_recipe(
)
else:
recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_")
if training_recipes_cfg is None:
training_recipes_cfg = _load_recipes_cfg()

launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get(
"launcher_repo"
Expand Down Expand Up @@ -149,7 +152,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str):
def _configure_gpu_args(
training_recipes_cfg: Dict[str, Any],
region_name: str,
recipe: OmegaConf,
recipe: dictconfig.DictConfig,
recipe_train_dir: tempfile.TemporaryDirectory,
) -> Dict[str, Any]:
"""Configure arguments specific to GPU."""
Expand Down Expand Up @@ -234,11 +237,12 @@ def _configure_trainium_args(


def _get_args_from_recipe(
training_recipe: str,
training_recipe: Union[str, dictconfig.DictConfig],
compute: Compute,
region_name: str,
recipe_overrides: Optional[Dict[str, Any]],
requirements: Optional[str],
role: Optional[str] = None,
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
"""Get arguments for ModelTrainer from a training recipe.

Expand All @@ -254,8 +258,8 @@ def _get_args_from_recipe(
```

Args:
training_recipe (str):
Name of the training recipe or path to the recipe file.
training_recipe (Union[str, Dict[str, Any]]):
Name of the training recipe or path to the recipe file or loaded recipe Dict.
compute (Compute):
Compute configuration for training.
region_name (str):
Expand All @@ -269,7 +273,13 @@ def _get_args_from_recipe(
raise ValueError("Must set `instance_type` in compute when using training recipes.")

training_recipes_cfg = _load_recipes_cfg()
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
if isinstance(training_recipe, str):
recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg)
else:
recipe = training_recipe
if _is_nova_recipe(recipe):
args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role)
return args, recipe_local_dir

if "trainer" not in recipe:
raise ValueError("Supplied recipe does not contain required field trainer.")
Expand All @@ -283,7 +293,7 @@ def _get_args_from_recipe(
if compute.instance_count is None:
if "num_nodes" not in recipe["trainer"]:
raise ValueError(
"Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe."
"Must provide Compute with instance_count or set trainer -> num_nodes in recipe."
)
compute.instance_count = recipe["trainer"]["num_nodes"]

Expand Down Expand Up @@ -313,7 +323,7 @@ def _get_args_from_recipe(

# Save Final Recipe to source_dir
OmegaConf.save(
config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml")
config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML)
)

# If recipe_requirements is provided, copy it to source_dir
Expand All @@ -322,7 +332,7 @@ def _get_args_from_recipe(
args["source_code"].requirements = os.path.basename(requirements)

# Update args with compute and hyperparameters
hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"}
hyperparameters = {"config-path": ".", "config-name": SM_RECIPE_YAML}

# Handle eval custom lambda configuration
if recipe.get("evaluation", {}):
Expand All @@ -339,3 +349,111 @@ def _get_args_from_recipe(
)

return args, recipe_train_dir

def _is_nova_recipe(
recipe: dictconfig.DictConfig,
) -> bool:
"""Check if the recipe is a Nova recipe.

A recipe is considered a Nova recipe if it meets either of the following conditions:

1. It has a run section with:
- A model_type that includes "amazon.nova"
- A model_name_or_path field

OR

2. It has a training_config section with:
- A distillation_data field

Args:
recipe (DictConfig): The loaded recipe configuration

Returns:
bool: True if the recipe is a Nova recipe, False otherwise
"""
run_config = recipe.get("run", {})
model_type = run_config.get("model_type", "").lower()
has_nova_model = (
model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config
)

# Check for distillation data
training_config = recipe.get("training_config", {})
has_distillation = training_config.get("distillation_data") is not None
return bool(has_nova_model) or bool(has_distillation)

def _get_args_from_nova_recipe(
recipe: dictconfig.DictConfig,
compute: Compute,
role: Optional[str] = None,
) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]:
if not compute.instance_count and not recipe.get("run", {}).get("replicas", None):
raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.")
compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas")

args = dict()
args.update({"hyperparameters": {}})

run_config = recipe.get("run", {})
model_name_or_path = run_config.get("model_name_or_path")
if model_name_or_path:
if model_name_or_path.startswith("s3://"):
args["hyperparameters"]["base_model_location"] = model_name_or_path
else:
args["hyperparameters"]["base_model"] = model_name_or_path

# Handle distillation configuration
training_config = recipe.get("training_config", {})
distillation_data = training_config.get("distillation_data")
if bool(distillation_data):
args["hyperparameters"]["distillation_data"] = distillation_data
if not role:
raise ValueError("Must provide 'role' parameter when using Nova distillation")
args["hyperparameters"]["role_arn"] = role

kms_key = training_config.get("kms_key")
if kms_key is None:
raise ValueError(
'Nova distillation job recipe requires "kms_key" field in "training_config"'
)
args["hyperparameters"]["kms_key"] = kms_key

# Handle eval custom lambda configuration
if recipe.get("evaluation", {}):
processor = recipe.get("processor", {})
lambda_arn = processor.get("lambda_arn", "")
if lambda_arn:
args["hyperparameters"]["eval_lambda_arn"] = lambda_arn

# Handle reward lambda configuration
run_config = recipe.get("run", {})
reward_lambda_arn = run_config.get("reward_lambda_arn", "")
if reward_lambda_arn:
args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn

_register_custom_resolvers()

# Resolve Final Recipe
final_recipe = _try_resolve_recipe(recipe)
if final_recipe is None:
final_recipe = _try_resolve_recipe(recipe, "recipes")
if final_recipe is None:
final_recipe = _try_resolve_recipe(recipe, "training")
if final_recipe is None:
raise RuntimeError("Could not resolve provided recipe.")

# Save Final Recipe to tmp dir
recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_")
final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML)
OmegaConf.save(config=final_recipe, f=final_recipe_path)

args.update(
{
"compute": compute,
"training_image": None,
"source_code": None,
"distributed": None,
}
)
return args, recipe_local_dir
2 changes: 2 additions & 0 deletions sagemaker-train/tests/unit/train/sm_recipes/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
_configure_gpu_args,
_configure_trainium_args,
_get_trainining_recipe_gpu_model_name_and_script,
_is_nova_recipe,
_get_args_from_nova_recipe,
)
from sagemaker.train.utils import _run_clone_command_silent
from sagemaker.train.configs import Compute
Expand Down
Loading
Loading