|
29 | 29 | from sagemaker.jumpstart.cache import ( |
30 | 30 | JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY, |
31 | 31 | JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY, |
| 32 | + DEFAULT_JUMPSTART_SAGEMAKER_SESSION, |
32 | 33 | JumpStartModelsCache, |
33 | 34 | ) |
34 | 35 | from sagemaker.jumpstart.constants import ( |
|
57 | 58 | from sagemaker.jumpstart.utils import get_jumpstart_content_bucket |
58 | 59 |
|
59 | 60 |
|
| 61 | +@patch("sagemaker.jumpstart.utils.get_region_fallback", lambda *args, **kwargs: "dummy-region") |
| 62 | +@patch( |
| 63 | + "sagemaker.jumpstart.utils.get_jumpstart_content_bucket", lambda *args, **kwargs: "dummy-bucket" |
| 64 | +) |
| 65 | +@patch("boto3.client") |
| 66 | +def test_jumpstart_cache_init(mock_boto3_client): |
| 67 | + cache = JumpStartModelsCache() |
| 68 | + assert cache._region == "dummy-region" |
| 69 | + assert cache.s3_bucket_name == "dummy-bucket" |
| 70 | + assert cache._manifest_file_s3_key == JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY |
| 71 | + assert cache._proprietary_manifest_s3_key == JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY |
| 72 | + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 73 | + mock_boto3_client.assert_called_once_with("s3", region_name="dummy-region") |
| 74 | + |
| 75 | + # Some callers override the session to None, should still be set to default |
| 76 | + cache = JumpStartModelsCache(sagemaker_session=None) |
| 77 | + assert cache._sagemaker_session == DEFAULT_JUMPSTART_SAGEMAKER_SESSION |
| 78 | + |
| 79 | + |
60 | 80 | @patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function) |
61 | 81 | @patch("sagemaker.jumpstart.utils.get_sagemaker_version", lambda: "2.68.3") |
62 | 82 | def test_jumpstart_cache_get_header(): |
|
0 commit comments