From d9472958a05525650960a11e06aec011109a26bc Mon Sep 17 00:00:00 2001 From: Yvonne Yu Date: Wed, 6 May 2026 17:12:59 -0700 Subject: [PATCH] feat: migrate rag to agentplatform PiperOrigin-RevId: 911632603 --- agentplatform/preview/__init__.py | 2 +- agentplatform/preview/rag/__init__.py | 169 ++ agentplatform/preview/rag/rag_data.py | 1235 ++++++++++++++ agentplatform/preview/rag/rag_retrieval.py | 536 +++++++ agentplatform/preview/rag/rag_store.py | 168 ++ .../preview/rag/utils/_gapic_utils.py | 1418 +++++++++++++++++ agentplatform/preview/rag/utils/resources.py | 817 ++++++++++ agentplatform/rag/__init__.py | 127 ++ agentplatform/rag/rag_data.py | 1152 +++++++++++++ agentplatform/rag/rag_inline_citations.py | 176 ++ agentplatform/rag/rag_retrieval.py | 505 ++++++ agentplatform/rag/rag_store.py | 165 ++ agentplatform/rag/utils/_gapic_utils.py | 824 ++++++++++ agentplatform/rag/utils/resources.py | 545 +++++++ tests/unit/vertex_rag/conftest.py | 14 +- tests/unit/vertex_rag/test_rag_constants.py | 6 +- .../vertex_rag/test_rag_constants_preview.py | 38 +- tests/unit/vertex_rag/test_rag_data.py | 10 +- .../unit/vertex_rag/test_rag_data_preview.py | 360 +---- tests/unit/vertex_rag/test_rag_retrieval.py | 7 +- .../vertex_rag/test_rag_retrieval_preview.py | 132 +- tests/unit/vertex_rag/test_rag_store.py | 121 +- .../unit/vertex_rag/test_rag_store_preview.py | 224 +-- 23 files changed, 7992 insertions(+), 759 deletions(-) create mode 100644 agentplatform/preview/rag/__init__.py create mode 100644 agentplatform/preview/rag/rag_data.py create mode 100644 agentplatform/preview/rag/rag_retrieval.py create mode 100644 agentplatform/preview/rag/rag_store.py create mode 100644 agentplatform/preview/rag/utils/_gapic_utils.py create mode 100644 agentplatform/preview/rag/utils/resources.py create mode 100644 agentplatform/rag/__init__.py create mode 100644 agentplatform/rag/rag_data.py create mode 100644 agentplatform/rag/rag_inline_citations.py create mode 100644 agentplatform/rag/rag_retrieval.py create mode 100644 agentplatform/rag/rag_store.py create mode 100644 agentplatform/rag/utils/_gapic_utils.py create mode 100644 agentplatform/rag/utils/resources.py diff --git a/agentplatform/preview/__init__.py b/agentplatform/preview/__init__.py index bb17a41dfd..26dad00a3b 100644 --- a/agentplatform/preview/__init__.py +++ b/agentplatform/preview/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""The agentplatform.preview module.""" +"""The agentplatform.preview module.""" \ No newline at end of file diff --git a/agentplatform/preview/rag/__init__.py b/agentplatform/preview/rag/__init__.py new file mode 100644 index 0000000000..282fc65523 --- /dev/null +++ b/agentplatform/preview/rag/__init__.py @@ -0,0 +1,169 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from agentplatform.preview.rag.rag_data import ( + batch_create_data_schemas, + batch_create_metadata, + batch_delete_data_schemas, + batch_delete_metadata, + create_corpus, + delete_corpus, + delete_file, + get_corpus, + get_file, + get_rag_engine_config, + import_files, + import_files_async, + list_corpora, + list_data_schemas, + list_files, + list_metadata, + update_corpus, + update_metadata, + update_rag_engine_config, + upload_file, +) +from agentplatform.preview.rag.rag_retrieval import ( + ask_contexts, + async_retrieve_contexts, + retrieval_query, +) +from agentplatform.preview.rag.rag_store import ( + Retrieval, + VertexRagStore, +) +from agentplatform.preview.rag.utils.resources import ( + ANN, + Basic, + ChunkingConfig, + DocumentCorpus, + EmbeddingModelConfig, + Filter, + HybridSearch, + JiraQuery, + JiraSource, + KNN, + LayoutParserConfig, + LlmParserConfig, + LlmRanker, + MemoryCorpus, + MetadataValue, + Pinecone, + RagCorpus, + RagCorpusTypeConfig, + RagDataSchema, + RagEmbeddingModelConfig, + RagEngineConfig, + RagFile, + RagManagedDb, + RagManagedDbConfig, + RagManagedVertexVectorSearch, + RagMetadata, + RagMetadataSchemaDetails, + RagResource, + RagRetrievalConfig, + RagVectorDbConfig, + RankService, + Ranking, + Scaled, + Serverless, + SharePointSource, + SharePointSources, + SlackChannel, + SlackChannelsSource, + Spanner, + TransformationConfig, + Unprovisioned, + UserSpecifiedMetadata, + VertexAiSearchConfig, + VertexFeatureStore, + VertexPredictionEndpoint, + VertexVectorSearch, + Weaviate, +) + +__all__ = ( + "ANN", + "Basic", + "ChunkingConfig", + "DocumentCorpus", + "EmbeddingModelConfig", + "Filter", + "HybridSearch", + "JiraQuery", + "JiraSource", + "KNN", + "LayoutParserConfig", + "LlmParserConfig", + "LlmRanker", + "MemoryCorpus", + "MetadataValue", + "Pinecone", + "RagEngineConfig", + "RagCorpus", + "RagCorpusTypeConfig", + "RagDataSchema", + "RagEmbeddingModelConfig", + "RagFile", + "RagManagedDb", + "RagManagedDbConfig", + "RagManagedVertexVectorSearch", + "RagMetadata", + "RagMetadataSchemaDetails", + "RagResource", + "RagRetrievalConfig", + "RagVectorDbConfig", + "Ranking", + "RankService", + "Retrieval", + "Scaled", + "Serverless", + "SharePointSource", + "SharePointSources", + "SlackChannel", + "SlackChannelsSource", + "Spanner", + "TransformationConfig", + "Unprovisioned", + "UserSpecifiedMetadata", + "VertexAiSearchConfig", + "VertexFeatureStore", + "VertexPredictionEndpoint", + "VertexRagStore", + "VertexVectorSearch", + "Weaviate", + "ask_contexts", + "batch_create_data_schemas", + "batch_create_metadata", + "batch_delete_data_schemas", + "batch_delete_metadata", + "create_corpus", + "delete_corpus", + "delete_file", + "get_corpus", + "get_file", + "import_files", + "import_files_async", + "list_corpora", + "list_data_schemas", + "list_files", + "list_metadata", + "retrieval_query", + "async_retrieve_contexts", + "upload_file", + "update_corpus", + "update_metadata", + "update_rag_engine_config", + "get_rag_engine_config", +) diff --git a/agentplatform/preview/rag/rag_data.py b/agentplatform/preview/rag/rag_data.py new file mode 100644 index 0000000000..bb9a2248f4 --- /dev/null +++ b/agentplatform/preview/rag/rag_data.py @@ -0,0 +1,1235 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""RAG data management SDK.""" +# +from typing import Optional, Sequence, Union +from google import auth +from google.api_core import operation_async +from google.auth.transport import requests as google_auth_requests +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1beta1 import ( + BatchCreateRagDataSchemasRequest, + BatchCreateRagMetadataRequest, + BatchDeleteRagDataSchemasRequest, + BatchDeleteRagMetadataRequest, + CreateRagCorpusRequest, + CreateRagDataSchemaRequest, + CreateRagMetadataRequest, + DeleteRagCorpusRequest, + DeleteRagFileRequest, + GetRagCorpusRequest, + GetRagEngineConfigRequest, + GetRagFileRequest, + ImportRagFilesResponse, + ListRagCorporaRequest, + ListRagDataSchemasRequest, + ListRagFilesRequest, + ListRagMetadataRequest, + RagCorpus as GapicRagCorpus, + UpdateRagCorpusRequest, + UpdateRagEngineConfigRequest, + UpdateRagMetadataRequest, +) +from google.cloud.aiplatform_v1beta1.services.vertex_rag_data_service.pagers import ( + ListRagCorporaPager, + ListRagDataSchemasPager, + ListRagFilesPager, + ListRagMetadataPager, +) +from google.cloud.aiplatform_v1beta1.types import EncryptionSpec +from agentplatform.preview.rag.utils import ( + _gapic_utils, +) +from agentplatform.preview.rag.utils.resources import ( + EmbeddingModelConfig, + JiraSource, + LayoutParserConfig, + LlmParserConfig, + Pinecone, + RagCorpus, + RagCorpusTypeConfig, + RagDataSchema, + RagEngineConfig, + RagFile, + RagManagedDb, + RagManagedVertexVectorSearch, + RagMetadata, + RagVectorDbConfig, + SharePointSources, + SlackChannelsSource, + TransformationConfig, + VertexAiSearchConfig, + VertexFeatureStore, + VertexVectorSearch, + Weaviate, +) + + +def create_corpus( + display_name: Optional[str] = None, + description: Optional[str] = None, + corpus_type_config: Optional[RagCorpusTypeConfig] = None, + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[RagVectorDbConfig] = None, + encryption_spec: Optional[EncryptionSpec] = None, + timeout: int = 600, +) -> RagCorpus: + """Creates a new RagCorpus resource. + + Example usage: + ``` + import agentplatform + from agentplatform.preview import rag + + agentplatform.init(project="my-project") + + rag_corpus = rag.create_corpus( + display_name="my-corpus-1", + ) + ``` + + Args: + display_name: If not provided, SDK will create one. The display name of + the RagCorpus. The name can be up to 128 characters long and can consist + of any UTF-8 characters. + description: The description of the RagCorpus. + corpus_type_config: The corpus type config of the RagCorpus. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. + backend_config: The backend config of the RagCorpus. It can specify a + Vector DB and/or the embedding model config. + encryption_spec: The encryption spec of the RagCorpus. + timeout: Default is 600 seconds. + + Returns: + RagCorpus. + Raises: + RuntimeError: Failed in RagCorpus creation due to exception. + RuntimeError: Failed in RagCorpus creation due to operation error. + """ + if not display_name: + display_name = "vertex-" + utils.timestamped_unique_name() + parent = initializer.global_config.common_location_path(project=None, location=None) + + rag_corpus = GapicRagCorpus(display_name=display_name, description=description) + + if corpus_type_config: + _gapic_utils.set_corpus_type_config( + corpus_type_config=corpus_type_config, + rag_corpus=rag_corpus, + ) + + if vertex_ai_search_config and backend_config: + raise ValueError( + "Only one of vertex_ai_search_config or backend_config can be set." + ) + + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + + if vertex_ai_search_config: + _gapic_utils.set_vertex_ai_search_config( + vertex_ai_search_config=vertex_ai_search_config, + rag_corpus=rag_corpus, + ) + else: + _gapic_utils.set_vector_db( + vector_db=None, + rag_corpus=rag_corpus, + ) + + if encryption_spec: + _gapic_utils.set_encryption_spec( + encryption_spec=encryption_spec, + rag_corpus=rag_corpus, + ) + + request = CreateRagCorpusRequest( + parent=parent, + rag_corpus=rag_corpus, + ) + client = _gapic_utils.create_rag_data_service_client() + + try: + response = client.create_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in RagCorpus creation due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=timeout)) + + +def update_corpus( + corpus_name: str, + display_name: Optional[str] = None, + description: Optional[str] = None, + vector_db: Optional[ + Union[ + Weaviate, + VertexFeatureStore, + VertexVectorSearch, + Pinecone, + RagManagedDb, + ] + ] = None, + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[RagVectorDbConfig] = None, + timeout: int = 600, +) -> RagCorpus: + """Updates a RagCorpus resource. + + Example usage: + ``` + import agentplatform + from agentplatform.preview import rag + + agentplatform.init(project="my-project") + + rag_corpus = rag.update_corpus( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + display_name="my-corpus-1", + ) + ``` + + Args: + corpus_name: The name of the RagCorpus resource to update. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + display_name: If not provided, the display name will not be updated. The + display name of the RagCorpus. The name can be up to 128 characters long + and can consist of any UTF-8 characters. + description: The description of the RagCorpus. If not provided, the + description will not be updated. + vector_db: The vector db config of the RagCorpus. If not provided, the + vector db will not be updated. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. If + not provided, the Vertex AI Search config will not be updated. + Note: embedding_model_config or vector_db cannot be set if + vertex_ai_search_config is specified. + backend_config: The backend config of the RagCorpus. Specifies a Vector DB + and/or the embedding model config. + timeout: Default is 600 seconds. + + Returns: + RagCorpus. + Raises: + RuntimeError: Failed in RagCorpus update due to exception. + RuntimeError: Failed in RagCorpus update due to operation error. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + if display_name and description: + rag_corpus = GapicRagCorpus( + name=corpus_name, display_name=display_name, description=description + ) + elif display_name: + rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name) + elif description: + rag_corpus = GapicRagCorpus(name=corpus_name, description=description) + else: + rag_corpus = GapicRagCorpus(name=corpus_name) + + if vertex_ai_search_config and vector_db: + raise ValueError("Only one of vertex_ai_search_config or vector_db can be set.") + + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + + if vertex_ai_search_config: + _gapic_utils.set_vertex_ai_search_config( + vertex_ai_search_config=vertex_ai_search_config, + rag_corpus=rag_corpus, + ) + else: + _gapic_utils.set_vector_db( + vector_db=vector_db, + rag_corpus=rag_corpus, + ) + + request = UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + client = _gapic_utils.create_rag_data_service_client() + + try: + response = client.update_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in RagCorpus update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config( + response.result(timeout=timeout) + ) + + +def get_corpus(name: str) -> RagCorpus: + """ + Get an existing RagCorpus. + + Args: + name: An existing RagCorpus resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + Returns: + RagCorpus. + """ + corpus_name = _gapic_utils.get_corpus_name(name) + request = GetRagCorpusRequest(name=corpus_name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagCorpus due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus(response) + + +def list_corpora( + page_size: Optional[int] = None, page_token: Optional[str] = None +) -> ListRagCorporaPager: + """ + List all RagCorpora in the same project and location. + + Example usage: + ``` + import agentplatform + from agentplatform.preview import rag + + agentplatform.init(project="my-project") + + # List all corpora. + rag_corpora = list(rag.list_corpora()) + + # Alternatively, return a ListRagCorporaPager. + pager_1 = rag.list_corpora(page_size=10) + # Then get the next page, use the generated next_page_token from the last pager. + pager_2 = rag.list_corpora(page_size=10, page_token=pager_1.next_page_token) + + ``` + Args: + page_size: The standard list page size. Leaving out the page_size + causes all of the results to be returned. + page_token: The standard list page token. + + Returns: + ListRagCorporaPager. + """ + parent = initializer.global_config.common_location_path(project=None, location=None) + request = ListRagCorporaRequest( + parent=parent, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_corpora(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagCorpora due to: ", e) from e + + return pager + + +def delete_corpus(name: str) -> None: + """ + Delete an existing RagCorpus. + + Args: + name: An existing RagCorpus resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + """ + corpus_name = _gapic_utils.get_corpus_name(name) + request = DeleteRagCorpusRequest(name=corpus_name) + + client = _gapic_utils.create_rag_data_service_client() + try: + client.delete_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in RagCorpus deletion due to: ", e) from e + return None + + +def upload_file( + corpus_name: str, + path: Union[str, Sequence[str]], + display_name: Optional[str] = None, + description: Optional[str] = None, + transformation_config: Optional[TransformationConfig] = None, +) -> RagFile: + """ + Synchronous file upload to an existing RagCorpus. + + Example usage: + + ``` + import agentplatform + from agentplatform.preview import rag + + agentplatform.init(project="my-project") + + # Optional. + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + rag_file = rag.upload_file( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + display_name="my_file.txt", + path="usr/home/my_file.txt", + transformation_config=transformation_config, + ) + ``` + + Args: + corpus_name: The name of the RagCorpus resource into which to upload the file. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + path: A local file path. For example, + "usr/home/my_file.txt". + display_name: The display name of the data file. + description: The description of the RagFile. + transformation_config: The config for transforming the RagFile, such as chunking. + Returns: + RagFile. + Raises: + RuntimeError: Failed in RagFile upload. + ValueError: RagCorpus is not found. + RuntimeError: Failed in indexing the RagFile. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + location = initializer.global_config.location + # GAPIC doesn't expose a path (scotty). Use requests API instead + if display_name is None: + display_name = "vertex-" + utils.timestamped_unique_name() + headers = {"X-Goog-Upload-Protocol": "multipart"} + if not initializer.global_config.api_endpoint: + request_endpoint = "{}-{}".format( + location, aiplatform.constants.base.API_BASE_PATH + ) + else: + request_endpoint = initializer.global_config.api_endpoint + upload_request_uri = "https://{}/upload/v1beta1/{}/ragFiles:upload".format( + request_endpoint, + corpus_name, + ) + js_rag_file = {"rag_file": {"display_name": display_name}} + + if description: + js_rag_file["rag_file"]["description"] = description + + if transformation_config and transformation_config.chunking_config: + chunk_size = transformation_config.chunking_config.chunk_size + chunk_overlap = transformation_config.chunking_config.chunk_overlap + js_rag_file["upload_rag_file_config"] = { + "rag_file_transformation_config": { + "rag_file_chunking_config": { + "fixed_length_chunking": { + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + } + } + } + } + files = { + "metadata": (None, str(js_rag_file)), + "file": open(path, "rb"), + } + credentials, _ = auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + authorized_session = google_auth_requests.AuthorizedSession(credentials=credentials) + try: + response = authorized_session.post( + url=upload_request_uri, + files=files, + headers=headers, + ) + except Exception as e: + raise RuntimeError("Failed in uploading the RagFile due to: ", e) from e + + if response.status_code == 404: + raise ValueError( + "RagCorpus '%s' is not found: %s", corpus_name, upload_request_uri + ) + if response.json().get("error"): + raise RuntimeError( + "Failed in indexing the RagFile due to: ", response.json().get("error") + ) + return _gapic_utils.convert_json_to_rag_file(response.json()) + + +def import_files( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + transformation_config: Optional[TransformationConfig] = None, + timeout: int = 600, + max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, + rebuild_ann_index: Optional[bool] = False, +) -> ImportRagFilesResponse: + """ + Import files to an existing RagCorpus, wait until completion. + + Example usage: + + ``` + import agentplatform + from agentplatform.preview import rag + from google.protobuf import timestamp_pb2 + + agentplatform.init(project="my-project") + # Google Drive example + paths = [ + "https://drive.google.com/file/d/123", + "https://drive.google.com/drive/folders/456" + ] + # Google Cloud Storage example + paths = ["gs://my_bucket/my_files_dir", ...] + + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + response = rag.import_files( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + transformation_config=transformation_config, + ) + + # Slack example + start_time = timestamp_pb2.Timestamp() + start_time.FromJsonString('2020-12-31T21:33:44Z') + end_time = timestamp_pb2.Timestamp() + end_time.GetCurrentTime() + source = rag.SlackChannelsSource( + channels = [ + SlackChannel("channel1", "api_key1"), + SlackChannel("channel2", "api_key2", start_time, end_time) + ], + ) + # Jira Example + jira_query = rag.JiraQuery( + email="xxx@yyy.com", + jira_projects=["project1", "project2"], + custom_queries=["query1", "query2"], + api_key="api_key", + server_uri="server.atlassian.net" + ) + source = rag.JiraSource( + queries=[jira_query], + ) + + response = rag.import_files( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + source=source, + transformation_config=transformation_config, + ) + + # SharePoint Example. + sharepoint_query = rag.SharePointSource( + sharepoint_folder_path="https://my-sharepoint-site.com/my-folder", + sharepoint_site_name="my-sharepoint-site.com", + client_id="my-client-id", + client_secret="my-client-secret", + tenant_id="my-tenant-id", + drive_id="my-drive-id", + ) + source = rag.SharePointSources( + share_point_sources=[sharepoint_query], + ) + + # Return the number of imported RagFiles after completion. + print(response.imported_rag_files_count) + + ``` + Args: + corpus_name: The name of the RagCorpus resource into which to import files. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + paths: A list of uris. Eligible uris will be Google Cloud Storage + directory ("gs://my-bucket/my_dir") or a Google Drive url for file + (https://drive.google.com/file/... or folder + "https://drive.google.com/corp/drive/folders/..."). + source: The source of the Slack or Jira import. + Must be either a SlackChannelsSource or JiraSource. + transformation_config: The config for transforming the imported + RagFiles. + max_embedding_requests_per_min: + Optional. The max number of queries per + minute that this job is allowed to make to the + embedding model specified on the corpus. This + value is specific to this job and not shared + across other import jobs. Consult the Quotas + page on the project to set an appropriate value + here. If unspecified, a default value of 1,000 + QPM would be used. + global_max_embedding_requests_per_min: + Optional. The max number of queries per minute that the indexing + pipeline job is allowed to make to the embedding model specified in + the project. Please follow the quota usage guideline of the embedding + model you use to set the value properly. If this value is not specified, + max_embedding_requests_per_min will be used by indexing pipeline job + as the global limit and this means parallel import jobs are not allowed. + timeout: Default is 600 seconds. + layout_parser: Configuration for the Document AI Layout Parser Processor + to use for document parsing. Optional. + If not None, the other parser configs must be None. + llm_parser: Configuration for the LLM Parser to use for document parsing. + Optional. + If not None, the other parser configs must be None. + rebuild_ann_index: Rebuilds the ANN index to optimize for recall on the + imported data. Only applicable for RagCorpora running on + RagManagedDb with ``retrieval_strategy`` set to ``ANN``. The + rebuild will be performed using the existing ANN config set + on the RagCorpus. To change the ANN config, please use the + UpdateRagCorpus API. Optional.Default is false, i.e., index is not + rebuilt. + Returns: + ImportRagFilesResponse. + """ + if source is not None and paths is not None: + raise ValueError("Only one of source or paths must be passed in at a time") + if source is None and paths is None: + raise ValueError("One of source or paths must be passed in") + if layout_parser is not None and llm_parser is not None: + raise ValueError( + "Only one of layout_parser or llm_parser may be passed in at a time" + ) + + rebuild_ann_index_request = ( + rebuild_ann_index if rebuild_ann_index is not None else False + ) + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = _gapic_utils.prepare_import_files_request( + corpus_name=corpus_name, + paths=paths, + source=source, + chunk_size=1024, + chunk_overlap=200, + transformation_config=transformation_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + global_max_embedding_requests_per_min=global_max_embedding_requests_per_min, + use_advanced_pdf_parsing=False, + layout_parser=layout_parser, + llm_parser=llm_parser, + rebuild_ann_index=rebuild_ann_index_request, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.import_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e + + return response.result(timeout=timeout) + + +async def import_files_async( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + transformation_config: Optional[TransformationConfig] = None, + max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, + rebuild_ann_index: Optional[bool] = False, +) -> operation_async.AsyncOperation: + """ + Import files to an existing RagCorpus asynchronously. + + Example usage: + + ``` + import agentplatform + from agentplatform.preview import rag + from google.protobuf import timestamp_pb2 + + agentplatform.init(project="my-project") + + # Google Drive example + paths = [ + "https://drive.google.com/file/d/123", + "https://drive.google.com/drive/folders/456" + ] + # Google Cloud Storage example + paths = ["gs://my_bucket/my_files_dir", ...] + + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + response = await rag.import_files_async( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + transformation_config=transformation_config, + ) + + # Slack example + start_time = timestamp_pb2.Timestamp() + start_time.FromJsonString('2020-12-31T21:33:44Z') + end_time = timestamp_pb2.Timestamp() + end_time.GetCurrentTime() + source = rag.SlackChannelsSource( + channels = [ + SlackChannel("channel1", "api_key1"), + SlackChannel("channel2", "api_key2", start_time, end_time) + ], + ) + # Jira Example + jira_query = rag.JiraQuery( + email="xxx@yyy.com", + jira_projects=["project1", "project2"], + custom_queries=["query1", "query2"], + api_key="api_key", + server_uri="server.atlassian.net" + ) + source = rag.JiraSource( + queries=[jira_query], + ) + + response = await rag.import_files_async( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + source=source, + transformation_config=transformation_config, + ) + + # SharePoint Example. + sharepoint_query = rag.SharePointSource( + sharepoint_folder_path="https://my-sharepoint-site.com/my-folder", + sharepoint_site_name="my-sharepoint-site.com", + client_id="my-client-id", + client_secret="my-client-secret", + tenant_id="my-tenant-id", + drive_id="my-drive-id", + ) + source = rag.SharePointSources( + share_point_sources=[sharepoint_query], + ) + + # Get the result. + await response.result() + + ``` + Args: + corpus_name: The name of the RagCorpus resource into which to import files. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + paths: A list of uris. Eligible uris will be Google Cloud Storage + directory ("gs://my-bucket/my_dir") or a Google Drive url for file + (https://drive.google.com/file/... or folder + "https://drive.google.com/corp/drive/folders/..."). + source: The source of the Slack or Jira import. + Must be either a SlackChannelsSource or JiraSource. + transformation_config: The config for transforming the imported + RagFiles. + max_embedding_requests_per_min: + Optional. The max number of queries per + minute that this job is allowed to make to the + embedding model specified on the corpus. This + value is specific to this job and not shared + across other import jobs. Consult the Quotas + page on the project to set an appropriate value + here. If unspecified, a default value of 1,000 + QPM would be used. + global_max_embedding_requests_per_min: + Optional. The max number of queries per minute that the indexing + pipeline job is allowed to make to the embedding model specified in + the project. Please follow the quota usage guideline of the embedding + model you use to set the value properly. If this value is not specified, + max_embedding_requests_per_min will be used by indexing pipeline job + as the global limit and this means parallel import jobs are not allowed. + layout_parser: Configuration for the Document AI Layout Parser Processor + to use for document parsing. Optional. + If not None, the other parser configs must be None. + llm_parser: Configuration for the LLM Parser to use for document parsing. + Optional. + If not None, the other parser configs must be None. + rebuild_ann_index: Rebuilds the ANN index to optimize for recall on the + imported data. Only applicable for RagCorpora running on + RagManagedDb with ``retrieval_strategy`` set to ``ANN``. The + rebuild will be performed using the existing ANN config set + on the RagCorpus. To change the ANN config, please use the + UpdateRagCorpus API. Optional.Default is false, i.e., index is not + rebuilt. + Returns: + operation_async.AsyncOperation. + """ + if source is not None and paths is not None: + raise ValueError("Only one of source or paths must be passed in at a time") + if source is None and paths is None: + raise ValueError("One of source or paths must be passed in") + if layout_parser is not None and llm_parser is not None: + raise ValueError( + "Only one of layout_parser or llm_parser may be passed in at a time" + ) + rebuild_ann_index_request = ( + rebuild_ann_index if rebuild_ann_index is not None else False + ) + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = _gapic_utils.prepare_import_files_request( + corpus_name=corpus_name, + paths=paths, + source=source, + chunk_size=1024, + chunk_overlap=200, + transformation_config=transformation_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + global_max_embedding_requests_per_min=global_max_embedding_requests_per_min, + use_advanced_pdf_parsing=False, + layout_parser=layout_parser, + llm_parser=llm_parser, + rebuild_ann_index=rebuild_ann_index_request, + ) + async_client = _gapic_utils.create_rag_data_service_async_client() + try: + response = await async_client.import_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e + return response + + +def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile: + """ + Get an existing RagFile. + + Args: + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + Returns: + RagFile. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) + request = GetRagFileRequest(name=name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_file(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagFile due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_file(response) + + +def list_files( + corpus_name: str, page_size: Optional[int] = None, page_token: Optional[str] = None +) -> ListRagFilesPager: + """ + List all RagFiles in an existing RagCorpus. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + # List all corpora. + rag_corpora = list(rag.list_corpora()) + + # List all files of the first corpus. + rag_files = list(rag.list_files(corpus_name=rag_corpora[0].name)) + + # Alternatively, return a ListRagFilesPager. + pager_1 = rag.list_files( + corpus_name=rag_corpora[0].name, + page_size=10 + ) + # Then get the next page, use the generated next_page_token from the last pager. + pager_2 = rag.list_files( + corpus_name=rag_corpora[0].name, + page_size=10, + page_token=pager_1.next_page_token + ) + + ``` + + Args: + corpus_name: An existing RagCorpus name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + page_size: The standard list page size. Leaving out the page_size + causes all of the results to be returned. + page_token: The standard list page token. + Returns: + ListRagFilesPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = ListRagFilesRequest( + parent=corpus_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagFiles due to: ", e) from e + + return pager + + +def delete_file(name: str, corpus_name: Optional[str] = None) -> None: + """ + Delete RagFile from an existing RagCorpus. + + Args: + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) + request = DeleteRagFileRequest(name=name) + + client = _gapic_utils.create_rag_data_service_client() + try: + client.delete_rag_file(request=request) + except Exception as e: + raise RuntimeError("Failed in RagFile deletion due to: ", e) from e + return None + + +def batch_create_data_schemas( + corpus_name: str, + requests: Sequence[RagDataSchema], + timeout: int = 600, +) -> Sequence[RagDataSchema]: + """Batch creates RagDataSchema resources. + + Args: + corpus_name: The name of the RagCorpus resource to create the + RagDataSchemas in. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + requests: The RagDataSchemas to create. + timeout: Default is 600 seconds. + + Returns: + Sequence of RagDataSchema. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + gapic_requests = [] + for request in requests: + gapic_requests.append( + CreateRagDataSchemaRequest( + parent=corpus_name, + rag_data_schema=_gapic_utils.convert_rag_data_schema_to_gapic(request), + ) + ) + request = BatchCreateRagDataSchemasRequest( + parent=corpus_name, + requests=gapic_requests, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_create_rag_data_schemas(request=request) + except Exception as e: + raise RuntimeError("Failed in RagDataSchema batch creation due to: ", e) from e + result = response.result(timeout=timeout) + if result.rag_data_schemas: + return [ + _gapic_utils.convert_gapic_to_rag_data_schema(schema) + for schema in result.rag_data_schemas + ] + return [] + + +def batch_delete_data_schemas( + corpus_name: str, + names: Sequence[str], + timeout: int = 600, +) -> None: + """Batch deletes RagDataSchema resources. + + Args: + corpus_name: The name of the RagCorpus resource from which to delete the + RagDataSchemas. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + names: The RagDataSchema resource names to delete. + timeout: Default is 600 seconds. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + full_names = [ + _gapic_utils.get_data_schema_name(name, corpus_name) for name in names + ] + request = BatchDeleteRagDataSchemasRequest( + parent=corpus_name, + names=full_names, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_delete_rag_data_schemas(request=request) + response.result(timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in RagDataSchema batch deletion due to: ", e) from e + return None + + +def list_data_schemas( + corpus_name: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, +) -> ListRagDataSchemasPager: + """Lists RagDataSchemas in an existing RagCorpus. + + Args: + corpus_name: An existing RagCorpus name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + page_size: The standard list page size. + page_token: The standard list page token. + + Returns: + ListRagDataSchemasPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = ListRagDataSchemasRequest( + parent=corpus_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_data_schemas(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagDataSchemas due to: ", e) from e + return pager + + +def batch_create_metadata( + corpus_name: str, + file_name: str, + requests: Sequence[RagMetadata], + timeout: int = 600, +) -> Sequence[RagMetadata]: + """Batch creates RagMetadata resources. + + Args: + corpus_name: The name of the RagCorpus resource. + file_name: The name of the RagFile resource. + requests: The RagMetadata to create. + timeout: Default is 600 seconds. + + Returns: + Sequence of RagMetadata. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + gapic_requests = [] + for request in requests: + gapic_requests.append( + CreateRagMetadataRequest( + parent=file_name, + rag_metadata=_gapic_utils.convert_rag_metadata_to_gapic(request), + ) + ) + request = BatchCreateRagMetadataRequest( + parent=file_name, + requests=gapic_requests, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_create_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in RagMetadata batch creation due to: ", e) from e + result = response.result(timeout=timeout) + if result.rag_metadata: + return [ + _gapic_utils.convert_gapic_to_rag_metadata(metadata) + for metadata in result.rag_metadata + ] + return [] + + +def batch_delete_metadata( + corpus_name: str, + file_name: str, + names: Sequence[str], + timeout: int = 600, +) -> None: + """Batch deletes RagMetadata resources. + + Args: + corpus_name: The name of the RagCorpus resource. + file_name: The name of the RagFile resource. + names: The RagMetadata resource names to delete. + timeout: Default is 600 seconds. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + full_names = [ + _gapic_utils.get_metadata_name(name, corpus_name, file_name) for name in names + ] + request = BatchDeleteRagMetadataRequest( + parent=file_name, + names=full_names, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.batch_delete_rag_metadata(request=request) + response.result(timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in RagMetadata batch deletion due to: ", e) from e + return None + + +def list_metadata( + corpus_name: str, + file_name: str, + page_size: Optional[int] = None, + page_token: Optional[str] = None, +) -> ListRagMetadataPager: + """Lists RagMetadata in an existing RagFile. + + Args: + corpus_name: An existing RagCorpus name. + file_name: An existing RagFile name. + page_size: The standard list page size. + page_token: The standard list page token. + + Returns: + ListRagMetadataPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + file_name = _gapic_utils.get_file_name(file_name, corpus_name) + request = ListRagMetadataRequest( + parent=file_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagMetadata due to: ", e) from e + return pager + + +def update_metadata( + rag_metadata: RagMetadata, +) -> RagMetadata: + """Updates a RagMetadata resource. + + Args: + rag_metadata: The RagMetadata which replaces the resource on the server. + + Returns: + RagMetadata. + """ + request = UpdateRagMetadataRequest( + rag_metadata=_gapic_utils.convert_rag_metadata_to_gapic(rag_metadata), + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.update_rag_metadata(request=request) + except Exception as e: + raise RuntimeError("Failed in RagMetadata update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_metadata(response) + + +def update_rag_engine_config( + rag_engine_config: RagEngineConfig, + timeout: int = 600, +) -> RagEngineConfig: + """Update RagEngineConfig. + + Example usage: + ``` + import agentplatform + from agentplatform.preview import rag + agentplatform.init(project="my-project") + rag_engine_config = rag.RagEngineConfig( + rag_managed_db_config=rag.RagManagedDbConfig( + rag_managed_db=rag.RagManagedDb( + db_basic_tier=rag.Basic(), + ), + ) + ), + ) + rag.update_rag_engine_config(rag_engine_config=rag_engine_config) + ``` + + Args: + rag_engine_config: The RagEngineConfig to update. + timeout: Default is 600 seconds. + + Raises: + RuntimeError: Failed in RagEngineConfig update due to exception. + """ + gapic_rag_engine_config = _gapic_utils.convert_rag_engine_config_to_gapic( + rag_engine_config + ) + request = UpdateRagEngineConfigRequest( + rag_engine_config=gapic_rag_engine_config, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.update_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in RagEngineConfig update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config( + response.result(timeout=timeout) + ) + + +def get_rag_engine_config(name: str) -> RagEngineConfig: + """Get an existing RagEngineConfig. + + Example usage: + ``` + import agentplatform + from agentplatform.preview import rag + agentplatform.init(project="my-project") + rag_engine_config = rag.get_rag_engine_config( + name="projects/my-project/locations/us-central1/ragEngineConfig" + ) + ``` + Args: + name: The RagEngineConfig resource name pattern of the singleton resource. + + Returns: + RagEngineConfig. + Raises: + RuntimeError: Failed in getting the RagEngineConfig. + """ + request = GetRagEngineConfigRequest(name=name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagEngineConfig due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config(response) diff --git a/agentplatform/preview/rag/rag_retrieval.py b/agentplatform/preview/rag/rag_retrieval.py new file mode 100644 index 0000000000..6860b0247e --- /dev/null +++ b/agentplatform/preview/rag/rag_retrieval.py @@ -0,0 +1,536 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Retrieval query to get relevant contexts.""" + +import re +from typing import List, Optional +import warnings + +from google.cloud import aiplatform_v1beta1 +from google.cloud.aiplatform import initializer +from agentplatform.preview.rag.utils import _gapic_utils +from agentplatform.preview.rag.utils import resources + +from google.protobuf import any_pb2 + + +def retrieval_query( + text: str, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, +) -> aiplatform_v1beta1.RetrieveContextsResponse: + """Retrieve top k relevant docs/chunks. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + # Using RagRetrievalConfig. + config = agentplatform.preview.rag.RagRetrievalConfig( + top_k=2, + filter=agentplatform.preview.rag.Filter( + vector_distance_threshold=0.5 + ), + hybrid_search=agentplatform.preview.rag.rag_retrieval_config.hybrid_search( + alpha=0.5 + ), + ranking=vertex.preview.rag.Ranking( + llm_ranker=agentplatform.preview.rag.LlmRanker( + model_name="gemini-1.5-flash-002" + ) + ) + ) + + results = agentplatform.preview.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[agentplatform.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: The query in text format to get relevant contexts. + rag_resources: A list of RagResource. It can be used to specify corpus + only or ragfiles. Currently only support one corpus or multiple files + from one corpus. In the future we may open up multiple corpora support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + + Returns: + RetrieveContextsResonse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_client() + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + else: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + + gapic_rag_resource = ( + aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1beta1.RetrieveContextsRequest.VertexRagStore( + rag_resources=[gapic_rag_resource], + ) + + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + request = aiplatform_v1beta1.RetrieveContextsRequest( + vertex_rag_store=vertex_rag_store, + parent=parent, + query=query, + ) + try: + response = client.retrieve_contexts(request=request) + except Exception as e: + raise RuntimeError("Failed in retrieving contexts due to: ", e) from e + + return response + + +async def async_retrieve_contexts( + text: str, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + timeout: int = 600, +) -> aiplatform_v1beta1.RetrieveContextsResponse: + """Retrieve top k relevant docs/chunks asynchronously. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.preview.rag.RagRetrievalConfig( + top_k=2, + ) + + results = await agentplatform.preview.rag.async_retrieve_contexts( + text="Why is the sky blue?", + rag_resources=[agentplatform.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + timeout: Optional. The timeout for the request in seconds. Default is 600. + + Returns: + RetrieveContextsResponse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_async_client() + + if not rag_resources: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + + gapic_rag_resources = [] + if rag_resources: + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1beta1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_resources=gapic_rag_resources, + ) + + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrival_config + + tool = aiplatform_v1beta1.Tool( + retrieval=aiplatform_v1beta1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1beta1.AsyncRetrieveContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + response_lro = await client.async_retrieve_contexts( + request=request, timeout=timeout + ) + try: + response = await response_lro.result(timeout=timeout) + except Exception as e: + if response_lro.done(): + raw_op = response_lro.operation + if raw_op.WhichOneof("result") == "response": + any_response = raw_op.response + inner_any = any_pb2.Any() + if any_response.Unpack(inner_any): + inner_any.type_url = "type.googleapis.com/google.cloud.aiplatform.v1beta1.RagContexts" + rag_contexts = aiplatform_v1beta1.RagContexts() + if inner_any.Unpack(rag_contexts._pb): + return aiplatform_v1beta1.AsyncRetrieveContextsResponse( + contexts=rag_contexts + ) + raise e + except Exception as e: + raise RuntimeError( + "Failed in retrieving contexts asynchronously due to: ", e + ) from e + + return response + + +def ask_contexts( + text: str, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + timeout: int = 600, +) -> aiplatform_v1beta1.AskContextsResponse: + """Ask questions on top k relevant docs/chunks. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.preview.rag.RagRetrievalConfig( + top_k=2, + ) + + results = agentplatform.preview.rag.ask_contexts( + text="Why is the sky blue?", + rag_resources=[agentplatform.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k, vector_distance_threshold, and alpha. + timeout: Optional. The timeout for the request in seconds. Default is 600. + + Returns: + AskContextsResponse. + """ + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_client() + + if not rag_resources: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + + gapic_rag_resources = [] + if rag_resources: + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match( + "^{}$".format( + _gapic_utils._VALID_RESOURCE_NAME_REGEX # pylint: disable=protected-access + ), + name, + ): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1beta1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1beta1.VertexRagStore( + rag_resources=gapic_rag_resources, + ) + + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + + if ( + rag_retrieval_config.hybrid_search + and rag_retrieval_config.hybrid_search.alpha + ): + api_retrival_config.hybrid_search.alpha = ( + rag_retrieval_config.hybrid_search.alpha + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + if rag_retrieval_config.filter and rag_retrieval_config.filter.metadata_filter: + api_retrival_config.filter.metadata_filter = ( + rag_retrieval_config.filter.metadata_filter + ) + + query = aiplatform_v1beta1.RagQuery( + text=text, + rag_retrieval_config=api_retrival_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrival_config + + tool = aiplatform_v1beta1.Tool( + retrieval=aiplatform_v1beta1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1beta1.AskContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + response = client.ask_contexts(request=request, timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in asking contexts due to: ", e) from e + + return response diff --git a/agentplatform/preview/rag/rag_store.py b/agentplatform/preview/rag/rag_store.py new file mode 100644 index 0000000000..f296def7f4 --- /dev/null +++ b/agentplatform/preview/rag/rag_store.py @@ -0,0 +1,168 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""RAG retrieval tool for content generation.""" + +import re +from typing import List, Optional, Union +import warnings + +from google.cloud import aiplatform_v1beta1 +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types +from agentplatform.preview.rag.utils import _gapic_utils +from agentplatform.preview.rag.utils import resources + + +class Retrieval: + """Defines a retrieval tool that a model can call to access external knowledge.""" + + def __init__( + self, + source: Union["VertexRagStore"], + disable_attribution: Optional[bool] = False, + ): + self._raw_retrieval = gapic_tool_types.Retrieval( + vertex_rag_store=source._raw_vertex_rag_store, + disable_attribution=disable_attribution, + ) + + +class VertexRagStore: + """Retrieve from Vertex RAG Store.""" + + def __init__( + self, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + ): + """Initializes a Vertex RAG store tool. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + # Using RagRetrievalConfig. + config = agentplatform.preview.rag.RagRetrievalConfig( + top_k=2, + filter=agentplatform.preview.rag.RagRetrievalConfig.Filter( + vector_distance_threshold=0.5 + ), + ) + + tool = Tool.from_retrieval( + retrieval=agentplatform.preview.rag.Retrieval( + source=agentplatform.preview.rag.VertexRagStore( + rag_resources=[ + agentplatform.preview.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1" + ) + ], + rag_retrieval_config=config, + ), + ) + ) + ``` + + Args: + rag_resources: List of RagResource to retrieve from. It can be used + to specify corpus only or ragfiles. Currently only support one + corpus or multiple files from one corpus. In the future we + may open up multiple corpora support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k and vector_distance_threshold. + """ + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + else: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + parent = initializer.global_config.common_location_path() + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should" + + " be: projects/{{project}}/locations/{{location}}/ragCorpora/{{rag_corpus_id}}" + ) + + if not rag_retrieval_config: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + else: + api_retrival_config = aiplatform_v1beta1.RagRetrievalConfig() + if rag_retrieval_config.top_k: + api_retrival_config.top_k = rag_retrieval_config.top_k + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + ): + api_retrival_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_similarity_threshold + ): + api_retrival_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.rank_service.model_name + and rag_retrieval_config.ranking.llm_ranker + and rag_retrieval_config.ranking.llm_ranker.model_name + ): + raise ValueError( + "Only one of rank_service or llm_ranker can be specified" + " at a time in rag_retrieval_config." + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrival_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrival_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + + gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( + rag_resources=[gapic_rag_resource], + rag_retrieval_config=api_retrival_config, + ) diff --git a/agentplatform/preview/rag/utils/_gapic_utils.py b/agentplatform/preview/rag/utils/_gapic_utils.py new file mode 100644 index 0000000000..ef3d166730 --- /dev/null +++ b/agentplatform/preview/rag/utils/_gapic_utils.py @@ -0,0 +1,1418 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +from typing import Any, Dict, Optional, Sequence, Union +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.utils import ( + VertexRagAsyncClientWithOverride, + VertexRagClientWithOverride, + VertexRagDataAsyncClientWithOverride, + VertexRagDataClientWithOverride, +) + +# +from google.cloud.aiplatform_v1beta1 import ( + GoogleDriveSource, + ImportRagFilesConfig, + ImportRagFilesRequest, + JiraSource as GapicJiraSource, + RagCorpus as GapicRagCorpus, + RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig, + RagEngineConfig as GapicRagEngineConfig, + RagFileChunkingConfig, + RagFileParsingConfig, + RagFileTransformationConfig, + RagFile as GapicRagFile, + RagManagedDbConfig as GapicRagManagedDbConfig, + RagVectorDbConfig as GapicRagVectorDbConfig, + SharePointSources as GapicSharePointSources, + SlackSource as GapicSlackSource, + VertexAiSearchConfig as GapicVertexAiSearchConfig, +) +from google.cloud.aiplatform_v1beta1.types import api_auth +from google.cloud.aiplatform_v1beta1.types import EncryptionSpec +from google.cloud.aiplatform_v1beta1.types import ( + vertex_rag_data as GapicRagDataTypes, +) +from agentplatform.preview.rag.utils.resources import ( + ANN, + Basic, + DocumentCorpus, + EmbeddingModelConfig, + JiraSource, + KNN, + LayoutParserConfig, + LlmParserConfig, + MemoryCorpus, + MetadataValue, + Pinecone, + RagCorpus, + RagCorpusTypeConfig, + RagDataSchema, + RagEmbeddingModelConfig, + RagEngineConfig, + RagFile, + RagManagedDb, + RagManagedDbConfig, + RagManagedVertexVectorSearch, + RagMetadata, + RagMetadataSchemaDetails, + RagVectorDbConfig, + Scaled, + Serverless, + SharePointSources, + SlackChannelsSource, + Spanner, + TransformationConfig, + Unprovisioned, + UserSpecifiedMetadata, + VertexAiSearchConfig, + VertexFeatureStore, + VertexPredictionEndpoint, + VertexVectorSearch, + Weaviate, +) + + +_VALID_RESOURCE_NAME_REGEX = "[a-z][a-zA-Z0-9._-]{0,127}" +_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = ( + r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?" +) + + +def create_rag_data_service_client(): + return initializer.global_config.create_client( + client_class=VertexRagDataClientWithOverride, + ).select_version("v1beta1") + + +def create_rag_data_service_async_client(): + return initializer.global_config.create_client( + client_class=VertexRagDataAsyncClientWithOverride, + ).select_version("v1beta1") + + +def create_rag_service_client(): + return initializer.global_config.create_client( + client_class=VertexRagClientWithOverride, + ).select_version("v1beta1") + + +def create_rag_service_async_client(): + return initializer.global_config.create_client( + client_class=VertexRagAsyncClientWithOverride, + ).select_version("v1beta1") + + +def convert_gapic_to_embedding_model_config( + gapic_embedding_model_config: GapicRagEmbeddingModelConfig, +) -> EmbeddingModelConfig: + """Convert GapicRagEmbeddingModelConfig to EmbeddingModelConfig.""" + embedding_model_config = EmbeddingModelConfig() + path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint + publisher_model = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + path, + ) + endpoint = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + if publisher_model: + embedding_model_config.publisher_model = path + if endpoint: + embedding_model_config.endpoint = path + embedding_model_config.model = ( + gapic_embedding_model_config.vertex_prediction_endpoint.model + ) + embedding_model_config.model_version_id = ( + gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id + ) + + return embedding_model_config + + +def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("weaviate") + except AttributeError: + if "weaviate" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("weaviate") + return False + + +def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("rag_managed_db") + except AttributeError: + if "rag_managed_db" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("rag_managed_db") + return False + + +def _check_knn(gapic_rag_managed_db: GapicRagVectorDbConfig.RagManagedDb) -> bool: + try: + return gapic_rag_managed_db.__contains__("knn") + except AttributeError: + if "knn" in gapic_rag_managed_db._pb.DESCRIPTOR.fields_by_name: + return gapic_rag_managed_db._pb.HasField("knn") + return False + + +def _check_ann(gapic_rag_managed_db: GapicRagVectorDbConfig.RagManagedDb) -> bool: + try: + return gapic_rag_managed_db.__contains__("ann") + except AttributeError: + if "ann" in gapic_rag_managed_db._pb.DESCRIPTOR.fields_by_name: + return gapic_rag_managed_db._pb.HasField("ann") + return False + + +def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("vertex_feature_store") + except AttributeError: + if "vertex_feature_store" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("vertex_feature_store") + return False + + +def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("pinecone") + except AttributeError: + if "pinecone" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("pinecone") + return False + + +def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("vertex_vector_search") + except AttributeError: + if "vertex_vector_search" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("vertex_vector_search") + return False + + +def _check_rag_managed_vertex_vector_search( + gapic_vector_db: GapicRagVectorDbConfig, +) -> bool: + try: + return gapic_vector_db.__contains__("rag_managed_vertex_vector_search") + except AttributeError: + return gapic_vector_db.rag_managed_vertex_vector_search._pb.ByteSize() > 0 + + +def _check_rag_embedding_model_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> bool: + try: + return gapic_vector_db.__contains__("rag_embedding_model_config") + except AttributeError: + return gapic_vector_db.rag_embedding_model_config._pb.ByteSize() > 0 + + +def _check_document_corpus( + gapic_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, +) -> bool: + try: + return gapic_corpus_type_config.__contains__("document_corpus") + except AttributeError: + return gapic_corpus_type_config.document_corpus._pb.ByteSize() > 0 + + +def _check_memory_corpus( + gapic_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, +) -> bool: + try: + return gapic_corpus_type_config.__contains__("memory_corpus") + except AttributeError: + return gapic_corpus_type_config.memory_corpus._pb.ByteSize() > 0 + + +def _convert_gapic_to_rag_managed_db( + gapic_rag_managed_db: GapicRagVectorDbConfig.RagManagedDb, +) -> RagManagedDb: + """Convert Gapic RagManagedDbConfig to RagManagedDb.""" + if _check_knn(gapic_rag_managed_db): + return RagManagedDb(retrieval_strategy=KNN()) + elif _check_ann(gapic_rag_managed_db): + return RagManagedDb( + retrieval_strategy=ANN( + tree_depth=gapic_rag_managed_db.ann.tree_depth, + leaf_count=gapic_rag_managed_db.ann.leaf_count, + ) + ) + else: + return RagManagedDb() + + +def _convert_rag_managed_db_to_gapic( + rag_managed_db: RagManagedDb, +) -> GapicRagVectorDbConfig.RagManagedDb: + """Convert RagManagedDb to Gapic RagManagedDb.""" + if isinstance(rag_managed_db.retrieval_strategy, KNN): + return GapicRagVectorDbConfig.RagManagedDb( + knn=GapicRagVectorDbConfig.RagManagedDb.KNN() + ) + elif isinstance(rag_managed_db.retrieval_strategy, ANN): + return GapicRagVectorDbConfig.RagManagedDb( + ann=GapicRagVectorDbConfig.RagManagedDb.ANN( + tree_depth=rag_managed_db.retrieval_strategy.tree_depth, + leaf_count=rag_managed_db.retrieval_strategy.leaf_count, + ) + ) + else: + return GapicRagVectorDbConfig.RagManagedDb() + + +def convert_gapic_to_vector_db( + gapic_vector_db: GapicRagVectorDbConfig, +) -> Union[Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb]: + """Convert Gapic GapicRagVectorDbConfig to Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, or Pinecone.""" + if _check_weaviate(gapic_vector_db): + return Weaviate( + weaviate_http_endpoint=gapic_vector_db.weaviate.http_endpoint, + collection_name=gapic_vector_db.weaviate.collection_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + elif _check_vertex_feature_store(gapic_vector_db): + return VertexFeatureStore( + resource_name=gapic_vector_db.vertex_feature_store.feature_view_resource_name, + ) + elif _check_pinecone(gapic_vector_db): + return Pinecone( + index_name=gapic_vector_db.pinecone.index_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + elif _check_rag_managed_vertex_vector_search(gapic_vector_db): + return RagManagedVertexVectorSearch( + collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name, + ) + elif _check_vertex_vector_search(gapic_vector_db): + return VertexVectorSearch( + index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, + index=gapic_vector_db.vertex_vector_search.index, + ) + elif _check_rag_managed_db(gapic_vector_db): + return _convert_gapic_to_rag_managed_db(gapic_vector_db.rag_managed_db) + else: + return None + + +def convert_gapic_to_vertex_ai_search_config( + gapic_vertex_ai_search_config: GapicVertexAiSearchConfig, +) -> Optional[VertexAiSearchConfig]: + """Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig.""" + if gapic_vertex_ai_search_config.serving_config: + return VertexAiSearchConfig( + serving_config=gapic_vertex_ai_search_config.serving_config, + ) + return None + + +def convert_gapic_to_rag_embedding_model_config( + gapic_embedding_model_config: GapicRagEmbeddingModelConfig, +) -> RagEmbeddingModelConfig: + """Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig.""" + embedding_model_config = RagEmbeddingModelConfig() + path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint + publisher_model = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + path, + ) + endpoint = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + if publisher_model: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + publisher_model=path + ) + if endpoint: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + endpoint=path, + model=gapic_embedding_model_config.vertex_prediction_endpoint.model, + model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id, + ) + return embedding_model_config + + +def convert_gapic_to_backend_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> RagVectorDbConfig: + """Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb.""" + + vector_config = RagVectorDbConfig() + if _check_pinecone(gapic_vector_db): + vector_config.vector_db = Pinecone( + index_name=gapic_vector_db.pinecone.index_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + elif _check_rag_managed_vertex_vector_search(gapic_vector_db): + vector_config.vector_db = RagManagedVertexVectorSearch( + collection_name=gapic_vector_db.rag_managed_vertex_vector_search.collection_name, + ) + elif _check_vertex_vector_search(gapic_vector_db): + vector_config.vector_db = VertexVectorSearch( + index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, + index=gapic_vector_db.vertex_vector_search.index, + ) + elif _check_rag_managed_db(gapic_vector_db): + vector_config.vector_db = _convert_gapic_to_rag_managed_db( + gapic_vector_db.rag_managed_db + ) + if _check_rag_embedding_model_config(gapic_vector_db): + vector_config.rag_embedding_model_config = ( + convert_gapic_to_rag_embedding_model_config( + gapic_vector_db.rag_embedding_model_config + ) + ) + if ( + vector_config.vector_db is None + and vector_config.rag_embedding_model_config is None + ): + return None + return vector_config + + +def convert_gapic_to_rag_corpus_type_config( + gapic_rag_corpus_type_config: GapicRagCorpus.CorpusTypeConfig, +) -> RagCorpusTypeConfig: + """Convert GapicRagCorpus.CorpusTypeConfig to RagCorpusTypeConfig.""" + if _check_document_corpus(gapic_rag_corpus_type_config): + return RagCorpusTypeConfig(corpus_type_config=DocumentCorpus()) + elif _check_memory_corpus(gapic_rag_corpus_type_config): + return RagCorpusTypeConfig( + corpus_type_config=MemoryCorpus( + llm_parser=LlmParserConfig( + model_name=gapic_rag_corpus_type_config.memory_corpus.llm_parser.model_name, + max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.max_parsing_requests_per_min, + global_max_parsing_requests_per_min=gapic_rag_corpus_type_config.memory_corpus.llm_parser.global_max_parsing_requests_per_min, + custom_parsing_prompt=gapic_rag_corpus_type_config.memory_corpus.llm_parser.custom_parsing_prompt, + ) + ) + ) + return None + + +def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: + """Convert GapicRagCorpus to RagCorpus.""" + rag_corpus = RagCorpus( + name=gapic_rag_corpus.name, + display_name=gapic_rag_corpus.display_name, + description=gapic_rag_corpus.description, + corpus_type_config=convert_gapic_to_rag_corpus_type_config( + gapic_rag_corpus.corpus_type_config + ), + vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ), + backend_config=convert_gapic_to_backend_config( + gapic_rag_corpus.vector_db_config + ), + encryption_spec=gapic_rag_corpus.encryption_spec, + ) + return rag_corpus + + +def convert_gapic_to_rag_corpus_no_embedding_model_config( + gapic_rag_corpus: GapicRagCorpus, +) -> RagCorpus: + """Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus.""" + vertex_ai_search_config = convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ) + old_config = gapic_rag_corpus.vector_db_config + rag_vector_db_config_no_embedding_model_config = old_config.__class__() + if _check_rag_managed_db(old_config): + rag_vector_db_config_no_embedding_model_config.rag_managed_db = ( + old_config.rag_managed_db + ) + elif _check_pinecone(old_config): + rag_vector_db_config_no_embedding_model_config.pinecone = old_config.pinecone + elif _check_vertex_vector_search(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_vector_search = ( + old_config.vertex_vector_search + ) + elif _check_weaviate(old_config): + rag_vector_db_config_no_embedding_model_config.weaviate = old_config.weaviate + elif _check_vertex_feature_store(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_feature_store = ( + old_config.vertex_feature_store + ) + try: + if old_config.__contains__("api_auth"): + rag_vector_db_config_no_embedding_model_config.api_auth = ( + old_config.api_auth + ) + except AttributeError: + pass + rag_corpus = RagCorpus( + name=gapic_rag_corpus.name, + display_name=gapic_rag_corpus.display_name, + description=gapic_rag_corpus.description, + vertex_ai_search_config=vertex_ai_search_config, + backend_config=convert_gapic_to_backend_config( + rag_vector_db_config_no_embedding_model_config + ), + encryption_spec=gapic_rag_corpus.encryption_spec, + ) + return rag_corpus + + +def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile: + """Convert GapicRagFile to RagFile.""" + rag_file = RagFile( + name=gapic_rag_file.name, + display_name=gapic_rag_file.display_name, + description=gapic_rag_file.description, + ) + return rag_file + + +def convert_gapic_to_rag_metadata( + gapic_rag_metadata: GapicRagDataTypes.RagMetadata, +) -> RagMetadata: + """Convert Gapic RagMetadata to RagMetadata.""" + return RagMetadata( + name=gapic_rag_metadata.name, + user_specified_metadata=convert_gapic_to_user_specified_metadata( + gapic_rag_metadata.user_specified_metadata + ), + ) + + +def convert_gapic_to_user_specified_metadata( + gapic_user_specified_metadata: GapicRagDataTypes.UserSpecifiedMetadata, +) -> UserSpecifiedMetadata: + """Convert Gapic UserSpecifiedMetadata to UserSpecifiedMetadata.""" + if not gapic_user_specified_metadata: + return None + return UserSpecifiedMetadata( + values={ + gapic_user_specified_metadata.key: convert_gapic_to_metadata_value( + gapic_user_specified_metadata.value + ) + } + ) + + +def convert_gapic_to_metadata_value( + gapic_metadata_value: GapicRagDataTypes.MetadataValue, +) -> MetadataValue: + """Convert Gapic MetadataValue to MetadataValue.""" + if not gapic_metadata_value: + return None + oneof_field = gapic_metadata_value._pb.WhichOneof("value") + if oneof_field == "str_value": + return MetadataValue(string_value=gapic_metadata_value.str_value) + elif oneof_field == "int_value": + return MetadataValue(int_value=gapic_metadata_value.int_value) + elif oneof_field == "float_value": + return MetadataValue(float_value=gapic_metadata_value.float_value) + elif oneof_field == "bool_value": + return MetadataValue(bool_value=gapic_metadata_value.bool_value) + return MetadataValue() + + +def convert_rag_metadata_to_gapic( + rag_metadata: RagMetadata, +) -> GapicRagDataTypes.RagMetadata: + """Convert RagMetadata to Gapic RagMetadata.""" + return GapicRagDataTypes.RagMetadata( + name=rag_metadata.name, + user_specified_metadata=convert_user_specified_metadata_to_gapic( + rag_metadata.user_specified_metadata + ), + ) + + +def convert_user_specified_metadata_to_gapic( + user_specified_metadata: UserSpecifiedMetadata, +) -> GapicRagDataTypes.UserSpecifiedMetadata: + """Convert UserSpecifiedMetadata to Gapic UserSpecifiedMetadata.""" + if not user_specified_metadata: + return None + if user_specified_metadata.values: + if len(user_specified_metadata.values) > 1: + raise ValueError( + "Only one key-value pair is supported in UserSpecifiedMetadata." + ) + key = list(user_specified_metadata.values.keys())[0] + return GapicRagDataTypes.UserSpecifiedMetadata( + key=key, + value=convert_metadata_value_to_gapic(user_specified_metadata.values[key]), + ) + return GapicRagDataTypes.UserSpecifiedMetadata() + + +def convert_metadata_value_to_gapic( + metadata_value: MetadataValue, +) -> GapicRagDataTypes.MetadataValue: + """Convert MetadataValue to Gapic MetadataValue.""" + if not metadata_value: + return None + if metadata_value.string_value is not None: + return GapicRagDataTypes.MetadataValue(str_value=metadata_value.string_value) + if metadata_value.int_value is not None: + return GapicRagDataTypes.MetadataValue(int_value=metadata_value.int_value) + if metadata_value.float_value is not None: + return GapicRagDataTypes.MetadataValue(float_value=metadata_value.float_value) + if metadata_value.bool_value is not None: + return GapicRagDataTypes.MetadataValue(bool_value=metadata_value.bool_value) + return GapicRagDataTypes.MetadataValue() + + +def convert_gapic_to_rag_data_schema( + gapic_rag_data_schema: GapicRagDataTypes.RagDataSchema, +) -> RagDataSchema: + """Convert Gapic RagDataSchema to RagDataSchema.""" + return RagDataSchema( + name=gapic_rag_data_schema.name, + key=gapic_rag_data_schema.key, + schema_details=convert_gapic_to_rag_metadata_schema_details( + gapic_rag_data_schema.schema_details + ), + ) + + +def convert_gapic_to_rag_metadata_schema_details( + gapic_details: GapicRagDataTypes.RagMetadataSchemaDetails, +) -> RagMetadataSchemaDetails: + """Convert Gapic RagMetadataSchemaDetails to RagMetadataSchemaDetails.""" + if not gapic_details: + return None + list_config = None + if gapic_details.list_config: + list_config = RagMetadataSchemaDetails.ListConfig( + value_schema=convert_gapic_to_rag_metadata_schema_details( + gapic_details.list_config.value_schema + ) + ) + search_strategy = None + if gapic_details.search_strategy: + search_strategy = RagMetadataSchemaDetails.SearchStrategy( + search_strategy_type=GapicRagDataTypes.RagMetadataSchemaDetails.SearchStrategy.SearchStrategyType( + gapic_details.search_strategy.search_strategy_type + ).name + ) + return RagMetadataSchemaDetails( + type=GapicRagDataTypes.RagMetadataSchemaDetails.DataType( + gapic_details.type_ + ).name, + granularity=GapicRagDataTypes.RagMetadataSchemaDetails.Granularity( + gapic_details.granularity + ).name, + list_config=list_config, + search_strategy=search_strategy, + ) + + +def convert_rag_data_schema_to_gapic( + rag_data_schema: RagDataSchema, +) -> GapicRagDataTypes.RagDataSchema: + """Convert RagDataSchema to Gapic RagDataSchema.""" + return GapicRagDataTypes.RagDataSchema( + name=rag_data_schema.name, + key=rag_data_schema.key, + schema_details=convert_rag_metadata_schema_details_to_gapic( + rag_data_schema.schema_details + ), + ) + + +def convert_rag_metadata_schema_details_to_gapic( + details: RagMetadataSchemaDetails, +) -> GapicRagDataTypes.RagMetadataSchemaDetails: + """Convert RagMetadataSchemaDetails to Gapic RagMetadataSchemaDetails.""" + if not details: + return None + list_config = None + if details.list_config: + list_config = GapicRagDataTypes.RagMetadataSchemaDetails.ListConfig( + value_schema=convert_rag_metadata_schema_details_to_gapic( + details.list_config.value_schema + ) + ) + search_strategy = None + if details.search_strategy: + search_strategy = GapicRagDataTypes.RagMetadataSchemaDetails.SearchStrategy( + search_strategy_type=details.search_strategy.search_strategy_type + ) + return GapicRagDataTypes.RagMetadataSchemaDetails( + type_=( + details.type + if details.type + else GapicRagDataTypes.RagMetadataSchemaDetails.DataType.DATA_TYPE_UNSPECIFIED + ), + granularity=( + details.granularity + if details.granularity + else GapicRagDataTypes.RagMetadataSchemaDetails.Granularity.GRANULARITY_UNSPECIFIED + ), + list_config=list_config, + search_strategy=search_strategy, + ) + + +def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile: + """Converts a JSON response to a RagFile.""" + rag_file = RagFile( + name=upload_rag_file_response.get("ragFile").get("name"), + display_name=upload_rag_file_response.get("ragFile").get("displayName"), + description=upload_rag_file_response.get("ragFile").get("description"), + ) + return rag_file + + +def convert_path_to_resource_id( + path: str, +) -> Union[str, GoogleDriveSource.ResourceId]: + """Converts a path to a Google Cloud storage uri or GoogleDriveSource.ResourceId.""" + if path.startswith("gs://"): + # Google Cloud Storage source + return path + elif path.startswith("https://drive.google.com/"): + # Google Drive source + path_list = path.split("/") + if "file" in path_list: + index = path_list.index("file") + 2 + resource_id = path_list[index].split("?")[0] + resource_type = GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE + elif "folders" in path_list: + index = path_list.index("folders") + 1 + resource_id = path_list[index].split("?")[0] + resource_type = ( + GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER + ) + else: + raise ValueError("path %s is not a valid Google Drive url.", path) + + return GoogleDriveSource.ResourceId( + resource_id=resource_id, + resource_type=resource_type, + ) + else: + raise ValueError( + "path must be a Google Cloud Storage uri or a Google Drive url." + ) + + +def convert_source_for_rag_import( + source: Union[SlackChannelsSource, JiraSource, SharePointSources], +) -> Union[GapicSlackSource, GapicJiraSource]: + """Converts a SlackChannelsSource or JiraSource to a GapicSlackSource or GapicJiraSource.""" + if isinstance(source, SlackChannelsSource): + result_source_channels = [] + for channel in source.channels: + api_key = channel.api_key + cid = channel.channel_id + start_time = channel.start_time + end_time = channel.end_time + result_channels = GapicSlackSource.SlackChannels( + channels=[ + GapicSlackSource.SlackChannels.SlackChannel( + channel_id=cid, + start_time=start_time, + end_time=end_time, + ) + ], + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ) + result_source_channels.append(result_channels) + return GapicSlackSource( + channels=result_source_channels, + ) + elif isinstance(source, JiraSource): + result_source_queries = [] + for query in source.queries: + api_key = query.api_key + custom_queries = query.custom_queries + projects = query.jira_projects + email = query.email + server_uri = query.server_uri + result_query = GapicJiraSource.JiraQueries( + custom_queries=custom_queries, + projects=projects, + email=email, + server_uri=server_uri, + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ) + result_source_queries.append(result_query) + return GapicJiraSource( + jira_queries=result_source_queries, + ) + elif isinstance(source, SharePointSources): + result_source_share_point_sources = [] + for share_point_source in source.share_point_sources: + sharepoint_folder_path = share_point_source.sharepoint_folder_path + sharepoint_folder_id = share_point_source.sharepoint_folder_id + drive_name = share_point_source.drive_name + drive_id = share_point_source.drive_id + client_id = share_point_source.client_id + client_secret = share_point_source.client_secret + tenant_id = share_point_source.tenant_id + sharepoint_site_name = share_point_source.sharepoint_site_name + result_share_point_source = GapicSharePointSources.SharePointSource( + client_id=client_id, + client_secret=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=client_secret + ), + tenant_id=tenant_id, + sharepoint_site_name=sharepoint_site_name, + ) + if sharepoint_folder_path is not None and sharepoint_folder_id is not None: + raise ValueError( + "sharepoint_folder_path and sharepoint_folder_id cannot both be set." + ) + elif sharepoint_folder_path is not None: + result_share_point_source.sharepoint_folder_path = ( + sharepoint_folder_path + ) + elif sharepoint_folder_id is not None: + result_share_point_source.sharepoint_folder_id = sharepoint_folder_id + if drive_name is not None and drive_id is not None: + raise ValueError("drive_name and drive_id cannot both be set.") + elif drive_name is not None: + result_share_point_source.drive_name = drive_name + elif drive_id is not None: + result_share_point_source.drive_id = drive_id + else: + raise ValueError("Either drive_name and drive_id must be set.") + result_source_share_point_sources.append(result_share_point_source) + return GapicSharePointSources( + share_point_sources=result_source_share_point_sources, + ) + else: + raise TypeError( + "source must be a SlackChannelsSource or JiraSource or SharePointSources." + ) + + +def prepare_import_files_request( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + chunk_size: int = 1024, + chunk_overlap: int = 200, + transformation_config: Optional[TransformationConfig] = None, + max_embedding_requests_per_min: int = 1000, + global_max_embedding_requests_per_min: Optional[int] = None, + use_advanced_pdf_parsing: bool = False, + partial_failures_sink: Optional[str] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, + rebuild_ann_index: bool = False, +) -> ImportRagFilesRequest: + if len(corpus_name.split("/")) != 6: + raise ValueError( + "corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`" + ) + + rag_file_parsing_config = RagFileParsingConfig( + advanced_parser=RagFileParsingConfig.AdvancedParser( + use_advanced_pdf_parsing=use_advanced_pdf_parsing, + ), + ) + if layout_parser is not None: + if ( + re.fullmatch( + _VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, layout_parser.processor_name + ) + is None + ): + raise ValueError( + "processor_name must be of the format " + "`projects/{project_id}/locations/{location}/processors/{processor_id}`" + "or " + "`projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`, " + f"got {layout_parser.processor_name!r}" + ) + rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser( + processor_name=layout_parser.processor_name, + ) + if layout_parser.max_parsing_requests_per_min is not None: + rag_file_parsing_config.layout_parser.max_parsing_requests_per_min = ( + layout_parser.max_parsing_requests_per_min + ) + if layout_parser.global_max_parsing_requests_per_min is not None: + rag_file_parsing_config.layout_parser.global_max_parsing_requests_per_min = ( + layout_parser.global_max_parsing_requests_per_min + ) + if llm_parser is not None: + rag_file_parsing_config.llm_parser = RagFileParsingConfig.LlmParser( + model_name=llm_parser.model_name + ) + if llm_parser.max_parsing_requests_per_min is not None: + rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = ( + llm_parser.max_parsing_requests_per_min + ) + if llm_parser.global_max_parsing_requests_per_min is not None: + rag_file_parsing_config.llm_parser.global_max_parsing_requests_per_min = ( + llm_parser.global_max_parsing_requests_per_min + ) + if llm_parser.custom_parsing_prompt is not None: + rag_file_parsing_config.llm_parser.custom_parsing_prompt = ( + llm_parser.custom_parsing_prompt + ) + + local_chunk_size = chunk_size + local_chunk_overlap = chunk_overlap + if transformation_config and transformation_config.chunking_config: + local_chunk_size = transformation_config.chunking_config.chunk_size + local_chunk_overlap = transformation_config.chunking_config.chunk_overlap + + rag_file_transformation_config = RagFileTransformationConfig( + rag_file_chunking_config=RagFileChunkingConfig( + fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking( + chunk_size=local_chunk_size, + chunk_overlap=local_chunk_overlap, + ), + ), + ) + import_rag_files_config = ImportRagFilesConfig( + rag_file_transformation_config=rag_file_transformation_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + rag_file_parsing_config=rag_file_parsing_config, + rebuild_ann_index=rebuild_ann_index, + ) + + if global_max_embedding_requests_per_min is not None: + import_rag_files_config.global_max_embedding_requests_per_min = ( + global_max_embedding_requests_per_min + ) + if source is not None: + gapic_source = convert_source_for_rag_import(source) + if isinstance(gapic_source, GapicSlackSource): + import_rag_files_config.slack_source = gapic_source + if isinstance(gapic_source, GapicJiraSource): + import_rag_files_config.jira_source = gapic_source + if isinstance(gapic_source, GapicSharePointSources): + import_rag_files_config.share_point_sources = gapic_source + else: + uris = [] + resource_ids = [] + for p in paths: + output = convert_path_to_resource_id(p) + if isinstance(output, str): + uris.append(p) + else: + resource_ids.append(output) + if uris: + import_rag_files_config.gcs_source.uris = uris + if resource_ids: + google_drive_source = GoogleDriveSource( + resource_ids=resource_ids, + ) + import_rag_files_config.google_drive_source = google_drive_source + + if partial_failures_sink is not None: + if partial_failures_sink.startswith("gs://"): + import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = ( + partial_failures_sink + ) + elif partial_failures_sink.startswith( + "bq://" + ) or partial_failures_sink.startswith("bigquery://"): + import_rag_files_config.partial_failure_bigquery_sink.output_uri = ( + partial_failures_sink + ) + else: + raise ValueError( + "if provided, partial_failures_sink must be a GCS path or a BigQuery table." + ) + + request = ImportRagFilesRequest( + parent=corpus_name, import_rag_files_config=import_rag_files_config + ) + return request + + +def get_corpus_name( + name: str, +) -> str: + if name: + client = create_rag_data_service_client() + if client.parse_rag_corpus_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + return client.rag_corpus_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=name, + ) + else: + raise ValueError( + "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`" + ) + return name + + +def get_file_name( + name: str, + corpus_name: str, +) -> str: + client = create_rag_data_service_client() + if client.parse_rag_file_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name: + raise ValueError( + "corpus_name must be provided if name is a `{rag_file}`, not a " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). " + ) + return client.rag_file_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name), + rag_file=name, + ) + else: + raise ValueError( + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`" + " or `{rag_file}`" + ) + + +def get_data_schema_name( + name: str, + corpus_name: str, +) -> str: + """Get the full resource name for a RagDataSchema.""" + if name: + if len(name.split("/")) == 8: + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name: + raise ValueError( + "corpus_name must be provided if name is a `{rag_data_schema}`," + " not a " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}`). " + ) + return "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}".format( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name).split("/")[-1], + rag_data_schema=name, + ) + else: + raise ValueError( + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragDataSchemas/{rag_data_schema}`" + " or `{rag_data_schema}`" + ) + return name + + +def get_metadata_name( + name: str, + corpus_name: str, + file_name: str, +) -> str: + """Get the full resource name for a RagMetadata.""" + if name: + if len(name.split("/")) == 10: + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name or not file_name: + raise ValueError( + "corpus_name and file_name must be provided if name is a" + " `{rag_metadata}`, not a " + "full resource name" + " (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}`). " + ) + return "projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}".format( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name).split("/")[-1], + rag_file=get_file_name(file_name, corpus_name).split("/")[-1], + rag_metadata=name, + ) + else: + raise ValueError( + "name must be of the format" + " `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}/ragMetadata/{rag_metadata}`" + " or `{rag_metadata}`" + ) + return name + + +def set_corpus_type_config( + corpus_type_config: RagCorpusTypeConfig, + rag_corpus: GapicRagCorpus, +) -> None: + """Set corpus type config in GapicRagCorpus.""" + if isinstance(corpus_type_config.corpus_type_config, DocumentCorpus): + rag_corpus.corpus_type_config = GapicRagCorpus.CorpusTypeConfig( + document_corpus=GapicRagCorpus.CorpusTypeConfig.DocumentCorpus() + ) + elif isinstance(corpus_type_config.corpus_type_config, MemoryCorpus): + memory_corpus = GapicRagCorpus.CorpusTypeConfig.MemoryCorpus() + if corpus_type_config.corpus_type_config.llm_parser is not None: + memory_corpus.llm_parser = RagFileParsingConfig.LlmParser( + model_name=corpus_type_config.corpus_type_config.llm_parser.model_name + ) + if ( + corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min + is not None + ): + memory_corpus.llm_parser.max_parsing_requests_per_min = ( + corpus_type_config.corpus_type_config.llm_parser.max_parsing_requests_per_min + ) + if ( + corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min + is not None + ): + memory_corpus.llm_parser.global_max_parsing_requests_per_min = ( + corpus_type_config.corpus_type_config.llm_parser.global_max_parsing_requests_per_min + ) + if ( + corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt + is not None + ): + memory_corpus.llm_parser.custom_parsing_prompt = ( + corpus_type_config.corpus_type_config.llm_parser.custom_parsing_prompt + ) + rag_corpus.corpus_type_config = GapicRagCorpus.CorpusTypeConfig( + memory_corpus=memory_corpus + ) + else: + raise TypeError + + +def set_embedding_model_config( + embedding_model_config: EmbeddingModelConfig, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the embedding model config for the rag corpus.""" + if embedding_model_config.publisher_model and embedding_model_config.endpoint: + raise ValueError("publisher_model and endpoint cannot be set at the same time.") + if ( + not embedding_model_config.publisher_model + and not embedding_model_config.endpoint + ): + raise ValueError("At least one of publisher_model and endpoint must be set.") + parent = initializer.global_config.common_location_path(project=None, location=None) + + if embedding_model_config.publisher_model: + publisher_model = embedding_model_config.publisher_model + full_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + publisher_model, + ) + resource_name = re.match( + r"^publishers/google/models/(?P.+?)$", + publisher_model, + ) + if full_resource_name: + rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + publisher_model + ) + elif resource_name: + rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + parent + "/" + publisher_model + ) + else: + raise ValueError( + "publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`" + ) + + if embedding_model_config.endpoint: + endpoint = embedding_model_config.endpoint + full_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + endpoint, + ) + resource_name = re.match( + r"^endpoints/(?P.+?)$", + endpoint, + ) + if full_resource_name: + rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + endpoint + ) + elif resource_name: + rag_corpus.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + parent + "/" + endpoint + ) + else: + raise ValueError( + "endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`" + ) + + +def set_encryption_spec( + encryption_spec: EncryptionSpec, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the encryption spec for the rag corpus.""" + # Raises value error if encryption_spec.kms_key_name is None or empty, + if encryption_spec.kms_key_name is None or not encryption_spec.kms_key_name: + raise ValueError("kms_key_name must be set if encryption_spec is set.") + + # Raises value error if encryption_spec.kms_key_name is not a valid KMS key name. + if not re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/keyRings/(?P.+?)/cryptoKeys/(?P.+?)$", + encryption_spec.kms_key_name, + ): + raise ValueError( + "kms_key_name must be of the format " + "`projects/{project}/locations/{location}/keyRings/{key_ring}/cryptoKeys/{crypto_key}`" + ) + + rag_corpus.encryption_spec = encryption_spec + + +def set_vector_db( + vector_db: Union[ + Weaviate, VertexFeatureStore, VertexVectorSearch, Pinecone, RagManagedDb, None + ], + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if vector_db is None: + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig() + elif isinstance(vector_db, RagManagedDb): + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + rag_managed_db=_convert_rag_managed_db_to_gapic(vector_db) + ) + elif isinstance(vector_db, Weaviate): + http_endpoint = vector_db.weaviate_http_endpoint + collection_name = vector_db.collection_name + api_key = vector_db.api_key + + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + weaviate=GapicRagVectorDbConfig.Weaviate( + http_endpoint=http_endpoint, + collection_name=collection_name, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ), + ) + elif isinstance(vector_db, VertexFeatureStore): + resource_name = vector_db.resource_name + + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + vertex_feature_store=GapicRagVectorDbConfig.VertexFeatureStore( + feature_view_resource_name=resource_name, + ), + ) + elif isinstance(vector_db, VertexVectorSearch): + index_endpoint = vector_db.index_endpoint + index = vector_db.index + + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + vertex_vector_search=GapicRagVectorDbConfig.VertexVectorSearch( + index_endpoint=index_endpoint, + index=index, + ), + ) + elif isinstance(vector_db, Pinecone): + index_name = vector_db.index_name + api_key = vector_db.api_key + + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + pinecone=GapicRagVectorDbConfig.Pinecone( + index_name=index_name, + ), + api_auth=api_auth.ApiAuth( + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ), + ) + elif isinstance(vector_db, RagManagedVertexVectorSearch): + rag_corpus.rag_vector_db_config = GapicRagVectorDbConfig( + rag_managed_vertex_vector_search=GapicRagVectorDbConfig.RagManagedVertexVectorSearch(), + ) + + else: + raise TypeError( + "vector_db must be a Weaviate, VertexFeatureStore, VertexVectorSearch, RagManagedDb, Pinecone, or RagManagedVertexVectorSearch." + ) + + +def set_vertex_ai_search_config( + vertex_ai_search_config: VertexAiSearchConfig, + rag_corpus: GapicRagCorpus, +) -> None: + if not vertex_ai_search_config.serving_config: + raise ValueError("serving_config must be set.") + engine_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/collections/(?P.+?)/engines/(?P.+?)/servingConfigs/(?P.+?)$", + vertex_ai_search_config.serving_config, + ) + data_store_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/collections/(?P.+?)/dataStores/(?P.+?)/servingConfigs/(?P.+?)$", + vertex_ai_search_config.serving_config, + ) + if engine_resource_name or data_store_resource_name: + rag_corpus.vertex_ai_search_config = GapicVertexAiSearchConfig( + serving_config=vertex_ai_search_config.serving_config, + ) + else: + raise ValueError( + "serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`" + ) + + +def set_backend_config( + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ], + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if backend_config is None: + return + + if backend_config.vector_db is not None: + vector_config = backend_config.vector_db + if vector_config is None: + rag_corpus.vector_db_config.rag_managed_db.CopyFrom( + GapicRagVectorDbConfig.RagManagedDb() + ) + elif isinstance(vector_config, RagManagedDb): + rag_corpus.vector_db_config.rag_managed_db = ( + _convert_rag_managed_db_to_gapic(vector_config) + ) + elif isinstance(vector_config, VertexVectorSearch): + index_endpoint = vector_config.index_endpoint + index = vector_config.index + + rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = ( + index_endpoint + ) + rag_corpus.vector_db_config.vertex_vector_search.index = index + elif isinstance(vector_config, Pinecone): + index_name = vector_config.index_name + api_key = vector_config.api_key + + rag_corpus.vector_db_config.pinecone.index_name = index_name + rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = ( + api_key + ) + elif isinstance(vector_config, RagManagedVertexVectorSearch): + rag_corpus.vector_db_config.rag_managed_vertex_vector_search = ( + GapicRagVectorDbConfig.RagManagedVertexVectorSearch() + ) + else: + raise TypeError( + "backend_config must be a VertexFeatureStore," + "RagManagedDb, Pinecone, or RagManagedVertexVectorSearch." + ) + if backend_config.rag_embedding_model_config: + set_embedding_model_config( + backend_config.rag_embedding_model_config, rag_corpus + ) + + +def _convert_gapic_to_spanner( + gapic_spanner: GapicRagManagedDbConfig.Spanner, +) -> Spanner: + """Converts a GapicRagManagedDbConfig.Spanner to a Spanner.""" + spanner = Spanner() + if gapic_spanner.__contains__("scaled"): + spanner.tier = Scaled() + elif gapic_spanner.__contains__("basic"): + spanner.tier = Basic() + elif gapic_spanner.__contains__("unprovisioned"): + spanner.tier = Unprovisioned() + else: + raise ValueError("At least one of scaled, basic, or unprovisioned must be set.") + return spanner + + +def convert_gapic_to_rag_engine_config( + response: GapicRagEngineConfig, +) -> RagEngineConfig: + """Converts a GapicRagEngineConfig to a RagEngineConfig.""" + rag_managed_db_config = RagManagedDbConfig() + if response.rag_managed_db_config.__contains__("spanner"): + rag_managed_db_config.mode = _convert_gapic_to_spanner( + response.rag_managed_db_config.spanner + ) + elif response.rag_managed_db_config.__contains__("serverless"): + rag_managed_db_config.mode = Serverless() + elif response.rag_managed_db_config.__contains__("basic"): + rag_managed_db_config.mode = Spanner(tier=Basic()) + elif response.rag_managed_db_config.__contains__("unprovisioned"): + rag_managed_db_config.mode = Spanner(tier=Unprovisioned()) + elif response.rag_managed_db_config.__contains__("scaled"): + rag_managed_db_config.mode = Spanner(tier=Scaled()) + else: + raise ValueError("At least one of rag_managed_db_config mode must be set.") + return RagEngineConfig( + name=response.name, + rag_managed_db_config=rag_managed_db_config, + ) + + +def _convert_spanner_to_gapic( + spanner: Spanner, +) -> GapicRagManagedDbConfig.Spanner: + """Converts a Spanner to a GapicRagManagedDbConfig.Spanner.""" + gapic_spanner = GapicRagManagedDbConfig.Spanner() + if isinstance(spanner.tier, Scaled): + gapic_spanner.scaled = GapicRagManagedDbConfig.Scaled() + elif isinstance(spanner.tier, Basic): + gapic_spanner.basic = GapicRagManagedDbConfig.Basic() + elif isinstance(spanner.tier, Unprovisioned): + gapic_spanner.unprovisioned = GapicRagManagedDbConfig.Unprovisioned() + return gapic_spanner + + +def convert_rag_engine_config_to_gapic( + rag_engine_config: RagEngineConfig, +) -> GapicRagEngineConfig: + """Converts a RagEngineConfig to a GapicRagEngineConfig.""" + rag_managed_db_config = GapicRagManagedDbConfig() + + if rag_engine_config.rag_managed_db_config is None or ( + rag_engine_config.rag_managed_db_config.mode is None + ): + rag_managed_db_config = GapicRagManagedDbConfig( + spanner=GapicRagManagedDbConfig.Spanner( + basic=GapicRagManagedDbConfig.Basic() + ) + ) + else: + if isinstance(rag_engine_config.rag_managed_db_config.mode, Serverless): + rag_managed_db_config.serverless = GapicRagManagedDbConfig.Serverless() + elif isinstance(rag_engine_config.rag_managed_db_config.mode, Spanner): + rag_managed_db_config.spanner = _convert_spanner_to_gapic( + rag_engine_config.rag_managed_db_config.mode + ) + return GapicRagEngineConfig( + name=rag_engine_config.name, + rag_managed_db_config=rag_managed_db_config, + ) diff --git a/agentplatform/preview/rag/utils/resources.py b/agentplatform/preview/rag/utils/resources.py new file mode 100644 index 0000000000..254d28852e --- /dev/null +++ b/agentplatform/preview/rag/utils/resources.py @@ -0,0 +1,817 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dataclasses +from typing import List, Optional, Sequence, Union + +from google.cloud.aiplatform_v1beta1.types import EncryptionSpec + +from google.protobuf import timestamp_pb2 + +DEPRECATION_DATE = "June 2025" + + +# +@dataclasses.dataclass +class RagFile: + """RAG file (output only). + + Attributes: + name: Generated resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}/ragFiles/{rag_file}`` + display_name: Display name that was configured at client side. + description: The description of the RagFile. + """ + + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None + + +@dataclasses.dataclass +class EmbeddingModelConfig: + """EmbeddingModelConfig. + + The representation of the embedding model config. Users input a 1P embedding + model as a Publisher model resource, or a 1P fine tuned embedding model + as an Endpoint resource. + + Attributes: + publisher_model: 1P publisher model resource name. Format: + ``publishers/google/models/{model}`` or + ``projects/{project}/locations/{location}/publishers/google/models/{model}`` + endpoint: 1P fine tuned embedding model resource name. Format: + ``endpoints/{endpoint}`` or + ``projects/{project}/locations/{location}/endpoints/{endpoint}``. + model: + Output only. The resource name of the model that is deployed + on the endpoint. Present only when the endpoint is not a + publisher model. Pattern: + ``projects/{project}/locations/{location}/models/{model}`` + model_version_id: + Output only. Version ID of the model that is + deployed on the endpoint. Present only when the + endpoint is not a publisher model. + """ + + publisher_model: Optional[str] = None + endpoint: Optional[str] = None + model: Optional[str] = None + model_version_id: Optional[str] = None + + +@dataclasses.dataclass +class VertexPredictionEndpoint: + """VertexPredictionEndpoint. + + Attributes: + publisher_model: 1P publisher model resource name. Format: + ``publishers/google/models/{model}`` or + ``projects/{project}/locations/{location}/publishers/google/models/{model}`` + endpoint: 1P fine tuned embedding model resource name. Format: + ``endpoints/{endpoint}`` or + ``projects/{project}/locations/{location}/endpoints/{endpoint}``. + model: + Output only. The resource name of the model that is deployed + on the endpoint. Present only when the endpoint is not a + publisher model. Pattern: + ``projects/{project}/locations/{location}/models/{model}`` + model_version_id: + Output only. Version ID of the model that is + deployed on the endpoint. Present only when the + endpoint is not a publisher model. + """ + + endpoint: Optional[str] = None + publisher_model: Optional[str] = None + model: Optional[str] = None + model_version_id: Optional[str] = None + + +@dataclasses.dataclass +class RagEmbeddingModelConfig: + """RagEmbeddingModelConfig. + + Attributes: + vertex_prediction_endpoint: The Vertex AI Prediction Endpoint resource + name. Format: + ``projects/{project}/locations/{location}/endpoints/{endpoint}`` + """ + + vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None + + +@dataclasses.dataclass +class Weaviate: + """Weaviate. + + Attributes: + weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint + collection_name: The corresponding Weaviate collection this corpus maps to + api_key: The SecretManager resource name for the Weaviate DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + weaviate_http_endpoint: Optional[str] = None + collection_name: Optional[str] = None + api_key: Optional[str] = None + + +@dataclasses.dataclass +class VertexFeatureStore: + """VertexFeatureStore. + + Attributes: + resource_name: The resource name of the FeatureView. Format: + ``projects/{project}/locations/{location}/featureOnlineStores/ + {feature_online_store}/featureViews/{feature_view}`` + """ + + resource_name: Optional[str] = None + + +@dataclasses.dataclass +class VertexVectorSearch: + """VertexVectorSearch. + + Attributes: + index_endpoint (str): + The resource name of the Index Endpoint. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + index (str): + The resource name of the Index. Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + """ + + index_endpoint: Optional[str] = None + index: Optional[str] = None + + +@dataclasses.dataclass +class KNN: + """Config for KNN search.""" + + +@dataclasses.dataclass +class ANN: + """Config for ANN search. + + RagManagedDb uses a tree-based structure to partition data and + facilitate faster searches. As a tradeoff, it requires longer + indexing time and manual triggering of index rebuild via the + ImportRagFiles and UpdateRagCorpus API. + + Attributes: + tree_depth (int): + The depth of the tree-based structure. Only + depth values of 2 and 3 are supported. + + Recommended value is 2 if you have if you have + O(10K) files in the RagCorpus and set this to 3 + if more than that. + + Default value is 2. + leaf_count (int): + Number of leaf nodes in the tree-based structure. Each leaf + node contains groups of closely related vectors along with + their corresponding centroid. + + Recommended value is 10 * sqrt(num of RagFiles in your + RagCorpus). + + Default value is 500. + """ + + tree_depth: Optional[int] = None + leaf_count: Optional[int] = None + + +@dataclasses.dataclass +class RagManagedDb: + """RagManagedDb. + + Attributes: + retrieval_strategy: Performs a KNN or ANN search on RagCorpus. + Default choice is KNN if not specified. + """ + + retrieval_strategy: Optional[Union[KNN, ANN]] = None + + +@dataclasses.dataclass +class Pinecone: + """Pinecone. + + Attributes: + index_name: The Pinecone index name. + api_key: The SecretManager resource name for the Pinecone DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + index_name: Optional[str] = None + api_key: Optional[str] = None + + +@dataclasses.dataclass +class RagManagedVertexVectorSearch: + """RagManagedVertexVectorSearch. + + Attributes: + collection_name: The resource name of the Vector Search 2.0 Collection that + RAG Created for the corpus. Only populated after the corpus is successfully + created. Format: + ``projects/{project}/locations/{location}/collections/{collection_id}`` + """ + + collection_name: Optional[str] = None + + +@dataclasses.dataclass +class VertexAiSearchConfig: + """VertexAiSearchConfig. + + Attributes: + serving_config: The resource name of the Vertex AI Search serving config. + Format: + ``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}`` + or + ``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`` + """ + + serving_config: Optional[str] = None + + +@dataclasses.dataclass +class RagVectorDbConfig: + """RagVectorDbConfig. + + Attributes: + vector_db: Can be one of the following: Weaviate, VertexFeatureStore, + VertexVectorSearch, Pinecone, RagManagedDb, RagManagedVertexVectorSearch. + rag_embedding_model_config: The embedding model config of the Vector DB. + """ + + vector_db: Optional[ + Union[ + Weaviate, + VertexFeatureStore, + VertexVectorSearch, + Pinecone, + RagManagedDb, + RagManagedVertexVectorSearch, + ] + ] = None + rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None + + +@dataclasses.dataclass +class RagResource: + """RagResource. + + The representation of the rag source. It can be used to specify corpus only + or ragfiles. Currently only support one corpus or multiple files from one + corpus. In the future we may open up multiple corpora support. + + Attributes: + rag_corpus: A Rag corpus resource name or corpus id. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` + or ``{rag_corpus_id}``. + rag_files_id: List of Rag file resource name or file ids in the same corpus. Format: + ``{rag_file}``. + """ + + rag_corpus: Optional[str] = None + rag_file_ids: Optional[List[str]] = None + + +@dataclasses.dataclass +class SlackChannel: + """SlackChannel. + + Attributes: + channel_id: The Slack channel ID. + api_key: The SecretManager resource name for the Slack API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + See: https://api.slack.com/tutorials/tracks/getting-a-token. + start_time: The starting timestamp for messages to import. + end_time: The ending timestamp for messages to import. + """ + + channel_id: str + api_key: str + start_time: Optional[timestamp_pb2.Timestamp] = None + end_time: Optional[timestamp_pb2.Timestamp] = None + + +@dataclasses.dataclass +class SlackChannelsSource: + """SlackChannelsSource. + + Attributes: + channels: The Slack channels. + """ + + channels: Sequence[SlackChannel] + + +@dataclasses.dataclass +class JiraQuery: + """JiraQuery. + + Attributes: + email: The Jira email address. + jira_projects: A list of Jira projects to import in their entirety. + custom_queries: A list of custom JQL Jira queries to import. + api_key: The SecretManager version resource name for Jira API access. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + See: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/ + server_uri: The Jira server URI. Format: + ``{server}.atlassian.net`` + """ + + email: str + jira_projects: Sequence[str] + custom_queries: Sequence[str] + api_key: str + server_uri: str + + +@dataclasses.dataclass +class JiraSource: + """JiraSource. + + Attributes: + queries: The Jira queries. + """ + + queries: Sequence[JiraQuery] + + +@dataclasses.dataclass +class SharePointSource: + """SharePointSource. + + Attributes: + sharepoint_folder_path: The path of the SharePoint folder to download + from. + sharepoint_folder_id: The ID of the SharePoint folder to download + from. + drive_name: The name of the drive to download from. + drive_id: The ID of the drive to download from. + client_id: The Application ID for the app registered in + Microsoft Azure Portal. The application must + also be configured with MS Graph permissions + "Files.ReadAll", "Sites.ReadAll" and + BrowserSiteLists.Read.All. + client_secret: The application secret for the app registered + in Azure. + tenant_id: Unique identifier of the Azure Active + Directory Instance. + sharepoint_site_name: The name of the SharePoint site to download + from. This can be the site name or the site id. + """ + + sharepoint_folder_path: Optional[str] = None + sharepoint_folder_id: Optional[str] = None + drive_name: Optional[str] = None + drive_id: Optional[str] = None + client_id: str = None + client_secret: str = None + tenant_id: str = None + sharepoint_site_name: str = None + + +@dataclasses.dataclass +class SharePointSources: + """SharePointSources. + + Attributes: + share_point_sources: The SharePoint sources. + """ + + share_point_sources: Sequence[SharePointSource] + + +@dataclasses.dataclass +class Filter: + """Filter. + + Attributes: + vector_distance_threshold: Only returns contexts with vector + distance smaller than the threshold. + vector_similarity_threshold: Only returns contexts with vector + similarity larger than the threshold. + metadata_filter: String for metadata filtering. + """ + + vector_distance_threshold: Optional[float] = None + vector_similarity_threshold: Optional[float] = None + metadata_filter: Optional[str] = None + + +@dataclasses.dataclass +class HybridSearch: + """HybridSearch. + + Attributes: + alpha: Alpha value controls the weight between dense and + sparse vector search results. The range is [0, 1], while 0 + means sparse vector search only and 1 means dense vector + search only. The default value is 0.5 which balances sparse + and dense vector search equally. + """ + + alpha: Optional[float] = None + + +@dataclasses.dataclass +class LlmRanker: + """LlmRanker. + + Attributes: + model_name: The model name used for ranking. Only Gemini models are + supported for now. + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class RankService: + """RankService. + + Attributes: + model_name: The model name of the rank service. Format: + ``semantic-ranker-512@latest`` + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class Ranking: + """Ranking. + + Attributes: + rank_service: (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.RankService) + Config for Rank Service. + llm_ranker (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking.LlmRanker): + Config for LlmRanker. + """ + + rank_service: Optional[RankService] = None + llm_ranker: Optional[LlmRanker] = None + + +@dataclasses.dataclass +class RagRetrievalConfig: + """RagRetrievalConfig. + + Attributes: + top_k: The number of contexts to retrieve. + filter: Config for filters. + hybrid_search (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.HybridSearch): + Config for Hybrid Search. + ranking (google.cloud.aiplatform_v1beta1.types.RagRetrievalConfig.Ranking): + Config for ranking and reranking. + """ + + top_k: Optional[int] = None + filter: Optional[Filter] = None + hybrid_search: Optional[HybridSearch] = None + ranking: Optional[Ranking] = None + + +@dataclasses.dataclass +class ChunkingConfig: + """ChunkingConfig. + + Attributes: + chunk_size: The size of each chunk. + chunk_overlap: The size of the overlap between chunks. + """ + + chunk_size: int + chunk_overlap: int + + +@dataclasses.dataclass +class TransformationConfig: + """TransformationConfig. + + Attributes: + chunking_config: The chunking config. + """ + + chunking_config: Optional[ChunkingConfig] = None + + +@dataclasses.dataclass +class LayoutParserConfig: + """Configuration for the Document AI Layout Parser Processor. + + Attributes: + processor_name (str): + The full resource name of a Document AI processor or processor + version. The processor must have type `LAYOUT_PARSER_PROCESSOR`. + Format: + - `projects/{project_id}/locations/{location}/processors/{processor_id}` + - `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}` + max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to the + Document AI processor per minute. Consult + https://cloud.google.com/document-ai/quotas and the Quota page for + your project to set an appropriate value here. If unspecified, a + default value of 120 QPM will be used. + global_max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to + the Document AI processor per minute in this project. + Consult https://cloud.google.com/document-ai/quotas and the + Quota page for your project to set an appropriate value + here. If this value is not specified, + max_parsing_requests_per_min will be used by indexing + pipeline as the global limit. + """ + + processor_name: str + max_parsing_requests_per_min: Optional[int] = None + global_max_parsing_requests_per_min: Optional[int] = None + + +@dataclasses.dataclass +class LlmParserConfig: + """Configuration for the LLM Parser Processor. + + Attributes: + model_name (str): + The full resource name of a Vertex AI model. Format: + - `projects/{project_id}/locations/{location}/publishers/google/models/{model_id}` + - `projects/{project_id}/locations/{location}/models/{model_id}` + max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to the + Vertex AI model per minute. Consult + https://cloud.google.com/vertex-ai/generative-ai/docs/quotas and + the Quota page for your project to set an appropriate value here. + If unspecified, a default value of 5000 QPM will be used. + global_max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to + the LLM model per minute in this project. Consult + https://cloud.google.com/vertex-ai/generative-ai/docs/quotas + and your document size to set an appropriate value here. If + this value is not specified, max_parsing_requests_per_min + will be used by indexing pipeline job as the global limit. + custom_parsing_prompt (str): + A custom prompt to use for parsing. + """ + + model_name: str + max_parsing_requests_per_min: Optional[int] = None + global_max_parsing_requests_per_min: Optional[int] = None + custom_parsing_prompt: Optional[str] = None + + +@dataclasses.dataclass +class Scaled: + """Scaled tier offers production grade performance along with + + autoscaling functionality. It is suitable for customers with large + amounts of data or performance sensitive workloads. + """ + + +@dataclasses.dataclass +class Basic: + """Basic tier is a cost-effective and low compute tier suitable for the following cases: + + * Experimenting with RagManagedDb. + * Small data size. + * Latency insensitive workload. + * Only using RAG Engine with external vector DBs. + + NOTE: This is the default tier if not explicitly chosen. + """ + + +@dataclasses.dataclass +class Unprovisioned: + """Disables the RAG Engine service and deletes all your data held within + this service. This will halt the billing of the service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling the + UpdateRagEngineConfig API. + """ + + +@dataclasses.dataclass +class Spanner: + """Switches RAG Engine to use Spanner/RagManagedDb as the backend. + + Attributes: + tier: The tier of the RagManagedDb. The default tier is Basic. + + NOTE: This is the default mode if not explicitly chosen. + """ + + tier: Optional[Union[Basic, Scaled, Unprovisioned]] = None + + +@dataclasses.dataclass +class Serverless: + """Switches RAG Engine to use serverless mode as the backend.""" + + +@dataclasses.dataclass +class RagManagedDbConfig: + """RagManagedDbConfig. + + The config of the RagManagedDb used by RagEngine. + + Attributes: + mode: The choice of backend for your RAG Engine. The default mode is + Spanner with Basic tier. + """ + + mode: Optional[Union[Spanner, Serverless]] = None + + +@dataclasses.dataclass +class RagEngineConfig: + """RagEngineConfig. + + Attributes: + name: Generated resource name for singleton resource. Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + rag_managed_db_config: The config of the RagManagedDb used by RagEngine. + The default tier is Basic. + """ + + name: str + rag_managed_db_config: Optional[RagManagedDbConfig] = None + + +@dataclasses.dataclass +class DocumentCorpus: + """DocumentCorpus.""" + + +@dataclasses.dataclass +class MemoryCorpus: + """MemoryCorpus. + + Attributes: + llm_parser: The LLM parser to use for the memory corpus. + """ + + llm_parser: Optional[LlmParserConfig] = None + + +@dataclasses.dataclass +class RagCorpusTypeConfig: + """CorpusTypeConfig. + + Attributes: + corpus_type_config: Can be one of the following: DocumentCorpus, + MemoryCorpus. + """ + + corpus_type_config: Optional[Union[DocumentCorpus, MemoryCorpus]] = None + + +@dataclasses.dataclass +class RagCorpus: + """RAG corpus(output only). + + Attributes: + name: Generated resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` + display_name: Display name that was configured at client side. + description: The description of the RagCorpus. + corpus_type_config: The corpus type config of the RagCorpus. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. + backend_config: The backend config of the RagCorpus. It can specify a + Vector DB and/or the embedding model config. + encryption_spec: The encryption spec of the RagCorpus. Immutable. + """ + + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None + corpus_type_config: Optional[RagCorpusTypeConfig] = None + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None + backend_config: Optional[RagVectorDbConfig] = None + encryption_spec: Optional[EncryptionSpec] = None + + +@dataclasses.dataclass +class RagMetadataSchemaDetails: + """Data schema details indicates the data type and the data + + struct corresponding to the key of user specified metadata. + + Attributes: + type (str): Type of the metadata. + list_config (RagMetadataSchemaDetails.ListConfig): Config for List data + type. + granularity (str): The granularity associated with this RagMetadataSchema. + search_strategy (RagMetadataSchemaDetails.SearchStrategy): The search + strategy for the metadata value of the key. + """ + + @dataclasses.dataclass + class ListConfig: + """Config for List data type. + + Attributes: + value_schema (RagMetadataSchemaDetails): The value's data type in the + list. + """ + + value_schema: Optional["RagMetadataSchemaDetails"] = None + + @dataclasses.dataclass + class SearchStrategy: + """The search strategy for the metadata value of the key. + + Attributes: + search_strategy_type (str): The search strategy type to be applied on + the metadata key. + """ + + search_strategy_type: Optional[str] = None + + type: Optional[str] = None + list_config: Optional[ListConfig] = None + granularity: Optional[str] = None + search_strategy: Optional[SearchStrategy] = None + + +@dataclasses.dataclass +class RagDataSchema: + """The schema of the user specified metadata. + + Attributes: + name (str): Identifier. Resource name of the data schema. + key (str): Required. The key of this data schema. + schema_details (RagMetadataSchemaDetails): The schema details mapping to + the key. + """ + + name: Optional[str] = None + key: Optional[str] = None + schema_details: Optional[RagMetadataSchemaDetails] = None + + +@dataclasses.dataclass +class MetadataValue: + """The value of metadata. + + Attributes: + string_value (str): The string value. + int_value (int): The int value. + float_value (float): The float value. + bool_value (bool): The bool value. + """ + + string_value: Optional[str] = None + int_value: Optional[int] = None + float_value: Optional[float] = None + bool_value: Optional[bool] = None + + +@dataclasses.dataclass +class RagMetadata: + """Metadata for RagFile provided by users. + + Attributes: + name (str): Identifier. Resource name of the RagMetadata. + user_specified_metadata (UserSpecifiedMetadata): User provided metadata. + """ + + name: Optional[str] = None + user_specified_metadata: Optional["UserSpecifiedMetadata"] = None + + +@dataclasses.dataclass +class UserSpecifiedMetadata: + """Metadata provided by users. + + Attributes: + values (Dict[str, MetadataValue]): Required. The values of the metadata. + """ + + values: dict[str, MetadataValue] diff --git a/agentplatform/rag/__init__.py b/agentplatform/rag/__init__.py new file mode 100644 index 0000000000..f3224c9099 --- /dev/null +++ b/agentplatform/rag/__init__.py @@ -0,0 +1,127 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from agentplatform.rag.rag_data import ( + add_inline_citations_and_references, + create_corpus, + delete_corpus, + delete_file, + get_corpus, + get_file, + get_rag_engine_config, + import_files, + import_files_async, + list_corpora, + list_files, + update_corpus, + update_rag_engine_config, + upload_file, +) +from agentplatform.rag.rag_retrieval import ( + ask_contexts, + async_retrieve_contexts, + retrieval_query, +) +from agentplatform.rag.rag_store import ( + Retrieval, + VertexRagStore, +) +from agentplatform.rag.utils.resources import ( + Basic, + ChunkingConfig, + Filter, + JiraQuery, + JiraSource, + LayoutParserConfig, + LlmParserConfig, + LlmRanker, + Pinecone, + RagCitedGenerationResponse, + RagCorpus, + RagEmbeddingModelConfig, + RagEngineConfig, + RagFile, + RagManagedDb, + RagManagedDbConfig, + RagResource, + RagRetrievalConfig, + RagVectorDbConfig, + RankService, + Ranking, + Scaled, + SharePointSource, + SharePointSources, + SlackChannel, + SlackChannelsSource, + TransformationConfig, + Unprovisioned, + VertexAiSearchConfig, + VertexPredictionEndpoint, + VertexVectorSearch, +) + + +__all__ = ( + "Basic", + "ChunkingConfig", + "Filter", + "JiraQuery", + "JiraSource", + "LayoutParserConfig", + "LlmParserConfig", + "LlmRanker", + "Pinecone", + "RagCorpus", + "RagEmbeddingModelConfig", + "RagEngineConfig", + "RagFile", + "RagCitedGenerationResponse", + "RagManagedDb", + "RagManagedDbConfig", + "RagResource", + "RagRetrievalConfig", + "RagVectorDbConfig", + "Ranking", + "RankService", + "Retrieval", + "Scaled", + "SharePointSource", + "SharePointSources", + "SlackChannel", + "SlackChannelsSource", + "TransformationConfig", + "Unprovisioned", + "VertexAiSearchConfig", + "VertexRagStore", + "VertexPredictionEndpoint", + "VertexVectorSearch", + "ask_contexts", + "create_corpus", + "delete_corpus", + "delete_file", + "get_corpus", + "get_rag_engine_config", + "get_file", + "import_files", + "import_files_async", + "list_corpora", + "list_files", + "retrieval_query", + "async_retrieve_contexts", + "upload_file", + "update_corpus", + "update_rag_engine_config", + "add_inline_citations_and_references", +) diff --git a/agentplatform/rag/rag_data.py b/agentplatform/rag/rag_data.py new file mode 100644 index 0000000000..39e6661a5d --- /dev/null +++ b/agentplatform/rag/rag_data.py @@ -0,0 +1,1152 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""RAG data management SDK.""" + +from typing import Optional, Sequence, Union +from google import auth +from google.api_core import operation_async +from google.auth.transport import requests as google_auth_requests +from google.cloud import aiplatform +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform import utils +from google.cloud.aiplatform_v1 import ( + CreateRagCorpusRequest, + DeleteRagCorpusRequest, + DeleteRagFileRequest, + GetRagCorpusRequest, + GetRagEngineConfigRequest, + GetRagFileRequest, + ImportRagFilesResponse, + ListRagCorporaRequest, + ListRagFilesRequest, + RagCorpus as GapicRagCorpus, + UpdateRagCorpusRequest, + UpdateRagEngineConfigRequest, +) +from google.cloud.aiplatform_v1.services.vertex_rag_data_service.pagers import ( + ListRagCorporaPager, + ListRagFilesPager, +) +from agentplatform.rag.rag_inline_citations import ( + format_bibliography, + populate_cited_chunk_references, +) +from agentplatform.rag.utils import ( + _gapic_utils, +) +from google.cloud.aiplatform_v1.types import EncryptionSpec +from agentplatform.rag.utils.resources import ( + JiraSource, + LayoutParserConfig, + LlmParserConfig, + RagCitedGenerationResponse, + RagCorpus, + RagEngineConfig, + RagFile, + RagVectorDbConfig, + SharePointSources, + SlackChannelsSource, + VertexAiSearchConfig, + TransformationConfig, +) + + +def create_corpus( + display_name: Optional[str] = None, + description: Optional[str] = None, + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ] = None, + encryption_spec: Optional[EncryptionSpec] = None, + timeout: int = 600, +) -> RagCorpus: + """Creates a new RagCorpus resource. + + Example usage: + ``` + import agentplatform + from agentplatform import rag + + agentplatform.init(project="my-project") + + rag_corpus = rag.create_corpus( + display_name="my-corpus-1", + ) + ``` + + Args: + display_name: If not provided, SDK will create one. The display name of + the RagCorpus. The name can be up to 128 characters long and can consist + of any UTF-8 characters. + description: The description of the RagCorpus. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. + Note: backend_config cannot be set if vertex_ai_search_config is + specified. + backend_config: The backend config of the RagCorpus, specifying a data + store and/or embedding model. + encryption_spec: The encryption spec of the RagCorpus. + timeout: Default is 600 seconds. + + Returns: + RagCorpus. + Raises: + RuntimeError: Failed in RagCorpus creation due to exception. + RuntimeError: Failed in RagCorpus creation due to operation error. + """ + if vertex_ai_search_config and backend_config: + raise ValueError( + "Only one of vertex_ai_search_config or backend_config can be set." + ) + + if not display_name: + display_name = "vertex-" + utils.timestamped_unique_name() + parent = initializer.global_config.common_location_path(project=None, location=None) + + rag_corpus = GapicRagCorpus(display_name=display_name, description=description) + + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + elif vertex_ai_search_config: + _gapic_utils.set_vertex_ai_search_config( + vertex_ai_search_config=vertex_ai_search_config, + rag_corpus=rag_corpus, + ) + + if encryption_spec: + _gapic_utils.set_encryption_spec( + encryption_spec=encryption_spec, + rag_corpus=rag_corpus, + ) + + request = CreateRagCorpusRequest( + parent=parent, + rag_corpus=rag_corpus, + ) + client = _gapic_utils.create_rag_data_service_client() + + try: + response = client.create_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in RagCorpus creation due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus(response.result(timeout=timeout)) + + +def update_corpus( + corpus_name: str, + display_name: Optional[str] = None, + description: Optional[str] = None, + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None, + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ] = None, + timeout: int = 600, +) -> RagCorpus: + """Updates a RagCorpus resource. + + It is intended to update 3rd party vector DBs (Vector Search, Vertex AI + Feature Store, Weaviate, Pinecone) but not Vertex RagManagedDb. + + Example usage: + ``` + import agentplatform + from agentplatform import rag + + agentplatform.init(project="my-project") + + rag_corpus = rag.update_corpus( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + display_name="my-corpus-1", + ) + ``` + + Args: + corpus_name: The name of the RagCorpus resource to update. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` or + ``{rag_corpus}``. + display_name: If not provided, the display name will not be updated. The + display name of the RagCorpus. The name can be up to 128 characters long + and can consist of any UTF-8 characters. + description: The description of the RagCorpus. If not provided, the + description will not be updated. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. If + not provided, the Vertex AI Search config will not be updated. + Note: backend_config cannot be set if vertex_ai_search_config is + specified. + backend_config: The backend config of the RagCorpus, specifying a data + store and/or embedding model. + timeout: Default is 600 seconds. + + Returns: + RagCorpus. + Raises: + RuntimeError: Failed in RagCorpus update due to exception. + RuntimeError: Failed in RagCorpus update due to operation error. + """ + if vertex_ai_search_config and backend_config: + raise ValueError( + "Only one of vertex_ai_search_config or backend_config can be set." + ) + + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + if display_name and description: + rag_corpus = GapicRagCorpus( + name=corpus_name, display_name=display_name, description=description + ) + elif display_name: + rag_corpus = GapicRagCorpus(name=corpus_name, display_name=display_name) + elif description: + rag_corpus = GapicRagCorpus(name=corpus_name, description=description) + else: + rag_corpus = GapicRagCorpus(name=corpus_name) + + if backend_config: + _gapic_utils.set_backend_config( + backend_config=backend_config, + rag_corpus=rag_corpus, + ) + + if vertex_ai_search_config: + _gapic_utils.set_vertex_ai_search_config( + vertex_ai_search_config=vertex_ai_search_config, + rag_corpus=rag_corpus, + ) + + request = UpdateRagCorpusRequest( + rag_corpus=rag_corpus, + ) + client = _gapic_utils.create_rag_data_service_client() + + try: + response = client.update_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in RagCorpus update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus_no_embedding_model_config( + response.result(timeout=timeout) + ) + + +def get_corpus(name: str) -> RagCorpus: + """ + Get an existing RagCorpus. + + Args: + name: An existing RagCorpus resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + Returns: + RagCorpus. + """ + corpus_name = _gapic_utils.get_corpus_name(name) + request = GetRagCorpusRequest(name=corpus_name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_corpus(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagCorpus due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_corpus(response) + + +def list_corpora( + page_size: Optional[int] = None, page_token: Optional[str] = None +) -> ListRagCorporaPager: + """ + List all RagCorpora in the same project and location. + + Example usage: + ``` + import agentplatform + from agentplatform import rag + + agentplatform.init(project="my-project") + + # List all corpora. + rag_corpora = list(rag.list_corpora()) + + # Alternatively, return a ListRagCorporaPager. + pager_1 = rag.list_corpora(page_size=10) + # Then get the next page, use the generated next_page_token from the last pager. + pager_2 = rag.list_corpora(page_size=10, page_token=pager_1.next_page_token) + + ``` + Args: + page_size: The standard list page size. Leaving out the page_size + causes all of the results to be returned. + page_token: The standard list page token. + + Returns: + ListRagCorporaPager. + """ + parent = initializer.global_config.common_location_path(project=None, location=None) + request = ListRagCorporaRequest( + parent=parent, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_corpora(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagCorpora due to: ", e) from e + + return pager + + +def delete_corpus(name: str) -> None: + """ + Delete an existing RagCorpus. + + Args: + name: An existing RagCorpus resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + """ + corpus_name = _gapic_utils.get_corpus_name(name) + request = DeleteRagCorpusRequest(name=corpus_name) + + client = _gapic_utils.create_rag_data_service_client() + try: + client.delete_rag_corpus(request=request) + print("Successfully deleted the RagCorpus.") + except Exception as e: + raise RuntimeError("Failed in RagCorpus deletion due to: ", e) from e + return None + + +def upload_file( + corpus_name: str, + path: Union[str, Sequence[str]], + display_name: Optional[str] = None, + description: Optional[str] = None, + transformation_config: Optional[TransformationConfig] = None, + timeout: int = 600, +) -> RagFile: + """ + Synchronous file upload to an existing RagCorpus. + + Example usage: + + ``` + import agentplatform + from agentplatform import rag + + agentplatform.init(project="my-project") + + // Optional. + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + rag_file = rag.upload_file( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + display_name="my_file.txt", + path="usr/home/my_file.txt", + transformation_config=transformation_config, + ) + ``` + + Args: + corpus_name: The name of the RagCorpus resource into which to upload the file. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + path: A local file path. For example, + "usr/home/my_file.txt". + display_name: The display name of the data file. + description: The description of the RagFile. + transformation_config: The config for transforming the RagFile, like chunking. + timeout: Default is 600 seconds. + + Returns: + RagFile. + Raises: + RuntimeError: Failed in RagFile upload. + ValueError: RagCorpus is not found. + RuntimeError: Failed in indexing the RagFile. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + location = initializer.global_config.location + # GAPIC doesn't expose a path (scotty). Use requests API instead + if display_name is None: + display_name = "vertex-" + utils.timestamped_unique_name() + headers = {"X-Goog-Upload-Protocol": "multipart"} + if not initializer.global_config.api_endpoint: + request_endpoint = "{}-{}".format( + location, aiplatform.constants.base.API_BASE_PATH + ) + else: + request_endpoint = initializer.global_config.api_endpoint + upload_request_uri = "https://{}/upload/v1/{}/ragFiles:upload".format( + request_endpoint, + corpus_name, + ) + js_rag_file = {"rag_file": {"display_name": display_name}} + + if description: + js_rag_file["rag_file"]["description"] = description + + if transformation_config and transformation_config.chunking_config: + chunk_size = transformation_config.chunking_config.chunk_size + chunk_overlap = transformation_config.chunking_config.chunk_overlap + js_rag_file["upload_rag_file_config"] = { + "rag_file_transformation_config": { + "rag_file_chunking_config": { + "fixed_length_chunking": { + "chunk_size": chunk_size, + "chunk_overlap": chunk_overlap, + } + } + } + } + + files = { + "metadata": (None, str(js_rag_file)), + "file": open(path, "rb"), + } + credentials = initializer.global_config.credentials + if not credentials: + credentials, _ = auth.default() + authorized_session = google_auth_requests.AuthorizedSession(credentials=credentials) + try: + response = authorized_session.post( + url=upload_request_uri, + files=files, + headers=headers, + timeout=timeout, + ) + except Exception as e: + raise RuntimeError("Failed in uploading the RagFile due to: ", e) from e + + if response.status_code == 404: + raise ValueError( + "RagCorpus '%s' is not found: %s", corpus_name, upload_request_uri + ) + if response.json().get("error"): + raise RuntimeError( + "Failed in indexing the RagFile due to: ", response.json().get("error") + ) + return _gapic_utils.convert_json_to_rag_file(response.json()) + + +def import_files( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + transformation_config: Optional[TransformationConfig] = None, + timeout: int = 600, + max_embedding_requests_per_min: int = 1000, + import_result_sink: Optional[str] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, +) -> ImportRagFilesResponse: + """ + Import files to an existing RagCorpus, wait until completion. + + Example usage: + + ``` + import agentplatform + from agentplatform import rag + from google.protobuf import timestamp_pb2 + + agentplatform.init(project="my-project") + # Google Drive example + paths = [ + "https://drive.google.com/file/d/123", + "https://drive.google.com/drive/folders/456" + ] + # Google Cloud Storage example + paths = ["gs://my_bucket/my_files_dir", ...] + + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + response = rag.import_files( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + transformation_config=transformation_config, + ) + + # Slack example + start_time = timestamp_pb2.Timestamp() + start_time.FromJsonString('2020-12-31T21:33:44Z') + end_time = timestamp_pb2.Timestamp() + end_time.GetCurrentTime() + source = rag.SlackChannelsSource( + channels = [ + SlackChannel("channel1", "api_key1"), + SlackChannel("channel2", "api_key2", start_time, end_time) + ], + ) + # Jira Example + jira_query = rag.JiraQuery( + email="xxx@yyy.com", + jira_projects=["project1", "project2"], + custom_queries=["query1", "query2"], + api_key="api_key", + server_uri="server.atlassian.net" + ) + source = rag.JiraSource( + queries=[jira_query], + ) + + response = rag.import_files( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + source=source, + transformation_config=transformation_config, + ) + + # SharePoint Example. + sharepoint_query = rag.SharePointSource( + sharepoint_folder_path="https://my-sharepoint-site.com/my-folder", + sharepoint_site_name="my-sharepoint-site.com", + client_id="my-client-id", + client_secret="my-client-secret", + tenant_id="my-tenant-id", + drive_id="my-drive-id", + ) + source = rag.SharePointSources( + share_point_sources=[sharepoint_query], + ) + + # Return the number of imported RagFiles after completion. + print(response.imported_rag_files_count) + + # Document AI Layout Parser example. + parser = LayoutParserConfig( + processor_name="projects/my-project/locations/us-central1/processors/my-processor-id", + max_parsing_requests_per_min=120, + ) + response = rag.import_files( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + parser=parser, + ) + + ``` + Args: + corpus_name: The name of the RagCorpus resource into which to import files. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + paths: A list of uris. Eligible uris will be Google Cloud Storage + directory ("gs://my-bucket/my_dir") or a Google Drive url for file + (https://drive.google.com/file/... or folder + "https://drive.google.com/corp/drive/folders/..."). + source: The source of the Slack or Jira import. + Must be either a SlackChannelsSource or JiraSource. + transformation_config: The config for transforming the imported + RagFiles. + max_embedding_requests_per_min: + Optional. The max number of queries per + minute that this job is allowed to make to the + embedding model specified on the corpus. This + value is specific to this job and not shared + across other import jobs. Consult the Quotas + page on the project to set an appropriate value + here. If unspecified, a default value of 1,000 + QPM would be used. + timeout: Default is 600 seconds. + import_result_sink: Either a GCS path to store import results or a + BigQuery table to store import results. The format is + "gs://my-bucket/my/object.ndjson" for GCS or + "bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS + object cannot be used. However, the BigQuery table may or may not + exist - if it does not exist, it will be created. If it does exist, + the schema will be checked and the import results will be appended + to the table. + parser: Document parser to use. Should be either None (default parser), + or a LayoutParserConfig (to parse documents using a Document AI + Layout Parser processor). + Returns: + ImportRagFilesResponse. + """ + if source is not None and paths is not None: + raise ValueError("Only one of source or paths must be passed in at a time") + if source is None and paths is None: + raise ValueError("One of source or paths must be passed in") + if layout_parser is not None and llm_parser is not None: + raise ValueError( + "Only one of layout_parser or llm_parser may be passed in at a time" + ) + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = _gapic_utils.prepare_import_files_request( + corpus_name=corpus_name, + paths=paths, + source=source, + transformation_config=transformation_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + import_result_sink=import_result_sink, + layout_parser=layout_parser, + llm_parser=llm_parser, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.import_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e + + return response.result(timeout=timeout) + + +async def import_files_async( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + transformation_config: Optional[TransformationConfig] = None, + max_embedding_requests_per_min: int = 1000, + import_result_sink: Optional[str] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, +) -> operation_async.AsyncOperation: + """ + Import files to an existing RagCorpus asynchronously. + + Example usage: + + ``` + import agentplatform + from agentplatform import rag + from google.protobuf import timestamp_pb2 + + agentplatform.init(project="my-project") + + # Google Drive example + paths = [ + "https://drive.google.com/file/d/123", + "https://drive.google.com/drive/folders/456" + ] + # Google Cloud Storage example + paths = ["gs://my_bucket/my_files_dir", ...] + + transformation_config = TransformationConfig( + chunking_config=ChunkingConfig( + chunk_size=1024, + chunk_overlap=200, + ), + ) + + response = await rag.import_files_async( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + transformation_config=transformation_config, + ) + + # Slack example + start_time = timestamp_pb2.Timestamp() + start_time.FromJsonString('2020-12-31T21:33:44Z') + end_time = timestamp_pb2.Timestamp() + end_time.GetCurrentTime() + source = rag.SlackChannelsSource( + channels = [ + SlackChannel("channel1", "api_key1"), + SlackChannel("channel2", "api_key2", start_time, end_time) + ], + ) + # Jira Example + jira_query = rag.JiraQuery( + email="xxx@yyy.com", + jira_projects=["project1", "project2"], + custom_queries=["query1", "query2"], + api_key="api_key", + server_uri="server.atlassian.net" + ) + source = rag.JiraSource( + queries=[jira_query], + ) + + response = await rag.import_files_async( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + source=source, + transformation_config=transformation_config, + ) + + # SharePoint Example. + sharepoint_query = rag.SharePointSource( + sharepoint_folder_path="https://my-sharepoint-site.com/my-folder", + sharepoint_site_name="my-sharepoint-site.com", + client_id="my-client-id", + client_secret="my-client-secret", + tenant_id="my-tenant-id", + drive_id="my-drive-id", + ) + source = rag.SharePointSources( + share_point_sources=[sharepoint_query], + ) + + # Document AI Layout Parser example. + parser = LayoutParserConfig( + processor_name="projects/my-project/locations/us-central1/processors/my-processor-id", + max_parsing_requests_per_min=120, + ) + response = rag.import_files_async( + corpus_name="projects/my-project/locations/us-central1/ragCorpora/my-corpus-1", + paths=paths, + parser=parser, + ) + + # Get the result. + await response.result() + + ``` + Args: + corpus_name: The name of the RagCorpus resource into which to import files. + Format: ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + paths: A list of uris. Eligible uris will be Google Cloud Storage + directory ("gs://my-bucket/my_dir") or a Google Drive url for file + (https://drive.google.com/file/... or folder + "https://drive.google.com/corp/drive/folders/..."). + source: The source of the Slack or Jira import. + Must be either a SlackChannelsSource or JiraSource. + transformation_config: The config for transforming the imported + RagFiles. + max_embedding_requests_per_min: + Optional. The max number of queries per + minute that this job is allowed to make to the + embedding model specified on the corpus. This + value is specific to this job and not shared + across other import jobs. Consult the Quotas + page on the project to set an appropriate value + here. If unspecified, a default value of 1,000 + QPM would be used. + import_result_sink: Either a GCS path to store import results or a + BigQuery table to store import results. The format is + "gs://my-bucket/my/object.ndjson" for GCS or + "bq://my-project.my-dataset.my-table" for BigQuery. An existing GCS + object cannot be used. However, the BigQuery table may or may not + exist - if it does not exist, it will be created. If it does exist, + the schema will be checked and the import results will be appended + to the table. + parser: Document parser to use. Should be either None (default parser), + or a LayoutParserConfig (to parse documents using a Document AI + Layout Parser processor). + Returns: + operation_async.AsyncOperation. + """ + if source is not None and paths is not None: + raise ValueError("Only one of source or paths must be passed in at a time") + if source is None and paths is None: + raise ValueError("One of source or paths must be passed in") + if layout_parser is not None and llm_parser is not None: + raise ValueError( + "Only one of layout_parser or llm_parser may be passed in at a time" + ) + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = _gapic_utils.prepare_import_files_request( + corpus_name=corpus_name, + paths=paths, + source=source, + transformation_config=transformation_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + import_result_sink=import_result_sink, + layout_parser=layout_parser, + llm_parser=llm_parser, + ) + async_client = _gapic_utils.create_rag_data_service_async_client() + try: + response = await async_client.import_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in importing the RagFiles due to: ", e) from e + return response + + +def get_file(name: str, corpus_name: Optional[str] = None) -> RagFile: + """ + Get an existing RagFile. + + Args: + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + Returns: + RagFile. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) + request = GetRagFileRequest(name=name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_file(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagFile due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_file(response) + + +def list_files( + corpus_name: str, page_size: Optional[int] = None, page_token: Optional[str] = None +) -> ListRagFilesPager: + """ + List all RagFiles in an existing RagCorpus. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + # List all corpora. + rag_corpora = list(rag.list_corpora()) + + # List all files of the first corpus. + rag_files = list(rag.list_files(corpus_name=rag_corpora[0].name)) + + # Alternatively, return a ListRagFilesPager. + pager_1 = rag.list_files( + corpus_name=rag_corpora[0].name, + page_size=10 + ) + # Then get the next page, use the generated next_page_token from the last pager. + pager_2 = rag.list_files( + corpus_name=rag_corpora[0].name, + page_size=10, + page_token=pager_1.next_page_token + ) + + ``` + + Args: + corpus_name: An existing RagCorpus name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + page_size: The standard list page size. Leaving out the page_size + causes all of the results to be returned. + page_token: The standard list page token. + Returns: + ListRagFilesPager. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + request = ListRagFilesRequest( + parent=corpus_name, + page_size=page_size, + page_token=page_token, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + pager = client.list_rag_files(request=request) + except Exception as e: + raise RuntimeError("Failed in listing the RagFiles due to: ", e) from e + + return pager + + +def delete_file(name: str, corpus_name: Optional[str] = None) -> None: + """ + Delete RagFile from an existing RagCorpus. + + Args: + name: Either a full RagFile resource name must be provided, or a RagCorpus + name and a RagFile name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`` + or ``{rag_file}``. + corpus_name: If `name` is not a full resource name, an existing RagCorpus + name must be provided. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`` + or ``{rag_corpus}``. + """ + corpus_name = _gapic_utils.get_corpus_name(corpus_name) + name = _gapic_utils.get_file_name(name, corpus_name) + request = DeleteRagFileRequest(name=name) + + client = _gapic_utils.create_rag_data_service_client() + try: + client.delete_rag_file(request=request) + print("Successfully deleted the RagFile.") + except Exception as e: + raise RuntimeError("Failed in RagFile deletion due to: ", e) from e + return None + + +def update_rag_engine_config( + rag_engine_config: RagEngineConfig, + timeout: int = 600, +) -> RagEngineConfig: + """Update RagEngineConfig. + + Example usage: + ``` + import agentplatform + from agentplatform import rag + agentplatform.init(project="my-project") + rag_engine_config = rag.RagEngineConfig( + rag_managed_db_config=rag.RagManagedDbConfig( + rag_managed_db=rag.RagManagedDb( + db_basic_tier=rag.Basic(), + ), + ) + ), + ) + rag.update_rag_engine_config(rag_engine_config=rag_engine_config) + ``` + + Args: + rag_engine_config: The RagEngineConfig to update. + timeout: Default is 600 seconds. + + Raises: + RuntimeError: Failed in RagEngineConfig update due to exception. + """ + gapic_rag_engine_config = _gapic_utils.convert_rag_engine_config_to_gapic( + rag_engine_config + ) + request = UpdateRagEngineConfigRequest( + rag_engine_config=gapic_rag_engine_config, + ) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.update_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in RagEngineConfig update due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config( + response.result(timeout=timeout) + ) + + +def get_rag_engine_config(name: str) -> RagEngineConfig: + """Get an existing RagEngineConfig. + + Example usage: + ``` + import agentplatform + from agentplatform import rag + agentplatform.init(project="my-project") + rag_engine_config = rag.get_rag_engine_config( + name="projects/my-project/locations/us-central1/ragEngineConfig" + ) + ``` + Args: + name: The RagEngineConfig resource name pattern of the singleton resource. + + Returns: + RagEngineConfig. + Raises: + RuntimeError: Failed in getting the RagEngineConfig. + """ + request = GetRagEngineConfigRequest(name=name) + client = _gapic_utils.create_rag_data_service_client() + try: + response = client.get_rag_engine_config(request=request) + except Exception as e: + raise RuntimeError("Failed in getting the RagEngineConfig due to: ", e) from e + return _gapic_utils.convert_gapic_to_rag_engine_config(response) + + +def add_inline_citations_and_references( + original_text_str, grounding_supports, grounding_chunks +) -> RagCitedGenerationResponse: + """Adds inline citations to a text string based on grounding_supports and grounding_chunks. + + Args: + original_text_str (str): The text (as a Unicode string) to which citations + will be added. + grounding_supports (list): A list of objects, where each object represents + a grounding support and has attributes: - segment: An object with + 'end_index' (byte offset relative to UTF-8). - grounding_chunk_indices: + A list of integers. + grounding_chunks (list): A list of objects, where each object is a source + chunk wrapper. To get URI: obj.retrieved_context.uri. To get page + span: obj.retrieved_context.rag_chunk.page_span. + + Returns: + RagCitedGenerationResponse: An object containing the text with inline + citations and a formatted bibliography string. + + Raises: + TypeError: If original_text_str is not a string, or if grounding_supports, + grounding_chunks, or internally generated reference dictionaries + are of an unexpected type (raised by this function or helpers). + ValueError: If original_text_str has encoding/decoding issues, + if segment data in grounding_supports is invalid, + if calculated insertion indices are out of bounds, + or if essential data within grounding_chunks (like the chunks + themselves, retrieved_context, or URI) is None when a value + is expected (raised by this function or helpers). + IndexError: If chunk indices used for citation or bibliography generation + are out of bounds for the provided grounding_chunks (raised by helpers). + AttributeError: If expected attributes (e.g., 'retrieved_context', 'uri') + are missing from the data structures within grounding_chunks + (raised by helpers). + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + rag_retrieval_tool = Tool.from_retrieval( + retrieval=rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[ + rag.RagResource( + rag_corpus=corpus_name, + # Optional: supply IDs from `rag.list_files()`. + # rag_file_ids=["rag-file-1", "rag-file-2", ...], + ) + ], + rag_retrieval_config=rag.RagRetrievalConfig( + top_k=10, + filter=rag.utils.resources.Filter(vector_distance_threshold=0.5), + ), + ), + ) + ) + + rag_model = GenerativeModel( + model_name="gemini-2.5-pro-preview-05-06", tools=[rag_retrieval_tool] + ) + response = rag_model.generate_content("Why is the sky blue?") + + rag_cited_generation_response = rag.add_inline_citations_and_references( + original_text_str=response.candidates[0].content.parts[0].text, + grounding_supports=list(response.candidates[0].grounding_metadata.grounding_supports), + grounding_chunks=list(response.candidates[0].grounding_metadata.grounding_chunks), + ) + print(rag_cited_generation_response.cited_text) + print(rag_cited_generation_response.final_bibliography) + + """ + if not isinstance(original_text_str, str): + raise TypeError("original_text_str must be a string.") + if not isinstance(grounding_supports, list): + raise TypeError("grounding_supports must be a list.") + if not isinstance(grounding_chunks, list): + raise TypeError("grounding_chunks must be a list.") + + cited_chunk_references = {} + + populate_cited_chunk_references( + grounding_supports, grounding_chunks, cited_chunk_references + ) + + # If there are no grounding supports, return original text and empty bibliography. + if not grounding_supports: + return _gapic_utils.convert_tuple_to_rag_cited_generation_response( + original_text_str, "" + ) + + try: + original_text_bytes_equivalent = original_text_str.encode("utf-8") + except UnicodeEncodeError as e: + raise ValueError( + "Could not encode original_text_str to UTF-8 for index" + f" conversion: {e}. Cannot process citations." + ) from e + + insertions = [] + + for support_idx, support in enumerate(grounding_supports): + current_support_chunk_indices = [] + if ( + hasattr(support, "grounding_chunk_indices") + and support.grounding_chunk_indices is not None + ): + # Ensure indices are integers and unique for the citation string + valid_indices = [ + idx for idx in support.grounding_chunk_indices if isinstance(idx, int) + ] + current_support_chunk_indices = sorted(list(set(valid_indices))) + + # Validate segment and end_index + if not ( + hasattr(support, "segment") + and support.segment + and hasattr(support.segment, "end_index") + and isinstance(support.segment.end_index, int) + and support.segment.end_index >= 0 # Ensure end_index is not negative + ): + raise ValueError(f"Invalid segment data for support at index {support_idx}") + + byte_end_idx = support.segment.end_index + + if byte_end_idx > len(original_text_bytes_equivalent): + byte_end_idx = len(original_text_bytes_equivalent) + + char_end_idx = -1 # Initialize before try block + try: + # Slice the byte string up to the (potentially clamped) byte_end_idx + prefix_bytes = original_text_bytes_equivalent[:byte_end_idx] + # Decode this prefix to find the corresponding character length + char_end_idx = len(prefix_bytes.decode("utf-8")) + except UnicodeDecodeError as e: + # If decoding the prefix fails, we cannot reliably determine char_end_idx. + raise ValueError( + "Could not decode prefix of original_text_str (up to byte" + f" {byte_end_idx}) from UTF-8 for support at index {support_idx}:" + f" {e}. Cannot accurately place citation." + ) from e + except IndexError as e: # Should be less common with slice clamping + raise ValueError( + "Index error during byte-to-char conversion for support at index" + f" {support_idx}: {e}. This might indicate an issue with byte_end_idx" + " calculation." + ) from e + + if char_end_idx > len(original_text_str): + char_end_idx = len(original_text_str) + + citation_str = "".join([f"[{idx}]" for idx in current_support_chunk_indices]) + + if citation_str: + insertions.append( + { + "char_index": char_end_idx, + "citation": citation_str, + "original_byte_idx": byte_end_idx, + } + ) + + # Sort insertions by character index in reverse order to insert from end to start, + # preserving earlier indices. + insertions.sort( + key=lambda x: (x["char_index"], x["original_byte_idx"]), reverse=True + ) + + modified_text_list = list(original_text_str) + for insertion_info in insertions: + idx = insertion_info["char_index"] + citation = insertion_info["citation"] + if 0 <= idx <= len(modified_text_list): + modified_text_list.insert(idx, citation) + else: + raise ValueError( + f"Calculated insertion index {idx} is out of bounds for" + f" modified_text_list (length {len(modified_text_list)}). Citation:" + f" '{citation}'" + ) + + cited_text_output = "".join(modified_text_list) + references_output_string = format_bibliography( + cited_chunk_references, grounding_chunks + ) + return _gapic_utils.convert_tuple_to_rag_cited_generation_response( + cited_text_output, references_output_string + ) diff --git a/agentplatform/rag/rag_inline_citations.py b/agentplatform/rag/rag_inline_citations.py new file mode 100644 index 0000000000..66bc928531 --- /dev/null +++ b/agentplatform/rag/rag_inline_citations.py @@ -0,0 +1,176 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Helper functions for processing and formatting citations from RAG generation outputs.""" + + +def populate_cited_chunk_references( + grounding_supports, + grounding_chunks, + cited_refs_dict, +) -> None: + """Populates cited_refs_dict with URI information for all unique chunk indices found in grounding_supports. + + Args: + grounding_supports: A list of support items, where each item might contain + grounding chunk indices. + grounding_chunks: A list of all available chunk items from which to retrieve + context and URI. + cited_refs_dict: A dictionary to populate with chunk_idx as key and URI + as value. + + Raises: + TypeError: If grounding_chunks is not a list, or cited_refs_dict is not a + dictionary. + ValueError: If grounding_chunks or cited_refs_dict is None. + If a chunk_item at a valid index is None. + If 'retrieved_context' or 'uri' attribute of a chunk is None. + IndexError: If a chunk_idx is out of bounds for grounding_chunks. + AttributeError: If 'retrieved_context' or 'uri' attribute is missing from + a chunk or its context. + """ + if grounding_chunks is None or not grounding_chunks: + raise ValueError("grounding_chunks cannot be None or empty.") + if grounding_supports is None or not grounding_supports: + raise ValueError("grounding_supports cannot be None or empty.") + if not isinstance(grounding_chunks, list): + raise TypeError("grounding_chunks must be a list.") + if not isinstance(grounding_supports, list): + raise TypeError("grounding_supports must be a list.") + if cited_refs_dict is None: + raise ValueError("cited_refs_dict cannot be None.") + if not isinstance(cited_refs_dict, dict): + raise TypeError("cited_refs_dict must be a dictionary.") + + for support in grounding_supports: + current_support_chunk_indices = [] + if ( + hasattr(support, "grounding_chunk_indices") + and support.grounding_chunk_indices is not None + ): + valid_indices = [ + idx for idx in support.grounding_chunk_indices if isinstance(idx, int) + ] + current_support_chunk_indices = sorted(list(set(valid_indices))) + + for chunk_idx in current_support_chunk_indices: + if chunk_idx not in cited_refs_dict: + if not (0 <= chunk_idx < len(grounding_chunks)): + raise IndexError( + f"Chunk index {chunk_idx} is out of bounds for grounding_chunks of size {len(grounding_chunks)}." + ) + chunk_item = grounding_chunks[chunk_idx] + if chunk_item is None: + raise ValueError(f"Chunk item at index {chunk_idx} is None.") + if not hasattr(chunk_item, "retrieved_context"): + raise AttributeError( + f"Chunk item at index {chunk_idx} is missing 'retrieved_context' attribute." + ) + retrieved_context_obj = chunk_item.retrieved_context + if retrieved_context_obj is None: + raise ValueError( + f"Attribute 'retrieved_context' for chunk {chunk_idx} is None." + ) + if not hasattr(retrieved_context_obj, "uri"): + raise AttributeError( + f"retrieved_context for chunk {chunk_idx} is missing 'uri' attribute." + ) + uri = retrieved_context_obj.uri + if uri is None: + raise ValueError(f"Attribute 'uri' for chunk {chunk_idx} is None.") + cited_refs_dict[chunk_idx] = uri + + +def format_bibliography(cited_refs_dict, grounding_chunks) -> str: + """Formats the bibliography string from the populated cited_refs_dict. + + Omits page information if page numbers are not valid (e.g., not >= 1). + + Args: + cited_refs_dict: A dictionary with chunk_idx as key and URI as value. + It's expected that populate_cited_chunk_references has successfully + populated this dict. + grounding_chunks: A list of all available chunk items, used to retrieve page + span information. + + Returns: + A string representing the formatted bibliography, with each reference + on a new line. + + Raises: + TypeError: If cited_refs_dict is not a dictionary or grounding_chunks is not a list. + ValueError: If cited_refs_dict or grounding_chunks is None. + If a chunk_item in grounding_chunks referenced by cited_refs_dict is None. + IndexError: If a chunk_idx from cited_refs_dict is out of bounds for + grounding_chunks. + """ + if cited_refs_dict is None: + raise ValueError("cited_refs_dict cannot be None.") + if not isinstance(cited_refs_dict, dict): + raise TypeError("cited_refs_dict must be a dictionary.") + if grounding_chunks is None: + raise ValueError("grounding_chunks cannot be None.") + if not isinstance(grounding_chunks, list): + raise TypeError("grounding_chunks must be a list.") + + reference_lines = [] + for chunk_idx_ref in sorted(list(cited_refs_dict.keys())): + uri = cited_refs_dict[chunk_idx_ref] + page_info_str = "" + if not ( + isinstance(chunk_idx_ref, int) + and 0 <= chunk_idx_ref < len(grounding_chunks) + ): + raise IndexError( + f"Chunk index {chunk_idx_ref} from cited_refs_dict is invalid or out of bounds " + f"for grounding_chunks of size {len(grounding_chunks)}." + ) + chunk_item = grounding_chunks[chunk_idx_ref] + if chunk_item is None: + raise ValueError( + f"Chunk item at index {chunk_idx_ref} in grounding_chunks is None, " + "but was referenced in cited_refs_dict." + ) + page_span_data = None + if ( + hasattr(chunk_item, "retrieved_context") + and chunk_item.retrieved_context + and hasattr(chunk_item.retrieved_context, "rag_chunk") + and chunk_item.retrieved_context.rag_chunk + and hasattr(chunk_item.retrieved_context.rag_chunk, "page_span") + and chunk_item.retrieved_context.rag_chunk.page_span + ): + page_span_data = chunk_item.retrieved_context.rag_chunk.page_span + if ( + page_span_data + and hasattr(page_span_data, "first_page") + and hasattr(page_span_data, "last_page") + ): + first_page_val = page_span_data.first_page + last_page_val = page_span_data.last_page + is_first_page_valid_num = ( + isinstance(first_page_val, int) and first_page_val >= 1 + ) + is_last_page_valid_num = ( + isinstance(last_page_val, int) and last_page_val >= 1 + ) + if is_first_page_valid_num and is_last_page_valid_num: + if last_page_val >= first_page_val: + page_info_str = ( + f", p.{first_page_val}-{last_page_val}" + if first_page_val != last_page_val + else f", p.{first_page_val}" + ) + reference_lines.append(f"[{chunk_idx_ref}] {uri}{page_info_str}") + return "\n".join(reference_lines) diff --git a/agentplatform/rag/rag_retrieval.py b/agentplatform/rag/rag_retrieval.py new file mode 100644 index 0000000000..70ead25de4 --- /dev/null +++ b/agentplatform/rag/rag_retrieval.py @@ -0,0 +1,505 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Retrieval query to get relevant contexts.""" + +import re +from typing import List, Optional + +from google.cloud import aiplatform_v1 +from google.cloud.aiplatform import initializer +from agentplatform.rag.utils import _gapic_utils +from agentplatform.rag.utils import resources + +from google.protobuf import any_pb2 + + +def retrieval_query( + text: str, + parent_override: Optional[str] = None, + api_path_override: Optional[str] = None, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, +) -> aiplatform_v1.RetrieveContextsResponse: + """Retrieve top k relevant docs/chunks. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.rag.RagRetrievalConfig( + top_k=2, + filter=agentplatform.rag.Filter( + vector_distance_threshold=0.5 + ), + ranking=agentplatform.rag.Ranking( + llm_ranker=rag.LlmRanker( + model_name="gemini-2.5-flash" + ) + ) + ) + + results = agentplatform.rag.retrieval_query( + text="Why is the sky blue?", + rag_resources=[agentplatform.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: The query in text format to get relevant contexts. + parent_override: Optional. The resource path of the parent. + api_path_override: Optional. The base API endpoint to use for the request. + rag_resources: A list of RagResource. It can be used to specify corpus + only or ragfiles. Currently only support one corpus or multiple files + from one corpus. In the future we may open up multiple corpora support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including similarity_top_k and vector_distance_threshold + + Returns: + RetrieveContextsResonse. + """ + parent = initializer.global_config.common_location_path() + if parent_override: + parent = parent_override + + client = _gapic_utils.create_rag_service_client(api_path_override) + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + else: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client(api_path_override) + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + + if rag_resources: + gapic_rag_resource = ( + aiplatform_v1.RetrieveContextsRequest.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + ) + vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore( + rag_resources=[gapic_rag_resource], + ) + else: + vertex_rag_store = aiplatform_v1.RetrieveContextsRequest.VertexRagStore( + rag_corpora=[rag_corpus_name], + ) + + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + api_retrieval_config.top_k = rag_retrieval_config.top_k + # Set vector_distance_threshold to config value if specified + if rag_retrieval_config.filter: + # Check if both vector_distance_threshold and vector_similarity_threshold + # are specified. + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + api_retrieval_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + api_retrieval_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrieval_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrieval_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + + query = aiplatform_v1.RagQuery( + text=text, + rag_retrieval_config=api_retrieval_config, + ) + request = aiplatform_v1.RetrieveContextsRequest( + vertex_rag_store=vertex_rag_store, + parent=parent, + query=query, + ) + try: + response = client.retrieve_contexts(request=request) + except Exception as e: + raise RuntimeError("Failed in retrieving contexts due to: ", e) from e + + return response + + +async def async_retrieve_contexts( + text: str, + parent_override: Optional[str] = None, + api_path_override: Optional[str] = None, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + timeout: int = 600, +) -> aiplatform_v1.RetrieveContextsResponse: + """Retrieve top k relevant docs/chunks asynchronously. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.rag.RagRetrievalConfig( + top_k=2, + ) + + results = await agentplatform.rag.async_retrieve_contexts( + text="Why is the sky blue?", + rag_resources=[agentplatform.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + parent_override: Optional. The parent resource name to use for the API + request. If not specified, the parent is determined from the global + configuration. + api_path_override: Optional. The API path override to use for the API + request. If not specified, the path is determined from the global + configuration. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k. + timeout: Optional. The timeout in seconds for the request. + + Returns: + RetrieveContextsResponse. + """ + if parent_override: + parent = parent_override + else: + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_async_client( + api_path_override=api_path_override + ) + + if not rag_resources: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client( + api_path_override=api_path_override + ) + + gapic_rag_resources = [] + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + + vertex_rag_store = aiplatform_v1.VertexRagStore( + rag_resources=gapic_rag_resources, + ) + + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + api_retrieval_config.top_k = rag_retrieval_config.top_k + # Set vector_distance_threshold to config value if specified + if rag_retrieval_config.filter: + # Check if both vector_distance_threshold and vector_similarity_threshold + # are specified. + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + api_retrieval_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + api_retrieval_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrieval_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrieval_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + + query = aiplatform_v1.RagQuery( + text=text, + rag_retrieval_config=api_retrieval_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrieval_config + + tool = aiplatform_v1.Tool( + retrieval=aiplatform_v1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1.AsyncRetrieveContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + response_lro = await client.async_retrieve_contexts( + request=request, timeout=timeout + ) + try: + response = await response_lro.result(timeout=timeout) + except Exception as e: + if response_lro.done(): + raw_op = response_lro.operation + if raw_op.WhichOneof("result") == "response": + any_response = raw_op.response + inner_any = any_pb2.Any() + if any_response.Unpack(inner_any): + inner_any.type_url = ( + "type.googleapis.com/google.cloud.aiplatform.v1.RagContexts" + ) + rag_contexts = aiplatform_v1.RagContexts() + if inner_any.Unpack(rag_contexts._pb): + return aiplatform_v1.AsyncRetrieveContextsResponse( + contexts=rag_contexts + ) + raise e + except Exception as e: + raise RuntimeError( + "Failed in retrieving contexts asynchronously due to: ", e + ) from e + + return response + + +def ask_contexts( + text: str, + parent_override: Optional[str] = None, + api_path_override: Optional[str] = None, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + timeout: int = 600, +) -> aiplatform_v1.AskContextsResponse: + """Ask questions on top k relevant docs/chunks. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.rag.RagRetrievalConfig( + top_k=2, + ) + + results = agentplatform.rag.ask_contexts( + text="Why is the sky blue?", + rag_resources=[agentplatform.rag.RagResource( + rag_corpus="projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1", + rag_file_ids=["rag-file-1", "rag-file-2", ...], + )], + rag_retrieval_config=config, + ) + ``` + + Args: + text: Required. The query in text format to get relevant contexts. + parent_override: Optional. The parent resource name to use for the API + request. If not specified, the parent is determined from the global + configuration. + api_path_override: Optional. The API path override to use for the API + request. If not specified, the path is determined from the global + configuration. + rag_resources: Optional. A list of RagResource. It can be used to specify + corpus only or ragfiles. Currently only support one corpus or multiple + files from one corpus. In the future we may open up multiple corpora + support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including top_k. + timeout: Optional. The timeout in seconds for the request. + + Returns: + AskContextsResponse. + """ + if parent_override: + parent = parent_override + else: + parent = initializer.global_config.common_location_path() + + client = _gapic_utils.create_rag_service_client(api_path_override=api_path_override) + + if not rag_resources: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client( + api_path_override=api_path_override + ) + + gapic_rag_resources = [] + for rag_resource in rag_resources: + name = rag_resource.rag_corpus + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + gapic_rag_resources.append( + aiplatform_v1.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resource.rag_file_ids, + ) + ) + + vertex_rag_store = aiplatform_v1.VertexRagStore( + rag_resources=gapic_rag_resources, + ) + + # If rag_retrieval_config is not specified, set it to default values. + if not rag_retrieval_config: + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + else: + # If rag_retrieval_config is specified, check for missing parameters. + api_retrieval_config = aiplatform_v1.RagRetrievalConfig() + api_retrieval_config.top_k = rag_retrieval_config.top_k + # Set vector_distance_threshold to config value if specified + if rag_retrieval_config.filter: + # Check if both vector_distance_threshold and vector_similarity_threshold + # are specified. + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + api_retrieval_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + api_retrieval_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.llm_ranker + ): + raise ValueError("Only one of rank_service and llm_ranker can be set.") + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.rank_service: + api_retrieval_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + elif rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrieval_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + + query = aiplatform_v1.RagQuery( + text=text, + rag_retrieval_config=api_retrieval_config, + ) + + vertex_rag_store.rag_retrieval_config = api_retrieval_config + + tool = aiplatform_v1.Tool( + retrieval=aiplatform_v1.Retrieval( + vertex_rag_store=vertex_rag_store, + ) + ) + + request = aiplatform_v1.AskContextsRequest( + parent=parent, + query=query, + tools=[tool], + ) + try: + response = client.ask_contexts(request=request, timeout=timeout) + except Exception as e: + raise RuntimeError("Failed in asking contexts due to: ", e) from e + + return response diff --git a/agentplatform/rag/rag_store.py b/agentplatform/rag/rag_store.py new file mode 100644 index 0000000000..b8f0b93f49 --- /dev/null +++ b/agentplatform/rag/rag_store.py @@ -0,0 +1,165 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""RAG retrieval tool for content generation.""" + +import re +from typing import List, Optional, Union + +from google.cloud import aiplatform_v1beta1 +from agentplatform.rag.utils import _gapic_utils +from agentplatform.rag.utils import resources +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform_v1beta1.types import tool as gapic_tool_types + + +class Retrieval: + """Defines a retrieval tool that a model can call to access external knowledge.""" + + def __init__( + self, + source: Union["VertexRagStore"], + disable_attribution: Optional[bool] = False, + ): + self._raw_retrieval = gapic_tool_types.Retrieval( + vertex_rag_store=source._raw_vertex_rag_store, + disable_attribution=disable_attribution, + ) + + +class VertexRagStore: + """Retrieve from Vertex RAG Store.""" + + def __init__( + self, + rag_resources: Optional[List[resources.RagResource]] = None, + rag_retrieval_config: Optional[resources.RagRetrievalConfig] = None, + ): + """Initializes a Vertex RAG store tool. + + Example usage: + ``` + import agentplatform + + agentplatform.init(project="my-project") + + config = agentplatform.rag.RagRetrievalConfig( + top_k=2, + filter=agentplatform.rag.RagRetrievalConfig.Filter( + vector_distance_threshold=0.5 + ), + ranking=vertex.rag.Ranking( + llm_ranker=agentplatform.rag.LlmRanker( + model_name="gemini-1.5-flash-002" + ) + ) + ) + + tool = Tool.from_retrieval( + retrieval=agentplatform.rag.Retrieval( + source=agentplatform.rag.VertexRagStore( + rag_corpora=["projects/my-project/locations/us-central1/ragCorpora/rag-corpus-1"], + rag_retrieval_config=config, + ), + ) + ) + ``` + + Args: + rag_resources: List of RagResource to retrieve from. It can be used + to specify corpus only or ragfiles. Currently only support one + corpus or multiple files from one corpus. In the future we + may open up multiple corpora support. + rag_retrieval_config: Optional. The config containing the retrieval + parameters, including similarity_top_k and vector_distance_threshold. + """ + + if rag_resources: + if len(rag_resources) > 1: + raise ValueError("Currently only support 1 RagResource.") + name = rag_resources[0].rag_corpus + else: + raise ValueError("rag_resources must be specified.") + + data_client = _gapic_utils.create_rag_data_service_client() + if data_client.parse_rag_corpus_path(name): + rag_corpus_name = name + elif re.match("^{}$".format(_gapic_utils._VALID_RESOURCE_NAME_REGEX), name): + parent = initializer.global_config.common_location_path() + rag_corpus_name = parent + "/ragCorpora/" + name + else: + raise ValueError( + f"Invalid RagCorpus name: {name}. Proper format should be:" + " projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}" + ) + + # If rag_retrieval_config is not specified, set it to default values. + api_retrieval_config = aiplatform_v1beta1.RagRetrievalConfig() + # If rag_retrieval_config is specified, populate the default config. + if rag_retrieval_config: + api_retrieval_config.top_k = rag_retrieval_config.top_k + # Set vector_distance_threshold to config value if specified + if rag_retrieval_config.filter: + # Check if both vector_distance_threshold and + # vector_similarity_threshold are specified. + if ( + rag_retrieval_config.filter + and rag_retrieval_config.filter.vector_distance_threshold + and rag_retrieval_config.filter.vector_similarity_threshold + ): + raise ValueError( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) + api_retrieval_config.filter.vector_distance_threshold = ( + rag_retrieval_config.filter.vector_distance_threshold + ) + api_retrieval_config.filter.vector_similarity_threshold = ( + rag_retrieval_config.filter.vector_similarity_threshold + ) + # Check if both rank_service and llm_ranker are specified. + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + and rag_retrieval_config.ranking.rank_service.model_name + and rag_retrieval_config.ranking.llm_ranker + and rag_retrieval_config.ranking.llm_ranker.model_name + ): + raise ValueError( + "Only one of rank_service or llm_ranker can be specified" + " at a time in rag_retrieval_config." + ) + # Set rank_service to config value if specified + if ( + rag_retrieval_config.ranking + and rag_retrieval_config.ranking.rank_service + ): + api_retrieval_config.ranking.rank_service.model_name = ( + rag_retrieval_config.ranking.rank_service.model_name + ) + # Set llm_ranker to config value if specified + if rag_retrieval_config.ranking and rag_retrieval_config.ranking.llm_ranker: + api_retrieval_config.ranking.llm_ranker.model_name = ( + rag_retrieval_config.ranking.llm_ranker.model_name + ) + + gapic_rag_resource = gapic_tool_types.VertexRagStore.RagResource( + rag_corpus=rag_corpus_name, + rag_file_ids=rag_resources[0].rag_file_ids, + ) + self._raw_vertex_rag_store = gapic_tool_types.VertexRagStore( + rag_resources=[gapic_rag_resource], + rag_retrieval_config=api_retrieval_config, + ) diff --git a/agentplatform/rag/utils/_gapic_utils.py b/agentplatform/rag/utils/_gapic_utils.py new file mode 100644 index 0000000000..1077f47592 --- /dev/null +++ b/agentplatform/rag/utils/_gapic_utils.py @@ -0,0 +1,824 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import re +from typing import Any, Dict, Optional, Sequence, Union +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.utils import ( + VertexRagAsyncClientWithOverride, + VertexRagClientWithOverride, + VertexRagDataAsyncClientWithOverride, + VertexRagDataClientWithOverride, +) +from google.cloud.aiplatform_v1 import ( + GoogleDriveSource, + ImportRagFilesConfig, + ImportRagFilesRequest, + JiraSource as GapicJiraSource, + RagCorpus as GapicRagCorpus, + RagEmbeddingModelConfig as GapicRagEmbeddingModelConfig, + RagEngineConfig as GapicRagEngineConfig, + RagFileChunkingConfig, + RagFileParsingConfig, + RagFileTransformationConfig, + RagFile as GapicRagFile, + RagManagedDbConfig as GapicRagManagedDbConfig, + RagVectorDbConfig as GapicRagVectorDbConfig, + SharePointSources as GapicSharePointSources, + SlackSource as GapicSlackSource, + VertexAiSearchConfig as GapicVertexAiSearchConfig, +) +from google.cloud.aiplatform_v1.types import api_auth +from google.cloud.aiplatform_v1.types import EncryptionSpec +from agentplatform.rag.utils.resources import ( + Basic, + JiraSource, + LayoutParserConfig, + LlmParserConfig, + Pinecone, + RagCitedGenerationResponse, + RagCorpus, + RagEmbeddingModelConfig, + RagEngineConfig, + RagFile, + RagManagedDb, + RagManagedDbConfig, + RagVectorDbConfig, + Scaled, + SharePointSources, + SlackChannelsSource, + TransformationConfig, + Unprovisioned, + VertexAiSearchConfig, + VertexPredictionEndpoint, + VertexVectorSearch, +) + + +_VALID_RESOURCE_NAME_REGEX = "[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}" +_VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX = ( + r"projects/[^/]+/locations/[^/]+/processors/[^/]+(?:/processorVersions/[^/]+)?" +) + + +def create_rag_data_service_client( + api_path_override: Optional[str] = None, +): + return initializer.global_config.create_client( + client_class=VertexRagDataClientWithOverride, + api_path_override=api_path_override, + ).select_version("v1") + + +def create_rag_data_service_async_client( + api_path_override: Optional[str] = None, +): + return initializer.global_config.create_client( + client_class=VertexRagDataAsyncClientWithOverride, + api_path_override=api_path_override, + ).select_version("v1") + + +def create_rag_service_client( + api_path_override: Optional[str] = None, +): + return initializer.global_config.create_client( + client_class=VertexRagClientWithOverride, + api_path_override=api_path_override, + ).select_version("v1") + + +def create_rag_service_async_client( + api_path_override: Optional[str] = None, +): + return initializer.global_config.create_client( + client_class=VertexRagAsyncClientWithOverride, + api_path_override=api_path_override, + ).select_version("v1") + + +def convert_gapic_to_rag_embedding_model_config( + gapic_embedding_model_config: GapicRagEmbeddingModelConfig, +) -> RagEmbeddingModelConfig: + """Convert GapicRagEmbeddingModelConfig to RagEmbeddingModelConfig.""" + embedding_model_config = RagEmbeddingModelConfig() + path = gapic_embedding_model_config.vertex_prediction_endpoint.endpoint + publisher_model = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + path, + ) + endpoint = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + path, + ) + if publisher_model: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + publisher_model=path + ) + if endpoint: + embedding_model_config.vertex_prediction_endpoint = VertexPredictionEndpoint( + endpoint=path, + model=gapic_embedding_model_config.vertex_prediction_endpoint.model, + model_version_id=gapic_embedding_model_config.vertex_prediction_endpoint.model_version_id, + ) + return embedding_model_config + + +def _check_weaviate(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("weaviate") + except AttributeError: + if "weaviate" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("weaviate") + return False + + +def _check_rag_managed_db(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("rag_managed_db") + except AttributeError: + if "rag_managed_db" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("rag_managed_db") + return False + + +def _check_vertex_feature_store(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("vertex_feature_store") + except AttributeError: + if "vertex_feature_store" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("vertex_feature_store") + return False + + +def _check_pinecone(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("pinecone") + except AttributeError: + if "pinecone" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("pinecone") + return False + + +def _check_vertex_vector_search(gapic_vector_db: GapicRagVectorDbConfig) -> bool: + try: + return gapic_vector_db.__contains__("vertex_vector_search") + except AttributeError: + if "vertex_vector_search" in gapic_vector_db._pb.DESCRIPTOR.fields_by_name: + return gapic_vector_db._pb.HasField("vertex_vector_search") + return False + + +def _check_rag_embedding_model_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> bool: + try: + return gapic_vector_db.__contains__("rag_embedding_model_config") + except AttributeError: + if ( + "rag_embedding_model_config" + in gapic_vector_db._pb.DESCRIPTOR.fields_by_name + ): + return gapic_vector_db._pb.HasField("rag_embedding_model_config") + return False + + +def convert_gapic_to_backend_config( + gapic_vector_db: GapicRagVectorDbConfig, +) -> RagVectorDbConfig: + """Convert Gapic RagVectorDbConfig to VertexVectorSearch, Pinecone, or RagManagedDb.""" + if not gapic_vector_db: + return None + vector_config = RagVectorDbConfig() + if _check_pinecone(gapic_vector_db): + vector_config.vector_db = Pinecone( + index_name=gapic_vector_db.pinecone.index_name, + api_key=gapic_vector_db.api_auth.api_key_config.api_key_secret_version, + ) + elif _check_vertex_vector_search(gapic_vector_db): + vector_config.vector_db = VertexVectorSearch( + index_endpoint=gapic_vector_db.vertex_vector_search.index_endpoint, + index=gapic_vector_db.vertex_vector_search.index, + ) + elif _check_rag_managed_db(gapic_vector_db): + vector_config.vector_db = RagManagedDb() + if _check_rag_embedding_model_config(gapic_vector_db): + vector_config.rag_embedding_model_config = ( + convert_gapic_to_rag_embedding_model_config( + gapic_vector_db.rag_embedding_model_config + ) + ) + return vector_config + + +def convert_gapic_to_vertex_ai_search_config( + gapic_vertex_ai_search_config: GapicVertexAiSearchConfig, +) -> Optional[VertexAiSearchConfig]: + """Convert Gapic VertexAiSearchConfig to VertexAiSearchConfig.""" + print(f"DEBUG: gapic_vertex_ai_search_config={gapic_vertex_ai_search_config!r}") + print(f"DEBUG: serving_config={gapic_vertex_ai_search_config.serving_config!r}") + if gapic_vertex_ai_search_config.serving_config: + return VertexAiSearchConfig( + serving_config=gapic_vertex_ai_search_config.serving_config, + ) + return None + + +def convert_gapic_to_rag_corpus(gapic_rag_corpus: GapicRagCorpus) -> RagCorpus: + """Convert GapicRagCorpus to RagCorpus.""" + rag_corpus = RagCorpus( + name=gapic_rag_corpus.name, + display_name=gapic_rag_corpus.display_name, + description=gapic_rag_corpus.description, + vertex_ai_search_config=convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ), + backend_config=convert_gapic_to_backend_config( + gapic_rag_corpus.vector_db_config + ), + encryption_spec=gapic_rag_corpus.encryption_spec, + ) + return rag_corpus + + +def convert_gapic_to_rag_corpus_no_embedding_model_config( + gapic_rag_corpus: GapicRagCorpus, +) -> RagCorpus: + """Convert GapicRagCorpus without embedding model config (for UpdateRagCorpus) to RagCorpus.""" + vertex_ai_search_config = convert_gapic_to_vertex_ai_search_config( + gapic_rag_corpus.vertex_ai_search_config + ) + old_config = gapic_rag_corpus.vector_db_config + rag_vector_db_config_no_embedding_model_config = old_config.__class__() + if _check_rag_managed_db(old_config): + rag_vector_db_config_no_embedding_model_config.rag_managed_db = ( + old_config.rag_managed_db + ) + elif _check_pinecone(old_config): + rag_vector_db_config_no_embedding_model_config.pinecone = old_config.pinecone + elif _check_vertex_vector_search(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_vector_search = ( + old_config.vertex_vector_search + ) + elif _check_weaviate(old_config): + rag_vector_db_config_no_embedding_model_config.weaviate = old_config.weaviate + elif _check_vertex_feature_store(old_config): + rag_vector_db_config_no_embedding_model_config.vertex_feature_store = ( + old_config.vertex_feature_store + ) + try: + if old_config.__contains__("api_auth"): + rag_vector_db_config_no_embedding_model_config.api_auth = ( + old_config.api_auth + ) + except AttributeError: + pass + rag_corpus = RagCorpus( + name=gapic_rag_corpus.name, + display_name=gapic_rag_corpus.display_name, + description=gapic_rag_corpus.description, + vertex_ai_search_config=vertex_ai_search_config, + backend_config=convert_gapic_to_backend_config( + rag_vector_db_config_no_embedding_model_config + ), + encryption_spec=gapic_rag_corpus.encryption_spec, + ) + return rag_corpus + + +def convert_gapic_to_rag_file(gapic_rag_file: GapicRagFile) -> RagFile: + """Convert GapicRagFile to RagFile.""" + rag_file = RagFile( + name=gapic_rag_file.name, + display_name=gapic_rag_file.display_name, + description=gapic_rag_file.description, + ) + return rag_file + + +def convert_json_to_rag_file(upload_rag_file_response: Dict[str, Any]) -> RagFile: + """Converts a JSON response to a RagFile.""" + rag_file = RagFile( + name=upload_rag_file_response.get("ragFile").get("name"), + display_name=upload_rag_file_response.get("ragFile").get("displayName"), + description=upload_rag_file_response.get("ragFile").get("description"), + ) + return rag_file + + +def convert_tuple_to_rag_cited_generation_response( + cited_text: str, final_bibliography: str +) -> RagCitedGenerationResponse: + """Converts a tuple to a RagCitedGenerationResponse.""" + rag_cited_generation_response = RagCitedGenerationResponse( + cited_text=cited_text, + final_bibliography=final_bibliography, + ) + return rag_cited_generation_response + + +def convert_path_to_resource_id( + path: str, +) -> Union[str, GoogleDriveSource.ResourceId]: + """Converts a path to a Google Cloud storage uri or GoogleDriveSource.ResourceId.""" + if path.startswith("gs://"): + # Google Cloud Storage source + return path + elif path.startswith("https://drive.google.com/"): + # Google Drive source + path_list = path.split("/") + if "file" in path_list: + index = path_list.index("file") + 2 + resource_id = path_list[index].split("?")[0] + resource_type = GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FILE + elif "folders" in path_list: + index = path_list.index("folders") + 1 + resource_id = path_list[index].split("?")[0] + resource_type = ( + GoogleDriveSource.ResourceId.ResourceType.RESOURCE_TYPE_FOLDER + ) + else: + raise ValueError("path %s is not a valid Google Drive url.", path) + + return GoogleDriveSource.ResourceId( + resource_id=resource_id, + resource_type=resource_type, + ) + else: + raise ValueError( + "path must be a Google Cloud Storage uri or a Google Drive url." + ) + + +def convert_source_for_rag_import( + source: Union[SlackChannelsSource, JiraSource, SharePointSources], +) -> Union[GapicSlackSource, GapicJiraSource]: + """Converts a SlackChannelsSource or JiraSource to a GapicSlackSource or GapicJiraSource.""" + if isinstance(source, SlackChannelsSource): + result_source_channels = [] + for channel in source.channels: + api_key = channel.api_key + cid = channel.channel_id + start_time = channel.start_time + end_time = channel.end_time + result_channels = GapicSlackSource.SlackChannels( + channels=[ + GapicSlackSource.SlackChannels.SlackChannel( + channel_id=cid, + start_time=start_time, + end_time=end_time, + ) + ], + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ) + result_source_channels.append(result_channels) + return GapicSlackSource( + channels=result_source_channels, + ) + elif isinstance(source, JiraSource): + result_source_queries = [] + for query in source.queries: + api_key = query.api_key + custom_queries = query.custom_queries + projects = query.jira_projects + email = query.email + server_uri = query.server_uri + result_query = GapicJiraSource.JiraQueries( + custom_queries=custom_queries, + projects=projects, + email=email, + server_uri=server_uri, + api_key_config=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=api_key + ), + ) + result_source_queries.append(result_query) + return GapicJiraSource( + jira_queries=result_source_queries, + ) + elif isinstance(source, SharePointSources): + result_source_share_point_sources = [] + for share_point_source in source.share_point_sources: + sharepoint_folder_path = share_point_source.sharepoint_folder_path + sharepoint_folder_id = share_point_source.sharepoint_folder_id + drive_name = share_point_source.drive_name + drive_id = share_point_source.drive_id + client_id = share_point_source.client_id + client_secret = share_point_source.client_secret + tenant_id = share_point_source.tenant_id + sharepoint_site_name = share_point_source.sharepoint_site_name + result_share_point_source = GapicSharePointSources.SharePointSource( + client_id=client_id, + client_secret=api_auth.ApiAuth.ApiKeyConfig( + api_key_secret_version=client_secret + ), + tenant_id=tenant_id, + sharepoint_site_name=sharepoint_site_name, + ) + if sharepoint_folder_path is not None and sharepoint_folder_id is not None: + raise ValueError( + "sharepoint_folder_path and sharepoint_folder_id cannot both be set." + ) + elif sharepoint_folder_path is not None: + result_share_point_source.sharepoint_folder_path = ( + sharepoint_folder_path + ) + elif sharepoint_folder_id is not None: + result_share_point_source.sharepoint_folder_id = sharepoint_folder_id + if drive_name is not None and drive_id is not None: + raise ValueError("drive_name and drive_id cannot both be set.") + elif drive_name is not None: + result_share_point_source.drive_name = drive_name + elif drive_id is not None: + result_share_point_source.drive_id = drive_id + else: + raise ValueError("Either drive_name and drive_id must be set.") + result_source_share_point_sources.append(result_share_point_source) + return GapicSharePointSources( + share_point_sources=result_source_share_point_sources, + ) + else: + raise TypeError( + "source must be a SlackChannelsSource or JiraSource or SharePointSources." + ) + + +def prepare_import_files_request( + corpus_name: str, + paths: Optional[Sequence[str]] = None, + source: Optional[Union[SlackChannelsSource, JiraSource, SharePointSources]] = None, + transformation_config: Optional[TransformationConfig] = None, + max_embedding_requests_per_min: int = 1000, + import_result_sink: Optional[str] = None, + partial_failures_sink: Optional[str] = None, + layout_parser: Optional[LayoutParserConfig] = None, + llm_parser: Optional[LlmParserConfig] = None, +) -> ImportRagFilesRequest: + if len(corpus_name.split("/")) != 6: + raise ValueError( + "corpus_name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}`" + ) + + rag_file_parsing_config = RagFileParsingConfig() + if layout_parser is not None: + if ( + re.fullmatch( + _VALID_DOCUMENT_AI_PROCESSOR_NAME_REGEX, + layout_parser.processor_name, + ) + is None + ): + raise ValueError( + "processor_name must be of the format" + " `projects/{project_id}/locations/{location}/processors/{processor_id}`or" + " `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}`," + f" got {layout_parser.processor_name!r}" + ) + rag_file_parsing_config.layout_parser = RagFileParsingConfig.LayoutParser( + processor_name=layout_parser.processor_name, + max_parsing_requests_per_min=layout_parser.max_parsing_requests_per_min, + ) + if llm_parser is not None: + rag_file_parsing_config.llm_parser = RagFileParsingConfig.LlmParser( + model_name=llm_parser.model_name + ) + if llm_parser.max_parsing_requests_per_min is not None: + rag_file_parsing_config.llm_parser.max_parsing_requests_per_min = ( + llm_parser.max_parsing_requests_per_min + ) + if llm_parser.custom_parsing_prompt is not None: + rag_file_parsing_config.llm_parser.custom_parsing_prompt = ( + llm_parser.custom_parsing_prompt + ) + + chunk_size = 1024 + chunk_overlap = 200 + if transformation_config and transformation_config.chunking_config: + chunk_size = transformation_config.chunking_config.chunk_size + chunk_overlap = transformation_config.chunking_config.chunk_overlap + + rag_file_transformation_config = RagFileTransformationConfig( + rag_file_chunking_config=RagFileChunkingConfig( + fixed_length_chunking=RagFileChunkingConfig.FixedLengthChunking( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + ), + ), + ) + + import_rag_files_config = ImportRagFilesConfig( + rag_file_transformation_config=rag_file_transformation_config, + rag_file_parsing_config=rag_file_parsing_config, + max_embedding_requests_per_min=max_embedding_requests_per_min, + ) + + import_result_sink = import_result_sink or partial_failures_sink + + if import_result_sink is not None: + if import_result_sink.startswith("gs://"): + import_rag_files_config.partial_failure_gcs_sink.output_uri_prefix = ( + import_result_sink + ) + elif import_result_sink.startswith("bq://"): + import_rag_files_config.partial_failure_bigquery_sink.output_uri = ( + import_result_sink + ) + else: + raise ValueError( + "import_result_sink must be a GCS path or a BigQuery table." + ) + + if source is not None: + gapic_source = convert_source_for_rag_import(source) + if isinstance(gapic_source, GapicSlackSource): + import_rag_files_config.slack_source = gapic_source + if isinstance(gapic_source, GapicJiraSource): + import_rag_files_config.jira_source = gapic_source + if isinstance(gapic_source, GapicSharePointSources): + import_rag_files_config.share_point_sources = gapic_source + else: + uris = [] + resource_ids = [] + for p in paths: + output = convert_path_to_resource_id(p) + if isinstance(output, str): + uris.append(p) + else: + resource_ids.append(output) + if uris: + import_rag_files_config.gcs_source.uris = uris + if resource_ids: + google_drive_source = GoogleDriveSource( + resource_ids=resource_ids, + ) + import_rag_files_config.google_drive_source = google_drive_source + + request = ImportRagFilesRequest( + parent=corpus_name, import_rag_files_config=import_rag_files_config + ) + return request + + +def get_corpus_name( + name: str, +) -> str: + if name: + client = create_rag_data_service_client() + if client.parse_rag_corpus_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + return client.rag_corpus_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=name, + ) + else: + raise ValueError( + "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}` or `{rag_corpus}`" + ) + return name + + +def get_file_name( + name: str, + corpus_name: str, +) -> str: + client = create_rag_data_service_client() + if client.parse_rag_file_path(name): + return name + elif re.match("^{}$".format(_VALID_RESOURCE_NAME_REGEX), name): + if not corpus_name: + raise ValueError( + "corpus_name must be provided if name is a `{rag_file}`, not a " + "full resource name (`projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}`). " + ) + return client.rag_file_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + rag_corpus=get_corpus_name(corpus_name), + rag_file=name, + ) + else: + raise ValueError( + "name must be of the format `projects/{project}/locations/{location}/ragCorpora/{rag_corpus}/ragFiles/{rag_file}` or `{rag_file}`" + ) + + +def set_embedding_model_config( + embedding_model_config: RagEmbeddingModelConfig, + rag_corpus: GapicRagCorpus, +) -> None: + if embedding_model_config.vertex_prediction_endpoint is None: + return + if ( + embedding_model_config.vertex_prediction_endpoint.publisher_model + and embedding_model_config.vertex_prediction_endpoint.endpoint + ): + raise ValueError("publisher_model and endpoint cannot be set at the same time.") + if ( + not embedding_model_config.vertex_prediction_endpoint.publisher_model + and not embedding_model_config.vertex_prediction_endpoint.endpoint + ): + raise ValueError("At least one of publisher_model and endpoint must be set.") + parent = initializer.global_config.common_location_path(project=None, location=None) + + if embedding_model_config.vertex_prediction_endpoint.publisher_model: + publisher_model = ( + embedding_model_config.vertex_prediction_endpoint.publisher_model + ) + full_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/publishers/google/models/(?P.+?)$", + publisher_model, + ) + resource_name = re.match( + r"^publishers/google/models/(?P.+?)$", + publisher_model, + ) + if full_resource_name: + rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + publisher_model + ) + elif resource_name: + rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + parent + "/" + publisher_model + ) + else: + raise ValueError( + "publisher_model must be of the format `projects/{project}/locations/{location}/publishers/google/models/{model_id}` or `publishers/google/models/{model_id}`" + ) + + if embedding_model_config.vertex_prediction_endpoint.endpoint: + endpoint = embedding_model_config.vertex_prediction_endpoint.endpoint + full_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/endpoints/(?P.+?)$", + endpoint, + ) + resource_name = re.match( + r"^endpoints/(?P.+?)$", + endpoint, + ) + if full_resource_name: + rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + endpoint + ) + elif resource_name: + rag_corpus.vector_db_config.rag_embedding_model_config.vertex_prediction_endpoint.endpoint = ( + parent + "/" + endpoint + ) + else: + raise ValueError( + "endpoint must be of the format `projects/{project}/locations/{location}/endpoints/{endpoint}` or `endpoints/{endpoint}`" + ) + + +def set_backend_config( + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ], + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the vector db configuration for the rag corpus.""" + if backend_config is None: + return + + if backend_config.vector_db is not None: + vector_config = backend_config.vector_db + if isinstance(vector_config, RagManagedDb): + rag_corpus.vector_db_config.rag_managed_db.CopyFrom( + GapicRagVectorDbConfig.RagManagedDb() + ) + elif isinstance(vector_config, VertexVectorSearch): + index_endpoint = vector_config.index_endpoint + index = vector_config.index + + rag_corpus.vector_db_config.vertex_vector_search.index_endpoint = ( + index_endpoint + ) + rag_corpus.vector_db_config.vertex_vector_search.index = index + elif isinstance(vector_config, Pinecone): + index_name = vector_config.index_name + api_key = vector_config.api_key + + rag_corpus.vector_db_config.pinecone.index_name = index_name + rag_corpus.vector_db_config.api_auth.api_key_config.api_key_secret_version = ( + api_key + ) + elif vector_config is not None: + raise TypeError( + "backend_config must be a VertexFeatureStore," + "RagManagedDb, or Pinecone." + ) + if backend_config.rag_embedding_model_config: + set_embedding_model_config( + backend_config.rag_embedding_model_config, rag_corpus + ) + + +def set_encryption_spec( + encryption_spec: EncryptionSpec, + rag_corpus: GapicRagCorpus, +) -> None: + """Sets the encryption spec for the rag corpus.""" + # Raises value error if encryption_spec.kms_key_name is None or empty, + if encryption_spec.kms_key_name is None or not encryption_spec.kms_key_name: + raise ValueError("kms_key_name must be set if encryption_spec is set.") + + # Raises value error if encryption_spec.kms_key_name is not a valid KMS key name. + if not re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/keyRings/(?P.+?)/cryptoKeys/(?P.+?)$", + encryption_spec.kms_key_name, + ): + raise ValueError( + "kms_key_name must be of the format " + "`projects/{project}/locations/{location}/keyRings/{key_ring}/cryptoKeys/{crypto_key}`" + ) + + rag_corpus.encryption_spec = encryption_spec + + +def set_vertex_ai_search_config( + vertex_ai_search_config: VertexAiSearchConfig, + rag_corpus: GapicRagCorpus, +) -> None: + if not vertex_ai_search_config.serving_config: + raise ValueError("serving_config must be set.") + engine_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/collections/(?P.+?)/engines/(?P.+?)/servingConfigs/(?P.+?)$", + vertex_ai_search_config.serving_config, + ) + data_store_resource_name = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)/collections/(?P.+?)/dataStores/(?P.+?)/servingConfigs/(?P.+?)$", + vertex_ai_search_config.serving_config, + ) + if engine_resource_name or data_store_resource_name: + rag_corpus.vertex_ai_search_config = GapicVertexAiSearchConfig( + serving_config=vertex_ai_search_config.serving_config, + ) + else: + raise ValueError( + "serving_config must be of the format `projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}` or `projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`" + ) + + +def convert_gapic_to_rag_engine_config( + response: GapicRagEngineConfig, +) -> RagEngineConfig: + """Converts a GapicRagEngineConfig to a RagEngineConfig.""" + rag_managed_db_config = RagManagedDbConfig() + # If future fields are added with similar names, beware that __contains__ + # may match them. + if response.rag_managed_db_config.__contains__("basic"): + rag_managed_db_config.tier = Basic() + elif response.rag_managed_db_config.__contains__("unprovisioned"): + rag_managed_db_config.tier = Unprovisioned() + elif response.rag_managed_db_config.__contains__("scaled"): + rag_managed_db_config.tier = Scaled() + else: + raise ValueError("At least one of rag_managed_db_config must be set.") + return RagEngineConfig( + name=response.name, + rag_managed_db_config=rag_managed_db_config, + ) + + +def convert_rag_engine_config_to_gapic( + rag_engine_config: RagEngineConfig, +) -> GapicRagEngineConfig: + """Converts a RagEngineConfig to a GapicRagEngineConfig.""" + rag_managed_db_config = GapicRagManagedDbConfig() + if ( + rag_engine_config.rag_managed_db_config is None + or rag_engine_config.rag_managed_db_config.tier is None + ): + rag_managed_db_config = GapicRagManagedDbConfig( + basic=GapicRagManagedDbConfig.Basic() + ) + else: + if isinstance(rag_engine_config.rag_managed_db_config.tier, Basic): + rag_managed_db_config.basic = GapicRagManagedDbConfig.Basic() + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Unprovisioned): + rag_managed_db_config.unprovisioned = ( + GapicRagManagedDbConfig.Unprovisioned() + ) + elif isinstance(rag_engine_config.rag_managed_db_config.tier, Scaled): + rag_managed_db_config.scaled = GapicRagManagedDbConfig.Scaled() + return GapicRagEngineConfig( + name=rag_engine_config.name, + rag_managed_db_config=rag_managed_db_config, + ) diff --git a/agentplatform/rag/utils/resources.py b/agentplatform/rag/utils/resources.py new file mode 100644 index 0000000000..2f50516280 --- /dev/null +++ b/agentplatform/rag/utils/resources.py @@ -0,0 +1,545 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import dataclasses +from typing import List, Optional, Sequence, Union + +from google.protobuf import timestamp_pb2 +from google.cloud.aiplatform_v1.types import EncryptionSpec + + +@dataclasses.dataclass +class RagFile: + """RAG file (output only). + + Attributes: + name: Generated resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}/ragFiles/{rag_file}`` + display_name: Display name that was configured at client side. + description: The description of the RagFile. + """ + + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None + + +@dataclasses.dataclass +class VertexPredictionEndpoint: + """VertexPredictionEndpoint. + + Attributes: + publisher_model: 1P publisher model resource name. Format: + ``publishers/google/models/{model}`` or + ``projects/{project}/locations/{location}/publishers/google/models/{model}`` + endpoint: 1P fine tuned embedding model resource name. Format: + ``endpoints/{endpoint}`` or + ``projects/{project}/locations/{location}/endpoints/{endpoint}``. + model: + Output only. The resource name of the model that is deployed + on the endpoint. Present only when the endpoint is not a + publisher model. Pattern: + ``projects/{project}/locations/{location}/models/{model}`` + model_version_id: + Output only. Version ID of the model that is + deployed on the endpoint. Present only when the + endpoint is not a publisher model. + """ + + endpoint: Optional[str] = None + publisher_model: Optional[str] = None + model: Optional[str] = None + model_version_id: Optional[str] = None + + +@dataclasses.dataclass +class RagEmbeddingModelConfig: + """RagEmbeddingModelConfig. + + Attributes: + vertex_prediction_endpoint: The Vertex AI Prediction Endpoint config. + """ + + vertex_prediction_endpoint: Optional[VertexPredictionEndpoint] = None + + +@dataclasses.dataclass +class Weaviate: + """Weaviate. + + Attributes: + weaviate_http_endpoint: The Weaviate DB instance HTTP endpoint + collection_name: The corresponding Weaviate collection this corpus maps to + api_key: The SecretManager resource name for the Weaviate DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + weaviate_http_endpoint: Optional[str] = None + collection_name: Optional[str] = None + api_key: Optional[str] = None + + +@dataclasses.dataclass +class VertexFeatureStore: + """VertexFeatureStore. + + Attributes: + resource_name: The resource name of the FeatureView. Format: + ``projects/{project}/locations/{location}/featureOnlineStores/ + {feature_online_store}/featureViews/{feature_view}`` + """ + + resource_name: Optional[str] = None + + +@dataclasses.dataclass +class VertexVectorSearch: + """VertexVectorSearch. + + Attributes: + index_endpoint (str): + The resource name of the Index Endpoint. Format: + ``projects/{project}/locations/{location}/indexEndpoints/{index_endpoint}`` + index (str): + The resource name of the Index. Format: + ``projects/{project}/locations/{location}/indexes/{index}`` + """ + + index_endpoint: Optional[str] = None + index: Optional[str] = None + + +@dataclasses.dataclass +class RagManagedDb: + """RagManagedDb.""" + + +@dataclasses.dataclass +class Pinecone: + """Pinecone. + + Attributes: + index_name: The Pinecone index name. + api_key: The SecretManager resource name for the Pinecone DB API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + """ + + index_name: Optional[str] = None + api_key: Optional[str] = None + + +@dataclasses.dataclass +class VertexAiSearchConfig: + """VertexAiSearchConfig. + + Attributes: + serving_config: The resource name of the Vertex AI Search serving config. + Format: + ``projects/{project}/locations/{location}/collections/{collection}/engines/{engine}/servingConfigs/{serving_config}`` + or + ``projects/{project}/locations/{location}/collections/{collection}/dataStores/{data_store}/servingConfigs/{serving_config}`` + """ + + serving_config: Optional[str] = None + + +@dataclasses.dataclass +class RagVectorDbConfig: + """RagVectorDbConfig. + + Attributes: + vector_db: Can be one of the following: RagManagedDb, Pinecone, + VertexVectorSearch. + rag_embedding_model_config: The embedding model config of the Vector DB. + """ + + vector_db: Optional[ + Union[ + VertexVectorSearch, + Pinecone, + RagManagedDb, + ] + ] = None + rag_embedding_model_config: Optional[RagEmbeddingModelConfig] = None + + +@dataclasses.dataclass +class RagCorpus: + """RAG corpus(output only). + + Attributes: + name: Generated resource name. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` + display_name: Display name that was configured at client side. + description: The description of the RagCorpus. + vertex_ai_search_config: The Vertex AI Search config of the RagCorpus. + backend_config: The backend config of the RagCorpus. It can be a data + store and/or retrieval engine. + encryption_spec: The encryption spec of the RagCorpus. Immutable. + """ + + name: Optional[str] = None + display_name: Optional[str] = None + description: Optional[str] = None + vertex_ai_search_config: Optional[VertexAiSearchConfig] = None + backend_config: Optional[ + Union[ + RagVectorDbConfig, + None, + ] + ] = None + encryption_spec: Optional[EncryptionSpec] = None + + +@dataclasses.dataclass +class RagResource: + """RagResource. + + The representation of the rag source. It can be used to specify corpus only + or ragfiles. Currently only support one corpus or multiple files from one + corpus. In the future we may open up multiple corpora support. + + Attributes: + rag_corpus: A Rag corpus resource name or corpus id. Format: + ``projects/{project}/locations/{location}/ragCorpora/{rag_corpus_id}`` + or ``{rag_corpus_id}``. + rag_files_id: List of Rag file resource name or file ids in the same corpus. Format: + ``{rag_file}``. + """ + + rag_corpus: Optional[str] = None + rag_file_ids: Optional[List[str]] = None + + +@dataclasses.dataclass +class SlackChannel: + """SlackChannel. + + Attributes: + channel_id: The Slack channel ID. + api_key: The SecretManager resource name for the Slack API token. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + See: https://api.slack.com/tutorials/tracks/getting-a-token. + start_time: The starting timestamp for messages to import. + end_time: The ending timestamp for messages to import. + """ + + channel_id: str + api_key: str + start_time: Optional[timestamp_pb2.Timestamp] = None + end_time: Optional[timestamp_pb2.Timestamp] = None + + +@dataclasses.dataclass +class SlackChannelsSource: + """SlackChannelsSource. + + Attributes: + channels: The Slack channels. + """ + + channels: Sequence[SlackChannel] + + +@dataclasses.dataclass +class JiraQuery: + """JiraQuery. + + Attributes: + email: The Jira email address. + jira_projects: A list of Jira projects to import in their entirety. + custom_queries: A list of custom JQL Jira queries to import. + api_key: The SecretManager version resource name for Jira API access. Format: + ``projects/{project}/secrets/{secret}/versions/{version}`` + See: https://support.atlassian.com/atlassian-account/docs/manage-api-tokens-for-your-atlassian-account/ + server_uri: The Jira server URI. Format: + ``{server}.atlassian.net`` + """ + + email: str + jira_projects: Sequence[str] + custom_queries: Sequence[str] + api_key: str + server_uri: str + + +@dataclasses.dataclass +class JiraSource: + """JiraSource. + + Attributes: + queries: The Jira queries. + """ + + queries: Sequence[JiraQuery] + + +@dataclasses.dataclass +class SharePointSource: + """SharePointSource. + + Attributes: + sharepoint_folder_path: The path of the SharePoint folder to download + from. + sharepoint_folder_id: The ID of the SharePoint folder to download + from. + drive_name: The name of the drive to download from. + drive_id: The ID of the drive to download from. + client_id: The Application ID for the app registered in + Microsoft Azure Portal. The application must + also be configured with MS Graph permissions + "Files.ReadAll", "Sites.ReadAll" and + BrowserSiteLists.Read.All. + client_secret: The application secret for the app registered + in Azure. + tenant_id: Unique identifier of the Azure Active + Directory Instance. + sharepoint_site_name: The name of the SharePoint site to download + from. This can be the site name or the site id. + """ + + sharepoint_folder_path: Optional[str] = None + sharepoint_folder_id: Optional[str] = None + drive_name: Optional[str] = None + drive_id: Optional[str] = None + client_id: str = None + client_secret: str = None + tenant_id: str = None + sharepoint_site_name: str = None + + +@dataclasses.dataclass +class SharePointSources: + """SharePointSources. + + Attributes: + share_point_sources: The SharePoint sources. + """ + + share_point_sources: Sequence[SharePointSource] + + +@dataclasses.dataclass +class Filter: + """Filter. + + Attributes: + vector_distance_threshold: Only returns contexts with vector + distance smaller than the threshold. + vector_similarity_threshold: Only returns contexts with vector + similarity larger than the threshold. + metadata_filter: String for metadata filtering. + """ + + vector_distance_threshold: Optional[float] = None + vector_similarity_threshold: Optional[float] = None + metadata_filter: Optional[str] = None + + +@dataclasses.dataclass +class LlmRanker: + """LlmRanker. + + Attributes: + model_name: The model name used for ranking. Only Gemini models are + supported for now. + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class RankService: + """RankService. + + Attributes: + model_name: The model name of the rank service. Format: + ``semantic-ranker-512@latest`` + """ + + model_name: Optional[str] = None + + +@dataclasses.dataclass +class Ranking: + """Ranking. + + Attributes: + rank_service: Config for Rank Service. + llm_ranker: Config for LlmRanker. + """ + + rank_service: Optional[RankService] = None + llm_ranker: Optional[LlmRanker] = None + + +@dataclasses.dataclass +class RagRetrievalConfig: + """RagRetrievalConfig. + + Attributes: + top_k: The number of contexts to retrieve. + filter: Config for filters. + ranking: Config for ranking. + """ + + top_k: Optional[int] = None + filter: Optional[Filter] = None + ranking: Optional[Ranking] = None + + +@dataclasses.dataclass +class ChunkingConfig: + """ChunkingConfig. + + Attributes: + chunk_size: The size of each chunk. + chunk_overlap: The size of the overlap between chunks. + """ + + chunk_size: int + chunk_overlap: int + + +@dataclasses.dataclass +class TransformationConfig: + """TransformationConfig. + + Attributes: + chunking_config: The chunking config. + """ + + chunking_config: Optional[ChunkingConfig] = None + + +@dataclasses.dataclass +class LayoutParserConfig: + """Configuration for the Document AI Layout Parser Processor. + + Attributes: + processor_name: The full resource name of a Document AI processor or + processor version. The processor must have type + `LAYOUT_PARSER_PROCESSOR`. + Format must be one of the following: + - `projects/{project_id}/locations/{location}/processors/{processor_id}` + - `projects/{project_id}/locations/{location}/processors/{processor_id}/processorVersions/{processor_version_id}` + max_parsing_requests_per_min: The maximum number of requests the job is + allowed to make to the Document AI processor per minute. Consult + https://cloud.google.com/document-ai/quotas and the Quota page for + your project to set an appropriate value here. If unspecified, a + default value of 120 QPM will be used. + """ + + processor_name: str + max_parsing_requests_per_min: Optional[int] = None + + +@dataclasses.dataclass +class LlmParserConfig: + """Configuration for the Document AI Layout Parser Processor. + + Attributes: + model_name (str): + The full resource name of a Vertex AI model. Format: + - `projects/{project_id}/locations/{location}/publishers/google/models/{model_id}` + - `projects/{project_id}/locations/{location}/models/{model_id}` + max_parsing_requests_per_min (int): + The maximum number of requests the job is allowed to make to the + Vertex AI model per minute. Consult + https://cloud.google.com/vertex-ai/generative-ai/docs/quotas and + the Quota page for your project to set an appropriate value here. + If unspecified, a default value of 120 QPM will be used. + custom_parsing_prompt (str): + A custom prompt to use for parsing. + """ + + model_name: str + max_parsing_requests_per_min: Optional[int] = None + custom_parsing_prompt: Optional[str] = None + + +@dataclasses.dataclass +class RagCitedGenerationResponse: + """RagCitedGenerationResponse. + + Attributes: + cited_text: The text with inline citations. + final_bibliography: List of all unique cited chunks, their URIs, and page + numbers (if applicable). + """ + + cited_text: str + final_bibliography: str + + +@dataclasses.dataclass +class Scaled: + """Scaled tier offers production grade performance along with + + autoscaling functionality. It is suitable for customers with large + amounts of data or performance sensitive workloads. + """ + + +@dataclasses.dataclass +class Basic: + """Basic tier is a cost-effective and low compute tier suitable for the following cases: + + * Experimenting with RagManagedDb. + * Small data size. + * Latency insensitive workload. + * Only using RAG Engine with external vector DBs. + + NOTE: This is the default tier if not explicitly chosen. + """ + + +@dataclasses.dataclass +class Unprovisioned: + """Disables the RAG Engine service and deletes all your data held within + this service. This will halt the billing of the service. + + NOTE: Once deleted the data cannot be recovered. To start using + RAG Engine again, you will need to update the tier by calling the + UpdateRagEngineConfig API. + """ + + +@dataclasses.dataclass +class RagManagedDbConfig: + """RagManagedDbConfig. + + The config of the RagManagedDb used by RagEngine. + + Attributes: + tier: The tier of the RagManagedDb. The default tier is Basic. + """ + + tier: Optional[Union[Basic, Scaled, Unprovisioned]] = None + + +@dataclasses.dataclass +class RagEngineConfig: + """RagEngineConfig. + + Attributes: + name: Generated resource name for singleton resource. Format: + ``projects/{project}/locations/{location}/ragEngineConfig`` + rag_managed_db_config: The config of the RagManagedDb used by RagEngine. + The default tier is Basic. + """ + + name: str + rag_managed_db_config: Optional[RagManagedDbConfig] = None diff --git a/tests/unit/vertex_rag/conftest.py b/tests/unit/vertex_rag/conftest.py index f9bbe134de..d9f3662b0c 100644 --- a/tests/unit/vertex_rag/conftest.py +++ b/tests/unit/vertex_rag/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -57,7 +57,7 @@ def authorized_session_mock(): @pytest.fixture def rag_data_client_mock(): - from vertexai.rag.utils import _gapic_utils + from agentplatform.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_client" @@ -83,7 +83,7 @@ def rag_data_client_mock(): @pytest.fixture def rag_data_client_preview_mock(): - from vertexai.preview.rag.utils import _gapic_utils + from agentplatform.preview.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_client" @@ -109,7 +109,7 @@ def rag_data_client_preview_mock(): @pytest.fixture def rag_data_client_mock_exception(): - from vertexai.rag.utils import _gapic_utils + from agentplatform.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_client" @@ -141,7 +141,7 @@ def rag_data_client_mock_exception(): @pytest.fixture def rag_data_client_preview_mock_exception(): - from vertexai.preview.rag.utils import _gapic_utils + from agentplatform.preview.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_client" @@ -177,7 +177,7 @@ def rag_data_client_preview_mock_exception(): @pytest.fixture def rag_data_async_client_mock_exception(): - from vertexai.rag.utils import _gapic_utils + from agentplatform.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_async_client" @@ -191,7 +191,7 @@ def rag_data_async_client_mock_exception(): @pytest.fixture def rag_data_async_client_preview_mock_exception(): - from vertexai.preview.rag.utils import _gapic_utils + from agentplatform.preview.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_async_client" diff --git a/tests/unit/vertex_rag/test_rag_constants.py b/tests/unit/vertex_rag/test_rag_constants.py index 65459cec2a..64a894d472 100644 --- a/tests/unit/vertex_rag/test_rag_constants.py +++ b/tests/unit/vertex_rag/test_rag_constants.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,7 +16,7 @@ from google.cloud import aiplatform -from vertexai.rag import ( +from agentplatform.rag import ( Basic, Filter, LayoutParserConfig, diff --git a/tests/unit/vertex_rag/test_rag_constants_preview.py b/tests/unit/vertex_rag/test_rag_constants_preview.py index 9819137268..56364b531a 100644 --- a/tests/unit/vertex_rag/test_rag_constants_preview.py +++ b/tests/unit/vertex_rag/test_rag_constants_preview.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -44,12 +42,11 @@ ) from google.cloud.aiplatform_v1beta1.types import api_auth from google.cloud.aiplatform_v1beta1.types import EncryptionSpec -from vertexai.preview.rag import ( +from agentplatform.preview.rag import ( ANN, Basic, DocumentCorpus, EmbeddingModelConfig, - Enterprise, Filter, HybridSearch, JiraQuery, @@ -282,7 +279,6 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - embedding_model_config=TEST_EMBEDDING_MODEL_CONFIG, ) TEST_CMEK_RAG_CORPUS = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, @@ -296,46 +292,48 @@ name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - vector_db=TEST_WEAVIATE_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_WEAVIATE_CONFIG), ) TEST_RAG_CORPUS_VERTEX_FEATURE_STORE = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - vector_db=TEST_VERTEX_FEATURE_STORE_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_VERTEX_FEATURE_STORE_CONFIG), ) TEST_RAG_CORPUS_PINECONE = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - vector_db=TEST_PINECONE_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_PINECONE_CONFIG), ) TEST_RAG_CORPUS_RAG_MANAGED_DB = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, - vector_db=TEST_RAG_MANAGED_DB_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_RAG_MANAGED_DB_CONFIG), ) TEST_RAG_CORPUS_RAG_MANAGED_DB_KNN = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, - vector_db=TEST_RAG_MANAGED_DB_KNN_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_RAG_MANAGED_DB_KNN_CONFIG), ) TEST_RAG_CORPUS_RAG_MANAGED_DB_ANN = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, - vector_db=TEST_RAG_MANAGED_DB_ANN_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_RAG_MANAGED_DB_ANN_CONFIG), ) TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG, + backend_config=RagVectorDbConfig(vector_db=TEST_VERTEX_VECTOR_SEARCH_CONFIG), ) TEST_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH = RagCorpus( name=TEST_RAG_CORPUS_RESOURCE_NAME, display_name=TEST_CORPUS_DISPLAY_NAME, description=TEST_CORPUS_DISCRIPTION, - vector_db=TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG, + backend_config=RagVectorDbConfig( + vector_db=TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG + ), ) TEST_PAGE_TOKEN = "test-page-token" @@ -626,19 +624,15 @@ ) TEST_RAG_ENGINE_CONFIG_BASIC = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, - rag_managed_db_config=RagManagedDbConfig(tier=Basic()), + rag_managed_db_config=RagManagedDbConfig(mode=Spanner(tier=Basic())), ) TEST_RAG_ENGINE_CONFIG_SCALED = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, - rag_managed_db_config=RagManagedDbConfig(tier=Scaled()), + rag_managed_db_config=RagManagedDbConfig(mode=Spanner(tier=Scaled())), ) TEST_RAG_ENGINE_CONFIG_UNPROVISIONED = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, - rag_managed_db_config=RagManagedDbConfig(tier=Unprovisioned()), -) -TEST_RAG_ENGINE_CONFIG_ENTERPRISE = RagEngineConfig( - name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, - rag_managed_db_config=RagManagedDbConfig(tier=Enterprise()), + rag_managed_db_config=RagManagedDbConfig(mode=Spanner(tier=Unprovisioned())), ) TEST_DEFAULT_RAG_ENGINE_CONFIG = RagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, @@ -648,9 +642,9 @@ name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=RagManagedDbConfig( mode=Spanner(tier=Basic()), - tier=Scaled(), ), ) +TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER.rag_managed_db_config.tier = Scaled() TEST_GAPIC_RAG_ENGINE_CONFIG_SERVERLESS = GapicRagEngineConfig( name=TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, rag_managed_db_config=GapicRagManagedDbConfig( diff --git a/tests/unit/vertex_rag/test_rag_data.py b/tests/unit/vertex_rag/test_rag_data.py index ff27968b35..763d03a92d 100644 --- a/tests/unit/vertex_rag/test_rag_data.py +++ b/tests/unit/vertex_rag/test_rag_data.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +14,12 @@ # import importlib from google.api_core import operation as ga_operation -from vertexai import rag -from vertexai.rag.utils._gapic_utils import ( +from agentplatform import rag +from agentplatform.rag.utils._gapic_utils import ( prepare_import_files_request, set_embedding_model_config, ) -from vertexai.rag.utils.resources import ( +from agentplatform.rag.utils.resources import ( ChunkingConfig, TransformationConfig, ) diff --git a/tests/unit/vertex_rag/test_rag_data_preview.py b/tests/unit/vertex_rag/test_rag_data_preview.py index 68a2fda465..ab453d9e84 100644 --- a/tests/unit/vertex_rag/test_rag_data_preview.py +++ b/tests/unit/vertex_rag/test_rag_data_preview.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,12 +14,12 @@ # import importlib from google.api_core import operation as ga_operation -from vertexai.preview import rag -from vertexai.preview.rag.utils._gapic_utils import ( +from agentplatform.preview import rag +from agentplatform.preview.rag.utils._gapic_utils import ( prepare_import_files_request, set_embedding_model_config, ) -from vertexai.rag.utils.resources import ( +from agentplatform.rag.utils.resources import ( ChunkingConfig, TransformationConfig, ) @@ -923,7 +921,7 @@ def update_rag_metadata_mock(): @pytest.fixture def rag_data_client_preview_mock_exception(): - from vertexai.preview.rag.utils import _gapic_utils + from agentplatform.preview.rag.utils import _gapic_utils with mock.patch.object( _gapic_utils, "create_rag_data_service_client" @@ -993,7 +991,6 @@ def rag_metadata_eq(returned_metadata, expected_metadata): def rag_corpus_eq(returned_corpus, expected_corpus): assert returned_corpus.name == expected_corpus.name assert returned_corpus.display_name == expected_corpus.display_name - assert returned_corpus.vector_db == expected_corpus.vector_db assert returned_corpus.backend_config == expected_corpus.backend_config assert ( returned_corpus.vertex_ai_search_config @@ -1049,9 +1046,19 @@ def import_files_request_eq(returned_request, expected_request): def rag_engine_config_eq(returned_config, expected_config): assert returned_config.name == expected_config.name - assert ( - returned_config.rag_managed_db_config == expected_config.rag_managed_db_config + ret_mode = ( + returned_config.rag_managed_db_config.mode + if returned_config.rag_managed_db_config + and returned_config.rag_managed_db_config.mode + else rag.Spanner(tier=rag.Basic()) ) + exp_mode = ( + expected_config.rag_managed_db_config.mode + if expected_config.rag_managed_db_config + and expected_config.rag_managed_db_config.mode + else rag.Spanner(tier=rag.Basic()) + ) + assert ret_mode == exp_mode @pytest.mark.usefixtures("google_auth_mock") @@ -1071,7 +1078,7 @@ def teardown_method(self): def test_create_corpus_success(self): rag_corpus = rag.create_corpus( display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG, + backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, ) rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS) @@ -1094,49 +1101,6 @@ def test_create_corpus_cmek_success(self): rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_CMEK_RAG_CORPUS) - @pytest.mark.usefixtures("create_rag_corpus_mock_weaviate") - def test_create_corpus_weaviate_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) - - @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_feature_store") - def test_create_corpus_vertex_feature_store_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_VERTEX_FEATURE_STORE_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE - ) - - @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search") - def test_create_corpus_vertex_vector_search_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH - ) - - @pytest.mark.usefixtures("create_rag_corpus_mock_rag_managed_vertex_vector_search") - def test_create_corpus_rag_managed_vertex_vector_search_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_RAG_MANAGED_VERTEX_VECTOR_SEARCH_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, - test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_VERTEX_VECTOR_SEARCH, - ) - @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_vector_search_backend") def test_create_corpus_vertex_vector_search_backend_success(self): rag_corpus = rag.create_corpus( @@ -1149,48 +1113,6 @@ def test_create_corpus_vertex_vector_search_backend_success(self): test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH_BACKEND, ) - @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone") - def test_create_corpus_pinecone_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE) - - @pytest.mark.usefixtures("create_rag_corpus_mock_rag_managed_db") - def test_create_corpus_rag_managed_db_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_RAG_MANAGED_DB_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_DB - ) - - @pytest.mark.usefixtures("create_rag_corpus_mock_rag_managed_db_knn") - def test_create_corpus_rag_managed_db_knn_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_RAG_MANAGED_DB_KNN_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_DB_KNN - ) - - @pytest.mark.usefixtures("create_rag_corpus_mock_rag_managed_db_ann") - def test_create_corpus_rag_managed_db_ann_success(self): - rag_corpus = rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_RAG_MANAGED_DB_ANN_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_DB_ANN - ) - @pytest.mark.usefixtures("create_rag_corpus_mock_pinecone_backend") def test_create_corpus_pinecone_backend_success(self): rag_corpus = rag.create_corpus( @@ -1238,32 +1160,6 @@ def test_create_corpus_rag_managed_db_ann_backend_success(self): test_rag_constants_preview.TEST_RAG_CORPUS_RAG_MANAGED_DB_ANN_BACKEND, ) - def test_create_corpus_backend_config_with_embedding_model_config_failure( - self, - ): - with pytest.raises(ValueError) as e: - rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, - embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG, - ) - e.match( - "Only one of backend_config or embedding_model_config and vector_db can be set. embedding_model_config and vector_db are deprecated, use backend_config instead." - ) - - def test_create_corpus_backend_config_with_vector_db_failure( - self, - ): - with pytest.raises(ValueError) as e: - rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - backend_config=test_rag_constants_preview.TEST_BACKEND_CONFIG_EMBEDDING_MODEL_CONFIG, - vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, - ) - e.match( - "Only one of backend_config or embedding_model_config and vector_db can be set. embedding_model_config and vector_db are deprecated, use backend_config instead." - ) - @pytest.mark.usefixtures("create_rag_corpus_mock_vertex_ai_engine_search_config") def test_create_corpus_vais_engine_search_config_success(self): rag_corpus = rag.create_corpus( @@ -1288,15 +1184,6 @@ def test_create_corpus_vais_datastore_search_config_success(self): test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_DATASTORE_SEARCH_CONFIG, ) - def test_create_corpus_vais_datastore_search_config_with_vector_db_failure(self): - with pytest.raises(ValueError) as e: - rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE, - vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, - ) - e.match("Only one of vertex_ai_search_config or vector_db can be set.") - def test_create_corpus_vais_datastore_search_config_with_backend_config_failure( self, ): @@ -1308,19 +1195,6 @@ def test_create_corpus_vais_datastore_search_config_with_backend_config_failure( ) e.match("Only one of vertex_ai_search_config or backend_config can be set.") - def test_create_corpus_vais_datastore_search_config_with_embedding_model_config_failure( - self, - ): - with pytest.raises(ValueError) as e: - rag.create_corpus( - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE, - embedding_model_config=test_rag_constants_preview.TEST_EMBEDDING_MODEL_CONFIG, - ) - e.match( - "Only one of vertex_ai_search_config or embedding_model_config can be set." - ) - def test_set_vertex_ai_search_config_with_invalid_serving_config_failure(self): with pytest.raises(ValueError) as e: rag.create_corpus( @@ -1369,110 +1243,6 @@ def test_create_document_corpus_success(self): rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_DOCUMENT_CORPUS) - @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") - def test_update_corpus_weaviate_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) - - @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") - def test_update_corpus_weaviate_no_display_name_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) - - @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") - def test_update_corpus_weaviate_with_description_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - description=test_rag_constants_preview.TEST_CORPUS_DISCRIPTION, - vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) - - @pytest.mark.usefixtures("update_rag_corpus_mock_weaviate") - def test_update_corpus_weaviate_with_description_and_display_name_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - description=test_rag_constants_preview.TEST_CORPUS_DISCRIPTION, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_WEAVIATE_CONFIG, - ) - - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_WEAVIATE) - - @pytest.mark.usefixtures("update_rag_corpus_mock_vertex_feature_store") - def test_update_corpus_vertex_feature_store_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_VERTEX_FEATURE_STORE_CONFIG, - ) - - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_FEATURE_STORE - ) - - @pytest.mark.usefixtures("update_rag_corpus_mock_vertex_vector_search") - def test_update_corpus_vertex_vector_search_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, - ) - rag_corpus_eq( - rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_VECTOR_SEARCH - ) - - @pytest.mark.usefixtures("update_rag_corpus_mock_pinecone") - def test_update_corpus_pinecone_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vector_db=test_rag_constants_preview.TEST_PINECONE_CONFIG, - ) - rag_corpus_eq(rag_corpus, test_rag_constants_preview.TEST_RAG_CORPUS_PINECONE) - - @pytest.mark.usefixtures("rag_data_client_preview_mock_exception") - def test_update_corpus_failure(self): - with pytest.raises(RuntimeError) as e: - rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - ) - e.match("Failed in RagCorpus update due to") - - @pytest.mark.usefixtures("update_rag_corpus_mock_vertex_ai_engine_search_config") - def test_update_corpus_vais_engine_search_config_success(self): - rag_corpus = rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_ENGINE, - ) - - rag_corpus_eq( - rag_corpus, - test_rag_constants_preview.TEST_RAG_CORPUS_VERTEX_AI_ENGINE_SEARCH_CONFIG, - ) - - def test_update_corpus_vais_datastore_search_config_with_vector_db_failure(self): - with pytest.raises(ValueError) as e: - rag.update_corpus( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - display_name=test_rag_constants_preview.TEST_CORPUS_DISPLAY_NAME, - vertex_ai_search_config=test_rag_constants_preview.TEST_VERTEX_AI_SEARCH_CONFIG_DATASTORE, - vector_db=test_rag_constants_preview.TEST_VERTEX_VECTOR_SEARCH_CONFIG, - ) - e.match("Only one of vertex_ai_search_config or vector_db can be set.") - @pytest.mark.usefixtures("rag_data_client_preview_mock") def test_get_corpus_success(self): rag_corpus = rag.get_corpus( @@ -1888,35 +1658,6 @@ def test_prepare_import_files_request_invalid_layout_parser_name(self): ) e.match("processor_name must be of the format") - def test_advanced_pdf_parsing_and_layout_parser_both_set_error(self): - with pytest.raises(ValueError) as e: - rag.import_files( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - paths=[test_rag_constants_preview.TEST_DRIVE_FOLDER], - transformation_config=create_transformation_config(), - use_advanced_pdf_parsing=True, - layout_parser=test_rag_constants_preview.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG, - ) - e.match( - "Only one of use_advanced_pdf_parsing or layout_parser may be " - "passed in at a time" - ) - - @pytest.mark.asyncio - async def test_advanced_pdf_parsing_and_layout_parser_both_set_error_async(self): - with pytest.raises(ValueError) as e: - await rag.import_files_async( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - paths=[test_rag_constants_preview.TEST_DRIVE_FOLDER], - transformation_config=create_transformation_config(), - use_advanced_pdf_parsing=True, - layout_parser=test_rag_constants_preview.TEST_LAYOUT_PARSER_WITH_PROCESSOR_PATH_CONFIG, - ) - e.match( - "Only one of use_advanced_pdf_parsing or layout_parser may be " - "passed in at a time" - ) - def test_prepare_import_files_request_llm_parser(self): request = prepare_import_files_request( corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, @@ -1929,20 +1670,6 @@ def test_prepare_import_files_request_llm_parser(self): test_rag_constants_preview.TEST_IMPORT_REQUEST_LLM_PARSER, ) - def test_advanced_pdf_parsing_and_llm_parser_both_set_error(self): - with pytest.raises(ValueError) as e: - rag.import_files( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - paths=[test_rag_constants_preview.TEST_DRIVE_FOLDER], - transformation_config=create_transformation_config(), - use_advanced_pdf_parsing=True, - llm_parser=test_rag_constants_preview.TEST_LLM_PARSER_CONFIG, - ) - e.match( - "Only one of use_advanced_pdf_parsing or llm_parser may be " - "passed in at a time" - ) - def test_layout_parser_and_llm_parser_both_set_error(self): with pytest.raises(ValueError) as e: rag.import_files( @@ -1954,21 +1681,6 @@ def test_layout_parser_and_llm_parser_both_set_error(self): ) e.match("Only one of layout_parser or llm_parser may be passed in at a time") - @pytest.mark.asyncio - async def test_advanced_pdf_parsing_and_llm_parser_both_set_error_async(self): - with pytest.raises(ValueError) as e: - await rag.import_files_async( - corpus_name=test_rag_constants_preview.TEST_RAG_CORPUS_RESOURCE_NAME, - paths=[test_rag_constants_preview.TEST_DRIVE_FOLDER], - transformation_config=create_transformation_config(), - use_advanced_pdf_parsing=True, - llm_parser=test_rag_constants_preview.TEST_LLM_PARSER_CONFIG, - ) - e.match( - "Only one of use_advanced_pdf_parsing or llm_parser may be " - "passed in at a time" - ) - @pytest.mark.asyncio async def test_layout_parser_and_llm_parser_both_set_error_async(self): with pytest.raises(ValueError) as e: @@ -2155,18 +1867,6 @@ def test_update_rag_engine_config_success( test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC, ) - def test_update_rag_engine_config_enterprise_success( - self, update_rag_engine_config_enterprise_mock - ): - rag_config = rag.update_rag_engine_config( - rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, - ) - assert update_rag_engine_config_enterprise_mock.call_count == 1 - rag_engine_config_eq( - rag_config, - test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, - ) - def test_update_rag_engine_config_scaled_success( self, update_rag_engine_config_scaled_mock ): @@ -2251,21 +1951,6 @@ def test_update_rag_engine_config_serverless_success( test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_SERVERLESS, ) - def test_update_rag_engine_config_with_mode_and_tier_failure(self): - with pytest.raises(ValueError) as e: - rag.update_rag_engine_config( - rag_engine_config=test_rag_constants_preview.TEST_BAD_RAG_ENGINE_CONFIG_WITH_MODE_AND_TIER, - ) - e.match("mode and tier both cannot be set at the same time") - - @pytest.mark.usefixtures("update_rag_engine_config_mock_exception") - def test_update_rag_engine_config_failure(self): - with pytest.raises(RuntimeError) as e: - rag.update_rag_engine_config( - rag_engine_config=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE, - ) - e.match("Failed in RagEngineConfig update due to") - @pytest.mark.usefixtures("update_rag_engine_config_basic_mock") def test_update_rag_engine_config_bad_input( self, update_rag_engine_config_basic_mock @@ -2328,15 +2013,6 @@ def test_get_rag_engine_config_success(self): rag_config, test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_BASIC ) - @pytest.mark.usefixtures("get_rag_engine_enterprise_config_mock") - def test_get_rag_engine_config_enterprise_success(self): - rag_config = rag.get_rag_engine_config( - name=test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_RESOURCE_NAME, - ) - rag_engine_config_eq( - rag_config, test_rag_constants_preview.TEST_RAG_ENGINE_CONFIG_ENTERPRISE - ) - @pytest.mark.usefixtures("get_rag_engine_scaled_config_mock") def test_get_rag_engine_config_scaled_success(self): rag_config = rag.get_rag_engine_config( diff --git a/tests/unit/vertex_rag/test_rag_retrieval.py b/tests/unit/vertex_rag/test_rag_retrieval.py index 7e72540b3f..5bdde36435 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval.py +++ b/tests/unit/vertex_rag/test_rag_retrieval.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,11 +15,12 @@ """Tests for vertex_rag.retrieval.""" import importlib + from google.cloud import aiplatform +from agentplatform import rag from google.cloud.aiplatform_v1 import VertexRagServiceAsyncClient from google.cloud.aiplatform_v1 import VertexRagServiceClient import test_rag_constants as tc -from vertexai import rag import mock import pytest diff --git a/tests/unit/vertex_rag/test_rag_retrieval_preview.py b/tests/unit/vertex_rag/test_rag_retrieval_preview.py index b233481532..18f56e5ce7 100644 --- a/tests/unit/vertex_rag/test_rag_retrieval_preview.py +++ b/tests/unit/vertex_rag/test_rag_retrieval_preview.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +19,7 @@ from google.cloud.aiplatform_v1beta1 import VertexRagServiceAsyncClient from google.cloud.aiplatform_v1beta1 import VertexRagServiceClient import test_rag_constants_preview -from vertexai.preview import rag +from agentplatform.preview import rag import mock import pytest @@ -99,15 +97,11 @@ def teardown_method(self): @pytest.mark.usefixtures("retrieve_contexts_mock") def test_retrieval_query_rag_resources_success(self): - with pytest.warns(DeprecationWarning): - response = rag.retrieval_query( - rag_resources=[tc.TEST_RAG_RESOURCE], - text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, - vector_search_alpha=0.5, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) + response = rag.retrieval_query( + rag_resources=[tc.TEST_RAG_RESOURCE], + text=tc.TEST_QUERY_TEXT, + ) + retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) @pytest.mark.usefixtures("ask_contexts_mock") def test_ask_contexts_rag_resources_success(self): @@ -139,16 +133,6 @@ def test_ask_contexts_multiple_rag_resources_success(self): ) retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.usefixtures("ask_contexts_mock") - def test_ask_contexts_multiple_rag_corpora_success(self): - with pytest.warns(DeprecationWarning): - response = rag.ask_contexts( - rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.asyncio async def test_async_retrieve_contexts_rag_resources_success( self, async_retrieve_contexts_mock @@ -185,18 +169,6 @@ async def test_async_retrieve_contexts_multiple_rag_resources_success( ) retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.asyncio - async def test_async_retrieve_contexts_multiple_rag_corpora_success( - self, async_retrieve_contexts_mock - ): - with pytest.warns(DeprecationWarning): - response = await rag.async_retrieve_contexts( - rag_corpora=[tc.TEST_RAG_CORPUS_ID, tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_ALPHA, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.usefixtures("retrieve_contexts_mock") def test_retrieval_query_rag_resources_config_success(self): response = rag.retrieval_query( @@ -223,73 +195,12 @@ def test_retrieval_query_rag_resources_default_config_success(self): ) retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_rag_corpora_success(self): - with pytest.warns(DeprecationWarning): - response = rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - - @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_rag_corpora_config_success(self): - response = rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - - @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_rag_corpora_config_rank_service_success(self): - response = rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_RANK_SERVICE, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - - @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_with_metadata_filter(self, retrieve_contexts_mock): - metadata_filter = 'doc.metadata.genre == "fiction"' - rag_retrieval_config = rag.RagRetrievalConfig( - top_k=10, - filter=rag.Filter( - vector_distance_threshold=0.5, metadata_filter=metadata_filter - ), - ) - rag.retrieval_query( - rag_resources=[tc.TEST_RAG_RESOURCE], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=rag_retrieval_config, - ) - retrieve_contexts_mock.assert_called_once() - args, kwargs = retrieve_contexts_mock.call_args - request = kwargs["request"] - assert ( - request.query.rag_retrieval_config.filter.metadata_filter == metadata_filter - ) - - @pytest.mark.usefixtures("retrieve_contexts_mock") - def test_retrieval_query_rag_corpora_config_llm_ranker_success(self): - response = rag.retrieval_query( - rag_corpora=[tc.TEST_RAG_CORPUS_ID], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG_LLM_RANKER, - ) - retrieve_contexts_eq(response, tc.TEST_RETRIEVAL_RESPONSE) - @pytest.mark.usefixtures("rag_client_mock_exception") def test_retrieval_query_failure(self): with pytest.raises(RuntimeError) as e: rag.retrieval_query( rag_resources=[tc.TEST_RAG_RESOURCE], text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, ) e.match("Failed in retrieving contexts due to") @@ -308,8 +219,6 @@ def test_retrieval_query_invalid_name(self): rag.retrieval_query( rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME], text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, ) e.match("Invalid RagCorpus name") @@ -322,31 +231,6 @@ def test_retrieval_query_invalid_name_config(self): ) e.match("Invalid RagCorpus name") - def test_retrieval_query_multiple_rag_corpora(self): - with pytest.raises(ValueError) as e: - rag.retrieval_query( - rag_corpora=[ - tc.TEST_RAG_CORPUS_ID, - tc.TEST_RAG_CORPUS_ID, - ], - text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, - ) - e.match("Currently only support 1 RagCorpus") - - def test_retrieval_query_multiple_rag_corpora_config(self): - with pytest.raises(ValueError) as e: - rag.retrieval_query( - rag_corpora=[ - tc.TEST_RAG_CORPUS_ID, - tc.TEST_RAG_CORPUS_ID, - ], - text=tc.TEST_QUERY_TEXT, - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, - ) - e.match("Currently only support 1 RagCorpus") - def test_retrieval_query_multiple_rag_resources(self): with pytest.raises(ValueError) as e: rag.retrieval_query( @@ -355,8 +239,6 @@ def test_retrieval_query_multiple_rag_resources(self): tc.TEST_RAG_RESOURCE, ], text=tc.TEST_QUERY_TEXT, - similarity_top_k=2, - vector_distance_threshold=0.5, ) e.match("Currently only support 1 RagResource") diff --git a/tests/unit/vertex_rag/test_rag_store.py b/tests/unit/vertex_rag/test_rag_store.py index 91a169bfb6..17dfd99a78 100644 --- a/tests/unit/vertex_rag/test_rag_store.py +++ b/tests/unit/vertex_rag/test_rag_store.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from vertexai import rag -from vertexai.preview.generative_models import Tool +from agentplatform import rag import pytest import test_rag_constants as tc @@ -23,100 +20,74 @@ @pytest.mark.usefixtures("google_auth_mock") class TestRagStoreValidations: def test_retrieval_tool_success(self): - tool = Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) + retrieval = rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, + ), ) - assert tool is not None + assert retrieval is not None def test_retrieval_tool_vector_similarity_success(self): - tool = Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, - ), - ) + retrieval = rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, + ), ) - assert tool is not None + assert retrieval is not None def test_retrieval_tool_no_rag_resources(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, - ), - ) + rag.VertexRagStore( + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, ) - e.match("rag_resources must be specified.") + e.match("rag_resources must be specified.") def test_retrieval_tool_ranking_config_success(self): - tool = Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_RANKING_CONFIG, - ), - ) + retrieval = rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_RANKING_CONFIG, + ), ) - assert tool is not None + assert retrieval is not None def test_retrieval_tool_invalid_name(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, - ), - ) + rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE_INVALID_NAME], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, ) - e.match("Invalid RagCorpus name") + e.match("Invalid RagCorpus name") def test_retrieval_tool_multiple_rag_resources(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) + rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE, tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_CONFIG, ) - e.match("Currently only support 1 RagResource") + e.match("Currently only support 1 RagResource") def test_retrieval_tool_invalid_config_filter(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG, - ) - ) - ) - e.match( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." + rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_CONFIG, ) + e.match( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) def test_retrieval_tool_invalid_ranking_config_filter(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[tc.TEST_RAG_RESOURCE], - rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG, - ) - ) - ) - e.match( - "Only one of rank_service or llm_ranker can be specified" - " at a time in rag_retrieval_config." + rag.VertexRagStore( + rag_resources=[tc.TEST_RAG_RESOURCE], + rag_retrieval_config=tc.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG, ) + e.match( + "Only one of rank_service or llm_ranker can be specified" + " at a time in rag_retrieval_config." + ) diff --git a/tests/unit/vertex_rag/test_rag_store_preview.py b/tests/unit/vertex_rag/test_rag_store_preview.py index 0529cfafa8..35ab62e259 100644 --- a/tests/unit/vertex_rag/test_rag_store_preview.py +++ b/tests/unit/vertex_rag/test_rag_store_preview.py @@ -1,6 +1,4 @@ -# -*- coding: utf-8 -*- - -# Copyright 2024 Google LLC +# Copyright 2026 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,208 +12,88 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from vertexai.preview import rag -from vertexai.preview.generative_models import Tool +from agentplatform.preview import rag import pytest import test_rag_constants_preview @pytest.mark.usefixtures("google_auth_mock") class TestRagStore: - def test_retrieval_tool_success(self): - with pytest.warns(DeprecationWarning): - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], - similarity_top_k=3, - vector_distance_threshold=0.4, - ), - ) - ) - - def test_retrieval_tool_config_success(self): - with pytest.warns(DeprecationWarning): - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_corpora=[ - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) - ) - - def test_retrieval_tool_similarity_config_success(self): - with pytest.warns(DeprecationWarning): - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_corpora=[ - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, - ), - ) - ) - - def test_retrieval_tool_ranking_config_success(self): - with pytest.warns(DeprecationWarning): - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_corpora=[ - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_RANKING_CONFIG, - ), - ) - ) - def test_empty_retrieval_tool_success(self): - tool = Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], - rag_retrieval_config=rag.RagRetrievalConfig(), - similarity_top_k=3, - vector_distance_threshold=0.4, - ), - ) + retrieval = rag.Retrieval( + source=rag.VertexRagStore( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + rag_retrieval_config=rag.RagRetrievalConfig(), + ), ) - assert tool is not None + assert retrieval is not None def test_retrieval_tool_no_rag_resources(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, - ), - ) + rag.VertexRagStore( + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_SIMILARITY_CONFIG, ) - e.match("rag_resources or rag_corpora must be specified.") + e.match("rag_resources must be specified.") def test_retrieval_tool_invalid_name(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[ - test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME - ], - similarity_top_k=3, - vector_distance_threshold=0.4, - ), - ) + rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], ) - e.match("Invalid RagCorpus name") + e.match("Invalid RagCorpus name") def test_retrieval_tool_invalid_name_config(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[ - test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) + rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE_INVALID_NAME + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, ) - e.match("Invalid RagCorpus name") - - def test_retrieval_tool_multiple_rag_corpora(self): - with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_corpora=[ - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - ], - similarity_top_k=3, - vector_distance_threshold=0.4, - ), - ) - ) - e.match("Currently only support 1 RagCorpus") - - def test_retrieval_tool_multiple_rag_corpora_config(self): - with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_corpora=[ - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - test_rag_constants_preview.TEST_RAG_CORPUS_ID, - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) - ) - e.match("Currently only support 1 RagCorpus") + e.match("Invalid RagCorpus name") def test_retrieval_tool_multiple_rag_resources(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[ - test_rag_constants_preview.TEST_RAG_RESOURCE, - test_rag_constants_preview.TEST_RAG_RESOURCE, - ], - similarity_top_k=3, - vector_distance_threshold=0.4, - ), - ) + rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], ) - e.match("Currently only support 1 RagResource") + e.match("Currently only support 1 RagResource") def test_retrieval_tool_multiple_rag_resources_config(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[ - test_rag_constants_preview.TEST_RAG_RESOURCE, - test_rag_constants_preview.TEST_RAG_RESOURCE, - ], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, - ), - ) + rag.VertexRagStore( + rag_resources=[ + test_rag_constants_preview.TEST_RAG_RESOURCE, + test_rag_constants_preview.TEST_RAG_RESOURCE, + ], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_CONFIG, ) - e.match("Currently only support 1 RagResource") + e.match("Currently only support 1 RagResource") def test_retrieval_tool_invalid_config_filter(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG, - ) - ) - ) - e.match( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." - ) + rag.VertexRagStore( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_CONFIG, + ) + e.match( + "Only one of vector_distance_threshold or" + " vector_similarity_threshold can be specified at a time" + " in rag_retrieval_config." + ) def test_retrieval_tool_invalid_ranking_config_filter(self): with pytest.raises(ValueError) as e: - Tool.from_retrieval( - retrieval=rag.Retrieval( - source=rag.VertexRagStore( - rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], - rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG, - ) - ) - ) - e.match( - "Only one of vector_distance_threshold or" - " vector_similarity_threshold can be specified at a time" - " in rag_retrieval_config." + rag.VertexRagStore( + rag_resources=[test_rag_constants_preview.TEST_RAG_RESOURCE], + rag_retrieval_config=test_rag_constants_preview.TEST_RAG_RETRIEVAL_ERROR_RANKING_CONFIG, ) + e.match( + "Only one of rank_service or llm_ranker can be specified" + " at a time in rag_retrieval_config." + )