diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 309265d659..68b0f6c474 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -62,3 +62,7 @@ "mistral.mistral-large-2402-v1:0": ["us-west-2", "us-east-1", "eu-west-1"], "amazon.nova-pro-v1:0": ["us-east-1"] } + +SM_RECIPE = "recipe" +SM_RECIPE_YAML = "recipe.yaml" +SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 1a1fcab410..8849a53545 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -83,6 +83,9 @@ SM_CODE_CONTAINER_PATH, SM_DRIVERS, SM_DRIVERS_LOCAL_PATH, + SM_RECIPE, + SM_RECIPE_YAML, + SM_RECIPE_CONTAINER_PATH, TRAIN_SCRIPT, DEFAULT_CONTAINER_ENTRYPOINT, DEFAULT_CONTAINER_ARGUMENTS, @@ -100,7 +103,12 @@ from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature from sagemaker.train import logger -from sagemaker.train.sm_recipes.utils import _get_args_from_recipe, _determine_device_type +from sagemaker.train.sm_recipes.utils import ( + _get_args_from_recipe, + _determine_device_type, + _is_nova_recipe, + _load_base_recipe, +) from sagemaker.core.jumpstart.configs import JumpStartConfig from sagemaker.core.jumpstart.document import get_hub_content_and_document @@ -249,6 +257,7 @@ class ModelTrainer(BaseModel): _remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None) _metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None) + _is_nova_recipe: Optional[bool] = PrivateAttr(default=None) # Private Attributes for Recipes _temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) @@ -573,6 +582,22 @@ def _create_training_job_args( final_input_data_config = list(existing_channels.values()) + new_channels + if self._is_nova_recipe: + for input_data in final_input_data_config: + if input_data.channel_name == SM_RECIPE: + raise ValueError( + "Cannot use reserved channel name 'recipe' as an input channel name " + " for Nova Recipe" + ) + recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML) + recipe_channel = self.create_input_data_channel( + channel_name=SM_RECIPE, + data_source=recipe_file_path, + key_prefix=input_data_key_prefix, + ) + final_input_data_config.append(recipe_channel) + self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH}) + if final_input_data_config: final_input_data_config = self._get_input_data_config( final_input_data_config, input_data_key_prefix @@ -1039,6 +1064,7 @@ def from_recipe( checkpoint_config: Optional[shapes.CheckpointConfig] = None, training_input_mode: Optional[str] = "File", environment: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Union[Dict[str, Any], str]] = {}, tags: Optional[List[Tag]] = None, sagemaker_session: Optional[Session] = None, role: Optional[str] = None, @@ -1136,12 +1162,19 @@ def from_recipe( if compute.instance_type is None: raise ValueError("Must set ``instance_type`` in Compute when using training recipes.") device_type = _determine_device_type(compute.instance_type) - if device_type == "cpu": + recipe = _load_base_recipe( + training_recipe=training_recipe, recipe_overrides=recipe_overrides + ) + is_nova = _is_nova_recipe(recipe=recipe) + if device_type == "cpu" and not is_nova: raise ValueError( "Training recipes are not supported for CPU instances. " "Please provide a GPU or Tranium instance type." ) + if training_image is None and is_nova: + raise ValueError("training_image must be provided when using recipe for Nova.") + if training_image_config and training_image is None: raise ValueError("training_image must be provided when using training_image_config.") @@ -1154,16 +1187,29 @@ def from_recipe( # - distributed # - compute # - hyperparameters - model_trainer_args, recipe_train_dir = _get_args_from_recipe( - training_recipe=training_recipe, + model_trainer_args, tmp_dir = _get_args_from_recipe( + training_recipe=recipe, recipe_overrides=recipe_overrides, requirements=requirements, compute=compute, region_name=sagemaker_session.boto_region_name, + role=role, ) if training_image is not None: model_trainer_args["training_image"] = training_image + if hyperparameters and not is_nova: + logger.warning( + "Hyperparameters are not supported for general training recipes. " + + "Ignoring hyperparameters input." + ) + if is_nova: + if hyperparameters and isinstance(hyperparameters, str): + hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters) + model_trainer_args["hyperparameters"].update(hyperparameters) + elif hyperparameters and isinstance(hyperparameters, dict): + model_trainer_args["hyperparameters"].update(hyperparameters) + model_trainer = cls( sagemaker_session=sagemaker_session, role=role, @@ -1180,7 +1226,8 @@ def from_recipe( **model_trainer_args, ) - model_trainer._temp_recipe_train_dir = recipe_train_dir + model_trainer._is_nova_recipe = is_nova + model_trainer._temp_recipe_train_dir = tmp_dir return model_trainer @classmethod diff --git a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py index f234d12d20..4e5b08c447 100644 --- a/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py +++ b/sagemaker-train/src/sagemaker/train/sm_recipes/utils.py @@ -19,7 +19,7 @@ import shutil import tempfile from urllib.request import urlretrieve -from typing import Dict, Any, Optional, Tuple +from typing import Dict, Any, Optional, Tuple, Union import omegaconf from omegaconf import OmegaConf, dictconfig @@ -30,6 +30,7 @@ from sagemaker.train.utils import _run_clone_command_silent from sagemaker.train.configs import Compute, SourceCode from sagemaker.train.distributed import Torchrun, SMP +from sagemaker.train.constants import SM_RECIPE_YAML def _try_resolve_recipe(recipe, key=None): @@ -86,6 +87,8 @@ def _load_base_recipe( ) else: recipe_launcher_dir = tempfile.TemporaryDirectory(prefix="launcher_") + if training_recipes_cfg is None: + training_recipes_cfg = _load_recipes_cfg() launcher_repo = os.environ.get("TRAINING_LAUNCHER_GIT", None) or training_recipes_cfg.get( "launcher_repo" @@ -149,7 +152,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): def _configure_gpu_args( training_recipes_cfg: Dict[str, Any], region_name: str, - recipe: OmegaConf, + recipe: dictconfig.DictConfig, recipe_train_dir: tempfile.TemporaryDirectory, ) -> Dict[str, Any]: """Configure arguments specific to GPU.""" @@ -234,11 +237,12 @@ def _configure_trainium_args( def _get_args_from_recipe( - training_recipe: str, + training_recipe: Union[str, dictconfig.DictConfig], compute: Compute, region_name: str, recipe_overrides: Optional[Dict[str, Any]], requirements: Optional[str], + role: Optional[str] = None, ) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: """Get arguments for ModelTrainer from a training recipe. @@ -254,8 +258,8 @@ def _get_args_from_recipe( ``` Args: - training_recipe (str): - Name of the training recipe or path to the recipe file. + training_recipe (Union[str, Dict[str, Any]]): + Name of the training recipe or path to the recipe file or loaded recipe Dict. compute (Compute): Compute configuration for training. region_name (str): @@ -269,7 +273,13 @@ def _get_args_from_recipe( raise ValueError("Must set `instance_type` in compute when using training recipes.") training_recipes_cfg = _load_recipes_cfg() - recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + if isinstance(training_recipe, str): + recipe = _load_base_recipe(training_recipe, recipe_overrides, training_recipes_cfg) + else: + recipe = training_recipe + if _is_nova_recipe(recipe): + args, recipe_local_dir = _get_args_from_nova_recipe(recipe, compute, role=role) + return args, recipe_local_dir if "trainer" not in recipe: raise ValueError("Supplied recipe does not contain required field trainer.") @@ -283,7 +293,7 @@ def _get_args_from_recipe( if compute.instance_count is None: if "num_nodes" not in recipe["trainer"]: raise ValueError( - "Must provide Compute with instance_count or" " set trainer -> num_nodes in recipe." + "Must provide Compute with instance_count or set trainer -> num_nodes in recipe." ) compute.instance_count = recipe["trainer"]["num_nodes"] @@ -313,7 +323,7 @@ def _get_args_from_recipe( # Save Final Recipe to source_dir OmegaConf.save( - config=final_recipe, f=os.path.join(args["source_code"].source_dir, "recipe.yaml") + config=final_recipe, f=os.path.join(args["source_code"].source_dir, SM_RECIPE_YAML) ) # If recipe_requirements is provided, copy it to source_dir @@ -322,7 +332,7 @@ def _get_args_from_recipe( args["source_code"].requirements = os.path.basename(requirements) # Update args with compute and hyperparameters - hyperparameters = {"config-path": ".", "config-name": "recipe.yaml"} + hyperparameters = {"config-path": ".", "config-name": SM_RECIPE_YAML} # Handle eval custom lambda configuration if recipe.get("evaluation", {}): @@ -339,3 +349,111 @@ def _get_args_from_recipe( ) return args, recipe_train_dir + +def _is_nova_recipe( + recipe: dictconfig.DictConfig, +) -> bool: + """Check if the recipe is a Nova recipe. + + A recipe is considered a Nova recipe if it meets either of the following conditions: + + 1. It has a run section with: + - A model_type that includes "amazon.nova" + - A model_name_or_path field + + OR + + 2. It has a training_config section with: + - A distillation_data field + + Args: + recipe (DictConfig): The loaded recipe configuration + + Returns: + bool: True if the recipe is a Nova recipe, False otherwise + """ + run_config = recipe.get("run", {}) + model_type = run_config.get("model_type", "").lower() + has_nova_model = ( + model_type and "amazon.nova" in model_type and "model_name_or_path" in run_config + ) + + # Check for distillation data + training_config = recipe.get("training_config", {}) + has_distillation = training_config.get("distillation_data") is not None + return bool(has_nova_model) or bool(has_distillation) + +def _get_args_from_nova_recipe( + recipe: dictconfig.DictConfig, + compute: Compute, + role: Optional[str] = None, +) -> Tuple[Dict[str, Any], tempfile.TemporaryDirectory]: + if not compute.instance_count and not recipe.get("run", {}).get("replicas", None): + raise ValueError("Must set ``instance_type`` in compute or ``replicas`` in recipe.") + compute.instance_count = compute.instance_count or recipe.get("run", {}).get("replicas") + + args = dict() + args.update({"hyperparameters": {}}) + + run_config = recipe.get("run", {}) + model_name_or_path = run_config.get("model_name_or_path") + if model_name_or_path: + if model_name_or_path.startswith("s3://"): + args["hyperparameters"]["base_model_location"] = model_name_or_path + else: + args["hyperparameters"]["base_model"] = model_name_or_path + + # Handle distillation configuration + training_config = recipe.get("training_config", {}) + distillation_data = training_config.get("distillation_data") + if bool(distillation_data): + args["hyperparameters"]["distillation_data"] = distillation_data + if not role: + raise ValueError("Must provide 'role' parameter when using Nova distillation") + args["hyperparameters"]["role_arn"] = role + + kms_key = training_config.get("kms_key") + if kms_key is None: + raise ValueError( + 'Nova distillation job recipe requires "kms_key" field in "training_config"' + ) + args["hyperparameters"]["kms_key"] = kms_key + + # Handle eval custom lambda configuration + if recipe.get("evaluation", {}): + processor = recipe.get("processor", {}) + lambda_arn = processor.get("lambda_arn", "") + if lambda_arn: + args["hyperparameters"]["eval_lambda_arn"] = lambda_arn + + # Handle reward lambda configuration + run_config = recipe.get("run", {}) + reward_lambda_arn = run_config.get("reward_lambda_arn", "") + if reward_lambda_arn: + args["hyperparameters"]["reward_lambda_arn"] = reward_lambda_arn + + _register_custom_resolvers() + + # Resolve Final Recipe + final_recipe = _try_resolve_recipe(recipe) + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "recipes") + if final_recipe is None: + final_recipe = _try_resolve_recipe(recipe, "training") + if final_recipe is None: + raise RuntimeError("Could not resolve provided recipe.") + + # Save Final Recipe to tmp dir + recipe_local_dir = tempfile.TemporaryDirectory(prefix="recipe_") + final_recipe_path = os.path.join(recipe_local_dir.name, SM_RECIPE_YAML) + OmegaConf.save(config=final_recipe, f=final_recipe_path) + + args.update( + { + "compute": compute, + "training_image": None, + "source_code": None, + "distributed": None, + } + ) + return args, recipe_local_dir \ No newline at end of file diff --git a/sagemaker-train/tests/unit/train/sm_recipes/test_utils.py b/sagemaker-train/tests/unit/train/sm_recipes/test_utils.py index b85c138b94..d5b140d20e 100644 --- a/sagemaker-train/tests/unit/train/sm_recipes/test_utils.py +++ b/sagemaker-train/tests/unit/train/sm_recipes/test_utils.py @@ -27,6 +27,8 @@ _configure_gpu_args, _configure_trainium_args, _get_trainining_recipe_gpu_model_name_and_script, + _is_nova_recipe, + _get_args_from_nova_recipe, ) from sagemaker.train.utils import _run_clone_command_silent from sagemaker.train.configs import Compute diff --git a/sagemaker-train/tests/unit/train/test_model_trainer.py b/sagemaker-train/tests/unit/train/test_model_trainer.py index 73b9fa48c2..945ad4bf69 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer.py @@ -18,9 +18,11 @@ import json import os import yaml +from omegaconf import OmegaConf import pytest from pydantic import ValidationError from unittest.mock import patch, MagicMock, ANY, mock_open +from tempfile import NamedTemporaryFile from sagemaker.core.resources import TrainingJob from sagemaker.core.shapes import ( @@ -43,6 +45,7 @@ DISTRIBUTED_JSON, SOURCE_CODE_JSON, TRAIN_SCRIPT, + SM_RECIPE_CONTAINER_PATH, ) from sagemaker.train.configs import ( Compute, @@ -67,7 +70,7 @@ MetricDefinition, ) from sagemaker.train.distributed import Torchrun, SMP, MPI -from sagemaker.train.sm_recipes.utils import _load_recipes_cfg +from sagemaker.train.sm_recipes.utils import _load_recipes_cfg, _is_nova_recipe, _get_args_from_nova_recipe from sagemaker.train.templates import EXEUCTE_DISTRIBUTED_DRIVER from tests.unit import DATA_DIR @@ -1347,3 +1350,93 @@ def test_metric_definitions(mock_training_job, modules_session): mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions == metric_definitions ) + + +@patch("sagemaker.train.model_trainer._get_unique_name") +@patch("sagemaker.core.resources.TrainingJob") +def test_nova_recipe(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + if os.path.isfile(path): + file_name = os.path.basename(path) + return f"s3://{bucket}/{key_prefix}/{file_name}" + else: + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + recipe_data = { + "run": { + "name": "dummy-model", + "model_type": "amazon.nova", + "model_name_or_path": "dummy-model", + } + } + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + # Patch TrainingJob.create to avoid Pydantic validation on session + with patch.object(TrainingJob, 'create', return_value=mock_training_job) as mock_create: + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + base_job_name=base_name, + ) + + assert trainer._is_nova_recipe + + trainer.train() + mock_create.assert_called_once() + assert mock_create.call_args.kwargs["hyper_parameters"] == { + "base_model": "dummy-model", + "sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH, + } + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + assert mock_create.call_args.kwargs["input_data_config"] == [ + Channel( + channel_name="recipe", + data_source=DataSource( + s3_data_source=S3DataSource( + s3_data_type="S3Prefix", + s3_uri=f"{default_base_path}/{unique_name}/input/recipe/recipe.yaml", + s3_data_distribution_type="FullyReplicated", + ) + ), + input_mode="File", + ) + ] + + +def test_nova_recipe_with_distillation(modules_session): + recipe_data = {"training_config": {"distillation_data": "true", "kms_key": "alias/my-kms-key"}} + + with NamedTemporaryFile(suffix=".yaml", delete=False) as recipe: + with open(recipe.name, "w") as file: + yaml.dump(recipe_data, file) + + # Create ModelTrainer from recipe + trainer = ModelTrainer.from_recipe( + training_recipe=recipe.name, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + training_image=DEFAULT_IMAGE, + ) + + # Verify that the hyperparameters were set correctly + assert trainer.hyperparameters == { + "distillation_data": "true", + "role_arn": DEFAULT_ROLE, + "kms_key": "alias/my-kms-key", + } + + # Clean up the temporary file + os.unlink(recipe.name) \ No newline at end of file