From 99877b7b6a9569da632dfd325a680375448a6c71 Mon Sep 17 00:00:00 2001 From: Nishar Miya <103556082+miyannishar@users.noreply.github.com> Date: Thu, 11 Dec 2025 02:13:55 +0600 Subject: [PATCH 1/4] added s3 artifact --- .../samples/s3_artifact_example/__init__.py | 16 + .../samples/s3_artifact_example/agent.py | 68 ++ pyproject.toml | 5 +- src/google/adk_community/__init__.py | 1 + src/google/adk_community/artifacts/README.md | 48 ++ .../adk_community/artifacts/__init__.py | 20 + .../artifacts/s3_artifact_service.py | 612 ++++++++++++++++++ .../artifacts/test_s3_artifact_service.py | 518 +++++++++++++++ 8 files changed, 1286 insertions(+), 2 deletions(-) create mode 100644 contributing/samples/s3_artifact_example/__init__.py create mode 100644 contributing/samples/s3_artifact_example/agent.py create mode 100644 src/google/adk_community/artifacts/README.md create mode 100644 src/google/adk_community/artifacts/__init__.py create mode 100644 src/google/adk_community/artifacts/s3_artifact_service.py create mode 100644 tests/unittests/artifacts/test_s3_artifact_service.py diff --git a/contributing/samples/s3_artifact_example/__init__.py b/contributing/samples/s3_artifact_example/__init__.py new file mode 100644 index 0000000..8ce90a2 --- /dev/null +++ b/contributing/samples/s3_artifact_example/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent + diff --git a/contributing/samples/s3_artifact_example/agent.py b/contributing/samples/s3_artifact_example/agent.py new file mode 100644 index 0000000..f0e166d --- /dev/null +++ b/contributing/samples/s3_artifact_example/agent.py @@ -0,0 +1,68 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example agent demonstrating S3 artifact storage. + +This example shows how to configure an ADK agent to use Amazon S3 for +artifact storage using the community S3ArtifactService. + +Before running: +1. Install: pip install google-adk-community boto3 +2. Set AWS credentials (see README.md) +3. Create S3 bucket +4. Update bucket name below or set ADK_S3_BUCKET environment variable +""" +from __future__ import annotations + +import os + +from google.adk import Agent +from google.adk.apps import App +from google.adk_community.artifacts import S3ArtifactService + +# Get bucket name from environment or use default +BUCKET_NAME = os.getenv("ADK_S3_BUCKET", "my-adk-artifacts") +AWS_REGION = os.getenv("AWS_REGION", "us-east-1") + +# Initialize S3 artifact service +artifact_service = S3ArtifactService( + bucket_name=BUCKET_NAME, + region_name=AWS_REGION, +) + +# Define the agent +root_agent = Agent( + name="s3_artifact_agent", + model="gemini-2.0-flash", + instruction="""You are a helpful assistant that can save and retrieve files. + +When users ask you to save information, use the save_artifact tool to store +it in S3. When they ask for previously saved information, use the load_artifact +tool to retrieve it. + +Examples: +- "Save this report to a file called quarterly_report.pdf" +- "Load the file called quarterly_report.pdf" +- "List all my saved files" +""", + description="An assistant that demonstrates S3 artifact storage", +) + +# Create app with S3 artifact service +app = App( + name="s3_artifact_example", + root_agent=root_agent, + artifact_service=artifact_service, +) + diff --git a/pyproject.toml b/pyproject.toml index 11afcd8..0d6c566 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,12 +25,13 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK + "boto3>=1.28.0", # For S3ArtifactService "google-adk", # Google ADK + "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "httpx>=0.27.0, <1.0.0", # For OpenMemory service + "orjson>=3.11.3", "redis>=5.0.0, <6.0.0", # Redis for session storage # go/keep-sorted end - "orjson>=3.11.3", ] dynamic = ["version"] diff --git a/src/google/adk_community/__init__.py b/src/google/adk_community/__init__.py index 9a1dc35..803823d 100644 --- a/src/google/adk_community/__init__.py +++ b/src/google/adk_community/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import artifacts from . import memory from . import sessions from . import version diff --git a/src/google/adk_community/artifacts/README.md b/src/google/adk_community/artifacts/README.md new file mode 100644 index 0000000..7bfa8a1 --- /dev/null +++ b/src/google/adk_community/artifacts/README.md @@ -0,0 +1,48 @@ +# Community Artifact Services + +This module contains community-contributed artifact service implementations for ADK. + +## Available Services + +### S3ArtifactService + +Production-ready artifact storage using Amazon S3. + +**Installation:** +```bash +pip install google-adk-community boto3 +``` + +**Usage:** +```python +from google.adk_community.artifacts import S3ArtifactService + +artifact_service = S3ArtifactService( + bucket_name="my-adk-artifacts", + region_name="us-east-1" +) +``` + +**Features:** +- Session-scoped and user-scoped artifacts +- Automatic version management +- Custom metadata support +- URL encoding for special characters +- Works with S3-compatible services (MinIO, DigitalOcean Spaces, etc.) + +**See Also:** +- [S3ArtifactService Implementation](./s3_artifact_service.py) +- [Example Agent](../../../contributing/samples/s3_artifact_example/) +- [Tests](../../../tests/unittests/artifacts/test_s3_artifact_service.py) + +## Contributing + +Want to add a new artifact service? See our [contribution guide](../../../CONTRIBUTING.md). + +Examples of artifact services to contribute: +- Azure Blob Storage +- Google Drive +- Dropbox +- MinIO (dedicated implementation) +- Any S3-compatible service + diff --git a/src/google/adk_community/artifacts/__init__.py b/src/google/adk_community/artifacts/__init__.py new file mode 100644 index 0000000..92b8473 --- /dev/null +++ b/src/google/adk_community/artifacts/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .s3_artifact_service import S3ArtifactService + +__all__ = [ + 'S3ArtifactService', +] + diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py new file mode 100644 index 0000000..3ec7638 --- /dev/null +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -0,0 +1,612 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An artifact service implementation using Amazon S3. + +The object key format used depends on whether the filename has a user namespace: + - For files with user namespace (starting with "user:"): + {app_name}/{user_id}/user/{filename}/{version} + - For regular session-scoped files: + {app_name}/{user_id}/{session_id}/{filename}/{version} +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Any +from typing import Optional +from urllib.parse import quote +from urllib.parse import unquote + +from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk.artifacts.base_artifact_service import BaseArtifactService +from google.adk.errors.input_validation_error import InputValidationError +from google.genai import types +from typing_extensions import override + +logger = logging.getLogger("google_adk_community." + __name__) + + +class S3ArtifactService(BaseArtifactService): + """An artifact service implementation using Amazon S3.""" + + def __init__( + self, + bucket_name: str, + region_name: Optional[str] = None, + **kwargs, + ): + """Initializes the S3ArtifactService. + + Args: + bucket_name: The name of the S3 bucket to use. + region_name: AWS region name (optional). + **kwargs: Additional keyword arguments to pass to boto3.client(). + """ + try: + import boto3 + except ImportError as exc: + raise ImportError( + "boto3 is required to use S3ArtifactService. " + "Install it with: pip install boto3" + ) from exc + + self.bucket_name = bucket_name + client_kwargs = dict(kwargs) + if region_name: + client_kwargs["region_name"] = region_name + + self.s3_client = boto3.client("s3", **client_kwargs) + + # Verify bucket access + try: + self.s3_client.head_bucket(Bucket=self.bucket_name) + logger.info("S3ArtifactService initialized with bucket: %s", bucket_name) + except Exception as e: + logger.error("Cannot access S3 bucket '%s': %s", bucket_name, e) + raise + + def _encode_filename(self, filename: str) -> str: + """URL-encode filename to handle special characters. + + Args: + filename: The filename to encode. + + Returns: + The URL-encoded filename. + """ + return quote(filename, safe="") + + def _decode_filename(self, encoded_filename: str) -> str: + """URL-decode filename to restore original filename. + + Args: + encoded_filename: The encoded filename to decode. + + Returns: + The decoded filename. + """ + return unquote(encoded_filename) + + def _file_has_user_namespace(self, filename: str) -> bool: + """Checks if the filename has a user namespace. + + Args: + filename: The filename to check. + + Returns: + True if the filename has a user namespace (starts with "user:"), + False otherwise. + """ + return filename.startswith("user:") + + def _get_object_key_prefix( + self, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> tuple[str, str]: + """Constructs the S3 object key prefix and encoded filename. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + filename: The name of the artifact file. + session_id: The ID of the session. + + Returns: + A tuple of (prefix, encoded_filename). + """ + if self._file_has_user_namespace(filename): + # Remove "user:" prefix before encoding + actual_filename = filename[5:] # len("user:") == 5 + encoded_filename = self._encode_filename(actual_filename) + return f"{app_name}/{user_id}/user", encoded_filename + + if session_id is None: + raise InputValidationError( + "Session ID must be provided for session-scoped artifacts." + ) + encoded_filename = self._encode_filename(filename) + return f"{app_name}/{user_id}/{session_id}", encoded_filename + + def _get_object_key( + self, + app_name: str, + user_id: str, + filename: str, + version: int, + session_id: Optional[str] = None, + ) -> str: + """Constructs the full S3 object key. + + Args: + app_name: The name of the application. + user_id: The ID of the user. + filename: The name of the artifact file. + version: The version of the artifact. + session_id: The ID of the session. + + Returns: + The constructed S3 object key. + """ + prefix, encoded_filename = self._get_object_key_prefix( + app_name, user_id, filename, session_id + ) + return f"{prefix}/{encoded_filename}/{version}" + + @override + async def save_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + artifact: types.Part, + session_id: Optional[str] = None, + custom_metadata: Optional[dict[str, Any]] = None, + ) -> int: + return await asyncio.to_thread( + self._save_artifact_sync, + app_name, + user_id, + session_id, + filename, + artifact, + custom_metadata, + ) + + @override + async def load_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[types.Part]: + return await asyncio.to_thread( + self._load_artifact_sync, + app_name, + user_id, + session_id, + filename, + version, + ) + + @override + async def list_artifact_keys( + self, *, app_name: str, user_id: str, session_id: Optional[str] = None + ) -> list[str]: + return await asyncio.to_thread( + self._list_artifact_keys_sync, + app_name, + user_id, + session_id, + ) + + @override + async def delete_artifact( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> None: + return await asyncio.to_thread( + self._delete_artifact_sync, + app_name, + user_id, + session_id, + filename, + ) + + @override + async def list_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> list[int]: + return await asyncio.to_thread( + self._list_versions_sync, + app_name, + user_id, + session_id, + filename, + ) + + def _save_artifact_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + artifact: types.Part, + custom_metadata: Optional[dict[str, Any]], + ) -> int: + """Synchronous implementation of save_artifact.""" + # Get next version number + versions = self._list_versions_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + version = 0 if not versions else max(versions) + 1 + + object_key = self._get_object_key( + app_name, user_id, filename, version, session_id + ) + + # Prepare data and content type + if artifact.inline_data: + data = artifact.inline_data.data + content_type = artifact.inline_data.mime_type or "application/octet-stream" + elif artifact.text: + data = artifact.text.encode("utf-8") + content_type = "text/plain; charset=utf-8" + else: + raise InputValidationError( + "Artifact must have either inline_data or text content." + ) + + # Prepare put_object arguments + put_kwargs: dict[str, Any] = { + "Bucket": self.bucket_name, + "Key": object_key, + "Body": data, + "ContentType": content_type, + } + + # Add custom metadata if provided + if custom_metadata: + put_kwargs["Metadata"] = { + str(k): str(v) for k, v in custom_metadata.items() + } + + try: + self.s3_client.put_object(**put_kwargs) + logger.debug( + "Saved artifact %s version %d to S3 key %s", + filename, + version, + object_key, + ) + return version + except Exception as e: + logger.error("Failed to save artifact '%s' to S3: %s", filename, e) + raise + + def _load_artifact_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + version: Optional[int], + ) -> Optional[types.Part]: + """Synchronous implementation of load_artifact.""" + if version is None: + versions = self._list_versions_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + if not versions: + return None + version = max(versions) + + object_key = self._get_object_key( + app_name, user_id, filename, version, session_id + ) + + try: + response = self.s3_client.get_object( + Bucket=self.bucket_name, Key=object_key + ) + content_type = response.get("ContentType", "application/octet-stream") + data = response["Body"].read() + + if not data: + return None + + artifact = types.Part.from_bytes(data=data, mime_type=content_type) + logger.debug( + "Loaded artifact %s version %d from S3 key %s", + filename, + version, + object_key, + ) + return artifact + + except self.s3_client.exceptions.NoSuchKey: + logger.debug( + "Artifact %s version %d not found in S3", filename, version + ) + return None + except Exception as e: + logger.error("Failed to load artifact '%s' from S3: %s", filename, e) + raise + + def _list_artifact_keys_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + ) -> list[str]: + """Synchronous implementation of list_artifact_keys.""" + filenames: set[str] = set() + + # List session-scoped artifacts + if session_id: + session_prefix = f"{app_name}/{user_id}/{session_id}/" + try: + response = self.s3_client.list_objects_v2( + Bucket=self.bucket_name, Prefix=session_prefix + ) + if "Contents" in response: + for obj in response["Contents"]: + # Parse: {prefix}/{encoded_filename}/{version} + key = obj["Key"] + parts = key[len(session_prefix) :].split("/") + if len(parts) >= 2: + encoded_filename = parts[0] + filename = self._decode_filename(encoded_filename) + filenames.add(filename) + except Exception as e: + logger.error( + "Failed to list session artifacts for %s: %s", session_id, e + ) + + # List user-scoped artifacts + user_prefix = f"{app_name}/{user_id}/user/" + try: + response = self.s3_client.list_objects_v2( + Bucket=self.bucket_name, Prefix=user_prefix + ) + if "Contents" in response: + for obj in response["Contents"]: + # Parse: {prefix}/{encoded_filename}/{version} + key = obj["Key"] + parts = key[len(user_prefix) :].split("/") + if len(parts) >= 2: + encoded_filename = parts[0] + filename = self._decode_filename(encoded_filename) + filenames.add(f"user:{filename}") + except Exception as e: + logger.error("Failed to list user artifacts for %s: %s", user_id, e) + + return sorted(list(filenames)) + + def _delete_artifact_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + ) -> None: + """Synchronous implementation of delete_artifact.""" + versions = self._list_versions_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + for version in versions: + object_key = self._get_object_key( + app_name, user_id, filename, version, session_id + ) + try: + self.s3_client.delete_object(Bucket=self.bucket_name, Key=object_key) + logger.debug("Deleted S3 object: %s", object_key) + except Exception as e: + logger.error("Failed to delete S3 object %s: %s", object_key, e) + + def _list_versions_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + ) -> list[int]: + """Lists all available versions of an artifact. + + This method retrieves all versions of a specific artifact by querying S3 + objects that match the constructed object key prefix. + + Args: + app_name: The name of the application. + user_id: The ID of the user who owns the artifact. + session_id: The ID of the session (ignored for user-namespaced files). + filename: The name of the artifact file. + + Returns: + A list of version numbers (integers) available for the specified + artifact. Returns an empty list if no versions are found. + """ + prefix, encoded_filename = self._get_object_key_prefix( + app_name, user_id, filename, session_id + ) + full_prefix = f"{prefix}/{encoded_filename}/" + + try: + response = self.s3_client.list_objects_v2( + Bucket=self.bucket_name, Prefix=full_prefix + ) + versions: list[int] = [] + if "Contents" in response: + for obj in response["Contents"]: + # Extract version from key: {prefix}/{encoded_filename}/{version} + key = obj["Key"] + version_str = key.split("/")[-1] + if version_str.isdigit(): + versions.append(int(version_str)) + return sorted(versions) + except Exception as e: + logger.error("Failed to list versions for '%s': %s", filename, e) + return [] + + def _get_artifact_version_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + version: Optional[int], + ) -> Optional[ArtifactVersion]: + """Synchronous implementation of get_artifact_version.""" + if version is None: + versions = self._list_versions_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + if not versions: + return None + version = max(versions) + + object_key = self._get_object_key( + app_name, user_id, filename, version, session_id + ) + + try: + response = self.s3_client.head_object( + Bucket=self.bucket_name, Key=object_key + ) + + metadata = response.get("Metadata", {}) or {} + last_modified = response.get("LastModified") + create_time = ( + last_modified.timestamp() + if hasattr(last_modified, "timestamp") + else None + ) + + canonical_uri = f"s3://{self.bucket_name}/{object_key}" + + return ArtifactVersion( + version=version, + canonical_uri=canonical_uri, + custom_metadata={str(k): str(v) for k, v in metadata.items()}, + create_time=create_time, + mime_type=response.get("ContentType"), + ) + except self.s3_client.exceptions.NoSuchKey: + logger.debug( + "Artifact %s version %d not found in S3", filename, version + ) + return None + except Exception as e: + logger.error( + "Failed to get artifact version for '%s' version %d: %s", + filename, + version, + e, + ) + return None + + def _list_artifact_versions_sync( + self, + app_name: str, + user_id: str, + session_id: Optional[str], + filename: str, + ) -> list[ArtifactVersion]: + """Lists all versions and their metadata of an artifact.""" + versions = self._list_versions_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + artifact_versions: list[ArtifactVersion] = [] + for version in versions: + artifact_version = self._get_artifact_version_sync( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=version, + ) + if artifact_version: + artifact_versions.append(artifact_version) + + return artifact_versions + + @override + async def list_artifact_versions( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + ) -> list[ArtifactVersion]: + return await asyncio.to_thread( + self._list_artifact_versions_sync, + app_name, + user_id, + session_id, + filename, + ) + + @override + async def get_artifact_version( + self, + *, + app_name: str, + user_id: str, + filename: str, + session_id: Optional[str] = None, + version: Optional[int] = None, + ) -> Optional[ArtifactVersion]: + return await asyncio.to_thread( + self._get_artifact_version_sync, + app_name, + user_id, + session_id, + filename, + version, + ) + diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py new file mode 100644 index 0000000..0d9f605 --- /dev/null +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -0,0 +1,518 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=missing-class-docstring,missing-function-docstring + +"""Tests for S3ArtifactService.""" + +from datetime import datetime +from typing import Any +from typing import Optional +from unittest import mock + +from google.adk.artifacts.base_artifact_service import ArtifactVersion +from google.adk_community.artifacts import S3ArtifactService +from google.genai import types +import pytest + +# Define a fixed datetime object for consistent testing +FIXED_DATETIME = datetime(2025, 1, 1, 12, 0, 0) + + +class MockS3Object: + """Mocks an S3 object.""" + + def __init__(self, key: str) -> None: + self.key = key + self.data: Optional[bytes] = None + self.content_type: Optional[str] = None + self.last_modified = FIXED_DATETIME + self.metadata: dict[str, Any] = {} + + def set_data(self, data: bytes, content_type: str, metadata: dict[str, Any]): + """Sets the object data.""" + self.data = data + self.content_type = content_type + self.metadata = metadata or {} + + +class MockS3Bucket: + """Mocks an S3 bucket.""" + + def __init__(self, name: str) -> None: + self.name = name + self.objects: dict[str, MockS3Object] = {} + + +class MockS3Client: + """Mocks the boto3 S3 client.""" + + def __init__(self, **kwargs) -> None: + self.buckets: dict[str, MockS3Bucket] = {} + self.exceptions = type( + "Exceptions", (), {"NoSuchKey": KeyError, "NoSuchBucket": Exception} + )() + + def head_bucket(self, Bucket: str): + """Mocks head_bucket call.""" + if Bucket not in self.buckets: + self.buckets[Bucket] = MockS3Bucket(Bucket) + return {} + + def put_object( + self, + Bucket: str, + Key: str, + Body: bytes, + ContentType: str, + Metadata: Optional[dict[str, str]] = None, + **kwargs, + ): + """Mocks put_object call.""" + if Bucket not in self.buckets: + self.buckets[Bucket] = MockS3Bucket(Bucket) + bucket = self.buckets[Bucket] + if Key not in bucket.objects: + bucket.objects[Key] = MockS3Object(Key) + bucket.objects[Key].set_data(Body, ContentType, Metadata or {}) + + def get_object(self, Bucket: str, Key: str): + """Mocks get_object call.""" + bucket = self.buckets.get(Bucket) + if not bucket or Key not in bucket.objects: + raise self.exceptions.NoSuchKey(f"Object {Key} not found") + obj = bucket.objects[Key] + if obj.data is None: + raise self.exceptions.NoSuchKey(f"Object {Key} not found") + + class MockBody: + + def __init__(self, data: bytes): + self._data = data + + def read(self) -> bytes: + return self._data + + return { + "Body": MockBody(obj.data), + "ContentType": obj.content_type, + "LastModified": obj.last_modified, + "Metadata": obj.metadata, + } + + def head_object(self, Bucket: str, Key: str): + """Mocks head_object call.""" + bucket = self.buckets.get(Bucket) + if not bucket or Key not in bucket.objects: + raise self.exceptions.NoSuchKey(f"Object {Key} not found") + obj = bucket.objects[Key] + if obj.data is None: + raise self.exceptions.NoSuchKey(f"Object {Key} not found") + return { + "ContentType": obj.content_type, + "LastModified": obj.last_modified, + "Metadata": obj.metadata, + } + + def delete_object(self, Bucket: str, Key: str): + """Mocks delete_object call.""" + bucket = self.buckets.get(Bucket) + if bucket and Key in bucket.objects: + del bucket.objects[Key] + + def list_objects_v2(self, Bucket: str, Prefix: str = ""): + """Mocks list_objects_v2 call.""" + bucket = self.buckets.get(Bucket) + if not bucket: + return {} + + contents = [] + for key, obj in bucket.objects.items(): + if key.startswith(Prefix) and obj.data is not None: + contents.append({"Key": key}) + + if contents: + return {"Contents": contents} + return {} + + +@pytest.fixture +def mock_s3_service(): + """Provides a mocked S3ArtifactService for testing.""" + with mock.patch("boto3.client", return_value=MockS3Client()): + return S3ArtifactService(bucket_name="test_bucket") + + +@pytest.mark.asyncio +async def test_load_empty(mock_s3_service): + """Tests loading an artifact when none exists.""" + assert not await mock_s3_service.load_artifact( + app_name="test_app", + user_id="test_user", + session_id="session_id", + filename="filename", + ) + + +@pytest.mark.asyncio +async def test_save_load_delete(mock_s3_service): + """Tests saving, loading, and deleting an artifact.""" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "file456" + + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + assert ( + await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + == artifact + ) + + # Attempt to load a version that doesn't exist + assert not await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=3, + ) + + await mock_s3_service.delete_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert not await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + +@pytest.mark.asyncio +async def test_list_keys(mock_s3_service): + """Tests listing keys in the artifact service.""" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "filename" + filenames = [filename + str(i) for i in range(5)] + + for f in filenames: + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=f, + artifact=artifact, + ) + + assert ( + await mock_s3_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + == filenames + ) + + +@pytest.mark.asyncio +async def test_list_versions(mock_s3_service): + """Tests listing versions of an artifact.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "with/slash/filename" + versions = [ + types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + for i in range(3) + ] + versions.append(types.Part.from_text(text="hello")) + + for i in range(4): + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=versions[i], + ) + + response_versions = await mock_s3_service.list_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + assert response_versions == list(range(4)) + + +@pytest.mark.asyncio +async def test_list_keys_preserves_user_prefix(mock_s3_service): + """Tests that list_artifact_keys preserves 'user:' prefix in returned names.""" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + app_name = "app0" + user_id = "user0" + session_id = "123" + + # Save artifacts with "user:" prefix (cross-session artifacts) + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename="user:document.pdf", + artifact=artifact, + ) + + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename="user:image.png", + artifact=artifact, + ) + + # Save session-scoped artifact without prefix + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename="session_file.txt", + artifact=artifact, + ) + + # List artifacts should return names with "user:" prefix for user-scoped + artifact_keys = await mock_s3_service.list_artifact_keys( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Should contain prefixed names and session file + expected_keys = ["session_file.txt", "user:document.pdf", "user:image.png"] + assert sorted(artifact_keys) == sorted(expected_keys) + + +@pytest.mark.asyncio +async def test_list_artifact_versions_and_get_artifact_version( + mock_s3_service, +): + """Tests listing artifact versions and getting a specific version.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "filename" + versions = [ + types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + for i in range(4) + ] + + for i in range(4): + custom_metadata = {"key": "value" + str(i)} + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=versions[i], + custom_metadata=custom_metadata, + ) + + artifact_versions = await mock_s3_service.list_artifact_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + assert len(artifact_versions) == 4 + for i, av in enumerate(artifact_versions): + assert av.version == i + assert av.canonical_uri == f"s3://test_bucket/{app_name}/{user_id}/{session_id}/{filename}/{i}" + assert av.custom_metadata["key"] == f"value{i}" + assert av.mime_type == "text/plain" + + # Get latest artifact version when version is not specified + latest = await mock_s3_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + assert latest is not None + assert latest.version == 3 + + # Get artifact version by version number + specific = await mock_s3_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=2, + ) + assert specific is not None + assert specific.version == 2 + + +@pytest.mark.asyncio +async def test_list_artifact_versions_with_user_prefix(mock_s3_service): + """Tests listing artifact versions with user prefix.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + user_scoped_filename = "user:document.pdf" + versions = [ + types.Part.from_bytes( + data=i.to_bytes(2, byteorder="big"), mime_type="text/plain" + ) + for i in range(4) + ] + + for i in range(4): + custom_metadata = {"key": "value" + str(i)} + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=user_scoped_filename, + artifact=versions[i], + custom_metadata=custom_metadata, + ) + + artifact_versions = await mock_s3_service.list_artifact_versions( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=user_scoped_filename, + ) + + assert len(artifact_versions) == 4 + for i, av in enumerate(artifact_versions): + assert av.version == i + # User-scoped: {app}/{user}/user/document.pdf/{version} + assert av.canonical_uri == f"s3://test_bucket/{app_name}/{user_id}/user/document.pdf/{i}" + + +@pytest.mark.asyncio +async def test_get_artifact_version_artifact_does_not_exist(mock_s3_service): + """Tests getting an artifact version when artifact does not exist.""" + assert not await mock_s3_service.get_artifact_version( + app_name="test_app", + user_id="test_user", + session_id="session_id", + filename="filename", + ) + + +@pytest.mark.asyncio +async def test_get_artifact_version_out_of_index(mock_s3_service): + """Tests loading an artifact with an out-of-index version.""" + app_name = "app0" + user_id = "user0" + session_id = "123" + filename = "filename" + artifact = types.Part.from_bytes(data=b"test_data", mime_type="text/plain") + + await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + + # Attempt to get a version that doesn't exist + assert not await mock_s3_service.get_artifact_version( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + version=3, + ) + + +@pytest.mark.asyncio +async def test_special_characters_in_filename(mock_s3_service): + """Tests URL encoding for special characters in filenames.""" + artifact = types.Part(text="Test content") + app_name = "app0" + user_id = "user0" + session_id = "123" + # Filename with special characters that need encoding + filename = "my file/with:special&chars.txt" + + version = await mock_s3_service.save_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + artifact=artifact, + ) + + loaded = await mock_s3_service.load_artifact( + app_name=app_name, + user_id=user_id, + session_id=session_id, + filename=filename, + ) + + assert loaded is not None + # Loaded artifacts come back as inline_data (bytes), not text + assert loaded.inline_data is not None + assert loaded.inline_data.data == b"Test content" + + +@pytest.mark.asyncio +async def test_custom_metadata(mock_s3_service): + """Tests custom metadata storage and retrieval.""" + artifact = types.Part(text="Test") + custom_metadata = {"author": "test", "tags": "integration,test"} + + await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="test.txt", + artifact=artifact, + custom_metadata=custom_metadata, + ) + + version_info = await mock_s3_service.get_artifact_version( + app_name="app0", + user_id="user0", + session_id="123", + filename="test.txt", + ) + + assert version_info is not None + assert version_info.custom_metadata["author"] == "test" + assert version_info.custom_metadata["tags"] == "integration,test" + From 63e1d0ebd19d58d772ca8a4603f7d513235aaa9d Mon Sep 17 00:00:00 2001 From: Nishar Miya <103556082+miyannishar@users.noreply.github.com> Date: Thu, 11 Dec 2025 02:32:39 +0600 Subject: [PATCH 2/4] added boto3 to optional dependency --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 0d6c566..7b81695 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ classifiers = [ # List of https://pypi.org/classifiers/ ] dependencies = [ # go/keep-sorted start - "boto3>=1.28.0", # For S3ArtifactService "google-adk", # Google ADK "google-genai>=1.21.1, <2.0.0", # Google GenAI SDK "httpx>=0.27.0, <1.0.0", # For OpenMemory service @@ -42,6 +41,9 @@ changelog = "https://github.com/google/adk-python-community/blob/main/CHANGELOG. documentation = "https://google.github.io/adk-docs/" [project.optional-dependencies] +s3 = [ + "boto3>=1.28.0", # For S3ArtifactService +] test = [ "pytest>=8.4.2", "pytest-asyncio>=1.2.0", From a5a987629b33b6f7de7032af59d5c4d8e6ad95e5 Mon Sep 17 00:00:00 2001 From: Nishar Miya <103556082+miyannishar@users.noreply.github.com> Date: Thu, 11 Dec 2025 02:34:24 +0600 Subject: [PATCH 3/4] added specific error hndling --- .../artifacts/s3_artifact_service.py | 101 ++++++++++++++++-- 1 file changed, 93 insertions(+), 8 deletions(-) diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py index 3ec7638..f07b911 100644 --- a/src/google/adk_community/artifacts/s3_artifact_service.py +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -34,6 +34,7 @@ from google.adk.errors.input_validation_error import InputValidationError from google.genai import types from typing_extensions import override +from botocore.exceptions import ClientError logger = logging.getLogger("google_adk_community." + __name__) @@ -73,8 +74,19 @@ def __init__( try: self.s3_client.head_bucket(Bucket=self.bucket_name) logger.info("S3ArtifactService initialized with bucket: %s", bucket_name) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Cannot access S3 bucket '%s': %s (Error: %s)", + bucket_name, + e, + error_code, + ) + raise except Exception as e: - logger.error("Cannot access S3 bucket '%s': %s", bucket_name, e) + logger.error( + "Unexpected error accessing S3 bucket '%s': %s", bucket_name, e + ) raise def _encode_filename(self, filename: str) -> str: @@ -310,8 +322,19 @@ def _save_artifact_sync( object_key, ) return version + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to save artifact '%s' to S3: %s (Error: %s)", + filename, + e, + error_code, + ) + raise except Exception as e: - logger.error("Failed to save artifact '%s' to S3: %s", filename, e) + logger.error( + "Unexpected error saving artifact '%s' to S3: %s", filename, e + ) raise def _load_artifact_sync( @@ -362,8 +385,19 @@ def _load_artifact_sync( "Artifact %s version %d not found in S3", filename, version ) return None + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to load artifact '%s' from S3: %s (Error: %s)", + filename, + e, + error_code, + ) + raise except Exception as e: - logger.error("Failed to load artifact '%s' from S3: %s", filename, e) + logger.error( + "Unexpected error loading artifact '%s' from S3: %s", filename, e + ) raise def _list_artifact_keys_sync( @@ -391,9 +425,19 @@ def _list_artifact_keys_sync( encoded_filename = parts[0] filename = self._decode_filename(encoded_filename) filenames.add(filename) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to list session artifacts for %s: %s (Error: %s)", + session_id, + e, + error_code, + ) except Exception as e: logger.error( - "Failed to list session artifacts for %s: %s", session_id, e + "Unexpected error listing session artifacts for %s: %s", + session_id, + e, ) # List user-scoped artifacts @@ -411,8 +455,18 @@ def _list_artifact_keys_sync( encoded_filename = parts[0] filename = self._decode_filename(encoded_filename) filenames.add(f"user:{filename}") + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to list user artifacts for %s: %s (Error: %s)", + user_id, + e, + error_code, + ) except Exception as e: - logger.error("Failed to list user artifacts for %s: %s", user_id, e) + logger.error( + "Unexpected error listing user artifacts for %s: %s", user_id, e + ) return sorted(list(filenames)) @@ -438,8 +492,18 @@ def _delete_artifact_sync( try: self.s3_client.delete_object(Bucket=self.bucket_name, Key=object_key) logger.debug("Deleted S3 object: %s", object_key) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to delete S3 object %s: %s (Error: %s)", + object_key, + e, + error_code, + ) except Exception as e: - logger.error("Failed to delete S3 object %s: %s", object_key, e) + logger.error( + "Unexpected error deleting S3 object %s: %s", object_key, e + ) def _list_versions_sync( self, @@ -481,8 +545,19 @@ def _list_versions_sync( if version_str.isdigit(): versions.append(int(version_str)) return sorted(versions) + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to list versions for '%s': %s (Error: %s)", + filename, + e, + error_code, + ) + return [] except Exception as e: - logger.error("Failed to list versions for '%s': %s", filename, e) + logger.error( + "Unexpected error listing versions for '%s': %s", filename, e + ) return [] def _get_artifact_version_sync( @@ -536,9 +611,19 @@ def _get_artifact_version_sync( "Artifact %s version %d not found in S3", filename, version ) return None + except ClientError as e: + error_code = e.response.get("Error", {}).get("Code", "Unknown") + logger.error( + "Failed to get artifact version for '%s' version %d: %s (Error: %s)", + filename, + version, + e, + error_code, + ) + return None except Exception as e: logger.error( - "Failed to get artifact version for '%s' version %d: %s", + "Unexpected error getting artifact version for '%s' version %d: %s", filename, version, e, From bf0ced1cc4a2827b6db0a380eb4c9e9a3545147c Mon Sep 17 00:00:00 2001 From: Nishar Miya <103556082+miyannishar@users.noreply.github.com> Date: Thu, 11 Dec 2025 02:35:43 +0600 Subject: [PATCH 4/4] address comment for empty file --- .../artifacts/s3_artifact_service.py | 3 -- .../artifacts/test_s3_artifact_service.py | 29 +++++++++++++++++++ 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/google/adk_community/artifacts/s3_artifact_service.py b/src/google/adk_community/artifacts/s3_artifact_service.py index f07b911..7213545 100644 --- a/src/google/adk_community/artifacts/s3_artifact_service.py +++ b/src/google/adk_community/artifacts/s3_artifact_service.py @@ -368,9 +368,6 @@ def _load_artifact_sync( content_type = response.get("ContentType", "application/octet-stream") data = response["Body"].read() - if not data: - return None - artifact = types.Part.from_bytes(data=data, mime_type=content_type) logger.debug( "Loaded artifact %s version %d from S3 key %s", diff --git a/tests/unittests/artifacts/test_s3_artifact_service.py b/tests/unittests/artifacts/test_s3_artifact_service.py index 0d9f605..8c5512b 100644 --- a/tests/unittests/artifacts/test_s3_artifact_service.py +++ b/tests/unittests/artifacts/test_s3_artifact_service.py @@ -516,3 +516,32 @@ async def test_custom_metadata(mock_s3_service): assert version_info.custom_metadata["author"] == "test" assert version_info.custom_metadata["tags"] == "integration,test" + +@pytest.mark.asyncio +async def test_empty_artifact(mock_s3_service): + """Tests saving and loading empty (0-byte) artifacts.""" + # Create empty artifact + empty_artifact = types.Part.from_bytes(data=b"", mime_type="text/plain") + + version = await mock_s3_service.save_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="empty.txt", + artifact=empty_artifact, + ) + + assert version == 0 + + # Load empty artifact - should succeed, not return None + loaded = await mock_s3_service.load_artifact( + app_name="app0", + user_id="user0", + session_id="123", + filename="empty.txt", + ) + + assert loaded is not None + assert loaded.inline_data is not None + assert loaded.inline_data.data == b"" +