From b0741bbe506ef1fea7cf8904843121395b0fbd4e Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 6 May 2025 15:46:36 -0700 Subject: [PATCH 1/3] fix: Map llama models to correct script --- src/sagemaker/modules/train/sm_recipes/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/sm_recipes/utils.py b/src/sagemaker/modules/train/sm_recipes/utils.py index 549645cbe2..6b39add6cd 100644 --- a/src/sagemaker/modules/train/sm_recipes/utils.py +++ b/src/sagemaker/modules/train/sm_recipes/utils.py @@ -129,7 +129,7 @@ def _get_trainining_recipe_gpu_model_name_and_script(model_type: str): """Get the model base name and script for the training recipe.""" model_type_to_script = { - "llama_v3": ("llama", "llama_pretrain.py"), + "llama": ("llama", "llama_pretrain.py"), "mistral": ("mistral", "mistral_pretrain.py"), "mixtral": ("mixtral", "mixtral_pretrain.py"), "deepseek": ("deepseek", "deepseek_pretrain.py"), From fdc00c4cf8963f09427111a307d3a3fbcf2d2ca5 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 6 May 2025 16:17:28 -0700 Subject: [PATCH 2/3] add test --- .../modules/train/sm_recipes/test_utils.py | 69 ++++++++++--------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index f5f7ceb083..cdc5a654ab 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -180,36 +180,41 @@ def test_get_args_from_recipe_compute( assert mock_trainium_args.call_count == 0 assert args is None - @pytest.mark.parametrize( - "test_case", - [ - { - "model_type": "llama_v3", - "script": "llama_pretrain.py", - "model_base_name": "llama_v3", - }, - { - "model_type": "mistral", - "script": "mistral_pretrain.py", - "model_base_name": "mistral", - }, - { - "model_type": "deepseek_llamav3", - "script": "deepseek_pretrain.py", - "model_base_name": "deepseek", - }, - { - "model_type": "deepseek_qwenv2", - "script": "deepseek_pretrain.py", - "model_base_name": "deepseek", - }, - ], +@pytest.mark.parametrize( + "test_case", + [ + { + "model_type": "llama_v4", + "script": "llama_pretrain.py", + "model_base_name": "llama" + }, + { + "model_type": "llama_v3", + "script": "llama_pretrain.py", + "model_base_name": "llama", + }, + { + "model_type": "mistral", + "script": "mistral_pretrain.py", + "model_base_name": "mistral", + }, + { + "model_type": "deepseek_llamav3", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + { + "model_type": "deepseek_qwenv2", + "script": "deepseek_pretrain.py", + "model_base_name": "deepseek", + }, + ], +) +def test_get_trainining_recipe_gpu_model_name_and_script(test_case): + model_type = test_case["model_type"] + script = test_case["script"] + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( + model_type ) - def test_get_trainining_recipe_gpu_model_name_and_script(test_case): - model_type = test_case["model_type"] - script = test_case["script"] - model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( - model_type, script - ) - assert model_base_name == test_case["model_base_name"] - assert script == test_case["script"] + assert model_base_name == test_case["model_base_name"] + assert script == test_case["script"] From 07b589ed978ff55e1c48c11342a3b008e403a3d8 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Thu, 8 May 2025 14:37:50 -0700 Subject: [PATCH 3/3] fix formatting --- .../sagemaker/modules/train/sm_recipes/test_utils.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py index cdc5a654ab..585a4d2745 100644 --- a/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py +++ b/tests/unit/sagemaker/modules/train/sm_recipes/test_utils.py @@ -180,14 +180,11 @@ def test_get_args_from_recipe_compute( assert mock_trainium_args.call_count == 0 assert args is None + @pytest.mark.parametrize( "test_case", [ - { - "model_type": "llama_v4", - "script": "llama_pretrain.py", - "model_base_name": "llama" - }, + {"model_type": "llama_v4", "script": "llama_pretrain.py", "model_base_name": "llama"}, { "model_type": "llama_v3", "script": "llama_pretrain.py", @@ -213,8 +210,6 @@ def test_get_args_from_recipe_compute( def test_get_trainining_recipe_gpu_model_name_and_script(test_case): model_type = test_case["model_type"] script = test_case["script"] - model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script( - model_type - ) + model_base_name, script = _get_trainining_recipe_gpu_model_name_and_script(model_type) assert model_base_name == test_case["model_base_name"] assert script == test_case["script"]