diff --git a/sagemaker-core/src/sagemaker/core/remote_function/job.py b/sagemaker-core/src/sagemaker/core/remote_function/job.py index 435062db57..6e727d4b9c 100644 --- a/sagemaker-core/src/sagemaker/core/remote_function/job.py +++ b/sagemaker-core/src/sagemaker/core/remote_function/job.py @@ -175,12 +175,12 @@ fi printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n" - printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n" - $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@" + printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n" + $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function\\n" - printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n" - python -m sagemaker.train.remote_function.invoke_function "$@" + printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n" + python -m sagemaker.core.remote_function.invoke_function "$@" fi """ @@ -234,14 +234,14 @@ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" $conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ -mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 else @@ -259,7 +259,7 @@ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n" + python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n" mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \ --allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \ @@ -267,7 +267,7 @@ -mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \ -x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \ $SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \ - python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@" + python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@" python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1 else @@ -320,18 +320,18 @@ printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n" printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ - -m sagemaker.train.remote_function.invoke_function \\n" + -m sagemaker.core.remote_function.invoke_function \\n" $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \ --master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \ - -m sagemaker.train.remote_function.invoke_function "$@" + -m sagemaker.core.remote_function.invoke_function "$@" else printf "INFO: No conda env provided. Invoking remote function with torchrun\\n" printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n" + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n" torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \ - --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@" + --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@" fi """ @@ -1259,7 +1259,215 @@ def _prepare_and_upload_runtime_scripts( return upload_path -def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): +def _decrement_version(version_str: str) -> str: + """Decrement a version string by one minor or patch version. + + Rules: + - If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0 + - If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1 + + Args: + version_str: Version string (e.g., "3.2.0") + + Returns: + Decremented version string + """ + from packaging import version as pkg_version + + try: + parsed = pkg_version.parse(version_str) + major = parsed.major + minor = parsed.minor + patch = parsed.micro + + if patch == 0: + # Decrement minor version + minor = max(0, minor - 1) + else: + # Decrement patch version + patch = max(0, patch - 1) + + return f"{major}.{minor}.{patch}" + except Exception: + return version_str + + +def _resolve_version_from_specifier(specifier_str: str) -> str: + """Resolve the version to check based on upper bounds. + + Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only). + If no upper bound exists, it's safe (unbounded). + If the decremented upper bound is less than a lower bound, use the lower bound. + + Args: + specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0") + + Returns: + The resolved version string to check, or None if safe + """ + import re + from packaging import version as pkg_version + + # Handle exact version pinning (==) + match = re.search(r'==\s*([\d.]+)', specifier_str) + if match: + return match.group(1) + + # Extract lower bounds for comparison + lower_bounds = [] + for match in re.finditer(r'>=\s*([\d.]+)', specifier_str): + lower_bounds.append(match.group(1)) + + # Handle upper bounds - find the most restrictive one + upper_bounds = [] + + # Find all <= bounds + for match in re.finditer(r'<=\s*([\d.]+)', specifier_str): + upper_bounds.append(('<=', match.group(1))) + + # Find all < bounds + for match in re.finditer(r'<\s*([\d.]+)', specifier_str): + upper_bounds.append(('<', match.group(1))) + + if upper_bounds: + # Sort by version to find the most restrictive (lowest) upper bound + upper_bounds.sort(key=lambda x: pkg_version.parse(x[1])) + operator, version = upper_bounds[0] + + # Special case: if upper bound is <4.0.0, it's safe (V3 only) + try: + parsed_upper = pkg_version.parse(version) + if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0: + # <4.0.0 means V3 only, which is safe + return None + except Exception: + pass + + resolved_version = version + if operator == '<': + resolved_version = _decrement_version(version) + + # If we have a lower bound and the resolved version is less than it, use the lower bound + if lower_bounds: + try: + resolved_parsed = pkg_version.parse(resolved_version) + for lower_bound_str in lower_bounds: + lower_parsed = pkg_version.parse(lower_bound_str) + if resolved_parsed < lower_parsed: + resolved_version = lower_bound_str + except Exception: + pass + + return resolved_version + + # For lower bounds only (>=, >), we don't check + return None + + +def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None: + """Check if the sagemaker version requirement uses incompatible hashing. + + Raises ValueError if the requirement would install a version that uses HMAC hashing + (which is incompatible with the current SHA256-based integrity checks). + + Args: + sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0") + + Raises: + ValueError: If the requirement would install a version using HMAC hashing + """ + import re + from packaging import version as pkg_version + + match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE) + if not match: + return + + specifier_str = match.group(1).strip() + + # Resolve the version that would be installed + resolved_version_str = _resolve_version_from_specifier(specifier_str) + if not resolved_version_str: + # No upper bound or exact version, so we can't determine if it's bad + return + + try: + resolved_version = pkg_version.parse(resolved_version_str) + except Exception: + return + + # Define HMAC thresholds for each major version + v2_hmac_threshold = pkg_version.parse("2.256.0") + v3_hmac_threshold = pkg_version.parse("3.2.0") + + # Check if the resolved version uses HMAC hashing + uses_hmac = False + if resolved_version.major == 2 and resolved_version < v2_hmac_threshold: + uses_hmac = True + elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold: + uses_hmac = True + + if uses_hmac: + raise ValueError( + f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) " + f"could install a version using HMAC-based integrity checks which are incompatible " + f"with the current SHA256-based integrity checks. Please update to " + f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)." + ) + + +def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str: + """Ensure sagemaker>=3.2.0 is in the dependencies. + + This function ensures that the remote environment has a compatible version of sagemaker + that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based + integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable. + Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require + the secret key. + + If no dependencies are provided, creates a temporary requirements.txt with sagemaker. + If dependencies are provided, appends sagemaker if not already present. + + Args: + local_dependencies_path: Path to user's dependencies file or None + + Returns: + Path to the dependencies file (created or modified) + + Raises: + ValueError: If user has pinned sagemaker to a version using HMAC hashing + """ + import tempfile + + SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0" + + if local_dependencies_path is None: + fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_") + os.close(fd) + + with open(req_file, "w") as f: + f.write(f"{SAGEMAKER_MIN_VERSION}\n") + logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION) + return req_file + + if local_dependencies_path.endswith(".txt"): + with open(local_dependencies_path, "r") as f: + content = f.read() + + if "sagemaker" in content.lower(): + for line in content.split('\n'): + if 'sagemaker' in line.lower(): + _check_sagemaker_version_compatibility(line.strip()) + break + else: + with open(local_dependencies_path, "a") as f: + f.write(f"\n{SAGEMAKER_MIN_VERSION}\n") + logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION) + + return local_dependencies_path + + +def _generate_input_data_config(job_settings, s3_base_uri): """Generates input data config""" from sagemaker.core.workflow.utilities import load_step_compilation_context @@ -1288,6 +1496,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str): local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies) + # Ensure sagemaker dependency is included to prevent version mismatch issues + # Resolves issue where computing hash for integrity check changed in 3.2.0 + local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path) + job_settings.dependencies = local_dependencies_path + if step_compilation_context: with _tmpdir() as tmp_dir: script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts( diff --git a/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py new file mode 100644 index 0000000000..b3d38c32a4 --- /dev/null +++ b/sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py @@ -0,0 +1,137 @@ +"""Integration tests for sagemaker dependency injection in remote functions. + +These tests verify that the sagemaker>=3.2.0 dependency is properly injected +into remote function jobs, preventing version mismatch issues. +""" + +import os +import sys +import tempfile +import pytest + +# Skip decorator for AWS configuration +# skip_if_no_aws_region = pytest.mark.skipif( +# not os.environ.get('AWS_DEFAULT_REGION'), +# reason="AWS credentials not configured" +# ) + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) + +from sagemaker.core.remote_function import remote + + +class TestRemoteFunctionDependencyInjection: + """Integration tests for dependency injection in remote functions.""" + + @pytest.mark.integ + # @skip_if_no_aws_region + def test_remote_function_without_dependencies(self): + """Test remote function execution without explicit dependencies. + + This test verifies that when no dependencies are provided, the remote + function still executes successfully because sagemaker>=3.2.0 is + automatically injected. + """ + @remote( + instance_type="ml.m5.large", + # No dependencies specified - sagemaker should be injected automatically + ) + def simple_add(x, y): + """Simple function that adds two numbers.""" + return x + y + + # Execute the function + result = simple_add(5, 3) + + # Verify result + assert result == 8, f"Expected 8, got {result}" + print("✓ Remote function without dependencies executed successfully") + + @pytest.mark.integ + # @skip_if_no_aws_region + def test_remote_function_with_user_dependencies_no_sagemaker(self): + """Test remote function with user dependencies but no sagemaker. + + This test verifies that when user provides dependencies without sagemaker, + sagemaker>=3.2.0 is automatically appended. + """ + # Create a temporary requirements.txt without sagemaker + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + req_file = f.name + + try: + @remote( + instance_type="ml.m5.large", + dependencies=req_file, + ) + def compute_with_numpy(x): + """Function that uses numpy.""" + import numpy as np + return np.array([x, x*2, x*3]).sum() + + # Execute the function + result = compute_with_numpy(5) + + # Verify result (5 + 10 + 15 = 30) + assert result == 30, f"Expected 30, got {result}" + print("✓ Remote function with user dependencies executed successfully") + finally: + os.remove(req_file) + + +class TestRemoteFunctionVersionCompatibility: + """Tests for version compatibility between local and remote environments.""" + + @pytest.mark.integ + # @skip_if_no_aws_region + def test_deserialization_with_injected_sagemaker(self): + """Test that deserialization works with injected sagemaker dependency. + + This test verifies that the remote environment can properly deserialize + functions when sagemaker>=3.2.0 is available. + """ + @remote( + instance_type="ml.m5.large", + ) + def complex_computation(data): + """Function that performs complex computation.""" + result = sum(data) * len(data) + return result + + # Execute with various data types + test_data = [1, 2, 3, 4, 5] + result = complex_computation(test_data) + + # Verify result (sum=15, len=5, 15*5=75) + assert result == 75, f"Expected 75, got {result}" + print("✓ Deserialization with injected sagemaker works correctly") + + @pytest.mark.integ + # @skip_if_no_aws_region + def test_multiple_remote_functions_with_dependencies(self): + """Test multiple remote functions with different dependency configurations. + + This test verifies that the dependency injection works correctly + when multiple remote functions are defined and executed. + """ + @remote(instance_type="ml.m5.large") + def func1(x): + return x + 1 + + @remote(instance_type="ml.m5.large") + def func2(x): + return x * 2 + + # Execute both functions + result1 = func1(5) + result2 = func2(5) + + assert result1 == 6, f"func1: Expected 6, got {result1}" + assert result2 == 10, f"func2: Expected 10, got {result2}" + print("✓ Multiple remote functions with dependencies executed successfully") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-m", "integ"]) diff --git a/sagemaker-core/tests/unit/remote_function/test_ensure_sagemaker_dependency.py b/sagemaker-core/tests/unit/remote_function/test_ensure_sagemaker_dependency.py new file mode 100644 index 0000000000..6d89bd6e5c --- /dev/null +++ b/sagemaker-core/tests/unit/remote_function/test_ensure_sagemaker_dependency.py @@ -0,0 +1,318 @@ +"""Unit tests for _ensure_sagemaker_dependency function. + +Tests the logic that ensures sagemaker>=3.2.0 is included in remote function dependencies +to prevent version mismatch issues with HMAC key integrity checks. +""" + +import os +import tempfile +import unittest +from unittest.mock import patch, MagicMock + +# Add src to path +import sys +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src')) + +from sagemaker.core.remote_function.job import _ensure_sagemaker_dependency, _check_sagemaker_version_compatibility + + +class TestEnsureSagemakerDependency(unittest.TestCase): + """Test cases for _ensure_sagemaker_dependency function.""" + + def test_no_dependencies_creates_temp_requirements_file(self): + """Test that a temp requirements.txt is created when no dependencies provided.""" + result = _ensure_sagemaker_dependency(None) + + # Verify file was created + self.assertTrue(os.path.exists(result), f"Requirements file not created at {result}") + + # Verify it's in temp directory + self.assertIn(tempfile.gettempdir(), result) + + # Verify content + with open(result, "r") as f: + content = f.read() + self.assertIn("sagemaker>=3.2.0,<4.0.0", content) + + # Cleanup + os.remove(result) + + def test_no_dependencies_file_has_correct_format(self): + """Test that created requirements.txt has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + lines = f.readlines() + + # Should have exactly one line with sagemaker dependency + self.assertEqual(len(lines), 1) + self.assertEqual(lines[0].strip(), "sagemaker>=3.2.0,<4.0.0") + + # Cleanup + os.remove(result) + + def test_appends_sagemaker_to_existing_requirements(self): + """Test that sagemaker is appended to existing requirements.txt.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Verify content + with open(result, "r") as f: + content = f.read() + + self.assertIn("numpy>=1.20.0", content) + self.assertIn("pandas>=1.3.0", content) + self.assertIn("sagemaker>=3.2.0,<4.0.0", content) + finally: + os.remove(temp_file) + + def test_does_not_duplicate_sagemaker_if_already_present(self): + """Test that sagemaker is not duplicated if already in requirements.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\nsagemaker>=3.2.0,<4.0.0\npandas>=1.3.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Count occurrences of sagemaker + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1, "sagemaker should appear exactly once") + + # Verify user's version is preserved + self.assertIn("sagemaker>=3.2.0,<4.0.0", content) + finally: + os.remove(temp_file) + + def test_preserves_user_dependencies(self): + """Test that user's existing dependencies are preserved.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("torch>=1.9.0\ntorchvision>=0.10.0\nscikit-learn>=0.24.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # All user dependencies should be present + self.assertIn("torch>=1.9.0", content) + self.assertIn("torchvision>=0.10.0", content) + self.assertIn("scikit-learn>=0.24.0", content) + self.assertIn("sagemaker>=3.2.0,<4.0.0", content) + finally: + os.remove(temp_file) + + def test_handles_yml_files_gracefully(self): + """Test that yml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\ndependencies:\n - numpy\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged (yml files are not modified) + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_handles_yaml_files_gracefully(self): + """Test that yaml files are returned unchanged.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as f: + f.write("name: test-env\nchannels:\n - conda-forge\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + # Should return the same file + self.assertEqual(result, temp_file) + + # Content should be unchanged + with open(result, "r") as f: + content = f.read() + + self.assertNotIn("sagemaker", content.lower()) + finally: + os.remove(temp_file) + + def test_case_insensitive_sagemaker_detection(self): + """Test that sagemaker detection is case-insensitive.""" + with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: + f.write("numpy>=1.20.0\nSAGEMAKER>=3.2.0,<4.0.0\n") + temp_file = f.name + + try: + result = _ensure_sagemaker_dependency(temp_file) + + with open(result, "r") as f: + content = f.read() + + # Should not duplicate even with different case + sagemaker_count = content.lower().count("sagemaker") + self.assertEqual(sagemaker_count, 1) + finally: + os.remove(temp_file) + + def test_temp_file_location(self): + """Test that temp file is created in system temp directory.""" + result = _ensure_sagemaker_dependency(None) + + # Should be in system temp directory + temp_dir = tempfile.gettempdir() + self.assertTrue(result.startswith(temp_dir)) + + # Should have correct prefix + self.assertIn("sagemaker_requirements_", result) + + # Cleanup + os.remove(result) + + def test_version_constraint_format(self): + """Test that version constraint has correct format.""" + result = _ensure_sagemaker_dependency(None) + + with open(result, "r") as f: + content = f.read().strip() + + # Should have both lower and upper bounds + self.assertIn(">=3.2.0", content) + self.assertIn("<4.0.0", content) + + # Cleanup + os.remove(result) + + +class TestCheckSagemakerVersionCompatibility(unittest.TestCase): + """Test cases for _check_sagemaker_version_compatibility function.""" + + # ===== GOOD CASES (should NOT raise ValueError) ===== + + def test_v3_good_exact_version_32(self): + """Test V3 exact version 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==3.2.0") + + def test_v3_good_greater_equal_32(self): + """Test V3 greater or equal 3.2.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0") + + def test_v3_good_range_32_to_40(self): + """Test V3 range 3.2.0 to 4.0.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0,<4.0.0") + + def test_v2_good_exact_version_256(self): + """Test V2 exact version 2.256.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker==2.256.0") + + def test_v2_good_range_256_to_300(self): + """Test V2 range 2.256.0 to 2.300.0 (good - SHA256).""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=2.256.0,<2.300.0") + + def test_unparseable_requirement_no_error(self): + """Test that unparseable requirements don't raise (let pip handle it).""" + # Should not raise - let pip handle invalid syntax + _check_sagemaker_version_compatibility("sagemaker") + _check_sagemaker_version_compatibility("invalid-requirement") + + # ===== BAD CASES (should raise ValueError) ===== + + def test_v3_bad_exact_version_31(self): + """Test V3 exact version 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.1.0") + + def test_v3_bad_exact_version_300(self): + """Test V3 exact version 3.0.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==3.0.0") + + def test_v3_bad_less_than_32(self): + """Test V3 less than 3.2.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<3.2.0") + + def test_v3_bad_less_equal_31(self): + """Test V3 less or equal 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=3.1.0") + + def test_v3_bad_range_300_to_31(self): + """Test V3 range 3.0.0 to 3.1.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=3.0.0,<3.2.0") + + def test_v2_bad_exact_version_255(self): + """Test V2 exact version 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError) as context: + _check_sagemaker_version_compatibility("sagemaker==2.255.0") + self.assertIn("HMAC-based integrity checks", str(context.exception)) + + def test_v2_bad_exact_version_200(self): + """Test V2 exact version 2.200.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker==2.200.0") + + def test_v2_bad_less_than_256(self): + """Test V2 less than 2.256.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<2.256.0") + + def test_v2_bad_less_equal_255(self): + """Test V2 less or equal 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker<=2.255.0") + + def test_v2_bad_range_200_to_255(self): + """Test V2 range 2.200.0 to 2.255.0 (bad - HMAC).""" + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0") + + # ===== EDGE CASES ===== + + def test_multiple_version_specifiers_good(self): + """Test multiple version specifiers that are good.""" + # Should not raise + _check_sagemaker_version_compatibility("sagemaker>=3.2.0,<4.0.0") + + def test_multiple_version_specifiers_good_with_lower_bound(self): + """Test multiple version specifiers that are good (upper bound resolves to good version).""" + # Should not raise - <3.300.0 decrements to 3.299.0 which is >= 3.2.0 + _check_sagemaker_version_compatibility("sagemaker>=3.0.0,<3.300.0") + + def test_multiple_version_specifiers_bad(self): + """Test multiple version specifiers that are bad.""" + # Should raise - <3.2.0 decrements to 3.1.0 which is < 3.2.0 (HMAC) + with self.assertRaises(ValueError): + _check_sagemaker_version_compatibility("sagemaker>=3.0.0,<3.2.0") + + def test_v3_good_greater_than_31(self): + """Test V3 greater than 3.1.0 (not checked - treat as lower bound only).""" + # Should not raise - > is treated as a lower bound, we don't check those + _check_sagemaker_version_compatibility("sagemaker>3.1.0") + + +if __name__ == "__main__": + unittest.main()