Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 13 additions & 18 deletions google/cloud/aiplatform/metadata/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any, Dict, Optional, Sequence, Union

from google.auth import credentials as auth_credentials
from google.cloud import storage
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from google.cloud.aiplatform import explain
Expand Down Expand Up @@ -371,6 +370,7 @@ def save_model(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
staging_bucket: Optional[str] = None,
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact.

Expand Down Expand Up @@ -418,12 +418,18 @@ def save_model(
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Artifact. Overrides
credentials set in aiplatform.init.
staging_bucket (str):
Optional. The staging bucket used to save the model. If not provided,
the staging bucket set in aiplatform.init will be used. A staging
bucket or uri is required for saving a model.

Returns:
An ExperimentModel instance.

Raises:
ValueError: if model type is not supported.
RuntimeError: If staging bucket was not set using aiplatform.init
and a staging bucket or uri was not passed in.
"""
framework_name = framework_version = ""
try:
Expand Down Expand Up @@ -476,24 +482,13 @@ def save_model(
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]

if not uri:
staging_bucket = initializer.global_config.staging_bucket
# TODO(b/264196887)
staging_bucket = staging_bucket or initializer.global_config.staging_bucket

if not staging_bucket:
project = project or initializer.global_config.project
location = location or initializer.global_config.location
credentials = credentials or initializer.global_config.credentials

staging_bucket_name = project + "-vertex-staging-" + location
client = storage.Client(project=project, credentials=credentials)
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
if not staging_bucket.exists():
_LOGGER.info(f'Creating staging bucket "{staging_bucket_name}"')
staging_bucket = client.create_bucket(
bucket_or_name=staging_bucket,
project=project,
location=location,
)
staging_bucket = f"gs://{staging_bucket_name}"
raise RuntimeError(
"staging_bucket should be passed to save_model constructor or "
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
)

unique_name = utils.timestamped_unique_name()
uri = f"{staging_bucket}/{unique_name}-{framework_name}-model"
Expand Down
8 changes: 8 additions & 0 deletions google/cloud/aiplatform/metadata/experiment_run_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ def log_model(
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
staging_bucket: Optional[str] = None,
) -> google_artifact_schema.ExperimentModel:
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.

Expand Down Expand Up @@ -1245,12 +1246,18 @@ def log_model(
credentials (auth_credentials.Credentials):
Optional. Custom credentials used to create this Artifact. Overrides
credentials set in aiplatform.init.
staging_bucket (str):
Optional. The staging bucket used to save the model. If not provided,
the staging bucket set in aiplatform.init will be used. A staging
bucket or uri is required for saving a model.

Returns:
An ExperimentModel instance.

Raises:
ValueError: if model type is not supported.
RuntimeError: If staging bucket was not set using aiplatform.init
and a staging bucket or uri was not passed in.
"""
experiment_model = _models.save_model(
model=model,
Expand All @@ -1262,6 +1269,7 @@ def log_model(
project=project,
location=location,
credentials=credentials,
staging_bucket=staging_bucket,
)

self._metadata_node.add_artifacts_and_executions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def save_model_sample(
Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821
] = None,
display_name: Optional[str] = None,
staging_bucket: Optional[str] = None,
) -> None:
aiplatform.init(project=project, location=location)

Expand All @@ -39,6 +40,7 @@ def save_model_sample(
uri=uri,
input_example=input_example,
display_name=display_name,
staging_bucket=staging_bucket,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def test_save_model_sample(mock_save_model):
uri=constants.MODEL_ARTIFACT_URI,
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
display_name=constants.DISPLAY_NAME,
staging_bucket=constants.STAGING_BUCKET,
)

mock_save_model.assert_called_once_with(
Expand All @@ -37,4 +38,5 @@ def test_save_model_sample(mock_save_model):
uri=constants.MODEL_ARTIFACT_URI,
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
display_name=constants.DISPLAY_NAME,
staging_bucket=constants.STAGING_BUCKET,
)
Loading