diff --git a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py index 10f338c4b5..80f9c50e4b 100644 --- a/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py +++ b/tests/integ/sagemaker/serve/test_base_model_builder_deploy.py @@ -12,38 +12,72 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import pytest - -from sagemaker import get_execution_role -from sklearn.datasets import load_iris -from sklearn.model_selection import train_test_split - import os +import uuid +from typing import Generator +import numpy as np +import pandas as pd +import pytest +from sagemaker_core.main.resources import TrainingJob from sagemaker_core.main.shapes import ( AlgorithmSpecification, Channel, DataSource, - S3DataSource, OutputDataConfig, ResourceConfig, + S3DataSource, StoppingCondition, ) -import uuid -from sagemaker.serve.builder.model_builder import ModelBuilder -import pandas as pd -import numpy as np -from sagemaker.serve import InferenceSpec, SchemaBuilder -from sagemaker_core.main.resources import TrainingJob +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split from xgboost import XGBClassifier -from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig - -from sagemaker.s3_utils import s3_path_join +from sagemaker import get_execution_role from sagemaker.async_inference import AsyncInferenceConfig +from sagemaker.s3_utils import s3_path_join +from sagemaker.serve import InferenceSpec, SchemaBuilder +from sagemaker.serve.builder.model_builder import ModelBuilder +from sagemaker.serverless.serverless_inference_config import ServerlessInferenceConfig from tests.integ.utils import cleanup_model_resources +@pytest.fixture(autouse=True) +def cleanup_endpoints(mb_sagemaker_session) -> Generator[None, None, None]: + """Clean up any existing endpoints before and after tests.""" + sagemaker_client = mb_sagemaker_session.sagemaker_client + + # Pre-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + yield + + # Post-test cleanup + try: + endpoints = sagemaker_client.list_endpoints() + for endpoint in endpoints["Endpoints"]: + try: + sagemaker_client.delete_endpoint(EndpointName=endpoint["EndpointName"]) + sagemaker_client.delete_endpoint_config( + EndpointConfigName=endpoint["EndpointConfigName"] + ) + except Exception as e: + print(f"Error cleaning up endpoint {endpoint['EndpointName']}: {e}") + except Exception as e: + print(f"Error listing endpoints: {e}") + + @pytest.fixture(scope="module") def xgboost_model_builder(mb_sagemaker_session): sagemaker_session = mb_sagemaker_session