From 5788c50bb6dad16081ba0966ceba61b6e70f95cd Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Fri, 1 May 2026 18:41:21 -0700 Subject: [PATCH] feat: migrate resources to agentplatform PiperOrigin-RevId: 909002274 --- agentplatform/resources/__init__.py | 180 ++ agentplatform/resources/preview/__init__.py | 89 + .../preview/feature_store/__init__.py | 68 + .../feature_store/_offline_store_impl.py | 188 ++ .../preview/feature_store/feature.py | 149 ++ .../preview/feature_store/feature_group.py | 590 ++++++ .../preview/feature_store/feature_monitor.py | 335 +++ .../feature_store/feature_online_store.py | 647 ++++++ .../preview/feature_store/feature_view.py | 537 +++++ .../preview/feature_store/offline_store.py | 289 +++ .../resources/preview/feature_store/utils.py | 231 ++ .../preview/ml_monitoring/__init__.py | 24 + .../preview/ml_monitoring/model_monitors.py | 1866 +++++++++++++++++ .../preview/ml_monitoring/spec/__init__.py | 44 + .../ml_monitoring/spec/notification.py | 72 + .../preview/ml_monitoring/spec/objective.py | 498 +++++ .../preview/ml_monitoring/spec/output.py | 46 + .../preview/ml_monitoring/spec/schema.py | 439 ++++ tests/unit/agentplatform/conftest.py | 283 +++ .../agentplatform/feature_store_constants.py | 498 +++++ .../test_feature.py | 8 +- .../test_feature_group.py | 12 +- .../test_feature_monitor.py | 6 +- .../test_feature_online_store.py | 8 +- .../test_feature_view.py | 8 +- .../unit/agentplatform/test_model_monitors.py | 1227 +++++++++++ tests/unit/vertexai/conftest.py | 2 +- tests/unit/vertexai/test_vertexai_feature.py | 291 +++ .../vertexai/test_vertexai_feature_group.py | 1027 +++++++++ .../vertexai/test_vertexai_feature_monitor.py | 370 ++++ .../test_vertexai_feature_online_store.py | 847 ++++++++ .../vertexai/test_vertexai_feature_view.py | 856 ++++++++ ...ors.py => test_vertexai_model_monitors.py} | 0 ...py => vertexai_feature_store_constants.py} | 0 34 files changed, 11712 insertions(+), 23 deletions(-) create mode 100644 agentplatform/resources/__init__.py create mode 100644 agentplatform/resources/preview/__init__.py create mode 100644 agentplatform/resources/preview/feature_store/__init__.py create mode 100644 agentplatform/resources/preview/feature_store/_offline_store_impl.py create mode 100644 agentplatform/resources/preview/feature_store/feature.py create mode 100644 agentplatform/resources/preview/feature_store/feature_group.py create mode 100644 agentplatform/resources/preview/feature_store/feature_monitor.py create mode 100644 agentplatform/resources/preview/feature_store/feature_online_store.py create mode 100644 agentplatform/resources/preview/feature_store/feature_view.py create mode 100644 agentplatform/resources/preview/feature_store/offline_store.py create mode 100644 agentplatform/resources/preview/feature_store/utils.py create mode 100644 agentplatform/resources/preview/ml_monitoring/__init__.py create mode 100644 agentplatform/resources/preview/ml_monitoring/model_monitors.py create mode 100644 agentplatform/resources/preview/ml_monitoring/spec/__init__.py create mode 100644 agentplatform/resources/preview/ml_monitoring/spec/notification.py create mode 100644 agentplatform/resources/preview/ml_monitoring/spec/objective.py create mode 100644 agentplatform/resources/preview/ml_monitoring/spec/output.py create mode 100644 agentplatform/resources/preview/ml_monitoring/spec/schema.py create mode 100644 tests/unit/agentplatform/feature_store_constants.py rename tests/unit/{vertexai => agentplatform}/test_feature.py (98%) rename tests/unit/{vertexai => agentplatform}/test_feature_group.py (99%) rename tests/unit/{vertexai => agentplatform}/test_feature_monitor.py (98%) rename tests/unit/{vertexai => agentplatform}/test_feature_online_store.py (99%) rename tests/unit/{vertexai => agentplatform}/test_feature_view.py (99%) create mode 100644 tests/unit/agentplatform/test_model_monitors.py create mode 100644 tests/unit/vertexai/test_vertexai_feature.py create mode 100644 tests/unit/vertexai/test_vertexai_feature_group.py create mode 100644 tests/unit/vertexai/test_vertexai_feature_monitor.py create mode 100644 tests/unit/vertexai/test_vertexai_feature_online_store.py create mode 100644 tests/unit/vertexai/test_vertexai_feature_view.py rename tests/unit/vertexai/{test_model_monitors.py => test_vertexai_model_monitors.py} (100%) rename tests/unit/vertexai/{feature_store_constants.py => vertexai_feature_store_constants.py} (100%) diff --git a/agentplatform/resources/__init__.py b/agentplatform/resources/__init__.py new file mode 100644 index 0000000000..db65f53403 --- /dev/null +++ b/agentplatform/resources/__init__.py @@ -0,0 +1,180 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The agentplatform resources module.""" + +from google.cloud.aiplatform import initializer + +from google.cloud.aiplatform.datasets import ( + ImageDataset, + TabularDataset, + TextDataset, + TimeSeriesDataset, + VideoDataset, +) +from google.cloud.aiplatform import explain +from google.cloud.aiplatform import gapic +from google.cloud.aiplatform import hyperparameter_tuning +from google.cloud.aiplatform.featurestore import ( + EntityType, + Feature, + Featurestore, +) +from google.cloud.aiplatform.matching_engine import ( + MatchingEngineIndex, + MatchingEngineIndexEndpoint, +) +from google.cloud.aiplatform import metadata +from google.cloud.aiplatform.tensorboard import uploader_tracker +from google.cloud.aiplatform.models import DeploymentResourcePool +from google.cloud.aiplatform.models import Endpoint +from google.cloud.aiplatform.models import PrivateEndpoint +from google.cloud.aiplatform.models import Model +from google.cloud.aiplatform.models import ModelRegistry +from google.cloud.aiplatform.model_evaluation import ModelEvaluation +from google.cloud.aiplatform.jobs import ( + BatchPredictionJob, + CustomJob, + HyperparameterTuningJob, + ModelDeploymentMonitoringJob, +) +from google.cloud.aiplatform.pipeline_jobs import PipelineJob +from google.cloud.aiplatform.pipeline_job_schedules import ( + PipelineJobSchedule, +) +from google.cloud.aiplatform.tensorboard import ( + Tensorboard, + TensorboardExperiment, + TensorboardRun, + TensorboardTimeSeries, +) +from google.cloud.aiplatform.training_jobs import ( + CustomTrainingJob, + CustomContainerTrainingJob, + CustomPythonPackageTrainingJob, + AutoMLTabularTrainingJob, + AutoMLForecastingTrainingJob, + SequenceToSequencePlusForecastingTrainingJob, + TemporalFusionTransformerForecastingTrainingJob, + TimeSeriesDenseEncoderForecastingTrainingJob, + AutoMLImageTrainingJob, + AutoMLTextTrainingJob, + AutoMLVideoTrainingJob, +) + +from google.cloud.aiplatform import helpers + +""" +Usage: +import agentplatform + +agentplatform.init(project='my_project') +""" +init = initializer.global_config.init + +get_pipeline_df = metadata.metadata._LegacyExperimentService.get_pipeline_df + +log_params = metadata.metadata._experiment_tracker.log_params +log_metrics = metadata.metadata._experiment_tracker.log_metrics +log_classification_metrics = ( + metadata.metadata._experiment_tracker.log_classification_metrics +) +log_model = metadata.metadata._experiment_tracker.log_model +get_experiment_df = metadata.metadata._experiment_tracker.get_experiment_df +start_run = metadata.metadata._experiment_tracker.start_run +autolog = metadata.metadata._experiment_tracker.autolog +start_execution = metadata.metadata._experiment_tracker.start_execution +log = metadata.metadata._experiment_tracker.log +log_time_series_metrics = metadata.metadata._experiment_tracker.log_time_series_metrics +end_run = metadata.metadata._experiment_tracker.end_run + +upload_tb_log = uploader_tracker._tensorboard_tracker.upload_tb_log +start_upload_tb_log = uploader_tracker._tensorboard_tracker.start_upload_tb_log +end_upload_tb_log = uploader_tracker._tensorboard_tracker.end_upload_tb_log + +save_model = metadata._models.save_model +get_experiment_model = metadata.schema.google.artifact_schema.ExperimentModel.get + +Experiment = metadata.experiment_resources.Experiment +ExperimentRun = metadata.experiment_run_resource.ExperimentRun +Artifact = metadata.artifact.Artifact +Execution = metadata.execution.Execution +Context = metadata.context.Context + + +__all__ = ( + "end_run", + "explain", + "gapic", + "init", + "helpers", + "hyperparameter_tuning", + "log", + "log_params", + "log_metrics", + "log_classification_metrics", + "log_model", + "log_time_series_metrics", + "get_experiment_df", + "get_pipeline_df", + "start_run", + "start_execution", + "save_model", + "get_experiment_model", + "autolog", + "upload_tb_log", + "start_upload_tb_log", + "end_upload_tb_log", + "Artifact", + "AutoMLImageTrainingJob", + "AutoMLTabularTrainingJob", + "AutoMLForecastingTrainingJob", + "AutoMLTextTrainingJob", + "AutoMLVideoTrainingJob", + "BatchPredictionJob", + "CustomJob", + "CustomTrainingJob", + "CustomContainerTrainingJob", + "CustomPythonPackageTrainingJob", + "DeploymentResourcePool", + "Endpoint", + "EntityType", + "Execution", + "Experiment", + "ExperimentRun", + "Feature", + "Featurestore", + "MatchingEngineIndex", + "MatchingEngineIndexEndpoint", + "ImageDataset", + "HyperparameterTuningJob", + "Model", + "ModelRegistry", + "ModelEvaluation", + "ModelDeploymentMonitoringJob", + "PipelineJob", + "PipelineJobSchedule", + "PrivateEndpoint", + "SequenceToSequencePlusForecastingTrainingJob", + "TabularDataset", + "Tensorboard", + "TensorboardExperiment", + "TensorboardRun", + "TensorboardTimeSeries", + "TextDataset", + "TemporalFusionTransformerForecastingTrainingJob", + "TimeSeriesDataset", + "TimeSeriesDenseEncoderForecastingTrainingJob", + "VideoDataset", +) diff --git a/agentplatform/resources/preview/__init__.py b/agentplatform/resources/preview/__init__.py new file mode 100644 index 0000000000..345592a409 --- /dev/null +++ b/agentplatform/resources/preview/__init__.py @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The agentplatform resources preview module.""" + +from google.cloud.aiplatform.preview.jobs import ( + CustomJob, + HyperparameterTuningJob, +) +from google.cloud.aiplatform.preview.models import ( + Prediction, + DeploymentResourcePool, + Endpoint, + Model, +) +from google.cloud.aiplatform.preview.featurestore.entity_type import ( + EntityType, +) +from google.cloud.aiplatform.preview.persistent_resource import ( + PersistentResource, +) +from google.cloud.aiplatform.preview.pipelinejobschedule.pipeline_job_schedules import ( + PipelineJobSchedule, +) + +from agentplatform.resources.preview.feature_store import ( + Feature, + FeatureGroup, + FeatureGroupBigQuerySource, + FeatureMonitor, + FeatureOnlineStore, + FeatureOnlineStoreType, + FeatureView, + FeatureViewBigQuerySource, + FeatureViewReadResponse, + FeatureViewRegistrySource, + FeatureViewVertexRagSource, + IndexConfig, + TreeAhConfig, + BruteForceConfig, + DistanceMeasureType, + AlgorithmConfig, +) + +from agentplatform.resources.preview.ml_monitoring import ( + ModelMonitor, + ModelMonitoringJob, +) + +__all__ = ( + "CustomJob", + "HyperparameterTuningJob", + "Prediction", + "DeploymentResourcePool", + "Endpoint", + "Model", + "PersistentResource", + "EntityType", + "PipelineJobSchedule", + "Feature", + "FeatureGroup", + "FeatureGroupBigQuerySource", + "FeatureMonitor", + "FeatureOnlineStoreType", + "FeatureOnlineStore", + "FeatureView", + "FeatureViewBigQuerySource", + "FeatureViewReadResponse", + "FeatureViewVertexRagSource", + "FeatureViewRegistrySource", + "IndexConfig", + "TreeAhConfig", + "BruteForceConfig", + "DistanceMeasureType", + "AlgorithmConfig", + "ModelMonitor", + "ModelMonitoringJob", +) diff --git a/agentplatform/resources/preview/feature_store/__init__.py b/agentplatform/resources/preview/feature_store/__init__.py new file mode 100644 index 0000000000..8a790eff46 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/__init__.py @@ -0,0 +1,68 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The agentplatform resources preview module.""" + +from agentplatform.resources.preview.feature_store.feature import ( + Feature, +) + +from agentplatform.resources.preview.feature_store.feature_group import ( + FeatureGroup, +) + +from agentplatform.resources.preview.feature_store.feature_monitor import ( + FeatureMonitor, +) + +from agentplatform.resources.preview.feature_store.feature_online_store import ( + FeatureOnlineStore, + FeatureOnlineStoreType, +) + +from agentplatform.resources.preview.feature_store.feature_view import ( + FeatureView, +) + +from agentplatform.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, + FeatureViewBigQuerySource, + FeatureViewReadResponse, + FeatureViewVertexRagSource, + FeatureViewRegistrySource, + IndexConfig, + TreeAhConfig, + BruteForceConfig, + DistanceMeasureType, + AlgorithmConfig, +) + +__all__ = ( + "Feature", + "FeatureGroup", + "FeatureGroupBigQuerySource", + "FeatureMonitor", + "FeatureOnlineStoreType", + "FeatureOnlineStore", + "FeatureView", + "FeatureViewBigQuerySource", + "FeatureViewReadResponse", + "FeatureViewVertexRagSource", + "FeatureViewRegistrySource", + "IndexConfig", + "TreeAhConfig", + "BruteForceConfig", + "DistanceMeasureType", + "AlgorithmConfig", +) diff --git a/agentplatform/resources/preview/feature_store/_offline_store_impl.py b/agentplatform/resources/preview/feature_store/_offline_store_impl.py new file mode 100644 index 0000000000..cd824b05a0 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/_offline_store_impl.py @@ -0,0 +1,188 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap +from dataclasses import dataclass +from typing import Optional, List + + +@dataclass +class DataSource: + """An object to represent a data source - both entity DataFrame and any feature data. + + Contains helpers for use with SQL templating. + """ + + def __init__( + self, + qualifying_name: str, + sql: str, + data_columns: List[str], + timestamp_column: str, + entity_id_columns: Optional[List[str]] = None, + ): + """Initialize DataSource object. + + Args: + qualifying_name: + A unique name used to qualify the data in the PITL query. + sql: + SQL query representing the data_source. + data_columns: + Columns other than entity ID column(s) and timestamp column. + timestamp_column: + The column that holds feature timestamp data. + entity_id_columns: + The column(s) that holds entity IDs. Shouldn't be populated for + entity_df. + """ + self.qualifying_name = qualifying_name + self._sql = sql + self.data_columns = data_columns + self.timestamp_column = timestamp_column + self.entity_id_columns = entity_id_columns + + def copy_with_pitl_suffix(self) -> "DataSource": + import copy + + data_source = copy.copy(self) + data_source.qualifying_name += "_pitl" + return data_source + + @property + def sql(self): + return self._sql + + @property + def comma_separated_qualified_data_columns(self): + return ", ".join( + [self.qualifying_name + "." + col for col in self.data_columns] + ) + + @property + def comma_separated_name_qualified_all_non_timestamp_columns(self): + """Same as `comma_separated_qualified_data_columns` but including entity ID column.""" + all_columns = self.data_columns.copy() + if self.entity_id_columns: + all_columns += self.entity_id_columns + return ", ".join([self.qualifying_name + "." + col for col in all_columns]) + + @property + def qualified_timestamp_column(self) -> str: + """Returns name qualified timestamp column e.g. `name.feature_timestamp`.""" + return f"{self.qualifying_name}.{self.timestamp_column}" + + +def _generate_eid_check(entity_data: DataSource, feature: DataSource): + """Generate equality check for entity columns of feature against matching columns in entity_data.""" + e_cols = set(entity_data.data_columns) + f_cols = feature.entity_id_columns + assert f_cols + + equal_statements = [] + for col in f_cols: + if col not in e_cols: + raise ValueError( + f"Feature entity ID column '{col}' should be a column in the entity DataFrame." + ) + equal_statements.append( + f"{entity_data.qualifying_name}.{col} = {feature.qualifying_name}.{col}" + ) + + statement = " AND\n".join(equal_statements) + + return statement + + +# Args: +# textwrap: Module +# generate_eid_check: function (above) +# entity_data: DataSource +# feature_data: List[DataSource] +_PITL_QUERY_TEMPLATE_RAW = """WITH + {{ entity_data.qualifying_name }}_without_row_num AS ( +{{ textwrap.indent(entity_data.sql, ' ' * 4) }} + ), + {{ entity_data.qualifying_name }} AS ( + SELECT *, ROW_NUMBER() OVER() AS row_num, + FROM entity_df_without_row_num + ), + + # Features + {% for feature_data_elem in feature_data %} + {{ feature_data_elem.qualifying_name }} AS ( +{{ textwrap.indent(feature_data_elem.sql, ' ' * 4) }} + ), + {% endfor %} + + # Features with PITL + {% for feature_data_elem in feature_data %} + {{ feature_data_elem.qualifying_name }}_pitl AS ( + SELECT + {{ entity_data.qualifying_name }}.row_num, + {{ feature_data_elem.comma_separated_qualified_data_columns }}, + FROM {{ entity_data.qualifying_name }} + LEFT JOIN {{ feature_data_elem.qualifying_name }} + ON ( +{{ textwrap.indent(generate_eid_check(entity_data, feature_data_elem) + ' AND', ' ' * 6) }} + CAST({{ feature_data_elem.qualified_timestamp_column }} AS TIMESTAMP) <= CAST({{ entity_data.qualified_timestamp_column }} AS TIMESTAMP) + ) + QUALIFY ROW_NUMBER() OVER (PARTITION BY {{ entity_data.qualifying_name }}.row_num ORDER BY {{ feature_data_elem.qualified_timestamp_column }} DESC) = 1 + ){{ ',' if not loop.last else '' }} + {% endfor %} + + +SELECT + {{ entity_data.comma_separated_name_qualified_all_non_timestamp_columns }}, + {% for feature_data_elem in feature_data %} + {% set feature_pitl = feature_data_elem.copy_with_pitl_suffix() %} + {{ feature_pitl.comma_separated_qualified_data_columns }}, + {% endfor %} + {{ entity_data.qualified_timestamp_column }} + +FROM {{ entity_data.qualifying_name }} +{% for feature_data_elem in feature_data %} +JOIN {{ feature_data_elem.qualifying_name }}_pitl USING (row_num) +{% endfor %} +""" + + +def pitl_query_template(): + try: + import jinja2 + except ImportError as exc: + raise ImportError( + "`Jinja2` is not installed but required for this functionality." + ) from exc + + return jinja2.Environment( + loader=jinja2.BaseLoader, lstrip_blocks=True, trim_blocks=True + ).from_string(_PITL_QUERY_TEMPLATE_RAW) + + +def render_pitl_query(entity_data: DataSource, feature_data: List[DataSource]): + """Return the PITL query jinja template. + + The args for the query are as follows: + textwrap: The python textwrap module. + entity_data[DataSource]: The entity data(frame) as SQL source. + feature_data[List[DataSource]]: + """ + return pitl_query_template().render( + textwrap=textwrap, + generate_eid_check=_generate_eid_check, + entity_data=entity_data, + feature_data=feature_data, + ) diff --git a/agentplatform/resources/preview/feature_store/feature.py b/agentplatform/resources/preview/feature_store/feature.py new file mode 100644 index 0000000000..4d74c6d568 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/feature.py @@ -0,0 +1,149 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import List, Optional +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature as gca_feature, + feature_monitor_v1beta1 as gca_feature_monitor, + feature_v1beta1 as gca_feature_v1beta1, + featurestore_service_v1beta1 as gca_featurestore_service_v1beta1, +) + + +class Feature(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature resources.""" + + client_class = utils.FeatureRegistryClientWithOverride + + _resource_noun = "features" + _getter_method = "get_feature" + _list_method = "list_features" + _delete_method = "delete_feature" + _parse_resource_name_method = "parse_feature_path" + _format_resource_name_method = "feature_path" + _gca_resource: gca_feature.Feature + + def __init__( + self, + name: str, + feature_group_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + latest_stats_count: Optional[int] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature. + + Args: + name: + The resource name + (`projects/.../locations/.../featureGroups/.../features/...`) or + ID. + feature_group_id: + The feature group ID. Must be passed in if name is an ID and not + a resource path. + project: + Project to retrieve feature from. If not set, the project set in + aiplatform.init will be used. + location: + Location to retrieve feature from. If not set, the location set + in aiplatform.init will be used. + gca_feature_arg: + The GCA feature object. + Only set when calling from get_feature with latest_stats_count set. + credentials: + Custom credentials to use to retrieve this feature. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if re.fullmatch( + r"projects/.+/locations/.+/featureGroups/.+/features/.+", + name, + ): + if feature_group_id: + raise ValueError( + f"Since feature '{name}' is provided as a path, feature_group_id should not be specified." + ) + feature = name + else: + from .feature_group import FeatureGroup + + # Construct the feature path using feature group ID if only the + # feature group ID is provided. + if not feature_group_id: + raise ValueError( + f"Since feature '{name}' is not provided as a path, please specify feature_group_id." + ) + + feature_group_path = utils.full_resource_name( + resource_name=feature_group_id, + resource_noun=FeatureGroup._resource_noun, + parse_resource_name_method=FeatureGroup._parse_resource_name, + format_resource_name_method=FeatureGroup._format_resource_name, + ) + + feature = f"{feature_group_path}/features/{name}" + + if latest_stats_count is not None: + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + feature_obj: gca_feature_v1beta1.Feature = api_client.select_version( + "v1beta1" + ).get_feature( + request=gca_featurestore_service_v1beta1.GetFeatureRequest( + name=f"{feature}", + feature_stats_and_anomaly_spec=gca_feature_monitor.FeatureStatsAndAnomalySpec( + latest_stats_count=latest_stats_count + ), + ) + ) + self._gca_resource = feature_obj + else: + self._gca_resource = self._get_gca_resource(resource_name=feature) + + @property + def version_column_name(self) -> str: + """The name of the BigQuery Table/View column hosting data for this version.""" + return self._gca_resource.version_column_name + + @property + def description(self) -> str: + """The description of the feature.""" + return self._gca_resource.description + + @property + def point_of_contact(self) -> str: + """The point of contact for the feature.""" + return self._gca_resource.point_of_contact + + @property + def feature_stats_and_anomalies( + self, + ) -> List[gca_feature_monitor.FeatureStatsAndAnomaly]: + """The number of latest stats to return. Only present when gca_feature is set.""" + return self._gca_resource.feature_stats_and_anomaly diff --git a/agentplatform/resources/preview/feature_store/feature_group.py b/agentplatform/resources/preview/feature_store/feature_group.py new file mode 100644 index 0000000000..d235c6fe74 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/feature_group.py @@ -0,0 +1,590 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, List, Optional, Sequence, Tuple +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature as gca_feature, + feature_group as gca_feature_group, + io as gca_io, + feature_monitor_v1beta1 as gca_feature_monitor, +) +from agentplatform.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, +) +from agentplatform.resources.preview.feature_store import ( + Feature, +) +from agentplatform.resources.preview.feature_store.feature_monitor import ( + FeatureMonitor, +) + + +_LOGGER = base.Logger(__name__) + + +class FeatureGroup(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Group resources.""" + + client_class = utils.FeatureRegistryClientWithOverride + + _resource_noun = "feature_groups" + _getter_method = "get_feature_group" + _list_method = "list_feature_groups" + _delete_method = "delete_feature_group" + _parse_resource_name_method = "parse_feature_group_path" + _format_resource_name_method = "feature_group_path" + _gca_resource: gca_feature_group.FeatureGroup + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature group. + + Args: + name: + The resource name + (`projects/.../locations/.../featureGroups/...`) or ID. + project: + Project to retrieve feature group from. If unset, the + project set in aiplatform.init will be used. + location: + Location to retrieve feature group from. If not set, + location set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature group. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + self._gca_resource = self._get_gca_resource(resource_name=name) + + @classmethod + def create( + cls, + name: str, + source: FeatureGroupBigQuerySource = None, + labels: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureGroup": + """Creates a new feature group. + + Args: + name: The name of the feature group. + source: The BigQuery source of the feature group. + labels: + The labels with user-defined metadata to organize your + FeatureGroup. + + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one + FeatureGroup(System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + description: Description of the FeatureGroup. + project: + Project to create feature group in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature group in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature group. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureGroup - the FeatureGroup resource object. + """ + + if not source: + raise ValueError("Please specify a valid source.") + + # Only BigQuery source is supported right now. + if not isinstance(source, FeatureGroupBigQuerySource): + raise ValueError("Only FeatureGroupBigQuerySource is a supported source.") + + # BigQuery source validation. + if not source.uri: + raise ValueError("Please specify URI in BigQuery source.") + + if not source.entity_id_columns: + _LOGGER.info( + "No entity ID columns specified in BigQuery source. Defaulting to ['entity_id']." + ) + entity_id_columns = ["entity_id"] + else: + entity_id_columns = source.entity_id_columns + + gapic_feature_group = gca_feature_group.FeatureGroup( + big_query=gca_feature_group.FeatureGroup.BigQuery( + big_query_source=gca_io.BigQuerySource(input_uri=source.uri), + entity_id_columns=entity_id_columns, + ), + name=name, + description=description, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_group.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_feature_group_lro = api_client.create_feature_group( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_group=gapic_feature_group, + feature_group_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_feature_group_lro) + + created_feature_group = create_feature_group_lro.result() + + _LOGGER.log_create_complete(cls, created_feature_group, "feature_group") + + feature_group_obj = cls( + name=created_feature_group.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_group_obj + + @base.optional_sync() + def delete(self, force: bool = False, sync: bool = True) -> None: + """Deletes this feature group. + + WARNING: This deletion is permanent. + + Args: + force: + If set to True, all features under this online store will be + deleted prior to online store deletion. Otherwise, deletion + will only succeed if the online store has no FeatureViews. + + If set to true, any Features under this FeatureGroup will also + be deleted. (Otherwise, the request will only work if the + FeatureGroup has no Features.) + sync: + Whether to execute this deletion synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + + lro = getattr(self.api_client, self._delete_method)( + name=self.resource_name, + force=force, + ) + _LOGGER.log_delete_with_lro(self, lro) + lro.result() + _LOGGER.log_delete_complete(self) + + def get_feature( + self, + feature_id: str, + latest_stats_count: Optional[int] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> Feature: + """Retrieves an existing managed feature. + + Args: + feature_id: The ID of the feature. + latest_stats_count: + The number of latest stats to retrieve. Only returns stats if + Feature Monitor is created, and historical stats were generated. + credentials: + Custom credentials to use to retrieve the feature under this + feature group. The order of which credentials are used is as + follows: (1) this parameter (2) credentials passed to FeatureGroup + constructor (3) credentials set in aiplatform.init. + + Returns: + Feature - the Feature resource object under this feature group. + """ + credentials = ( + credentials or self.credentials or initializer.global_config.credentials + ) + if latest_stats_count is not None: + return Feature( + name=f"{self.resource_name}/features/{feature_id}", + latest_stats_count=latest_stats_count, + credentials=credentials, + ) + return Feature( + f"{self.resource_name}/features/{feature_id}", credentials=credentials + ) + + def create_feature( + self, + name: str, + version_column_name: Optional[str] = None, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + point_of_contact: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> Feature: + """Creates a new feature. + + Args: + name: The name of the feature. + version_column_name: + The name of the BigQuery Table/View column hosting data for this + version. If no value is provided, will use feature_id. + description: Description of the feature. + labels: + The labels with user-defined metadata to organize your Features. + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information on and examples + of labels. No more than 64 user labels can be associated with + one Feature (System labels are excluded)." System reserved label + keys are prefixed with "aiplatform.googleapis.com/" and are + immutable. + point_of_contact: + Entity responsible for maintaining this feature. Can be comma + separated list of email addresses or URIs. + project: + Project to create feature in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature. Overrides + credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + Feature - the Feature resource object. + """ + + gapic_feature = gca_feature.Feature() + + if version_column_name: + gapic_feature.version_column_name = version_column_name + + if description: + gapic_feature.description = description + + if labels: + utils.validate_labels(labels) + gapic_feature.labels = labels + + if point_of_contact: + gapic_feature.point_of_contact = point_of_contact + + if request_metadata is None: + request_metadata = () + + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + create_feature_lro = api_client.create_feature( + parent=self.resource_name, + feature=gapic_feature, + feature_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(Feature, create_feature_lro) + + created_feature = create_feature_lro.result() + + _LOGGER.log_create_complete(Feature, created_feature, "feature") + + feature_obj = Feature( + name=created_feature.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_obj + + def list_features( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[Feature]: + """Lists features under this feature group. + + Args: + project: + Project to list features in. If unset, the project set in + aiplatform.init will be used. + location: + Location to list features in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to list features. Overrides + credentials set in aiplatform.init. + + Returns: + List of features under this feature group. + """ + + return Feature.list( + parent=self.resource_name, + project=project, + location=location, + credentials=credentials, + ) + + def get_feature_monitor( + self, + feature_monitor_id: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> FeatureMonitor: + """Retrieves an existing feature monitor. + + Args: + feature_monitor_id: The ID of the feature monitor. + credentials: + Custom credentials to use to retrieve the feature monitor under this + feature group. The order of which credentials are used is as + follows: (1) this parameter (2) credentials passed to FeatureGroup + constructor (3) credentials set in aiplatform.init. + + Returns: + FeatureMonitor - the Feature Monitor resource object under this + feature group. + """ + credentials = ( + credentials or self.credentials or initializer.global_config.credentials + ) + return FeatureMonitor( + f"{self.resource_name}/featureMonitors/{feature_monitor_id}", + credentials=credentials, + ) + + def create_feature_monitor( + self, + name: str, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + schedule_config: Optional[str] = None, + feature_selection_configs: Optional[List[Tuple[str, float]]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + ) -> FeatureMonitor: + """Creates a new feature monitor. + + Args: + name: The name of the feature monitor. + description: Description of the feature monitor. + labels: + The labels with user-defined metadata to organize your FeatureMonitors. + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information on and examples + of labels. No more than 64 user labels can be associated with + one FeatureMonitor (System labels are excluded)." System reserved label + keys are prefixed with "aiplatform.googleapis.com/" and are + immutable. + schedule_config: + Configures when data is to be monitored for this + FeatureMonitor. At the end of the scheduled time, + the stats and drift are generated for the selected features. + Example format: "TZ=America/New_York 0 9 * * *" (monitors + daily at 9 AM EST). + feature_selection_configs: + List of tuples of feature id and monitoring threshold. If unset, + all features in the feature group will be monitored, and the + default thresholds 0.3 will be used. + project: + Project to create feature in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature. Overrides + credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + + Returns: + FeatureMonitor - the FeatureMonitor resource object. + """ + + gapic_feature_monitor = gca_feature_monitor.FeatureMonitor() + + if description: + gapic_feature_monitor.description = description + + if labels: + utils.validate_labels(labels) + gapic_feature_monitor.labels = labels + + if request_metadata is None: + request_metadata = () + + if schedule_config: + gapic_feature_monitor.schedule_config = gca_feature_monitor.ScheduleConfig( + cron=schedule_config + ) + + if feature_selection_configs is None: + raise ValueError( + "Please specify feature_configs: features to be monitored and" + " their thresholds." + ) + + if feature_selection_configs is not None: + gapic_feature_monitor.feature_selection_config.feature_configs = [ + gca_feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id=feature_id, + drift_threshold=threshold if threshold else 0.3, + ) + for feature_id, threshold in feature_selection_configs + ] + + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + create_feature_monitor_lro = api_client.select_version( + "v1beta1" + ).create_feature_monitor( + parent=self.resource_name, + feature_monitor=gapic_feature_monitor, + feature_monitor_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(FeatureMonitor, create_feature_monitor_lro) + + created_feature_monitor = create_feature_monitor_lro.result() + + _LOGGER.log_create_complete( + FeatureMonitor, created_feature_monitor, "feature_monitor" + ) + + feature_monitor_obj = FeatureMonitor( + name=created_feature_monitor.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_monitor_obj + + def list_feature_monitors( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[FeatureMonitor]: + """Lists features monitors under this feature group. + + Args: + project: + Project to list feature monitors in. If unset, the project set in + aiplatform.init will be used. + location: + Location to list feature monitors in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to list feature monitors. Overrides + credentials set in aiplatform.init. + + Returns: + List of feature monitors under this feature group. + """ + + return FeatureMonitor.list( + parent=self.resource_name, + project=project, + location=location, + credentials=credentials, + ) + + @property + def source(self) -> FeatureGroupBigQuerySource: + return FeatureGroupBigQuerySource( + uri=self._gca_resource.big_query.big_query_source.input_uri, + entity_id_columns=self._gca_resource.big_query.entity_id_columns, + ) diff --git a/agentplatform/resources/preview/feature_store/feature_monitor.py b/agentplatform/resources/preview/feature_store/feature_monitor.py new file mode 100644 index 0000000000..f0172d22f1 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/feature_monitor.py @@ -0,0 +1,335 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import List, Dict, Optional, Tuple, Sequence +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base, initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature_monitor_v1beta1 as gca_feature_monitor, + feature_monitor_job_v1beta1 as gca_feature_monitor_job, +) + +_LOGGER = base.Logger(__name__) + + +class FeatureMonitor(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Monitor resources.""" + + client_class = utils.FeatureRegistryClientV1Beta1WithOverride + + _resource_noun = "feature_monitors" + _getter_method = "get_feature_monitor" + _list_method = "list_feature_monitors" + _delete_method = "delete_feature_monitor" + _parse_resource_name_method = "parse_feature_monitor_path" + _format_resource_name_method = "feature_monitor_path" + _gca_resource: gca_feature_monitor.FeatureMonitor + + def __init__( + self, + name: str, + feature_group_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature. + + Args: + name: + The resource name + (`projects/.../locations/.../featureGroups/.../featureMonitors/...`) or + ID. + feature_group_id: + The feature group ID. Must be passed in if name is an ID and not + a resource path. + project: + Project to retrieve feature from. If not set, the project set in + aiplatform.init will be used. + location: + Location to retrieve feature from. If not set, the location set + in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature. Overrides + credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if re.fullmatch( + r"projects/.+/locations/.+/featureGroups/.+/featureMonitors/.+", + name, + ): + if feature_group_id: + raise ValueError( + f"Since feature monitor '{name}' is provided as a path, feature_group_id should not be specified." + ) + feature_monitor = name + else: + from .feature_group import FeatureGroup + + # Construct the feature path using feature group ID if only the + # feature group ID is provided. + if not feature_group_id: + raise ValueError( + f"Since feature monitor '{name}' is not provided as a path, please specify feature_group_id." + ) + + feature_group_path = utils.full_resource_name( + resource_name=feature_group_id, + resource_noun=FeatureGroup._resource_noun, + parse_resource_name_method=FeatureGroup._parse_resource_name, + format_resource_name_method=FeatureGroup._format_resource_name, + ) + + feature_monitor = f"{feature_group_path}/featureMonitors/{name}" + + self._gca_resource = self._get_gca_resource(resource_name=feature_monitor) + + @property + def description(self) -> str: + """The description of the feature monitor.""" + return self._gca_resource.description + + @property + def schedule_config(self) -> str: + """The schedule config of the feature monitor.""" + return self._gca_resource.schedule_config.cron + + @property + def feature_selection_configs(self) -> List[Tuple[str, float]]: + """The feature and it's drift threshold configs of the feature monitor.""" + configs: List[Tuple[str, float]] = [] + for ( + feature_config + ) in self._gca_resource.feature_selection_config.feature_configs: + configs.append( + ( + feature_config.feature_id, + ( + feature_config.drift_threshold + if feature_config.drift_threshold + else 0.3 + ), + ) + ) + return configs + + class FeatureMonitorJob(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Monitor Job resources.""" + + client_class = utils.FeatureRegistryClientV1Beta1WithOverride + + _resource_noun = "featureMonitorJobs" + _getter_method = "get_feature_monitor_job" + _list_method = "list_feature_monitor_jobs" + _delete_method = "delete_feature_monitor_job" + _parse_resource_name_method = "parse_feature_monitor_job_path" + _format_resource_name_method = "feature_monitor_job_path" + _gca_resource: gca_feature_monitor_job.FeatureMonitorJob + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature monitor job. + + Args: + name: The resource name + (`projects/.../locations/.../featureGroups/.../featureMonitors/.../featureMonitorJobs/...`) + project: Project to retrieve the feature monitor job from. If + unset, the project set in aiplatform.init will be used. + location: Location to retrieve the feature monitor job from. If + not set, location set in aiplatform.init will be used. + credentials: Custom credentials to use to retrieve this feature + monitor job. Overrides credentials set in aiplatform.init. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if not re.fullmatch( + r"projects/.+/locations/.+/featureGroups/.+/featureMonitors/.+/featureMonitorJobs/.+", + name, + ): + raise ValueError( + "name need to specify the fully qualified" + + " feature monitor job resource path." + ) + + self._gca_resource = self._get_gca_resource(resource_name=name) + + @property + def description(self) -> str: + """The description of the feature monitor.""" + return self._gca_resource.description + + @property + def feature_stats_and_anomalies( + self, + ) -> List[gca_feature_monitor.FeatureStatsAndAnomaly]: + """The feature stats and anomaly of the feature monitor job.""" + if self._gca_resource.job_summary: + return self._gca_resource.job_summary.feature_stats_and_anomalies + return [] + + def create_feature_monitor_job( + self, + description: Optional[str] = None, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + ) -> FeatureMonitorJob: + """Creates a new feature monitor job. + + Args: + description: Description of the feature monitor job. + labels: + The labels with user-defined metadata to organize your + FeatureMonitorJobs. + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information on and examples + of labels. No more than 64 user labels can be associated with + one FeatureMonitor (System labels are excluded)." System reserved label + keys are prefixed with "aiplatform.googleapis.com/" and are + immutable. + project: + Project to create feature in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature. Overrides + credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + + Returns: + FeatureMonitorJob - the FeatureMonitorJob resource object. + """ + + gapic_feature_monitor_job = gca_feature_monitor_job.FeatureMonitorJob() + + if description: + gapic_feature_monitor_job.description = description + + if labels: + utils.validate_labels(labels) + gapic_feature_monitor_job.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + created_feature_monitor_job = api_client.select_version( + "v1beta1" + ).create_feature_monitor_job( + parent=self.resource_name, + feature_monitor_job=gapic_feature_monitor_job, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + feature_monitor_job_obj = self.FeatureMonitorJob( + name=created_feature_monitor_job.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_monitor_job_obj + + def get_feature_monitor_job( + self, + feature_monitor_job_id: str, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> FeatureMonitorJob: + """Retrieves an existing feature monitor. + + Args: + feature_monitor_job_id: The ID of the feature monitor job. + credentials: + Custom credentials to use to retrieve the feature monitor job under this + feature monitor. The order of which credentials are used is as + follows - (1) this parameter (2) credentials passed to FeatureMonitor + constructor (3) credentials set in aiplatform.init. + + Returns: + FeatureMonitorJob - the Feature Monitor Job resource object under this + feature monitor. + """ + credentials = ( + credentials or self.credentials or initializer.global_config.credentials + ) + return FeatureMonitor.FeatureMonitorJob( + f"{self.resource_name}/featureMonitorJobs/{feature_monitor_job_id}", + credentials=credentials, + ) + + def list_feature_monitor_jobs( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[FeatureMonitorJob]: + """Lists features monitor jobs under this feature monitor. + + Args: + project: + Project to list feature monitors in. If unset, the project set in + aiplatform.init will be used. + location: + Location to list feature monitors in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to list feature monitors. Overrides + credentials set in aiplatform.init. + + Returns: + List of feature monitor jobs under this feature monitor. + """ + + return FeatureMonitor.FeatureMonitorJob.list( + parent=self.resource_name, + project=project, + location=location, + credentials=credentials, + ) diff --git a/agentplatform/resources/preview/feature_store/feature_online_store.py b/agentplatform/resources/preview/feature_store/feature_online_store.py new file mode 100644 index 0000000000..c631494cb2 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/feature_online_store.py @@ -0,0 +1,647 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import enum +from typing import ( + Dict, + List, + Optional, + Sequence, + Tuple, + Union, +) + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import ( + base, + initializer, + utils, +) +from google.cloud.aiplatform.compat.types import ( + feature_online_store as gca_feature_online_store, + service_networking as gca_service_networking, + feature_view as gca_feature_view, +) +from agentplatform.resources.preview.feature_store.feature_view import ( + FeatureView, +) +from agentplatform.resources.preview.feature_store.utils import ( + IndexConfig, + FeatureViewBigQuerySource, + FeatureViewVertexRagSource, + FeatureViewRegistrySource, +) + + +_LOGGER = base.Logger(__name__) + + +@enum.unique +class FeatureOnlineStoreType(enum.Enum): + UNKNOWN = 0 + BIGTABLE = 1 + OPTIMIZED = 2 + + +class FeatureOnlineStore(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Online Store resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "feature_online_stores" + _getter_method = "get_feature_online_store" + _list_method = "list_feature_online_stores" + _delete_method = "delete_feature_online_store" + _parse_resource_name_method = "parse_feature_online_store_path" + _format_resource_name_method = "feature_online_store_path" + _gca_resource: gca_feature_online_store.FeatureOnlineStore + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature online store. + + Args: + name: + The resource name + (`projects/.../locations/.../featureOnlineStores/...`) or ID. + project: + Project to retrieve feature online store from. If unset, the + project set in aiplatform.init will be used. + location: + Location to retrieve feature online store from. If not set, + location set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature online store. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + self._gca_resource = self._get_gca_resource(resource_name=name) + + @classmethod + @base.optional_sync() + def create_bigtable_store( + cls, + name: str, + min_node_count: Optional[int] = 1, + max_node_count: Optional[int] = 1, + cpu_utilization_target: Optional[int] = 50, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureOnlineStore": + """Creates a Bigtable online store. + + Example Usage: + + my_fos = agentplatform.preview.FeatureOnlineStore.create_bigtable_store('my_fos') + + Args: + name: The name of the feature online store. + min_node_count: + The minimum number of Bigtable nodes to scale down to. Must be + greater than or equal to 1. + max_node_count: + The maximum number of Bigtable nodes to scale up to. Must + satisfy min_node_count <= max_node_count <= (10 * + min_node_count). + cpu_utilization_target: + A percentage of the cluster's CPU capacity. Can be from 10% to + 80%. When a cluster's CPU utilization exceeds the target that + you have set, Bigtable immediately adds nodes to the cluster. + When CPU utilization is substantially lower than the target, + Bigtable removes nodes. If not set will default to 50%. + labels: + The labels with user-defined metadata to organize your feature + online store. Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information on and examples of labels. No more than 64 + user labels can be associated with one feature online store + (System labels are excluded)." System reserved label keys are + prefixed with "aiplatform.googleapis.com/" and are immutable. + project: + Project to create feature online store in. If unset, the project + set in aiplatform.init will be used. + location: + Location to create feature online store in. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature online store. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureOnlineStore - the FeatureOnlineStore resource object. + """ + + if min_node_count < 1: + raise ValueError("min_node_count must be greater than or equal to 1") + + if max_node_count < min_node_count: + raise ValueError( + "max_node_count must be greater than or equal to min_node_count" + ) + elif 10 * min_node_count < max_node_count: + raise ValueError( + "max_node_count must be less than or equal to 10 * min_node_count" + ) + + if cpu_utilization_target < 10 or cpu_utilization_target > 80: + raise ValueError("cpu_utilization_target must be between 10 and 80") + + gapic_feature_online_store = gca_feature_online_store.FeatureOnlineStore( + bigtable=gca_feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=gca_feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=min_node_count, + max_node_count=max_node_count, + cpu_utilization_target=cpu_utilization_target, + ), + ), + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_online_store.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_online_store_lro = api_client.create_feature_online_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_online_store=gapic_feature_online_store, + feature_online_store_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_online_store_lro) + + created_online_store = create_online_store_lro.result() + + _LOGGER.log_create_complete(cls, created_online_store, "feature_online_store") + + online_store_obj = cls( + name=created_online_store.name, + project=project, + location=location, + credentials=credentials, + ) + + return online_store_obj + + @classmethod + @base.optional_sync() + def create_optimized_store( + cls, + name: str, + enable_private_service_connect: bool = False, + project_allowlist: Optional[Sequence[str]] = None, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureOnlineStore": + """Creates an Optimized online store. + + Example Usage: + + ``` + # Create optimized store with public endpoint. + my_fos = agentplatform.preview.FeatureOnlineStore.create_optimized_store( + 'my_fos' + ) + ``` + + ``` + # Create optimized online store with private service connect. + my_fos = agentplatform.preview.FeatureOnlineStore.create_optimized_store( + 'my_fos', + enable_private_service_connect=True, + project_allowlist=['my-project'], + ) + ``` + + Args: + name: The name of the feature online store. + enable_private_service_connect: + Optional. If true, expose the optimized online store + via private service connect. Otherwise the optimized online + store will be accessible through public endpoint. + project_allowlist: + A list of Projects from which the forwarding + rule will target the service attachment. Only needed when + `enable_private_service_connect` is set to true. + labels: + The labels with user-defined metadata to organize your feature + online store. Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information on and examples of labels. No more than 64 + user labels can be associated with one feature online store + (System labels are excluded)." System reserved label keys are + prefixed with "aiplatform.googleapis.com/" and are immutable. + project: + Project to create feature online store in. If unset, the project + set in aiplatform.init will be used. + location: + Location to create feature online store in. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature online store. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureOnlineStore - the FeatureOnlineStore resource object. + """ + if enable_private_service_connect: + if not project_allowlist: + raise ValueError( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is set to true." + ) + + dedicated_serving_endpoint = gca_feature_online_store.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=gca_service_networking.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=project_allowlist, + ), + ) + else: + dedicated_serving_endpoint = ( + gca_feature_online_store.FeatureOnlineStore.DedicatedServingEndpoint() + ) + + gapic_feature_online_store = gca_feature_online_store.FeatureOnlineStore( + optimized=gca_feature_online_store.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=dedicated_serving_endpoint, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_online_store.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_online_store_lro = api_client.create_feature_online_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_online_store=gapic_feature_online_store, + feature_online_store_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_online_store_lro) + + created_online_store = create_online_store_lro.result() + + _LOGGER.log_create_complete(cls, created_online_store, "feature_online_store") + + online_store_obj = cls( + name=created_online_store.name, + project=project, + location=location, + credentials=credentials, + ) + + return online_store_obj + + @base.optional_sync() + def delete(self, force: bool = False, sync: bool = True) -> None: + """Deletes this online store. + + WARNING: This deletion is permanent. + + Args: + force: + If set to True, all feature views under this online store will + be deleted prior to online store deletion. Otherwise, deletion + will only succeed if the online store has no FeatureViews. + sync: + Whether to execute this deletion synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + + lro = getattr(self.api_client, self._delete_method)( + name=self.resource_name, + force=force, + ) + _LOGGER.log_delete_with_lro(self, lro) + lro.result() + _LOGGER.log_delete_complete(self) + + @property + def feature_online_store_type(self) -> FeatureOnlineStoreType: + if self._gca_resource.bigtable: + return FeatureOnlineStoreType.BIGTABLE + # Optimized is an empty proto, so self._gca_resource.optimized is always false. + elif hasattr(self.gca_resource, "optimized"): + return FeatureOnlineStoreType.OPTIMIZED + else: + raise ValueError( + f"Online store does not have type or is unsupported by SDK: {self._gca_resource}." + ) + + @property + def labels(self) -> Dict[str, str]: + return self._gca_resource.labels + + @base.optional_sync() + def create_feature_view( + self, + name: str, + source: Union[ + FeatureViewBigQuerySource, + FeatureViewVertexRagSource, + FeatureViewRegistrySource, + ], + labels: Optional[Dict[str, str]] = None, + sync_config: Optional[str] = None, + index_config: Optional[IndexConfig] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> FeatureView: + """Creates a FeatureView from a BigQuery source. + + Example Usage: + ``` + existing_fos = FeatureOnlineStore('my_fos') + new_fv = existing_fos.create_feature_view( + 'my_fos', + BigQuerySource( + uri='bq://my-proj/dataset/table', + entity_id_columns=['entity_id'], + ) + ) + # Example for how to create an embedding FeatureView. + embedding_fv = existing_fos.create_feature_view( + 'my_fos', + BigQuerySource( + uri='bq://my-proj/dataset/table', + entity_id_columns=['entity_id'], + ) + index_config=IndexConfig( + embedding_column="embedding", + filter_column=["currency_code", "gender", + crowding_column="crowding", + dimentions=1536, + distance_measure_type=DistanceMeasureType.SQUARED_L2_DISTANCE, + algorithm_config=TreeAhConfig(), + ) + ) + ``` + Args: + name: The name of the feature view. + source: + The source to load data from when a feature view sync runs. + Currently supports a BigQuery source, Vertex RAG source, Registry source. + labels: + The labels with user-defined metadata to organize your + FeatureViews. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information on and examples + of labels. No more than 64 user labels can be associated with + one FeatureOnlineStore(System labels are excluded)." System + reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + sync_config: + Configures when data is to be synced/updated for this + FeatureView. At the end of the sync the latest feature values + for each entity ID of this FeatureView are made ready for online + serving. Example format: "TZ=America/New_York 0 9 * * *" (sync + daily at 9 AM EST). + index_config: + Configuration for index preparation for vector search. It + contains the required configurations to create an index from + source data, so that approximate nearest neighbor (a.k.a ANN) + algorithms search can be performed during online serving. + project: + Project to create feature view in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature view in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature view. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureView - the FeatureView resource object. + """ + if not source: + raise ValueError("Please specify a valid source.") + + big_query_source = None + vertex_rag_source = None + feature_registry_source = None + + if isinstance(source, FeatureViewBigQuerySource): + if not source.uri: + raise ValueError("Please specify URI in BigQuery source.") + + if not source.entity_id_columns: + raise ValueError("Please specify entity ID columns in BigQuery source.") + + big_query_source = gca_feature_view.FeatureView.BigQuerySource( + uri=source.uri, + entity_id_columns=source.entity_id_columns, + ) + elif isinstance(source, FeatureViewVertexRagSource): + if not source.uri: + raise ValueError("Please specify URI in Vertex RAG source.") + + vertex_rag_source = gca_feature_view.FeatureView.VertexRagSource( + uri=source.uri, + rag_corpus_id=source.rag_corpus_id or None, + ) + elif isinstance(source, FeatureViewRegistrySource): + if not source.features: + raise ValueError( + "Please specify features in Registry Source in format `.`." + ) + feature_group_mappings = {} + for feature in source.features: + feature_group_id, feature_id = feature.split(".") + if not feature_id or not feature_group_id: + raise ValueError( + "Please specify features in Registry Source in format `.`." + ) + if feature_group_id in feature_group_mappings: + feature_group_mappings[feature_group_id].append(feature_id) + else: + feature_group_mappings[feature_group_id] = [feature_id] + feature_groups = [] + for feature_group_id in feature_group_mappings: + feature_ids = feature_group_mappings[feature_group_id] + feature_groups.append( + gca_feature_view.FeatureView.FeatureRegistrySource.FeatureGroup( + feature_group_id=feature_group_id, + feature_ids=feature_ids, + ) + ) + feature_registry_source = ( + gca_feature_view.FeatureView.FeatureRegistrySource( + feature_groups=feature_groups, + project_number=source.project_number or None, + ) + ) + else: + raise ValueError( + "Only FeatureViewBigQuerySource, FeatureViewVertexRagSource and FeatureViewRegistrySource are supported sources." + ) + + gapic_feature_view = gca_feature_view.FeatureView( + big_query_source=big_query_source, + vertex_rag_source=vertex_rag_source, + feature_registry_source=feature_registry_source, + sync_config=( + gca_feature_view.FeatureView.SyncConfig(cron=sync_config) + if sync_config + else None + ), + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_view.labels = labels + + if request_metadata is None: + request_metadata = () + + if index_config: + gapic_feature_view.index_config = gca_feature_view.FeatureView.IndexConfig( + index_config.as_dict() + ) + + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + create_feature_view_lro = api_client.create_feature_view( + parent=self.resource_name, + feature_view=gapic_feature_view, + feature_view_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(FeatureView, create_feature_view_lro) + + created_feature_view = create_feature_view_lro.result() + + _LOGGER.log_create_complete(FeatureView, created_feature_view, "feature_view") + + feature_view_obj = FeatureView( + name=created_feature_view.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_view_obj + + def list_feature_views( + self, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List[FeatureView]: + """Lists feature views under this feature online store. + + Args: + project: + Project to list feature views in. If unset, the project set in + aiplatform.init will be used. + location: + Location to list feature views in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to list feature views. Overrides + credentials set in aiplatform.init. + + Returns: + List of feature views under this feature online store. + """ + + return FeatureView.list( + feature_online_store_id=self.name, + project=project, + location=location, + credentials=credentials, + ) diff --git a/agentplatform/resources/preview/feature_store/feature_view.py b/agentplatform/resources/preview/feature_store/feature_view.py new file mode 100644 index 0000000000..84ba9a65c7 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/feature_view.py @@ -0,0 +1,537 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import List, Dict, Optional +from google.cloud.aiplatform import initializer +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature_view_sync as gca_feature_view_sync, + feature_view as gca_feature_view, + feature_online_store_service as fos_service, +) +import agentplatform.resources.preview.feature_store.utils as fs_utils + +_LOGGER = base.Logger(__name__) + + +class FeatureView(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature View resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "featureViews" + _getter_method = "get_feature_view" + _list_method = "list_feature_views" + _delete_method = "delete_feature_view" + _parse_resource_name_method = "parse_feature_view_path" + _format_resource_name_method = "feature_view_path" + _gca_resource: gca_feature_view.FeatureView + _online_store_client: utils.FeatureOnlineStoreClientWithOverride + + _online_store_clients_with_connection_options: Dict[ + fs_utils.ConnectionOptions, utils.FeatureOnlineStoreClientWithOverride + ] = None + + def __init__( + self, + name: str, + feature_online_store_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature view. + + Args: + name: + The resource name + (`projects/.../locations/.../featureOnlineStores/.../featureViews/...`) + or ID. + feature_online_store_id: + The feature online store ID. Must be passed in if name is an ID + and not a resource path. + project: + Project to retrieve the feature view from. If unset, the project + set in aiplatform.init will be used. + location: + Location to retrieve the feature view from. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature view. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if re.fullmatch( + r"projects/.+/locations/.+/featureOnlineStores/.+/featureViews/.+", + name, + ): + feature_view = name + else: + from .feature_online_store import FeatureOnlineStore + + # Construct the feature view path using feature online store ID if + # only the feature view ID is provided. + if not feature_online_store_id: + raise ValueError( + "Since feature view is not provided as a path, please specify" + + " feature_online_store_id." + ) + + feature_online_store_path = utils.full_resource_name( + resource_name=feature_online_store_id, + resource_noun=FeatureOnlineStore._resource_noun, + parse_resource_name_method=FeatureOnlineStore._parse_resource_name, + format_resource_name_method=FeatureOnlineStore._format_resource_name, + ) + + feature_view = f"{feature_online_store_path}/featureViews/{name}" + + self._gca_resource = self._get_gca_resource(resource_name=feature_view) + + def _get_online_store_client( + self, connection_options: Optional[fs_utils.ConnectionOptions] = None + ) -> utils.FeatureOnlineStoreClientWithOverride: + """Return the online store client. + + Also sets the `_online_store_client` attr if not set yet. Note that if + `connection_options` is passed in, the `_online_store_client` attr will + not be set - only the client will be returned. If the same + `connection_options` is passed in, this code will return the same + (cached) client as previously built. + """ + if getattr(self, "_online_store_client", None): + return self._online_store_client + + fos_name = fs_utils.get_feature_online_store_name(self.resource_name) + from .feature_online_store import FeatureOnlineStore + + fos = FeatureOnlineStore(name=fos_name) + + if connection_options: + # Check if we have a previously client created for these + # connection_options. + if self._online_store_clients_with_connection_options is None: + self._online_store_clients_with_connection_options = {} + if connection_options in self._online_store_clients_with_connection_options: + return self._online_store_clients_with_connection_options[ + connection_options + ] + host = connection_options.host + + if isinstance( + connection_options.transport, + fs_utils.ConnectionOptions.InsecureGrpcChannel, + ): + import grpc + from google.cloud.aiplatform_v1.services import ( + feature_online_store_service as feature_online_store_service_v1, + ) + from google.cloud.aiplatform_v1beta1.services import ( + feature_online_store_service as feature_online_store_service_v1beta1, + ) + + gapic_client_class = ( + utils.FeatureOnlineStoreClientWithOverride.get_gapic_client_class() + ) + gapic_client_class_to_transport_class = { + feature_online_store_service_v1.client.FeatureOnlineStoreServiceClient: ( + feature_online_store_service_v1.transports.grpc.FeatureOnlineStoreServiceGrpcTransport + ), + feature_online_store_service_v1beta1.client.FeatureOnlineStoreServiceClient: ( + feature_online_store_service_v1beta1.transports.grpc.FeatureOnlineStoreServiceGrpcTransport + ), + } + if gapic_client_class not in gapic_client_class_to_transport_class: + raise ValueError( + f"Unexpected gapic class '{gapic_client_class}' used by internal client." + ) + + transport_class = gapic_client_class_to_transport_class[ + gapic_client_class + ] + + client = gapic_client_class( + transport=transport_class( + channel=grpc.insecure_channel(host + ":10002") + ), + ) + + self._online_store_clients_with_connection_options[ + connection_options + ] = client + return client + else: + raise ValueError( + f"Unsupported connection transport type, got transport: {connection_options.transport}" + ) + + if fos._gca_resource.bigtable.auto_scaling: + # This is Bigtable online store. + _LOGGER.info(f"Connecting to Bigtable online store name {fos_name}") + self._online_store_client = initializer.global_config.create_client( + client_class=utils.FeatureOnlineStoreClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + return self._online_store_client + + if ( + fos._gca_resource.dedicated_serving_endpoint.private_service_connect_config.enable_private_service_connect + ): + raise ValueError( + "Use `connection_options` to specify an IP address. Required for optimized online store with private service connect." + ) + + # From here, optimized serving with public endpoint. + if not fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name: + raise fs_utils.PublicEndpointNotFoundError( + "Public endpoint is not created yet for the optimized online store:" + f"{fos_name}. Please run sync and wait for it to complete." + ) + + _LOGGER.info( + f"Public endpoint for the optimized online store {fos_name} is" + f" {fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name}" + ) + self._online_store_client = initializer.global_config.create_client( + client_class=utils.FeatureOnlineStoreClientWithOverride, + credentials=self.credentials, + location_override=self.location, + prediction_client=True, + api_path_override=fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name, + ) + return self._online_store_client + + @classmethod + def list( + cls, + feature_online_store_id: str, + filter: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["FeatureView"]: + """List all feature view under feature_online_store_id. + + Example Usage: + ``` + feature_views = agentplatform.preview.FeatureView.list( + feature_online_store_id="my_fos", + filter=labels.label_key=label_value) + ``` + Args: + feature_online_store_id: + Parentfeature online store ID. + filter: + Filter to apply on the returned feature online store. + project: + Project to use to get a list of feature views. If unset, the + project set in aiplatform.init will be used. + location: + Location to use to get a list feature views. If not set, + location set in aiplatform.init will be used. + credentials: + Custom credentials to use to get a list of feature views. + Overrides credentials set in aiplatform.init. + + Returns: + List[FeatureView] - list of FeatureView resource object. + """ + from .feature_online_store import FeatureOnlineStore + + fos = FeatureOnlineStore( + name=feature_online_store_id, + project=project, + location=location, + credentials=credentials, + ) + return cls._list( + filter=filter, credentials=credentials, parent=fos.resource_name + ) + + @base.optional_sync() + def delete(self, sync: bool = True) -> None: + """Deletes this feature view. + + WARNING: This deletion is permanent. + + Args: + sync: + Whether to execute this deletion synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + lro = getattr(self.api_client, self._delete_method)(name=self.resource_name) + _LOGGER.log_delete_with_lro(self, lro) + lro.result() + _LOGGER.log_delete_complete(self) + + def sync(self) -> "FeatureViewSync": + """Starts an on-demand Sync for the FeatureView. + + Args: None + + Returns: + "FeatureViewSync" - FeatureViewSync instance + """ + sync_method = getattr(self.api_client, self.FeatureViewSync.sync_method()) + + sync_request = { + "feature_view": self.resource_name, + } + sync_response = sync_method(request=sync_request) + + return self.FeatureViewSync(name=sync_response.feature_view_sync) + + def get_sync(self, name) -> "FeatureViewSync": + """Gets the FeatureViewSync resource for the given name. + + Args: + name: The resource ID + + Returns: + "FeatureViewSync" - FeatureViewSync instance + """ + feature_view_path = self.resource_name + feature_view_sync = f"{feature_view_path}/featureViewSyncs/{name}" + return self.FeatureViewSync(name=feature_view_sync) + + def list_syncs( + self, + filter: Optional[str] = None, + ) -> List["FeatureViewSync"]: + """List all feature view under this FeatureView. + + Args: + parent_resource_name: Fully qualified name of the parent FeatureView + resource. + filter: Filter to apply on the returned feature online store. + + Returns: + List[FeatureViewSync] - list of FeatureViewSync resource object. + """ + + return self.FeatureViewSync._list( + filter=filter, credentials=self.credentials, parent=self.resource_name + ) + + def read( + self, + key: List[str], + connection_options: Optional[fs_utils.ConnectionOptions] = None, + request_timeout: Optional[float] = None, + ) -> fs_utils.FeatureViewReadResponse: + """Read the feature values from FeatureView. + + Example Usage: + Read feature view. Use this for Bigtable online stores and for + Optimized online stores that use public endpoint. + ``` + data = agentplatform.preview.FeatureView( + name='feature_view_name', feature_online_store_id='fos_name') + .read(key=[12345, 6789]) + .to_dict() + ``` + + Read feature view using IP with an insecure gRPC channel. Use this + for optimized online stores using private service connect. + ``` + data = agentplatform.preview.FeatureView( + name='feature_view_name', feature_online_store_id='fos_name') + .read( + key=[12345, 6789], + connection_options=fs_utils.ConnectionOptions( + host="", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel())) + .to_dict() + ``` + Args: + key: The request key to read feature values for. + connection_options: + If specified, use these options to connect to a host for sending + requests instead of the default + `-aiplatform.googleapis.com` or the feature online + store's public endpoint. + + Returns: + "FeatureViewReadResponse" - FeatureViewReadResponse object. It is + intermediate class that can be further converted by to_dict() or + to_proto(). + """ + self.wait() + + online_store_client = self._get_online_store_client( + connection_options=connection_options + ) + + response = online_store_client.fetch_feature_values( + feature_view=self.resource_name, + data_key=fos_service.FeatureViewDataKey( + composite_key=fos_service.FeatureViewDataKey.CompositeKey(parts=key) + ), + timeout=request_timeout, + ) + return fs_utils.FeatureViewReadResponse(response) + + def search( + self, + entity_id: Optional[str] = None, + embedding_value: Optional[List[float]] = None, + neighbor_count: Optional[int] = None, + string_filters: Optional[ + List[fos_service.NearestNeighborQuery.StringFilter] + ] = None, + per_crowding_attribute_neighbor_count: Optional[int] = None, + return_full_entity: bool = False, + approximate_neighbor_candidates: Optional[int] = None, + leaf_nodes_search_fraction: Optional[float] = None, + request_timeout: Optional[float] = None, + ) -> fs_utils.SearchNearestEntitiesResponse: + """Search the nearest entities from FeatureView. + + Example Usage: + ``` + data = agentplatform.preview.FeatureView( + name='feature_view_name', feature_online_store_id='fos_name') + .search(entity_id='sample_entity') + .to_dict() + ``` + Args: + entity_id: The entity id whose similar entities should be searched + for. + embedding_value: The embedding vector that be used for similar + search. + neighbor_count: The number of similar entities to be retrieved + from feature view for each query. + string_filters: The list of string filters. + per_crowding_attribute_neighbor_count: Crowding is a constraint on a + neighbor list produced by nearest neighbor search requiring that + no more than sper_crowding_attribute_neighbor_count of the k + neighbors returned have the same value of crowding_attribute. + It's used for improving result diversity. + return_full_entity: If true, return full entities including the + features other than embeddings. + approximate_neighbor_candidates: The number of neighbors to find via + approximate search before exact reordering is performed; if set, + this value must be > neighbor_count. + leaf_nodes_search_fraction: The fraction of the number of leaves to + search, set at query time allows user to tune search performance. + This value increase result in both search accuracy and latency + increase. The value should be between 0.0 and 1.0. + + Returns: + "SearchNearestEntitiesResponse" - SearchNearestEntitiesResponse + object. It is intermediate class that can be further converted by + to_dict() or to_proto() + """ + self.wait() + if entity_id: + embedding = None + elif embedding_value: + embedding = fos_service.NearestNeighborQuery.Embedding( + value=embedding_value + ) + else: + raise ValueError( + "Either entity_id or embedding_value needs to be provided for search." + ) + response = self._get_online_store_client().search_nearest_entities( + request=fos_service.SearchNearestEntitiesRequest( + feature_view=self.resource_name, + query=fos_service.NearestNeighborQuery( + entity_id=entity_id, + embedding=embedding, + neighbor_count=neighbor_count, + string_filters=string_filters, + per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count, # pylint: disable=line-too-long + parameters=fos_service.NearestNeighborQuery.Parameters( + approximate_neighbor_candidates=approximate_neighbor_candidates, + leaf_nodes_search_fraction=leaf_nodes_search_fraction, + ), + ), + return_full_entity=return_full_entity, + ), + timeout=request_timeout, + ) + return fs_utils.SearchNearestEntitiesResponse(response) + + class FeatureViewSync(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature View Sync resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "featureViewSyncs" + _getter_method = "get_feature_view_sync" + _list_method = "list_feature_view_syncs" + _delete_method = "delete_feature_view" + _sync_method = "sync_feature_view" + _parse_resource_name_method = "parse_feature_view_sync_path" + _format_resource_name_method = "feature_view_sync_path" + _gca_resource: gca_feature_view_sync.FeatureViewSync + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature view sync. + + Args: + name: The resource name + (`projects/.../locations/.../featureOnlineStores/.../featureViews/.../featureViewSyncs/...`) + project: Project to retrieve the feature view from. If unset, the + project set in aiplatform.init will be used. + location: Location to retrieve the feature view from. If not set, + location set in aiplatform.init will be used. + credentials: Custom credentials to use to retrieve this feature view. + Overrides credentials set in aiplatform.init. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if not re.fullmatch( + r"projects/.+/locations/.+/featureOnlineStores/.+/featureViews/.+/featureViewSyncs/.+", + name, + ): + raise ValueError( + "name need to specify the fully qualified" + + " feature_view_sync resource path." + ) + + self._gca_resource = getattr(self.api_client, self._getter_method)( + name=name, retry=base._DEFAULT_RETRY + ) + + @classmethod + def sync_method(cls) -> str: + """Returns the sync method.""" + return cls._sync_method diff --git a/agentplatform/resources/preview/feature_store/offline_store.py b/agentplatform/resources/preview/feature_store/offline_store.py new file mode 100644 index 0000000000..4ba369d168 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/offline_store.py @@ -0,0 +1,289 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import re + +from typing import Optional, List, Tuple, Union, TYPE_CHECKING +from google.auth import credentials as auth_credentials +from agentplatform.resources.preview.feature_store import ( + FeatureGroup, + Feature, +) +from google.cloud.aiplatform import initializer, __version__ + +from . import _offline_store_impl as impl + + +if TYPE_CHECKING: + try: + import bigframes + except ImportError: + bigframes = None + + try: + import pandas as pd + except ImportError: + pd = None + + +def _try_import_bigframes(): + """Try to import `bigframes` and return it if successful - otherwise raise an import error.""" + try: + import bigframes + import bigframes.pandas + + return bigframes + except ImportError as exc: + raise ImportError( + "`bigframes` is not installed but required for this functionality." + ) from exc + + +def _get_feature_group_from_feature( + feature: Feature, credentials: auth_credentials.Credentials +): + """Given a feature, return the feature group resource.""" + result = re.fullmatch( + r"projects/(?P.+)/locations/(?P.+)/featureGroups/(?P.+)/features/.+", + feature.resource_name, + ) + + if not result: + raise ValueError("Couldn't find feature group in feature.") + + project = feature.project + location = feature.location + feature_group = result.group("feature_group") + + return FeatureGroup( + feature_group, project=project, location=location, credentials=credentials + ) + + +def _extract_feature_from_str_repr( + str_feature: str, credentials: auth_credentials.Credentials +) -> Tuple[FeatureGroup, Feature]: + """Given a feature in string representation, return the feature and feature group.""" + # TODO: compile expr + place it in a constant + result = re.fullmatch( + r"((?P.*)\.)?(?P.*)\.(?P.*)", + str_feature, + ) + if not result: + raise ValueError( + f"Feature '{str_feature}' is a string but not in expected format 'feature_group.feature' or 'project.feature_group.feature'." + ) + + feature_group = FeatureGroup( + result.group("feature_group"), + project=result.group("project"), # None if no match. + credentials=credentials, + ) + feature = feature_group.get_feature(result.group("feature")) + + return (feature_group, feature) + + +def _feature_to_data_source( + feature_group: FeatureGroup, feature: Feature +) -> impl.DataSource: + qualifying_name = f"{feature_group.name}__{feature.name}" + gbq_column = feature.version_column_name + assert gbq_column + + column_name = feature.name + assert column_name + + timestamp_column = "feature_timestamp" + + # TODO: Expose entity_id_columns as a property in FeatureGroup + entity_id_columns = feature_group._gca_resource.big_query.entity_id_columns + assert entity_id_columns + + bq_uri = feature_group._gca_resource.big_query.big_query_source.input_uri + assert bq_uri + + fully_qualified_table = bq_uri.lstrip("bq://") + assert fully_qualified_table + + query = ( + f"SELECT\n" + f' {", ".join(entity_id_columns)},\n' + f" {gbq_column} AS {column_name},\n" + f" {timestamp_column}\n" + f"FROM {fully_qualified_table}" + ) + + return impl.DataSource( + qualifying_name=qualifying_name, + sql=query, + data_columns=[column_name], + # TODO: this will be parameterized in the future + timestamp_column=timestamp_column, + entity_id_columns=entity_id_columns, + ) + + +class _DataFrameToBigQueryDataFramesConverter: + @classmethod + def to_bigquery_dataframe( + cls, df: "pd.DataFrame", session: "Optional[bigframes.session.Session]" = None + ) -> "bigframes.pandas.DataFrame": + bigframes = _try_import_bigframes() + return bigframes.pandas.DataFrame(data=df, session=session) + + +def fetch_historical_feature_values( + entity_df: "bigframes.pandas.DataFrame", + # TODO: Add support for FeatureView | FeatureGroup | bigframes.pandas.DataFrame + features: List[Union[str, Feature]], + # TODO: Add support for feature_age_threshold + feature_age_threshold: Optional[datetime.timedelta] = None, + dry_run: bool = False, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, +) -> "Union[bigframes.pandas.DataFrame, None]": + """Fetch historical data at the timestamp specified for each entity. + + This runs a Point-In-Time Lookup (PITL) query in BigQuery across all + features and returns the historical feature values. Feature data will be + joined by matching their entity_id_column(s) with corresponding columns in + the entity data frame. + + Args: + entity_df: + An entity DataFrame where one/multiple columns have entity ID. + One column should have a timestamp (used for feature lookup). Other + columns may have feature data. Entity IDs may be repeated with + different timestamp values (in the timestamp column) to lookup data for + entities at different points in time. + features: + Feature data will be joined with the entity data frame. + * If `str` is given use `project.feature_group.feature` as the format. + `project_id.feature_group_id.feature_id` may be used if features are + in another project. + * If `FeatureView` is given, the *sources* of the FeatureView will be + used - but data will be read from the backing BigQuery table. + feature_age_threshold: + How far back from the timestamp to look for features values. If no + feature values are found, empty/null value will be populated. + dry_run: + Build the Point-In-Time Lookup (PITL) query but don't run it. The PITL + query will be printed to stdout. + project: + The project to use for feature lookup and running the Point-In-Time + Lookup (PITL) query in BigQuery. If unset, the project set in + aiplatform.init will be used. + location: + The location to use for feature lookup and running the Point-In-Time + Lookup (PITL) query in BigQuery. If unset, the project set in + aiplatform.init will be used. + credentials: + Custom credentials to use for feature lookup and running the + Point-In-Time Lookup (PITL) query in BigQuery. Overrides credentials + set in aiplatform.init. + + Returns: + A `bigframes.pandas.DataFrame` with the historical feature values. `None` + if in `dry_run` mode. + """ + + bigframes = _try_import_bigframes() + project = project or initializer.global_config.project + location = location or initializer.global_config.location + credentials = credentials or initializer.global_config.credentials + application_name = ( + f"vertexai-offline-store/{__version__}+fetch-historical-feature-values" + ) + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + application_name=application_name, + ) + session = bigframes.connect(session_options) + + if feature_age_threshold is not None: + raise NotImplementedError("feature_age_threshold is not yet supported.") + + if not features: + raise ValueError("Please specify a non-empty list of features.") + + # Convert to bigframe if needed. + if not isinstance(entity_df, bigframes.pandas.DataFrame): + entity_df = _DataFrameToBigQueryDataFramesConverter.to_bigquery_dataframe( + df=entity_df, + session=session, + ) + + # Ensure one timestamp column is present in the entity DataFrame. + ts_cols = entity_df.select_dtypes(include=["datetime"]).columns + if len(ts_cols) > 1: + # TODO: Support multiple timestamp columns by specifying feature_timestamp column in an override. + raise ValueError( + 'Multiple timestamp columns ("datetime" dtype) found in entity DataFrame. ' + "Only one timestamp column is allowed. " + f"Timestamp columns: {', '.join([col for col in ts_cols])}" + ) + elif len(ts_cols) == 0: + raise ValueError( + 'No timestamp column ("datetime" dtype) found in entity DataFrame.' + ) + entity_df_ts_col = ts_cols[0] + entity_df_non_ts_cols = [c for c in entity_df.columns if c != entity_df_ts_col] + entity_data_source = impl.DataSource( + qualifying_name="entity_df", + sql=entity_df.sql, + data_columns=entity_df_non_ts_cols, + timestamp_column=entity_df_ts_col, + ) + + feature_data: List[impl.DataSource] = [] + for feature in features: + if isinstance(feature, Feature): + feature_group = _get_feature_group_from_feature(feature, credentials) + feature_data.append(_feature_to_data_source(feature_group, feature)) + elif isinstance(feature, str): + feature_group, feature = _extract_feature_from_str_repr( + feature, credentials + ) + feature_data.append(_feature_to_data_source(feature_group, feature)) + else: + raise ValueError( + f"Unsupported feature type {type(feature)} found in feature list. Feature: {feature}" + ) + + # TODO: Verify `feature_data`. + # * Ensure that qualifying_names are not interfering. + # * Ensure that feature names are not interfering. + # * Ensure that entity id columns of all features are present in the entity DF. + + query = impl.render_pitl_query( + entity_data=entity_data_source, + feature_data=feature_data, + ) + + if dry_run: + print("--- Dry run mode: PITL QUERY BEGIN ---") + print(query) + print("--- Dry run mode: PITL QUERY END ---") + return None + + return session.read_gbq_query( + query, + index_col=bigframes.enums.DefaultIndexKind.NULL, + ) diff --git a/agentplatform/resources/preview/feature_store/utils.py b/agentplatform/resources/preview/feature_store/utils.py new file mode 100644 index 0000000000..a3f1e90fe2 --- /dev/null +++ b/agentplatform/resources/preview/feature_store/utils.py @@ -0,0 +1,231 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from dataclasses import dataclass +from dataclasses import field +import enum +from typing import Any, Dict, List, Optional, Union +from google.cloud.aiplatform.compat.types import ( + feature_online_store_service as fos_service, +) +import proto +from typing_extensions import override + + +def get_feature_online_store_name(online_store_name: str) -> str: + """Extract Feature Online Store's name from FeatureView's full resource name. + + Args: + online_store_name: Full resource name is projects/project_number/ + locations/us-central1/featureOnlineStores/fos_name/featureViews/fv_name + + Returns: + str: feature online store name. + """ + arr = online_store_name.split("/") + return arr[5] + + +class PublicEndpointNotFoundError(RuntimeError): + """Public endpoint has not been created yet.""" + + +@dataclass +class FeatureViewBigQuerySource: + uri: str + entity_id_columns: List[str] + + +@dataclass +class FeatureViewVertexRagSource: + uri: str + rag_corpus_id: Optional[str] = None + + +@dataclass +class FeatureViewRegistrySource: + """Configuration options for Feature View being registered with Feature Registry features. + + Attributes: + features : Use `.` as + the format for each feature. + project_number : Optional. The project number of the project that owns the + Feature Registry if in a different project. + """ + + features: List[str] + project_number: Optional[int] = None + + +@dataclass(frozen=True) +class ConnectionOptions: + """Represents connection options used for sending RPCs to the online store.""" + + @dataclass(frozen=True) + class InsecureGrpcChannel: + """Use an insecure gRPC channel to connect to the host.""" + + pass + + host: str # IP address or DNS. + transport: Union[ + InsecureGrpcChannel + ] # Currently only insecure gRPC channel is supported. + + def __eq__(self, other): + if self.host != other.host: + return False + + if isinstance(self.transport, ConnectionOptions.InsecureGrpcChannel): + # Insecure grpc channel has no other parameters to check. + if isinstance(other.transport, ConnectionOptions.InsecureGrpcChannel): + return True + + # Otherwise, can't compare against a different transport type. + raise ValueError( + f"Transport '{self.transport}' cannot be compared to transport '{other.transport}'." + ) + + # Currently only InsecureGrpcChannel is supported. + raise ValueError(f"Unsupported transport supplied: {self.transport}") + + +@dataclass +class FeatureViewReadResponse: + _response: fos_service.FetchFeatureValuesResponse + + def __init__(self, response: fos_service.FetchFeatureValuesResponse): + self._response = response + + def to_dict(self) -> Dict[str, Any]: + return proto.Message.to_dict(self._response.key_values) + + def to_proto(self) -> fos_service.FetchFeatureValuesResponse: + return self._response + + +@dataclass +class SearchNearestEntitiesResponse: + _response: fos_service.SearchNearestEntitiesResponse + + def __init__(self, response: fos_service.SearchNearestEntitiesResponse): + self._response = response + + def to_dict(self) -> Dict[str, Any]: + return proto.Message.to_dict(self._response.nearest_neighbors) + + def to_proto(self) -> fos_service.SearchNearestEntitiesResponse: + return self._response + + +class DistanceMeasureType(enum.Enum): + """The distance measure used in nearest neighbor search.""" + + DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0 + # Euclidean (L_2) Distance. + SQUARED_L2_DISTANCE = 1 + # Cosine Distance. Defined as 1 - cosine similarity. + COSINE_DISTANCE = 2 + # Dot Product Distance. Defined as a negative of the dot product. + DOT_PRODUCT_DISTANCE = 3 + + +class AlgorithmConfig(abc.ABC): + """Base class for configuration options for matching algorithm.""" + + def as_dict(self) -> Dict: + """Returns the configuration as a dictionary. + + Returns: + Dict[str, Any] + """ + pass + + +@dataclass +class TreeAhConfig(AlgorithmConfig): + """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing). + + Please refer to this paper for more details: https://arxiv.org/abs/1908.10396 + + Args: + leaf_node_embedding_count (int): Optional. Number of embeddings on each + leaf node. The default value is 1000 if not set. + """ + + leaf_node_embedding_count: Optional[int] = None + + @override + def as_dict(self) -> Dict: + return {"leaf_node_embedding_count": self.leaf_node_embedding_count} + + +@dataclass +class BruteForceConfig(AlgorithmConfig): + """Configuration options for using brute force search. + + It simply implements the standard linear search in the database for each + query. + """ + + @override + def as_dict(self) -> Dict[str, Any]: + return {"bruteForceConfig": {}} + + +@dataclass +class IndexConfig: + """Configuration options for the Vertex FeatureView for embedding.""" + + embedding_column: str + dimensions: int + algorithm_config: AlgorithmConfig = field(default_factory=TreeAhConfig()) + filter_columns: Optional[List[str]] = None + crowding_column: Optional[str] = None + distance_measure_type: Optional[DistanceMeasureType] = None + + def as_dict(self) -> Dict[str, Any]: + """Returns the configuration as a dictionary. + + Returns: + Dict[str, Any] + """ + config = { + "embedding_column": self.embedding_column, + "embedding_dimension": self.dimensions, + } + if self.distance_measure_type is not None: + config["distance_measure_type"] = self.distance_measure_type.value + if self.filter_columns is not None: + config["filter_columns"] = self.filter_columns + if self.crowding_column is not None: + config["crowding_column"] = self.crowding_column + + if isinstance(self.algorithm_config, TreeAhConfig): + config["tree_ah_config"] = self.algorithm_config.as_dict() + else: + config["brute_force_config"] = self.algorithm_config.as_dict() + return config + + +@dataclass +class FeatureGroupBigQuerySource: + """BigQuery source for the Feature Group.""" + + # The URI for the BigQuery table/view. + uri: str + # The entity ID columns. If not specified, defaults to ['entity_id']. + entity_id_columns: Optional[List[str]] = None diff --git a/agentplatform/resources/preview/ml_monitoring/__init__.py b/agentplatform/resources/preview/ml_monitoring/__init__.py new file mode 100644 index 0000000000..41bbc1427f --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agentplatform.resources.preview.ml_monitoring.model_monitors import ( + ModelMonitor, + ModelMonitoringJob, +) + +__all__ = ( + "ModelMonitor", + "ModelMonitoringJob", +) diff --git a/agentplatform/resources/preview/ml_monitoring/model_monitors.py b/agentplatform/resources/preview/ml_monitoring/model_monitors.py new file mode 100644 index 0000000000..364b7f7c46 --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/model_monitors.py @@ -0,0 +1,1866 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import copy +import dataclasses +import json +import re +import time +from typing import Any, Dict, List, Optional + +from google.auth import credentials as auth_credentials +from agentplatform.resources.preview.ml_monitoring.spec import ( + notification, + objective, + output, + schema, +) +from google.cloud.aiplatform import base, initializer, utils +from google.cloud.aiplatform.compat.types import ( + explanation_v1beta1 as explanation, + job_state_v1beta1 as gca_job_state, + model_monitor_v1beta1 as gca_model_monitor_compat, + model_monitoring_alert_v1beta1 as model_monitoring_alert, + model_monitoring_job_v1beta1 as gca_model_monitoring_job_compat, + model_monitoring_service_v1beta1 as model_monitoring_service, + model_monitoring_spec_v1beta1 as model_monitoring_spec, + model_monitoring_stats_v1beta1 as model_monitoring_stats, + schedule_service_v1beta1 as gca_schedule_service, + schedule_v1beta1 as gca_schedule, +) +import proto + +from google.protobuf import field_mask_pb2 +from google.protobuf import timestamp_pb2 +from google.type import interval_pb2 +from google.protobuf import text_format + +try: + import tensorflow as tf +except ImportError: + tf = None +try: + import tensorflow_data_validation as tfdv +except ImportError: + tfdv = None +try: + from tensorflow_metadata.proto.v0 import statistics_pb2 + from tensorflow_metadata.proto.v0 import anomalies_pb2 +except ImportError: + statistics_pb2 = None + anomalies_pb2 = None + +_LOGGER = base.Logger(__name__) + +_JOB_COMPLETE_STATES = ( + gca_job_state.JobState.JOB_STATE_SUCCEEDED, + gca_job_state.JobState.JOB_STATE_FAILED, + gca_job_state.JobState.JOB_STATE_PARTIALLY_SUCCEEDED, +) + +_JOB_ERROR_STATES = (gca_job_state.JobState.JOB_STATE_FAILED,) + +# _block_until_complete wait times +_JOB_WAIT_TIME = 5 # start at five seconds +_LOG_WAIT_TIME = 5 +_MAX_WAIT_TIME = 60 * 5 # 5 minute wait +_WAIT_TIME_MULTIPLIER = 2 # scale wait by 2 every iteration + + +def _visualize_stats( + baseline_stats_output: str, target_stats_output: str +) -> None: + """Visualizes the model monitoring stats from output directory.""" + import tensorflow as tf + + if not statistics_pb2: + raise TypeError( + "statistics_pb2 should be installed to visualize the results" + ) + if not tf.io.gfile.exists(target_stats_output): + raise ValueError("No stats were generated.") + if tf.io.gfile.exists(baseline_stats_output): + with tf.io.gfile.GFile( + baseline_stats_output, "rb" + ) as baseline, tf.io.gfile.GFile(target_stats_output, "rb") as target: + baseline_combined_stats = statistics_pb2.DatasetFeatureStatisticsList() + baseline_combined_stats.ParseFromString(baseline.read()) + target_combined_stats = statistics_pb2.DatasetFeatureStatisticsList() + target_combined_stats.ParseFromString(target.read()) + baseline.close() + target.close() + tfdv.visualize_statistics( + lhs_statistics=baseline_combined_stats, + rhs_statistics=target_combined_stats, + lhs_name="Baseline Stats", + rhs_name="Target Stats", + ) + else: + with tf.io.gfile.GFile(target_stats_output, "rb") as target: + target_combined_stats = statistics_pb2.DatasetFeatureStatisticsList() + target_combined_stats.ParseFromString(target.read()) + target.close() + tfdv.visualize_statistics(target_combined_stats) + + +def _visualize_anomalies(anomalies_output: str) -> None: + """Visualizes the model monitoring anomalies from output directory.""" + import tensorflow as tf + + if not anomalies_pb2: + raise TypeError( + "anomalies_pb2 should be installed to visualize the results" + ) + with tf.io.gfile.GFile(anomalies_output, "r") as f: + anomalies = anomalies_pb2.Anomalies() + text_format.Merge(f.read(), anomalies) + f.close() + tfdv.display_anomalies(anomalies) + + +def _visualize_feature_attribution(feature_attribution_output: str) -> None: + """Visualizes the model monitoring feature attribution from output directory.""" + import tensorflow as tf + + with tf.io.gfile.GFile(feature_attribution_output, "r") as f: + print(json.dumps(json.loads(f.read()), indent=4)) + + +def _feature_drift_stats_output_path( + output_directory: str, job_id: str +) -> (str, str): + """Returns the baseline and target output paths for the model monitoring feature drift stats.""" + return ( + f"{output_directory}/tabular/jobs/{job_id}/feature_drift/baseline/statistics", + f"{output_directory}/tabular/jobs/{job_id}/feature_drift/target/statistics", + ) + + +def _feature_drift_anomalies_output_path( + output_directory: str, job_id: str +) -> str: + """Returns the output path for the model monitoring anomalies.""" + return f"{output_directory}/tabular/jobs/{job_id}/feature_drift/anomalies.textproto" + + +def _prediction_output_stats_output_path( + output_directory: str, job_id: str +) -> (str, str): + """Returns the baseline and target output paths for the model monitoring prediction output stats.""" + return ( + f"{output_directory}/tabular/jobs/{job_id}/output_drift/baseline/statistics", + f"{output_directory}/tabular/jobs/{job_id}/output_drift/target/statistics", + ) + + +def _prediction_output_anomalies_output_path( + output_directory: str, job_id: str +) -> str: + """Returns the output path for the model monitoring anomalies.""" + return f"{output_directory}/tabular/jobs/{job_id}/output_drift/anomalies.textproto" + + +def _feature_attribution_target_stats_output_path( + output_directory: str, job_id: str +) -> str: + """Returns the output path for the model monitoring stats.""" + return ( + f"{output_directory}/tabular/jobs/{job_id}/xai/target/feature_score.json" + ) + + +def _feature_attribution_baseline_stats_output_path( + output_directory: str, job_id: str +) -> str: + """Returns the output path for the model monitoring anomalies.""" + return f"{output_directory}/tabular/jobs/{job_id}/xai/baseline/feature_score.json" + + +def _transform_schema_pandas( + dataset: Dict[str, str], + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, +) -> schema.ModelMonitoringSchema: + """Transforms the pandas schema to model monitoring schema.""" + ground_truth_fields_list = list() + prediction_fields_list = list() + feature_fields_list = list() + pandas_integer_types = ["integer", "Int32", "Int64", "UInt32", "UInt64"] + pandas_string_types = [ + "string", + "bytes", + "date", + "time", + "datetime64", + "datetime", + "mixed-integer", + "inteval", + "Interval", + ] + pandas_float_types = [ + "floating", + "decimal", + "mixed-integer-float", + "Float32", + "Float64", + ] + for field in dataset: + infer_type = dataset[field] + if infer_type in pandas_string_types: + data_type = "string" + elif infer_type in pandas_integer_types: + data_type = "integer" + elif infer_type in pandas_float_types: + data_type = "float" + elif infer_type == "boolean": + data_type = "boolean" + elif infer_type == "categorical" or infer_type == "category": + data_type = "categorical" + else: + raise ValueError(f"Unsupported data type: {infer_type}") + if ground_truth_fields and field in ground_truth_fields: + ground_truth_fields_list.append( + schema.FieldSchema(name=field, data_type=data_type, repeated=False) + ) + elif prediction_fields and field in prediction_fields: + prediction_fields_list.append( + schema.FieldSchema(name=field, data_type=data_type, repeated=False) + ) + elif (feature_fields and field in feature_fields) or not feature_fields: + feature_fields_list.append( + schema.FieldSchema(name=field, data_type=data_type, repeated=False) + ) + return schema.ModelMonitoringSchema( + ground_truth_fields=ground_truth_fields_list + if ground_truth_fields + else None, + prediction_fields=prediction_fields_list if prediction_fields else None, + feature_fields=feature_fields_list, + ) + + +def _transform_field_schema( + field_schema: gca_model_monitor_compat.ModelMonitoringSchema.FieldSchema, +) -> Dict[str, Any]: + result = dict() + result["name"] = field_schema.name + result["data_type"] = field_schema.data_type + result["repeated"] = field_schema.repeated + return result + + +def _get_schedule_name(schedule_name: str) -> str: + if schedule_name: + client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + ) + if client.parse_schedule_path(schedule_name): + return schedule_name + elif re.match("^{}$".format("[0-9]{0,127}"), schedule_name): + return client.schedule_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + schedule=schedule_name, + ) + else: + raise ValueError( + "schedule name must be of the format" + " `projects/{project}/locations/{location}/schedules/{schedule}` or" + " `{schedule}`" + ) + return schedule_name + + +def _get_model_monitoring_job_name( + model_monitoring_job_name: str, + model_monitor_name: str, +) -> str: + if model_monitoring_job_name: + client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + ) + if client.parse_model_monitoring_job_path(model_monitoring_job_name): + return model_monitoring_job_name + elif re.match("^{}$".format("[0-9]{0,127}"), model_monitoring_job_name): + model_monitor_name = model_monitor_name.split("/")[-1] + return client.model_monitoring_job_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + model_monitor=model_monitor_name, + model_monitoring_job=model_monitoring_job_name, + ) + else: + raise ValueError( + "model monitoring job name must be of the format" + " `projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`" + " or `{model_monitoring_job}`" + ) + return model_monitoring_job_name + + +@dataclasses.dataclass +class MetricsSearchResponse: + """MetricsSearchResponse represents a response of the search metrics request. + + Attributes: + monitoring_stats (List[model_monitoring_stats.ModelMonitoringStats]): + Stats retrieved for requested objectives. + next_page_token (str): The page token that can be used by the next call. + """ + + next_page_token: str + _search_metrics_response: Any + monitoring_stats: List[model_monitoring_stats.ModelMonitoringStats] = ( + dataclasses.field(default_factory=list) + ) + + @property + def raw_search_metrics_response( + self, + ) -> model_monitoring_service.SearchModelMonitoringStatsResponse: + """Raw search metrics response.""" + return self._search_metrics_response + + +# TODO: b/307946658 - Return a dict or a new dataclass for search_alert +@dataclasses.dataclass +class AlertsSearchResponse: + """AlertsSearchResponse represents a response of the search alerts request. + + Attributes: + next_page_token (str): The page token that can be used by the next call. + model_monitoring_alerts + (List[model_monitoring_alert.ModelMonitoringAlert]): Alerts retrieved + for requested objectives. + total_alerts (int): Total number of alerts retrieved for requested + objectives. + """ + + next_page_token: str + _search_alerts_response: Any + total_alerts: int + model_monitoring_alerts: List[model_monitoring_alert.ModelMonitoringAlert] = ( + dataclasses.field(default_factory=list) + ) + + @property + def raw_search_alerts_response( + self, + ) -> model_monitoring_service.SearchModelMonitoringAlertsResponse: + """Raw search metrics response.""" + return self._search_alerts_response + + +@dataclasses.dataclass +class ListJobsResponse: + """ListJobsResponse represents a response of the list jobs request. + + Attributes: + list_jobs (List[model_monitoring_job.ModelMonitoringJob]): Jobs retrieved + for request. + next_page_token (str): The page token that can be used by the next call. + """ + + next_page_token: str + _list_jobs_response: Any + list_jobs: List[gca_model_monitoring_job_compat.ModelMonitoringJob] = ( + dataclasses.field(default_factory=list) + ) + + @property + def raw_list_jobs_response( + self, + ) -> model_monitoring_service.ListModelMonitoringJobsResponse: + """Raw list jobs response.""" + return self._list_jobs_response + + +@dataclasses.dataclass +class ListSchedulesResponse: + """ListSchedulesResponse represents a response of the list jobs request. + + Attributes: + list_schedules (List[schedule.Schedule]): Jobs retrieved for request. + next_page_token (str): The page token that can be used by the next call. + """ + + next_page_token: str + _list_schedules_response: Any + list_schedules: List[gca_schedule.Schedule] = dataclasses.field( + default_factory=list + ) + + @property + def raw_list_schedules_response( + self, + ) -> gca_schedule_service.ListSchedulesResponse: + """Raw list jobs response.""" + return self._list_schedules_response + + +class ModelMonitor(base.VertexAiResourceNounWithFutureManager): + """Initializer for ModelMonitor. + + Args: + model_monitor_name (str): Required. A fully-qualified model monitor + resource name or model monitor ID. + Example: "projects/123/locations/us-central1/modelMonitors/456" or + "456" when project and location are initialized or passed. + project (str): Required. Project to retrieve model monitor from. If not + set, project set in aiplatform.init will be used. + location (str): Required. Location to retrieve model monitor from. If not + set, location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): Optional. Custom credentials + to use to retrieve this model monitor. Overrides credentials set in + aiplatform.init. + """ + + client_class = utils.ModelMonitoringClientWithOverride + _resource_noun = "modelMonitors" + _getter_method = "get_model_monitor" + _list_method = "list_model_monitors" + _delete_method = "delete_model_monitor" + _parse_resource_name_method = "parse_model_monitor_path" + _format_resource_name_method = "model_monitor_path" + + def __init__( + self, + model_monitor_name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=model_monitor_name, + ) + self._gca_resource = self._get_gca_resource( + resource_name=model_monitor_name + ) + + @classmethod + def create( + cls, + model_name: str, + model_version_id: str, + training_dataset: Optional[objective.MonitoringInput] = None, + display_name: Optional[str] = None, + model_monitoring_schema: Optional[schema.ModelMonitoringSchema] = None, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + model_monitor_id: Optional[str] = None, + ) -> "ModelMonitor": + """Creates a new ModelMonitor. + + Args: + model_name (str): Required. A model resource name as model monitoring + target. + Format: ``projects/{project}/locations/{location}/models/{model}`` + model_version_id (str): Required. Model version id. + training_dataset (objective.MonitoringInput): Optional. Training dataset + used to train the model. It can serve as a baseline dataset to + identify changes in production. + display_name (str): Optional. The user-defined name of the ModelMonitor. + The name can be up to 128 characters long and can comprise any UTF-8 + character. Display name of the ModelMonitor. + model_monitoring_schema (schema.ModelMonitoringSchema): Required for + most models, but optional for Gemini Enterprise Agent Platform AutoML + Tables unless the schema information is not available. The Monitoring + Schema specifies the model's features, prediction outputs and ground + truth properties. It is used to extract pertinent data from the + dataset and to process features based on their properties. Make sure + that the schema aligns with your dataset, if it does not, Gemini + Enterprise Agent Platform will be unable to extract data form the + dataset. + tabular_objective_spec (objective.TabularObjective): Optional. The + default tabular monitoring objective spec for the model monitor. It + can be overriden in the ModelMonitoringJob objective spec. + output_spec (output.OutputSpec): Optional. The default monitoring + metrics/logs export spec, it can be overriden in the + ModelMonitoringJob output spec. If not specified, a default Google + Cloud Storage bucket will be created under your project. + notification_spec (notification.NotificationSpec): Optional. The default + notification spec for monitoring result. It can be overriden in the + ModelMonitoringJob notification spec. + explanation_spec (explanation.ExplanationSpec): Optional. The default + explanation spec for feature attribution monitoring. It can be + overriden in the ModelMonitoringJob explanation spec. + project (str): Optional. Project to retrieve model monitor from. If not + set, project set in aiplatform.init will be used. + location (str): Optional. Location to retrieve model monitor from. If + not set, location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): Optional. Custom credentials + to use to create this model monitor. Overrides credentials set in + aiplatform.init. + model_monitor_id (str): Optional. The unique ID of the model monitor, + which will become the final component of the model monitor resource + name. If not specified, it will be generated by Gemini Enterprise + Agent Platform. + + Returns: + ModelMonitor: The model monitor that was created. + """ + api_client = initializer.global_config.create_client( + client_class=cls.client_class, + credentials=credentials, + location_override=location, + ) + + if display_name: + utils.validate_display_name(display_name) + else: + display_name = cls._generate_display_name() + + project = project or initializer.global_config.project + location = location or initializer.global_config.location + + user_monitoring_target = gca_model_monitor_compat.ModelMonitor.ModelMonitoringTarget( + vertex_model=gca_model_monitor_compat.ModelMonitor.ModelMonitoringTarget.VertexModelSource( + model=model_name, model_version_id=model_version_id + ) + ) + + operation_future = api_client.create_model_monitor( + request=model_monitoring_service.CreateModelMonitorRequest( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + model_monitor=gca_model_monitor_compat.ModelMonitor( + display_name=display_name, + model_monitoring_target=user_monitoring_target, + training_dataset=( + training_dataset._as_proto() if training_dataset else None + ), + model_monitoring_schema=( + model_monitoring_schema._as_proto() + if model_monitoring_schema + else None + ), + tabular_objective=( + tabular_objective_spec._as_proto() + if tabular_objective_spec + else None + ), + notification_spec=( + notification_spec._as_proto() if notification_spec else None + ), + output_spec=output_spec._as_proto() if output_spec else None, + explanation_spec=explanation_spec, + ), + model_monitor_id=model_monitor_id, + ), + ) + _LOGGER.log_create_with_lro(cls, operation_future) + created_model_monitor = operation_future.result(timeout=None) + _LOGGER.log_create_complete(cls, created_model_monitor, "model_monitor") + self = cls._construct_sdk_resource_from_gapic( + gapic_resource=created_model_monitor, + project=project, + location=location, + credentials=credentials, + ) + model_monitor_id = self._gca_resource.name.split("/")[-1] + _LOGGER.info( + f"https://console.cloud.google.com/vertex-ai/model-monitoring/locations/{location}/model-monitors/{model_monitor_id}?project={project}" + ) + return self + + def update( + self, + display_name: Optional[str] = None, + training_dataset: Optional[objective.MonitoringInput] = None, + model_monitoring_schema: Optional[schema.ModelMonitoringSchema] = None, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + ) -> "ModelMonitor": + """Updates an existing ModelMonitor. + + Args: + display_name (str): Optional. The user-defined name of the ModelMonitor. + The name can be up to 128 characters long and can comprise any UTF-8 + character. Display name of the ModelMonitor. + training_dataset (objective.MonitoringInput): Optional. Training dataset + used to train the model. It can serve as a baseline dataset to + identify changes in production. + model_monitoring_schema (schema.ModelMonitoringSchema): Optional. The + Monitoring Schema specifies the model's features, prediction outputs + and ground truth properties. It is used to extract pertinent data from + the dataset and to process features based on their properties. Make + sure that the schema aligns with your dataset, if it does not, Gemini + Enterprise Agent Platform will be unable to extract data form the + dataset. + tabular_objective_spec (objective.TabularObjective): Optional. The + default tabular monitoring objective spec for the model monitor. It + can be overriden in the ModelMonitoringJob objective spec. + output_spec (output.OutputSpec): Optional. The default monitoring + metrics/logs export spec, it can be overriden in the + ModelMonitoringJob output spec. + notification_spec (notification.NotificationSpec): Optional. The default + notification spec for monitoring result. It can be overriden in the + ModelMonitoringJob notification spec. + explanation_spec (explanation.ExplanationSpec): Optional. The default + explanation spec for feature attribution monitoring. It can be + overriden in the ModelMonitoringJob explanation spec. + + Returns: + ModelMonitor: The updated model monitor. + """ + self._sync_gca_resource() + current_monitor = copy.deepcopy(self._gca_resource) + update_mask: List[str] = [] + if display_name is not None: + update_mask.append("display_name") + current_monitor.display_name = display_name + if training_dataset is not None: + update_mask.append("training_dataset") + current_monitor.training_dataset = training_dataset._as_proto() + if model_monitoring_schema is not None: + update_mask.append("model_monitoring_schema") + current_monitor.model_monitoring_schema = ( + model_monitoring_schema._as_proto() + ) + if tabular_objective_spec is not None: + update_mask.append("tabular_objective") + current_monitor.tabular_objective = tabular_objective_spec._as_proto() + if output_spec is not None: + update_mask.append("output_spec") + current_monitor.output_spec = output_spec._as_proto() + if notification_spec is not None: + update_mask.append("notification_spec") + current_monitor.notification_spec = notification_spec._as_proto() + if explanation_spec is not None: + update_mask.append("explanation_spec") + current_monitor.explanation_spec = explanation_spec + lro = self.api_client.update_model_monitor( + model_monitor=current_monitor, + update_mask=field_mask_pb2.FieldMask(paths=update_mask), + ) + self._gca_resource = lro.result() + return self + + @base.optional_sync() + def delete(self, force: bool = False, sync: bool = True) -> None: + """Force delete the model monitor. + + Args: + force (bool): Required. If force is set to True, all schedules on this + ModelMonitor will be deleted first. Default is False. + sync (bool): Whether to execute this method synchronously. If False, + this method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future has + completed. Default is True. + """ + _LOGGER.log_action_start_against_resource("Deleting", "", self) + lro = self.api_client.delete_model_monitor( + request=model_monitoring_service.DeleteModelMonitorRequest( + name=self._gca_resource.name, force=force + ) + ) + _LOGGER.log_action_started_against_resource_with_lro( + "Delete", "", self.__class__, lro + ) + _LOGGER.log_action_completed_against_resource("deleted.", "", self) + + def create_schedule( + self, + cron: str, + target_dataset: objective.MonitoringInput, + display_name: Optional[str] = None, + model_monitoring_job_display_name: Optional[str] = None, + start_time: Optional[timestamp_pb2.Timestamp] = None, + end_time: Optional[timestamp_pb2.Timestamp] = None, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + baseline_dataset: Optional[objective.MonitoringInput] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + ) -> "gca_schedule.Schedule": + """Creates a new Scheduled run for model monitoring job. + + Args: + cron (str): Required. Cron schedule (https://en.wikipedia.org/wiki/Cron) + to launch scheduled runs. To explicitly set a timezone to the cron + tab, apply a prefix in the cron tab: "CRON_TZ=${IANA_TIME_ZONE}" or + "TZ=${IANA_TIME_ZONE}". The ${IANA_TIME_ZONE} may only be a valid + string from IANA time zone database. For example, + "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * + *". + target_dataset (objective.MonitoringInput): Required. The target dataset + for analysis. + display_name (str): Optional. The user-defined name of the Schedule. The + name can be up to 128 characters long and can be consist of any UTF-8 + characters. Display name of the Schedule. + model_monitoring_job_display_name (str): Optional. The user-defined name + of the ModelMonitoringJob. The name can be up to 128 characters long + and can be consist of any UTF-8 characters. Display name of the + ModelMonitoringJob. + start_time (timestamp_pb2.Timestamp): Optional. Timestamp after which + the first run can be scheduled. Default to Schedule create time if not + specified. + end_time (timestamp_pb2.Timestamp): Optional. Timestamp after which no + new runs can be scheduled. If specified, The schedule will be + completed when the end_time is reached. If not specified, new runs + will keep getting scheduled until this Schedule is paused or deleted. + Already scheduled runs will be allowed to complete. Unset if not + specified. + tabular_objective_spec (objective.TabularObjective): Optional. The + tabular monitoring objective spec. If not set, the default tabular + objective spec in ModelMonitor will be used. You must either set here + or set the default one in the ModelMonitor. + baseline_dataset (objective.MonitoringInput): Optional. The baseline + dataset for monitoring job. If not set, the training dataset in + ModelMonitor will be used as baseline dataset. + output_spec (output.OutputSpec): Optional. The monitoring metrics/logs + export spec. If not set, will use the default output_spec defined in + ModelMonitor. + notification_spec (notification.NotificationSpec): Optional. The + notification spec for monitoring result. If not set, will use the + default notification_spec defined in ModelMonitor. + explanation_spec (explanation.ExplanationSpec): Optional. The + explanation spec for feature attribution monitoring. If not set, will + use the default explanation_spec defined in ModelMonitor. + + Returns: + Schedule: The created schedule. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + model_monitor_name = utils.full_resource_name( + resource_name=self._gca_resource.name, + resource_noun=self._resource_noun, + parse_resource_name_method=self._parse_resource_name, + format_resource_name_method=self._format_resource_name, + project=self.project, + location=self.location, + ) + + schedule_request = gca_schedule_service.CreateScheduleRequest( + parent=initializer.global_config.common_location_path( + project=self.project, location=self.location + ), + schedule=gca_schedule.Schedule( + display_name=display_name, + start_time=start_time, + end_time=end_time, + cron=cron, + create_model_monitoring_job_request=model_monitoring_service.CreateModelMonitoringJobRequest( + parent=model_monitor_name, + model_monitoring_job=gca_model_monitoring_job_compat.ModelMonitoringJob( + display_name=model_monitoring_job_display_name, + model_monitoring_spec=model_monitoring_spec.ModelMonitoringSpec( + objective_spec=model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=( + tabular_objective_spec._as_proto() + if tabular_objective_spec + else self._gca_resource.tabular_objective + ), + target_dataset=target_dataset._as_proto(), + baseline_dataset=( + baseline_dataset._as_proto() + if baseline_dataset + else self._gca_resource.training_dataset + ), + explanation_spec=( + explanation_spec + if explanation_spec + else self._gca_resource.explanation_spec + ), + ), + output_spec=( + output_spec._as_proto() + if output_spec + else self._gca_resource.output_spec + ), + notification_spec=( + notification_spec._as_proto() + if notification_spec + else self._gca_resource.notification_spec + ), + ), + ), + ), + max_concurrent_run_count=1, + ), + ) + created_schedule = api_client.select_version("v1beta1").create_schedule( + request=schedule_request + ) + _LOGGER.log_create_complete( + gca_schedule.Schedule, created_schedule, "schedule" + ) + return created_schedule + + def update_schedule( + self, + schedule_name: str, + display_name: Optional[str] = None, + model_monitoring_job_display_name: Optional[str] = None, + cron: Optional[str] = None, + baseline_dataset: Optional[objective.MonitoringInput] = None, + target_dataset: Optional[objective.MonitoringInput] = None, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + end_time: Optional[timestamp_pb2.Timestamp] = None, + ) -> "gca_schedule.Schedule": + """Updates an existing Schedule. + + Args: + schedule_name (str): Required. The resource name of schedule that needs + to be updated. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + or ``{schedule}`` + display_name (str): Optional. The user-defined name of the Schedule. The + name can be up to 128 characters long and can be consist of any UTF-8 + characters. Display name of the Schedule. + model_monitoring_job_display_name (str): Optional. The user-defined + display name of the ModelMonitoringJob that needs to be updated. + cron (str): Optional. Cron schedule (https://en.wikipedia.org/wiki/Cron) + to launch scheduled runs. To explicitly set a timezone to the cron + tab, apply a prefix in the cron tab: "CRON_TZ=${IANA_TIME_ZONE}" or + "TZ=${IANA_TIME_ZONE}". The ${IANA_TIME_ZONE} may only be a valid + string from IANA time zone database. For example, + "CRON_TZ=America/New_York 1 * * * *", or "TZ=America/New_York 1 * * * + *". + baseline_dataset (objective.MonitoringInput): Optional. The baseline + dataset for monitoring job. + target_dataset (objective.MonitoringInput): Optional. The target dataset + for analysis. + tabular_objective_spec (objective.TabularObjective): Optional. The + tabular monitoring objective spec. + output_spec (output.OutputSpec): Optional. The monitoring metrics/logs + export spec. + notification_spec (notification.NotificationSpec): Optional. The + notification spec for monitoring result. + explanation_spec (explanation.ExplanationSpec): Optional. The + explanation spec for feature attribution monitoring. + end_time (timestamp_pb2.Timestamp): Optional. Timestamp after which no + new runs can be scheduled. + + Returns: + Schedule: The updated schedule. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + model_monitor_name = utils.full_resource_name( + resource_name=self._gca_resource.name, + resource_noun=self._resource_noun, + parse_resource_name_method=self._parse_resource_name, + format_resource_name_method=self._format_resource_name, + project=self.project, + location=self.location, + ) + schedule_name = _get_schedule_name(schedule_name) + current_schedule = copy.deepcopy( + self.get_schedule(schedule_name=schedule_name) + ) + update_mask = [] + if display_name is not None: + update_mask.append("display_name") + current_schedule.display_name = display_name + if cron is not None: + update_mask.append("cron") + current_schedule.cron = cron + if end_time is not None: + update_mask.append("end_time") + current_schedule.end_time = end_time + current_job_request = current_schedule.create_model_monitoring_job_request + current_spec = ( + current_job_request.model_monitoring_job.model_monitoring_spec + ) + if ( + tabular_objective_spec is not None + or output_spec is not None + or notification_spec is not None + or model_monitoring_job_display_name is not None + or baseline_dataset is not None + or target_dataset is not None + ): + update_mask.append("create_model_monitoring_job_request") + updated_model_monitoring_spec = model_monitoring_spec.ModelMonitoringSpec( + objective_spec=model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=( + tabular_objective_spec._as_proto() + if tabular_objective_spec + else current_spec.objective_spec.tabular_objective + ), + baseline_dataset=( + baseline_dataset._as_proto() + if baseline_dataset + else current_spec.objective_spec.baseline + ), + target_dataset=( + target_dataset._as_proto() + if target_dataset + else current_spec.objective_spec.target + ), + explanation_spec=( + explanation_spec + if explanation_spec + else current_spec.objective_spec.explanation_spec + ), + ), + output_spec=( + output_spec._as_proto() + if output_spec + else current_spec.output_spec + ), + notification_spec=( + notification_spec._as_proto() + if notification_spec + else current_spec.notification_spec + ), + ) + current_schedule.create_model_monitoring_job_request = model_monitoring_service.CreateModelMonitoringJobRequest( + parent=model_monitor_name, + model_monitoring_job=gca_model_monitoring_job_compat.ModelMonitoringJob( + display_name=( + model_monitoring_job_display_name + if model_monitoring_job_display_name + else current_job_request.model_monitoring_job.display_name + ), + model_monitoring_spec=updated_model_monitoring_spec, + ), + ) + return api_client.select_version("v1beta1").update_schedule( + schedule=current_schedule, + update_mask=field_mask_pb2.FieldMask(paths=update_mask), + ) + + def delete_schedule(self, schedule_name: str) -> None: + """Deletes an existing Schedule. + + Args: + schedule_name (str): Required. The resource name of schedule that needs + to be deleted. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + or ``{schedule}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").delete_schedule( + name=schedule_name + ) + + def pause_schedule(self, schedule_name: str) -> None: + """Pauses an existing Schedule. + + Args: + schedule_name (str): Required. The resource name of schedule that needs + to be paused. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + or ``{schedule}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").pause_schedule( + name=schedule_name + ) + + def resume_schedule(self, schedule_name: str) -> None: + """Resumes an existing Schedule. + + Args: + schedule_name (str): Required. The resource name of schedule that needs + to be resumed. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + or ``{schedule}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").resume_schedule( + name=schedule_name + ) + + def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": + """Gets an existing Schedule. + + Args: + schedule_name (str): Required. The resource name of schedule that needs + to be fetched. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + or ``{schedule}`` + + Returns: + Schedule: The schedule requested. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").get_schedule(name=schedule_name) + + def list_schedules( + self, + filter: Optional[str] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> "ListSchedulesResponse.list_schedules": + """List Schedules. + + Args: + filter (str): Optional. Lists the Schedules that match the filter + expression. The following fields are supported: - ``display_name``: + Supports ``=``, ``!=`` comparisons, and ``:`` wildcard. - ``state``: + Supports ``=`` and ``!=`` comparisons. - ``request``: Supports + existence of the check. (e.g. + ``create_pipeline_job_request:*`` --> Schedule has + create_pipeline_job_request). - ``create_time``: Supports ``=``, + ``!=``, ``<``, ``>``, ``<=``, and ``>=`` comparisons. Values must be + in RFC 3339 format. - ``start_time``: Supports ``=``, ``!=``, ``<``, + ``>``, ``<=``, and ``>=`` comparisons. Values must be in RFC 3339 + format. - ``end_time``: Supports ``=``, ``!=``, ``<``, ``>``, ``<=``, + ``>=`` comparisons and ``:*`` existence check. Values must be in RFC + 3339 format. - ``next_run_time``: Supports ``=``, ``!=``, ``<``, + ``>``, ``<=``, and ``>=`` comparisons. Values must be in RFC 3339 + format. Filter expressions can be combined together using logical + operators (``NOT``, ``AND`` & ``OR``). The syntax to define filter + expression is based on https://google.aip.dev/160. + page_size (int): Optional. The standard page list size. + page_token (str): Optional. A page token received from a previous call. + + Returns: + MetricsSearchResponse: The model monitoring stats results. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + filter = ( + f"{filter} AND" + f" create_model_monitoring_job_request.parent={self._gca_resource.name}" + if filter + else f"create_model_monitoring_job_request.parent={self._gca_resource.name}" + ) + list_schedules_response = ( + api_client.select_version("v1beta1") + .list_schedules( + request=gca_schedule_service.ListSchedulesRequest( + parent=f"projects/{self.project}/locations/{self.location}", + filter=filter, + page_size=page_size, + page_token=page_token, + ) + ) + ._response + ) + return ListSchedulesResponse( + list_schedules=list_schedules_response.schedules, + next_page_token=list_schedules_response.next_page_token, + _list_schedules_response=list_schedules_response, + ).list_schedules + + def run( + self, + target_dataset: objective.MonitoringInput, + display_name: Optional[str] = None, + model_monitoring_job_id: Optional[str] = None, + sync: Optional[bool] = False, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + baseline_dataset: Optional[objective.MonitoringInput] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + ) -> "ModelMonitoringJob": + """Creates a new ModelMonitoringJob. + + Args: + target_dataset (objective.MonitoringInput): Required. The target dataset + for analysis. + display_name (str): Optional. The user-defined name of the + ModelMonitoringJob. The name can be up to 128 characters long and can + comprise any UTF-8 character. Display name of the ModelMonitoringJob. + model_monitoring_job_id (str): Optional. The unique ID of the model + monitoring job run, which will become the final component of the model + monitoring job resource name. The maximum length is 63 characters, and + valid characters are /^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$/. If not + specified, it will be generated by Gemini Enterprise Agent Platform. + sync (bool): Whether to execute this method synchronously. If False, + this method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future has + completed. Default is False. + tabular_objective_spec (objective.TabularObjective): Optional. The + tabular monitoring objective spec for the model monitoring job. + baseline_dataset (objective.MonitoringInput): Optional. The baseline + dataset for monitoring job. If not set, the training dataset in + ModelMonitor will be used as baseline dataset. + output_spec (output.OutputSpec): Optional. The monitoring metrics/logs + export spec. If not set, will use the default output_spec defined in + ModelMonitor. + notification_spec (notification.NotificationSpec): Optional. The + notification spec for monitoring result. If not set, will use the + default notification_spec defined in ModelMonitor. + explanation_config (explanation.ExplanationSpec): Optional. The + explanation spec for feature attribution monitoring. If not set, will + use the default explanation_spec defined in ModelMonitor. + + Returns: + ModelMonitoringJob: The model monitoring job that was created. + """ + model_monitor_name = utils.full_resource_name( + resource_name=self._gca_resource.name, + resource_noun=self._resource_noun, + parse_resource_name_method=self._parse_resource_name, + format_resource_name_method=self._format_resource_name, + project=self.project, + location=self.location, + ) + + return ModelMonitoringJob.create( + model_monitor_name=model_monitor_name, + project=self.project, + location=self.location, + credentials=self.credentials, + display_name=display_name, + target_dataset=target_dataset, + baseline_dataset=baseline_dataset, + model_monitoring_job_id=model_monitoring_job_id, + tabular_objective_spec=tabular_objective_spec, + output_spec=output_spec, + notification_spec=notification_spec, + explanation_spec=explanation_spec, + sync=sync, + ) + + def search_metrics( + self, + stats_name: Optional[str] = None, + objective_type: Optional[str] = None, + model_monitoring_job_name: Optional[str] = None, + schedule_name: Optional[str] = None, + algorithm: Optional[str] = None, + start_time: Optional[timestamp_pb2.Timestamp] = None, + end_time: Optional[timestamp_pb2.Timestamp] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> "MetricsSearchResponse.monitoring_stats": + """Search ModelMonitoringStats. + + Args: + stats_name (str): Optional. The stats name filter for the search, if not + set, all stats will be returned. For tabular model it's the feature + name. + objective_type (str): Optional. One of the supported monitoring + objectives: `raw-feature-drift` `prediction-output-drift` + `feature-attribution` + model_monitoring_job_name (str): Optional. The resource name of a + particular model monitoring job that the user wants to search metrics + result from. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + schedule_name (str): Optional. The resource name of a particular model + monitoring schedule that the user wants to search metrics result from. + Format: + ``projects/{project}/locations/{location}/schedules/{schedule}`` + algorithm (str): Optional. The algorithm type filter for the search, eg: + jensen_shannon_divergence, l_infinity. + start_time (timestamp_pb2.Timestamp): Optional. Inclusive start of the + time interval for which results should be returned. + end_time (timestamp_pb2.Timestamp): Optional. Exclusive end of the time + interval for which results should be returned. + page_size (int): Optional. The standard page list size. + page_token (str): Optional. A page token received from a previous call. + + Returns: + MetricsSearchResponse: The model monitoring stats results. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + user_time_interval = ( + interval_pb2.Interval(start_time=start_time, end_time=end_time) + if start_time or end_time + else None + ) + model_monitoring_stats_response = api_client.search_model_monitoring_stats( + request=model_monitoring_service.SearchModelMonitoringStatsRequest( + model_monitor=self._gca_resource.name, + stats_filter=model_monitoring_stats.SearchModelMonitoringStatsFilter( + tabular_stats_filter=model_monitoring_stats.SearchModelMonitoringStatsFilter.TabularStatsFilter( + stats_name=stats_name, + objective_type=objective_type, + model_monitoring_job=model_monitoring_job_name, + model_monitoring_schedule=schedule_name, + algorithm=algorithm, + ), + ), + time_interval=user_time_interval, + page_size=page_size, + page_token=page_token, + ), + )._response + return MetricsSearchResponse( + monitoring_stats=model_monitoring_stats_response.monitoring_stats, + next_page_token=model_monitoring_stats_response.next_page_token, + _search_metrics_response=model_monitoring_stats_response, + ).monitoring_stats + + def search_alerts( + self, + stats_name: Optional[str] = None, + objective_type: Optional[str] = None, + model_monitoring_job_name: Optional[str] = None, + start_time: Optional[timestamp_pb2.Timestamp] = None, + end_time: Optional[timestamp_pb2.Timestamp] = None, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> Dict[str, Any]: + """Search ModelMonitoringAlerts. + + Args: + stats_name (str): Optional. The stats name filter for the search, if not + set, all stats will be returned. For tabular models, provide the name + of the feature to return alerts from. + objective_type (str): Optional. Return alerts from one of the supported + monitoring + objectives: `raw-feature-drift` `prediction-output-drift` + `feature-attribution` + model_monitoring_job_name (str): Optional. The resource name of a + particular model monitoring job that the user wants to search metrics + result from. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + start_time (timestamp_pb2.Timestamp): Optional. Inclusive start of the + time interval for which alerts should be returned. + end_time (timestamp_pb2.Timestamp): Optional. Exclusive end of the time + interval for which alerts should be returned. + page_size (int): Optional. The standard page list size. + page_token (str): Optional. A page token received from a previous call. + + Returns: + AlertsSearchResponse: The model monitoring alerts results. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + user_time_interval = ( + interval_pb2.Interval(start_time=start_time, end_time=end_time) + if start_time or end_time + else None + ) + model_monitoring_alerts_response = ( + api_client.search_model_monitoring_alerts( + request=model_monitoring_service.SearchModelMonitoringAlertsRequest( + model_monitor=self._gca_resource.name, + stats_name=stats_name, + objective_type=objective_type, + model_monitoring_job=model_monitoring_job_name, + alert_time_interval=user_time_interval, + page_size=page_size, + page_token=page_token, + ), + )._response + ) + alert_response = AlertsSearchResponse( + model_monitoring_alerts=model_monitoring_alerts_response.model_monitoring_alerts, + next_page_token=model_monitoring_alerts_response.next_page_token, + total_alerts=model_monitoring_alerts_response.total_number_alerts, + _search_alerts_response=model_monitoring_alerts_response, + ) + return { + "total_number_alerts": alert_response.total_alerts, + "model_monitoring_alerts": alert_response.model_monitoring_alerts, + } + + def list_jobs( + self, + page_size: Optional[int] = None, + page_token: Optional[str] = None, + ) -> "ListJobsResponse.list_jobs": + """List ModelMonitoringJobs. + + Args: + page_size (int): Optional. The standard page list size. + page_token (str): Optional. A page token received from a previous call. + + Returns: + ListJobsResponse.list_jobs: The list model monitoring jobs responses. + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + + model_monitor_name = utils.full_resource_name( + resource_name=self._gca_resource.name, + resource_noun=self._resource_noun, + parse_resource_name_method=self._parse_resource_name, + format_resource_name_method=self._format_resource_name, + project=self.project, + location=self.location, + ) + + list_jobs_response = api_client.list_model_monitoring_jobs( + request=model_monitoring_service.ListModelMonitoringJobsRequest( + parent=model_monitor_name, + page_size=page_size, + page_token=page_token, + ) + )._response + return ListJobsResponse( + list_jobs=list_jobs_response.model_monitoring_jobs, + next_page_token=list_jobs_response.next_page_token, + _list_jobs_response=list_jobs_response, + ).list_jobs + + def delete_model_monitoring_job(self, model_monitoring_job_name: str) -> None: + """Delete a model monitoring job. + + Args: + model_monitoring_job_name (str): Required. The resource name of the + model monitoring job that needs to be deleted. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or ``{model_monitoring_job}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name + ) + api_client.delete_model_monitoring_job(name=job_resource_name) + + def get_model_monitoring_job( + self, model_monitoring_job_name: str + ) -> "ModelMonitoringJob": + """Get the specified ModelMonitoringJob. + + Args: + model_monitoring_job_name (str): Required. The resource name of the + ModelMonitoringJob that is needed. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or ``{model_monitoring_job}`` + + Returns: + ModelMonitoringJob: The model monitoring job get. + """ + self.wait() + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name + ) + return ModelMonitoringJob( + model_monitoring_job_name=job_resource_name, + project=self.project, + location=self.location, + credentials=self.credentials, + ) + + def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None: + """The method to visualize the feature drift result from a model monitoring job as a histogram chart and a table. + + Args: + model_monitoring_job_name (str): Required. The resource name of model + monitoring job to show the drift stats from. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or ``{model_monitoring_job}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) + output_directory = ( + job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix + ) + target_output, baseline_output = _feature_drift_stats_output_path( + output_directory, job_id + ) + anomoaly_output = _feature_drift_anomalies_output_path( + output_directory, job_id + ) + _visualize_stats(baseline_output, target_output) + _visualize_anomalies(anomoaly_output) + + def get_schema(self) -> gca_model_monitor_compat.ModelMonitoringSchema: + """Get the schema of the model monitor.""" + self._sync_gca_resource() + return self._gca_resource.model_monitoring_schema + + def show_output_drift_stats(self, model_monitoring_job_name: str) -> None: + """The method to visualize the prediction output drift result from a model monitoring job as a histogram chart and a table. + + Args: + model_monitoring_job_name (str): Required. The resource name of model + monitoring job to show the drift stats from. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or ``{model_monitoring_job}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) + output_directory = ( + job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix + ) + target_output, baseline_output = _prediction_output_stats_output_path( + output_directory, job_id + ) + anomoaly_output = _prediction_output_anomalies_output_path( + output_directory, job_id + ) + _visualize_stats(baseline_output, target_output) + _visualize_anomalies(anomoaly_output) + + def show_feature_attribution_drift_stats( + self, model_monitoring_job_name: str + ) -> None: + """The method to visualize the feature attribution drift result from a model monitoring job as a histogram chart and a table. + + Args: + model_monitoring_job_name (str): Required. The resource name of model + monitoring job to show the feature attribution drift stats from. + Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or ``{model_monitoring_job}`` + """ + api_client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) + output_directory = ( + job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix + ) + target_stats_output = _feature_attribution_target_stats_output_path( + output_directory, job_id + ) + baseline_stats_output = _feature_attribution_baseline_stats_output_path( + output_directory, job_id + ) + _visualize_feature_attribution(baseline_stats_output) + _visualize_feature_attribution(target_stats_output) + + +class ModelMonitoringJob(base.VertexAiStatefulResource): + r"""Initializer for ModelMonitoringJob. + + Example Usage: + + my_monitoring_job = aiplatform.ModelMonitoringJob( + model_monitoring_job_name='projects/123/locations/us-central1/modelMonitors/\ + my_model_monitor_id/modelMonitoringJobs/my_monitoring_job_id' + ) + or + my_monitoring_job = aiplatform.aiplatform.ModelMonitoringJob( + model_monitoring_job_name='my_monitoring_job_id', + model_monitor_id='my_model_monitor_id', + ) + Args: + model_monitoring_job_name (str): + Required. The resource name for the Model Monitoring Job if + provided alone, or the model monitoring job id if provided with + model_monitor_id. + model_monitor_id (str): + Optional. The model monitor id depends on the way of initializing + ModelMonitoringJob. + project (str): + Required. Project to retrieve endpoint from. If not set, project + set in aiplatform.init will be used. + location (str): + Required. Location to retrieve endpoint from. If not set, + location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): + Optional. Custom credentials to use to init model monitoring job. + Overrides credentials set in aiplatform.init. + """ + + client_class = utils.ModelMonitoringClientWithOverride + _resource_noun = "modelMonitoringJobs" + _getter_method = "get_model_monitoring_job" + _list_method = "list_model_monitoring_jobs" + _delete_method = "delete_model_monitoring_job" + _parse_resource_name_method = "parse_model_monitoring_job_path" + _format_resource_name_method = "model_monitoring_job_path" + + # Required by the done() method + _valid_done_states = _JOB_COMPLETE_STATES + + def __init__( + self, + model_monitoring_job_name: str, + model_monitor_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=model_monitoring_job_name, + ) + self._gca_resource = self._get_gca_resource( + resource_name=model_monitoring_job_name, + parent_resource_name_fields=( + {ModelMonitor._resource_noun: model_monitor_id} + if model_monitor_id + else model_monitor_id + ), + ) + + @property + def state(self) -> gca_job_state.JobState: + """Fetch Job again and return the current JobState. + + Returns: + state (job_state.JobState): + Enum that describes the state of a Model Monitoring Job. + """ + + # Fetch the Job again for most up-to-date job state + self._sync_gca_resource() + return self._gca_resource.state + + @classmethod + def _construct_sdk_resource_from_gapic( + cls, + gapic_resource: proto.Message, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "ModelMonitoringJob": + """Given a GAPIC ModelMonitoringJob object, return the SDK representation. + + Args: + gapic_resource (proto.Message): A GAPIC representation of a + ModelMonitoringJob resource, usually retrieved by a get_* or in a + list_* API call. + project (str): Optional. Project to construct ModelMonitoringJob object + from. If not set, project set in aiplatform.init will be used. + location (str): Optional. Location to construct ModelMonitoringJob + object from. If not set, location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): Optional. Custom credentials + to use to construct ModelMonitoringJob. Overrides credentials set in + aiplatform.init. + + Returns: + ModelMonitoringJob: The model monitoring job that was created. + """ + model_monitoring_job = super()._construct_sdk_resource_from_gapic( + gapic_resource=gapic_resource, + project=project, + location=location, + credentials=credentials, + ) + + return model_monitoring_job + + def _block_until_complete(self) -> None: + """Helper method to block and check on job until complete.""" + # Used these numbers so failures surface fast + wait = _JOB_WAIT_TIME # start at five seconds + log_wait = _LOG_WAIT_TIME + max_wait = _MAX_WAIT_TIME # 5 minute wait + multiplier = _WAIT_TIME_MULTIPLIER # scale wait by 2 every iteration + + previous_time = time.time() + while not self.done(): + current_time = time.time() + if current_time - previous_time >= log_wait: + _LOGGER.info( + "%s %s current state:\n%s" + % ( + self.__class__.__name__, + self._gca_resource.name, + self._gca_resource.state, + ) + ) + log_wait = min(log_wait * multiplier, max_wait) + previous_time = current_time + time.sleep(wait) + + # Error is only populated when the job state is JOB_STATE_FAILED. + if self._gca_resource.state in _JOB_ERROR_STATES: + raise RuntimeError( + "Job failed with:\n%s" % self._gca_resource.job_execution_detail.error + ) + elif ( + self._gca_resource.state + == gca_job_state.JobState.JOB_STATE_PARTIALLY_SUCCEEDED + ): + obj_status_msg = "" + for ( + obj, + status, + ) in self._gca_resource.job_execution_detail.objective_status.items(): + obj_status_msg += f"{obj}: {status}\n" + _LOGGER.warning("Job partially succeeded:\n%s" % obj_status_msg) + else: + _LOGGER.log_action_completed_against_resource("run", "completed", self) + + @classmethod + def create( + cls, + model_monitor_name: str = None, + target_dataset: objective.MonitoringInput = None, + display_name: Optional[str] = None, + model_monitoring_job_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + baseline_dataset: Optional[objective.MonitoringInput] = None, + tabular_objective_spec: Optional[objective.TabularObjective] = None, + output_spec: Optional[output.OutputSpec] = None, + notification_spec: Optional[notification.NotificationSpec] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + sync: bool = False, + ) -> "ModelMonitoringJob": + """Creates a new ModelMonitoringJob. + + Args: + model_monitor_name (str): Required. The parent model monitor resource + name. Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}`` + target_dataset (objective.MonitoringInput): Required. The target dataset + for analysis. + display_name (str): Optional. The user-defined name of the + ModelMonitoringJob. The name can be up to 128 characters long and can + comprise any UTF-8 character. + model_monitoring_job_id (str): Optional. The unique ID of the model + monitoring job run, which will become the final component of the model + monitoring job resource name. The maximum length is 63 characters, and + valid characters are /^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$/. If not + specified, it will be generated by Gemini Enterprise Agent Platform. + project (str): Optional. Project to retrieve endpoint from. If not set, + project set in aiplatform.init will be used. + location (str): Optional. Location to retrieve endpoint from. If not + set, location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): Optional. Custom credentials + to use to create model monitoring job. Overrides credentials set in + aiplatform.init. + baseline_dataset (objective.MonitoringInput): Optional. The baseline + dataset for monitoring job. If not set, the training dataset in + ModelMonitor will be used as baseline dataset. + output_spec (output.OutputSpec): Optional. The monitoring metrics/logs + export spec. If not set, will use the default output_spec defined in + ModelMonitor. + notification_spec (notification.NotificationSpec): Optional. The + notification spec for monitoring result. If not set, will use the + default notification_spec defined in ModelMonitor. + explanation_spec (explanation.ExplanationSpec): Optional. The + explanation spec for feature attribution monitoring. If not set, will + use the default explanation_spec defined in ModelMonitor. + sync (bool): Required. Whether to execute this method synchronously. If + False, this method will be executed in concurrent Future and any + downstream object will be immediately returned and synced when the + Future has completed. Default is False. + + Returns: + ModelMonitoringJob: The model monitoring job that was created. + """ + if not display_name: + display_name = cls._generate_display_name() + utils.validate_display_name(display_name) + + project = project or initializer.global_config.project + location = location or initializer.global_config.location + + parent = utils.full_resource_name( + resource_name=model_monitor_name, + resource_noun=ModelMonitor._resource_noun, + parse_resource_name_method=ModelMonitor._parse_resource_name, + format_resource_name_method=ModelMonitor._format_resource_name, + project=project, + location=location, + ) + + gca_model_monitoring_job = gca_model_monitoring_job_compat.ModelMonitoringJob( + display_name=display_name, + model_monitoring_spec=model_monitoring_spec.ModelMonitoringSpec( + objective_spec=model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=( + tabular_objective_spec._as_proto() + if tabular_objective_spec + else None + ), + baseline_dataset=( + baseline_dataset._as_proto() if baseline_dataset else None + ), + target_dataset=( + target_dataset._as_proto() if target_dataset else None + ), + explanation_spec=explanation_spec, + ), + output_spec=(output_spec._as_proto() if output_spec else None), + notification_spec=( + notification_spec._as_proto() if notification_spec else None + ), + ), + ) + empty_model_monitoring_job = cls._empty_constructor( + project=project, + location=location, + credentials=credentials, + ) + return cls._submit_job( + model_monitor_name=parent, + empty_model_monitoring_job=empty_model_monitoring_job, + gca_model_monitoring_job=gca_model_monitoring_job, + model_monitoring_job_id=model_monitoring_job_id, + sync=sync, + project=project, + location=location, + credentials=credentials, + ) + + @classmethod + @base.optional_sync(return_input_arg="empty_model_monitoring_job") + def _submit_job( + cls, + model_monitor_name: str, + empty_model_monitoring_job: "ModelMonitoringJob", + gca_model_monitoring_job: gca_model_monitoring_job_compat.ModelMonitoringJob, + sync: bool = False, + model_monitoring_job_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> "ModelMonitoringJob": + """Submits a new ModelMonitoringJob. + + Args: + model_monitor_name (str): Required. The parent model monitor resource + name. Format: + ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}`` + empty_model_monitoring_job (ModelMonitoringJob): Required. + ModelMonitoringJob without _gca_resource populated. + gca_model_monitoring_job + (gca_model_monitoring_job_compat.ModelMonitoringJob): Required. a + model monitoring job proto for creating a model monitoring job on + Gemini Enterprise Agent Platform. + sync (bool): Required. Whether to execute this method synchronously. If + False, this method will be executed in concurrent Future and any + downstream object will be immediately returned and synced when the + Future has completed. Default is False. + model_monitoring_job_id (str): Optional. The unique ID of the model + monitoring job run, which will become the final component of the model + monitoring job resource name. The maximum length is 63 characters, and + valid characters are /^[a-z]([a-z0-9-]{0,61}[a-z0-9])?$/. If not + specified, it will be generated by Gemini Enterprise Agent Platform. + project (str): Optional. Project to retrieve endpoint from. If not set, + project set in aiplatform.init will be used. + location (str): Optional. Location to retrieve endpoint from. If not + set, location set in aiplatform.init will be used. + credentials (auth_credentials.Credentials): Optional. Custom credentials + to use to create model monitoring job. Overrides credentials set in + aiplatform.init. + + Returns: + ModelMonitoringJob: The model monitoring job that was created. + """ + api_client = initializer.global_config.create_client( + client_class=cls.client_class, + credentials=credentials, + location_override=location, + ) + _LOGGER.log_create_with_lro(cls) + created_model_monitoring_job = api_client.create_model_monitoring_job( + request=model_monitoring_service.CreateModelMonitoringJobRequest( + parent=model_monitor_name, + model_monitoring_job=gca_model_monitoring_job, + model_monitoring_job_id=model_monitoring_job_id, + ), + ) + empty_model_monitoring_job._gca_resource = created_model_monitoring_job + model_monitoring_job = cls._construct_sdk_resource_from_gapic( + gapic_resource=created_model_monitoring_job, + project=project, + location=location, + credentials=credentials, + ) + _LOGGER.log_create_complete( + cls, created_model_monitoring_job, "model_monitoring_job" + ) + model_monitoring_job._block_until_complete() + return model_monitoring_job + + def delete(self) -> None: + """Deletes an Model Monitoring Job.""" + self.api_client.delete_model_monitoring_job(name=self._gca_resource.name) diff --git a/agentplatform/resources/preview/ml_monitoring/spec/__init__.py b/agentplatform/resources/preview/ml_monitoring/spec/__init__.py new file mode 100644 index 0000000000..f5fd1c3a73 --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/spec/__init__.py @@ -0,0 +1,44 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agentplatform.resources.preview.ml_monitoring.spec.notification import ( + NotificationSpec, +) +from agentplatform.resources.preview.ml_monitoring.spec.objective import ( + DataDriftSpec, + FeatureAttributionSpec, + MonitoringInput, + ObjectiveSpec, + TabularObjective, +) +from agentplatform.resources.preview.ml_monitoring.spec.output import ( + OutputSpec, +) +from agentplatform.resources.preview.ml_monitoring.spec.schema import ( + FieldSchema, + ModelMonitoringSchema, +) + +__all__ = ( + "NotificationSpec", + "OutputSpec", + "ObjectiveSpec", + "FeatureAttributionSpec", + "DataDriftSpec", + "MonitoringInput", + "TabularObjective", + "FieldSchema", + "ModelMonitoringSchema", +) diff --git a/agentplatform/resources/preview/ml_monitoring/spec/notification.py b/agentplatform/resources/preview/ml_monitoring/spec/notification.py new file mode 100644 index 0000000000..46972141a3 --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/spec/notification.py @@ -0,0 +1,72 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional, List +from google.cloud.aiplatform.compat.types import ( + model_monitoring_spec_v1beta1 as model_monitoring_spec, +) + + +class NotificationSpec: + """Initializer for NotificationSpec. + + Args: + user_emails (List[str]): + Optional. The email addresses to send the alert to. + notification_channels (List[str]): + Optional. The notification channels to send the alert to. + Format: ``projects/{project}/notificationChannels/{channel}`` + enable_cloud_logging (bool): + Optional. If dump the anomalies to Cloud Logging. The anomalies will + be put to json payload. This can be further sinked to Pub/Sub or any + other services supported by Cloud Logging. + """ + + def __init__( + self, + user_emails: Optional[List[str]] = None, + notification_channels: Optional[List[str]] = None, + enable_cloud_logging: Optional[bool] = False, + ): + self.user_emails = user_emails + self.notification_channels = notification_channels + self.enable_cloud_logging = enable_cloud_logging + + def _as_proto(self) -> model_monitoring_spec.ModelMonitoringNotificationSpec: + """Converts ModelMonitoringNotificationSpec to a proto message. + + Returns: + The GAPIC representation of the notification alert config. + """ + user_email_config = None + if self.user_emails is not None: + user_email_config = ( + model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=self.user_emails + ) + ) + user_notification_channel_config = [] + if self.notification_channels: + for notification_channel in self.notification_channels: + user_notification_channel_config.append( + model_monitoring_spec.ModelMonitoringNotificationSpec.NotificationChannelConfig( + notification_channel=notification_channel + ) + ) + return model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=user_email_config, + notification_channel_configs=user_notification_channel_config, + enable_cloud_logging=self.enable_cloud_logging, + ) diff --git a/agentplatform/resources/preview/ml_monitoring/spec/objective.py b/agentplatform/resources/preview/ml_monitoring/spec/objective.py new file mode 100644 index 0000000000..089512141d --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/spec/objective.py @@ -0,0 +1,498 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, List, Optional + +from google.cloud.aiplatform.compat.types import ( + explanation_v1beta1 as explanation, + machine_resources_v1beta1 as machine_resources, + model_monitoring_alert_v1beta1 as model_monitoring_alert, + model_monitoring_spec_v1beta1 as model_monitoring_spec, +) + +from google.protobuf import timestamp_pb2 +from google.type import interval_pb2 + +TF_RECORD = "tf-record" +CSV = "csv" +JSONL = "jsonl" +JENSEN_SHANNON_DIVERGENCE = "jensen_shannon_divergence" +L_INFINITY = "l_infinity" +SUPPORTED_NUMERIC_METRICS = [JENSEN_SHANNON_DIVERGENCE] +SUPPORTED_CATEGORICAL_METRICS = [JENSEN_SHANNON_DIVERGENCE, L_INFINITY] + + +class DataDriftSpec: + """Data drift monitoring spec. + + Data drift measures the distribution distance between the current dataset + and a baseline dataset. A typical use case is to detect data drift between + the recent production serving dataset and the training dataset, or to + compare the recent production dataset with a dataset from a previous period. + + Example: + feature_drift_spec=DataDriftSpec( + features=["feature1"] + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_threshold=0.01, + default_numeric_alert_threshold=0.02, + feature_alert_thresholds={"feature1":0.02, "feature2":0.01}, + ) + + Attributes: + features (List[str]): Optional. Feature names / Prediction output names + interested in monitoring. These should be a subset of the input feature + names or prediction output names specified in the monitoring schema. If + not specified, all features / prediction outputs outlied in the + monitoring schema will be used. + categorical_metric_type (str): Optional. Supported metrics type: + l_infinity, jensen_shannon_divergence + numeric_metric_type (str): Optional. Supported metrics type: + jensen_shannon_divergence + default_categorical_alert_threshold (float): Optional. Default alert + threshold for all the categorical features. + default_numeric_alert_threshold (float): Optional. Default alert threshold + for all the numeric features. + feature_alert_thresholds (Dict[str, float]): Optional. Per feature alert + threshold will override default alert threshold. + """ + + def __init__( + self, + features: Optional[List[str]] = None, + categorical_metric_type: Optional[str] = L_INFINITY, + numeric_metric_type: Optional[str] = JENSEN_SHANNON_DIVERGENCE, + default_categorical_alert_threshold: Optional[float] = None, + default_numeric_alert_threshold: Optional[float] = None, + feature_alert_thresholds: Optional[Dict[str, float]] = None, + ): + self.features = features + self.categorical_metric_type = categorical_metric_type + self.numeric_metric_type = numeric_metric_type + self.default_categorical_alert_threshold = ( + default_categorical_alert_threshold + ) + self.default_numeric_alert_threshold = default_numeric_alert_threshold + self.feature_alert_thresholds = feature_alert_thresholds + + def _as_proto( + self, + ) -> model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec: + """Converts DataDriftSpec to a proto message. + + Returns: + The GAPIC representation of the data drift spec. + """ + user_default_categorical_alert_threshold = None + user_default_numeric_alert_threshold = None + user_alert_thresholds = None + user_features = None + if self.numeric_metric_type not in SUPPORTED_NUMERIC_METRICS: + raise ValueError( + f"The numeric metric type is not supported {self.numeric_metric_type}" + ) + user_numeric_metric_type = self.numeric_metric_type + if self.categorical_metric_type not in SUPPORTED_CATEGORICAL_METRICS: + raise ValueError( + "The categorical metric type is not supported" + f" {self.categorical_metric_type}" + ) + user_categorical_metric_type = self.categorical_metric_type + if self.default_categorical_alert_threshold: + user_default_categorical_alert_threshold = ( + model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=self.default_categorical_alert_threshold + ) + ) + if self.default_numeric_alert_threshold: + user_default_numeric_alert_threshold = ( + model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=self.default_numeric_alert_threshold + ) + ) + if self.feature_alert_thresholds: + user_alert_thresholds = {} + for feature in self.feature_alert_thresholds: + user_alert_thresholds.update({ + feature: model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=self.feature_alert_thresholds[feature] + ) + }) + if self.features: + user_features = self.features + return model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + default_categorical_alert_condition=user_default_categorical_alert_threshold, + default_numeric_alert_condition=user_default_numeric_alert_threshold, + categorical_metric_type=user_categorical_metric_type, + numeric_metric_type=user_numeric_metric_type, + feature_alert_conditions=user_alert_thresholds, + features=user_features, + ) + + +class FeatureAttributionSpec: + """Feature attribution spec. + + Example: + feature_attribution_spec=FeatureAttributionSpec( + features=["feature1"] + default_alert_threshold=0.01, + feature_alert_thresholds={"feature1":0.02, "feature2":0.01}, + batch_dedicated_resources=BatchDedicatedResources( + starting_replica_count=1, + max_replica_count=2, + machine_spec=my_machine_spec, + ), + ) + + Attributes: + features (List[str]): Optional. Input feature names interested in + monitoring. These should be a subset of the input feature names + specified in the monitoring schema. If not specified, all features + outlied in the monitoring schema will be used. + default_alert_threshold (float): Optional. Default alert threshold for all + the features. + feature_alert_thresholds (Dict[str, float]): Optional. Per feature alert + threshold will override default alert threshold. + batch_dedicated_resources (machine_resources.BatchDedicatedResources): + Optional. The config of resources used by the Model Monitoring during + the batch explanation for non-AutoML models. If not set, `n1-standard-2` + machine type will be used by default. + """ + + def __init__( + self, + features: Optional[List[str]] = None, + default_alert_threshold: Optional[float] = None, + feature_alert_thresholds: Optional[Dict[str, float]] = None, + batch_dedicated_resources: Optional[ + machine_resources.BatchDedicatedResources + ] = None, + ): + self.features = features + self.default_alert_threshold = default_alert_threshold + self.feature_alert_thresholds = feature_alert_thresholds + self.batch_dedicated_resources = batch_dedicated_resources + + def _as_proto( + self, + ) -> ( + model_monitoring_spec.ModelMonitoringObjectiveSpec.FeatureAttributionSpec + ): + """Converts FeatureAttributionSpec to a proto message. + + Returns: + The GAPIC representation of the feature attribution spec. + """ + user_default_alert_threshold = None + user_alert_thresholds = None + user_features = None + if self.default_alert_threshold: + user_default_alert_threshold = ( + model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=self.default_alert_threshold + ) + ) + if self.feature_alert_thresholds: + user_alert_thresholds = {} + for feature in self.feature_alert_thresholds: + user_alert_thresholds.update({ + feature: model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=self.feature_alert_thresholds[feature] + ) + }) + if self.features: + user_features = self.features + return model_monitoring_spec.ModelMonitoringObjectiveSpec.FeatureAttributionSpec( + default_alert_condition=user_default_alert_threshold, + feature_alert_conditions=user_alert_thresholds, + features=user_features, + batch_explanation_dedicated_resources=self.batch_dedicated_resources, + ) + + +class MonitoringInput: + """Model monitoring data input spec. + + Attributes: + vertex_dataset (str): Optional. Resource name of the Gemini Enterprise + Agent Platform managed dataset. + Format: ``projects/{project}/locations/{location}/datasets/{dataset}`` + At least one source of dataset should be provided, and if one of the + fields is set, no need to set other sources (vertex_dataset, + gcs_uri, table_uri, query, batch_prediction_job, endpoints). + gcs_uri (str): Optional. Google Cloud Storage URI to the input file(s). + May contain wildcards. + data_format (str): Optional. Data format of Google Cloud Storage file(s). + Should be provided if a gcs_uri is set. Supported formats: "csv", + "jsonl", "tf-record" + table_uri (str): Optonal. BigQuery URI to a table, up to 2000 characters + long. All the columns in the table will be selected. Accepted forms: - + BigQuery path. For example: ``bq://projectId.bqDatasetId.bqTableId``. + query (str): Optional. Standard SQL for BigQuery to be used instead of the + ``table_uri``. + timestamp_field (str): Optional. The timestamp field in the dataset. the + ``timestamp_field`` must be specified if you'd like to use + ``start_time``, ``end_time``, ``offset`` or ``window``. If you use + ``query`` to specify the dataset, make sure the ``timestamp_field`` is + in the selection fields. + batch_prediction_job (str): Optional. Gemini Enterprise Agent Platform + Batch Prediction Job resource name. + Format: + ``projects/{project}/locations/{location}/batchPredictionJobs/{batch_prediction_job}`` + endpoints (List[str]): Optional. List of Gemini Enterprise Agent Platform + Endpoint resource names. + Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + start_time (timestamp_pb2.Timestamp): Optional. Inclusive start of the + time interval for which results should be returned. Should be set + together with ``end_time``. + end_time (timestamp_pb2.Timestamp): Optional. Exclusive end of the time + interval for which results should be returned. Should be set together + with ``start_time`.` + offset (str): Optional. Offset is the time difference from the cut-off + time. For scheduled jobs, the cut-off time is the scheduled time. For + non-scheduled jobs, it's the time when the job was created. Currently we + support the following format: 'w|W': Week, 'd|D': Day, 'h|H': Hour E.g. + '1h' stands for 1 hour, '2d' stands for 2 days. + window (str): Optional. Window refers to the scope of data selected for + analysis. It allows you to specify the quantity of data you wish to + examine. It refers to the data time window prior to the cut-off time or + the cut-off time minus the offset. Currently we support the following + format: 'w|W': Week, 'd|D': Day, 'h|H': Hour E.g. '1h' stands for 1 + hour, '2d' stands for 2 days. + """ + + def __init__( + self, + vertex_dataset: Optional[str] = None, + gcs_uri: Optional[str] = None, + data_format: Optional[str] = None, + table_uri: Optional[str] = None, + query: Optional[str] = None, + timestamp_field: Optional[str] = None, + batch_prediction_job: Optional[str] = None, + endpoints: Optional[List[str]] = None, + start_time: Optional[timestamp_pb2.Timestamp] = None, + end_time: Optional[timestamp_pb2.Timestamp] = None, + offset: Optional[str] = None, + window: Optional[str] = None, + ): + self.vertex_dataset = vertex_dataset + self.gcs_uri = gcs_uri + self.data_format = data_format + self.table_uri = table_uri + self.query = query + self.timestamp_field = timestamp_field + self.batch_prediction_job = batch_prediction_job + self.endpoints = endpoints + self.start_time = start_time + self.end_time = end_time + self.offset = offset + self.window = window + + def _as_proto(self) -> model_monitoring_spec.ModelMonitoringInput: + """Converts ModelMonitoringInput to a proto message. + + Returns: + The GAPIC representation of the model monitoring input. + """ + user_time_interval = None + user_time_spec = None + if self.offset or self.window: + user_time_spec = model_monitoring_spec.ModelMonitoringInput.TimeOffset( + offset=self.offset if self.offset else None, + window=self.window if self.window else None, + ) + elif self.start_time or self.end_time: + user_time_interval = interval_pb2.Interval( + start_time=self.start_time if self.start_time else None, + end_time=self.end_time if self.end_time else None, + ) + if self.vertex_dataset or self.gcs_uri or self.table_uri or self.query: + user_vertex_dataset = None + user_gcs_source = None + user_bigquery_source = None + if self.vertex_dataset: + user_vertex_dataset = self.vertex_dataset + elif self.gcs_uri: + if not self.data_format: + raise ValueError("`data_format` must be provided with gcs uri.") + if self.data_format == CSV: + user_data_format = ( + model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource.DataFormat.CSV + ) + elif self.data_format == JSONL: + user_data_format = ( + model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource.DataFormat.JSONL + ) + elif self.data_format == TF_RECORD: + user_data_format = ( + model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource.DataFormat.TF_RECORD + ) + else: + raise ValueError( + ( + "Unsupported value in data format. `data_format` " + "must be one of %s, %s, or %s" + ) + % (TF_RECORD, CSV, JSONL) + ) + user_gcs_source = model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringGcsSource( + gcs_uri=self.gcs_uri, + format_=user_data_format, + ) + elif self.table_uri or self.query: + user_bigquery_source = model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset.ModelMonitoringBigQuerySource( + table_uri=self.table_uri, + query=self.query, + ) + else: + raise ValueError("At least one source of dataset must be provided.") + user_model_monitoring_dataset = ( + model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=user_vertex_dataset, + gcs_source=user_gcs_source, + bigquery_source=user_bigquery_source, + timestamp_field=self.timestamp_field, + ) + ) + return model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=user_model_monitoring_dataset, + time_offset=user_time_spec, + time_interval=user_time_interval, + ) + elif self.batch_prediction_job: + user_batch_prediction_output = ( + model_monitoring_spec.ModelMonitoringInput.BatchPredictionOutput( + batch_prediction_job=self.batch_prediction_job, + ) + ) + return model_monitoring_spec.ModelMonitoringInput( + batch_prediction_output=user_batch_prediction_output, + time_offset=user_time_spec, + time_interval=user_time_interval, + ) + elif self.endpoints: + user_vertex_endpoint_logs = ( + model_monitoring_spec.ModelMonitoringInput.VertexEndpointLogs( + endpoints=self.endpoints, + ) + ) + return model_monitoring_spec.ModelMonitoringInput( + vertex_endpoint_logs=user_vertex_endpoint_logs, + time_offset=user_time_spec, + time_interval=user_time_interval, + ) + else: + raise ValueError("At least one source of dataInput must be provided.") + + +class TabularObjective: + """Initializer for TabularObjective. + + Attributes: + feature_drift_spec (DataDriftSpec): Optional. Input feature distribution + drift monitoring spec. + prediction_output_drift_spec (DataDriftSpec): Optional. Prediction output + distribution drift monitoring spec. + feature_attribution_spec (FeatureAttributionSpec): Optional. Feature + attribution monitoring spec. + """ + + def __init__( + self, + feature_drift_spec: Optional[DataDriftSpec] = None, + prediction_output_drift_spec: Optional[DataDriftSpec] = None, + feature_attribution_spec: Optional[FeatureAttributionSpec] = None, + ): + self.feature_drift_spec = feature_drift_spec + self.prediction_output_drift_spec = prediction_output_drift_spec + self.feature_attribution_spec = feature_attribution_spec + + def _as_proto( + self, + ) -> model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective: + """Converts TabularObjective to a proto message. + + Returns: + The GAPIC representation of the model monitoring tabular objective. + """ + user_feature_drift_spec = None + user_prediction_output_drift_spec = None + user_feature_attribution_spec = None + if self.feature_drift_spec: + user_feature_drift_spec = self.feature_drift_spec._as_proto() + if self.prediction_output_drift_spec: + user_prediction_output_drift_spec = ( + self.prediction_output_drift_spec._as_proto() + ) + if self.feature_attribution_spec: + user_feature_attribution_spec = self.feature_attribution_spec._as_proto() + return model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=user_feature_drift_spec, + prediction_output_drift_spec=user_prediction_output_drift_spec, + feature_attribution_spec=user_feature_attribution_spec, + ) + + +class ObjectiveSpec: + """Initializer for ObjectiveSpec. + + Args: + baseline_dataset (MonitoringInput): Required. Baseline datasets that are + used by all the monitoring objectives. It could be the training dataset + or production serving dataset from a previous period. + target_dataset (MonitoringInput): Required. Target dataset for monitoring + analysis, it's used by all the monitoring objectives. + tabular_objective (TabularObjective): Optional. The tabular monitoring + objective. + explanation_spec (explanation.ExplanationSpec): Optional. The explanation + spec. This spec is required when the objectives spec includes feature + attribution objectives. + """ + + def __init__( + self, + baseline_dataset: MonitoringInput, + target_dataset: MonitoringInput, + tabular_objective: Optional[TabularObjective] = None, + explanation_spec: Optional[explanation.ExplanationSpec] = None, + ): + self.baseline = baseline_dataset + self.target = target_dataset + self.tabular_objective = tabular_objective + self.explanation_spec = explanation_spec + + def _as_proto(self) -> model_monitoring_spec.ModelMonitoringObjectiveSpec: + """Converts ModelMonitoringObjectiveSpec to a proto message. + + Returns: + The GAPIC representation of the model monitoring objective config. + """ + user_tabular_objective = None + if not self.baseline or not self.target: + raise ValueError("At least one objective must be provided.") + if self.tabular_objective: + user_tabular_objective = self.tabular_objective._as_proto() + return model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=user_tabular_objective, + explanation_spec=self.explanation_spec + if self.explanation_spec + else None, + target_dataset=self.target._as_proto(), + baseline_dataset=self.baseline._as_proto(), + ) diff --git a/agentplatform/resources/preview/ml_monitoring/spec/output.py b/agentplatform/resources/preview/ml_monitoring/spec/output.py new file mode 100644 index 0000000000..0156082e48 --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/spec/output.py @@ -0,0 +1,46 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat.types import ( + io_v1beta1 as io, + model_monitoring_spec_v1beta1 as model_monitoring_spec, +) + + +class OutputSpec: + """Initializer for OutputSpec. + + Args: + data_source (str): + Optional. Google Cloud Storage base folder path for metrics, error + logs, etc. + """ + + def __init__( + self, + gcs_base_dir: str, + ): + self.gcs_base_dir = gcs_base_dir + + def _as_proto(self) -> model_monitoring_spec.ModelMonitoringOutputSpec: + """Converts ModelMonitoringOutputSpec to a proto message. + + Returns: + The GAPIC representation of the notification alert config. + """ + user_gcs_base_dir = io.GcsDestination(output_uri_prefix=self.gcs_base_dir) + return model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=user_gcs_base_dir, + ) diff --git a/agentplatform/resources/preview/ml_monitoring/spec/schema.py b/agentplatform/resources/preview/ml_monitoring/spec/schema.py new file mode 100644 index 0000000000..58ed5dcae4 --- /dev/null +++ b/agentplatform/resources/preview/ml_monitoring/spec/schema.py @@ -0,0 +1,439 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import logging +import os +from typing import Dict, List, MutableSequence, Optional +from google.cloud import bigquery +from google.cloud.aiplatform.compat.types import ( + model_monitor_v1beta1 as model_monitor, +) + +try: + import pandas as pd +except ImportError: + pd = None +try: + import tensorflow as tf +except ImportError: + tf = None + + +class FieldSchema: + """Field Schema. + + The class identifies the data type of a single feature, + which combines together to form the Schema for different fields in + ModelMonitoringSchema. + + Attributes: + name (str): + Required. Field name. + data_type (str): + Required. Supported data types are: ``float``, ``integer`` + ``boolean``, ``string``, ``categorical``. + repeated (bool): + Optional. Describes if the schema field is an array of given data + type. + """ + + def __init__( + self, + name: str, + data_type: str, + repeated: Optional[bool] = False, + ): + self.name = name + self.data_type = data_type + self.repeated = repeated + + def _as_proto(self) -> model_monitor.ModelMonitoringSchema.FieldSchema: + """Converts ModelMonitoringSchema.FieldSchema to a proto message. + + Returns: + The GAPIC representation of the model monitoring field schema. + """ + return model_monitor.ModelMonitoringSchema.FieldSchema( + name=self.name, + data_type=self.data_type, + repeated=self.repeated, + ) + + +class ModelMonitoringSchema: + """Initializer for ModelMonitoringSchema. + + Args: + feature_fields (MutableSequence[FieldSchema]): + Required. Feature names of the model. Gemini Enterprise Agent Platform will try to match + the features from your dataset as follows: + * For 'csv' files, the header names are required, and we will + extract thecorresponding feature values when the header names + align with the feature names. + * For 'jsonl' files, we will extract the corresponding feature + values if the key names match the feature names. Note: Nested + features are not supported, so please ensure your features are + flattened. Ensure the feature values are scalar or an array of + scalars. + * For 'bigquery' dataset, we will extract the corresponding feature + values if the column names match the feature names. + Note: The column type can be a scalar or an array of scalars. + STRUCT or JSON types are not supported. You may use SQL queries to + select or aggregate the relevant features from your original + table. However, ensure that the 'schema' of the query results + meets our requirements. + * For the Gemini Enterprise Agent Platform Endpoint Request Response Logging table or + Gemini Enterprise Agent Platform Batch Prediction Job results. If the prediction + instance format is an array, ensure that the sequence in + ``feature_fields`` matches the order of features in the prediction + instance. We will match the feature with the array in the order + specified in ``feature_fields``. + prediction_fields (MutableSequence[FieldSchema]): + Optional. Prediction output names of the model. The requirements are + the same as the ``feature_fields``. + For AutoML Tables, the prediction output name presented in schema + will be: `predicted_{target_column}`, the `target_column` is the one + you specified when you train the model. + For Prediction output drift analysis: + * AutoML Classification, the distribution of the argmax label will + be analyzed. + * AutoML Regression, the distribution of the value will be analyzed. + ground_truth_fields (MutableSequence[FieldSchema]): + Optional. Target /ground truth names of the model. + """ + + def __init__( + self, + feature_fields: MutableSequence[FieldSchema], + ground_truth_fields: Optional[MutableSequence[FieldSchema]] = None, + prediction_fields: Optional[MutableSequence[FieldSchema]] = None, + ): + self.feature_fields = feature_fields + self.prediction_fields = prediction_fields + self.ground_truth_fields = ground_truth_fields + + def _as_proto(self) -> model_monitor.ModelMonitoringSchema: + """Converts ModelMonitoringSchema to a proto message. + + Returns: + The GAPIC representation of the model monitoring schema. + """ + user_feature_fields = list() + user_prediction_fields = list() + user_ground_truth_fields = list() + for field in self.feature_fields: + user_feature_fields.append(field._as_proto()) + if self.prediction_fields: + for field in self.prediction_fields: + user_prediction_fields.append(field._as_proto()) + if self.ground_truth_fields: + for field in self.ground_truth_fields: + user_ground_truth_fields.append(field._as_proto()) + return model_monitor.ModelMonitoringSchema( + feature_fields=user_feature_fields, + prediction_fields=( + user_prediction_fields if self.prediction_fields else None + ), + ground_truth_fields=( + user_ground_truth_fields if self.ground_truth_fields else None + ), + ) + + def to_json(self, output_dir: Optional[str] = None) -> str: + """Transform ModelMonitoringSchema to json format. + + Args: + output_dir (str): + Optional. The output directory that the transformed json file + would be put into. + """ + result = model_monitor.ModelMonitoringSchema.to_json(self._as_proto()) + if output_dir: + result_path = os.path.join(output_dir, "model_monitoring_schema.json") + with tf.io.gfile.GFile(result_path, "w") as f: + json.dump(result, f) + f.close() + logging.info("Transformed schema to json file: %s", result_path) + return result + + +def _check_duplicate( + field: str, + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, +) -> bool: + """Check if a field appears in two field lists.""" + feature = True + ground_truth = True + prediction = True + if not feature_fields or field not in feature_fields: + feature = False + if not ground_truth_fields or field not in ground_truth_fields: + ground_truth = False + if not prediction_fields or field not in prediction_fields: + prediction = False + return feature if (feature == ground_truth) else prediction + + +def _transform_schema_pandas( + dataset: Dict[str, str], + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, +) -> ModelMonitoringSchema: + """Transforms the pandas schema to model monitoring schema.""" + ground_truth_fields_list = list() + prediction_fields_list = list() + feature_fields_list = list() + pandas_integer_types = ["integer", "Int32", "Int64", "UInt32", "UInt64"] + pandas_string_types = [ + "string", + "bytes", + "date", + "time", + "datetime64", + "datetime", + "mixed-integer", + "inteval", + "Interval", + ] + pandas_float_types = [ + "floating", + "decimal", + "mixed-integer-float", + "Float32", + "Float64", + ] + for field in dataset: + infer_type = dataset[field] + if infer_type in pandas_string_types: + data_type = "string" + elif infer_type in pandas_integer_types: + data_type = "integer" + elif infer_type in pandas_float_types: + data_type = "float" + elif infer_type == "boolean": + data_type = "boolean" + elif infer_type == "categorical" or infer_type == "category": + data_type = "categorical" + else: + raise ValueError(f"Unsupported data type: {infer_type}") + if _check_duplicate( + field, feature_fields, ground_truth_fields, prediction_fields + ): + raise ValueError(f"The field {field} specified in two or more field lists") + if ground_truth_fields and field in ground_truth_fields: + ground_truth_fields_list.append( + FieldSchema( + name=field, + data_type=data_type, + ) + ) + elif prediction_fields and field in prediction_fields: + prediction_fields_list.append( + FieldSchema( + name=field, + data_type=data_type, + ) + ) + elif (feature_fields and field in feature_fields) or not feature_fields: + feature_fields_list.append( + FieldSchema( + name=field, + data_type=data_type, + ) + ) + return ModelMonitoringSchema( + ground_truth_fields=ground_truth_fields_list if ground_truth_fields else None, + prediction_fields=prediction_fields_list if prediction_fields else None, + feature_fields=feature_fields_list, + ) + + +def transform_schema_from_bigquery( + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, + table: Optional[str] = None, + query: Optional[str] = None, +) -> ModelMonitoringSchema: + """Transform the existing dataset to ModelMonitoringSchema as model monitor + could accept. + + Args: + feature_fields (List[str]): + Optional. The input feature fields for given dataset. + By default all features we find would be the input features. + ground_truth_fields (List[str]): + Optional. The ground truth fields for given dataset. + By default all features we find would be the input features. + prediction_fields (List[str]): + Optional. The prediction output field for given dataset. + By default all features we find would be the input features. + table (str): + Optional. The BigQuery table uri. + query (str): + Optional. The BigQuery query. + """ + ground_truth_fields_list = list() + prediction_fields_list = list() + feature_fields_list = list() + bq_string_types = [ + "STRING", + "BYTES", + "DATE", + "TIME", + "GEOGRAPHY", + "DATETIME", + "JSON", + "INTEVAL", + "RANGE", + ] + bq_integer_types = ["INTEGER", "INT64", "TIMESTAMP"] + bq_float_types = ["FLOAT", "DOUBLE", "FLOAT64", "NUMERIC", "BIGNUMERIC"] + if table: + if table.startswith("bq://"): + table = table[len("bq://") :] + try: + client = bigquery.Client() + table = client.get_table(table) + bq_schema = table.schema + except Exception as e: + raise ValueError("Failed to get table from bq address provided.") from e + elif query: + try: + client = bigquery.Client() + bq_schema = client.query( + query=query, job_config=bigquery.job.QueryJobConfig(dry_run=True) + ).schema + except Exception as e: + raise ValueError("Failed to get query from bq address provided.") from e + else: + raise ValueError("Either table or query must be provided.") + for field in bq_schema: + if field.field_type in bq_string_types: + data_type = "string" + elif field.field_type in bq_integer_types: + data_type = "integer" + elif field.field_type in bq_float_types: + data_type = "float" + elif field.field_type == "BOOLEAN" or field.field_type == "BOOL": + data_type = "boolean" + else: + raise ValueError(f"Unsupported data type: {field.field_type}") + if _check_duplicate( + field.name, feature_fields, ground_truth_fields, prediction_fields + ): + raise ValueError( + f"The field {field.name} specified in two or more field lists" + ) + if ground_truth_fields and field.name in ground_truth_fields: + ground_truth_fields_list.append( + FieldSchema( + name=field.name, + data_type=data_type, + repeated=True if field.mode == "REPEATED" else False, + ) + ) + elif prediction_fields and field.name in prediction_fields: + prediction_fields_list.append( + FieldSchema( + name=field.name, + data_type=data_type, + repeated=True if field.mode == "REPEATED" else False, + ) + ) + elif (feature_fields and field.name in feature_fields) or not feature_fields: + feature_fields_list.append( + FieldSchema( + name=field.name, + data_type=data_type, + repeated=True if field.mode == "REPEATED" else False, + ) + ) + return ModelMonitoringSchema( + ground_truth_fields=ground_truth_fields_list if ground_truth_fields else None, + prediction_fields=prediction_fields_list if prediction_fields else None, + feature_fields=feature_fields_list, + ) + + +def transform_schema_from_csv( + file_path: str, + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, +) -> ModelMonitoringSchema: + """Transform the existing dataset to ModelMonitoringSchema as model monitor could accept. + + Args: + file_path (str): + Required. The dataset file path. + feature_fields (List[str]): + Optional. The input feature fields for given dataset. + By default all features we find would be the input features. + ground_truth_fields (List[str]): + Optional. The ground truth fields for given dataset. + By default all features we find would be the input features. + prediction_fields (List[str]):s + Optional. The prediction output field for given dataset. + By default all features we find would be the input features. + """ + with tf.io.gfile.GFile(file_path, "r") as f: + input_dataset = pd.read_csv(f) + dict_dataset = dict() + for field in input_dataset.columns: + dict_dataset[field] = input_dataset.convert_dtypes().dtypes[field] + monitoring_schema = _transform_schema_pandas( + dict_dataset, feature_fields, ground_truth_fields, prediction_fields + ) + f.close() + return monitoring_schema + + +def transform_schema_from_json( + file_path: str, + feature_fields: Optional[List[str]] = None, + ground_truth_fields: Optional[List[str]] = None, + prediction_fields: Optional[List[str]] = None, +) -> ModelMonitoringSchema: + """Transform the existing dataset to ModelMonitoringSchema as model monitor + could accept. + + Args: + file_path (str): + Required. The dataset file path. + feature_fields (List[str]): + Optional. The input feature fields for given dataset. + By default all features we find would be the input features. + ground_truth_fields (List[str]): + Optional. The ground truth fields for given dataset. + By default all features we find would be the input features. + prediction_fields (List[str]): + Optional. The prediction output field for given dataset. + By default all features we find would be the input features. + """ + with tf.io.gfile.GFile(file_path, "r") as f: + input_dataset = pd.read_json(f, lines=True) + dict_dataset = dict() + for field in input_dataset.columns: + dict_dataset[field] = input_dataset.convert_dtypes().dtypes[field] + monitoring_schema = _transform_schema_pandas( + dict_dataset, feature_fields, ground_truth_fields, prediction_fields + ) + f.close() + return monitoring_schema diff --git a/tests/unit/agentplatform/conftest.py b/tests/unit/agentplatform/conftest.py index 1480b0a7b6..d4db9f37c6 100644 --- a/tests/unit/agentplatform/conftest.py +++ b/tests/unit/agentplatform/conftest.py @@ -22,8 +22,37 @@ from google import auth from google.auth import credentials as auth_credentials from google.cloud import storage +from google.api_core import operation as ga_operation from google.cloud.aiplatform import base as aiplatform_base +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, +) +from google.cloud.aiplatform.compat.services import ( + feature_registry_service_client, +) +from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( + FeatureRegistryServiceClient, +) +from .feature_store_constants import ( + _TEST_BIGTABLE_FOS1, + _TEST_EMBEDDING_FV1, + _TEST_ESF_OPTIMIZED_FOS, + _TEST_ESF_OPTIMIZED_FOS2, + _TEST_FG1, + _TEST_FG1_F1, + _TEST_FG1_F2, + _TEST_FG1_FM1, + _TEST_FV_LIST, + _TEST_FV1, + _TEST_FV3, + _TEST_FV4, + _TEST_OPTIMIZED_EMBEDDING_FV, + _TEST_OPTIMIZED_FV1, + _TEST_OPTIMIZED_FV2, + _TEST_PSC_OPTIMIZED_FOS, +) import pytest +from unittest.mock import patch _TEST_PROJECT = "test-project" @@ -156,3 +185,257 @@ def mock_temp_dir(): def mock_named_temp_file(): with mock.patch.object(tempfile, "NamedTemporaryFile") as named_temp_file_mock: yield named_temp_file_mock + + +@pytest.fixture +def get_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_BIGTABLE_FOS1 + yield get_fos_mock + + +@pytest.fixture +def get_esf_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_ESF_OPTIMIZED_FOS + yield get_fos_mock + + +@pytest.fixture +def get_psc_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_PSC_OPTIMIZED_FOS + yield get_fos_mock + + +@pytest.fixture +def get_esf_optimized_fos_no_endpoint_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_ESF_OPTIMIZED_FOS2 + yield get_fos_mock + + +@pytest.fixture +def create_bigtable_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_bigtable_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_BIGTABLE_FOS1 + create_bigtable_fos_mock.return_value = create_fos_lro_mock + yield create_bigtable_fos_mock + + +@pytest.fixture +def create_esf_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_esf_optimized_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_ESF_OPTIMIZED_FOS + create_esf_optimized_fos_mock.return_value = create_fos_lro_mock + yield create_esf_optimized_fos_mock + + +@pytest.fixture +def create_psc_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_psc_optimized_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_PSC_OPTIMIZED_FOS + create_psc_optimized_fos_mock.return_value = create_fos_lro_mock + yield create_psc_optimized_fos_mock + + +@pytest.fixture +def get_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_fv_mock: + get_fv_mock.return_value = _TEST_FV1 + yield get_fv_mock + + +@pytest.fixture +def get_rag_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_rag_fv_mock: + get_rag_fv_mock.return_value = _TEST_FV3 + yield get_rag_fv_mock + + +@pytest.fixture +def get_registry_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_rag_fv_mock: + get_rag_fv_mock.return_value = _TEST_FV4 + yield get_rag_fv_mock + + +@pytest.fixture +def list_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_views", + ) as list_fv: + list_fv.return_value = _TEST_FV_LIST + yield list_fv + + +@pytest.fixture +def create_bq_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_bq_fv_mock: + create_bq_fv_lro_mock = mock.Mock(ga_operation.Operation) + create_bq_fv_lro_mock.result.return_value = _TEST_FV1 + create_bq_fv_mock.return_value = create_bq_fv_lro_mock + yield create_bq_fv_mock + + +@pytest.fixture +def create_rag_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_rag_fv_mock: + create_rag_fv_lro_mock = mock.Mock(ga_operation.Operation) + create_rag_fv_lro_mock.result.return_value = _TEST_FV3 + create_rag_fv_mock.return_value = create_rag_fv_lro_mock + yield create_rag_fv_mock + + +@pytest.fixture +def create_registry_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_registry_fv_mock: + create_registry_fv_lro_mock = mock.Mock(ga_operation.Operation) + create_registry_fv_lro_mock.result.return_value = _TEST_FV4 + create_registry_fv_mock.return_value = create_registry_fv_lro_mock + yield create_registry_fv_mock + + +@pytest.fixture +def create_embedding_fv_from_bq_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_embedding_fv_mock: + create_embedding_fv_mock_lro = mock.Mock(ga_operation.Operation) + create_embedding_fv_mock_lro.result.return_value = _TEST_OPTIMIZED_EMBEDDING_FV + create_embedding_fv_mock.return_value = create_embedding_fv_mock_lro + yield create_embedding_fv_mock + + +@pytest.fixture +def get_optimized_embedding_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_fv_mock: + get_fv_mock.return_value = _TEST_OPTIMIZED_EMBEDDING_FV + yield get_fv_mock + + +@pytest.fixture +def get_optimized_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_optimized_fv_mock: + get_optimized_fv_mock.return_value = _TEST_OPTIMIZED_FV1 + yield get_optimized_fv_mock + + +@pytest.fixture +def get_embedding_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_embedding_fv_mock: + get_embedding_fv_mock.return_value = _TEST_EMBEDDING_FV1 + yield get_embedding_fv_mock + + +@pytest.fixture +def get_optimized_fv_no_endpointmock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_optimized_fv_no_endpointmock: + get_optimized_fv_no_endpointmock.return_value = _TEST_OPTIMIZED_FV2 + yield get_optimized_fv_no_endpointmock + + +@pytest.fixture +def get_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "get_feature_group", + ) as get_fg_mock: + get_fg_mock.return_value = _TEST_FG1 + yield get_fg_mock + + +@pytest.fixture +def get_feature_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "get_feature", + ) as get_fg_mock: + get_fg_mock.return_value = _TEST_FG1_F1 + yield get_fg_mock + + +@pytest.fixture +def get_feature_with_version_column_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "get_feature", + ) as get_fg_mock: + get_fg_mock.return_value = _TEST_FG1_F2 + yield get_fg_mock + + +@pytest.fixture +def get_feature_monitor_mock(): + with patch.object( + FeatureRegistryServiceClient, + "get_feature_monitor", + ) as get_fg_mock: + get_fg_mock.return_value = _TEST_FG1_FM1 + yield get_fg_mock + + +@pytest.fixture +def base_logger_mock(): + with patch.object( + aiplatform_base._LOGGER, + "info", + wraps=aiplatform_base._LOGGER.info, + ) as logger_mock: + yield logger_mock diff --git a/tests/unit/agentplatform/feature_store_constants.py b/tests/unit/agentplatform/feature_store_constants.py new file mode 100644 index 0000000000..15e627f085 --- /dev/null +++ b/tests/unit/agentplatform/feature_store_constants.py @@ -0,0 +1,498 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat import types + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +# Test feature online store 1 +_TEST_BIGTABLE_FOS1_ID = "my_fos1" +_TEST_BIGTABLE_FOS1_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS1_ID}" +) +_TEST_BIGTABLE_FOS1_LABELS = {"my_key": "my_fos1"} +_TEST_BIGTABLE_FOS1 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS1_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=1, + max_node_count=2, + cpu_utilization_target=50, + ) + ), + labels=_TEST_BIGTABLE_FOS1_LABELS, +) + +# Test feature online store 2 +_TEST_BIGTABLE_FOS2_ID = "my_fos2" +_TEST_BIGTABLE_FOS2_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS2_ID}" +) +_TEST_BIGTABLE_FOS2_LABELS = {"my_key": "my_fos2"} +_TEST_BIGTABLE_FOS2 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS2_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=2, + max_node_count=3, + cpu_utilization_target=60, + ) + ), + labels=_TEST_BIGTABLE_FOS2_LABELS, +) + +# Test feature online store 3 +_TEST_BIGTABLE_FOS3_ID = "my_fos3" +_TEST_BIGTABLE_FOS3_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS3_ID}" +) +_TEST_BIGTABLE_FOS3_LABELS = {"my_key": "my_fos3"} +_TEST_BIGTABLE_FOS3 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS3_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=3, + max_node_count=4, + cpu_utilization_target=70, + ) + ), + labels=_TEST_BIGTABLE_FOS3_LABELS, +) + +# Test feature online store for optimized with esf endpoint +_TEST_ESF_OPTIMIZED_FOS_ID = "my_esf_optimized_fos" +_TEST_ESF_OPTIMIZED_FOS_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_ESF_OPTIMIZED_FOS_ID}" +) +_TEST_ESF_OPTIMIZED_FOS_LABELS = {"my_key": "my_esf_optimized_fos"} +_TEST_ESF_OPTIMIZED_FOS = types.feature_online_store.FeatureOnlineStore( + name=_TEST_ESF_OPTIMIZED_FOS_PATH, + optimized=types.feature_online_store.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + public_endpoint_domain_name="test-esf-endpoint", + ), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, +) + +# Test feature online store for optimized with psc endpoint +_TEST_PSC_OPTIMIZED_FOS_ID = "my_psc_optimized_fos" +_TEST_PSC_OPTIMIZED_FOS_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_PSC_OPTIMIZED_FOS_ID}" +) +_TEST_PSC_OPTIMIZED_FOS_LABELS = {"my_key": "my_psc_optimized_fos"} +_TEST_PSC_PROJECT_ALLOWLIST = ["project-1", "project-2"] +_TEST_PSC_OPTIMIZED_FOS = types.feature_online_store_v1.FeatureOnlineStore( + name=_TEST_PSC_OPTIMIZED_FOS_PATH, + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=types.service_networking_v1.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ), + ), + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, +) + +_TEST_FOS_LIST = [_TEST_BIGTABLE_FOS1, _TEST_BIGTABLE_FOS2, _TEST_BIGTABLE_FOS3] + +# Test feature online store for optimized with esf endpoint but sync has not run yet. +_TEST_ESF_OPTIMIZED_FOS2_ID = "my_esf_optimised_fos2" +_TEST_ESF_OPTIMIZED_FOS2_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_ESF_OPTIMIZED_FOS2_ID}" +) +_TEST_ESF_OPTIMIZED_FOS2_LABELS = {"my_key": "my_esf_optimized_fos2"} +_TEST_ESF_OPTIMIZED_FOS2 = types.feature_online_store_v1.FeatureOnlineStore( + name=_TEST_ESF_OPTIMIZED_FOS2_PATH, + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, +) + + +# Test feature view 1 +_TEST_FV1_ID = "my_fv1" +_TEST_FV1_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv1" +_TEST_FV1_LABELS = {"my_key": "my_fv1"} +_TEST_FV1_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table" +_TEST_FV1_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FV1 = types.feature_view.FeatureView( + name=_TEST_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test feature view 2 +_TEST_FV2_ID = "my_fv2" +_TEST_FV2_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv2" +_TEST_FV2_LABELS = {"my_key": "my_fv2"} +_TEST_FV2_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table" +_TEST_FV2_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FV2 = types.feature_view.FeatureView( + name=_TEST_FV2_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV2_BQ_URI, + entity_id_columns=_TEST_FV2_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV2_LABELS, +) + +# Test feature view 3 +_TEST_FV3_ID = "my_fv3" +_TEST_FV3_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv3" +_TEST_FV3_LABELS = {"my_key": "my_fv3"} +_TEST_FV3_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table" +_TEST_FV3 = types.feature_view.FeatureView( + name=_TEST_FV3_PATH, + vertex_rag_source=types.feature_view.FeatureView.VertexRagSource( + uri=_TEST_FV3_BQ_URI, + ), + labels=_TEST_FV3_LABELS, +) + + +# Test feature view sync 1 +_TEST_FV_SYNC1_ID = "my_fv_sync1" +_TEST_FV_SYNC1_PATH = f"{_TEST_FV1_PATH}/featureViewSyncs/my_fv_sync1" +_TEST_FV_SYNC1 = types.feature_view_sync.FeatureViewSync( + name=_TEST_FV_SYNC1_PATH, +) +_TEST_FV_SYNC1_RESPONSE = ( + types.feature_online_store_admin_service.SyncFeatureViewResponse( + feature_view_sync=_TEST_FV_SYNC1_PATH, + ) +) + +# Test feature view sync 2 +_TEST_FV_SYNC2_ID = "my_fv_sync2" +_TEST_FV_SYNC2_PATH = f"{_TEST_FV2_PATH}/featureViewSyncs/my_fv_sync2" +_TEST_FV_SYNC2 = types.feature_view_sync.FeatureViewSync( + name=_TEST_FV_SYNC2_PATH, +) + +_TEST_FV_SYNC_LIST = [_TEST_FV_SYNC1, _TEST_FV_SYNC2] + +# Test optimized feature view 1 +_TEST_OPTIMIZED_FV1_ID = "optimized_fv1" +_TEST_OPTIMIZED_FV1_PATH = f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/optimized_fv1" +_TEST_OPTIMIZED_FV1 = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test optimized feature view 2 +_TEST_OPTIMIZED_FV2_ID = "optimized_fv2" +_TEST_OPTIMIZED_FV2_PATH = f"{_TEST_ESF_OPTIMIZED_FOS2_PATH}/featureViews/optimized_fv2" +_TEST_OPTIMIZED_FV2 = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_FV2_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test embedding feature view 1 +_TEST_EMBEDDING_FV1_ID = "embedding_fv1" +_TEST_EMBEDDING_FV1_PATH = f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/embedding_fv1" +_TEST_EMBEDDING_FV1 = types.feature_view.FeatureView( + name=_TEST_EMBEDDING_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +_TEST_STRING_FILTER = ( + types.feature_online_store_service.NearestNeighborQuery.StringFilter( + name="filter_name", + allow_tokens=["allow_token_1", "allow_token_2"], + ) +) + +# Test optimized embedding feature view +_TEST_OPTIMIZED_EMBEDDING_FV_ID = "optimized_embedding_fv" +_TEST_OPTIMIZED_EMBEDDING_FV_PATH = ( + f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/{_TEST_OPTIMIZED_EMBEDDING_FV_ID}" +) +_TEST_OPTIMIZED_EMBEDDING_FV = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_EMBEDDING_FV_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, + index_config=types.feature_view.FeatureView.IndexConfig( + embedding_column="embedding_column", + filter_columns=["col1", "col2"], + crowding_column="crowding_column", + embedding_dimension=123, + distance_measure_type=types.feature_view.FeatureView.IndexConfig.DistanceMeasureType.DOT_PRODUCT_DISTANCE, + ), +) + +# Response for FetchFeatureValues +_TEST_FV_FETCH1 = types.feature_online_store_service_v1.FetchFeatureValuesResponse( + key_values=types.feature_online_store_service_v1.FetchFeatureValuesResponse.FeatureNameValuePairList( + features=[ + types.feature_online_store_service_v1.FetchFeatureValuesResponse.FeatureNameValuePairList.FeatureNameValuePair( + name="key1", + value=types.featurestore_online_service.FeatureValue( + string_value="value1", + ), + ), + ] + ) +) + +# Response for SearchNearestEntitiesResponse +_TEST_FV_SEARCH1 = types.feature_online_store_service_v1.SearchNearestEntitiesResponse( + nearest_neighbors=types.feature_online_store_service_v1.NearestNeighbors( + neighbors=[ + types.feature_online_store_service_v1.NearestNeighbors.Neighbor( + entity_id="neighbor_entity_id_1", + distance=0.1, + ), + ] + ) +) + +_TEST_FG1_ID = "my_fg1" +_TEST_FG1_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}" +_TEST_FG1_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg1" +_TEST_FG1_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FG1_LABELS = {"my_key": "my_fg1"} +_TEST_FG1 = types.feature_group.FeatureGroup( + name=_TEST_FG1_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG1_BQ_URI, + ), + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, +) + + +_TEST_FG2_ID = "my_fg2" +_TEST_FG2_F1_ID = "my_fg2_f1" +_TEST_FG2_F2_ID = "my_fg2_f2" +_TEST_FG2_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG2_ID}" +_TEST_FG2_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg2" +_TEST_FG2_ENTITY_ID_COLUMNS = ["entity_id1", "entity_id2"] +_TEST_FG2_LABELS = {"my_key2": "my_fg2"} +_TEST_FG2 = types.feature_group.FeatureGroup( + name=_TEST_FG2_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG2_BQ_URI, + ), + entity_id_columns=_TEST_FG2_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG2_LABELS, +) + + +_TEST_FG3_ID = "my_fg3" +_TEST_FG3_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG3_ID}" +_TEST_FG3_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table_for_fg3" +_TEST_FG3_ENTITY_ID_COLUMNS = ["entity_id1", "entity_id2", "entity_id3"] +_TEST_FG3_LABELS = {"my_key3": "my_fg3"} +_TEST_FG3 = types.feature_group.FeatureGroup( + name=_TEST_FG3_PATH, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG3_BQ_URI, + ), + entity_id_columns=_TEST_FG3_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG3_LABELS, +) + +_TEST_FG_LIST = [_TEST_FG1, _TEST_FG2, _TEST_FG3] + +_TEST_FG1_F1_ID = "my_fg1_f1" +_TEST_FG1_F1_PATH = ( + f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/features/{_TEST_FG1_F1_ID}" +) +_TEST_FG1_F1_DESCRIPTION = "My feature 1 in feature group 1" +_TEST_FG1_F1_LABELS = {"my_fg1_feature": "f1"} +_TEST_FG1_F1_POINT_OF_CONTACT = "fg1-f1-announce-list" +_TEST_FG1_F1 = types.feature.Feature( + name=_TEST_FG1_F1_PATH, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, +) + + +_TEST_FG1_F2_ID = "my_fg1_f2" +_TEST_FG1_F2_PATH = ( + f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/features/{_TEST_FG1_F2_ID}" +) +_TEST_FG1_F2_DESCRIPTION = "My feature 2 in feature group 1" +_TEST_FG1_F2_LABELS = {"my_fg1_feature": "f2"} +_TEST_FG1_F2_POINT_OF_CONTACT = "fg1-f2-announce-list" +_TEST_FG1_F2_VERSION_COLUMN_NAME = "specific_column_for_feature_2" +_TEST_FG1_F2 = types.feature.Feature( + name=_TEST_FG1_F2_PATH, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, +) + +_TEST_FG1_FEATURE_LIST = [_TEST_FG1_F1, _TEST_FG1_F2] + +_TEST_FG1_FM1_ID = "my_fg1_fm1" +_TEST_FG1_FM1_PATH = ( + f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/featureMonitors/{_TEST_FG1_FM1_ID}" +) +_TEST_FG1_FM1_DESCRIPTION = "My feature monitor 1 in feature group 1" +_TEST_FG1_FM1_LABELS = {"my_fg1_feature_monitor": "fm1"} +_TEST_FG1_FM1 = types.feature_monitor.FeatureMonitor( + name=_TEST_FG1_FM1_PATH, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=types.feature_monitor.ScheduleConfig(cron="0 0 * * *"), + feature_selection_config=types.feature_monitor.FeatureSelectionConfig( + feature_configs=[ + types.feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id="my_fg1_f1", + drift_threshold=0.3, + ), + types.feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id="my_fg1_f2", + drift_threshold=0.4, + ), + ] + ), +) +_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS = [("my_fg1_f1", 0.3), ("my_fg1_f2", 0.4)] +_TEST_FG1_FM1_SCHEDULE_CONFIG = "0 0 * * *" +_TEST_FG1_FM2_ID = "my_fg1_fm2" +_TEST_FG1_FM2_PATH = ( + f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/featureMonitors/{_TEST_FG1_FM2_ID}" +) +_TEST_FG1_FM2_DESCRIPTION = "My feature monitor 2 in feature group 1" +_TEST_FG1_FM2_LABELS = {"my_fg1_feature_monitor": "fm2"} +_TEST_FG1_FM2_FEATURE_SELECTION_CONFIGS = [("my_fg1_f2", 0.5)] +_TEST_FG1_FM2_SCHEDULE_CONFIG = "8 0 * * *" +_TEST_FG1_FM2 = types.feature_monitor.FeatureMonitor( + name=_TEST_FG1_FM2_PATH, + description=_TEST_FG1_FM2_DESCRIPTION, + labels=_TEST_FG1_FM2_LABELS, + schedule_config=types.feature_monitor.ScheduleConfig(cron="8 0 * * *"), + feature_selection_config=types.feature_monitor.FeatureSelectionConfig( + feature_configs=[ + types.feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id="my_fg1_f2", + drift_threshold=0.5, + ), + ] + ), +) +_TEST_FG1_FM_LIST = [_TEST_FG1_FM1, _TEST_FG1_FM2] + +_TEST_FG1_FMJ1_ID = "1234567890" +_TEST_FG1_FMJ1_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/featureMonitors/{_TEST_FG1_FM1_ID}/featureMonitorJobs/{_TEST_FG1_FMJ1_ID}" +_TEST_FG1_FMJ1_DESCRIPTION = "My feature monitor job 1 in feature monitor 1" +_TEST_FG1_FMJ1_LABELS = {"my_fg1_feature_monitor_job": "fmj1"} +_TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY = types.feature_monitor.FeatureStatsAndAnomaly( + feature_id="my_fg1_f1", + distribution_deviation=0.5, + drift_detection_threshold=0.4, + drift_detected=True, + feature_monitor_job_id=_TEST_FG1_FMJ1_ID, + feature_monitor_id=_TEST_FG1_FM1_ID, +) +_TEST_FG1_F2_FEATURE_STATS_AND_ANOMALY = types.feature_monitor.FeatureStatsAndAnomaly( + feature_id="my_fg1_f2", + distribution_deviation=0.2, + drift_detection_threshold=0.4, + drift_detected=False, + feature_monitor_job_id=_TEST_FG1_FMJ1_ID, + feature_monitor_id=_TEST_FG1_FM1_ID, +) +_TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES = [ + _TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY, + _TEST_FG1_F2_FEATURE_STATS_AND_ANOMALY, +] +_TEST_FG1_FMJ1 = types.feature_monitor_job.FeatureMonitorJob( + name=_TEST_FG1_FMJ1_PATH, + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + job_summary=types.feature_monitor_job.FeatureMonitorJob.JobSummary( + feature_stats_and_anomalies=_TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES + ), +) +_TEST_FG1_FMJ2_ID = "1234567891" +_TEST_FG1_FMJ2_PATH = f"{_TEST_PARENT}/featureGroups/{_TEST_FG1_ID}/featureMonitors/{_TEST_FG1_FM1_ID}/featureMonitorJobs/{_TEST_FG1_FMJ2_ID}" +_TEST_FG1_FMJ2_DESCRIPTION = "My feature monitor job 2 in feature monitor 1" +_TEST_FG1_FMJ2_LABELS = {"my_fg1_feature_monitor_job": "fmj2"} +_TEST_FG1_FMJ2 = types.feature_monitor_job.FeatureMonitorJob( + name=_TEST_FG1_FMJ2_PATH, + description=_TEST_FG1_FMJ2_DESCRIPTION, + labels=_TEST_FG1_FMJ2_LABELS, +) +_TEST_FG1_FMJ_LIST = [_TEST_FG1_FMJ1, _TEST_FG1_FMJ2] + +_TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY = types.feature_monitor.FeatureStatsAndAnomaly( + feature_id="my_fg1_f1", + distribution_deviation=0.5, + drift_detection_threshold=0.4, + drift_detected=True, + feature_monitor_job_id="1234567890", + feature_monitor_id="1234567891", +) +_TEST_FG1_F1_WITH_STATS = types.feature_v1beta1.Feature( + name=_TEST_FG1_F1_PATH, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + feature_stats_and_anomaly=[_TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY], +) + +# Test feature view 4 +_TEST_FV4_ID = "my_fv4" +_TEST_FV4_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv4" +_TEST_FV4_LABELS = {"my_key": "my_fv4"} +_TEST_FV4 = types.feature_view.FeatureView( + name=_TEST_FV4_PATH, + feature_registry_source=types.feature_view.FeatureView.FeatureRegistrySource( + feature_groups=[ + types.feature_view.FeatureView.FeatureRegistrySource.FeatureGroup( + feature_group_id=_TEST_FG1_ID, + feature_ids=[_TEST_FG1_F1_ID, _TEST_FG1_F2_ID], + ), + types.feature_view.FeatureView.FeatureRegistrySource.FeatureGroup( + feature_group_id=_TEST_FG2_ID, + feature_ids=[_TEST_FG2_F1_ID, _TEST_FG2_F2_ID], + ), + ], + ), + labels=_TEST_FV4_LABELS, +) + +_TEST_FV_LIST = [_TEST_FV1, _TEST_FV2, _TEST_FV3, _TEST_FV4] diff --git a/tests/unit/vertexai/test_feature.py b/tests/unit/agentplatform/test_feature.py similarity index 98% rename from tests/unit/vertexai/test_feature.py rename to tests/unit/agentplatform/test_feature.py index d3f3e00b4f..80882278a4 100644 --- a/tests/unit/vertexai/test_feature.py +++ b/tests/unit/agentplatform/test_feature.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -30,7 +28,7 @@ from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( FeatureRegistryServiceClient, ) -from feature_store_constants import ( +from .feature_store_constants import ( _TEST_FG1_F1_DESCRIPTION, _TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY, _TEST_FG1_F1_ID, @@ -48,7 +46,7 @@ _TEST_LOCATION, _TEST_PROJECT, ) -from vertexai.resources.preview import ( +from agentplatform.resources.preview import ( Feature, FeatureGroup, ) diff --git a/tests/unit/vertexai/test_feature_group.py b/tests/unit/agentplatform/test_feature_group.py similarity index 99% rename from tests/unit/vertexai/test_feature_group.py rename to tests/unit/agentplatform/test_feature_group.py index 4ef2a74c54..795d71f950 100644 --- a/tests/unit/vertexai/test_feature_group.py +++ b/tests/unit/agentplatform/test_feature_group.py @@ -27,13 +27,13 @@ from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( FeatureRegistryServiceClient, ) -from vertexai.resources.preview.feature_store import ( +from agentplatform.resources.preview.feature_store import ( feature_group, ) -from vertexai.resources.preview import ( +from agentplatform.resources.preview import ( FeatureGroup, ) -from vertexai.resources.preview.feature_store import ( +from agentplatform.resources.preview.feature_store import ( FeatureGroupBigQuerySource, ) import pytest @@ -43,7 +43,7 @@ from google.cloud.aiplatform.compat import types -from feature_store_constants import ( +from .feature_store_constants import ( _TEST_PARENT, _TEST_PROJECT, _TEST_LOCATION, @@ -95,8 +95,8 @@ _TEST_FG1_FM2_SCHEDULE_CONFIG, _TEST_FG1_FM_LIST, ) -from test_feature import feature_eq -from test_feature_monitor import ( +from .test_feature import feature_eq +from .test_feature_monitor import ( feature_monitor_eq, ) diff --git a/tests/unit/vertexai/test_feature_monitor.py b/tests/unit/agentplatform/test_feature_monitor.py similarity index 98% rename from tests/unit/vertexai/test_feature_monitor.py rename to tests/unit/agentplatform/test_feature_monitor.py index b5aaa5490e..9779cb5c36 100644 --- a/tests/unit/vertexai/test_feature_monitor.py +++ b/tests/unit/agentplatform/test_feature_monitor.py @@ -22,7 +22,7 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base -from feature_store_constants import ( +from .feature_store_constants import ( _TEST_PROJECT, _TEST_LOCATION, _TEST_FG1_ID, @@ -43,12 +43,12 @@ _TEST_FG1_FMJ2_LABELS, _TEST_FG1_FMJ2_PATH, ) -from vertexai.resources.preview import FeatureMonitor +from agentplatform.resources.preview import FeatureMonitor from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( FeatureRegistryServiceClient, ) from google.cloud.aiplatform.compat import types -from vertexai.resources.preview.feature_store import ( +from agentplatform.resources.preview.feature_store import ( feature_monitor, ) import pytest diff --git a/tests/unit/vertexai/test_feature_online_store.py b/tests/unit/agentplatform/test_feature_online_store.py similarity index 99% rename from tests/unit/vertexai/test_feature_online_store.py rename to tests/unit/agentplatform/test_feature_online_store.py index fbe4cf4050..12f548c8db 100644 --- a/tests/unit/vertexai/test_feature_online_store.py +++ b/tests/unit/agentplatform/test_feature_online_store.py @@ -28,7 +28,7 @@ from google.cloud.aiplatform.compat.services import ( feature_online_store_admin_service_client, ) -from feature_store_constants import ( +from .feature_store_constants import ( _TEST_BIGTABLE_FOS1_ID, _TEST_BIGTABLE_FOS1_LABELS, _TEST_BIGTABLE_FOS1_PATH, @@ -74,8 +74,8 @@ _TEST_PSC_OPTIMIZED_FOS_PATH, _TEST_PSC_PROJECT_ALLOWLIST, ) -from test_feature_view import fv_eq -from vertexai.resources.preview import ( +from .test_feature_view import fv_eq +from agentplatform.resources.preview import ( DistanceMeasureType, FeatureOnlineStore, FeatureOnlineStoreType, @@ -85,7 +85,7 @@ IndexConfig, TreeAhConfig, ) -from vertexai.resources.preview.feature_store import ( +from agentplatform.resources.preview.feature_store import ( feature_online_store, ) import pytest diff --git a/tests/unit/vertexai/test_feature_view.py b/tests/unit/agentplatform/test_feature_view.py similarity index 99% rename from tests/unit/vertexai/test_feature_view.py rename to tests/unit/agentplatform/test_feature_view.py index eacb4e011b..aa3c5da2a4 100644 --- a/tests/unit/vertexai/test_feature_view.py +++ b/tests/unit/agentplatform/test_feature_view.py @@ -23,20 +23,20 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base -from vertexai.resources.preview import ( +from agentplatform.resources.preview import ( FeatureView, ) -import vertexai.resources.preview.feature_store.utils as fs_utils +import agentplatform.resources.preview.feature_store.utils as fs_utils import pytest from google.cloud.aiplatform.compat.services import ( feature_online_store_admin_service_client, feature_online_store_service_client, ) -from vertexai.resources.preview.feature_store import ( +from agentplatform.resources.preview.feature_store import ( feature_view, ) -from feature_store_constants import ( +from .feature_store_constants import ( _TEST_BIGTABLE_FOS1_ID, _TEST_BIGTABLE_FOS1_PATH, _TEST_EMBEDDING_FV1_PATH, diff --git a/tests/unit/agentplatform/test_model_monitors.py b/tests/unit/agentplatform/test_model_monitors.py new file mode 100644 index 0000000000..842f4ae7c3 --- /dev/null +++ b/tests/unit/agentplatform/test_model_monitors.py @@ -0,0 +1,1227 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import os +from unittest import mock + +from google import auth +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.services import ( + model_monitoring_service_client_v1beta1 as model_monitoring_service_client, + schedule_service_client_v1beta1 as schedule_service_client, +) +from google.cloud.aiplatform.compat.types import ( + io_v1beta1 as io, + model_monitor_v1beta1 as gca_model_monitor, + model_monitoring_alert_v1beta1 as gca_model_monitoring_alert, + model_monitoring_job_v1beta1 as gca_model_monitoring_job, + model_monitoring_service_v1beta1 as gca_model_monitoring_service, + model_monitoring_spec_v1beta1 as gca_model_monitoring_spec, + model_monitoring_stats_v1beta1 as gca_model_monitoring_stats, + schedule_service_v1beta1 as gca_schedule_service, + schedule_v1beta1 as gca_schedule, + job_state_v1beta1 as gca_job_state, + explanation_v1beta1 as explanation, +) +from agentplatform.resources.preview import ( + ml_monitoring, + ModelMonitor, + ModelMonitoringJob, +) +import pytest + +from google.protobuf import empty_pb2 # type: ignore +from google.protobuf import field_mask_pb2 # type: ignore + + +# -*- coding: utf-8 -*- + +_TEST_CREDENTIALS = mock.Mock( + spec=auth_credentials.AnonymousCredentials(), + universe_domain="googleapis.com", +) +_TEST_DESCRIPTION = "test description" +_TEST_JSON_CONTENT_TYPE = "application/json" +_TEST_LOCATION = "us-central1" +_TEST_LOCATION_2 = "europe-west4" +_TEST_PROJECT = "test-project" +_TEST_REPLICA_COUNT = 1 +_TEST_MODEL_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/models/123" +_TEST_MODEL_VERSION_ID = "1" +_TEST_MODEL_MONITOR_APP = "ortools-on-vertex-v0.1" +_TEST_MODEL_MONITOR_DISPLAY_NAME = "model-monitor-display-name" +_TEST_MODEL_MONITOR_USER_ID = "user_456" +_TEST_MODEL_MONITOR_ID = "456" +_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME = "job-display-name" +_TEST_MODEL_MONITORING_JOB_USER_ID = "user_789" +_TEST_MODEL_MONITORING_JOB_ID = "789" +_TEST_SCHEDULE_NAME = "000" +_TEST_OUTPUT_PATH = "tests/output_path" +_TEST_NOTIFICATION_EMAIL = "123@test.com" +_TEST_BASELINE_RESOURCE = "tests/baseline" +_TEST_TARGET_RESOURCE = "tests/target" +_TEST_TRAINING_DATASET = gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), +) +_TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "testdata") +_TEST_MODEL_MONITOR_PARENT = initializer.global_config.common_location_path( + project=_TEST_PROJECT, location=_TEST_LOCATION +) +_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME = model_monitoring_service_client.ModelMonitoringServiceClient.model_monitoring_job_path( + _TEST_PROJECT, + _TEST_LOCATION, + _TEST_MODEL_MONITOR_ID, + _TEST_MODEL_MONITORING_JOB_ID, +) +_TEST_MODEL_MONITOR_RESOURCE_NAME = ( + model_monitoring_service_client.ModelMonitoringServiceClient.model_monitor_path( + _TEST_PROJECT, _TEST_LOCATION, _TEST_MODEL_MONITOR_ID + ) +) +_TEST_MODEL_MONITORING_SCHEMA = ml_monitoring.spec.ModelMonitoringSchema( + feature_fields=[ + ml_monitoring.spec.FieldSchema( + name="feature1", + data_type="string", + repeated=False, + ) + ], +) +_TEST_CREATE_MODEL_MONITOR_OBJ = gca_model_monitor.ModelMonitor( + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + model_monitoring_target=gca_model_monitor.ModelMonitor.ModelMonitoringTarget( + vertex_model=gca_model_monitor.ModelMonitor.ModelMonitoringTarget.VertexModelSource( + model=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + ) + ), + training_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ), +) +_TEST_MODEL_MONITOR_OBJ = gca_model_monitor.ModelMonitor( + name=_TEST_MODEL_MONITOR_RESOURCE_NAME, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + model_monitoring_target=gca_model_monitor.ModelMonitor.ModelMonitoringTarget( + vertex_model=gca_model_monitor.ModelMonitor.ModelMonitoringTarget.VertexModelSource( + model=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + ) + ), + training_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + model_monitoring_schema=gca_model_monitor.ModelMonitoringSchema( + feature_fields=[ + gca_model_monitor.ModelMonitoringSchema.FieldSchema( + name="feature1", + data_type="string", + repeated=False, + ) + ], + ), + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), +) +_TEST_UPDATED_MODEL_MONITOR_OBJ = gca_model_monitor.ModelMonitor( + name=_TEST_MODEL_MONITOR_RESOURCE_NAME, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + model_monitoring_target=gca_model_monitor.ModelMonitor.ModelMonitoringTarget( + vertex_model=gca_model_monitor.ModelMonitor.ModelMonitoringTarget.VertexModelSource( + model=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + ) + ), + training_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + model_monitoring_schema=gca_model_monitor.ModelMonitoringSchema( + feature_fields=[ + gca_model_monitor.ModelMonitoringSchema.FieldSchema( + name="feature1", + data_type="string", + repeated=False, + ) + ], + ), + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_spec.model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_spec.model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL, "456@test.com"] + ), + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), +) +_TEST_CREATE_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job.ModelMonitoringJob( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + model_monitoring_spec=gca_model_monitoring_spec.ModelMonitoringSpec( + objective_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + baseline_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + target_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_TARGET_RESOURCE + ) + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ) + ), + ), +) +_TEST_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job.ModelMonitoringJob( + name=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME, + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + model_monitoring_spec=gca_model_monitoring_spec.ModelMonitoringSpec( + objective_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + baseline_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + target_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_TARGET_RESOURCE + ) + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ) + ), + ), + state=gca_job_state.JobState.JOB_STATE_SUCCEEDED, +) +_TEST_CRON = r"America/New_York 1 \* \* \* \*" +_TEST_SCHEDULE_OBJ = gca_schedule.Schedule( + display_name=_TEST_SCHEDULE_NAME, + cron=_TEST_CRON, + create_model_monitoring_job_request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_MODEL_MONITORING_JOB_OBJ, + ), + max_concurrent_run_count=1, +) +_TEST_UPDATED_MODEL_MONITORING_JOB_OBJ = gca_model_monitoring_job.ModelMonitoringJob( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + model_monitoring_spec=gca_model_monitoring_spec.ModelMonitoringSpec( + objective_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec( + tabular_objective=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.TabularObjective( + feature_drift_spec=gca_model_monitoring_spec.ModelMonitoringObjectiveSpec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.1, + ), + default_numeric_alert_condition=gca_model_monitoring_alert.ModelMonitoringAlertCondition( + threshold=0.2, + ), + ) + ), + baseline_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + ), + target_dataset=gca_model_monitoring_spec.ModelMonitoringInput( + columnized_dataset=gca_model_monitoring_spec.ModelMonitoringInput.ModelMonitoringDataset( + vertex_dataset=_TEST_TARGET_RESOURCE + ) + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + ), + output_spec=gca_model_monitoring_spec.ModelMonitoringOutputSpec( + gcs_base_directory=io.GcsDestination(output_uri_prefix=_TEST_OUTPUT_PATH) + ), + notification_spec=gca_model_monitoring_spec.ModelMonitoringNotificationSpec( + email_config=gca_model_monitoring_spec.ModelMonitoringNotificationSpec.EmailConfig( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ) + ), + ), +) +_TEST_UPDATED_SCHEDULE_OBJ = gca_schedule.Schedule( + display_name=_TEST_SCHEDULE_NAME, + cron=r"America/New_York 0 \* \* \* \*", + create_model_monitoring_job_request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_UPDATED_MODEL_MONITORING_JOB_OBJ, + ), + max_concurrent_run_count=1, +) +_TEST_SEARCH_REQUEST = gca_model_monitoring_service.SearchModelMonitoringStatsRequest( + model_monitor=_TEST_MODEL_MONITOR_RESOURCE_NAME, + stats_filter=( + gca_model_monitoring_stats.SearchModelMonitoringStatsFilter( + tabular_stats_filter=( + gca_model_monitoring_stats.SearchModelMonitoringStatsFilter.TabularStatsFilter( + model_monitoring_job=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME, + ) + ) + ) + ), +) +_TEST_SEARCH_RESPONSE = ( + gca_model_monitoring_service.SearchModelMonitoringStatsResponse() +) +_TEST_SEARCH_ALERTS_REQUEST = ( + gca_model_monitoring_service.SearchModelMonitoringAlertsRequest( + model_monitor=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME, + ) +) +_TEST_SEARCH_ALERTS_RESPONSE = ( + gca_model_monitoring_service.SearchModelMonitoringAlertsResponse() +) +_TEST_LIST_REQUEST = gca_model_monitoring_service.ListModelMonitoringJobsRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME +) +_TEST_LIST_RESPONSE = gca_model_monitoring_service.ListModelMonitoringJobsResponse( + model_monitoring_jobs=[ + _TEST_MODEL_MONITORING_JOB_OBJ, + _TEST_MODEL_MONITORING_JOB_OBJ, + ], + next_page_token="1", +) + + +@pytest.fixture +def authorized_session_mock(): + with mock.patch( + "google.auth.transport.requests.AuthorizedSession" + ) as mock_authorized_session: + mock_auth_session = mock_authorized_session(_TEST_CREDENTIALS) + yield mock_auth_session + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + google_auth_mock.return_value = ( + auth_credentials.AnonymousCredentials(), + "test-project", + ) + yield google_auth_mock + + +@pytest.fixture +def create_client_mock(): + with mock.patch.object( + initializer.global_config, "create_client" + ) as create_client_mock: + api_client_mock = mock.Mock( + spec=model_monitoring_service_client.ModelMonitoringServiceClient + ) + api_client_mock.get_model_monitor.return_value = _TEST_MODEL_MONITOR_OBJ + create_client_mock.return_value = api_client_mock + yield create_client_mock + + +@pytest.fixture +def create_model_monitor_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "create_model_monitor", + ) as create_model_monitor_mock: + create_model_monitor_lro_mock = mock.Mock(ga_operation.Operation) + create_model_monitor_lro_mock.result.return_value = _TEST_MODEL_MONITOR_OBJ + create_model_monitor_mock.return_value = create_model_monitor_lro_mock + yield create_model_monitor_mock + + +@pytest.fixture +def get_model_monitor_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "get_model_monitor", + ) as get_model_monitor_mock: + get_model_monitor_mock.return_value = _TEST_MODEL_MONITOR_OBJ + yield get_model_monitor_mock + + +@pytest.fixture +def update_model_monitor_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "update_model_monitor", + ) as update_model_monitor_mock: + update_model_monitor_lro_mock = mock.Mock(ga_operation.Operation) + update_model_monitor_lro_mock.result.return_value = ( + _TEST_UPDATED_MODEL_MONITOR_OBJ + ) + update_model_monitor_mock.return_value = update_model_monitor_lro_mock + yield update_model_monitor_mock + + +@pytest.fixture +def create_schedule_mock(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "create_schedule" + ) as create_schedule_mock: + create_schedule_mock.return_value = _TEST_SCHEDULE_OBJ + yield create_schedule_mock + + +@pytest.fixture +def update_schedule_mock(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "update_schedule" + ) as update_schedule_mock: + update_schedule_mock.return_value = _TEST_UPDATED_SCHEDULE_OBJ + yield update_schedule_mock + + +@pytest.fixture +def get_schedule_mock(): + with mock.patch.object( + schedule_service_client.ScheduleServiceClient, "get_schedule" + ) as get_schedule_mock: + get_schedule_mock.return_value = _TEST_SCHEDULE_OBJ + yield get_schedule_mock + + +@pytest.fixture +def search_metrics_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "search_model_monitoring_stats", + ) as search_metrics_mock: + search_metrics_mock.return_value = ( + model_monitoring_service_client.pagers.SearchModelMonitoringStatsPager( + method=search_metrics_mock, + request=_TEST_SEARCH_REQUEST, + response=_TEST_SEARCH_RESPONSE, + ) + ) + yield search_metrics_mock + + +@pytest.fixture +def search_alerts_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "search_model_monitoring_alerts", + ) as search_alerts_mock: + search_alerts_mock.return_value = ( + model_monitoring_service_client.pagers.SearchModelMonitoringAlertsPager( + method=search_alerts_mock, + request=_TEST_SEARCH_ALERTS_REQUEST, + response=_TEST_SEARCH_ALERTS_RESPONSE, + ) + ) + yield search_alerts_mock + + +@pytest.fixture +def list_model_monitoring_jobs_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "list_model_monitoring_jobs", + ) as list_model_monitoring_jobs_mock: + list_model_monitoring_jobs_mock.return_value = ( + model_monitoring_service_client.pagers.ListModelMonitoringJobsPager( + method=list_model_monitoring_jobs_mock, + request=_TEST_LIST_REQUEST, + response=_TEST_LIST_RESPONSE, + ) + ) + yield list_model_monitoring_jobs_mock + + +@pytest.fixture +def delete_model_monitor_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "delete_model_monitor", + ) as delete_model_monitor_mock: + delete_model_monitor_lro_mock = mock.Mock(ga_operation.Operation) + delete_model_monitor_lro_mock.result.return_value = empty_pb2.Empty + delete_model_monitor_mock.return_value = delete_model_monitor_lro_mock + yield delete_model_monitor_mock + + +@pytest.fixture +def create_model_monitoring_job_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "create_model_monitoring_job", + ) as create_model_monitoring_job_mock: + create_model_monitoring_job_mock.return_value = _TEST_MODEL_MONITORING_JOB_OBJ + yield create_model_monitoring_job_mock + + +@pytest.fixture +def get_model_monitoring_job_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "get_model_monitoring_job", + ) as get_model_monitoring_job_mock: + model_monitoring_job_mock = mock.Mock( + spec=gca_model_monitoring_job.ModelMonitoringJob + ) + model_monitoring_job_mock.state = gca_job_state.JobState.JOB_STATE_SUCCEEDED + model_monitoring_job_mock.name = _TEST_MODEL_MONITORING_JOB_RESOURCE_NAME + get_model_monitoring_job_mock.return_value = model_monitoring_job_mock + yield get_model_monitoring_job_mock + + +@pytest.fixture +def delete_model_monitoring_job_mock(): + with mock.patch.object( + model_monitoring_service_client.ModelMonitoringServiceClient, + "delete_model_monitoring_job", + ) as delete_model_monitoring_job_mock: + delete_model_monitoring_job_lro_mock = mock.Mock(ga_operation.Operation) + delete_model_monitoring_job_lro_mock.result.return_value = empty_pb2.Empty + delete_model_monitoring_job_mock.return_value = ( + delete_model_monitoring_job_lro_mock + ) + yield delete_model_monitoring_job_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestModelMonitor: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_constructor_creates_client(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + ModelMonitor(_TEST_MODEL_MONITOR_ID) + create_client_mock.assert_any_call( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION, + appended_user_agent=None, + ) + + def test_constructor_create_client_with_custom_location(self, create_client_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + ModelMonitor(_TEST_MODEL_MONITOR_ID, location=_TEST_LOCATION_2) + create_client_mock.assert_any_call( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=initializer.global_config.credentials, + location_override=_TEST_LOCATION_2, + appended_user_agent=None, + ) + + def test_constructor_creates_client_with_custom_credentials( + self, create_client_mock + ): + creds = auth_credentials.AnonymousCredentials() + ModelMonitor(_TEST_MODEL_MONITOR_ID, credentials=creds) + create_client_mock.assert_any_call( + client_class=utils.ModelMonitoringClientWithOverride, + credentials=creds, + location_override=_TEST_LOCATION, + appended_user_agent=None, + ) + + @pytest.mark.usefixtures("create_model_monitor_mock") + def test_create_model_monitor(self, create_model_monitor_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + create_model_monitor_mock.assert_called_once_with( + request=gca_model_monitoring_service.CreateModelMonitorRequest( + parent=_TEST_MODEL_MONITOR_PARENT, + model_monitor=_TEST_CREATE_MODEL_MONITOR_OBJ, + ), + ) + + @pytest.mark.usefixtures("create_model_monitor_mock") + def test_create_model_monitor_with_user_id(self, create_model_monitor_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + model_monitor_id=_TEST_MODEL_MONITOR_USER_ID, + ) + create_model_monitor_mock.assert_called_once_with( + request=gca_model_monitoring_service.CreateModelMonitorRequest( + parent=_TEST_MODEL_MONITOR_PARENT, + model_monitor=_TEST_CREATE_MODEL_MONITOR_OBJ, + model_monitor_id=_TEST_MODEL_MONITOR_USER_ID, + ), + ) + + @pytest.mark.usefixtures( + "create_model_monitor_mock", + "get_model_monitor_mock", + "update_model_monitor_mock", + ) + def test_update_model_monitor(self, update_model_monitor_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + assert isinstance(test_model_monitor, ModelMonitor) + test_model_monitor.update( + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL, "456@test.com"] + ), + ) + update_model_monitor_mock.assert_called_once_with( + model_monitor=_TEST_UPDATED_MODEL_MONITOR_OBJ, + update_mask=field_mask_pb2.FieldMask(paths=["notification_spec"]), + ) + + @pytest.mark.usefixtures("create_schedule_mock", "create_model_monitor_mock") + def test_create_schedule(self, create_schedule_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + ) + test_model_monitor.create_schedule( + display_name=_TEST_SCHEDULE_NAME, + model_monitoring_job_display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + cron=_TEST_CRON, + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + ) + create_schedule_mock.assert_called_once_with( + request=gca_schedule_service.CreateScheduleRequest( + parent=_TEST_MODEL_MONITOR_PARENT, + schedule=gca_schedule.Schedule( + display_name=_TEST_SCHEDULE_NAME, + cron=_TEST_CRON, + create_model_monitoring_job_request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_CREATE_MODEL_MONITORING_JOB_OBJ, + ), + max_concurrent_run_count=1, + ), + ) + ) + + @pytest.mark.usefixtures( + "create_schedule_mock", + "update_schedule_mock", + "get_schedule_mock", + "create_model_monitor_mock", + ) + def test_update_schedule(self, update_schedule_mock, get_schedule_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + test_model_monitor.create_schedule( + display_name=_TEST_SCHEDULE_NAME, + model_monitoring_job_display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + cron=_TEST_CRON, + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + ) + test_model_monitor.update_schedule( + schedule_name=_TEST_SCHEDULE_NAME, + cron=r"America/New_York 0 \* \* \* \*", + baseline_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + categorical_metric_type="l_infinity", + numeric_metric_type="jensen_shannon_divergence", + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + ) + update_schedule_mock.assert_called_once_with( + schedule=_TEST_UPDATED_SCHEDULE_OBJ, + update_mask=field_mask_pb2.FieldMask( + paths=["cron", "create_model_monitoring_job_request"] + ), + ) + assert get_schedule_mock.call_count == 1 + + @pytest.mark.usefixtures( + "create_model_monitoring_job_mock", + "create_model_monitor_mock", + "get_model_monitoring_job_mock", + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_run_model_monitoring_job(self, create_model_monitoring_job_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + ) + test_model_monitoring_job = test_model_monitor.run( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + baseline_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + sync=sync, + ) + + if not sync: + test_model_monitoring_job.wait() + + create_model_monitoring_job_mock.assert_called_once_with( + request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_CREATE_MODEL_MONITORING_JOB_OBJ, + ) + ) + + @pytest.mark.usefixtures( + "create_model_monitoring_job_mock", + "create_model_monitor_mock", + "get_model_monitoring_job_mock", + ) + def test_run_model_monitoring_job_with_user_id( + self, create_model_monitoring_job_mock + ): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + ) + test_model_monitor.run( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + baseline_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + model_monitoring_job_id=_TEST_MODEL_MONITORING_JOB_USER_ID, + sync=True, + ) + create_model_monitoring_job_mock.assert_called_once_with( + request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_CREATE_MODEL_MONITORING_JOB_OBJ, + model_monitoring_job_id=_TEST_MODEL_MONITORING_JOB_USER_ID, + ) + ) + + @pytest.mark.usefixtures( + "create_model_monitoring_job_mock", + "create_model_monitor_mock", + "search_metrics_mock", + "get_model_monitoring_job_mock", + ) + def test_search_metrics(self, search_metrics_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + test_model_monitor.run( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + sync=True, + ) + test_model_monitor.search_metrics( + model_monitoring_job_name=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME + ) + search_metrics_mock.assert_called_once_with(request=_TEST_SEARCH_REQUEST) + + @pytest.mark.usefixtures( + "create_model_monitoring_job_mock", + "create_model_monitor_mock", + "search_alerts_mock", + "get_model_monitoring_job_mock", + ) + def test_search_alerts(self, search_alerts_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + test_model_monitor.run( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + sync=True, + ) + test_model_monitor.search_alerts( + model_monitoring_job_name=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME + ) + search_alerts_mock.assert_called_once_with(request=_TEST_SEARCH_ALERTS_REQUEST) + + @pytest.mark.usefixtures("create_model_monitor_mock", "delete_model_monitor_mock") + @pytest.mark.parametrize("force", [True, False]) + def test_delete_model_monitor(self, delete_model_monitor_mock, force): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + test_model_monitor.delete(force=force) + delete_model_monitor_mock.assert_called_once_with( + request=gca_model_monitoring_service.DeleteModelMonitorRequest( + name=_TEST_MODEL_MONITOR_RESOURCE_NAME, force=force + ) + ) + + @pytest.mark.usefixtures( + "create_model_monitoring_job_mock", "get_model_monitoring_job_mock" + ) + @pytest.mark.parametrize("sync", [True, False]) + def test_create_model_monitoring_job(self, create_model_monitoring_job_mock, sync): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitoring_job = ModelMonitoringJob.create( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + model_monitor_name=_TEST_MODEL_MONITOR_RESOURCE_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + baseline_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + explanation_spec=explanation.ExplanationSpec( + parameters=explanation.ExplanationParameters(top_k=10) + ), + sync=sync, + ) + + if not sync: + test_model_monitoring_job.wait() + + create_model_monitoring_job_mock.assert_called_once_with( + request=gca_model_monitoring_service.CreateModelMonitoringJobRequest( + parent=_TEST_MODEL_MONITOR_RESOURCE_NAME, + model_monitoring_job=_TEST_CREATE_MODEL_MONITORING_JOB_OBJ, + ) + ) + + @pytest.mark.usefixtures( + "create_model_monitor_mock", + "create_model_monitoring_job_mock", + "delete_model_monitoring_job_mock", + "get_model_monitoring_job_mock", + ) + def test_delete_model_monitoring_job(self, delete_model_monitoring_job_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_job = ModelMonitoringJob.create( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + model_monitor_name=_TEST_MODEL_MONITOR_RESOURCE_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + baseline_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + sync=True, + ) + test_job.delete() + delete_model_monitoring_job_mock.assert_called_once_with( + name=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME + ) + + @pytest.mark.usefixtures( + "create_model_monitor_mock", + "get_model_monitoring_job_mock", + "create_model_monitoring_job_mock", + ) + def test_get_model_monitoring_job(self): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + ) + test_model_monitor = ModelMonitor.create( + training_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_BASELINE_RESOURCE + ), + model_name=_TEST_MODEL_NAME, + model_version_id=_TEST_MODEL_VERSION_ID, + display_name=_TEST_MODEL_MONITOR_DISPLAY_NAME, + tabular_objective_spec=ml_monitoring.spec.TabularObjective( + feature_drift_spec=ml_monitoring.spec.DataDriftSpec( + default_categorical_alert_threshold=0.1, + default_numeric_alert_threshold=0.2, + ), + ), + output_spec=ml_monitoring.spec.OutputSpec(gcs_base_dir=_TEST_OUTPUT_PATH), + notification_spec=ml_monitoring.spec.NotificationSpec( + user_emails=[_TEST_NOTIFICATION_EMAIL] + ), + ) + test_model_monitor.run( + display_name=_TEST_MODEL_MONITORING_JOB_DISPLAY_NAME, + target_dataset=ml_monitoring.spec.MonitoringInput( + vertex_dataset=_TEST_TARGET_RESOURCE + ), + sync=True, + ) + test_model_monitoring_job = test_model_monitor.get_model_monitoring_job( + model_monitoring_job_name=_TEST_MODEL_MONITORING_JOB_RESOURCE_NAME + ) + assert isinstance(test_model_monitoring_job, ModelMonitoringJob) + + +# TODO: Add unit tests for visualization methods. diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py index a12ea8a96d..c29c0a2478 100644 --- a/tests/unit/vertexai/conftest.py +++ b/tests/unit/vertexai/conftest.py @@ -54,7 +54,7 @@ ResourceRuntimeSpec, ServiceAccountSpec, ) -from feature_store_constants import ( +from vertexai_feature_store_constants import ( _TEST_BIGTABLE_FOS1, _TEST_EMBEDDING_FV1, _TEST_ESF_OPTIMIZED_FOS, diff --git a/tests/unit/vertexai/test_vertexai_feature.py b/tests/unit/vertexai/test_vertexai_feature.py new file mode 100644 index 0000000000..151b926438 --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_feature.py @@ -0,0 +1,291 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict, List, Optional +from unittest import mock +from unittest.mock import call, patch + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform.compat import types +from google.cloud.aiplatform.compat.services import ( + feature_registry_service_client, +) +from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( + FeatureRegistryServiceClient, +) +from vertexai_feature_store_constants import ( + _TEST_FG1_F1_DESCRIPTION, + _TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY, + _TEST_FG1_F1_ID, + _TEST_FG1_F1_LABELS, + _TEST_FG1_F1_PATH, + _TEST_FG1_F1_POINT_OF_CONTACT, + _TEST_FG1_F1_WITH_STATS, + _TEST_FG1_F2_DESCRIPTION, + _TEST_FG1_F2_ID, + _TEST_FG1_F2_LABELS, + _TEST_FG1_F2_PATH, + _TEST_FG1_F2_POINT_OF_CONTACT, + _TEST_FG1_F2_VERSION_COLUMN_NAME, + _TEST_FG1_ID, + _TEST_LOCATION, + _TEST_PROJECT, +) +from vertexai.resources.preview import ( + Feature, + FeatureGroup, +) +import pytest + + +@pytest.fixture +def delete_feature_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "delete_feature", + ) as delete_feature_mock: + delete_feature_lro_mock = mock.Mock(ga_operation.Operation) + delete_feature_mock.return_value = delete_feature_lro_mock + yield delete_feature_mock + + +@pytest.fixture +def get_feature_with_stats_and_anomalies_mock(): + with patch.object( + FeatureRegistryServiceClient, + "get_feature", + ) as get_feature_with_stats_and_anomalies_mock: + get_feature_with_stats_and_anomalies_mock.return_value = _TEST_FG1_F1_WITH_STATS + yield get_feature_with_stats_and_anomalies_mock + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +def feature_eq( + feature_to_check: Feature, + name: str, + resource_name: str, + project: str, + location: str, + description: str, + labels: Dict[str, str], + point_of_contact: str, + version_column_name: Optional[str] = None, + feature_stats_and_anomalies: Optional[ + List[types.feature_monitor.FeatureStatsAndAnomaly] + ] = None, +): + """Check if a Feature has the appropriate values set.""" + assert feature_to_check.name == name + assert feature_to_check.resource_name == resource_name + assert feature_to_check.project == project + assert feature_to_check.location == location + assert feature_to_check.description == description + assert feature_to_check.labels == labels + assert feature_to_check.point_of_contact == point_of_contact + + if version_column_name: + assert feature_to_check.version_column_name == version_column_name + if feature_stats_and_anomalies: + assert ( + feature_to_check.feature_stats_and_anomalies == feature_stats_and_anomalies + ) + + +def test_init_with_feature_id_and_no_fg_id_raises_error(get_feature_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature 'my_fg1_f1' is not provided as a path, please specify" + + " feature_group_id." + ), + ): + Feature(_TEST_FG1_F1_ID) + + +def test_init_with_feature_path_and_fg_id_raises_error(get_feature_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature 'projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1' is provided as a path, feature_group_id should not be specified." + ), + ): + Feature(_TEST_FG1_F1_PATH, feature_group_id=_TEST_FG1_ID) + + +def test_init_with_feature_id(get_feature_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = Feature(_TEST_FG1_F1_ID, feature_group_id=_TEST_FG1_ID) + + get_feature_mock.assert_called_once_with( + name=_TEST_FG1_F1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_init_with_feature_id_for_explicit_version_column( + get_feature_with_version_column_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = Feature(_TEST_FG1_F2_ID, feature_group_id=_TEST_FG1_ID) + + get_feature_with_version_column_mock.assert_called_once_with( + name=_TEST_FG1_F2_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_eq( + feature, + name=_TEST_FG1_F2_ID, + resource_name=_TEST_FG1_F2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + ) + + +def test_init_with_feature_path(get_feature_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = Feature(_TEST_FG1_F1_PATH) + + get_feature_mock.assert_called_once_with( + name=_TEST_FG1_F1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_init_with_feature_path_for_explicit_version_column( + get_feature_with_version_column_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = Feature(_TEST_FG1_F2_PATH) + + get_feature_with_version_column_mock.assert_called_once_with( + name=_TEST_FG1_F2_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_eq( + feature, + name=_TEST_FG1_F2_ID, + resource_name=_TEST_FG1_F2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + ) + + +def test_init_with_latest_stats_count(get_feature_with_stats_and_anomalies_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = Feature(name=_TEST_FG1_F1_PATH, latest_stats_count=1) + + get_feature_with_stats_and_anomalies_mock.assert_called_once_with( + request=types.featurestore_service_v1beta1.GetFeatureRequest( + name=_TEST_FG1_F1_PATH, + feature_stats_and_anomaly_spec=types.feature_monitor.FeatureStatsAndAnomalySpec( + latest_stats_count=1 + ), + ) + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + feature_stats_and_anomalies=[_TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY], + ) + + +@pytest.mark.parametrize("sync", [True]) +def test_delete_feature( + get_fg_mock, get_feature_mock, delete_feature_mock, base_logger_mock, sync +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature = FeatureGroup(_TEST_FG1_ID).get_feature(_TEST_FG1_F1_ID) + feature.delete(sync=sync) + + if not sync: + feature.wait() + + delete_feature_mock.assert_called_once_with( + name=_TEST_FG1_F1_PATH, + ) + + base_logger_mock.assert_has_calls( + [ + call( + "Deleting Feature resource:" + " projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1" + ), + call( + "Delete Feature backing LRO:" + f" {delete_feature_mock.return_value.operation.name}" + ), + call( + "Feature resource" + " projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1" + " deleted." + ), + ] + ) diff --git a/tests/unit/vertexai/test_vertexai_feature_group.py b/tests/unit/vertexai/test_vertexai_feature_group.py new file mode 100644 index 0000000000..5767e1eb9b --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_feature_group.py @@ -0,0 +1,1027 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict, List +from unittest import mock +from unittest.mock import call, patch + +from google.auth import credentials as auth_credentials +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( + FeatureRegistryServiceClient, +) +from vertexai.resources.preview.feature_store import ( + feature_group, +) +from vertexai.resources.preview import ( + FeatureGroup, +) +from vertexai.resources.preview.feature_store import ( + FeatureGroupBigQuerySource, +) +import pytest +from google.cloud.aiplatform.compat.services import ( + feature_registry_service_client, +) +from google.cloud.aiplatform.compat import types + + +from vertexai_feature_store_constants import ( + _TEST_PARENT, + _TEST_PROJECT, + _TEST_LOCATION, + _TEST_FG1, + _TEST_FG1_ID, + _TEST_FG1_PATH, + _TEST_FG1_BQ_URI, + _TEST_FG1_ENTITY_ID_COLUMNS, + _TEST_FG1_LABELS, + _TEST_FG2_ID, + _TEST_FG2_PATH, + _TEST_FG2_BQ_URI, + _TEST_FG2_ENTITY_ID_COLUMNS, + _TEST_FG2_LABELS, + _TEST_FG3_ID, + _TEST_FG3_PATH, + _TEST_FG3_BQ_URI, + _TEST_FG3_ENTITY_ID_COLUMNS, + _TEST_FG3_LABELS, + _TEST_FG_LIST, + _TEST_FG1_F1, + _TEST_FG1_F1_ID, + _TEST_FG1_F1_PATH, + _TEST_FG1_F1_DESCRIPTION, + _TEST_FG1_F1_LABELS, + _TEST_FG1_F1_POINT_OF_CONTACT, + _TEST_FG1_F1_WITH_STATS, + _TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY, + _TEST_FG1_F2, + _TEST_FG1_F2_ID, + _TEST_FG1_F2_PATH, + _TEST_FG1_F2_DESCRIPTION, + _TEST_FG1_F2_LABELS, + _TEST_FG1_F2_POINT_OF_CONTACT, + _TEST_FG1_F2_VERSION_COLUMN_NAME, + _TEST_FG1_FEATURE_LIST, + _TEST_FG1_FM1, + _TEST_FG1_FM1_ID, + _TEST_FG1_FM1_PATH, + _TEST_FG1_FM1_DESCRIPTION, + _TEST_FG1_FM1_LABELS, + _TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + _TEST_FG1_FM1_SCHEDULE_CONFIG, + _TEST_FG1_FM2_ID, + _TEST_FG1_FM2_PATH, + _TEST_FG1_FM2_DESCRIPTION, + _TEST_FG1_FM2_LABELS, + _TEST_FG1_FM2_FEATURE_SELECTION_CONFIGS, + _TEST_FG1_FM2_SCHEDULE_CONFIG, + _TEST_FG1_FM_LIST, +) +from test_vertexai_feature import feature_eq +from test_vertexai_feature_monitor import ( + feature_monitor_eq, +) + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.fixture +def fg_logger_mock(): + with patch.object( + feature_group._LOGGER, + "info", + wraps=feature_group._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def create_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "create_feature_group", + ) as create_fg_mock: + create_fg_lro_mock = mock.Mock(ga_operation.Operation) + create_fg_lro_mock.result.return_value = _TEST_FG1 + create_fg_mock.return_value = create_fg_lro_mock + yield create_fg_mock + + +@pytest.fixture +def list_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "list_feature_groups", + ) as list_fg_mock: + list_fg_mock.return_value = _TEST_FG_LIST + yield list_fg_mock + + +@pytest.fixture +def delete_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "delete_feature_group", + ) as delete_fg_mock: + delete_fg_lro_mock = mock.Mock(ga_operation.Operation) + delete_fg_mock.return_value = delete_fg_lro_mock + yield delete_fg_mock + + +@pytest.fixture +def create_feature_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "create_feature", + ) as create_feature_mock: + create_feature_lro_mock = mock.Mock(ga_operation.Operation) + create_feature_lro_mock.result.return_value = _TEST_FG1_F1 + create_feature_mock.return_value = create_feature_lro_mock + yield create_feature_mock + + +@pytest.fixture +def create_feature_monitor_mock(): + with patch.object( + FeatureRegistryServiceClient, + "create_feature_monitor", + ) as create_feature_monitor_mock: + create_feature_monitor_lro_mock = mock.Mock(ga_operation.Operation) + create_feature_monitor_lro_mock.result.return_value = _TEST_FG1_FM1 + create_feature_monitor_mock.return_value = create_feature_monitor_lro_mock + yield create_feature_monitor_mock + + +@pytest.fixture +def create_feature_with_version_column_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "create_feature", + ) as create_feature_mock: + create_feature_lro_mock = mock.Mock(ga_operation.Operation) + create_feature_lro_mock.result.return_value = _TEST_FG1_F2 + create_feature_mock.return_value = create_feature_lro_mock + yield create_feature_mock + + +@pytest.fixture +def list_features_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "list_features", + ) as list_features_mock: + list_features_mock.return_value = _TEST_FG1_FEATURE_LIST + yield list_features_mock + + +@pytest.fixture +def list_feature_monitors_mock(): + with patch.object( + FeatureRegistryServiceClient, + "list_feature_monitors", + ) as list_feature_monitors_mock: + list_feature_monitors_mock.return_value = _TEST_FG1_FM_LIST + yield list_feature_monitors_mock + + +@pytest.fixture +def get_feature_with_stats_and_anomalies_mock(): + with patch.object( + FeatureRegistryServiceClient, + "get_feature", + ) as get_feature_with_stats_and_anomalies_mock: + get_feature_with_stats_and_anomalies_mock.return_value = _TEST_FG1_F1_WITH_STATS + yield get_feature_with_stats_and_anomalies_mock + + +@pytest.fixture() +def mock_base_instantiate_client(): + with patch.object( + aiplatform.base.VertexAiResourceNoun, + "_instantiate_client", + ) as base_instantiate_client_mock: + base_instantiate_client_mock.return_value = mock.MagicMock() + yield base_instantiate_client_mock + + +def fg_eq( + fg_to_check: FeatureGroup, + name: str, + resource_name: str, + source_uri: str, + entity_id_columns: List[str], + project: str, + location: str, + labels: Dict[str, str], +): + """Check if a FeatureGroup has the appropriate values set.""" + assert fg_to_check.name == name + assert fg_to_check.resource_name == resource_name + assert fg_to_check.source == FeatureGroupBigQuerySource( + uri=source_uri, + entity_id_columns=entity_id_columns, + ) + assert fg_to_check.project == project + assert fg_to_check.location == location + assert fg_to_check.labels == labels + + +@pytest.mark.parametrize( + "feature_group_name", + [_TEST_FG1_ID, _TEST_FG1_PATH], +) +def test_init(feature_group_name, get_fg_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(feature_group_name) + + get_fg_mock.assert_called_once_with( + name=_TEST_FG1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fg_eq( + fg, + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) + + +def test_create_fg_no_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify a valid source."), + ): + FeatureGroup.create("fg") + + +def test_create_fg_bad_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Only FeatureGroupBigQuerySource is a supported source."), + ): + FeatureGroup.create("fg", source=int(1)) + + +def test_create_fg_no_source_bq_uri_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in BigQuery source."), + ): + FeatureGroup.create( + "fg", source=FeatureGroupBigQuerySource(uri=None, entity_id_columns=None) + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_fg( + create_fg_mock, get_fg_mock, fg_logger_mock, create_request_timeout, sync +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup.create( + _TEST_FG1_ID, + source=FeatureGroupBigQuerySource( + uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fg.wait() + + # When creating, the FeatureOnlineStore object doesn't have the path set. + expected_fg = types.feature_group.FeatureGroup( + name=_TEST_FG1_ID, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG1_BQ_URI, + ), + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + ) + create_fg_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_group=expected_fg, + feature_group_id=_TEST_FG1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating FeatureGroup"), + call( + f"Create FeatureGroup backing LRO: {create_fg_mock.return_value.operation.name}" + ), + call( + "FeatureGroup created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1" + ), + call("To use this FeatureGroup in another session:"), + call( + "feature_group = aiplatform.FeatureGroup('projects/test-project/locations/us-central1/featureGroups/my_fg1')" + ), + ] + ) + + fg_eq( + fg, + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) + + +def test_list(list_fg_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature_groups = FeatureGroup.list() + + list_fg_mock.assert_called_once_with(request={"parent": _TEST_PARENT}) + assert len(feature_groups) == len(_TEST_FG_LIST) + fg_eq( + feature_groups[0], + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) + fg_eq( + feature_groups[1], + name=_TEST_FG2_ID, + resource_name=_TEST_FG2_PATH, + source_uri=_TEST_FG2_BQ_URI, + entity_id_columns=_TEST_FG2_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG2_LABELS, + ) + fg_eq( + feature_groups[2], + name=_TEST_FG3_ID, + resource_name=_TEST_FG3_PATH, + source_uri=_TEST_FG3_BQ_URI, + entity_id_columns=_TEST_FG3_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG3_LABELS, + ) + + +@pytest.mark.parametrize("force", [True, False]) +@pytest.mark.parametrize("sync", [True]) +def test_delete(delete_fg_mock, get_fg_mock, fg_logger_mock, force, sync): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + fg.delete(force=force, sync=sync) + + if not sync: + fg.wait() + + delete_fg_mock.assert_called_once_with( + name=_TEST_FG1_PATH, + force=force, + ) + + fg_logger_mock.assert_has_calls( + [ + call( + "Deleting FeatureGroup resource: projects/test-project/locations/us-central1/featureGroups/my_fg1" + ), + call( + f"Delete FeatureGroup backing LRO: {delete_fg_mock.return_value.operation.name}" + ), + call( + "FeatureGroup resource projects/test-project/locations/us-central1/featureGroups/my_fg1 deleted." + ), + ] + ) + + +def test_get_feature(get_fg_mock, get_feature_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + feature = fg.get_feature(_TEST_FG1_F1_ID) + + get_feature_mock.assert_called_once_with( + name=_TEST_FG1_F1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_with_latest_stats_count( + get_fg_mock, get_feature_with_stats_and_anomalies_mock +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + feature = fg.get_feature(_TEST_FG1_F1_ID, latest_stats_count=1) + + get_feature_with_stats_and_anomalies_mock.assert_called_once_with( + request=types.featurestore_service_v1beta1.GetFeatureRequest( + name=_TEST_FG1_F1_PATH, + feature_stats_and_anomaly_spec=types.feature_monitor.FeatureStatsAndAnomalySpec( + latest_stats_count=1 + ), + ) + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + feature_stats_and_anomalies=[_TEST_FG1_F1_FEATURE_STATS_AND_ANOMALY], + ) + + +def test_get_feature_credentials_set_in_init(mock_base_instantiate_client): + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=credentials + ) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + fg = FeatureGroup(_TEST_FG1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature = fg.get_feature(_TEST_FG1_F1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_from_feature_group_with_explicit_credentials( + mock_base_instantiate_client, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + fg = FeatureGroup(_TEST_FG1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature = fg.get_feature(_TEST_FG1_F1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_from_feature_group_with_explicit_credentials_overrides_init_credentials( + mock_base_instantiate_client, +): + init_credentials = mock.MagicMock(spec=auth_credentials.Credentials) + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials + ) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + fg = FeatureGroup(_TEST_FG1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature = fg.get_feature(_TEST_FG1_F1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_with_explicit_credentials(mock_base_instantiate_client): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + fg = FeatureGroup(_TEST_FG1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=mock.ANY, + appended_user_agent=None, + ) + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_with_explicit_credentials_overrides_init_credentials( + mock_base_instantiate_client, +): + init_credentials = mock.MagicMock(spec=auth_credentials.Credentials) + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials + ) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + fg = FeatureGroup(_TEST_FG1_ID) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=init_credentials, + appended_user_agent=None, + ) + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_with_explicit_credentials_overrides_feature_group_credentials( + mock_base_instantiate_client, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + feature_group_credentials = mock.MagicMock(spec=auth_credentials.Credentials) + fg = FeatureGroup(_TEST_FG1_ID, credentials=feature_group_credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=feature_group_credentials, + appended_user_agent=None, + ) + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +def test_get_feature_with_explicit_credentials_overrides_init_and_feature_group_credentials( + mock_base_instantiate_client, +): + init_credentials = mock.MagicMock(spec=auth_credentials.Credentials) + aiplatform.init( + project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=init_credentials + ) + + mock_base_instantiate_client.return_value.get_feature_group.return_value = _TEST_FG1 + mock_base_instantiate_client.return_value.get_feature.return_value = _TEST_FG1_F1 + + feature_group_credentials = mock.MagicMock(spec=auth_credentials.Credentials) + fg = FeatureGroup(_TEST_FG1_ID, credentials=feature_group_credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=feature_group_credentials, + appended_user_agent=None, + ) + + credentials = mock.MagicMock(spec=auth_credentials.Credentials) + feature = fg.get_feature(_TEST_FG1_F1_ID, credentials=credentials) + mock_base_instantiate_client.assert_called_with( + location=_TEST_LOCATION, + credentials=credentials, + appended_user_agent=None, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_feature( + get_fg_mock, + create_feature_mock, + get_feature_mock, + fg_logger_mock, + create_request_timeout, + sync, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + feature = fg.create_feature( + _TEST_FG1_F1_ID, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + feature.wait() + + expected_feature = types.feature.Feature( + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + create_feature_mock.assert_called_once_with( + parent=_TEST_FG1_PATH, + feature=expected_feature, + feature_id=_TEST_FG1_F1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + feature_eq( + feature, + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating Feature"), + call( + f"Create Feature backing LRO: {create_feature_mock.return_value.operation.name}" + ), + call( + "Feature created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1" + ), + call("To use this Feature in another session:"), + call( + "feature = aiplatform.Feature('projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f1')" + ), + ] + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_feature_with_version_feature_column( + get_fg_mock, + create_feature_with_version_column_mock, + get_feature_with_version_column_mock, + fg_logger_mock, + create_request_timeout, + sync, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + feature = fg.create_feature( + _TEST_FG1_F2_ID, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + feature.wait() + + expected_feature = types.feature.Feature( + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + ) + create_feature_with_version_column_mock.assert_called_once_with( + parent=_TEST_FG1_PATH, + feature=expected_feature, + feature_id=_TEST_FG1_F2_ID, + metadata=(), + timeout=create_request_timeout, + ) + + feature_eq( + feature, + name=_TEST_FG1_F2_ID, + resource_name=_TEST_FG1_F2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating Feature"), + call( + f"Create Feature backing LRO: {create_feature_with_version_column_mock.return_value.operation.name}" + ), + call( + "Feature created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f2" + ), + call("To use this Feature in another session:"), + call( + "feature = aiplatform.Feature('projects/test-project/locations/us-central1/featureGroups/my_fg1/features/my_fg1_f2')" + ), + ] + ) + + +def test_list_features(get_fg_mock, list_features_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + features = FeatureGroup(_TEST_FG1_ID).list_features() + + list_features_mock.assert_called_once_with(request={"parent": _TEST_FG1_PATH}) + assert len(features) == len(_TEST_FG1_FEATURE_LIST) + feature_eq( + features[0], + name=_TEST_FG1_F1_ID, + resource_name=_TEST_FG1_F1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F1_DESCRIPTION, + labels=_TEST_FG1_F1_LABELS, + point_of_contact=_TEST_FG1_F1_POINT_OF_CONTACT, + ) + feature_eq( + features[1], + name=_TEST_FG1_F2_ID, + resource_name=_TEST_FG1_F2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_F2_DESCRIPTION, + labels=_TEST_FG1_F2_LABELS, + point_of_contact=_TEST_FG1_F2_POINT_OF_CONTACT, + version_column_name=_TEST_FG1_F2_VERSION_COLUMN_NAME, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create_feature_monitor( + get_fg_mock, + get_feature_monitor_mock, + create_feature_monitor_mock, + fg_logger_mock, + create_request_timeout, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup(_TEST_FG1_ID) + feature_monitor = fg.create_feature_monitor( + _TEST_FG1_FM1_ID, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=_TEST_FG1_FM1_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + create_request_timeout=create_request_timeout, + ) + + expected_feature_monitor = types.feature_monitor.FeatureMonitor( + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=types.feature_monitor.ScheduleConfig( + cron=_TEST_FG1_FM1_SCHEDULE_CONFIG + ), + feature_selection_config=types.feature_monitor.FeatureSelectionConfig( + feature_configs=[ + types.feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id="my_fg1_f1", drift_threshold=0.3 + ), + types.feature_monitor.FeatureSelectionConfig.FeatureConfig( + feature_id="my_fg1_f2", drift_threshold=0.4 + ), + ] + ), + ) + create_feature_monitor_mock.assert_called_once_with( + parent=_TEST_FG1_PATH, + feature_monitor_id=_TEST_FG1_FM1_ID, + feature_monitor=expected_feature_monitor, + metadata=(), + timeout=create_request_timeout, + ) + + feature_monitor_eq( + feature_monitor, + name=_TEST_FG1_FM1_ID, + resource_name=_TEST_FG1_FM1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=_TEST_FG1_FM1_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating FeatureMonitor"), + call( + f"Create FeatureMonitor backing LRO:" + f" {create_feature_monitor_mock.return_value.operation.name}" + ), + call( + "FeatureMonitor created. Resource name:" + " projects/test-project/locations/us-central1/featureGroups/" + "my_fg1/featureMonitors/my_fg1_fm1" + ), + call("To use this FeatureMonitor in another session:"), + call( + "feature_monitor = aiplatform.FeatureMonitor(" + "'projects/test-project/locations/us-central1/featureGroups/" + "my_fg1/featureMonitors/my_fg1_fm1')" + ), + ] + ) + + +def test_list_feature_monitors( + get_fg_mock, get_feature_monitor_mock, list_feature_monitors_mock +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature_monitors = FeatureGroup(_TEST_FG1_ID).list_feature_monitors() + + list_feature_monitors_mock.assert_called_once_with( + request={"parent": _TEST_FG1_PATH} + ) + assert len(feature_monitors) == len(_TEST_FG1_FM_LIST) + feature_monitor_eq( + feature_monitors[0], + name=_TEST_FG1_FM1_ID, + resource_name=_TEST_FG1_FM1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=_TEST_FG1_FM1_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + ) + feature_monitor_eq( + feature_monitors[1], + name=_TEST_FG1_FM2_ID, + resource_name=_TEST_FG1_FM2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FM2_DESCRIPTION, + labels=_TEST_FG1_FM2_LABELS, + schedule_config=_TEST_FG1_FM2_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM2_FEATURE_SELECTION_CONFIGS, + ) diff --git a/tests/unit/vertexai/test_vertexai_feature_monitor.py b/tests/unit/vertexai/test_vertexai_feature_monitor.py new file mode 100644 index 0000000000..71191319b0 --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_feature_monitor.py @@ -0,0 +1,370 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict, List, Optional, Tuple +from unittest.mock import patch + +from google.cloud import aiplatform +from google.cloud.aiplatform import base + +from vertexai_feature_store_constants import ( + _TEST_PROJECT, + _TEST_LOCATION, + _TEST_FG1_ID, + _TEST_FG1_FM1_DESCRIPTION, + _TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + _TEST_FG1_FM1_ID, + _TEST_FG1_FM1_LABELS, + _TEST_FG1_FM1_PATH, + _TEST_FG1_FM1_SCHEDULE_CONFIG, + _TEST_FG1_FMJ1, + _TEST_FG1_FMJ1_DESCRIPTION, + _TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES, + _TEST_FG1_FMJ1_ID, + _TEST_FG1_FMJ1_LABELS, + _TEST_FG1_FMJ_LIST, + _TEST_FG1_FMJ1_PATH, + _TEST_FG1_FMJ2_DESCRIPTION, + _TEST_FG1_FMJ2_LABELS, + _TEST_FG1_FMJ2_PATH, +) +from vertexai.resources.preview import FeatureMonitor +from google.cloud.aiplatform_v1beta1.services.feature_registry_service import ( + FeatureRegistryServiceClient, +) +from google.cloud.aiplatform.compat import types +from vertexai.resources.preview.feature_store import ( + feature_monitor, +) +import pytest + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.fixture +def fm_logger_mock(): + with patch.object( + feature_monitor._LOGGER, + "info", + wraps=feature_monitor._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def get_feature_monitor_job_mock(): + with patch.object( + FeatureRegistryServiceClient, + "get_feature_monitor_job", + ) as get_fmj_mock: + get_fmj_mock.return_value = _TEST_FG1_FMJ1 + yield get_fmj_mock + + +@pytest.fixture +def create_feature_monitor_job_mock(): + with patch.object( + FeatureRegistryServiceClient, + "create_feature_monitor_job", + ) as create_feature_monitor_job_mock: + create_feature_monitor_job_mock.return_value = _TEST_FG1_FMJ1 + yield create_feature_monitor_job_mock + + +@pytest.fixture +def list_feature_monitor_jobs_mock(): + with patch.object( + FeatureRegistryServiceClient, + "list_feature_monitor_jobs", + ) as list_feature_monitor_jobs_mock: + list_feature_monitor_jobs_mock.return_value = _TEST_FG1_FMJ_LIST + yield list_feature_monitor_jobs_mock + + +def feature_monitor_eq( + feature_monitor_to_check: FeatureMonitor, + name: str, + resource_name: str, + project: str, + location: str, + description: str, + labels: Dict[str, str], + schedule_config: str, + feature_selection_configs: List[Tuple[str, float]], +): + """Check if a Feature Monitor has the appropriate values set.""" + assert feature_monitor_to_check.name == name + assert feature_monitor_to_check.resource_name == resource_name + assert feature_monitor_to_check.project == project + assert feature_monitor_to_check.location == location + assert feature_monitor_to_check.description == description + assert feature_monitor_to_check.labels == labels + assert feature_monitor_to_check.schedule_config == schedule_config + assert ( + feature_monitor_to_check.feature_selection_configs == feature_selection_configs + ) + + +def feature_monitor_job_eq( + feature_monitor_job_to_check: FeatureMonitor.FeatureMonitorJob, + resource_name: str, + project: str, + location: str, + description: str, + labels: Dict[str, str], + feature_stats_and_anomalies: Optional[ + List[types.feature_monitor.FeatureStatsAndAnomaly] + ] = None, +): + """Check if a Feature Monitor Job has the appropriate values set.""" + assert feature_monitor_job_to_check.resource_name == resource_name + assert feature_monitor_job_to_check.project == project + assert feature_monitor_job_to_check.location == location + assert feature_monitor_job_to_check.description == description + assert feature_monitor_job_to_check.labels == labels + if feature_stats_and_anomalies: + assert ( + feature_monitor_job_to_check.feature_stats_and_anomalies + == feature_stats_and_anomalies + ) + + +def test_init_with_feature_monitor_id_and_no_fg_id_raises_error(): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature monitor 'my_fg1_fm1' is not provided as a path, please" + " specify feature_group_id." + ), + ): + FeatureMonitor(_TEST_FG1_FM1_ID) + + +def test_init_with_feature_monitor_path_and_fg_id_raises_error(): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature monitor 'projects/test-project/locations/us-central1/" + "featureGroups/my_fg1/featureMonitors/my_fg1_fm1' is provided as a " + "path, feature_group_id should not be specified." + ), + ): + FeatureMonitor( + _TEST_FG1_FM1_PATH, + feature_group_id=_TEST_FG1_ID, + ) + + +def test_init_with_feature_monitor_id(get_feature_monitor_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + feature_monitor = FeatureMonitor( + _TEST_FG1_FM1_ID, + feature_group_id=_TEST_FG1_ID, + ) + + get_feature_monitor_mock.assert_called_once_with( + name=_TEST_FG1_FM1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_monitor_eq( + feature_monitor, + name=_TEST_FG1_FM1_ID, + resource_name=_TEST_FG1_FM1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=_TEST_FG1_FM1_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + ) + + +def test_init_with_feature_monitor_path(get_feature_monitor_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + feature_monitor = FeatureMonitor(_TEST_FG1_FM1_PATH) + + get_feature_monitor_mock.assert_called_once_with( + name=_TEST_FG1_FM1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_monitor_eq( + feature_monitor, + name=_TEST_FG1_FM1_ID, + resource_name=_TEST_FG1_FM1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FM1_DESCRIPTION, + labels=_TEST_FG1_FM1_LABELS, + schedule_config=_TEST_FG1_FM1_SCHEDULE_CONFIG, + feature_selection_configs=_TEST_FG1_FM1_FEATURE_SELECTION_CONFIGS, + ) + + +def test_init_with_feature_monitor_job_path(get_feature_monitor_job_mock): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + feature_monitor_job = FeatureMonitor.FeatureMonitorJob(_TEST_FG1_FMJ1_PATH) + + get_feature_monitor_job_mock.assert_called_once_with( + name=_TEST_FG1_FMJ1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_monitor_job_eq( + feature_monitor_job, + resource_name=_TEST_FG1_FMJ1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + feature_stats_and_anomalies=_TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create_feature_monitor_job( + get_feature_monitor_mock, + get_feature_monitor_job_mock, + create_feature_monitor_job_mock, + create_request_timeout, + fm_logger_mock, +): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + fm = FeatureMonitor( + _TEST_FG1_FM1_ID, + feature_group_id=_TEST_FG1_ID, + ) + feature_monitor_job = fm.create_feature_monitor_job( + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + create_request_timeout=create_request_timeout, + ) + + expected_feature_monitor_job = types.feature_monitor_job.FeatureMonitorJob( + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + ) + create_feature_monitor_job_mock.assert_called_once_with( + parent=_TEST_FG1_FM1_PATH, + feature_monitor_job=expected_feature_monitor_job, + metadata=(), + timeout=create_request_timeout, + ) + + feature_monitor_job_eq( + feature_monitor_job, + resource_name=_TEST_FG1_FMJ1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + feature_stats_and_anomalies=_TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES, + ) + + +def test_get_feature_monitor_job( + get_feature_monitor_mock, get_feature_monitor_job_mock +): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + fm = FeatureMonitor( + _TEST_FG1_FM1_ID, + feature_group_id=_TEST_FG1_ID, + ) + feature_monitor_job = fm.get_feature_monitor_job(_TEST_FG1_FMJ1_ID) + + get_feature_monitor_job_mock.assert_called_once_with( + name=_TEST_FG1_FMJ1_PATH, + retry=base._DEFAULT_RETRY, + ) + + feature_monitor_job_eq( + feature_monitor_job, + resource_name=_TEST_FG1_FMJ1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + feature_stats_and_anomalies=_TEST_FG1_FMJ1_FEATURE_STATS_AND_ANOMALIES, + ) + + +def test_list_feature_monitors_jobs( + get_feature_monitor_mock, list_feature_monitor_jobs_mock +): + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + feature_monitor_jobs = FeatureMonitor( + _TEST_FG1_FM1_ID, + feature_group_id=_TEST_FG1_ID, + ).list_feature_monitor_jobs() + + list_feature_monitor_jobs_mock.assert_called_once_with( + request={"parent": _TEST_FG1_FM1_PATH} + ) + assert len(feature_monitor_jobs) == len(_TEST_FG1_FMJ_LIST) + feature_monitor_job_eq( + feature_monitor_jobs[0], + resource_name=_TEST_FG1_FMJ1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FMJ1_DESCRIPTION, + labels=_TEST_FG1_FMJ1_LABELS, + ) + feature_monitor_job_eq( + feature_monitor_jobs[1], + resource_name=_TEST_FG1_FMJ2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + description=_TEST_FG1_FMJ2_DESCRIPTION, + labels=_TEST_FG1_FMJ2_LABELS, + ) diff --git a/tests/unit/vertexai/test_vertexai_feature_online_store.py b/tests/unit/vertexai/test_vertexai_feature_online_store.py new file mode 100644 index 0000000000..5df9e0b339 --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_feature_online_store.py @@ -0,0 +1,847 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict +from unittest import mock +from unittest.mock import call +from unittest.mock import patch + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform.compat import types +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, +) +from vertexai_feature_store_constants import ( + _TEST_BIGTABLE_FOS1_ID, + _TEST_BIGTABLE_FOS1_LABELS, + _TEST_BIGTABLE_FOS1_PATH, + _TEST_BIGTABLE_FOS2_ID, + _TEST_BIGTABLE_FOS2_LABELS, + _TEST_BIGTABLE_FOS2_PATH, + _TEST_BIGTABLE_FOS3_ID, + _TEST_BIGTABLE_FOS3_LABELS, + _TEST_BIGTABLE_FOS3_PATH, + _TEST_ESF_OPTIMIZED_FOS_ID, + _TEST_ESF_OPTIMIZED_FOS_LABELS, + _TEST_ESF_OPTIMIZED_FOS_PATH, + _TEST_FG1_F1_ID, + _TEST_FG1_F2_ID, + _TEST_FG2_F1_ID, + _TEST_FG2_F2_ID, + _TEST_FG1_ID, + _TEST_FG2_ID, + _TEST_FOS_LIST, + _TEST_FV1_BQ_URI, + _TEST_FV1_ENTITY_ID_COLUMNS, + _TEST_FV1_ID, + _TEST_FV1_LABELS, + _TEST_FV1_PATH, + _TEST_FV2_ID, + _TEST_FV2_LABELS, + _TEST_FV2_PATH, + _TEST_FV3_BQ_URI, + _TEST_FV3_ID, + _TEST_FV3_LABELS, + _TEST_FV3_PATH, + _TEST_FV4_ID, + _TEST_FV4_LABELS, + _TEST_FV4_PATH, + _TEST_FV_LIST, + _TEST_LOCATION, + _TEST_OPTIMIZED_EMBEDDING_FV_ID, + _TEST_OPTIMIZED_EMBEDDING_FV_PATH, + _TEST_PARENT, + _TEST_PROJECT, + _TEST_PSC_OPTIMIZED_FOS_ID, + _TEST_PSC_OPTIMIZED_FOS_LABELS, + _TEST_PSC_OPTIMIZED_FOS_PATH, + _TEST_PSC_PROJECT_ALLOWLIST, +) +from test_vertexai_feature_view import fv_eq +from vertexai.resources.preview import ( + DistanceMeasureType, + FeatureOnlineStore, + FeatureOnlineStoreType, + FeatureViewBigQuerySource, + FeatureViewRegistrySource, + FeatureViewVertexRagSource, + IndexConfig, + TreeAhConfig, +) +from vertexai.resources.preview.feature_store import ( + feature_online_store, +) +import pytest + + +@pytest.fixture +def fos_logger_mock(): + with patch.object( + feature_online_store._LOGGER, + "info", + wraps=feature_online_store._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def list_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_online_stores", + ) as list_fos_mock: + list_fos_mock.return_value = _TEST_FOS_LIST + yield list_fos_mock + + +@pytest.fixture +def delete_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "delete_feature_online_store", + ) as delete_fos_mock: + delete_fos_lro_mock = mock.Mock(ga_operation.Operation) + delete_fos_mock.return_value = delete_fos_lro_mock + yield delete_fos_mock + + +def fos_eq( + fos_to_check: FeatureOnlineStore, + name: str, + resource_name: str, + project: str, + location: str, + labels: Dict[str, str], + type: FeatureOnlineStoreType, +): + """Check if a FeatureOnlineStore has the appropriate values set.""" + assert fos_to_check.name == name + assert fos_to_check.resource_name == resource_name + assert fos_to_check.project == project + assert fos_to_check.location == location + assert fos_to_check.labels == labels + assert fos_to_check.feature_online_store_type == type + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.mark.parametrize( + "online_store_name", + [_TEST_BIGTABLE_FOS1_ID, _TEST_BIGTABLE_FOS1_PATH], +) +def test_init(online_store_name, get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore(online_store_name) + + get_fos_mock.assert_called_once_with( + name=_TEST_BIGTABLE_FOS1_PATH, retry=base._DEFAULT_RETRY + ) + + fos_eq( + fos, + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create( + create_request_timeout, + create_bigtable_fos_mock, + get_fos_mock, + fos_logger_mock, + sync=True, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore.create_bigtable_store( + _TEST_BIGTABLE_FOS1_ID, + labels=_TEST_BIGTABLE_FOS1_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureOnlineStore object doesn't have the path set. + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + bigtable=types.feature_online_store_v1.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store_v1.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=1, + max_node_count=1, + cpu_utilization_target=50, + ) + ), + labels=_TEST_BIGTABLE_FOS1_LABELS, + ) + create_bigtable_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_BIGTABLE_FOS1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + f"Create FeatureOnlineStore backing LRO: {create_bigtable_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store = aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create_esf_optimized_store( + create_request_timeout, + create_esf_optimized_fos_mock, + get_esf_optimized_fos_mock, + fos_logger_mock, + sync=True, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore.create_optimized_store( + _TEST_ESF_OPTIMIZED_FOS_ID, + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fos.wait() + + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + ) + create_esf_optimized_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_ESF_OPTIMIZED_FOS_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + "Create FeatureOnlineStore backing LRO:" + f" {create_esf_optimized_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_esf_optimized_fos" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store =" + " aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_esf_optimized_fos')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_ESF_OPTIMIZED_FOS_ID, + resource_name=_TEST_ESF_OPTIMIZED_FOS_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + type=FeatureOnlineStoreType.OPTIMIZED, + ) + + +def test_create_psc_optimized_store_no_project_allowlist_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is" + " set to true." + ), + ): + FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + enable_private_service_connect=True, + ) + + +def test_create_psc_optimized_store_empty_project_allowlist_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "`project_allowlist` cannot be empty when `enable_private_service_connect` is" + " set to true." + ), + ): + FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + enable_private_service_connect=True, + project_allowlist=[], + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_psc_optimized_store( + create_psc_optimized_fos_mock, + get_psc_optimized_fos_mock, + fos_logger_mock, + create_request_timeout, + sync, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + create_request_timeout=create_request_timeout, + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ) + + if not sync: + fos.wait() + + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + private_service_connect_config=types.service_networking_v1.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ) + ), + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + ) + create_psc_optimized_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_PSC_OPTIMIZED_FOS_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + "Create FeatureOnlineStore backing LRO:" + f" {create_psc_optimized_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_psc_optimized_fos" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store =" + " aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_psc_optimized_fos')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_PSC_OPTIMIZED_FOS_ID, + resource_name=_TEST_PSC_OPTIMIZED_FOS_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + type=FeatureOnlineStoreType.OPTIMIZED, + ) + + +def test_list(list_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + online_stores = FeatureOnlineStore.list() + + list_fos_mock.assert_called_once_with(request={"parent": _TEST_PARENT}) + assert len(online_stores) == len(_TEST_FOS_LIST) + fos_eq( + online_stores[0], + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + fos_eq( + online_stores[1], + name=_TEST_BIGTABLE_FOS2_ID, + resource_name=_TEST_BIGTABLE_FOS2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS2_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + fos_eq( + online_stores[2], + name=_TEST_BIGTABLE_FOS3_ID, + resource_name=_TEST_BIGTABLE_FOS3_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS3_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("force", [True, False]) +def test_delete(force, delete_fos_mock, get_fos_mock, fos_logger_mock, sync=True): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + fos.delete(force=force, sync=sync) + + if not sync: + fos.wait() + + delete_fos_mock.assert_called_once_with( + name=_TEST_BIGTABLE_FOS1_PATH, + force=force, + ) + + fos_logger_mock.assert_has_calls( + [ + call( + "Deleting FeatureOnlineStore resource: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1" + ), + call( + f"Delete FeatureOnlineStore backing LRO: {delete_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore resource projects/test-project/locations/us-central1/featureOnlineStores/my_fos1 deleted." + ), + ] + ) + + +def test_create_fv_none_source_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify a valid source."), + ): + fos.create_feature_view("bq_fv", None) + + +def test_create_fv_wrong_object_type_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape( + "Only FeatureViewBigQuerySource, FeatureViewVertexRagSource and" + " FeatureViewRegistrySource are supported sources." + ), + ): + fos.create_feature_view("bq_fv", fos) + + +def test_create_bq_fv_bad_uri_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in BigQuery source."), + ): + fos.create_feature_view( + "bq_fv", + FeatureViewBigQuerySource(uri=None, entity_id_columns=["entity_id"]), + ) + + +@pytest.mark.parametrize("entity_id_columns", [None, []]) +def test_create_bq_fv_bad_entity_id_columns_raises_error( + entity_id_columns, get_fos_mock +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify entity ID columns in BigQuery source."), + ): + fos.create_feature_view( + "bq_fv", + FeatureViewBigQuerySource(uri="hi", entity_id_columns=entity_id_columns), + ) + + +@pytest.mark.parametrize( + "features", [None, [], [".feature"], [".", "feature"], ["feature.", "feature"]] +) +def test_create_fr_fv_invalid_feature_name_raises_error(features, get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape( + "Please specify features in Registry Source in format `.`." + ), + ): + fos.create_feature_view( + "fr_fv", + FeatureViewRegistrySource(features=features), + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_bq_fv( + create_request_timeout, + sync, + get_fos_mock, + create_bq_fv_mock, + get_fv_mock, + fos_logger_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + fv = fos.create_feature_view( + _TEST_FV1_ID, + FeatureViewBigQuerySource( + uri=_TEST_FV1_BQ_URI, entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS + ), + labels=_TEST_FV1_LABELS, + create_request_timeout=create_request_timeout, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureView object doesn't have the path set. + expected_fv = types.feature_view.FeatureView( + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, + ) + create_bq_fv_mock.assert_called_with( + parent=_TEST_BIGTABLE_FOS1_PATH, + feature_view=expected_fv, + feature_view_id=_TEST_FV1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureView"), + call( + f"Create FeatureView backing LRO: {create_bq_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView created. Resource name: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1" + ), + call("To use this FeatureView in another session:"), + call( + "feature_view = aiplatform.FeatureView('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1')" + ), + ] + ) + + +def test_create_embedding_fv( + get_esf_optimized_fos_mock, + create_embedding_fv_from_bq_mock, + get_optimized_embedding_fv_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_ESF_OPTIMIZED_FOS_ID) + + embedding_fv = fos.create_feature_view( + _TEST_OPTIMIZED_EMBEDDING_FV_ID, + FeatureViewBigQuerySource(uri="hi", entity_id_columns=["entity_id"]), + index_config=IndexConfig( + embedding_column="embedding", + dimensions=1536, + filter_columns=["currency_code", "gender", "shipping_country_codes"], + crowding_column="crowding", + distance_measure_type=DistanceMeasureType.SQUARED_L2_DISTANCE, + algorithm_config=TreeAhConfig(), + ), + ) + fv_eq( + fv_to_check=embedding_fv, + name=_TEST_OPTIMIZED_EMBEDDING_FV_ID, + resource_name=_TEST_OPTIMIZED_EMBEDDING_FV_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + +def test_create_rag_fv_bad_uri_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in Vertex RAG source."), + ): + fos.create_feature_view( + "rag_fv", + FeatureViewVertexRagSource(uri=None), + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_rag_fv( + create_request_timeout, + sync, + get_fos_mock, + create_rag_fv_mock, + get_rag_fv_mock, + fos_logger_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + rag_fv = fos.create_feature_view( + _TEST_FV3_ID, + FeatureViewVertexRagSource(uri=_TEST_FV3_BQ_URI), + labels=_TEST_FV3_LABELS, + create_request_timeout=create_request_timeout, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureView object doesn't have the path set. + expected_fv = types.feature_view.FeatureView( + vertex_rag_source=types.feature_view.FeatureView.VertexRagSource( + uri=_TEST_FV3_BQ_URI, + ), + labels=_TEST_FV3_LABELS, + ) + create_rag_fv_mock.assert_called_with( + parent=_TEST_BIGTABLE_FOS1_PATH, + feature_view=expected_fv, + feature_view_id=_TEST_FV3_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fv_eq( + fv_to_check=rag_fv, + name=_TEST_FV3_ID, + resource_name=_TEST_FV3_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV3_LABELS, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureView"), + call( + "Create FeatureView backing LRO:" + f" {create_rag_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv3" + ), + call("To use this FeatureView in another session:"), + call( + "feature_view =" + " aiplatform.FeatureView('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv3')" + ), + ] + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_registry_fv( + create_request_timeout, + sync, + get_fos_mock, + create_registry_fv_mock, + get_registry_fv_mock, + fos_logger_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + registry_fv = fos.create_feature_view( + _TEST_FV4_ID, + FeatureViewRegistrySource( + features=[ + ".".join([_TEST_FG1_ID, _TEST_FG1_F1_ID]), + ".".join([_TEST_FG1_ID, _TEST_FG1_F2_ID]), + ".".join([_TEST_FG2_ID, _TEST_FG2_F1_ID]), + ".".join([_TEST_FG2_ID, _TEST_FG2_F2_ID]), + ] + ), + labels=_TEST_FV4_LABELS, + create_request_timeout=create_request_timeout, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureView object doesn't have the path set. + expected_fv = types.feature_view.FeatureView( + feature_registry_source=types.feature_view.FeatureView.FeatureRegistrySource( + feature_groups=[ + types.feature_view.FeatureView.FeatureRegistrySource.FeatureGroup( + feature_group_id=_TEST_FG1_ID, + feature_ids=[_TEST_FG1_F1_ID, _TEST_FG1_F2_ID], + ), + types.feature_view.FeatureView.FeatureRegistrySource.FeatureGroup( + feature_group_id=_TEST_FG2_ID, + feature_ids=[_TEST_FG2_F1_ID, _TEST_FG2_F2_ID], + ), + ], + ), + labels=_TEST_FV4_LABELS, + ) + create_registry_fv_mock.assert_called_with( + parent=_TEST_BIGTABLE_FOS1_PATH, + feature_view=expected_fv, + feature_view_id=_TEST_FV4_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fv_eq( + fv_to_check=registry_fv, + name=_TEST_FV4_ID, + resource_name=_TEST_FV4_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV4_LABELS, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureView"), + call( + "Create FeatureView backing LRO:" + f" {create_registry_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv4" + ), + call("To use this FeatureView in another session:"), + call( + "feature_view =" + " aiplatform.FeatureView('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv4')" + ), + ] + ) + + +def test_list_feature_views( + get_fos_mock, + list_fv_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + feature_views = fos.list_feature_views() + + list_fv_mock.assert_called_once_with(request={"parent": _TEST_BIGTABLE_FOS1_PATH}) + assert len(feature_views) == len(_TEST_FV_LIST) + + fv_eq( + feature_views[0], + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + fv_eq( + feature_views[1], + name=_TEST_FV2_ID, + resource_name=_TEST_FV2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV2_LABELS, + ) + fv_eq( + feature_views[2], + name=_TEST_FV3_ID, + resource_name=_TEST_FV3_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV3_LABELS, + ) + fv_eq( + feature_views[3], + name=_TEST_FV4_ID, + resource_name=_TEST_FV4_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV4_LABELS, + ) diff --git a/tests/unit/vertexai/test_vertexai_feature_view.py b/tests/unit/vertexai/test_vertexai_feature_view.py new file mode 100644 index 0000000000..724959ab95 --- /dev/null +++ b/tests/unit/vertexai/test_vertexai_feature_view.py @@ -0,0 +1,856 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict +from unittest import mock +from unittest.mock import call, patch +from google.api_core import operation as ga_operation + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from vertexai.resources.preview import ( + FeatureView, +) +import vertexai.resources.preview.feature_store.utils as fs_utils +import pytest +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, + feature_online_store_service_client, +) +from vertexai.resources.preview.feature_store import ( + feature_view, +) + +from vertexai_feature_store_constants import ( + _TEST_BIGTABLE_FOS1_ID, + _TEST_BIGTABLE_FOS1_PATH, + _TEST_EMBEDDING_FV1_PATH, + _TEST_STRING_FILTER, + _TEST_FV1_ID, + _TEST_FV1_LABELS, + _TEST_FV1_PATH, + _TEST_FV2_ID, + _TEST_FV2_LABELS, + _TEST_FV2_PATH, + _TEST_FV3_ID, + _TEST_FV3_LABELS, + _TEST_FV3_PATH, + _TEST_FV4_ID, + _TEST_FV4_LABELS, + _TEST_FV4_PATH, + _TEST_FV_FETCH1, + _TEST_FV_LIST, + _TEST_FV_SEARCH1, + _TEST_FV_SYNC1, + _TEST_FV_SYNC1_ID, + _TEST_FV_SYNC1_PATH, + _TEST_FV_SYNC2_ID, + _TEST_FV_SYNC2_PATH, + _TEST_FV_SYNC_LIST, + _TEST_LOCATION, + _TEST_OPTIMIZED_FV1_PATH, + _TEST_OPTIMIZED_FV2_PATH, + _TEST_PROJECT, + _TEST_FV_SYNC1_RESPONSE, +) + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.fixture +def fv_logger_mock(): + with patch.object( + feature_view._LOGGER, + "info", + wraps=feature_view._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def delete_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "delete_feature_view", + ) as delete_fv: + delete_fv_lro_mock = mock.Mock(ga_operation.Operation) + delete_fv.return_value = delete_fv_lro_mock + yield delete_fv + + +@pytest.fixture +def get_fv_sync_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view_sync", + ) as get_fv_sync_mock: + get_fv_sync_mock.return_value = _TEST_FV_SYNC1 + yield get_fv_sync_mock + + +@pytest.fixture +def list_fv_syncs_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_view_syncs", + ) as list_fv_syncs_mock: + list_fv_syncs_mock.return_value = _TEST_FV_SYNC_LIST + yield list_fv_syncs_mock + + +@pytest.fixture +def sync_fv_sync_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "sync_feature_view", + ) as sync_fv_sync_mock: + sync_fv_sync_mock.return_value = _TEST_FV_SYNC1_RESPONSE + yield sync_fv_sync_mock + + +@pytest.fixture +def fetch_feature_values_mock(): + with patch.object( + feature_online_store_service_client.FeatureOnlineStoreServiceClient, + "fetch_feature_values", + ) as fetch_feature_values_mock: + fetch_feature_values_mock.return_value = _TEST_FV_FETCH1 + yield fetch_feature_values_mock + + +@pytest.fixture +def search_nearest_entities_mock(): + with patch.object( + feature_online_store_service_client.FeatureOnlineStoreServiceClient, + "search_nearest_entities", + ) as search_nearest_entities_mock: + search_nearest_entities_mock.return_value = _TEST_FV_SEARCH1 + yield search_nearest_entities_mock + + +@pytest.fixture +def transport_mock(): + with mock.patch( + "google.cloud.aiplatform_v1.services.feature_online_store_service.transports.grpc.FeatureOnlineStoreServiceGrpcTransport" + ) as transport: + transport.return_value = mock.MagicMock(autospec=True) + yield transport + + +@pytest.fixture +def grpc_insecure_channel_mock(): + import grpc + + with mock.patch.object(grpc, "insecure_channel", autospec=True) as channel: + channel.return_value = mock.MagicMock(autospec=True) + yield channel + + +@pytest.fixture +def client_mock(): + with mock.patch( + "google.cloud.aiplatform_v1.services.feature_online_store_service.FeatureOnlineStoreServiceClient" + ) as client_mock: + yield client_mock + + +@pytest.fixture +def utils_client_with_override_mock(): + with mock.patch( + "google.cloud.aiplatform.utils.FeatureOnlineStoreClientWithOverride" + ) as client_mock: + yield client_mock + + +def fv_eq( + fv_to_check: FeatureView, + name: str, + resource_name: str, + project: str, + location: str, + labels: Dict[str, str], +): + """Check if a FeatureView has the appropriate values set.""" + assert fv_to_check.name == name + assert fv_to_check.resource_name == resource_name + assert fv_to_check.project == project + assert fv_to_check.location == location + assert fv_to_check.labels == labels + + +def fv_sync_eq( + fv_sync_to_check: FeatureView.FeatureViewSync, + name: str, + resource_name: str, + project: str, + location: str, +): + """Check if a FeatureViewSync has the appropriate values set.""" + assert fv_sync_to_check.name == name + assert fv_sync_to_check.resource_name == resource_name + assert fv_sync_to_check.project == project + assert fv_sync_to_check.location == location + + +def test_init_with_fv_id_and_no_fos_id_raises_error(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature view is not provided as a path, please specify" + + " feature_online_store_id." + ), + ): + FeatureView(_TEST_FV1_ID) + + +def test_init_with_fv_id(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(_TEST_FV1_ID, feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + +def test_init_with_fv_path(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(_TEST_FV1_PATH) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + +def test_list(list_fv_mock, get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature_views = FeatureView.list(feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + + list_fv_mock.assert_called_once_with(request={"parent": _TEST_BIGTABLE_FOS1_PATH}) + assert len(feature_views) == len(_TEST_FV_LIST) + + fv_eq( + feature_views[0], + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + fv_eq( + feature_views[1], + name=_TEST_FV2_ID, + resource_name=_TEST_FV2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV2_LABELS, + ) + fv_eq( + feature_views[2], + name=_TEST_FV3_ID, + resource_name=_TEST_FV3_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV3_LABELS, + ) + fv_eq( + feature_views[3], + name=_TEST_FV4_ID, + resource_name=_TEST_FV4_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV4_LABELS, + ) + + +def test_delete(delete_fv_mock, fv_logger_mock, get_fos_mock, get_fv_mock, sync=True): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(name=_TEST_FV1_ID, feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + fv.delete() + + if not sync: + fv.wait() + + delete_fv_mock.assert_called_once_with(name=_TEST_FV1_PATH) + + fv_logger_mock.assert_has_calls( + [ + call( + "Deleting FeatureView resource: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1" + ), + call( + f"Delete FeatureView backing LRO: {delete_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView resource projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1 deleted." + ), + ] + ) + + +def test_get_sync(get_fv_mock, get_fv_sync_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_sync = FeatureView(_TEST_FV1_PATH).get_sync(_TEST_FV_SYNC1_ID) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + get_fv_sync_mock.assert_called_once_with( + name=_TEST_FV_SYNC1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_sync_eq( + fv_sync_to_check=fv_sync, + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +def test_list_syncs(get_fv_mock, list_fv_syncs_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_syncs = FeatureView(_TEST_FV1_PATH).list_syncs() + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + list_fv_syncs_mock.assert_called_once_with(request={"parent": _TEST_FV1_PATH}) + assert len(fv_syncs) == len(_TEST_FV_SYNC_LIST) + + fv_sync_eq( + fv_sync_to_check=fv_syncs[0], + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + fv_sync_eq( + fv_sync_to_check=fv_syncs[1], + name=_TEST_FV_SYNC2_ID, + resource_name=_TEST_FV_SYNC2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +def test_on_demand_sync(get_fv_mock, get_fv_sync_mock, sync_fv_sync_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_sync = FeatureView(_TEST_FV1_PATH).sync() + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + sync_fv_sync_mock.assert_called_once_with( + request={"feature_view": _TEST_FV1_PATH}, + ) + + get_fv_sync_mock.assert_called_once_with( + name=_TEST_FV_SYNC1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_sync_eq( + fv_sync_to_check=fv_sync, + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_fetch_feature_values_bigtable( + get_fos_mock, get_fv_mock, fetch_feature_values_mock, fv_logger_mock, output_type +): + if output_type == "dict": + fv_dict = FeatureView(_TEST_FV1_PATH).read(key=["key1"]).to_dict() + assert fv_dict == { + "features": [{"name": "key1", "value": {"string_value": "value1"}}] + } + elif output_type == "proto": + fv_proto = FeatureView(_TEST_FV1_PATH).read(key=["key1"]).to_proto() + assert fv_proto == _TEST_FV_FETCH1 + + fv_logger_mock.assert_has_calls( + [ + call("Connecting to Bigtable online store name my_fos1"), + ] + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_fetch_feature_values_optimized( + get_esf_optimized_fos_mock, + get_optimized_fv_mock, + fetch_feature_values_mock, + fv_logger_mock, + output_type, +): + if output_type == "dict": + fv_dict = FeatureView(_TEST_OPTIMIZED_FV1_PATH).read(key=["key1"]).to_dict() + assert fv_dict == { + "features": [{"name": "key1", "value": {"string_value": "value1"}}] + } + elif output_type == "proto": + fv_proto = FeatureView(_TEST_OPTIMIZED_FV1_PATH).read(key=["key1"]).to_proto() + assert fv_proto == _TEST_FV_FETCH1 + + fv_logger_mock.assert_has_calls( + [ + call( + "Public endpoint for the optimized online store my_esf_optimized_fos is test-esf-endpoint" + ), + ] + ) + + +def test_fetch_feature_values_optimized_no_endpoint( + get_esf_optimized_fos_no_endpoint_mock, + get_optimized_fv_no_endpointmock, + fetch_feature_values_mock, +): + """Tests that the public endpoint is not created for the optimized online store.""" + with pytest.raises( + fs_utils.PublicEndpointNotFoundError, + match=re.escape( + "Public endpoint is not created yet for the optimized online " + "store:my_esf_optimised_fos2. Please run sync and wait for it " + "to complete." + ), + ): + FeatureView(_TEST_OPTIMIZED_FV2_PATH).read(key=["key1"]).to_dict() + + +def test_ffv_optimized_psc_with_no_connection_options_raises_error( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, +): + with pytest.raises(ValueError) as excinfo: + FeatureView(_TEST_OPTIMIZED_FV1_PATH).read(key=["key1"]) + + assert str(excinfo.value) == ( + "Use `connection_options` to specify an IP address. Required for optimized online store with private service connect." + ) + + +def test_ffv_optimized_psc_with_no_connection_transport_raises_error( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, +): + with pytest.raises(ValueError) as excinfo: + FeatureView(_TEST_OPTIMIZED_FV1_PATH).read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.2.3.4", transport=None + ), + ) + + assert str(excinfo.value) == ( + "Unsupported connection transport type, got transport: None" + ) + + +def test_ffv_optimized_psc_with_bad_connection_transport_raises_error( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, +): + with pytest.raises(ValueError) as excinfo: + FeatureView(_TEST_OPTIMIZED_FV1_PATH).read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.2.3.4", transport="hi" + ), + ) + + assert str(excinfo.value) == ( + "Unsupported connection transport type, got transport: hi" + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_ffv_optimized_psc( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, + transport_mock, + grpc_insecure_channel_mock, + fetch_feature_values_mock, + output_type, +): + rsp = FeatureView(_TEST_OPTIMIZED_FV1_PATH).read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.2.3.4", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + + # Ensure that we create and use insecure channel to the target. + grpc_insecure_channel_mock.assert_called_once_with("1.2.3.4:10002") + transport_grpc_channel = transport_mock.call_args.kwargs["channel"] + assert transport_grpc_channel == grpc_insecure_channel_mock.return_value + + if output_type == "dict": + assert rsp.to_dict() == { + "features": [{"name": "key1", "value": {"string_value": "value1"}}] + } + elif output_type == "proto": + assert rsp.to_proto() == _TEST_FV_FETCH1 + + +def test_same_connection_options_are_equal(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + assert opt1 == opt2 + + +def test_different_host_in_connection_options_are_not_equal(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.2", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + + assert opt1 != opt2 + + +def test_bad_transport_in_compared_connection_options_raises_error(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=None, + ) + + with pytest.raises(ValueError) as excinfo: + assert opt1 != opt2 + + assert str(excinfo.value) == ( + "Transport 'ConnectionOptions.InsecureGrpcChannel()' cannot be compared to transport 'None'." + ) + + +def test_bad_transport_in_connection_options_raises_error(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=None, + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + + with pytest.raises(ValueError) as excinfo: + assert opt1 != opt2 + + assert str(excinfo.value) == ("Unsupported transport supplied: None") + + +def test_same_connection_options_have_same_hash(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + + d = {} + d[opt1] = "hi" + assert d[opt2] == "hi" + + +@pytest.mark.parametrize( + "hosts", + [ + ("1.1.1.1", "1.1.1.2"), + ("1.1.1.2", "1.1.1.1"), + ("10.0.0.1", "9.9.9.9"), + ], +) +def test_different_host_in_connection_options_have_different_hash(hosts): + opt1 = fs_utils.ConnectionOptions( + host=hosts[0], + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + opt2 = fs_utils.ConnectionOptions( + host=hosts[1], + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + + d = {} + d[opt1] = "hi" + assert opt2 not in d + + +@pytest.mark.parametrize( + "transports", + [ + (fs_utils.ConnectionOptions.InsecureGrpcChannel(), None), + (None, fs_utils.ConnectionOptions.InsecureGrpcChannel()), + (None, "hi"), + ("hi", None), + ], +) +def test_bad_transport_in_connection_options_have_different_hash(transports): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=transports[0], + ) + opt2 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=transports[1], + ) + + d = {} + d[opt1] = "hi" + assert opt2 not in d + + +def test_diff_host_and_bad_transport_in_connection_options_have_different_hash(): + opt1 = fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=None, + ) + opt2 = fs_utils.ConnectionOptions( + host="9.9.9.9", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ) + + d = {} + d[opt1] = "hi" + assert opt2 not in d + + +def test_ffv_optimized_psc_reuse_client_for_same_connection_options_in_same_ffv( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, + client_mock, + transport_mock, + grpc_insecure_channel_mock, + fetch_feature_values_mock, +): + fv = FeatureView(_TEST_OPTIMIZED_FV1_PATH) + fv.read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + fv.read( + key=["key2"], + connection_options=fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + + # Insecure channel and transport creation should only be done once. + assert grpc_insecure_channel_mock.call_args_list == [mock.call("1.1.1.1:10002")] + assert transport_mock.call_args_list == [ + mock.call(channel=grpc_insecure_channel_mock.return_value), + ] + + +def test_ffv_optimized_psc_different_client_for_different_connection_options( + get_psc_optimized_fos_mock, + get_optimized_fv_mock, + client_mock, + transport_mock, + grpc_insecure_channel_mock, + fetch_feature_values_mock, +): + # Return two different grpc channels each time insecure channel is called. + import grpc + + grpc_chan1 = mock.MagicMock(spec=grpc.Channel) + grpc_chan2 = mock.MagicMock(spec=grpc.Channel) + grpc_insecure_channel_mock.side_effect = [grpc_chan1, grpc_chan2] + + fv = FeatureView(_TEST_OPTIMIZED_FV1_PATH) + fv.read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + fv.read( + key=["key2"], + connection_options=fs_utils.ConnectionOptions( + host="1.2.3.4", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + + # Insecure channel and transport creation should be done twice - one for each different connection. + assert grpc_insecure_channel_mock.call_args_list == [ + mock.call("1.1.1.1:10002"), + mock.call("1.2.3.4:10002"), + ] + assert transport_mock.call_args_list == [ + mock.call(channel=grpc_chan1), + mock.call(channel=grpc_chan2), + ] + + +def test_ffv_optimized_psc_bad_gapic_client_raises_error( + get_psc_optimized_fos_mock, get_optimized_fv_mock, utils_client_with_override_mock +): + with pytest.raises(ValueError) as excinfo: + FeatureView(_TEST_OPTIMIZED_FV1_PATH).read( + key=["key1"], + connection_options=fs_utils.ConnectionOptions( + host="1.1.1.1", + transport=fs_utils.ConnectionOptions.InsecureGrpcChannel(), + ), + ) + + assert str(excinfo.value) == ( + f"Unexpected gapic class '{utils_client_with_override_mock.get_gapic_client_class.return_value}' used by internal client." + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_search_nearest_entities( + get_esf_optimized_fos_mock, + get_embedding_fv_mock, + search_nearest_entities_mock, + fv_logger_mock, + output_type, +): + if output_type == "dict": + fv_dict = ( + # Test with entity_id input. + FeatureView(_TEST_EMBEDDING_FV1_PATH) + .search( + entity_id="key1", + neighbor_count=2, + string_filters=[_TEST_STRING_FILTER], + per_crowding_attribute_neighbor_count=1, + return_full_entity=True, + approximate_neighbor_candidates=3, + leaf_nodes_search_fraction=0.5, + ) + .to_dict() + ) + assert fv_dict == { + "neighbors": [{"distance": 0.1, "entity_id": "neighbor_entity_id_1"}] + } + elif output_type == "proto": + fv_proto = ( + # Test with embedding_value input. + FeatureView(_TEST_EMBEDDING_FV1_PATH) + .search(embedding_value=[0.1, 0.2, 0.3]) + .to_proto() + ) + assert fv_proto == _TEST_FV_SEARCH1 + + fv_logger_mock.assert_has_calls( + [ + call( + "Public endpoint for the optimized online store my_esf_optimized_fos" + " is test-esf-endpoint" + ), + ] + ) + + +def test_search_nearest_entities_without_entity_id_or_embedding( + get_esf_optimized_fos_mock, + get_embedding_fv_mock, + search_nearest_entities_mock, + fv_logger_mock, +): + try: + FeatureView(_TEST_EMBEDDING_FV1_PATH).search().to_proto() + assert not search_nearest_entities_mock.called + except ValueError as e: + error_msg = ( + "Either entity_id or embedding_value needs to be provided for search." + ) + assert str(e) == error_msg + + +def test_search_nearest_entities_no_endpoint( + get_esf_optimized_fos_no_endpoint_mock, + get_optimized_fv_no_endpointmock, + fetch_feature_values_mock, +): + """Tests that the public endpoint is not created for the optimized online store.""" + try: + FeatureView(_TEST_OPTIMIZED_FV2_PATH).search(entity_id="key1").to_dict() + assert not fetch_feature_values_mock.called + except fs_utils.PublicEndpointNotFoundError as e: + assert isinstance(e, fs_utils.PublicEndpointNotFoundError) + error_msg = ( + "Public endpoint is not created yet for the optimized online " + "store:my_esf_optimised_fos2. Please run sync and wait for it " + "to complete." + ) + assert str(e) == error_msg diff --git a/tests/unit/vertexai/test_model_monitors.py b/tests/unit/vertexai/test_vertexai_model_monitors.py similarity index 100% rename from tests/unit/vertexai/test_model_monitors.py rename to tests/unit/vertexai/test_vertexai_model_monitors.py diff --git a/tests/unit/vertexai/feature_store_constants.py b/tests/unit/vertexai/vertexai_feature_store_constants.py similarity index 100% rename from tests/unit/vertexai/feature_store_constants.py rename to tests/unit/vertexai/vertexai_feature_store_constants.py