diff --git a/veadk/tools/vanna_tools/agent_memory.py b/veadk/tools/vanna_tools/agent_memory.py new file mode 100644 index 00000000..f4034ef4 --- /dev/null +++ b/veadk/tools/vanna_tools/agent_memory.py @@ -0,0 +1,324 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.agent_memory import ( + SaveQuestionToolArgsTool as VannaSaveQuestionToolArgsTool, + SearchSavedCorrectToolUsesTool as VannaSearchSavedCorrectToolUsesTool, + SaveTextMemoryTool as VannaSaveTextMemoryTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class SaveQuestionToolArgsTool(BaseTool): + """Save successful question-tool-argument combinations for future reference.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the save tool usage tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin']) + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSaveQuestionToolArgsTool() + self.access_groups = access_groups or ["admin"] # Default: only admin + + super().__init__( + name="save_question_tool_args", # Keep the same name as Vanna + description="Save a successful question-tool-argument combination for future reference.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "question": types.Schema( + type=types.Type.STRING, + description="The original question that was asked", + ), + "tool_name": types.Schema( + type=types.Type.STRING, + description="The name of the tool that was used successfully", + ), + "args": types.Schema( + type=types.Type.OBJECT, + description="The arguments that were passed to the tool", + ), + }, + required=["question", "tool_name", "args"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Save a tool usage pattern.""" + question = args.get("question", "").strip() + tool_name = args.get("tool_name", "").strip() + tool_args = args.get("args", {}) + + if not question: + return "Error: No question provided" + + if not tool_name: + return "Error: No tool name provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + question=question, tool_name=tool_name, args=tool_args + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error saving tool usage: {str(e)}" + + +class SearchSavedCorrectToolUsesTool(BaseTool): + """Search for similar tool usage patterns based on a question.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the search similar tools tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSearchSavedCorrectToolUsesTool() + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="search_saved_correct_tool_uses", # Keep the same name as Vanna + description="Search for similar tool usage patterns based on a question.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "question": types.Schema( + type=types.Type.STRING, + description="The question to find similar tool usage patterns for", + ), + "limit": types.Schema( + type=types.Type.INTEGER, + description="Maximum number of results to return (default: 10)", + ), + }, + required=["question"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Search for similar tool usage patterns.""" + question = args.get("question", "").strip() + limit = args.get("limit", 10) + + if not question: + return "Error: No question provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + question=question, limit=limit + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error searching similar tools: {str(e)}" + + +class SaveTextMemoryTool(BaseTool): + """Save free-form text memories for important insights, observations, or context.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the save text memory tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSaveTextMemoryTool() + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="save_text_memory", # Keep the same name as Vanna + description="Save free-form text memory for important insights, observations, or context.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "content": types.Schema( + type=types.Type.STRING, + description="The text content to save as a memory", + ), + }, + required=["content"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Save a text memory.""" + content = args.get("content", "").strip() + + if not content: + return "Error: No content provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()(content=content) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error saving text memory: {str(e)}" diff --git a/veadk/tools/vanna_tools/examples/agent.py b/veadk/tools/vanna_tools/examples/agent.py new file mode 100644 index 00000000..917b17bf --- /dev/null +++ b/veadk/tools/vanna_tools/examples/agent.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from veadk import Agent, Runner + +# Import Vanna dependencies for initialization +from vanna.integrations.sqlite import SqliteRunner +from vanna.tools import LocalFileSystem +from vanna.integrations.local.agent_memory import DemoAgentMemory +import httpx + +# Import the refactored class-based tools +from veadk.tools.vanna_tools.run_sql import RunSqlTool +from veadk.tools.vanna_tools.visualize_data import VisualizeDataTool +from veadk.tools.vanna_tools.file_system import WriteFileTool +from veadk.tools.vanna_tools.agent_memory import ( + SaveQuestionToolArgsTool, + SearchSavedCorrectToolUsesTool, +) +from veadk.tools.vanna_tools.summarize_data import SummarizeDataTool + +from google.adk.sessions import InMemorySessionService + + +# Setup SQLite database +def setup_sqlite(): + """Download and setup the Chinook SQLite database.""" + db_path = "/tmp/Chinook.sqlite" + if not os.path.exists(db_path): + print("Downloading Chinook.sqlite...") + url = "https://vanna.ai/Chinook.sqlite" + try: + with open(db_path, "wb") as f: + with httpx.stream("GET", url) as response: + for chunk in response.iter_bytes(): + f.write(chunk) + print("Database downloaded successfully!") + except Exception as e: + print(f"Error downloading database: {e}") + return db_path + + +# Create a session with user groups for access control +async def create_session(user_groups: list = ["user"]): + session_service = InMemorySessionService() + example_session = await session_service.create_session( + app_name="example_app", + user_id="example_user", + state={"user_groups": user_groups}, + ) + return session_service, example_session + + +# Initialize user-customizable resources +db_path = setup_sqlite() + +# 1. SQL Runner - can be SqliteRunner, PostgresRunner, MySQLRunner, etc. +sqlite_runner = SqliteRunner(database_path=db_path) + +# 2. File System - customize working directory as needed +file_system = LocalFileSystem(working_directory="/tmp/data_storage") +if not os.path.exists("/tmp/data_storage"): + os.makedirs("/tmp/data_storage", exist_ok=True) + +# 3. Agent Memory - customize memory implementation and capacity +agent_memory = DemoAgentMemory(max_items=1000) + +# Initialize tools with user-defined components and access control +# Tool names now match Vanna's original names for compatibility +run_sql_tool = RunSqlTool( + sql_runner=sqlite_runner, + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], # Both admin and user can use +) + +visualize_data_tool = VisualizeDataTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +write_file_tool = WriteFileTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +# Memory tools: save only for admin, search for all users +save_tool = SaveQuestionToolArgsTool( + agent_memory=agent_memory, + access_groups=["admin"], # Only admin can save +) + +search_tool = SearchSavedCorrectToolUsesTool( + agent_memory=agent_memory, + access_groups=["admin", "user"], # All users can search +) + +summarize_data_tool = SummarizeDataTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +# Define the Veadk Agent with class-based tools +agent: Agent = Agent( + name="vanna_sql_agent", + description="An intelligent agent that can query databases, visualize data, and generate reports.", + instruction=""" + You are a helpful assistant that can answer questions about data in the Chinook database. + You can execute SQL queries, visualize the results, save/search useful tool usage patterns, and generate documents. + + Here is the schema of the Chinook database: + ```sql + CREATE TABLE [Album] + ( + [AlbumId] INTEGER NOT NULL, + [Title] NVARCHAR(160) NOT NULL, + [ArtistId] INTEGER NOT NULL, + CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]), + FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Artist] + ( + [ArtistId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Artist] PRIMARY KEY ([ArtistId]) + ); + CREATE TABLE [Customer] + ( + [CustomerId] INTEGER NOT NULL, + [FirstName] NVARCHAR(40) NOT NULL, + [LastName] NVARCHAR(20) NOT NULL, + [Company] NVARCHAR(80), + [Address] NVARCHAR(70), + [City] NVARCHAR(40), + [State] NVARCHAR(40), + [Country] NVARCHAR(40), + [PostalCode] NVARCHAR(10), + [Phone] NVARCHAR(24), + [Fax] NVARCHAR(24), + [Email] NVARCHAR(60) NOT NULL, + [SupportRepId] INTEGER, + CONSTRAINT [PK_Customer] PRIMARY KEY ([CustomerId]), + FOREIGN KEY ([SupportRepId]) REFERENCES [Employee] ([EmployeeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Employee] + ( + [EmployeeId] INTEGER NOT NULL, + [LastName] NVARCHAR(20) NOT NULL, + [FirstName] NVARCHAR(20) NOT NULL, + [Title] NVARCHAR(30), + [ReportsTo] INTEGER, + [BirthDate] DATETIME, + [HireDate] DATETIME, + [Address] NVARCHAR(70), + [City] NVARCHAR(40), + [State] NVARCHAR(40), + [Country] NVARCHAR(40), + [PostalCode] NVARCHAR(10), + [Phone] NVARCHAR(24), + [Fax] NVARCHAR(24), + [Email] NVARCHAR(60), + CONSTRAINT [PK_Employee] PRIMARY KEY ([EmployeeId]), + FOREIGN KEY ([ReportsTo]) REFERENCES [Employee] ([EmployeeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Genre] + ( + [GenreId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Genre] PRIMARY KEY ([GenreId]) + ); + CREATE TABLE [Invoice] + ( + [InvoiceId] INTEGER NOT NULL, + [CustomerId] INTEGER NOT NULL, + [InvoiceDate] DATETIME NOT NULL, + [BillingAddress] NVARCHAR(70), + [BillingCity] NVARCHAR(40), + [BillingState] NVARCHAR(40), + [BillingCountry] NVARCHAR(40), + [BillingPostalCode] NVARCHAR(10), + [Total] NUMERIC(10,2) NOT NULL, + CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]), + FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [InvoiceLine] + ( + [InvoiceLineId] INTEGER NOT NULL, + [InvoiceId] INTEGER NOT NULL, + [TrackId] INTEGER NOT NULL, + [UnitPrice] NUMERIC(10,2) NOT NULL, + [Quantity] INTEGER NOT NULL, + CONSTRAINT [PK_InvoiceLine] PRIMARY KEY ([InvoiceLineId]), + FOREIGN KEY ([InvoiceId]) REFERENCES [Invoice] ([InvoiceId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [MediaType] + ( + [MediaTypeId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_MediaType] PRIMARY KEY ([MediaTypeId]) + ); + CREATE TABLE [Playlist] + ( + [PlaylistId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Playlist] PRIMARY KEY ([PlaylistId]) + ); + CREATE TABLE [PlaylistTrack] + ( + [PlaylistId] INTEGER NOT NULL, + [TrackId] INTEGER NOT NULL, + CONSTRAINT [PK_PlaylistTrack] PRIMARY KEY ([PlaylistId], [TrackId]), + FOREIGN KEY ([PlaylistId]) REFERENCES [Playlist] ([PlaylistId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Track] + ( + [TrackId] INTEGER NOT NULL, + [Name] NVARCHAR(200) NOT NULL, + [AlbumId] INTEGER, + [MediaTypeId] INTEGER NOT NULL, + [GenreId] INTEGER, + [Composer] NVARCHAR(220), + [Milliseconds] INTEGER NOT NULL, + [Bytes] INTEGER, + [UnitPrice] NUMERIC(10,2) NOT NULL, + CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]), + FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + ``` + + Here are some examples of how to query this database: + + Q: Get all the tracks in the album 'Balls to the Wall'. + A: SELECT * FROM Track WHERE AlbumId = (SELECT AlbumId FROM Album WHERE Title = 'Balls to the Wall') + + Q: Get the total sales for each customer. + A: SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId + + Q: How many tracks are there in each genre? + A: SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId + + Available tools (using Vanna's original names): + 1. `run_sql` - Execute SQL queries + 2. `visualize_data` - Create visualizations from CSV files + 3. `write_file` - Save content to files + 4. `save_question_tool_args` - Save successful tool usage patterns (admin only) + 5. `search_saved_correct_tool_uses` - Search for similar tool usage patterns + 6. `summarize_data` - Generate statistical summaries of CSV files + """, + tools=[ + run_sql_tool, + visualize_data_tool, + write_file_tool, + save_tool, + search_tool, + summarize_data_tool, + ], + model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}}, +) + + +async def main(prompt: str, user_groups: list = None) -> str: + session_service, example_session = await create_session( + user_groups + ) # Default to 'user' group if not specified + + runner = Runner( + agent=agent, + app_name=example_session.app_name, + user_id=example_session.user_id, + session_service=session_service, + ) + + response = await runner.run( + messages=prompt, + session_id=example_session.id, + ) + + return response + + +if __name__ == "__main__": + import asyncio + + # print("=== Example 1: Regular User ===") + # user_input = "What are the top 5 selling albums?" + # response = asyncio.run(main(user_input, user_groups=['user'])) + # print(response) + + print("\n=== Example 2: Admin User (can save patterns) ===") + admin_input = "What are the top 5 selling albums?" + response = asyncio.run(main(admin_input, user_groups=["admin"])) + print(response) diff --git a/veadk/tools/vanna_tools/file_system.py b/veadk/tools/vanna_tools/file_system.py new file mode 100644 index 00000000..432fb112 --- /dev/null +++ b/veadk/tools/vanna_tools/file_system.py @@ -0,0 +1,467 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.file_system import ( + WriteFileTool as VannaWriteFileTool, + ReadFileTool as VannaReadFileTool, + ListFilesTool as VannaListFilesTool, + SearchFilesTool as VannaSearchFilesTool, + EditFileTool as VannaEditFileTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class WriteFileTool(BaseTool): + """Write content to a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the write file tool with custom file_system. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + access_groups: List of user groups that can access this tool + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaWriteFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="write_file", + description="Write content to a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Name of the file to write", + ), + "content": types.Schema( + type=types.Type.STRING, + description="Content to write to the file", + ), + "overwrite": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to overwrite existing files (default: False)", + ), + }, + required=["filename", "content"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + content = args.get("content", "") + overwrite = args.get("overwrite", False) + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, content=content, overwrite=overwrite + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error writing file: {str(e)}" + + +class ReadFileTool(BaseTool): + """Read the contents of a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaReadFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="read_file", + description="Read the contents of a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Name of the file to read", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()(filename=filename) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error reading file: {str(e)}" + + +class ListFilesTool(BaseTool): + """List files in a directory.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaListFilesTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="list_files", + description="List files in a directory.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "directory": types.Schema( + type=types.Type.STRING, + description="Directory to list (defaults to current directory)", + ), + }, + required=[], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + directory = args.get("directory", ".") + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()(directory=directory) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error listing files: {str(e)}" + + +class SearchFilesTool(BaseTool): + """Search for files by name or content.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaSearchFilesTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="search_files", + description="Search for files by name or content.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="Text to search for in file names or contents", + ), + "include_content": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to search within file contents (default: True)", + ), + "max_results": types.Schema( + type=types.Type.INTEGER, + description="Maximum number of matches to return (default: 20)", + ), + }, + required=["query"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + query = args.get("query", "").strip() + include_content = args.get("include_content", True) + max_results = args.get("max_results", 20) + + if not query: + return "Error: No search query provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + query=query, include_content=include_content, max_results=max_results + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error searching files: {str(e)}" + + +class EditFileTool(BaseTool): + """Modify specific lines within a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaEditFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="edit_file", + description="Modify specific lines within a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Path to the file to edit", + ), + "edits": types.Schema( + type=types.Type.ARRAY, + description="List of edits to apply", + items=types.Schema( + type=types.Type.OBJECT, + properties={ + "start_line": types.Schema( + type=types.Type.INTEGER, + description="First line (1-based) affected by this edit", + ), + "end_line": types.Schema( + type=types.Type.INTEGER, + description="Last line (1-based, inclusive) to replace", + ), + "new_content": types.Schema( + type=types.Type.STRING, + description="Replacement text", + ), + }, + ), + ), + }, + required=["filename", "edits"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + edits = args.get("edits", []) + + if not filename: + return "Error: No filename provided" + + if not edits: + return "Error: No edits provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, edits=edits + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error editing file: {str(e)}" diff --git a/veadk/tools/vanna_tools/python.py b/veadk/tools/vanna_tools/python.py new file mode 100644 index 00000000..c2d3ba5a --- /dev/null +++ b/veadk/tools/vanna_tools/python.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.python import ( + RunPythonFileTool as VannaRunPythonFileTool, + PipInstallTool as VannaPipInstallTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class RunPythonFileTool(BaseTool): + """Execute a Python file using the workspace interpreter.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the run Python file tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaRunPythonFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="run_python_file", + description="Execute a Python file using the workspace interpreter.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Python file to execute (relative to the workspace root)", + ), + "arguments": types.Schema( + type=types.Type.ARRAY, + description="Optional arguments to pass to the Python script", + items=types.Schema(type=types.Type.STRING), + ), + "timeout_seconds": types.Schema( + type=types.Type.NUMBER, + description="Optional timeout for the command in seconds", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + arguments = args.get("arguments", []) + timeout_seconds = args.get("timeout_seconds") + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, arguments=arguments, timeout_seconds=timeout_seconds + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error running Python file: {str(e)}" + + +class PipInstallTool(BaseTool): + """Install Python packages using pip.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the pip install tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (default: admin only) + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaPipInstallTool(file_system=file_system) + self.access_groups = access_groups or [ + "admin" + ] # Default: only admin can install packages + + super().__init__( + name="pip_install", + description="Install Python packages using pip.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "packages": types.Schema( + type=types.Type.ARRAY, + description="Packages (with optional specifiers) to install", + items=types.Schema(type=types.Type.STRING), + ), + "upgrade": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to include --upgrade in the pip invocation (default: False)", + ), + "extra_args": types.Schema( + type=types.Type.ARRAY, + description="Additional arguments to pass to pip install", + items=types.Schema(type=types.Type.STRING), + ), + "timeout_seconds": types.Schema( + type=types.Type.NUMBER, + description="Optional timeout for the command in seconds", + ), + }, + required=["packages"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + packages = args.get("packages", []) + upgrade = args.get("upgrade", False) + extra_args = args.get("extra_args", []) + timeout_seconds = args.get("timeout_seconds") + + if not packages: + return "Error: No packages provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + packages=packages, + upgrade=upgrade, + extra_args=extra_args, + timeout_seconds=timeout_seconds, + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error installing packages: {str(e)}" diff --git a/veadk/tools/vanna_tools/run_sql.py b/veadk/tools/vanna_tools/run_sql.py new file mode 100644 index 00000000..0a083423 --- /dev/null +++ b/veadk/tools/vanna_tools/run_sql.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools import RunSqlTool as VannaRunSqlTool +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class RunSqlTool(BaseTool): + """Execute SQL queries against a database.""" + + def __init__( + self, + sql_runner, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the SQL tool with custom sql_runner and file_system. + + Args: + sql_runner: A Vanna SQL runner instance (e.g., SqliteRunner, PostgresRunner) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + """ + self.sql_runner = sql_runner + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaRunSqlTool( + sql_runner=sql_runner, file_system=file_system + ) + self.access_groups = access_groups or ["admin", "user"] # Default: all groups + + super().__init__( + name="run_sql", # Keep the same name as Vanna + description="Execute a SQL query against the database and return results as a CSV file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "sql": types.Schema( + type=types.Type.STRING, + description="The SQL query to execute", + ), + }, + required=["sql"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Execute the SQL query.""" + sql = args.get("sql", "").strip() + + if not sql: + return "Error: No SQL query provided" + + try: + # Get user groups and check access + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + # Create Vanna context once per request + vanna_context = self._create_vanna_context(tool_context, user_groups) + + # Execute using Vanna tool + args_model = self.vanna_tool.get_args_schema()(sql=sql) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error executing SQL query: {str(e)}" diff --git a/veadk/tools/vanna_tools/summarize_data.py b/veadk/tools/vanna_tools/summarize_data.py new file mode 100644 index 00000000..1a5a407a --- /dev/null +++ b/veadk/tools/vanna_tools/summarize_data.py @@ -0,0 +1,124 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import io +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class SummarizeDataTool(BaseTool): + """Generate statistical summaries of CSV data files.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the summarize data tool with custom file_system. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + """ + self.agent_memory = agent_memory + self.file_system = file_system + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="summarize_data", + description="Generate a statistical summary of data from a CSV file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="The name of the CSV file to summarize", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Generate a statistical summary of CSV data.""" + filename = args.get("filename", "").strip() + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + # Read the file content + content = await self.file_system.read_file(filename, vanna_context) + + # Parse into DataFrame + df = pd.read_csv(io.StringIO(content)) + + # Generate summary stats + description = df.describe().to_markdown() + head = df.head().to_markdown() + info = f"Rows: {len(df)}, Columns: {len(df.columns)}\nColumn Names: {', '.join(df.columns)}" + + summary = f"**Data Summary for {filename}**\n\n**Info:**\n{info}\n\n**First 5 Rows:**\n{head}\n\n**Statistical Description:**\n{description}" + return summary + except Exception as e: + return f"Failed to summarize data: {str(e)}" diff --git a/veadk/tools/vanna_tools/visualize_data.py b/veadk/tools/vanna_tools/visualize_data.py new file mode 100644 index 00000000..4a0fee8c --- /dev/null +++ b/veadk/tools/vanna_tools/visualize_data.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools import VisualizeDataTool as VannaVisualizeDataTool +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class VisualizeDataTool(BaseTool): + """Visualize data from CSV files.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the visualization tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaVisualizeDataTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="visualize_data", # Keep the same name as Vanna + description="Create visualizations from CSV data files.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="The name of the CSV file to visualize", + ), + "title": types.Schema( + type=types.Type.STRING, + description="Optional title for the chart", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Create a visualization from CSV data.""" + filename = args.get("filename", "").strip() + title = args.get("title") + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + filename=filename, title=title + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error visualizing data: {str(e)}"