diff --git a/examples/veadk-vanna-proj/__init__.py b/examples/veadk-vanna-proj/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/examples/veadk-vanna-proj/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/veadk-vanna-proj/clean.py b/examples/veadk-vanna-proj/clean.py new file mode 100644 index 00000000..f3f195f3 --- /dev/null +++ b/examples/veadk-vanna-proj/clean.py @@ -0,0 +1,23 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from veadk.cloud.cloud_app import CloudApp + +def main() -> None: + cloud_app = CloudApp(vefaas_application_name="veadk-cloud-vanna-agent") + cloud_app.delete_self() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/veadk-vanna-proj/config.yaml.example b/examples/veadk-vanna-proj/config.yaml.example new file mode 100644 index 00000000..7f2fdc08 --- /dev/null +++ b/examples/veadk-vanna-proj/config.yaml.example @@ -0,0 +1,6 @@ +model: + agent: + provider: openai + name: doubao-1-5-pro-256k-250115 + api_base: https://ark.cn-beijing.volces.com/api/v3/ + api_key: diff --git a/examples/veadk-vanna-proj/deploy.py b/examples/veadk-vanna-proj/deploy.py new file mode 100644 index 00000000..77468d92 --- /dev/null +++ b/examples/veadk-vanna-proj/deploy.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from pathlib import Path + +from a2a.types import TextPart +from fastmcp.client import Client + +from veadk.cloud.cloud_agent_engine import CloudAgentEngine +from veadk.cloud.cloud_app import CloudApp, get_message_id + +SESSION_ID = "cloud_app_test_session" +USER_ID = "cloud_app_test_user" + + +async def _send_msg_with_a2a(cloud_app: CloudApp, message: str) -> None: + print("===== A2A example =====") + + response_message = await cloud_app.message_send(message, SESSION_ID, USER_ID) + + if not response_message or not response_message.parts: + print( + "No response from VeFaaS application. Something wrong with cloud application." + ) + return + + print(f"Message ID: {get_message_id(response_message)}") + + if isinstance(response_message.parts[0].root, TextPart): + print( + f"Response from {cloud_app.vefaas_endpoint}: {response_message.parts[0].root.text}" + ) + else: + print( + f"Response from {cloud_app.vefaas_endpoint}: {response_message.parts[0].root}" + ) + + +async def _send_msg_with_mcp(cloud_app: CloudApp, message: str) -> None: + print("===== MCP example =====") + + endpoint = cloud_app._get_vefaas_endpoint() + print(f"MCP server endpoint: {endpoint}/mcp") + + # Connect to MCP server + client = Client(f"{endpoint}/mcp") + + async with client: + # List available tools + tools = await client.list_tools() + print(f"Available tools: {tools}") + + # Call run_agent tool, pass user input and session information + res = await client.call_tool( + "run_agent", + { + "user_input": message, + "session_id": SESSION_ID, + "user_id": USER_ID, + }, + ) + print(f"Response from {cloud_app.vefaas_endpoint}: {res}") + + +async def main(): + engine = CloudAgentEngine() + + cloud_app = engine.deploy( + path=str(Path(__file__).parent / "src"), + application_name="veadk-cloud-vanna-agent", + gateway_name="dong-mcp-agent2", + gateway_service_name="", + gateway_upstream_name="", + use_adk_web=True, + auth_method="none", + identity_user_pool_name="", + identity_client_name="", + local_test=False, # Set to True for local testing before deploy to VeFaaS + ) + print(f"VeFaaS application ID: {cloud_app.vefaas_application_id}") + + if False: + print(f"Web is running at: {cloud_app.vefaas_endpoint}") + else: + # Test with deployed cloud application + message = "How is the weather like in Beijing?" + print(f"Test message: {message}") + + # await _send_msg_with_a2a(cloud_app=cloud_app, message=message) + # await _send_msg_with_mcp(cloud_app=cloud_app, message=message) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/veadk-vanna-proj/src/.adk/session.db b/examples/veadk-vanna-proj/src/.adk/session.db new file mode 100644 index 00000000..dd390d0e Binary files /dev/null and b/examples/veadk-vanna-proj/src/.adk/session.db differ diff --git a/examples/veadk-vanna-proj/src/__init__.py b/examples/veadk-vanna-proj/src/__init__.py new file mode 100644 index 00000000..7f463206 --- /dev/null +++ b/examples/veadk-vanna-proj/src/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/veadk-vanna-proj/src/agent.py b/examples/veadk-vanna-proj/src/agent.py new file mode 100644 index 00000000..ab5948e3 --- /dev/null +++ b/examples/veadk-vanna-proj/src/agent.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from data_agent.agent import agent # type: ignore + +from veadk.memory.short_term_memory import ShortTermMemory +from veadk.types import AgentRunConfig + +# [required] instantiate the agent run configuration +agent_run_config = AgentRunConfig( + app_name="vanna_sql_agent", + agent=agent, # type: ignore + short_term_memory=ShortTermMemory(backend="local", local_database_path="/tmp/session.db"), # type: ignore + model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}} +) diff --git a/examples/veadk-vanna-proj/src/app.py b/examples/veadk-vanna-proj/src/app.py new file mode 100644 index 00000000..af0340ce --- /dev/null +++ b/examples/veadk-vanna-proj/src/app.py @@ -0,0 +1,171 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import asynccontextmanager +from typing import Callable + +from agent import agent_run_config + +from fastapi import FastAPI +from fastapi.routing import APIRoute + +from fastmcp import FastMCP + +from starlette.routing import Route + +from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder +from a2a.types import AgentProvider + +from veadk.a2a.ve_a2a_server import init_app +from veadk.runner import Runner +from veadk.types import AgentRunConfig +from veadk.utils.logger import get_logger +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator +from opentelemetry import context + +logger = get_logger(__name__) + +assert isinstance(agent_run_config, AgentRunConfig), ( + f"Invalid agent_run_config type: {type(agent_run_config)}, expected `AgentRunConfig`" +) + +app_name = agent_run_config.app_name +agent = agent_run_config.agent +short_term_memory = agent_run_config.short_term_memory + +VEFAAS_REGION = os.getenv("APP_REGION", "cn-beijing") +VEFAAS_FUNC_ID = os.getenv("_FAAS_FUNC_ID", "") +agent_card_builder = AgentCardBuilder(agent=agent, provider=AgentProvider(organization="Volcengine Agent Development Kit (VeADK)", url=f"https://console.volcengine.com/vefaas/region:vefaas+{VEFAAS_REGION}/function/detail/{VEFAAS_FUNC_ID}")) + + +def build_mcp_run_agent_func() -> Callable: + runner = Runner( + agent=agent, + short_term_memory=short_term_memory, + app_name=app_name, + user_id="", + ) + + async def run_agent( + user_input: str, + user_id: str = "mcp_user", + session_id: str = "mcp_session", + ) -> str: + # Set user_id for runner + runner.user_id = user_id + + # Running agent and get final output + final_output = await runner.run( + messages=user_input, + session_id=session_id, + ) + return final_output + + run_agent_doc = f"""{agent.description} + Args: + user_input: User's input message (required). + user_id: User identifier. Defaults to "mcp_user". + session_id: Session identifier. Defaults to "mcp_session". + Returns: + Final agent response as a string.""" + + run_agent.__doc__ = run_agent_doc + + return run_agent + + +async def agent_card() -> dict: + agent_card = await agent_card_builder.build() + return agent_card.model_dump() + +async def get_cozeloop_space_id() -> dict: + return {"space_id": os.getenv("OBSERVABILITY_OPENTELEMETRY_COZELOOP_SERVICE_NAME", default="")} + +# Build a run_agent function for building MCP server +run_agent_func = build_mcp_run_agent_func() + +a2a_app = init_app( + server_url="0.0.0.0", + app_name=app_name, + agent=agent, + short_term_memory=short_term_memory, +) + +a2a_app.post("/run_agent", operation_id="run_agent", tags=["mcp"])(run_agent_func) +a2a_app.get("/agent_card", operation_id="agent_card", tags=["mcp"])(agent_card) +a2a_app.get("/get_cozeloop_space_id", operation_id="get_cozeloop_space_id", tags=["mcp"])(get_cozeloop_space_id) + +# === Build mcp server === + +mcp = FastMCP.from_fastapi(app=a2a_app, name=app_name, include_tags={"mcp"}) + +# Create MCP ASGI app +mcp_app = mcp.http_app(path="/", transport="streamable-http") + + +# Combined lifespan management +@asynccontextmanager +async def combined_lifespan(app: FastAPI): + async with mcp_app.lifespan(app): + yield + + +# Create main FastAPI app with combined lifespan +app = FastAPI( + title=a2a_app.title, + version=a2a_app.version, + lifespan=combined_lifespan, + openapi_url=None, + docs_url=None, + redoc_url=None +) + +@app.middleware("http") +async def otel_context_middleware(request, call_next): + carrier = { + "traceparent": request.headers.get("Traceparent"), + "tracestate": request.headers.get("Tracestate"), + } + logger.debug(f"carrier: {carrier}") + if carrier["traceparent"] is None: + return await call_next(request) + else: + ctx = TraceContextTextMapPropagator().extract(carrier=carrier) + logger.debug(f"ctx: {ctx}") + token = context.attach(ctx) + try: + response = await call_next(request) + finally: + context.detach(token) + return response + +# Mount A2A routes to main app +for route in a2a_app.routes: + app.routes.append(route) + +# Mount MCP server at /mcp endpoint +app.mount("/mcp", mcp_app) + + +# remove openapi routes +paths = ["/openapi.json", "/docs", "/redoc"] +new_routes = [] +for route in app.router.routes: + if isinstance(route, (APIRoute, Route)) and route.path in paths: + continue + new_routes.append(route) +app.router.routes = new_routes + +# === Build mcp server end === diff --git a/examples/veadk-vanna-proj/src/data_agent/.adk/session.db b/examples/veadk-vanna-proj/src/data_agent/.adk/session.db new file mode 100644 index 00000000..6fb1f62f Binary files /dev/null and b/examples/veadk-vanna-proj/src/data_agent/.adk/session.db differ diff --git a/examples/veadk-vanna-proj/src/data_agent/__init__.py b/examples/veadk-vanna-proj/src/data_agent/__init__.py new file mode 100644 index 00000000..a2539111 --- /dev/null +++ b/examples/veadk-vanna-proj/src/data_agent/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent import agent + +# required from Google ADK Web +root_agent = agent diff --git a/examples/veadk-vanna-proj/src/data_agent/agent.py b/examples/veadk-vanna-proj/src/data_agent/agent.py new file mode 100644 index 00000000..3a392169 --- /dev/null +++ b/examples/veadk-vanna-proj/src/data_agent/agent.py @@ -0,0 +1,298 @@ +from veadk import Agent +from google.adk.planners import PlanReActPlanner +from .tools import ( + run_sql, + visualize_data, + save_correctanswer_memory, + search_similar_tools, + generate_document, + summarize_data, + run_python_file, + pip_install, + read_file, + edit_file, + list_files, + search_files, + save_text_memory, + query_with_dsl, + recall_metadata, + get_current_time, +) + +# # Define the Veadk Agent using Vanna Tools +# agent: Agent = Agent( +# name="b2b_data_agent", +# description="Assistant for querying B2B customer, revenue, and usage data.", +# instruction=""" +# You are a data analysis agent for a Cloud Service Provider. +# You have access to a SQLite database `b2b_crm.sqlite` with the following schema: + +# - `customer`: Stores customer profiles. Key fields: `name` (full name), `short_name`, `is_main_customer` (1=True), `sales_team`. +# - `revenue`: Monthly revenue data. Fields: `year_month` (YYYY-MM), `product_name`, `amount`. +# - `resource_usage`: Daily usage data. Fields: `usage_date` (YYYY-MM-DD), `resource_type` (Tokens, GPU), `quantity`. +# - `account_credit`: Credit status. Fields: `total_credit_limit`, `available_balance`, `arrears_amount` (positive means debt). + +# **Available Tools:** +# - `run_sql(sql)`: Executes SQL queries on the B2B CRM database. +# - `run_python_file(filename)`: Executes Python scripts. +# - `pip_install(packages)`: Installs Python packages. +# - `visualize_data(filename, title)`: Creates visualizations from CSV files generated by SQL queries. +# - `summarize_data(filename)`: Generates statistical summaries of CSV files. +# - `generate_document(filename, content)`: Creates a new file with the given content. +# - `read_file(filename, start_line, end_line)`: Reads the content of a file. +# - `edit_file(filename, edits)`: Edits a file by replacing lines. +# - `list_files(path)`: Lists files in a directory. +# - `search_files(query, path)`: Searches for files matching a query. +# - `search_similar_tools(question, limit)`: Searches for similar past tool usages. +# - `save_correctanswer_memory(question, tool_name, args)`: Saves successful tool usages. +# - `save_text_memory(text, tags)`: Saves arbitrary text to memory for future retrieval. + +# **Strategy for Ambiguous Requests:** +# 1. **Name Disambiguation**: If a user asks for "Xiaomi" (小米), ALWAYS check `customer` table first. Prefer `is_main_customer=1` unless specified otherwise. +# - *Example SQL*: `SELECT * FROM customer WHERE (name LIKE '%小米%' OR short_name = '小米') AND is_main_customer = 1` +# 2. **Time Ranges**: +# - "Last 3 months" usually means the last 3 completed billing cycles in `revenue` table. +# - "Recent trend" implies querying `resource_usage` and plotting the `quantity` over `usage_date`. +# 3. **Missing Data**: If specific daily data (e.g., "today") is requested but not in the DB, explain that data might not be generated yet. + +# **Report Generation Requirement:** +# For complex analysis tasks (e.g., "Analyze anomaly", "Generate report", "Forecast trend", "Diagnose issue"), you MUST: +# 1. Perform the analysis using SQL and Python. +# 2. **Generate a Markdown Report**: Use `generate_document` to save a detailed report (e.g., `analysis_report.md`). +# - The report MUST include: **Executive Summary**, **Methodology** (SQL/Python logic), **Detailed Findings** (with data tables/charts), and **Recommendations**. +# 3. In your Final Answer, provide a brief summary AND the file path of the generated report. + +# **Output Requirement:** +# You MUST describe the detailed execution process in your final answer using **Chinese**. The description should include: +# 1. **Thought Process**: How you analyzed the request and what strategy you chose. +# 2. **Tool Usage**: Which tools were used, with what parameters (e.g., specific SQL queries). +# 3. **Intermediate Results**: Key findings from each step (e.g., "Found customer ID ACC-001 for Xiaomi"). +# 4. **Final Answer**: The direct answer to the user's question, supported by the data found. + +# **Planning Strategy:** +# 1. **Analyze the Request**: Determine if the request requires simple database retrieval (Text-to-SQL) or complex analysis/calculation (Text-to-Python). +# 2. **Formulate a Plan**: +# * **Simple Path (Text-to-SQL)**: If the request is a direct data lookup (e.g., "What were the sales last month?"), create a plan to write and execute a SQL query using `run_sql`. +# * **Complex Path (Text-to-Python/Multi-turn)**: If the request involves advanced analytics, predictions, complex logic, or non-SQL operations (e.g., "Predict next month's sales trend", "Find anomalies in sales"), create a multi-step plan: +# a. Retrieve necessary data using `run_sql`. +# b. Write a Python script using `generate_document` to process the data. +# c. Execute the script using `run_python_file`. +# d. Analyze the output. +# 3. **Execute & Observe**: Follow your plan, executing tools and observing outputs. +# 4. **Iterative Refinement (Multi-turn)**: +# * **Error Recovery**: If a tool execution fails (e.g., SQL syntax error, Python runtime error), analyze the error message, revise your plan, and retry using `REPLANNING`. +# * **Clarification**: If the request is ambiguous, ask the user for clarification. + +# Here is the schema details of the B2B CRM database: +# ```sql +# CREATE TABLE IF NOT EXISTS customer ( +# customer_id TEXT PRIMARY KEY, +# name TEXT NOT NULL, -- Full Name +# short_name TEXT, -- Short Name +# is_main_customer BOOLEAN, -- Is Main Customer (1=Yes, 0=No) +# customer_level TEXT, -- Customer Level (Strategic, KA, NA) +# owner TEXT, -- Owner Name +# sales_team TEXT, -- Sales Team +# industry TEXT, +# status TEXT +# ); +# CREATE TABLE IF NOT EXISTS revenue ( +# id INTEGER PRIMARY KEY AUTOINCREMENT, +# customer_id TEXT, +# year_month TEXT, -- Revenue Month (YYYY-MM) +# product_category TEXT, -- Product Category +# product_name TEXT, -- Product Name +# amount REAL, -- Revenue Amount +# FOREIGN KEY(customer_id) REFERENCES customer(customer_id) +# ); +# CREATE TABLE IF NOT EXISTS resource_usage ( +# id INTEGER PRIMARY KEY AUTOINCREMENT, +# customer_id TEXT, +# usage_date TEXT, -- Usage Date (YYYY-MM-DD) +# resource_type TEXT, -- Resource Type +# model_or_card TEXT, -- Model/Card Type +# quantity REAL, -- Usage Quantity +# FOREIGN KEY(customer_id) REFERENCES customer(customer_id) +# ); +# CREATE TABLE IF NOT EXISTS account_credit ( +# customer_id TEXT PRIMARY KEY, +# total_credit_limit REAL, -- Total Credit Limit +# available_balance REAL, -- Available Balance +# arrears_amount REAL, -- Arrears Amount +# FOREIGN KEY(customer_id) REFERENCES customer(customer_id) +# ); +# ``` + +# Here are some examples of how to query this database: + +# Q: "小米客户近3个月的收入" (Ambiguous Name & Time Range) +# Thought: User asks for "Xiaomi". I need to find the main customer "Xiaomi" to avoid "Xiaomi Shoes". "Last 3 months" refers to revenue data. +# Plan: +# 1. Find the `customer_id` for "Xiaomi" where `is_main_customer=1`. +# 2. Query `revenue` table for this `customer_id` for the last 3 months. +# A: SELECT sum(amount) FROM revenue WHERE customer_id = (SELECT customer_id FROM customer WHERE (name LIKE '%小米%' OR short_name LIKE '%小米%') AND is_main_customer=1) AND year_month >= strftime('%Y-%m', date('now', '-3 months')) + +# Q: "小米最近的用量趋势" (Complex Trend Visualization) +# Thought: User wants "trend". This requires daily data from `resource_usage` and a chart. +# Plan: +# 1. Get `customer_id` for "Xiaomi" (Main Customer). +# 2. Query `usage_date` and `quantity` from `resource_usage` for the last 30 days. +# 3. Save result to CSV. +# 4. Call `visualize_data` to plot the trend. +# A: (Plan to use `run_sql` then `visualize_data`) + +# Q: "查一下分期乐的信控情况,有没有欠费?" (Derived Metric & Join) +# Thought: "Debt" or "Arrears" means checking `arrears_amount` in `account_credit` table. +# Plan: +# 1. Find `customer_id` for "Fenqile" (分期乐). +# 2. Join `customer` and `account_credit` to get credit limit, balance, and arrears. +# 3. If `arrears_amount` > 0, report it as debt. +# A: SELECT c.name, a.total_credit_limit, a.available_balance, a.arrears_amount FROM customer c JOIN account_credit a ON c.customer_id = a.customer_id WHERE c.name LIKE '%分期乐%' OR c.short_name LIKE '%分期乐%' + +# 1. Use `run_sql` to execute queries. +# """, +# tools=[ +# run_sql, # RunSqlTool: Execute SQL queries +# visualize_data, # VisualizeDataTool: Create visualizations +# save_correctanswer_memory, # SaveQuestionToolArgsTool: Save tool usage examples +# search_similar_tools, # SearchSavedCorrectToolUsesTool: Search tool usage examples +# generate_document, # WriteFileTool: Create new files +# summarize_data, # SummarizeDataTool: Summarize CSV data +# run_python_file, # RunPythonFileTool: Execute Python scripts +# pip_install, # PipInstallTool: Install Python packages +# read_file, # ReadFileTool: Read file content +# edit_file, # EditFileTool: Edit file content +# list_files, # ListFilesTool: List directory content +# search_files, # SearchFilesTool: Search for files +# save_text_memory # SaveTextMemoryTool: Save text to memory +# ], +# planner=PlanReActPlanner(), +# model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}} +# ) + + +# Define the Veadk Agent using Vanna Tools +agent: Agent = Agent( + name="b2b_data_agent", + description="Assistant for querying B2B customer, revenue, and usage data.", + instruction=""" +### 任务 +您是一个AI助手,你的任务如下: +- 根据用户自然语言请求,调用工具 `recall_metadata` 查询数据库元数据,理解用户查询中涉及的数据对象、字段、过滤条件等信息。注意,调用工具 `recall_metadata`的时候,tenant参数请固定为"c360"。 +- 根据用户自然语言请求和数据库元数据生成数据可视化引擎的查询结构DSL,目标是解析用户的查询,识别所需的数据对象、字段、过滤器、排序、分组、限制。切记你构造查询结构的所有的字段信息必须从元数据中获取,不允许胡乱编造。 +- 调用工具 `query_with_dsl` 查询业务数据。注意,调用工具 `query_with_dsl` 的时候,operator参数固定为 "liujiawei.boom@bytedance.com",tenant参数请固定为"c360"。 +- 对于复杂的分析任务(例如,“分析异常”、“生成报告”、“预测趋势”、“诊断问题”),你必须: + - 使用Python进行分析。 + - **生成Markdown报告**:使用`generate_document`保存详细报告(例如,`analysis_report.md`)。 + - 报告必须包含:**执行摘要**、**方法论**(Python逻辑)、**详细发现**(含数据表格/图表)以及**建议**。 + - 在你的最终答案中,提供简要摘要以及生成报告的文件路径。 + +### 关键指南: +- **分析元数据**: + - 分析元数据,将用户描述的字段、对象或条件映射到确切的字段名。**注意** 对于使用到字段名的地方,严格按照元数据提供的字段名原样使用,不要修改,例如元数据提供的字段名= "sf_id",在使用到的地方就用"sf_id",不要修改为"sfid" + - 对于枚举字段(字段的数据类型='enum') + 1. 基于抽样值理解枚举值数据,描述结构为"value:`值`,lable:`label`" 中的label理解关键字,但始终在过滤器或条件中使用对应的value。例如:在名为'account'数据对象中,如果字段'sub_industry_c'是枚举类型,其中一个label是'游戏',value是'Game',那么如果用户说“游戏客户”,则解释为查询对象'account',过滤器为:"sub_industry_c = 'Game'"。对所有枚举应用此逻辑。 + 2. 如果使用枚举字段作为三元组判断条件,不能使用contains函数,而应该使用“=”,例如要实现“模型简称='DeepSeek'”,三元组应为"ModelShortName = 'DeepSeek'" + - 对于文本字段(字段的数据类型='text'),有以下约定 + 1. 如果同时该字段的特殊类型是“可模糊匹配”时,在过滤器条件中不能使用'='操作符,而应使用contains函数,例如name.contains('名称'),反之则不能使用contains函数 +- **解析用户查询**: + - 从用户需求中识别核心数据对象(obj)(例如,如果用户提到“客户”或“accounts”,则映射到元数据中匹配的对象)。 + - 识别字段:用于显示、过滤、排序(orderBy)、分组(groupBy)。 + - 过滤器:构建“filter”中的逻辑表达式,有以下约定: + 1. 值由三元组+逻辑连接符\大括号号嵌套连接组成,如 举例:"field1 = 'value' && (field2 > 10 or field3 = 11)"中,"field1 = 'value'"、"field2 > 10"、"field3 = 11"为三元组,"and"和"or"为逻辑连接符,"()"为嵌套逻辑 + 2. 三元组中,左值或右值可以为字段、函数、常量(如字符串、整数等),中值为比较符,如(=、>、<等) + 3. 对于日期的处理:如果用户提到“本月”,则“本月”是指当前月的第一天,将过滤器设置为日期字段 >= 当前月的第一天(格式为'YYYY-MM-01',基于当前日期计算)。 + - OrderBy:排序字段,例如"field DESC"(如果降序)。特殊规则:对于客户的查询(即query.obj = 'account'),如果用户未指定query.orderBy,则默认按照客户等级倒序排列(从元数据中映射“客户等级”字段的apiName,并设置为“ DESC”) + - GroupBy:聚合字段,例如"field"(如果求和或计数)。 + - Limit:仅整数,例如10;如果未指定,默认为100 + + - 对于客户对象的查询,有以下约定: + 1. "L6、L7"等"L级"指的是客户标签 + 2. 如果是需要按照客户名称过滤数据,默认需要使用名称和简称一起模糊搜索 + 3. 除非明确要求输出客户ID,否则不要返回 + 4. "ACC-" 这样的一串编号是指"客户编号"字段 + 5. "腾讯"指的是客户名称 + 6. **拜访/跟进时间查询**: + - 用户表述:"最近拜访时间"、"最近跟进时间"、"最新拜访日期"、"最后一次拜访"、"最后一次跟进"、"最近一次拜访是什么时候" + - 客户表的 statistical_data.AggLatestNoteSubmitTime(最近拜访日期)字段,不可排序 + - **同义词**:"拜访" = "跟进","时间" = "日期" = "是什么时候" + - ❌ 不要从拜访/跟进记录表查询或使用orderBy + - **重要**:即使用户问"最后一次跟进"或"最近一次拜访",这是描述字段含义,不代表只返回1条记录。除非用户明确要求"只看1个客户",否则limit保持默认100。 + + - 对于用量数据的查询,有以下约定: + 1. 除了根据“大模型”、“CPU”、“GPU”这几个词来确定查询的数据对象外,还可以根据大模型用量对象中的“Model简称”字段确认本次查询是查大模型数据用量,也可以根据CPU&GPU用量对象中的“GPU卡型号”字段确认使用该对象 + 2. 如果是还要查询客户数据,则默认以客户ID作为groupBy + +### DSL构建规则 +1. filter过滤器禁止使用子查询语句。 +2.选取的元数据字段必须来自于同一数据对象,禁止跨多数据对象选取字段。 + +- 其他约定: +1. 对于100万、1亿这类的金额,在进行过滤时,需要转换成正确的数字,如100万应转换为1000000 +2. 火山账号一般为 210 开头的 10 位数字,如2100001029 +3. AppC360DmVolcengineDailyIncomeDf **不支持时间筛选,禁止添加时间条件** + +### DSL示例 +{ + "type": "object", + "properties": { + "Operator": { + "type": "string", + "description": "查询人邮箱" + }, + "Select": { + "type": "string", + "description": "要查询的字段名,多个字段用逗号分隔,类型为字符串" + }, + "Where": { + "type": "string", + "description": "过滤条件的逻辑表达式字符串,如 \"a = b or c = d\",用于筛选结果" + }, + "Limit": { + "type": "string", + "description": "返回结果的数量限制,默认10,范围1-10000" + }, + "OrderBy": { + "type": "string", + "description": "排序字段及方式,格式如“字段名 asc”表示正序,默认无排序" + }, + "Table": { + "type": "string", + "description": "查询的目标数据对象名,字符串类型,不能为空" + } + }, + "required": [ + "Operator", + "Select", + "Table" + ] +} + +### 输出要求: +你必须在最终答案中用**中文**描述详细的执行过程。描述应包括: +1. **思考过程**:你如何分析请求以及选择了何种策略。 +2. **工具使用**:使用了哪些工具,以及使用了什么参数。 +3. **中间结果**:每个步骤的关键发现。 +4. **最终答案**:对用户问题的直接回答,并辅以找到的数据支持。 + """, + tools=[ + run_sql, # RunSqlTool: Execute SQL queries + visualize_data, # VisualizeDataTool: Create visualizations + save_correctanswer_memory, # SaveQuestionToolArgsTool: Save tool usage examples + search_similar_tools, # SearchSavedCorrectToolUsesTool: Search tool usage examples + generate_document, # WriteFileTool: Create new files + summarize_data, # SummarizeDataTool: Summarize CSV data + run_python_file, # RunPythonFileTool: Execute Python scripts + pip_install, # PipInstallTool: Install Python packages + read_file, # ReadFileTool: Read file content + edit_file, # EditFileTool: Edit file content + list_files, # ListFilesTool: List directory content + search_files, # SearchFilesTool: Search for files + save_text_memory, # SaveTextMemoryTool: Save text to memory + query_with_dsl, + recall_metadata, + get_current_time, + ], + planner=PlanReActPlanner(), + model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}}, +) diff --git a/examples/veadk-vanna-proj/src/data_agent/tools.py b/examples/veadk-vanna-proj/src/data_agent/tools.py new file mode 100644 index 00000000..188f92fe --- /dev/null +++ b/examples/veadk-vanna-proj/src/data_agent/tools.py @@ -0,0 +1,512 @@ +import os +import httpx +import pandas as pd +import io +from typing import Optional, Dict, Any +import requests + +from vanna.integrations.sqlite import SqliteRunner +from vanna.tools.file_system import ( + LocalFileSystem, + WriteFileTool, + ReadFileTool, + EditFileTool, + ListFilesTool, + SearchFilesTool, +) +from vanna.tools import RunSqlTool, VisualizeDataTool +from vanna.tools.python import RunPythonFileTool, PipInstallTool +from vanna.tools.agent_memory import ( + SaveQuestionToolArgsTool, + SearchSavedCorrectToolUsesTool, + SaveTextMemoryTool, +) +from vanna.integrations.local.agent_memory import DemoAgentMemory +from vanna.core.tool import ToolContext +from vanna.core.user import User + + +# Setup SQLite +def setup_sqlite(): + # Use the generated B2B sample data + # Note: In VeFaaS, only /tmp is writable, so we might need to copy it there if we want to modify it. + # But for read-only access or local dev, we can point to the sample_data directory. + + # Try to find the sample data relative to this file + current_dir = os.path.dirname(os.path.abspath(__file__)) + # Go up one level to src, then to sample_data + sample_data_path = os.path.join( + os.path.dirname(current_dir), "sample_data", "b2b_crm.sqlite" + ) + + if os.path.exists(sample_data_path): + return sample_data_path + + # Fallback to Chinook if B2B data not found (or for compatibility) + db_path = "/tmp/Chinook.sqlite" + if not os.path.exists(db_path): + print("Downloading Chinook.sqlite...") + url = "https://vanna.ai/Chinook.sqlite" + try: + with open(db_path, "wb") as f: + with httpx.stream("GET", url) as response: + for chunk in response.iter_bytes(): + f.write(chunk) + except Exception as e: + print(f"Error downloading database: {e}") + return db_path + + +# Initialize Resources +db_path = setup_sqlite() +# Use /tmp for file storage as it's the only writable directory in VeFaaS +file_system = LocalFileSystem(working_directory="/tmp/data_storage") +if not os.path.exists("/tmp/data_storage"): + os.makedirs("/tmp/data_storage", exist_ok=True) + +sqlite_runner = SqliteRunner(database_path=db_path) +agent_memory = DemoAgentMemory(max_items=1000) + +# Initialize Vanna Tools +sql_tool = RunSqlTool(sql_runner=sqlite_runner, file_system=file_system) +viz_tool = VisualizeDataTool(file_system=file_system) +run_python_tool = RunPythonFileTool(file_system=file_system) +pip_install_tool = PipInstallTool(file_system=file_system) + +# File System Tools +write_file_tool = WriteFileTool(file_system=file_system) +read_file_tool = ReadFileTool(file_system=file_system) +edit_file_tool = EditFileTool(file_system=file_system) +list_files_tool = ListFilesTool(file_system=file_system) +search_files_tool = SearchFilesTool(file_system=file_system) + +save_mem_tool = SaveQuestionToolArgsTool() +search_mem_tool = SearchSavedCorrectToolUsesTool() +save_text_mem_tool = SaveTextMemoryTool() + +# Create a mock context for tool execution +# In a real application, this should be created per-request with the actual user +mock_user = User( + id="veadk-user", email="user@example.com", group_memberships=["admin", "user"] +) +mock_context = ToolContext( + user=mock_user, + conversation_id="default", + request_id="default", + agent_memory=agent_memory, +) + +# Wrapper Functions for Veadk Agent + + +async def run_sql(sql: str) -> str: + """ + Execute a SQL query against the Chinook database. + + Args: + sql: The SQL query to execute. + """ + args_model = sql_tool.get_args_schema()(sql=sql) + result = await sql_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def visualize_data(filename: str, title: str = None) -> str: + """ + Visualize data from a CSV file. + + Args: + filename: The name of the CSV file to visualize. + title: Optional title for the chart. + """ + # Check if the file is likely a CSV file + if not filename.lower().endswith(".csv"): + return ( + f"Error: visualize_data only supports CSV files. You provided: {filename}" + ) + + args_model = viz_tool.get_args_schema()(filename=filename, title=title) + result = await viz_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def run_python_file(filename: str) -> str: + """ + Execute a Python file. + + Args: + filename: The name of the Python file to execute. + """ + # Check if the file is likely a Python file + if not filename.lower().endswith(".py"): + return f"Error: run_python_file only supports Python files. You provided: {filename}" + + args_model = run_python_tool.get_args_schema()(filename=filename) + result = await run_python_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def pip_install(packages: list[str]) -> str: + """ + Install Python packages using pip. + + Args: + packages: List of package names to install. + """ + args_model = pip_install_tool.get_args_schema()(packages=packages) + result = await pip_install_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def read_file(filename: str, start_line: int = 1, end_line: int = -1) -> str: + """ + Read the content of a file. + + Args: + filename: The name of the file to read. + start_line: The line number to start reading from (1-based). + end_line: The line number to stop reading at (inclusive). -1 for end of file. + """ + args_model = read_file_tool.get_args_schema()( + filename=filename, start_line=start_line, end_line=end_line + ) + result = await read_file_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def edit_file(filename: str, edits: list[dict[str, Any]]) -> str: + """ + Edit a file by replacing lines. + + Args: + filename: The name of the file to edit. + edits: A list of edits to apply. Each edit is a dictionary with: + - start_line: The line number to start replacing (1-based). + - end_line: The line number to stop replacing (inclusive). + - new_content: The new content to insert. + """ + # Convert dicts to EditFileTool.Edit objects if necessary, but Pydantic should handle dicts + args_model = edit_file_tool.get_args_schema()(filename=filename, edits=edits) + result = await edit_file_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def list_files(path: str = ".") -> str: + """ + List files in a directory. + + Args: + path: The directory path to list. Defaults to current directory. + """ + args_model = list_files_tool.get_args_schema()(path=path) + result = await list_files_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def search_files(query: str, path: str = ".") -> str: + """ + Search for files matching a query. + + Args: + query: The search query (regex or glob pattern). + path: The directory path to search in. Defaults to current directory. + """ + args_model = search_files_tool.get_args_schema()(query=query, path=path) + result = await search_files_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def save_correctanswer_memory( + question: str, tool_name: str, args: Dict[str, Any] +) -> str: + """ + Save a successful question-tool-argument combination for future reference. + + Args: + question: The original question that was asked. + tool_name: The name of the tool that was used successfully. + args: The arguments that were passed to the tool. + """ + # Temporarily disabled due to infinite loop issues + return "Memory saved successfully (Simulated)" + + +async def search_similar_tools(question: str, limit: int = 10) -> str: + """ + Search for similar tool usage patterns based on a question. + + Args: + question: The question to find similar tool usage patterns for. + limit: Maximum number of results to return. + """ + args_model = search_mem_tool.get_args_schema()(question=question, limit=limit) + result = await search_mem_tool.execute(mock_context, args_model) + # Return the result (whether success or error message) + return str(result.result_for_llm) + + +async def save_text_memory(text: str, tags: list[str] = None) -> str: + """ + Save arbitrary text to memory for future retrieval. + + Args: + text: The text content to save. + tags: Optional list of tags to categorize the memory. + """ + # Note: SaveTextMemoryParams uses 'content' field, but we expose it as 'text' to the LLM for clarity. + # We map 'text' to 'content' here. 'tags' are not currently supported by SaveTextMemoryParams in this version of Vanna. + args_model = save_text_mem_tool.get_args_schema()(content=text) + result = await save_text_mem_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def generate_document(filename: str, content: str) -> str: + """ + Generate a document (save content to a file). + + Args: + filename: The name of the file to save (e.g., 'report.md', 'summary.txt'). + content: The text content to write to the file. + """ + args_model = write_file_tool.get_args_schema()( + filename=filename, content=content, overwrite=True + ) + result = await write_file_tool.execute(mock_context, args_model) + return str(result.result_for_llm) + + +async def summarize_data(filename: str) -> str: + """ + Generate a statistical summary of data from a CSV file. + + Args: + filename: The name of the CSV file to summarize. + """ + try: + # Read the file content + content = await file_system.read_file(filename, mock_context) + + # Parse into DataFrame + df = pd.read_csv(io.StringIO(content)) + + # Generate summary stats + description = df.describe().to_markdown() + head = df.head().to_markdown() + info = f"Rows: {len(df)}, Columns: {len(df.columns)}\nColumn Names: {', '.join(df.columns)}" + + summary = f"**Data Summary for {filename}**\n\n**Info:**\n{info}\n\n**First 5 Rows:**\n{head}\n\n**Statistical Description:**\n{description}" + return summary + except Exception as e: + return f"Failed to summarize data: {str(e)}" + + +# def query_with_dsl(dsl_json: Dict[str, Any], timeout: int = 30) -> Dict[str, Any]: +# """ +# 使用DSL JSON查询数据的函数 + +# Args: +# dsl_json: 完整的DSL查询JSON对象 +# timeout: 请求超时时间(秒) + +# Returns: +# 格式化后的查询结果字典 + +# Example: +# dsl = { +# "Operator": "liujiawei.boom@bytedance.com", +# "Tenant": "c360", +# "Table": "large_model_usage", +# "Select": "account_number, request_date, token_amount", +# "GroupBy": "account_number, request_date", +# "Where": "request_date >= '2025-11-24' and account_number != 'ACC-0000872346'", +# "OrderBy": "", +# "Limit": 100 +# } +# result = query_with_dsl(dsl) +# """ +# # API端点 +# host = "bytedance" +# url = f"http://eps-agent.{host}.net/search_metadata/query?Action=Query" + +# # 请求头 +# headers = { +# "Content-Type": "application/json" +# } + +# try: +# # 发送POST请求 +# response = requests.post(url, headers=headers, json=dsl_json, timeout=timeout) + +# # 检查响应状态 +# response.raise_for_status() + +# # 解析JSON响应 +# result = response.json() + +# # 格式化输出 +# formatted_result = { +# "ResponseMetadata": { +# "RequestId": result.get("ResponseMetadata", {}).get("RequestId", "") +# }, +# "Result": [] +# } + +# # 提取结果数据 +# if "Result" in result: +# for item in result["Result"]: +# formatted_item = { +# "account_id": item.get("account_id", ""), +# "account_number": item.get("account_number", ""), +# "request_date": item.get("request_date", ""), +# "token_amount": item.get("token_amount", 0) +# } +# formatted_result["Result"].append(formatted_item) + +# return formatted_result + +# except requests.exceptions.RequestException as e: +# raise Exception(f"请求错误: {e}") +# except json.JSONDecodeError as e: +# raise Exception(f"JSON解析错误: {e}") +# except Exception as e: +# raise Exception(f"未知错误: {e}") + + +def query_with_dsl( + operator: str, + tenant: str, + table: str, + select: str, + group_by: Optional[str] = None, + where: Optional[str] = None, + order_by: Optional[str] = None, + limit: Optional[int] = 100, + timeout: int = 30, +) -> Dict[str, Any]: + """ + 查询数据的函数 + + Args: + operator: 操作者标识,通常为企业邮箱,用于审计和权限校验 + tenant: 租户标识,需与元数据查询时保持一致 + table: 需要查询的数据表名 + select: 需要查询的字段列表,多个字段间用英文逗号分隔 + group_by: 分组字段列表,多个字段间用英文逗号分隔 + where: 筛选条件,采用 SQL-like 语法 + order_by: 排序条件,格式为 "字段名 ASC/DESC" + limit: 返回记录的最大数量,默认为 100 + timeout: 请求超时时间(秒) + + Returns: + 查询结果字典 + + Example: + result = query_data( + operator="liujiawei.boom@bytedance.com", + tenant="c360", + table="large_model_usage", + select="account_number, request_date, token_amount", + group_by="account_number, request_date", + where="request_date >= '2025-11-24' and account_number != 'ACC-0000872346'", + order_by="request_date DESC", + limit=100 + ) + """ + # API端点 + host = "bytedance" + url = f"http://eps-agent.{host}.net/search_metadata/query?Action=Query" + + # 构建请求体 + payload = { + "Operator": operator, + "Tenant": tenant, + "Table": table, + "Select": select, + } + + # 添加可选参数 + if group_by: + payload["GroupBy"] = group_by + if where: + payload["Where"] = where + if order_by: + payload["OrderBy"] = order_by + if limit is not None: + payload["Limit"] = limit + + # 请求头 + headers = {"Content-Type": "application/json"} + + try: + # 发送POST请求 + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + + # 解析JSON响应 + result = response.json() + + # 检查响应状态 + response.raise_for_status() + + return result + + except Exception as e: + return f"错误: {e}, 返回内容: {result}" + + +def recall_metadata(tenant: str, query: str, timeout: int = 30) -> Dict[str, Any]: + """ + 调用元数据查询接口的函数 + + Args: + tenant: 租户名称 + query: 查询文本 + timeout: 请求超时时间(秒) + + Returns: + 查询结果字典 + + Example: + result = recall_metadata( + tenant="c360", + query="小米的收入是多少?" + ) + """ + # API端点 + host = "bytedance" + url = f"http://eps-agent.{host}.net/search_metadata/metadata?Action=RecallMetadata" + + # 请求头 + headers = {"Content-Type": "application/json"} + + # 请求体 + payload = {"Tenant": tenant, "Query": query} + + try: + # 发送POST请求 + response = requests.post(url, headers=headers, json=payload, timeout=timeout) + + result = response.json() + + # 检查响应状态 + response.raise_for_status() + + # 解析JSON响应 + return result + + except requests.exceptions.RequestException as e: + return f"错误: {e}, 返回内容: {result}" + + +def get_current_time() -> str: + """ + 获取当前时间的函数 + + Returns: + 当前时间的字符串表示,格式为 "YYYY-MM-DD HH:MM:SS" + + Example: + current_time = get_current_time() + """ + from datetime import datetime + + return datetime.now().strftime("%Y-%m-%d %H:%M:%S") diff --git a/examples/veadk-vanna-proj/src/requirements.txt b/examples/veadk-vanna-proj/src/requirements.txt new file mode 100644 index 00000000..d2858f5c --- /dev/null +++ b/examples/veadk-vanna-proj/src/requirements.txt @@ -0,0 +1,5 @@ +veadk-python==0.5.16 +fastapi==0.123.10 +uvicorn[standard]==0.40.0 +vanna==2.0.1 +httpx==0.28.1 diff --git a/examples/veadk-vanna-proj/src/run.sh b/examples/veadk-vanna-proj/src/run.sh new file mode 100755 index 00000000..b0f80137 --- /dev/null +++ b/examples/veadk-vanna-proj/src/run.sh @@ -0,0 +1,49 @@ +#!/bin/bash +set -ex +cd `dirname $0` + +# A special check for CLI users (run.sh should be located at the 'root' dir) +if [ -d "output" ]; then + cd ./output/ +fi + +# Default values for host and port +HOST="0.0.0.0" +PORT=${_FAAS_RUNTIME_PORT:-8000} +TIMEOUT=${_FAAS_FUNC_TIMEOUT} + +export SERVER_HOST=$HOST +export SERVER_PORT=$PORT + +export PYTHONPATH=$PYTHONPATH:./site-packages + +# Parse arguments +while [[ $# -gt 0 ]]; do + case $1 in + --port) + PORT="$2" + shift 2 + ;; + --host) + HOST="$2" + shift 2 + ;; + *) + shift + ;; + esac +done + + +USE_ADK_WEB=${USE_ADK_WEB:-False} + +export SHORT_TERM_MEMORY_BACKEND= # can be `mysql` +export LONG_TERM_MEMORY_BACKEND= # can be `opensearch` + +if [ "$USE_ADK_WEB" = "True" ]; then + echo "USE_ADK_WEB is True, running veadk web" + exec python3 -m veadk.cli.cli web --host $HOST +else + echo "USE_ADK_WEB is False, running A2A and MCP server" + exec python3 -m uvicorn app:app --host $HOST --port $PORT --timeout-graceful-shutdown $TIMEOUT --loop asyncio +fi diff --git a/examples/veadk-vanna-proj/src/sample_data/README.md b/examples/veadk-vanna-proj/src/sample_data/README.md new file mode 100644 index 00000000..9648d9ea --- /dev/null +++ b/examples/veadk-vanna-proj/src/sample_data/README.md @@ -0,0 +1,134 @@ +# B2B Data Agent 样例问题集 + +本文档列出了10个典型的 B2B 业务场景问题,对比了使用本 Agent 前后的效果差异。为了展示 Agent 的高级分析能力,我们特别设计了多个需要结合 SQL 和 Python 脚本解决的复杂问题(Text-to-Python)。 + +## 场景 1: 客户名称模糊匹配 + +### 1. 小米客户近3个月的收入 + +**Bad Case (原逻辑):** +可能直接模糊匹配 `name LIKE '%小米%'`,导致将“宁波小米粒鞋业”的收入也计算在内,或者因为存在多个匹配项而报错。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 识别到“小米”可能存在歧义,优先查找 `customer` 表中 `is_main_customer=1` 的记录。 +2. **Action**: `run_sql("SELECT customer_id FROM customer WHERE (name LIKE '%小米%' OR short_name = '小米') AND is_main_customer = 1")` -> 找到 ACC-001。 +3. **Action**: `run_sql("SELECT sum(amount) FROM revenue WHERE customer_id = 'ACC-001' AND year_month >= ...")` +4. **Result**: 准确返回小米科技(主客户)的收入数据,排除干扰项。 + +### 2. 网易2025年的总收入 + +**Bad Case (原逻辑):** +可能无法区分“网易”是指集团还是某个子公司,或者漏掉某些月份的数据。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 查找“网易”对应的主客户 ID。 +2. **Action**: 锁定 ACC-005 (网易(杭州)网络有限公司)。 +3. **Action**: 聚合查询 2025 全年的 `revenue`。 +4. **Result**: 返回网易 2025 年度的准确总收入。 + +## 场景 2: 复杂趋势与预测 (SQL + Python) + +### 3. 小米最近的用量趋势(含可视化) + +**Bad Case (原逻辑):** +仅返回一堆数字或 SQL 查询结果,用户无法直观理解趋势。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 用户需要“趋势”,意味着需要日粒度数据并生成图表。 +2. **Action**: `run_sql` 查询 `resource_usage` 表获取近 30 天的 `usage_date` 和 `quantity`,保存为 CSV。 +3. **Action**: 调用 `visualize_data` 工具,传入 CSV 数据生成折线图。 +4. **Result**: 返回一张清晰的用量趋势折线图。 + +### 4. 预测小米下个月的用量增长 + +**Bad Case (原逻辑):** +无法处理“预测”请求,因为 SQL 无法进行时间序列预测。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 这是一个预测任务,需要提取历史数据并使用 Python 进行线性回归或简单外推。 +2. **Action**: `run_sql` 导出小米过去 3 个月的用量数据到 CSV。 +3. **Action**: `generate_document` 编写 Python 脚本:读取 CSV,使用 `scikit-learn` 或 `numpy` 拟合趋势线,预测下月总量。 +4. **Action**: `run_python_file` 执行预测脚本。 +5. **Result**: "基于过去3个月的增长趋势,预测小米下个月的用量约为 35,000,000 Tokens,环比增长 5%。" + +## 场景 3: 复杂逻辑与异常检测 (SQL + Python) + +### 5. 帮我分析一下小米的云资源使用异常 + +**Bad Case (原逻辑):** +不知道什么是“异常”,直接报错或返回空。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 这是一个复杂分析任务。需要定义异常(如 3-Sigma 准则或 IQR)。 +2. **Action**: `run_sql` 获取小米的历史日用量数据。 +3. **Action**: `generate_document` 编写 Python 脚本:计算均值和标准差,识别超过 `mean + 3*std` 的日期。 +4. **Action**: `run_python_file` 执行脚本。 +5. **Action**: `generate_document` 生成分析报告 `xiaomi_anomaly_analysis.md`。 +6. **Result**: "检测到 15天前 用量激增,达到 3,000,000,超过平均值 3 倍,属于异常波动。详细分析请见报告:`xiaomi_anomaly_analysis.md`" + +### 6. 计算网易 DeepSeek 调用的周环比增长率 + +**Bad Case (原逻辑):** +SQL 计算周环比(WoW)非常复杂,容易出错。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 使用 Python pandas 处理时间序列更高效。 +2. **Action**: `run_sql` 获取网易 DeepSeek 的每日用量数据。 +3. **Action**: `generate_document` 编写 Python 脚本:`df.resample('W').sum().pct_change()` 计算周环比。 +4. **Action**: `run_python_file` 执行脚本。 +5. **Result**: "上周对比上上周,DeepSeek 调用量增长了 15.2%。" + +## 场景 4: 跨表关联与业务洞察 + +### 7. 哪些欠费客户还在大量使用资源?(风险预警) + +**Bad Case (原逻辑):** +无法同时关联欠费状态和近期用量,或者 SQL 逻辑过于复杂导致超时。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 需要找到 `arrears_amount > 0` 的客户,并检查其最近 3 天的 `resource_usage` 是否超过阈值。 +2. **Action**: `run_sql` 联查 `customer`, `account_credit`, `resource_usage`,筛选欠费且近3天有用量的客户。 +3. **Result**: "警告:分期乐(ACC-004)当前欠费 74.8万元,但最近3天仍消耗了 500 GPU Hours,建议立即介入。" + +### 8. 谁是 DeepSeek 模型最大的使用方? + +**Bad Case (原逻辑):** +无法将“DeepSeek”映射到 `model_or_card` 字段,或者不知道如何定义“最大使用方”。 + +**Expected (Agent 逻辑):** + +1. **Thought**: “DeepSeek” 对应 `resource_usage` 表中的 `model_or_card`。 +2. **Action**: 按 `customer_id` 分组统计 `quantity` 总和,按降序排列,取第一名。 +3. **Result**: "网易(ACC-005)是 DeepSeek 模型的最大使用方,累计消耗 Token 超过 6000 万。" + +## 场景 5: 综合报告生成 (SQL + Python) + +### 9. 生成一份小米的月度消费报告 + +**Bad Case (原逻辑):** +只能返回零散的数据,无法生成结构化报告。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 报告需要包含收入、用量、趋势图和关键指标。这是一个多步任务。 +2. **Action 1**: `run_sql` 查询当月收入和用量总和。 +3. **Action 2**: `run_sql` 查询日用量趋势,并调用 `visualize_data` 生成图表。 +4. **Action 3**: `generate_document` 将上述数据和图表路径整合成 Markdown 格式的报告 `xiaomi_monthly_report.md`,报告中必须包含各步骤的详细分析。 +5. **Result**: "报告已生成:`xiaomi_monthly_report.md`。摘要:小米本月总消费 xxx 元,趋势平稳,无异常波动。" + +### 10. 今天的最新收入数据出来了吗? + +**Bad Case (原逻辑):** +尝试查询数据库,返回空结果,用户不知道是没数据还是没发生交易。 + +**Expected (Agent 逻辑):** + +1. **Thought**: 命中 Instruction 中的 "Missing Data" 策略。 +2. **Result**: 直接回答:“根据数据库设计,`revenue` 表按月更新,`resource_usage` 按日更新。今天的实时收入数据尚未生成,通常需要在次月出账后查看。” diff --git a/examples/veadk-vanna-proj/src/sample_data/b2b_crm.sqlite b/examples/veadk-vanna-proj/src/sample_data/b2b_crm.sqlite new file mode 100644 index 00000000..bca835da Binary files /dev/null and b/examples/veadk-vanna-proj/src/sample_data/b2b_crm.sqlite differ diff --git a/examples/veadk-vanna-proj/src/sample_data/b2b_data_gen.py b/examples/veadk-vanna-proj/src/sample_data/b2b_data_gen.py new file mode 100644 index 00000000..f818cf6b --- /dev/null +++ b/examples/veadk-vanna-proj/src/sample_data/b2b_data_gen.py @@ -0,0 +1,124 @@ +import sqlite3 +import random +from datetime import datetime, timedelta +import os + +def create_b2b_database(): + db_path = os.path.join(os.path.dirname(__file__), 'b2b_crm.sqlite') + # If db exists, remove it to start fresh + if os.path.exists(db_path): + os.remove(db_path) + + conn = sqlite3.connect(db_path) + c = conn.cursor() + + # 1. 客户表:解决名称歧义 (小米 vs 小米粒)、归属关系 (Owner/SalesTeam) + c.execute('''CREATE TABLE IF NOT EXISTS customer ( + customer_id TEXT PRIMARY KEY, + name TEXT NOT NULL, -- 全称 + short_name TEXT, -- 简称 + is_main_customer BOOLEAN, -- 是否主客户 + customer_level TEXT, -- 客户等级 (Strategic, KA, NA) + owner TEXT, -- 负责人 + sales_team TEXT, -- 销售团队 + industry TEXT, + status TEXT + )''') + + # 2. 收入表:解决 "最近3个月收入"、"分产品收入" + c.execute('''CREATE TABLE IF NOT EXISTS revenue ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + customer_id TEXT, + year_month TEXT, -- 计收月份 (YYYY-MM) + product_category TEXT, -- 产品分类 (AI, Cloud) + product_name TEXT, -- 产品名称 (Model Inference, GPU) + amount REAL, -- 收入金额 + FOREIGN KEY(customer_id) REFERENCES customer(customer_id) + )''') + + # 3. 用量表:解决 "最近3天调用量"、"Tokens趋势" (日粒度) + c.execute('''CREATE TABLE IF NOT EXISTS resource_usage ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + customer_id TEXT, + usage_date TEXT, -- 用量日期 (YYYY-MM-DD) + resource_type TEXT, -- 资源类型 (Tokens, GPU_Hours) + model_or_card TEXT, -- 模型/卡型 (DeepSeek, L20) + quantity REAL, -- 用量数值 + FOREIGN KEY(customer_id) REFERENCES customer(customer_id) + )''') + + # 4. 信控表:解决 "信控额度"、"欠费" + c.execute('''CREATE TABLE IF NOT EXISTS account_credit ( + customer_id TEXT PRIMARY KEY, + total_credit_limit REAL, -- 总信控额度 + available_balance REAL, -- 可用余额 + arrears_amount REAL, -- 欠费金额 + FOREIGN KEY(customer_id) REFERENCES customer(customer_id) + )''') + + # --- 注入针对 Bad Case 的数据 --- + + # Case A: 模糊匹配 & 主客户逻辑 ("小米") + customers = [ + ('ACC-001', '小米科技有限责任公司', '小米', 1, 'Strategic', 'ZhangSan', 'KA-North', 'Internet', 'Active'), + ('ACC-002', '宁波小米粒鞋业有限公司', '小米粒', 0, 'SMB', 'LiSi', 'SME-East', 'Retail', 'Active'), + ('ACC-003', '安宁市小米渣食品店', '小米渣', 0, 'SMB', 'WangWu', 'SME-South', 'Retail', 'Active'), + ('ACC-004', '深圳市分期乐网络科技有限公司', '分期乐', 1, 'KA', 'ZhaoLiu', 'FinTech-Group', 'Finance', 'Active'), + ('ACC-005', '网易(杭州)网络有限公司', '网易', 1, 'Strategic', 'SunBa', 'KA-East', 'Internet', 'Active'), + ] + c.executemany('INSERT OR REPLACE INTO customer VALUES (?,?,?,?,?,?,?,?,?)', customers) + + # Case B: 复杂时间窗口收入 ("小米最近3个月收入") + # 假设当前是 2026-01,生成 2025-10 ~ 2025-12 的数据 + revenue_data = [] + for month in ['2025-10', '2025-11', '2025-12']: + # 小米 (ACC-001) - 只有主客户有大额收入 + revenue_data.append(('ACC-001', month, 'AI', 'Model Inference', 12000000.00)) + # 小米粒 (ACC-002) - 极小金额 + revenue_data.append(('ACC-002', month, 'Cloud', 'VM', 5.50)) + + # Case C: 分产品收入 ("网易 25年 Deepseek 收入") + for m in range(1, 13): + m_str = f"2025-{m:02d}" + revenue_data.append(('ACC-005', m_str, 'AI', 'DeepSeek', 500000.00)) + + c.executemany('INSERT INTO revenue (customer_id, year_month, product_category, product_name, amount) VALUES (?,?,?,?,?)', revenue_data) + + # Case D: 用量趋势 ("联想/小米 近3天调用量") + usage_data = [] + # 使用当前日期作为基准,确保"最近3天"总是有数据 + today = datetime.now() + base_date = today - timedelta(days=30) + + for i in range(31): # 生成过去30天到今天的数据 + d = (base_date + timedelta(days=i)).strftime('%Y-%m-%d') + + # 小米 (ACC-001) 每天调用量波动 (Doubao-Pro) + # 正常波动: 100w + i*1w + # 异常点注入: 15天前 (i=15) 突然飙升到 300w (是均值的约2-3倍) + quantity = 1000000 + i*10000 + if i == 15: + quantity = 3000000 # 异常点 + + usage_data.append(('ACC-001', d, 'Tokens', 'Doubao-Pro', quantity)) + + # 网易 (ACC-005) 使用 DeepSeek 模型 + usage_data.append(('ACC-005', d, 'Tokens', 'DeepSeek', 2000000 + i*50000)) + + # Case E: 信控与欠费 ("分期乐信控余额" & "欠费仍在使用") + # 分期乐 (ACC-004): 欠费 74w + c.execute("INSERT OR REPLACE INTO account_credit VALUES ('ACC-004', 400000, -348787.45, 748787.45)") + + # 分期乐最近3天仍有 GPU 用量 (模拟风险场景) + for i in range(3): + d = (today - timedelta(days=i)).strftime('%Y-%m-%d') + usage_data.append(('ACC-004', d, 'GPU_Hours', 'L20', 500)) + + c.executemany('INSERT INTO resource_usage (customer_id, usage_date, resource_type, model_or_card, quantity) VALUES (?,?,?,?,?)', usage_data) + + conn.commit() + conn.close() + print(f"B2B 模拟数据库已生成: {db_path}") + +if __name__ == "__main__": + create_b2b_database() diff --git a/veadk/tools/vanna_tools/agent_memory.py b/veadk/tools/vanna_tools/agent_memory.py new file mode 100644 index 00000000..f4034ef4 --- /dev/null +++ b/veadk/tools/vanna_tools/agent_memory.py @@ -0,0 +1,324 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.agent_memory import ( + SaveQuestionToolArgsTool as VannaSaveQuestionToolArgsTool, + SearchSavedCorrectToolUsesTool as VannaSearchSavedCorrectToolUsesTool, + SaveTextMemoryTool as VannaSaveTextMemoryTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class SaveQuestionToolArgsTool(BaseTool): + """Save successful question-tool-argument combinations for future reference.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the save tool usage tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin']) + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSaveQuestionToolArgsTool() + self.access_groups = access_groups or ["admin"] # Default: only admin + + super().__init__( + name="save_question_tool_args", # Keep the same name as Vanna + description="Save a successful question-tool-argument combination for future reference.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "question": types.Schema( + type=types.Type.STRING, + description="The original question that was asked", + ), + "tool_name": types.Schema( + type=types.Type.STRING, + description="The name of the tool that was used successfully", + ), + "args": types.Schema( + type=types.Type.OBJECT, + description="The arguments that were passed to the tool", + ), + }, + required=["question", "tool_name", "args"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Save a tool usage pattern.""" + question = args.get("question", "").strip() + tool_name = args.get("tool_name", "").strip() + tool_args = args.get("args", {}) + + if not question: + return "Error: No question provided" + + if not tool_name: + return "Error: No tool name provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + question=question, tool_name=tool_name, args=tool_args + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error saving tool usage: {str(e)}" + + +class SearchSavedCorrectToolUsesTool(BaseTool): + """Search for similar tool usage patterns based on a question.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the search similar tools tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSearchSavedCorrectToolUsesTool() + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="search_saved_correct_tool_uses", # Keep the same name as Vanna + description="Search for similar tool usage patterns based on a question.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "question": types.Schema( + type=types.Type.STRING, + description="The question to find similar tool usage patterns for", + ), + "limit": types.Schema( + type=types.Type.INTEGER, + description="Maximum number of results to return (default: 10)", + ), + }, + required=["question"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Search for similar tool usage patterns.""" + question = args.get("question", "").strip() + limit = args.get("limit", 10) + + if not question: + return "Error: No question provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + question=question, limit=limit + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error searching similar tools: {str(e)}" + + +class SaveTextMemoryTool(BaseTool): + """Save free-form text memories for important insights, observations, or context.""" + + def __init__( + self, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the save text memory tool with custom agent_memory. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.agent_memory = agent_memory + self.vanna_tool = VannaSaveTextMemoryTool() + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="save_text_memory", # Keep the same name as Vanna + description="Save free-form text memory for important insights, observations, or context.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "content": types.Schema( + type=types.Type.STRING, + description="The text content to save as a memory", + ), + }, + required=["content"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Save a text memory.""" + content = args.get("content", "").strip() + + if not content: + return "Error: No content provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()(content=content) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error saving text memory: {str(e)}" diff --git a/veadk/tools/vanna_tools/examples/agent.py b/veadk/tools/vanna_tools/examples/agent.py new file mode 100644 index 00000000..917b17bf --- /dev/null +++ b/veadk/tools/vanna_tools/examples/agent.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from veadk import Agent, Runner + +# Import Vanna dependencies for initialization +from vanna.integrations.sqlite import SqliteRunner +from vanna.tools import LocalFileSystem +from vanna.integrations.local.agent_memory import DemoAgentMemory +import httpx + +# Import the refactored class-based tools +from veadk.tools.vanna_tools.run_sql import RunSqlTool +from veadk.tools.vanna_tools.visualize_data import VisualizeDataTool +from veadk.tools.vanna_tools.file_system import WriteFileTool +from veadk.tools.vanna_tools.agent_memory import ( + SaveQuestionToolArgsTool, + SearchSavedCorrectToolUsesTool, +) +from veadk.tools.vanna_tools.summarize_data import SummarizeDataTool + +from google.adk.sessions import InMemorySessionService + + +# Setup SQLite database +def setup_sqlite(): + """Download and setup the Chinook SQLite database.""" + db_path = "/tmp/Chinook.sqlite" + if not os.path.exists(db_path): + print("Downloading Chinook.sqlite...") + url = "https://vanna.ai/Chinook.sqlite" + try: + with open(db_path, "wb") as f: + with httpx.stream("GET", url) as response: + for chunk in response.iter_bytes(): + f.write(chunk) + print("Database downloaded successfully!") + except Exception as e: + print(f"Error downloading database: {e}") + return db_path + + +# Create a session with user groups for access control +async def create_session(user_groups: list = ["user"]): + session_service = InMemorySessionService() + example_session = await session_service.create_session( + app_name="example_app", + user_id="example_user", + state={"user_groups": user_groups}, + ) + return session_service, example_session + + +# Initialize user-customizable resources +db_path = setup_sqlite() + +# 1. SQL Runner - can be SqliteRunner, PostgresRunner, MySQLRunner, etc. +sqlite_runner = SqliteRunner(database_path=db_path) + +# 2. File System - customize working directory as needed +file_system = LocalFileSystem(working_directory="/tmp/data_storage") +if not os.path.exists("/tmp/data_storage"): + os.makedirs("/tmp/data_storage", exist_ok=True) + +# 3. Agent Memory - customize memory implementation and capacity +agent_memory = DemoAgentMemory(max_items=1000) + +# Initialize tools with user-defined components and access control +# Tool names now match Vanna's original names for compatibility +run_sql_tool = RunSqlTool( + sql_runner=sqlite_runner, + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], # Both admin and user can use +) + +visualize_data_tool = VisualizeDataTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +write_file_tool = WriteFileTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +# Memory tools: save only for admin, search for all users +save_tool = SaveQuestionToolArgsTool( + agent_memory=agent_memory, + access_groups=["admin"], # Only admin can save +) + +search_tool = SearchSavedCorrectToolUsesTool( + agent_memory=agent_memory, + access_groups=["admin", "user"], # All users can search +) + +summarize_data_tool = SummarizeDataTool( + file_system=file_system, + agent_memory=agent_memory, + access_groups=["admin", "user"], +) + +# Define the Veadk Agent with class-based tools +agent: Agent = Agent( + name="vanna_sql_agent", + description="An intelligent agent that can query databases, visualize data, and generate reports.", + instruction=""" + You are a helpful assistant that can answer questions about data in the Chinook database. + You can execute SQL queries, visualize the results, save/search useful tool usage patterns, and generate documents. + + Here is the schema of the Chinook database: + ```sql + CREATE TABLE [Album] + ( + [AlbumId] INTEGER NOT NULL, + [Title] NVARCHAR(160) NOT NULL, + [ArtistId] INTEGER NOT NULL, + CONSTRAINT [PK_Album] PRIMARY KEY ([AlbumId]), + FOREIGN KEY ([ArtistId]) REFERENCES [Artist] ([ArtistId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Artist] + ( + [ArtistId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Artist] PRIMARY KEY ([ArtistId]) + ); + CREATE TABLE [Customer] + ( + [CustomerId] INTEGER NOT NULL, + [FirstName] NVARCHAR(40) NOT NULL, + [LastName] NVARCHAR(20) NOT NULL, + [Company] NVARCHAR(80), + [Address] NVARCHAR(70), + [City] NVARCHAR(40), + [State] NVARCHAR(40), + [Country] NVARCHAR(40), + [PostalCode] NVARCHAR(10), + [Phone] NVARCHAR(24), + [Fax] NVARCHAR(24), + [Email] NVARCHAR(60) NOT NULL, + [SupportRepId] INTEGER, + CONSTRAINT [PK_Customer] PRIMARY KEY ([CustomerId]), + FOREIGN KEY ([SupportRepId]) REFERENCES [Employee] ([EmployeeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Employee] + ( + [EmployeeId] INTEGER NOT NULL, + [LastName] NVARCHAR(20) NOT NULL, + [FirstName] NVARCHAR(20) NOT NULL, + [Title] NVARCHAR(30), + [ReportsTo] INTEGER, + [BirthDate] DATETIME, + [HireDate] DATETIME, + [Address] NVARCHAR(70), + [City] NVARCHAR(40), + [State] NVARCHAR(40), + [Country] NVARCHAR(40), + [PostalCode] NVARCHAR(10), + [Phone] NVARCHAR(24), + [Fax] NVARCHAR(24), + [Email] NVARCHAR(60), + CONSTRAINT [PK_Employee] PRIMARY KEY ([EmployeeId]), + FOREIGN KEY ([ReportsTo]) REFERENCES [Employee] ([EmployeeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Genre] + ( + [GenreId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Genre] PRIMARY KEY ([GenreId]) + ); + CREATE TABLE [Invoice] + ( + [InvoiceId] INTEGER NOT NULL, + [CustomerId] INTEGER NOT NULL, + [InvoiceDate] DATETIME NOT NULL, + [BillingAddress] NVARCHAR(70), + [BillingCity] NVARCHAR(40), + [BillingState] NVARCHAR(40), + [BillingCountry] NVARCHAR(40), + [BillingPostalCode] NVARCHAR(10), + [Total] NUMERIC(10,2) NOT NULL, + CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]), + FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [InvoiceLine] + ( + [InvoiceLineId] INTEGER NOT NULL, + [InvoiceId] INTEGER NOT NULL, + [TrackId] INTEGER NOT NULL, + [UnitPrice] NUMERIC(10,2) NOT NULL, + [Quantity] INTEGER NOT NULL, + CONSTRAINT [PK_InvoiceLine] PRIMARY KEY ([InvoiceLineId]), + FOREIGN KEY ([InvoiceId]) REFERENCES [Invoice] ([InvoiceId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [MediaType] + ( + [MediaTypeId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_MediaType] PRIMARY KEY ([MediaTypeId]) + ); + CREATE TABLE [Playlist] + ( + [PlaylistId] INTEGER NOT NULL, + [Name] NVARCHAR(120), + CONSTRAINT [PK_Playlist] PRIMARY KEY ([PlaylistId]) + ); + CREATE TABLE [PlaylistTrack] + ( + [PlaylistId] INTEGER NOT NULL, + [TrackId] INTEGER NOT NULL, + CONSTRAINT [PK_PlaylistTrack] PRIMARY KEY ([PlaylistId], [TrackId]), + FOREIGN KEY ([PlaylistId]) REFERENCES [Playlist] ([PlaylistId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([TrackId]) REFERENCES [Track] ([TrackId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + CREATE TABLE [Track] + ( + [TrackId] INTEGER NOT NULL, + [Name] NVARCHAR(200) NOT NULL, + [AlbumId] INTEGER, + [MediaTypeId] INTEGER NOT NULL, + [GenreId] INTEGER, + [Composer] NVARCHAR(220), + [Milliseconds] INTEGER NOT NULL, + [Bytes] INTEGER, + [UnitPrice] NUMERIC(10,2) NOT NULL, + CONSTRAINT [PK_Track] PRIMARY KEY ([TrackId]), + FOREIGN KEY ([AlbumId]) REFERENCES [Album] ([AlbumId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([GenreId]) REFERENCES [Genre] ([GenreId]) + ON DELETE NO ACTION ON UPDATE NO ACTION, + FOREIGN KEY ([MediaTypeId]) REFERENCES [MediaType] ([MediaTypeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION + ); + ``` + + Here are some examples of how to query this database: + + Q: Get all the tracks in the album 'Balls to the Wall'. + A: SELECT * FROM Track WHERE AlbumId = (SELECT AlbumId FROM Album WHERE Title = 'Balls to the Wall') + + Q: Get the total sales for each customer. + A: SELECT c.FirstName, c.LastName, SUM(i.Total) as TotalSales FROM Customer c JOIN Invoice i ON c.CustomerId = i.CustomerId GROUP BY c.CustomerId + + Q: How many tracks are there in each genre? + A: SELECT g.Name, COUNT(t.TrackId) as TrackCount FROM Genre g JOIN Track t ON g.GenreId = t.GenreId GROUP BY g.GenreId + + Available tools (using Vanna's original names): + 1. `run_sql` - Execute SQL queries + 2. `visualize_data` - Create visualizations from CSV files + 3. `write_file` - Save content to files + 4. `save_question_tool_args` - Save successful tool usage patterns (admin only) + 5. `search_saved_correct_tool_uses` - Search for similar tool usage patterns + 6. `summarize_data` - Generate statistical summaries of CSV files + """, + tools=[ + run_sql_tool, + visualize_data_tool, + write_file_tool, + save_tool, + search_tool, + summarize_data_tool, + ], + model_extra_config={"extra_body": {"thinking": {"type": "disabled"}}}, +) + + +async def main(prompt: str, user_groups: list = None) -> str: + session_service, example_session = await create_session( + user_groups + ) # Default to 'user' group if not specified + + runner = Runner( + agent=agent, + app_name=example_session.app_name, + user_id=example_session.user_id, + session_service=session_service, + ) + + response = await runner.run( + messages=prompt, + session_id=example_session.id, + ) + + return response + + +if __name__ == "__main__": + import asyncio + + # print("=== Example 1: Regular User ===") + # user_input = "What are the top 5 selling albums?" + # response = asyncio.run(main(user_input, user_groups=['user'])) + # print(response) + + print("\n=== Example 2: Admin User (can save patterns) ===") + admin_input = "What are the top 5 selling albums?" + response = asyncio.run(main(admin_input, user_groups=["admin"])) + print(response) diff --git a/veadk/tools/vanna_tools/file_system.py b/veadk/tools/vanna_tools/file_system.py new file mode 100644 index 00000000..432fb112 --- /dev/null +++ b/veadk/tools/vanna_tools/file_system.py @@ -0,0 +1,467 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.file_system import ( + WriteFileTool as VannaWriteFileTool, + ReadFileTool as VannaReadFileTool, + ListFilesTool as VannaListFilesTool, + SearchFilesTool as VannaSearchFilesTool, + EditFileTool as VannaEditFileTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class WriteFileTool(BaseTool): + """Write content to a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the write file tool with custom file_system. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + access_groups: List of user groups that can access this tool + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaWriteFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="write_file", + description="Write content to a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Name of the file to write", + ), + "content": types.Schema( + type=types.Type.STRING, + description="Content to write to the file", + ), + "overwrite": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to overwrite existing files (default: False)", + ), + }, + required=["filename", "content"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + content = args.get("content", "") + overwrite = args.get("overwrite", False) + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, content=content, overwrite=overwrite + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error writing file: {str(e)}" + + +class ReadFileTool(BaseTool): + """Read the contents of a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaReadFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="read_file", + description="Read the contents of a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Name of the file to read", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()(filename=filename) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error reading file: {str(e)}" + + +class ListFilesTool(BaseTool): + """List files in a directory.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaListFilesTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="list_files", + description="List files in a directory.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "directory": types.Schema( + type=types.Type.STRING, + description="Directory to list (defaults to current directory)", + ), + }, + required=[], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + directory = args.get("directory", ".") + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()(directory=directory) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error listing files: {str(e)}" + + +class SearchFilesTool(BaseTool): + """Search for files by name or content.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaSearchFilesTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="search_files", + description="Search for files by name or content.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "query": types.Schema( + type=types.Type.STRING, + description="Text to search for in file names or contents", + ), + "include_content": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to search within file contents (default: True)", + ), + "max_results": types.Schema( + type=types.Type.INTEGER, + description="Maximum number of matches to return (default: 20)", + ), + }, + required=["query"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + query = args.get("query", "").strip() + include_content = args.get("include_content", True) + max_results = args.get("max_results", 20) + + if not query: + return "Error: No search query provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + query=query, include_content=include_content, max_results=max_results + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error searching files: {str(e)}" + + +class EditFileTool(BaseTool): + """Modify specific lines within a file.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaEditFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="edit_file", + description="Modify specific lines within a file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Path to the file to edit", + ), + "edits": types.Schema( + type=types.Type.ARRAY, + description="List of edits to apply", + items=types.Schema( + type=types.Type.OBJECT, + properties={ + "start_line": types.Schema( + type=types.Type.INTEGER, + description="First line (1-based) affected by this edit", + ), + "end_line": types.Schema( + type=types.Type.INTEGER, + description="Last line (1-based, inclusive) to replace", + ), + "new_content": types.Schema( + type=types.Type.STRING, + description="Replacement text", + ), + }, + ), + ), + }, + required=["filename", "edits"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + edits = args.get("edits", []) + + if not filename: + return "Error: No filename provided" + + if not edits: + return "Error: No edits provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, edits=edits + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error editing file: {str(e)}" diff --git a/veadk/tools/vanna_tools/python.py b/veadk/tools/vanna_tools/python.py new file mode 100644 index 00000000..c2d3ba5a --- /dev/null +++ b/veadk/tools/vanna_tools/python.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools.python import ( + RunPythonFileTool as VannaRunPythonFileTool, + PipInstallTool as VannaPipInstallTool, +) +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class RunPythonFileTool(BaseTool): + """Execute a Python file using the workspace interpreter.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the run Python file tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaRunPythonFileTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="run_python_file", + description="Execute a Python file using the workspace interpreter.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="Python file to execute (relative to the workspace root)", + ), + "arguments": types.Schema( + type=types.Type.ARRAY, + description="Optional arguments to pass to the Python script", + items=types.Schema(type=types.Type.STRING), + ), + "timeout_seconds": types.Schema( + type=types.Type.NUMBER, + description="Optional timeout for the command in seconds", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + filename = args.get("filename", "").strip() + arguments = args.get("arguments", []) + timeout_seconds = args.get("timeout_seconds") + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + filename=filename, arguments=arguments, timeout_seconds=timeout_seconds + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error running Python file: {str(e)}" + + +class PipInstallTool(BaseTool): + """Install Python packages using pip.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the pip install tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (default: admin only) + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaPipInstallTool(file_system=file_system) + self.access_groups = access_groups or [ + "admin" + ] # Default: only admin can install packages + + super().__init__( + name="pip_install", + description="Install Python packages using pip.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "packages": types.Schema( + type=types.Type.ARRAY, + description="Packages (with optional specifiers) to install", + items=types.Schema(type=types.Type.STRING), + ), + "upgrade": types.Schema( + type=types.Type.BOOLEAN, + description="Whether to include --upgrade in the pip invocation (default: False)", + ), + "extra_args": types.Schema( + type=types.Type.ARRAY, + description="Additional arguments to pass to pip install", + items=types.Schema(type=types.Type.STRING), + ), + "timeout_seconds": types.Schema( + type=types.Type.NUMBER, + description="Optional timeout for the command in seconds", + ), + }, + required=["packages"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + return VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + packages = args.get("packages", []) + upgrade = args.get("upgrade", False) + extra_args = args.get("extra_args", []) + timeout_seconds = args.get("timeout_seconds") + + if not packages: + return "Error: No packages provided" + + try: + user_groups = self._get_user_groups(tool_context) + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + args_model = self.vanna_tool.get_args_schema()( + packages=packages, + upgrade=upgrade, + extra_args=extra_args, + timeout_seconds=timeout_seconds, + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + return str(result.result_for_llm) + except Exception as e: + return f"Error installing packages: {str(e)}" diff --git a/veadk/tools/vanna_tools/run_sql.py b/veadk/tools/vanna_tools/run_sql.py new file mode 100644 index 00000000..0a083423 --- /dev/null +++ b/veadk/tools/vanna_tools/run_sql.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools import RunSqlTool as VannaRunSqlTool +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class RunSqlTool(BaseTool): + """Execute SQL queries against a database.""" + + def __init__( + self, + sql_runner, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the SQL tool with custom sql_runner and file_system. + + Args: + sql_runner: A Vanna SQL runner instance (e.g., SqliteRunner, PostgresRunner) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + """ + self.sql_runner = sql_runner + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaRunSqlTool( + sql_runner=sql_runner, file_system=file_system + ) + self.access_groups = access_groups or ["admin", "user"] # Default: all groups + + super().__init__( + name="run_sql", # Keep the same name as Vanna + description="Execute a SQL query against the database and return results as a CSV file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "sql": types.Schema( + type=types.Type.STRING, + description="The SQL query to execute", + ), + }, + required=["sql"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Execute the SQL query.""" + sql = args.get("sql", "").strip() + + if not sql: + return "Error: No SQL query provided" + + try: + # Get user groups and check access + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + # Create Vanna context once per request + vanna_context = self._create_vanna_context(tool_context, user_groups) + + # Execute using Vanna tool + args_model = self.vanna_tool.get_args_schema()(sql=sql) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error executing SQL query: {str(e)}" diff --git a/veadk/tools/vanna_tools/summarize_data.py b/veadk/tools/vanna_tools/summarize_data.py new file mode 100644 index 00000000..1a5a407a --- /dev/null +++ b/veadk/tools/vanna_tools/summarize_data.py @@ -0,0 +1,124 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import io +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class SummarizeDataTool(BaseTool): + """Generate statistical summaries of CSV data files.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the summarize data tool with custom file_system. + + Args: + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + file_system: A Vanna file system instance (e.g., LocalFileSystem) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + """ + self.agent_memory = agent_memory + self.file_system = file_system + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="summarize_data", + description="Generate a statistical summary of data from a CSV file.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="The name of the CSV file to summarize", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Generate a statistical summary of CSV data.""" + filename = args.get("filename", "").strip() + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + # Read the file content + content = await self.file_system.read_file(filename, vanna_context) + + # Parse into DataFrame + df = pd.read_csv(io.StringIO(content)) + + # Generate summary stats + description = df.describe().to_markdown() + head = df.head().to_markdown() + info = f"Rows: {len(df)}, Columns: {len(df.columns)}\nColumn Names: {', '.join(df.columns)}" + + summary = f"**Data Summary for {filename}**\n\n**Info:**\n{info}\n\n**First 5 Rows:**\n{head}\n\n**Statistical Description:**\n{description}" + return summary + except Exception as e: + return f"Failed to summarize data: {str(e)}" diff --git a/veadk/tools/vanna_tools/visualize_data.py b/veadk/tools/vanna_tools/visualize_data.py new file mode 100644 index 00000000..4a0fee8c --- /dev/null +++ b/veadk/tools/vanna_tools/visualize_data.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025 Beijing Volcano Engine Technology Co., Ltd. and/or its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List +from google.adk.tools import BaseTool, ToolContext +from google.genai import types +from vanna.tools import VisualizeDataTool as VannaVisualizeDataTool +from vanna.core.user import User +from vanna.core.tool import ToolContext as VannaToolContext + + +class VisualizeDataTool(BaseTool): + """Visualize data from CSV files.""" + + def __init__( + self, + file_system, + agent_memory, + access_groups: Optional[List[str]] = None, + ): + """ + Initialize the visualization tool with custom file_system. + + Args: + file_system: A Vanna file system instance (e.g., LocalFileSystem) + agent_memory: A Vanna agent memory instance (e.g., DemoAgentMemory) + access_groups: List of user groups that can access this tool (e.g., ['admin', 'user']) + user_group_resolver: Optional callable that takes ToolContext and returns user groups + """ + self.file_system = file_system + self.agent_memory = agent_memory + self.vanna_tool = VannaVisualizeDataTool(file_system=file_system) + self.access_groups = access_groups or ["admin", "user"] + + super().__init__( + name="visualize_data", # Keep the same name as Vanna + description="Create visualizations from CSV data files.", + ) + + def _get_declaration(self) -> types.FunctionDeclaration: + return types.FunctionDeclaration( + name=self.name, + description=self.description, + parameters=types.Schema( + type=types.Type.OBJECT, + properties={ + "filename": types.Schema( + type=types.Type.STRING, + description="The name of the CSV file to visualize", + ), + "title": types.Schema( + type=types.Type.STRING, + description="Optional title for the chart", + ), + }, + required=["filename"], + ), + ) + + def _get_user_groups(self, tool_context: ToolContext) -> List[str]: + """Get user groups from context.""" + user_groups = tool_context.state.get("user_groups", ["user"]) + return user_groups + + def _check_access(self, user_groups: List[str]) -> bool: + """Check if user has access to this tool.""" + return any(group in self.access_groups for group in user_groups) + + def _create_vanna_context( + self, tool_context: ToolContext, user_groups: List[str] + ) -> VannaToolContext: + """Create Vanna context from Veadk ToolContext.""" + user_id = tool_context.user_id + user_email = tool_context.state.get("user_email", "user@example.com") + + vanna_user = User(id=user_id, email=user_email, group_memberships=user_groups) + + vanna_context = VannaToolContext( + user=vanna_user, + conversation_id=tool_context.session.id, + request_id=tool_context.session.id, + agent_memory=self.agent_memory, + ) + + return vanna_context + + async def run_async( + self, *, args: Dict[str, Any], tool_context: ToolContext + ) -> str: + """Create a visualization from CSV data.""" + filename = args.get("filename", "").strip() + title = args.get("title") + + if not filename: + return "Error: No filename provided" + + try: + user_groups = self._get_user_groups(tool_context) + + if not self._check_access(user_groups): + return f"Error: Access denied. This tool requires one of the following groups: {', '.join(self.access_groups)}" + + vanna_context = self._create_vanna_context(tool_context, user_groups) + + args_model = self.vanna_tool.get_args_schema()( + filename=filename, title=title + ) + result = await self.vanna_tool.execute(vanna_context, args_model) + + return str(result.result_for_llm) + except Exception as e: + return f"Error visualizing data: {str(e)}"