diff --git a/google/cloud/aiplatform/metadata/_models.py b/google/cloud/aiplatform/metadata/_models.py index 0f006bcf2e..c5f8fecf07 100644 --- a/google/cloud/aiplatform/metadata/_models.py +++ b/google/cloud/aiplatform/metadata/_models.py @@ -22,7 +22,6 @@ from typing import Any, Dict, Optional, Sequence, Union from google.auth import credentials as auth_credentials -from google.cloud import storage from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import explain @@ -371,6 +370,7 @@ def save_model( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + staging_bucket: Optional[str] = None, ) -> google_artifact_schema.ExperimentModel: """Saves a ML model into a MLMD artifact. @@ -418,12 +418,18 @@ def save_model( credentials (auth_credentials.Credentials): Optional. Custom credentials used to create this Artifact. Overrides credentials set in aiplatform.init. + staging_bucket (str): + Optional. The staging bucket used to save the model. If not provided, + the staging bucket set in aiplatform.init will be used. A staging + bucket or uri is required for saving a model. Returns: An ExperimentModel instance. Raises: ValueError: if model type is not supported. + RuntimeError: If staging bucket was not set using aiplatform.init + and a staging bucket or uri was not passed in. """ framework_name = framework_version = "" try: @@ -476,24 +482,13 @@ def save_model( model_file = _FRAMEWORK_SPECS[framework_name]["model_file"] if not uri: - staging_bucket = initializer.global_config.staging_bucket - # TODO(b/264196887) + staging_bucket = staging_bucket or initializer.global_config.staging_bucket + if not staging_bucket: - project = project or initializer.global_config.project - location = location or initializer.global_config.location - credentials = credentials or initializer.global_config.credentials - - staging_bucket_name = project + "-vertex-staging-" + location - client = storage.Client(project=project, credentials=credentials) - staging_bucket = storage.Bucket(client=client, name=staging_bucket_name) - if not staging_bucket.exists(): - _LOGGER.info(f'Creating staging bucket "{staging_bucket_name}"') - staging_bucket = client.create_bucket( - bucket_or_name=staging_bucket, - project=project, - location=location, - ) - staging_bucket = f"gs://{staging_bucket_name}" + raise RuntimeError( + "staging_bucket should be passed to save_model constructor or " + "should be set using aiplatform.init(staging_bucket='gs://my-bucket')" + ) unique_name = utils.timestamped_unique_name() uri = f"{staging_bucket}/{unique_name}-{framework_name}-model" diff --git a/google/cloud/aiplatform/metadata/experiment_run_resource.py b/google/cloud/aiplatform/metadata/experiment_run_resource.py index 95a3d616b6..8da82c2f52 100644 --- a/google/cloud/aiplatform/metadata/experiment_run_resource.py +++ b/google/cloud/aiplatform/metadata/experiment_run_resource.py @@ -1196,6 +1196,7 @@ def log_model( project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, + staging_bucket: Optional[str] = None, ) -> google_artifact_schema.ExperimentModel: """Saves a ML model into a MLMD artifact and log it to this ExperimentRun. @@ -1245,12 +1246,18 @@ def log_model( credentials (auth_credentials.Credentials): Optional. Custom credentials used to create this Artifact. Overrides credentials set in aiplatform.init. + staging_bucket (str): + Optional. The staging bucket used to save the model. If not provided, + the staging bucket set in aiplatform.init will be used. A staging + bucket or uri is required for saving a model. Returns: An ExperimentModel instance. Raises: ValueError: if model type is not supported. + RuntimeError: If staging bucket was not set using aiplatform.init + and a staging bucket or uri was not passed in. """ experiment_model = _models.save_model( model=model, @@ -1262,6 +1269,7 @@ def log_model( project=project, location=location, credentials=credentials, + staging_bucket=staging_bucket, ) self._metadata_node.add_artifacts_and_executions( diff --git a/samples/model-builder/experiment_tracking/save_model_sample.py b/samples/model-builder/experiment_tracking/save_model_sample.py index d46a0bdbfa..e746c7dc12 100644 --- a/samples/model-builder/experiment_tracking/save_model_sample.py +++ b/samples/model-builder/experiment_tracking/save_model_sample.py @@ -30,6 +30,7 @@ def save_model_sample( Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821 ] = None, display_name: Optional[str] = None, + staging_bucket: Optional[str] = None, ) -> None: aiplatform.init(project=project, location=location) @@ -39,6 +40,7 @@ def save_model_sample( uri=uri, input_example=input_example, display_name=display_name, + staging_bucket=staging_bucket, ) diff --git a/samples/model-builder/experiment_tracking/save_model_sample_test.py b/samples/model-builder/experiment_tracking/save_model_sample_test.py index fae5fc1abe..6eff403f13 100644 --- a/samples/model-builder/experiment_tracking/save_model_sample_test.py +++ b/samples/model-builder/experiment_tracking/save_model_sample_test.py @@ -29,6 +29,7 @@ def test_save_model_sample(mock_save_model): uri=constants.MODEL_ARTIFACT_URI, input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE, display_name=constants.DISPLAY_NAME, + staging_bucket=constants.STAGING_BUCKET, ) mock_save_model.assert_called_once_with( @@ -37,4 +38,5 @@ def test_save_model_sample(mock_save_model): uri=constants.MODEL_ARTIFACT_URI, input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE, display_name=constants.DISPLAY_NAME, + staging_bucket=constants.STAGING_BUCKET, )