diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 48c42c9093..15fce697ea 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -1061,7 +1061,7 @@ def _prepare_train_script( execute_driver=execute_driver, ) - with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f: + with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w", newline="\n") as f: f.write(train_script) @classmethod diff --git a/sagemaker-train/tests/unit/train/test_model_trainer.py b/sagemaker-train/tests/unit/train/test_model_trainer.py index 220e0fb40f..2e7a7d73c6 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer.py @@ -559,9 +559,10 @@ def test_train_with_distributed_config( ) assert os.path.exists(expected_train_script_path) - with open(expected_train_script_path, "r") as f: + with open(expected_train_script_path, "rb") as f: train_script_content = f.read() - assert test_case["expected_template"] in train_script_content + assert test_case["expected_template"] in train_script_content.decode("utf-8") + assert b"\r\n" not in train_script_content assert os.path.exists(expected_runner_json_path) with open(expected_runner_json_path, "r") as f: @@ -1529,4 +1530,4 @@ def test_llmft_recipe_missing_training_image_error(modules_session): ) # Clean up the temporary file - os.unlink(recipe.name) \ No newline at end of file + os.unlink(recipe.name)