From 149af4f9948288bcb6e0cff791b0cc58fa4bf6f4 Mon Sep 17 00:00:00 2001 From: yaojin Date: Wed, 6 May 2026 16:12:07 +0800 Subject: [PATCH 01/12] update --- .github/workflows/release.yml | 409 ++++++++++++++++++++++++++++++++++ 1 file changed, 409 insertions(+) create mode 100644 .github/workflows/release.yml diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..d97e1e10c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,409 @@ +name: Release + +# Uses GitHub Models API with GITHUB_TOKEN (auto-provided) for AI release notes. +# Model: gpt-4o-mini (no manual secrets required) + +on: + workflow_dispatch: + inputs: + release_type: + description: Release type to cut + required: true + default: auto + type: choice + options: + - auto + - patch + - minor + - major + prerelease: + description: Mark the GitHub Release as a prerelease + required: true + default: false + type: boolean + use_ai_notes: + description: Use OpenAI to draft release notes when OPENAI_API_KEY is configured + required: true + default: true + type: boolean + +permissions: + contents: write + +concurrency: + group: release-${{ github.ref_name }} + cancel-in-progress: false + +jobs: + release: + name: Cut release + runs-on: ubuntu-latest + + steps: + - name: Checkout source + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + fetch-depth: 0 + persist-credentials: true + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + - name: Resolve base tag and target version + id: version + shell: bash + env: + REQUESTED_RELEASE_TYPE: ${{ inputs.release_type }} + run: | + set -euo pipefail + + git fetch --force --tags + + stable_tag="$(git tag --merged HEAD --list 'v*' --sort=-version:refname | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' | head -n 1 || true)" + base_tag="$stable_tag" + if [ -z "$base_tag" ]; then + base_tag="$(git tag --merged HEAD --list 'v*' --sort=-version:refname | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+([-.][0-9A-Za-z.]+)?$' | head -n 1 || true)" + fi + if [ -z "$base_tag" ]; then + base_tag="v0.0.0" + log_range="" + else + log_range="${base_tag}..HEAD" + fi + + if [ -n "$log_range" ] && [ -z "$(git log --oneline "$log_range")" ]; then + echo "No commits found since ${base_tag}; skipping duplicate release." + exit 1 + fi + + release_type="$REQUESTED_RELEASE_TYPE" + if [ "$release_type" = "auto" ]; then + if [ -n "$log_range" ]; then + subjects="$(git log --format=%s "$log_range")" + bodies="$(git log --format=%B "$log_range")" + else + subjects="$(git log --format=%s)" + bodies="$(git log --format=%B)" + fi + if printf '%s\n' "$bodies" | grep -Eq 'BREAKING CHANGE|^[^[:space:]]+(\([^)]+\))?!:'; then + release_type="major" + elif printf '%s\n' "$subjects" | grep -Eq '^feat(\([^)]+\))?:'; then + release_type="minor" + else + release_type="patch" + fi + fi + + next_version="$(python - "$base_tag" "$release_type" <<'PY' + import re + import sys + + base_tag, release_type = sys.argv[1], sys.argv[2] + match = re.match(r"^v?(\d+)\.(\d+)\.(\d+)", base_tag) + if not match: + major, minor, patch = 0, 0, 0 + else: + major, minor, patch = map(int, match.groups()) + + if release_type == "major": + major += 1 + minor = 0 + patch = 0 + elif release_type == "minor": + minor += 1 + patch = 0 + else: + patch += 1 + + print(f"{major}.{minor}.{patch}") + PY + )" + + tag_name="v${next_version}" + + if git rev-parse "$tag_name" >/dev/null 2>&1; then + echo "Tag ${tag_name} already exists." + exit 1 + fi + + { + echo "base_tag=$base_tag" + echo "release_type=$release_type" + echo "version=$next_version" + echo "tag=$tag_name" + echo "log_range=$log_range" + } >> "$GITHUB_OUTPUT" + + echo "Base tag: $base_tag" + echo "Release type: $release_type" + echo "Next version: $next_version" + + - name: Collect release context + shell: bash + env: + BASE_TAG: ${{ steps.version.outputs.base_tag }} + TARGET_VERSION: ${{ steps.version.outputs.version }} + TARGET_TAG: ${{ steps.version.outputs.tag }} + RELEASE_TYPE: ${{ steps.version.outputs.release_type }} + LOG_RANGE: ${{ steps.version.outputs.log_range }} + SOURCE_REF: ${{ github.ref_name }} + run: | + set -euo pipefail + + mkdir -p .github/release-artifacts + + if [ -n "$LOG_RANGE" ]; then + git log --no-merges --pretty=format:'- %s (%h)' "$LOG_RANGE" > .github/release-artifacts/commit-bullets.txt + git log --no-merges --pretty=format:'%H%x09%s' "$LOG_RANGE" > .github/release-artifacts/commit-table.tsv + else + git log --no-merges --pretty=format:'- %s (%h)' > .github/release-artifacts/commit-bullets.txt + git log --no-merges --pretty=format:'%H%x09%s' > .github/release-artifacts/commit-table.tsv + fi + + python <<'PY' + from pathlib import Path + import os + + base_tag = os.environ["BASE_TAG"] + target_version = os.environ["TARGET_VERSION"] + target_tag = os.environ["TARGET_TAG"] + release_type = os.environ["RELEASE_TYPE"] + + notes_path = Path("RELEASE_NOTES.md") + existing = notes_path.read_text(encoding="utf-8") if notes_path.exists() else "" + style_excerpt = "\n".join(existing.splitlines()[:120]).strip() + commit_bullets = Path(".github/release-artifacts/commit-bullets.txt").read_text(encoding="utf-8").strip() + + prompt = f"""You are writing Clawith release notes in markdown. + + Rules: + - Start with a top-level heading exactly like: # {target_tag} — + - Keep the tone concise and product-focused. + - Use these sections when they make sense: ## What's New, ## Bug Fixes, ## Upgrade Guide, ## Notes + - Mention upgrade or migration risk only when the commits strongly imply it. + - Do not invent features that are not present in the commit list. + - Prefer grouping related changes instead of listing every commit verbatim. + + Context: + - Previous release tag: {base_tag} + - Target release tag: {target_tag} + - Release type: {release_type} + + Release branch: + {os.environ["SOURCE_REF"]} + + Recent release note style: + {style_excerpt or "(no prior release note style provided)"} + + Commits included in this release: + {commit_bullets or "- No commit bullets collected"} + """ + + Path(".github/release-artifacts/release-prompt.txt").write_text(prompt, encoding="utf-8") + PY + + - name: Update version files + shell: bash + env: + TARGET_VERSION: ${{ steps.version.outputs.version }} + run: | + set -euo pipefail + + printf '%s\n' "$TARGET_VERSION" > backend/VERSION + printf '%s\n' "$TARGET_VERSION" > frontend/VERSION + + - name: Draft release notes with GitHub Models + if: ${{ inputs.use_ai_notes }} + shell: bash + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + set -euo pipefail + + python <<'PY' + from pathlib import Path + import json + import os + + prompt = Path(".github/release-artifacts/release-prompt.txt").read_text(encoding="utf-8") + payload = { + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": prompt} + ], + } + Path(".github/release-artifacts/openai-payload.json").write_text( + json.dumps(payload, ensure_ascii=False), + encoding="utf-8", + ) + PY + + curl -fsSL https://models.inference.ai.azure.com/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $GITHUB_TOKEN" \ + -d @.github/release-artifacts/openai-payload.json \ + > .github/release-artifacts/openai-response.json + + python <<'PY' + from pathlib import Path + import json + + response = json.loads(Path(".github/release-artifacts/openai-response.json").read_text(encoding="utf-8")) + text = response.get("choices", [{}])[0].get("message", {}).get("content", "").strip() + + if not text: + raise SystemExit("GitHub Models did not return release note text.") + + Path(".github/release-artifacts/release-notes.generated.md").write_text( + text.rstrip() + "\n", + encoding="utf-8", + ) + PY + + - name: Build fallback release notes + shell: bash + env: + BASE_TAG: ${{ steps.version.outputs.base_tag }} + TARGET_VERSION: ${{ steps.version.outputs.version }} + TARGET_TAG: ${{ steps.version.outputs.tag }} + SOURCE_REF: ${{ github.ref_name }} + run: | + set -euo pipefail + + if [ -s .github/release-artifacts/release-notes.generated.md ]; then + exit 0 + fi + + python <<'PY' + from pathlib import Path + import os + + base_tag = os.environ["BASE_TAG"] + target_version = os.environ["TARGET_VERSION"] + target_tag = os.environ["TARGET_TAG"] + source_ref = os.environ["SOURCE_REF"] + + entries = [] + for line in Path(".github/release-artifacts/commit-table.tsv").read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + _, subject = line.split("\t", 1) + entries.append(subject.strip()) + + features = [] + fixes = [] + others = [] + for subject in entries: + lowered = subject.lower() + if lowered.startswith("feat"): + features.append(subject) + elif lowered.startswith("fix"): + fixes.append(subject) + else: + others.append(subject) + + def bullets(items): + return "\n".join(f"- {item}" for item in items[:8]) or "- No user-facing highlights captured from commit subjects." + + sections = [ + f"# {target_tag} — Release Highlights", + "", + "## What's New", + bullets(features or others), + ] + + if fixes: + sections.extend([ + "", + "## Bug Fixes", + bullets(fixes), + ]) + + sections.extend([ + "", + "## Upgrade Guide", + "", + "### Docker Deployment", + "```bash", + f"git pull origin {source_ref}", + "docker compose down && docker compose up -d --build", + "```", + "", + "### Source Deployment", + "```bash", + f"git pull origin {source_ref}", + "cd frontend && npm install && npm run build", + "cd ..", + "```", + "", + "## Notes", + f"- Release generated from changes since `{base_tag}`.", + f"- Runtime version files were updated to `{target_version}`.", + ]) + + Path(".github/release-artifacts/release-notes.generated.md").write_text( + "\n".join(sections).rstrip() + "\n", + encoding="utf-8", + ) + PY + + - name: Refresh RELEASE_NOTES.md + shell: bash + run: | + set -euo pipefail + + python <<'PY' + from pathlib import Path + + notes_file = Path(".github/release-artifacts/release-notes.generated.md") + release_notes_path = Path("RELEASE_NOTES.md") + + new_block = notes_file.read_text(encoding="utf-8").strip() + existing = release_notes_path.read_text(encoding="utf-8").strip() if release_notes_path.exists() else "" + + if existing: + combined = f"{new_block}\n\n---\n\n{existing}\n" + else: + combined = f"{new_block}\n" + + release_notes_path.write_text(combined, encoding="utf-8") + PY + + - name: Commit release metadata + shell: bash + env: + TARGET_TAG: ${{ steps.version.outputs.tag }} + run: | + set -euo pipefail + + git add backend/VERSION frontend/VERSION RELEASE_NOTES.md + + if git diff --cached --quiet; then + echo "No release metadata changes to commit." + exit 0 + fi + + git commit -m "chore(release): cut ${TARGET_TAG}" + git push origin HEAD:${{ github.ref_name }} + + - name: Create and push tag + shell: bash + env: + TARGET_TAG: ${{ steps.version.outputs.tag }} + run: | + set -euo pipefail + + git tag -a "$TARGET_TAG" -m "Release $TARGET_TAG" + git push origin "$TARGET_TAG" + + - name: Publish GitHub Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.version.outputs.tag }} + name: ${{ steps.version.outputs.tag }} + body_path: .github/release-artifacts/release-notes.generated.md + prerelease: ${{ inputs.prerelease }} + generate_release_notes: false From 0c8c9324ada47a7c9a2490c1c59b06ac73a5d27d Mon Sep 17 00:00:00 2001 From: yaojin Date: Wed, 6 May 2026 16:22:03 +0800 Subject: [PATCH 02/12] update --- .github/workflows/release.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d97e1e10c..5e6d172b7 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,7 +1,7 @@ name: Release -# Uses GitHub Models API with GITHUB_TOKEN (auto-provided) for AI release notes. -# Model: gpt-4o-mini (no manual secrets required) +# Uses GitHub Models API for AI release notes. +# Requires: Repository secret MODELS_TOKEN (PAT with models:read scope) on: workflow_dispatch: @@ -22,7 +22,7 @@ on: default: false type: boolean use_ai_notes: - description: Use OpenAI to draft release notes when OPENAI_API_KEY is configured + description: Use GitHub Models to draft release notes when MODELS_TOKEN is configured required: true default: true type: boolean @@ -219,7 +219,7 @@ jobs: if: ${{ inputs.use_ai_notes }} shell: bash env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + MODELS_TOKEN: ${{ secrets.MODELS_TOKEN }} run: | set -euo pipefail @@ -243,7 +243,7 @@ jobs: curl -fsSL https://models.inference.ai.azure.com/chat/completions \ -H "Content-Type: application/json" \ - -H "Authorization: Bearer $GITHUB_TOKEN" \ + -H "Authorization: Bearer $MODELS_TOKEN" \ -d @.github/release-artifacts/openai-payload.json \ > .github/release-artifacts/openai-response.json From 72ee1dcca02fee82ccb11a49b0012b7ae0ed7a68 Mon Sep 17 00:00:00 2001 From: yaojin Date: Wed, 6 May 2026 16:24:42 +0800 Subject: [PATCH 03/12] update --- .github/workflows/release.yml | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 5e6d172b7..d7534f2fa 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -177,6 +177,12 @@ jobs: style_excerpt = "\n".join(existing.splitlines()[:120]).strip() commit_bullets = Path(".github/release-artifacts/commit-bullets.txt").read_text(encoding="utf-8").strip() + # Limit commit bullets to fit within API token limits + lines = commit_bullets.splitlines() + if len(lines) > 100: + lines = lines[:100] + [f"- ... and {len(lines) - 100} more commits"] + commit_summary = "\n".join(lines) + prompt = f"""You are writing Clawith release notes in markdown. Rules: @@ -191,15 +197,10 @@ jobs: - Previous release tag: {base_tag} - Target release tag: {target_tag} - Release type: {release_type} - - Release branch: - {os.environ["SOURCE_REF"]} - - Recent release note style: - {style_excerpt or "(no prior release note style provided)"} + - Release branch: {os.environ["SOURCE_REF"]} Commits included in this release: - {commit_bullets or "- No commit bullets collected"} + {commit_summary or "- No commit bullets collected"} """ Path(".github/release-artifacts/release-prompt.txt").write_text(prompt, encoding="utf-8") From 79d325ee18f384f6c76afa10347fc43b1cbb7b19 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Wed, 6 May 2026 16:49:47 +0800 Subject: [PATCH 04/12] refactor: add stateless HA runtime foundations --- backend/app/api/feishu.py | 103 +- backend/app/api/files.py | 260 ++--- backend/app/api/pages.py | 18 +- backend/app/api/relationships.py | 23 +- backend/app/api/slack.py | 19 +- backend/app/api/tenants.py | 40 +- backend/app/api/upload.py | 32 +- backend/app/api/webhooks.py | 53 +- backend/app/api/websocket.py | 104 +- backend/app/config.py | 24 + backend/app/main.py | 275 ++--- backend/app/models/trigger_execution.py | 41 + backend/app/services/agent_context.py | 62 +- backend/app/services/agent_manager.py | 91 +- backend/app/services/agent_seeder.py | 244 ++--- backend/app/services/agent_tools.py | 2 +- backend/app/services/collaboration.py | 18 +- backend/app/services/dingtalk_stream.py | 60 +- backend/app/services/enterprise_sync.py | 24 +- backend/app/services/heartbeat.py | 14 +- backend/app/services/llm/caller.py | 6 +- backend/app/services/okr_scheduler.py | 37 +- backend/app/services/realtime.py | 19 + .../app/services/realtime_runtime/__init__.py | 15 + .../app/services/realtime_runtime/router.py | 200 ++++ backend/app/services/storage.py | 45 + .../app/services/storage_runtime/__init__.py | 42 + .../services/storage_runtime/agent_files.py | 84 ++ backend/app/services/storage_runtime/base.py | 57 ++ .../app/services/storage_runtime/facade.py | 52 + backend/app/services/storage_runtime/local.py | 101 ++ backend/app/services/storage_runtime/s3.py | 204 ++++ backend/app/services/storage_runtime/utils.py | 24 + backend/app/services/supervision_reminder.py | 2 - backend/app/services/trigger_daemon.py | 964 +----------------- .../app/services/trigger_runtime/__init__.py | 30 + .../app/services/trigger_runtime/dispatch.py | 64 ++ .../app/services/trigger_runtime/evaluator.py | 429 ++++++++ .../services/trigger_runtime/executions.py | 131 +++ .../app/services/trigger_runtime/invoker.py | 368 +++++++ backend/app/services/trigger_runtime/keys.py | 47 + backend/app/services/trigger_runtime/queue.py | 73 ++ .../app/services/workspace_collaboration.py | 45 +- backend/entrypoint.sh | 70 +- deploy/nginx/role-all.conf | 30 + docker-compose.role-all.yml | 141 +++ docker-compose.yml | 95 ++ 47 files changed, 3175 insertions(+), 1707 deletions(-) create mode 100644 backend/app/models/trigger_execution.py create mode 100644 backend/app/services/realtime.py create mode 100644 backend/app/services/realtime_runtime/__init__.py create mode 100644 backend/app/services/realtime_runtime/router.py create mode 100644 backend/app/services/storage.py create mode 100644 backend/app/services/storage_runtime/__init__.py create mode 100644 backend/app/services/storage_runtime/agent_files.py create mode 100644 backend/app/services/storage_runtime/base.py create mode 100644 backend/app/services/storage_runtime/facade.py create mode 100644 backend/app/services/storage_runtime/local.py create mode 100644 backend/app/services/storage_runtime/s3.py create mode 100644 backend/app/services/storage_runtime/utils.py create mode 100644 backend/app/services/trigger_runtime/__init__.py create mode 100644 backend/app/services/trigger_runtime/dispatch.py create mode 100644 backend/app/services/trigger_runtime/evaluator.py create mode 100644 backend/app/services/trigger_runtime/executions.py create mode 100644 backend/app/services/trigger_runtime/invoker.py create mode 100644 backend/app/services/trigger_runtime/keys.py create mode 100644 backend/app/services/trigger_runtime/queue.py create mode 100644 deploy/nginx/role-all.conf create mode 100644 docker-compose.role-all.yml diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index 00fe32737..c279c6d63 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Awaitable, Callable +from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from loguru import logger @@ -18,6 +19,7 @@ from app.models.identity import IdentityProvider from app.schemas.schemas import ChannelConfigCreate, ChannelConfigOut, TokenResponse, UserOut from app.services.feishu_service import feishu_service +from app.services.storage import agent_upload_key, get_storage_backend, store_agent_upload router = APIRouter(tags=["feishu"]) @@ -33,6 +35,21 @@ "抱歉,我暂时无法稳定识别你的飞书账号,已停止本次处理以避免重复创建账号。" "请稍后重试,或联系管理员检查飞书 Contact API 权限。" ) + + +def _storage_mtime(entry) -> float: + raw = str(getattr(entry, "modified_at", "") or "") + if not raw: + return 0.0 + try: + return float(raw) + except ValueError: + try: + return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() + except ValueError: + return 0.0 + + def _build_card( answer_text: str, thinking_text: str = "", @@ -566,20 +583,18 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession if _post_image_keys: import base64 as _b64 _msg_id = message.get("message_id", "") - from pathlib import Path as _PostPath - from app.config import get_settings as _post_gs - _post_settings = _post_gs() - _upload_dir = _PostPath(_post_settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - _upload_dir.mkdir(parents=True, exist_ok=True) for _ik in _post_image_keys: try: _img_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, _msg_id, _ik, "image" ) - # Save to workspace - _save_path = _upload_dir / f"image_{_ik[-8:]}.jpg" - _save_path.write_bytes(_img_bytes) - logger.info(f"[Feishu] Saved post image to {_save_path} ({len(_img_bytes)} bytes)") + _, _workspace_path, _save_path = await store_agent_upload( + agent_id, + f"image_{_ik[-8:]}.jpg", + _img_bytes, + content_type="image/jpeg", + ) + logger.info(f"[Feishu] Saved post image to {_workspace_path} ({len(_img_bytes)} bytes)") # Embed as base64 marker for vision models _b64_data = _b64.b64encode(_img_bytes).decode("ascii") _image_markers.append(f"[image_data:data:image/jpeg;base64,{_b64_data}]") @@ -817,21 +832,22 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession # to disk always succeeds even if the DB transaction fails. try: import time as _time - import pathlib as _pl - from app.config import get_settings as _gs - _upload_dir = _pl.Path(_gs().AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" + _storage = get_storage_backend() + _upload_key = agent_upload_key(agent_id, "placeholder").rsplit("/", 1)[0] _recent_file_path = None - if _upload_dir.exists() and "uploads/" not in user_text and "workspace/" not in user_text: + if "uploads/" not in user_text and "workspace/" not in user_text: _now = _time.time() - _candidates = sorted( - _upload_dir.iterdir(), - key=lambda p: p.stat().st_mtime, - reverse=True, - ) - for _fp in _candidates: - if _fp.is_file() and (_now - _fp.stat().st_mtime) < 1800: # 30 min - _recent_file_path = f"uploads/{_fp.name}" - break + if await _storage.exists(_upload_key) and await _storage.is_dir(_upload_key): + _candidates = sorted( + [e for e in await _storage.list_dir(_upload_key) if not e.is_dir], + key=_storage_mtime, + reverse=True, + ) + for _entry in _candidates: + _mtime = _storage_mtime(_entry) + if _mtime and (_now - _mtime) < 1800: + _recent_file_path = f"uploads/{_entry.name}" + break if _recent_file_path: # _recent_file_path is relative to uploads dir; agent workspace root is # AGENT_DATA_DIR/{agent_id}/, so the correct relative path is workspace/uploads/ @@ -869,11 +885,16 @@ async def _feishu_file_sender(file_path, msg: str = ""): _fs = _gs_fallback() _base_url = getattr(_fs, 'BASE_URL', '').rstrip('/') or '' _fp = _P(file_path) - _ws_root = _P(_fs.AGENT_DATA_DIR) + _parts = list(_fp.parts) try: - _rel = str(_fp.relative_to(_ws_root / str(agent_id))) + _workspace_idx = _parts.index("workspace") + _rel = "/".join(_parts[_workspace_idx:]) except ValueError: - _rel = _fp.name + _ws_root = _P(getattr(_fs, "STORAGE_LOCAL_ROOT", "") or _fs.AGENT_DATA_DIR) + try: + _rel = str(_fp.relative_to(_ws_root / str(agent_id))) + except ValueError: + _rel = _fp.name _fallback_parts = [] if msg: _fallback_parts.append(msg) @@ -1195,8 +1216,6 @@ async def _handle_feishu_file( ): """Handle incoming file or image messages from Feishu (runs as a background task).""" import asyncio, random, json - from pathlib import Path - from app.config import get_settings from app.models.audit import ChatMessage from app.models.agent import Agent as AgentModel from app.models.user import User as UserModel @@ -1226,18 +1245,18 @@ async def _handle_feishu_file( return # Resolve workspace upload dir - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - save_path = upload_dir / filename - # Download the file try: file_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, message_id, file_key, res_type ) - save_path.write_bytes(file_bytes) - logger.info(f"[Feishu] Saved {msg_type} to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, save_path = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg" if msg_type == "image" else None, + ) + logger.info(f"[Feishu] Saved {msg_type} to {workspace_path} ({len(file_bytes)} bytes)") except Exception as e: logger.error(f"[Feishu] Failed to download {msg_type}: {e}") err_tip = "抱歉,文件下载失败。可能原因:机器人缺少 `im:resource` 权限(文件读取)。\n请在飞书开放平台 → 权限管理 → 批量导入权限 JSON → 重新发布机器人版本后重试。" @@ -1555,20 +1574,18 @@ async def _img_heartbeat(): async def _download_post_images(agent_id, config, message_id, image_keys): """Download images embedded in a Feishu post message to the agent's workspace.""" - from pathlib import Path - from app.config import get_settings - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - for ik in image_keys: try: file_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, message_id, ik, "image" ) - save_path = upload_dir / f"image_{ik[-8:]}.jpg" - save_path.write_bytes(file_bytes) - logger.info(f"[Feishu] Saved post image to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload( + agent_id, + f"image_{ik[-8:]}.jpg", + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[Feishu] Saved post image to {workspace_path} ({len(file_bytes)} bytes)") except Exception as e: logger.error(f"[Feishu] Failed to download post image {ik}: {e}") diff --git a/backend/app/api/files.py b/backend/app/api/files.py index aca0725a3..aace960ac 100644 --- a/backend/app/api/files.py +++ b/backend/app/api/files.py @@ -29,6 +29,12 @@ release_edit_lock, write_workspace_file, ) +from app.services.storage import ( + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, +) from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -66,7 +72,14 @@ class RestoreRevisionBody(BaseModel): def _agent_base_dir(agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / str(agent_id) + + +def _agent_storage_key(agent_id: uuid.UUID, rel_path: str = "") -> str: + prefix = str(agent_id) + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix def _safe_path(agent_id: uuid.UUID, rel_path: str) -> Path: @@ -87,27 +100,23 @@ async def list_files( ): """List files and directories in an agent's file system.""" await check_agent_access(db, current_user, agent_id) - target = _safe_path(agent_id, path) - - if not target.exists(): + storage = get_storage_backend() + storage_key = _agent_storage_key(agent_id, path) + if not await storage.exists(storage_key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Path not found") - if not target.is_dir(): + if not await storage.is_dir(storage_key): raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Path is not a directory") items = [] - base_abs = _agent_base_dir(agent_id).resolve() - for entry in sorted(target.iterdir(), key=lambda e: (not e.is_dir(), e.name)): - if entry.name == '.gitkeep': - continue - rel = str(entry.resolve().relative_to(base_abs)) - stat = entry.stat() + for entry in await storage.list_dir(storage_key): + rel_path = str(Path(entry.key).relative_to(str(agent_id))) items.append(FileInfo( name=entry.name, - path=rel, - is_dir=entry.is_dir(), - size=stat.st_size if entry.is_file() else 0, - modified_at=str(stat.st_mtime), - url=f"/api/agents/{agent_id}/files/download?path={rel}" if not entry.is_dir() else None + path=rel_path, + is_dir=entry.is_dir, + size=entry.size, + modified_at=entry.modified_at, + url=f"/api/agents/{agent_id}/files/download?path={rel_path}" if not entry.is_dir else None )) return items @@ -121,17 +130,17 @@ async def read_file( ): """Read the content of a file.""" await check_agent_access(db, current_user, agent_id) - target = _safe_path(agent_id, path) - - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key = _agent_storage_key(agent_id, path) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") try: - async with aiofiles.open(target, "r", encoding="utf-8") as f: - content = await f.read() + content = await storage.read_text(key, encoding="utf-8", errors="replace") return FileContent(path=path, content=content) except UnicodeDecodeError: - return FileContent(path=path, content=f"[二进制文件: {target.name}, {target.stat().st_size} bytes]") + stat = await storage.stat(key) + return FileContent(path=path, content=f"[二进制文件: {Path(path).name}, {stat.size} bytes]") def _file_kind(path: str) -> str: @@ -237,16 +246,18 @@ async def preview_file( ): """Return a browser-friendly preview payload for Workspace files.""" await check_agent_access(db, current_user, agent_id) - target = _safe_path(agent_id, path) - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key = _agent_storage_key(agent_id, path) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") kind = _file_kind(path) - mime_type = mimetypes.guess_type(target.name)[0] or "application/octet-stream" + mime_type = mimetypes.guess_type(Path(path).name)[0] or "application/octet-stream" download_url = f"/api/agents/{agent_id}/files/download?path={path}" + local_target: Path | None = None if kind in {"markdown", "html", "text"}: - content = await read_text_if_exists(target) + content = await storage.read_text(key, encoding="utf-8", errors="replace") return { "path": path, "kind": kind, @@ -256,7 +267,7 @@ async def preview_file( "download_url": download_url, } if kind == "csv": - content = await read_text_if_exists(target) or "" + content = await storage.read_text(key, encoding="utf-8", errors="replace") rows = _parse_csv_rows(content) return { "path": path, @@ -277,6 +288,8 @@ async def preview_file( } if kind == "xlsx": try: + target = await ensure_local_path(key) + local_target = target from openpyxl import load_workbook wb = load_workbook(target, read_only=True, data_only=True) @@ -311,6 +324,8 @@ async def preview_file( "download_url": download_url, } if kind in {"docx", "pptx"}: + target = await ensure_local_path(key) + local_target = target extracted_text = _extract_document_text(target, kind) companion = _find_companion_text_preview(target) companion_content = await read_text_if_exists(companion) if companion is not None else None @@ -323,7 +338,10 @@ async def preview_file( "download_url": download_url, } - companion = _find_companion_text_preview(target) + if local_target is not None: + companion = _find_companion_text_preview(local_target) + else: + companion = None if companion is not None: content = await read_text_if_exists(companion) return { @@ -336,13 +354,13 @@ async def preview_file( "download_url": download_url, } - raw = target.read_bytes() + raw = await storage.read_bytes(key) encoded = base64.b64encode(raw[:1024 * 1024]).decode("ascii") return { "path": path, "kind": kind, "mime_type": mime_type, - "size": target.stat().st_size, + "size": len(raw), "base64_sample": encoded, "download_url": download_url, } @@ -384,13 +402,29 @@ async def download_file( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") await check_agent_access(db, user, agent_id) - target = _safe_path(agent_id, path) - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key = _agent_storage_key(agent_id, path) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") - return FileResponse( - path=str(target), - filename=target.name, - content_disposition_type="inline" if inline else "attachment", + presigned = await storage.presign_download_url(key, filename=Path(path).name, inline=inline) + if presigned: + return Response( + status_code=302, + headers={"Location": presigned}, + ) + local_path = await storage.local_path_for(key) + if local_path is not None: + return FileResponse( + path=str(local_path), + filename=Path(path).name, + content_disposition_type="inline" if inline else "attachment", + ) + data = await storage.read_bytes(key) + disposition = "inline" if inline else "attachment" + return Response( + content=data, + media_type=guess_content_type(Path(path).name), + headers={"Content-Disposition": f'{disposition}; filename="{Path(path).name}"'}, ) @@ -505,11 +539,10 @@ async def restore_file_revision( if revision.after_content is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot restore an empty/deleted revision") - target = _safe_path(agent_id, revision.path) - before = await read_text_if_exists(target) - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(revision.after_content) + storage = get_storage_backend() + storage_key = _agent_storage_key(agent_id, revision.path) + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.exists(storage_key) else None + await storage.write_text(storage_key, revision.after_content, encoding="utf-8") restored = await record_revision( db, agent_id=agent_id, @@ -533,16 +566,14 @@ async def delete_file( ): """Delete a file.""" await check_agent_access(db, current_user, agent_id) - target = _safe_path(agent_id, path) - - if not target.exists(): + storage = get_storage_backend() + key = _agent_storage_key(agent_id, path) + if not await storage.exists(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") - - if target.is_dir(): - import shutil - shutil.rmtree(target) + if await storage.is_dir(key): + await storage.delete_tree(key) else: - target.unlink() + await storage.delete(key) return {"status": "ok", "path": path} @@ -579,19 +610,11 @@ async def import_skill_to_agent( if not skill.files: raise HTTPException(status_code=400, detail="Skill has no files") - # Write each file into the agent's workspace - base = _agent_base_dir(agent_id) - skill_dir = base / "skills" / skill.folder_name - skill_dir.mkdir(parents=True, exist_ok=True) - + storage = get_storage_backend() written = [] for f in skill.files: - file_path = (skill_dir / f.path).resolve() - # Safety check - if not str(file_path).startswith(str(base.resolve())): - continue - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(f.content, encoding="utf-8") + skill_key = _agent_storage_key(agent_id, f"skills/{skill.folder_name}/{f.path}") + await storage.write_text(skill_key, f.content, encoding="utf-8") written.append(f.path) return { @@ -630,28 +653,25 @@ async def upload_file_to_workspace( if normalized_path not in {"workspace", "skills"} and not normalized_path.startswith(("workspace/", "skills/")): raise HTTPException(status_code=400, detail="右侧根目录视图是 agent 根目录;上传文件时请放到 workspace/ 或 skills/ 目录下") - base = _agent_base_dir(agent_id) - target_dir = (base / normalized_path).resolve() - if not str(target_dir).startswith(str(base.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target_dir.mkdir(parents=True, exist_ok=True) filename = file.filename or "unnamed" # Sanitize filename filename = filename.replace("/", "_").replace("\\", "_") - save_path = target_dir / filename + storage = get_storage_backend() + file_key = _agent_storage_key(agent_id, f"{normalized_path}/{filename}") content = await file.read() - save_path.write_bytes(content) + await storage.write_bytes(file_key, content, content_type=guess_content_type(filename)) # Auto-extract text from non-text files extracted_path = None from app.services.text_extractor import needs_extraction, save_extracted_text if needs_extraction(filename): + save_path = await ensure_local_path(file_key) txt_file = save_extracted_text(save_path, content, filename) if txt_file: - base_abs = base.resolve() - extracted_path = str(txt_file.resolve().relative_to(base_abs)) + extracted_path = f"{normalized_path}/{txt_file.name}" + extracted_key = _agent_storage_key(agent_id, extracted_path) + await storage.write_bytes(extracted_key, txt_file.read_bytes(), content_type="text/plain; charset=utf-8") return { "status": "ok", @@ -669,11 +689,19 @@ async def upload_file_to_workspace( def _enterprise_kb_dir(tenant_id: str) -> Path: - return Path(settings.AGENT_DATA_DIR) / f"enterprise_info_{tenant_id}" / "knowledge_base" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / f"enterprise_info_{tenant_id}" / "knowledge_base" def _enterprise_info_dir(tenant_id: str) -> Path: - return Path(settings.AGENT_DATA_DIR) / f"enterprise_info_{tenant_id}" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / f"enterprise_info_{tenant_id}" + + +def _enterprise_storage_key(tenant_id: str, rel_path: str = "") -> str: + prefix = f"enterprise_info_{tenant_id}" + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix @enterprise_kb_router.get("/files") @@ -684,31 +712,20 @@ async def list_enterprise_kb_files( """List files in enterprise knowledge base (tenant-scoped).""" if not current_user.tenant_id: return [] - info_dir = _enterprise_info_dir(str(current_user.tenant_id)).resolve() - info_dir.mkdir(parents=True, exist_ok=True) - - if path: - target = (info_dir / path).resolve() - else: - target = info_dir - if not str(target).startswith(str(info_dir)): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - if not target.exists() or not target.is_dir(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + if not await storage.exists(storage_key) or not await storage.is_dir(storage_key): return [] items = [] - for entry in sorted(target.iterdir(), key=lambda e: (not e.is_dir(), e.name)): - if entry.name == '.gitkeep': - continue - rel = str(entry.resolve().relative_to(info_dir.resolve())) - stat = entry.stat() + for entry in await storage.list_dir(storage_key): + rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) items.append({ "name": entry.name, "path": rel, - "is_dir": entry.is_dir(), - "size": stat.st_size if entry.is_file() else 0, - "url": f"/api/enterprise/knowledge-base/download?path={rel}" if not entry.is_dir() else None + "is_dir": entry.is_dir, + "size": entry.size, + "url": f"/api/enterprise/knowledge-base/download?path={rel}" if not entry.is_dir else None }) return items @@ -727,28 +744,28 @@ async def upload_enterprise_kb_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target_dir = (info_dir / sub_path).resolve() - if not str(target_dir).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target_dir.mkdir(parents=True, exist_ok=True) filename = file.filename or "unnamed" filename = filename.replace("/", "_").replace("\\", "_") - save_path = target_dir / filename + storage = get_storage_backend() + rel_path = f"{sub_path}/{filename}" if sub_path else filename + storage_key = _enterprise_storage_key(str(current_user.tenant_id), rel_path) content = await file.read() - save_path.write_bytes(content) + await storage.write_bytes(storage_key, content, content_type=guess_content_type(filename)) # Auto-extract text from non-text files extracted_path = None from app.services.text_extractor import needs_extraction, save_extracted_text if needs_extraction(filename): + save_path = await ensure_local_path(storage_key) txt_file = save_extracted_text(save_path, content, filename) if txt_file: - extracted_path = str(txt_file.resolve().relative_to(info_dir.resolve())) - - rel_path = f"{sub_path}/{filename}" if sub_path else filename + extracted_path = f"{sub_path}/{txt_file.name}" if sub_path else txt_file.name + await storage.write_bytes( + _enterprise_storage_key(str(current_user.tenant_id), extracted_path), + txt_file.read_bytes(), + content_type="text/plain; charset=utf-8", + ) return { "status": "ok", "path": rel_path, @@ -767,18 +784,17 @@ async def read_enterprise_file( """Read content of an enterprise knowledge base file (tenant-scoped).""" if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + if not await storage.exists(storage_key) or not await storage.is_file(storage_key): raise HTTPException(status_code=404, detail="File not found") try: - content = target.read_text(encoding="utf-8", errors="replace") + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") return {"path": path, "content": content} except Exception: - return {"path": path, "content": f"[二进制文件: {target.name}, {target.stat().st_size} bytes]"} + stat = await storage.stat(storage_key) + return {"path": path, "content": f"[二进制文件: {Path(path).name}, {stat.size} bytes]"} @enterprise_kb_router.put("/content") @@ -793,14 +809,8 @@ async def write_enterprise_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(data.content) + storage = get_storage_backend() + await storage.write_text(_enterprise_storage_key(str(current_user.tenant_id), path), data.content, encoding="utf-8") return {"status": "ok", "path": path} @@ -815,18 +825,14 @@ async def delete_enterprise_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - if not target.exists(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + if not await storage.exists(storage_key): raise HTTPException(status_code=404, detail="File not found") - - if target.is_dir(): - import shutil - shutil.rmtree(target) + if await storage.is_dir(storage_key): + await storage.delete_tree(storage_key) else: - target.unlink() + await storage.delete(storage_key) return {"status": "ok", "path": path} diff --git a/backend/app/api/pages.py b/backend/app/api/pages.py index 461e79a67..3856d80e0 100644 --- a/backend/app/api/pages.py +++ b/backend/app/api/pages.py @@ -1,20 +1,17 @@ """Public pages API — serves published HTML without authentication.""" import uuid -from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import HTMLResponse from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.core.security import get_current_user from app.database import get_db from app.models.published_page import PublishedPage from app.models.user import User - -settings = get_settings() +from app.services.storage import get_storage_backend, normalize_storage_key # Public router — no /api prefix, no auth public_router = APIRouter(tags=["pages"]) @@ -22,11 +19,6 @@ # Authenticated router — under /api prefix router = APIRouter(prefix="/pages", tags=["pages"]) - -def _agent_base_dir(agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) - - # ── Public render (NO auth) ──────────────────────────── @public_router.get("/p/{short_id}") @@ -39,12 +31,12 @@ async def render_page(short_id: str, db: AsyncSession = Depends(get_db)): if not page: raise HTTPException(status_code=404, detail="Page not found") - # Read the HTML file from agent workspace - file_path = _agent_base_dir(page.agent_id) / page.source_path - if not file_path.exists() or not file_path.is_file(): + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{page.agent_id}/{page.source_path}") + if not await storage.exists(storage_key) or not await storage.is_file(storage_key): raise HTTPException(status_code=404, detail="Source file no longer exists") - html_content = file_path.read_text(encoding="utf-8", errors="replace") + html_content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") # Increment view count await db.execute( diff --git a/backend/app/api/relationships.py b/backend/app/api/relationships.py index ed72e9174..4bf252db1 100644 --- a/backend/app/api/relationships.py +++ b/backend/app/api/relationships.py @@ -2,7 +2,6 @@ import json import uuid -from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -10,16 +9,15 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app.config import get_settings from app.core.permissions import build_visible_agents_query, check_agent_access from app.core.security import get_current_user from app.database import get_db from app.models.agent import Agent from app.models.org import AgentRelationship, AgentAgentRelationship, OrgMember -from app.services.org_sync_adapter import derive_member_department_paths from app.models.user import User +from app.services.org_sync_adapter import derive_member_department_paths +from app.services.storage import store_agent_bytes -settings = get_settings() router = APIRouter(prefix="/agents/{agent_id}/relationships", tags=["relationships"]) RELATION_LABELS = { @@ -368,11 +366,13 @@ async def _regenerate_relationships_file(db: AsyncSession, agent_id: uuid.UUID): ) agent_rels = a_result.scalars().all() - ws = Path(settings.AGENT_DATA_DIR) / str(agent_id) - ws.mkdir(parents=True, exist_ok=True) - if not human_rows and not agent_rels: - (ws / "relationships.md").write_text("# 关系网络\n\n_暂无配置的关系。_\n", encoding="utf-8") + await store_agent_bytes( + agent_id, + "relationships.md", + "# 关系网络\n\n_暂无配置的关系。_\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) return lines = ["# 关系网络\n"] @@ -412,4 +412,9 @@ async def _regenerate_relationships_file(db: AsyncSession, agent_id: uuid.UUID): lines.append(f"- {r.description}") lines.append("") - (ws / "relationships.md").write_text("\n".join(lines), encoding="utf-8") + await store_agent_bytes( + agent_id, + "relationships.md", + "\n".join(lines).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) diff --git a/backend/app/api/slack.py b/backend/app/api/slack.py index a7006f9dd..ccb705c76 100644 --- a/backend/app/api/slack.py +++ b/backend/app/api/slack.py @@ -4,6 +4,7 @@ import hmac import time import uuid +from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from loguru import logger @@ -16,6 +17,7 @@ from app.models.channel_config import ChannelConfig from app.models.user import User from app.schemas.schemas import ChannelConfigOut +from app.services.storage import store_agent_upload router = APIRouter(tags=["slack"]) @@ -307,17 +309,12 @@ async def slack_event_webhook( history = [{"role": m.role, "content": m.content} for m in reversed(history_r.scalars().all())] # Handle file attachments: save to workspace/uploads/ and send ack - from app.config import get_settings as _gs import asyncio as _asyncio import random as _random - from pathlib import Path as _Path import httpx as _httpx from datetime import datetime, timezone from app.api.feishu import _FILE_ACK_MESSAGES _file_user_messages = [] - _settings = _gs() - _upload_dir = _Path(_settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - _upload_dir.mkdir(parents=True, exist_ok=True) _bot_token = config.app_secret or "" for _sf in slack_files: _fname = _sf.get("name") or _sf.get("title") or f"slack_file_{_sf.get('id', 'unk')}.bin" @@ -332,8 +329,13 @@ async def slack_event_webhook( _ct = _r.headers.get("content-type", "") if "text/html" in _ct or _r.content[:15].lower().startswith(b" Path: - return Path(get_settings().AGENT_DATA_DIR) / "_tenant_logos" - - -def _tenant_logo_path(tenant_id: uuid.UUID) -> Path: - return _tenant_logo_dir() / f"{tenant_id}.png" +def _tenant_logo_key(tenant_id: uuid.UUID) -> str: + return normalize_storage_key(f"_tenant_logos/{tenant_id}.png") def _tenant_logo_url(tenant_id: uuid.UUID) -> str: - try: - mtime = int(_tenant_logo_path(tenant_id).stat().st_mtime) - except OSError: - mtime = int(datetime.utcnow().timestamp()) - return f"/api/tenants/{tenant_id}/logo?v={mtime}" + return f"/api/tenants/{tenant_id}/logo?v={int(datetime.utcnow().timestamp())}" async def _get_updateable_tenant( @@ -541,9 +532,11 @@ async def update_tenant( @router.get("/{tenant_id}/logo") async def get_tenant_logo(tenant_id: uuid.UUID): """Serve a tenant logo. Logos are public UI assets, addressed by UUID.""" - path = _tenant_logo_path(tenant_id) - if not path.exists(): + storage = get_storage_backend() + key = _tenant_logo_key(tenant_id) + if not await storage.exists(key): raise HTTPException(status_code=404, detail="Logo not found") + path = await ensure_local_path(key) return FileResponse(path, media_type="image/png") @@ -580,11 +573,8 @@ async def upload_tenant_logo( if len(png_data) > 1024 * 1024: raise HTTPException(status_code=400, detail="Logo image must be 1 MB or smaller after processing") - logo_dir = _tenant_logo_dir() - logo_dir.mkdir(parents=True, exist_ok=True) - path = _tenant_logo_path(tenant_id) - async with aiofiles.open(path, "wb") as f: - await f.write(png_data) + storage = get_storage_backend() + await storage.write_bytes(_tenant_logo_key(tenant_id), png_data, content_type="image/png") config = dict(tenant.im_config or {}) config["logo_url"] = _tenant_logo_url(tenant_id) @@ -602,12 +592,10 @@ async def delete_tenant_logo( """Remove a custom company logo and fall back to the generated default.""" tenant = await _get_updateable_tenant(tenant_id, current_user, db) - path = _tenant_logo_path(tenant_id) - if path.exists(): - try: - path.unlink() - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to delete logo") from exc + storage = get_storage_backend() + key = _tenant_logo_key(tenant_id) + if await storage.exists(key): + await storage.delete(key) config = dict(tenant.im_config or {}) config.pop("logo_url", None) diff --git a/backend/app/api/upload.py b/backend/app/api/upload.py index 1aa0ea166..28c7867c9 100644 --- a/backend/app/api/upload.py +++ b/backend/app/api/upload.py @@ -9,13 +9,10 @@ from loguru import logger from app.core.security import get_current_user from app.models.user import User -from app.config import get_settings +from app.services.storage import ensure_local_path, get_storage_backend, guess_content_type, normalize_storage_key router = APIRouter(prefix="/chat", tags=["chat"]) -_settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) - # Supported extensions and their text extraction method TEXT_EXTENSIONS = { ".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", @@ -125,20 +122,19 @@ async def upload_file( # Determine save directory workspace_path = "" if agent_id: - # Save to agent's workspace/uploads/ - uploads_dir = WORKSPACE_ROOT / agent_id / "workspace" / "uploads" - uploads_dir.mkdir(parents=True, exist_ok=True) - save_path = uploads_dir / file.filename - # Avoid overwriting: add suffix if file exists - if save_path.exists(): - stem = save_path.stem - suffix = save_path.suffix - counter = 1 - while save_path.exists(): - save_path = uploads_dir / f"{stem}_{counter}{suffix}" - counter += 1 - save_path.write_bytes(content) - workspace_path = f"workspace/uploads/{save_path.name}" + storage = get_storage_backend() + filename = file.filename.replace("/", "_").replace("\\", "_") + workspace_path = f"workspace/uploads/{filename}" + key = normalize_storage_key(f"{agent_id}/{workspace_path}") + counter = 1 + while await storage.exists(key): + stem, ext = os.path.splitext(filename) + filename = f"{stem}_{counter}{ext}" + workspace_path = f"workspace/uploads/{filename}" + key = normalize_storage_key(f"{agent_id}/{workspace_path}") + counter += 1 + await storage.write_bytes(key, content, content_type=guess_content_type(filename)) + save_path = await ensure_local_path(key) else: # Fallback: save to /tmp (legacy behavior) fallback_dir = Path("/tmp/clawith_uploads") diff --git a/backend/app/api/webhooks.py b/backend/app/api/webhooks.py index 50dd1f435..b61d20ef1 100644 --- a/backend/app/api/webhooks.py +++ b/backend/app/api/webhooks.py @@ -14,17 +14,32 @@ from loguru import logger from sqlalchemy import select +from app.core.events import get_redis from app.database import async_session from app.models.trigger import AgentTrigger +from app.services.trigger_runtime import enqueue_webhook_execution router = APIRouter(prefix="/api/webhooks", tags=["webhooks"]) -# In-memory rate limiter: token -> list of timestamps -_rate_hits: dict[str, list[float]] = {} RATE_LIMIT = 5 # max hits per minute per token MAX_PAYLOAD_SIZE = 65536 # 64KB max payload +async def _record_and_count_hits(token: str) -> int: + """Record the current hit in Redis and return the rolling 60-second count.""" + redis = await get_redis() + now = time.time() + key = f"webhook:rate:{token}" + member = f"{now}:{hashlib.sha1(f'{token}:{now}'.encode()).hexdigest()[:8]}" + async with redis.pipeline(transaction=True) as pipe: + pipe.zremrangebyscore(key, 0, now - 60) + pipe.zadd(key, {member: now}) + pipe.zcard(key) + pipe.expire(key, 120) + _, _, count, _ = await pipe.execute() + return int(count) + + @router.post("/t/{token}") async def receive_webhook(token: str, request: Request): """Receive a webhook POST from an external service. @@ -37,17 +52,13 @@ async def receive_webhook(token: str, request: Request): - Payload size limit (64KB) """ # Rate limiting — use per-agent limit if available - now = time.time() - hits = _rate_hits.get(token, []) - hits = [t for t in hits if now - t < 60] # keep last 60 seconds + hit_count = await _record_and_count_hits(token) # We'll check per-agent rate limit after finding the trigger below. # For now, apply a generous global ceiling to prevent memory abuse. - if len(hits) >= 60: # hard ceiling: 60/min regardless of config + if hit_count >= 60: # hard ceiling: 60/min regardless of config logger.warning(f"Webhook hard rate limit exceeded for token {token[:8]}...") return JSONResponse({"ok": True}, status_code=429) - hits.append(now) - _rate_hits[token] = hits # Payload size check body = await request.body() @@ -83,7 +94,7 @@ async def receive_webhook(token: str, request: Request): agent_obj = agent_result.scalar_one_or_none() agent_rate_limit = (agent_obj.webhook_rate_limit if agent_obj else None) or RATE_LIMIT # Re-check hits against agent-specific limit (hits already collected above) - if len(hits) > agent_rate_limit: # > because we already appended current hit + if hit_count > agent_rate_limit: # > because current hit is already counted logger.warning(f"Webhook per-agent rate limit ({agent_rate_limit}/min) for token {token[:8]}...") # Log audit entry so user can see dropped webhooks try: @@ -120,24 +131,28 @@ async def receive_webhook(token: str, request: Request): try: payload_str = body.decode("utf-8") # Try to pretty-format JSON for readability + payload_obj = None try: payload_obj = json.loads(payload_str) payload_str = json.dumps(payload_obj, ensure_ascii=False, indent=2) except json.JSONDecodeError: - pass # Keep as raw string + payload_obj = None except Exception: + payload_obj = None payload_str = repr(body[:2000]) - # Store payload and set pending flag - new_config = {**cfg, "_webhook_pending": True, "_webhook_payload": payload_str[:8000]} - from sqlalchemy import update - await db.execute( - update(AgentTrigger) - .where(AgentTrigger.id == target.id) - .values(config=new_config) + _execution, created = await enqueue_webhook_execution( + db, + trigger=target, + body=body, + payload_text=payload_str, + payload_obj=payload_obj if isinstance(payload_obj, dict) else None, + request_headers={k.lower(): v for k, v in request.headers.items()}, ) - await db.commit() + if not created: + logger.info(f"Webhook duplicate ignored for trigger {target.name}") + return JSONResponse({"ok": True}) - logger.info(f"Webhook received for trigger {target.name} (agent {target.agent_id})") + logger.info(f"Webhook queued for trigger {target.name} (agent {target.agent_id})") return JSONResponse({"ok": True}) diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index a3d23dbf2..8389404c1 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -20,6 +20,7 @@ from app.models.user import User from app.services.chat_session_service import ensure_primary_platform_session from app.services.llm import call_llm, call_llm_with_failover +from app.services.realtime import realtime_router router = APIRouter(tags=["websocket"]) @@ -32,59 +33,82 @@ def __init__(self): self.active_connections: dict[str, list[tuple]] = {} async def connect(self, agent_id: str, websocket: WebSocket, session_id: str = None, user_id: str | None = None): - await websocket.accept() if agent_id not in self.active_connections: self.active_connections[agent_id] = [] self.active_connections[agent_id].append((websocket, session_id, user_id)) - - def disconnect(self, agent_id: str, websocket: WebSocket): + await realtime_router.register_connection( + agent_id=agent_id, + websocket=websocket, + session_id=session_id, + user_id=user_id, + ) + + async def disconnect(self, agent_id: str, websocket: WebSocket): if agent_id in self.active_connections: self.active_connections[agent_id] = [ (ws, sid, uid) for ws, sid, uid in self.active_connections[agent_id] if ws != websocket ] + await realtime_router.unregister_connection(agent_id=agent_id, websocket=websocket) + + def _local_connections(self, agent_id: str) -> list[tuple[WebSocket, str | None, str | None]]: + return self.active_connections.get(agent_id, []) + + async def deliver_pubsub_message( + self, + *, + agent_id: str, + payload: dict, + session_id: str | None = None, + user_id: str | None = None, + ) -> None: + if agent_id not in self.active_connections: + return + for ws, sid, uid in list(self.active_connections[agent_id]): + if session_id is not None and sid != session_id: + continue + if user_id is not None and uid != user_id: + continue + try: + await ws.send_json(payload) + except Exception: + pass async def send_message(self, agent_id: str, message: dict): - if agent_id in self.active_connections: - for ws, _sid, _uid in self.active_connections[agent_id]: - try: - await ws.send_json(message) - except Exception: - pass + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + ) async def send_to_session(self, agent_id: str, session_id: str, message: dict): """Send message only to WebSocket connections matching the given session_id.""" - if agent_id in self.active_connections: - for ws, sid, _uid in self.active_connections[agent_id]: - if sid == session_id: - try: - await ws.send_json(message) - except Exception: - pass + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + session_id=session_id, + ) async def send_to_user(self, agent_id: str, user_id: str, message: dict): """Send message to all live WebSocket sessions of a given platform user for an agent.""" - if agent_id in self.active_connections: - for ws, _sid, uid in self.active_connections[agent_id]: - if uid == user_id: - try: - await ws.send_json(message) - except Exception: - pass - - def get_active_session_ids(self, agent_id: str) -> list[str]: + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + user_id=user_id, + ) + + async def get_active_session_ids(self, agent_id: str) -> list[str]: """Return distinct session IDs for all active WS connections of an agent.""" - if agent_id not in self.active_connections: - return [] - return list(set(sid for _ws, sid, _uid in self.active_connections[agent_id] if sid)) + return await realtime_router.get_active_session_ids(agent_id) - def is_user_viewing_session(self, agent_id: str, session_id: str, user_id: str) -> bool: + async def is_user_viewing_session(self, agent_id: str, session_id: str, user_id: str) -> bool: """Return True if the given platform user currently has this exact session open.""" - if agent_id not in self.active_connections: - return False - for _ws, sid, uid in self.active_connections[agent_id]: - if sid == session_id and uid == user_id: - return True - return False + return await realtime_router.is_user_viewing_session( + agent_id=agent_id, + session_id=session_id, + user_id=user_id, + ) manager = ConnectionManager() @@ -98,7 +122,7 @@ async def maybe_mark_session_read_for_active_viewer( user_id: uuid.UUID, ) -> bool: """Advance last_read_at_by_user if the owner is actively viewing this exact session.""" - if not manager.is_user_viewing_session(str(agent_id), session_id, str(user_id)): + if not await manager.is_user_viewing_session(str(agent_id), session_id, str(user_id)): return False session = await db.get(ChatSession, uuid.UUID(session_id)) @@ -323,9 +347,7 @@ async def websocket_chat( return agent_id_str = str(agent_id) - if agent_id_str not in manager.active_connections: - manager.active_connections[agent_id_str] = [] - manager.active_connections[agent_id_str].append((websocket, conv_id, str(user_id))) + await manager.connect(agent_id_str, websocket, conv_id, str(user_id)) logger.info(f"[WS] Ready! Agent={agent_name}") # Send session_id to frontend so Take Control can reference the correct session. @@ -954,9 +976,9 @@ async def _on_failover(reason: str): except WebSocketDisconnect: logger.info(f"[WS] Client disconnected: {user_id}") - manager.disconnect(str(agent_id), websocket) + await manager.disconnect(str(agent_id), websocket) except Exception as e: logger.error(f"[WS] Unexpected error: {e}") import traceback traceback.print_exc() - manager.disconnect(str(agent_id), websocket) + await manager.disconnect(str(agent_id), websocket) diff --git a/backend/app/config.py b/backend/app/config.py index b0719c368..ac3e0b5bb 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,7 +1,10 @@ """Application configuration.""" from functools import lru_cache +import os from pathlib import Path +import socket +import uuid from pydantic_settings import BaseSettings @@ -32,6 +35,14 @@ def _default_agent_data_dir() -> str: return str(Path.home() / ".clawith" / "data" / "agents") +def _default_instance_id() -> str: + """Generate a stable-enough per-process instance identifier.""" + host = socket.gethostname() or "unknown" + pid = os.getpid() + suffix = uuid.uuid4().hex[:8] + return f"{host}-{pid}-{suffix}" + + def _default_agent_template_dir() -> str: """Locate the agent template directory for both Docker and source deployments. @@ -73,6 +84,7 @@ class Settings(BaseSettings): # Redis REDIS_URL: str = "redis://localhost:6379/0" + INSTANCE_ID: str = _default_instance_id() # JWT JWT_SECRET_KEY: str = "change-me-jwt-secret" @@ -83,8 +95,20 @@ class Settings(BaseSettings): EMAIL_VERIFICATION_REQUIRED: bool = False # Require email verification for login # File Storage + STORAGE_BACKEND: str = "local" AGENT_DATA_DIR: str = _default_agent_data_dir() AGENT_TEMPLATE_DIR: str = _default_agent_template_dir() + STORAGE_LOCAL_ROOT: str = _default_agent_data_dir() + S3_BUCKET: str = "" + S3_REGION: str = "" + S3_ENDPOINT_URL: str = "" + S3_ACCESS_KEY_ID: str = "" + S3_SECRET_ACCESS_KEY: str = "" + S3_PREFIX: str = "agents" + S3_PRESIGN_TTL_SECONDS: int = 3600 + + # Process role + PROCESS_ROLE: str = "all" # Docker (for Agent containers) DOCKER_NETWORK: str = "clawith_network" diff --git a/backend/app/main.py b/backend/app/main.py index 0f4691004..7366ffa1d 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -13,10 +13,26 @@ from app.core.logging_config import configure_logging, intercept_standard_logging from app.core.middleware import TraceIdMiddleware from app.schemas.schemas import HealthResponse +from app.services.realtime import realtime_router settings = get_settings() +def _process_roles() -> set[str]: + raw = (settings.PROCESS_ROLE or "all").strip().lower() + if not raw: + return {"all"} + roles = {part.strip() for part in raw.split(",") if part.strip()} + return roles or {"all"} + + +def _role_enabled(*required: str) -> bool: + roles = _process_roles() + if "all" in roles: + return True + return any(role in roles for role in required) + + def _log_bwrap_startup_status() -> None: """Emit a startup diagnostic for bubblewrap availability. @@ -122,132 +138,137 @@ async def lifespan(app: FastAPI): from app.services.wechat_channel import wechat_poll_manager from app.services.discord_gateway import discord_gateway_manager - # ── Step 0: Ensure all DB tables exist (idempotent, safe to run on every startup) ── - try: - from app.database import Base, engine - # Import all models so Base.metadata is fully populated - import app.models.user # noqa - import app.models.agent # noqa - import app.models.task # noqa - import app.models.llm # noqa - import app.models.tool # noqa - import app.models.audit # noqa - import app.models.skill # noqa - import app.models.channel_config # noqa - import app.models.schedule # noqa - import app.models.plaza # noqa - import app.models.activity_log # noqa - import app.models.org # noqa - import app.models.system_settings # noqa - import app.models.invitation_code # noqa - import app.models.tenant # noqa - import app.models.tenant_setting # noqa - import app.models.participant # noqa - import app.models.chat_session # noqa - import app.models.trigger # noqa - import app.models.notification # noqa - import app.models.gateway_message # noqa - import app.models.agent_credential # noqa - import app.models.okr # noqa OKR system tables - - import app.models.identity # noqa - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - logger.info("[startup] Database tables ready") - except Exception as e: - logger.warning(f"[startup] create_all failed: {e}") - # Startup: seed data — each step isolated so one failure doesn't block others - logger.info("[startup] seeding...") + if _role_enabled("all", "bootstrap"): + # ── Step 0: Ensure all DB tables exist (idempotent, safe to run on every startup) ── + try: + from app.database import Base, engine + # Import all models so Base.metadata is fully populated + import app.models.user # noqa + import app.models.agent # noqa + import app.models.task # noqa + import app.models.llm # noqa + import app.models.tool # noqa + import app.models.audit # noqa + import app.models.skill # noqa + import app.models.channel_config # noqa + import app.models.schedule # noqa + import app.models.plaza # noqa + import app.models.activity_log # noqa + import app.models.org # noqa + import app.models.system_settings # noqa + import app.models.invitation_code # noqa + import app.models.tenant # noqa + import app.models.tenant_setting # noqa + import app.models.participant # noqa + import app.models.chat_session # noqa + import app.models.trigger # noqa + import app.models.trigger_execution # noqa + import app.models.notification # noqa + import app.models.gateway_message # noqa + import app.models.agent_credential # noqa + import app.models.okr # noqa + + import app.models.identity # noqa + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("[startup] Database tables ready") + except Exception as e: + logger.warning(f"[startup] create_all failed: {e}") + logger.info("[startup] seeding...") - # Seed default company (Tenant) — required before users can register - try: - from app.models.tenant import Tenant - from app.database import async_session as _session - from sqlalchemy import select as _select - async with _session() as _db: - _existing = await _db.execute(_select(Tenant).where(Tenant.slug == "default")) - if not _existing.scalar_one_or_none(): - _db.add(Tenant(name="Default", slug="default", im_provider="web_only")) - await _db.commit() - logger.info("[startup] Default company created") - except Exception as e: - logger.warning(f"[startup] Default company seed failed: {e}") + try: + from app.models.tenant import Tenant + from app.database import async_session as _session + from sqlalchemy import select as _select + async with _session() as _db: + _existing = await _db.execute(_select(Tenant).where(Tenant.slug == "default")) + if not _existing.scalar_one_or_none(): + _db.add(Tenant(name="Default", slug="default", im_provider="web_only")) + await _db.commit() + logger.info("[startup] Default company created") + except Exception as e: + logger.warning(f"[startup] Default company seed failed: {e}") - # Migrate old shared enterprise_info/ → enterprise_info_{first_tenant_id}/ - try: - import shutil - from pathlib import Path as _Path - from app.config import get_settings as _gs - from app.models.tenant import Tenant as _T - from app.database import async_session as _ses - from sqlalchemy import select as _sel - _data_dir = _Path(_gs().AGENT_DATA_DIR) - _old_dir = _data_dir / "enterprise_info" - if _old_dir.exists() and any(_old_dir.iterdir()): - async with _ses() as _db: - _first = await _db.execute(_sel(_T).order_by(_T.created_at).limit(1)) - _tenant = _first.scalar_one_or_none() - if _tenant: - _new_dir = _data_dir / f"enterprise_info_{_tenant.id}" - if not _new_dir.exists(): - shutil.copytree(str(_old_dir), str(_new_dir)) - print(f"[startup] ✅ Migrated enterprise_info → enterprise_info_{_tenant.id}", flush=True) - else: - print(f"[startup] ℹ️ enterprise_info_{_tenant.id} already exists, skipping migration", flush=True) - except Exception as e: - print(f"[startup] ⚠️ enterprise_info migration failed: {e}", flush=True) + try: + import shutil + from pathlib import Path as _Path + from app.config import get_settings as _gs + from app.models.tenant import Tenant as _T + from app.database import async_session as _ses + from sqlalchemy import select as _sel + _data_dir = _Path(_gs().AGENT_DATA_DIR) + _old_dir = _data_dir / "enterprise_info" + if _old_dir.exists() and any(_old_dir.iterdir()): + async with _ses() as _db: + _first = await _db.execute(_sel(_T).order_by(_T.created_at).limit(1)) + _tenant = _first.scalar_one_or_none() + if _tenant: + _new_dir = _data_dir / f"enterprise_info_{_tenant.id}" + if not _new_dir.exists(): + shutil.copytree(str(_old_dir), str(_new_dir)) + print(f"[startup] ✅ Migrated enterprise_info → enterprise_info_{_tenant.id}", flush=True) + else: + print(f"[startup] ℹ️ enterprise_info_{_tenant.id} already exists, skipping migration", flush=True) + except Exception as e: + print(f"[startup] ⚠️ enterprise_info migration failed: {e}", flush=True) - try: - from app.services.tool_seeder import seed_builtin_tools, clean_orphaned_mcp_tools - await seed_builtin_tools() - await clean_orphaned_mcp_tools() - except Exception as e: - logger.warning(f"[startup] Builtin tools seed or cleanup failed: {e}") + try: + from app.services.tool_seeder import seed_builtin_tools, clean_orphaned_mcp_tools + await seed_builtin_tools() + await clean_orphaned_mcp_tools() + except Exception as e: + logger.warning(f"[startup] Builtin tools seed or cleanup failed: {e}") - try: - from app.services.tool_seeder import seed_atlassian_rovo_config, get_atlassian_api_key - await seed_atlassian_rovo_config() - # Auto-import Atlassian Rovo tools if an API key is already configured - _rovo_key = await get_atlassian_api_key() - if _rovo_key: - from app.services.resource_discovery import seed_atlassian_rovo_tools - await seed_atlassian_rovo_tools(_rovo_key) - except Exception as e: - logger.warning(f"[startup] Atlassian tools seed failed: {e}") + try: + from app.services.tool_seeder import seed_atlassian_rovo_config, get_atlassian_api_key + await seed_atlassian_rovo_config() + _rovo_key = await get_atlassian_api_key() + if _rovo_key: + from app.services.resource_discovery import seed_atlassian_rovo_tools + await seed_atlassian_rovo_tools(_rovo_key) + except Exception as e: + logger.warning(f"[startup] Atlassian tools seed failed: {e}") - try: - await seed_agent_templates() - except Exception as e: - logger.warning(f"[startup] Agent templates seed failed: {e}") + try: + await seed_agent_templates() + except Exception as e: + logger.warning(f"[startup] Agent templates seed failed: {e}") - try: - from app.services.skill_seeder import seed_skills, push_default_skills_to_existing_agents - await seed_skills() - await push_default_skills_to_existing_agents() - except Exception as e: - logger.warning(f"[startup] Skills seed failed: {e}") + try: + from app.services.skill_seeder import seed_skills, push_default_skills_to_existing_agents + await seed_skills() + await push_default_skills_to_existing_agents() + except Exception as e: + logger.warning(f"[startup] Skills seed failed: {e}") - try: - from app.services.agent_seeder import seed_default_agents - await seed_default_agents() - except Exception as e: - logger.warning(f"[startup] Default agents seed failed: {e}") + try: + from app.services.agent_seeder import seed_default_agents + await seed_default_agents() + except Exception as e: + logger.warning(f"[startup] Default agents seed failed: {e}") - try: - # Seed OKR Agent independently (supports retroactive creation on existing deployments) - from app.services.agent_seeder import seed_okr_agent - await seed_okr_agent() - except Exception as e: - logger.warning(f"[startup] OKR Agent seed failed: {e}") + try: + from app.services.agent_seeder import seed_okr_agent + await seed_okr_agent() + except Exception as e: + logger.warning(f"[startup] OKR Agent seed failed: {e}") - try: - # Patch existing OKR Agent with new fields/tools/triggers added in later versions - from app.services.agent_seeder import patch_existing_okr_agent - await patch_existing_okr_agent() - except Exception as e: - logger.warning(f"[startup] OKR Agent patch failed: {e}") + try: + from app.services.agent_seeder import patch_existing_okr_agent + await patch_existing_okr_agent() + except Exception as e: + logger.warning(f"[startup] OKR Agent patch failed: {e}") + else: + logger.info(f"[startup] bootstrap skipped for PROCESS_ROLE={settings.PROCESS_ROLE}") + + if _role_enabled("all", "api"): + try: + from app.api.websocket import manager as ws_manager + await realtime_router.start(ws_manager.deliver_pubsub_message) + logger.info("[startup] realtime router subscriber started") + except Exception as e: + logger.error(f"[startup] realtime router start failed: {e}") - # Start background tasks (always, even if seeding failed) try: logger.info("[startup] starting background tasks...") from app.services.audit_logger import write_audit_log @@ -264,14 +285,19 @@ def _bg_task_error(t): import traceback traceback.print_exception(type(exc), exc, exc.__traceback__) - for name, coro in [ - ("trigger_daemon", start_trigger_daemon()), - ("feishu_ws", feishu_ws_manager.start_all()), - ("dingtalk_stream", dingtalk_stream_manager.start_all()), - ("wecom_stream", wecom_stream_manager.start_all()), - ("wechat_poll", wechat_poll_manager.start_all()), - ("discord_gw", discord_gateway_manager.start_all()), - ]: + task_specs = [] + if _role_enabled("all", "worker"): + task_specs.append(("trigger_daemon", start_trigger_daemon())) + if _role_enabled("all", "connector"): + task_specs.extend([ + ("feishu_ws", feishu_ws_manager.start_all()), + ("dingtalk_stream", dingtalk_stream_manager.start_all()), + ("wecom_stream", wecom_stream_manager.start_all()), + ("wechat_poll", wechat_poll_manager.start_all()), + ("discord_gw", discord_gateway_manager.start_all()), + ]) + + for name, coro in task_specs: task = asyncio.create_task(coro, name=name) task.add_done_callback(_bg_task_error) logger.info(f"[startup] created bg task: {name}") @@ -288,6 +314,7 @@ def _bg_task_error(t): yield # Shutdown + await realtime_router.stop() await close_redis() diff --git a/backend/app/models/trigger_execution.py b/backend/app/models/trigger_execution.py new file mode 100644 index 000000000..8482728a9 --- /dev/null +++ b/backend/app/models/trigger_execution.py @@ -0,0 +1,41 @@ +"""Trigger execution records for distributed claiming and idempotency.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class TriggerExecution(Base): + """A concrete trigger execution request that workers can claim and process.""" + + __tablename__ = "trigger_executions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + trigger_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("agent_triggers.id", ondelete="CASCADE"), nullable=False, index=True + ) + agent_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False, index=True + ) + source: Mapped[str] = mapped_column(String(32), nullable=False, default="webhook") + status: Mapped[str] = mapped_column(String(20), nullable=False, default="pending") + idempotency_key: Mapped[str] = mapped_column(String(255), nullable=False) + payload: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + payload_text: Mapped[str] = mapped_column(Text, nullable=False, default="") + lease_owner: Mapped[str | None] = mapped_column(String(128)) + lease_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + scheduled_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_error: Mapped[str | None] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) + + __table_args__ = ( + UniqueConstraint("trigger_id", "idempotency_key", name="uq_trigger_execution_idempotency"), + Index("ix_trigger_executions_status_scheduled", "status", "scheduled_at"), + ) diff --git a/backend/app/services/agent_context.py b/backend/app/services/agent_context.py index 52ffe61a2..ebecc2992 100644 --- a/backend/app/services/agent_context.py +++ b/backend/app/services/agent_context.py @@ -8,23 +8,17 @@ from pathlib import Path from app.config import get_settings +from app.services.storage import get_storage_backend, normalize_storage_key settings = get_settings() -PERSISTENT_DATA = Path(settings.AGENT_DATA_DIR) - - -def _agent_workspace(agent_id: uuid.UUID) -> Path: - """Return the canonical persistent workspace path for an agent.""" - return PERSISTENT_DATA / str(agent_id) - - -def _read_file_safe(path: Path, max_chars: int = 3000) -> str: - """Read a file, return empty string if missing. Truncate if too long.""" - if not path.exists(): +async def _read_file_safe(key: str, max_chars: int = 3000) -> str: + """Read a storage-backed text file, return empty string if missing.""" + storage = get_storage_backend() + if not await storage.exists(key) or not await storage.is_file(key): return "" try: - content = path.read_text(encoding="utf-8", errors="replace").strip() + content = (await storage.read_text(key, encoding="utf-8", errors="replace")).strip() if len(content) > max_chars: content = content[:max_chars] + "\n...(truncated)" return content @@ -76,7 +70,7 @@ def _parse_skill_frontmatter(content: str, filename: str) -> tuple[str, str]: return name, description -def _load_skills_index(agent_id: uuid.UUID) -> str: +async def _load_skills_index(agent_id: uuid.UUID) -> str: """Load skill index (name + description) from skills/ directory. Supports two formats: @@ -87,36 +81,36 @@ def _load_skills_index(agent_id: uuid.UUID) -> str: prompt. The model is instructed to call read_file to load full content when a skill is relevant. """ - ws_root = _agent_workspace(agent_id) skills: list[tuple[str, str, str]] = [] # (name, description, path_relative_to_skills) - skills_dir = ws_root / "skills" - if skills_dir.exists(): - for entry in sorted(skills_dir.iterdir()): + storage = get_storage_backend() + skills_prefix = normalize_storage_key(f"{agent_id}/skills") + if await storage.exists(skills_prefix) and await storage.is_dir(skills_prefix): + for entry in await storage.list_dir(skills_prefix): if entry.name.startswith("."): continue + entry_key = entry.key # Case 1: Folder-based skill — skills//SKILL.md - if entry.is_dir(): - skill_md = entry / "SKILL.md" - if not skill_md.exists(): - # Also try lowercase skill.md - skill_md = entry / "skill.md" - if skill_md.exists(): + if entry.is_dir: + skill_md_key = f"{entry_key}/SKILL.md" + if not await storage.exists(skill_md_key): + skill_md_key = f"{entry_key}/skill.md" + if await storage.exists(skill_md_key): try: - content = skill_md.read_text(encoding="utf-8", errors="replace").strip() + content = (await storage.read_text(skill_md_key, encoding="utf-8", errors="replace")).strip() name, desc = _parse_skill_frontmatter(content, entry.name) skills.append((name, desc, f"{entry.name}/SKILL.md")) except Exception: skills.append((entry.name, "", f"{entry.name}/SKILL.md")) # Case 2: Flat file — skills/.md - elif entry.suffix == ".md" and entry.is_file(): + elif Path(entry.name).suffix == ".md" and not entry.is_dir: try: - content = entry.read_text(encoding="utf-8", errors="replace").strip() - name, desc = _parse_skill_frontmatter(content, entry.stem) + content = (await storage.read_text(entry_key, encoding="utf-8", errors="replace")).strip() + name, desc = _parse_skill_frontmatter(content, Path(entry.name).stem) skills.append((name, desc, entry.name)) except Exception: - skills.append((entry.stem, "", entry.name)) + skills.append((Path(entry.name).stem, "", entry.name)) # Deduplicate by name seen: set[str] = set() @@ -158,24 +152,24 @@ async def build_agent_context(agent_id: uuid.UUID, agent_name: str, role_descrip - skills/ → skill names + summaries - relationships.md → relationship descriptions """ - ws_root = _agent_workspace(agent_id) - # --- Soul --- - soul = _read_file_safe(ws_root / "soul.md", 2000) + soul = await _read_file_safe(normalize_storage_key(f"{agent_id}/soul.md"), 2000) # Strip markdown heading if present if soul.startswith("# "): soul = "\n".join(soul.split("\n")[1:]).strip() # --- Memory --- - memory = _read_file_safe(ws_root / "memory" / "memory.md", 2000) or _read_file_safe(ws_root / "memory.md", 2000) + memory = await _read_file_safe(normalize_storage_key(f"{agent_id}/memory/memory.md"), 2000) + if not memory: + memory = await _read_file_safe(normalize_storage_key(f"{agent_id}/memory.md"), 2000) if memory.startswith("# "): memory = "\n".join(memory.split("\n")[1:]).strip() # --- Skills index (progressive disclosure) --- - skills_text = _load_skills_index(agent_id) + skills_text = await _load_skills_index(agent_id) # --- Relationships --- - relationships = _read_file_safe(ws_root / "relationships.md", 2000) + relationships = await _read_file_safe(normalize_storage_key(f"{agent_id}/relationships.md"), 2000) if relationships.startswith("# "): relationships = "\n".join(relationships.split("\n")[1:]).strip() diff --git a/backend/app/services/agent_manager.py b/backend/app/services/agent_manager.py index 9f4c4a73c..4e04383cc 100644 --- a/backend/app/services/agent_manager.py +++ b/backend/app/services/agent_manager.py @@ -16,6 +16,7 @@ from app.models.agent import Agent from app.models.llm import LLMModel from app.services.llm import get_model_api_key +from app.services.storage import get_storage_backend, normalize_storage_key settings = get_settings() @@ -31,36 +32,64 @@ def __init__(self): self.docker_client = None def _agent_dir(self, agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / str(agent_id) + + def _agent_storage_prefix(self, agent_id: uuid.UUID) -> str: + return normalize_storage_key(str(agent_id)) def _template_dir(self) -> Path: return Path(settings.AGENT_TEMPLATE_DIR) + async def _materialize_agent_dir(self, agent_id: uuid.UUID) -> Path: + """Create a local working tree from shared storage for container mounting.""" + agent_dir = self._agent_dir(agent_id) + storage = get_storage_backend() + agent_prefix = self._agent_storage_prefix(agent_id) + agent_dir.mkdir(parents=True, exist_ok=True) + if not await storage.exists(agent_prefix): + return agent_dir + for entry in await storage.list_dir(agent_prefix): + await self._materialize_entry(storage, entry.key, agent_dir) + return agent_dir + + async def _materialize_entry(self, storage, storage_key: str, local_root: Path) -> None: + rel = Path(storage_key).relative_to(Path(storage_key).parts[0]).as_posix() + local_path = local_root / rel + if await storage.is_dir(storage_key): + local_path.mkdir(parents=True, exist_ok=True) + for child in await storage.list_dir(storage_key): + await self._materialize_entry(storage, child.key, local_root) + return + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_bytes(await storage.read_bytes(storage_key)) + async def initialize_agent_files(self, db: AsyncSession, agent: Agent, personality: str = "", boundaries: str = "") -> None: """Copy template files and customize for this agent.""" agent_dir = self._agent_dir(agent.id) template_dir = self._template_dir() + storage = get_storage_backend() + agent_prefix = self._agent_storage_prefix(agent.id) - if agent_dir.exists(): + if await storage.exists(agent_prefix): logger.warning(f"Agent dir already exists: {agent_dir}") return if template_dir.exists(): - # Copy template - shutil.copytree(str(template_dir), str(agent_dir)) + for src in template_dir.rglob("*"): + if src.is_dir(): + continue + rel = src.relative_to(template_dir).as_posix() + await storage.write_bytes( + f"{agent_prefix}/{rel}", + src.read_bytes(), + ) else: - # No template dir (local dev) — create minimal workspace structure logger.info(f"Template dir not found ({template_dir}), creating minimal workspace") - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) - (agent_dir / "tasks.json").write_text("[]", encoding="utf-8") + await storage.write_text(f"{agent_prefix}/tasks.json", "[]", encoding="utf-8") # Customize soul.md - soul_path = agent_dir / "soul.md" # Get creator name from app.models.user import User result = await db.execute(select(User).where(User.id == agent.creator_id)) @@ -68,8 +97,9 @@ async def initialize_agent_files(self, db: AsyncSession, agent: Agent, creator_name = creator.display_name if creator else "Unknown" soul_content = f"# Personality\n\nI'm {agent.name}, {agent.role_description or 'a digital assistant'}.\n" - if soul_path.exists(): - template_content = soul_path.read_text() + soul_key = f"{agent_prefix}/soul.md" + if await storage.exists(soul_key): + template_content = await storage.read_text(soul_key, encoding="utf-8", errors="replace") soul_content = template_content.replace("{{agent_name}}", agent.name) soul_content = soul_content.replace("{{role_description}}", agent.role_description or "通用助手") soul_content = soul_content.replace("{{creator_name}}", creator_name) @@ -109,34 +139,34 @@ def replace_or_append_section(content: str, section_name: str, section_content: soul_content = replace_or_append_section(soul_content, "Personality", personality) soul_content = replace_or_append_section(soul_content, "Boundaries", boundaries) - soul_path.write_text(soul_content, encoding="utf-8") + await storage.write_text(soul_key, soul_content, encoding="utf-8") # Ensure memory.md exists - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") + mem_key = f"{agent_prefix}/memory/memory.md" + if not await storage.exists(mem_key): + await storage.write_text(mem_key, "# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") # Ensure reflections.md exists — copy from central template - refl_path = agent_dir / "memory" / "reflections.md" - if not refl_path.exists(): + refl_key = f"{agent_prefix}/memory/reflections.md" + if not await storage.exists(refl_key): refl_template = Path(__file__).parent.parent / "templates" / "reflections.md" refl_content = refl_template.read_text(encoding="utf-8") if refl_template.exists() else "# Reflections Journal\n" - refl_path.write_text(refl_content, encoding="utf-8") + await storage.write_text(refl_key, refl_content, encoding="utf-8") # Ensure HEARTBEAT.md exists — copy from central template - hb_path = agent_dir / "HEARTBEAT.md" - if not hb_path.exists(): + hb_key = f"{agent_prefix}/HEARTBEAT.md" + if not await storage.exists(hb_key): hb_template = Path(__file__).parent.parent / "templates" / "HEARTBEAT.md" hb_content = hb_template.read_text(encoding="utf-8") if hb_template.exists() else "# Heartbeat Instructions\n" - hb_path.write_text(hb_content, encoding="utf-8") + await storage.write_text(hb_key, hb_content, encoding="utf-8") # Customize state.json - state_path = agent_dir / "state.json" - if state_path.exists(): - state = json.loads(state_path.read_text()) + state_key = f"{agent_prefix}/state.json" + if await storage.exists(state_key): + state = json.loads(await storage.read_text(state_key, encoding="utf-8", errors="replace")) state["agent_id"] = str(agent.id) state["name"] = agent.name - state_path.write_text(json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") + await storage.write_text(state_key, json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") logger.info(f"Initialized agent files at {agent_dir}") @@ -171,7 +201,7 @@ async def start_container(self, db: AsyncSession, agent: Agent) -> str | None: agent.last_active_at = datetime.now(timezone.utc) return None - agent_dir = self._agent_dir(agent.id) + agent_dir = await self._materialize_agent_dir(agent.id) # Get model config model = None @@ -269,7 +299,8 @@ async def remove_container(self, agent: Agent) -> bool: async def archive_agent_files(self, agent_id: uuid.UUID) -> Path: """Archive agent files to a backup location and return the archive directory.""" agent_dir = self._agent_dir(agent_id) - archive_dir = Path(settings.AGENT_DATA_DIR) / "_archived" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + archive_dir = Path(local_root) / "_archived" archive_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") dest = archive_dir / f"{agent_id}_{timestamp}" diff --git a/backend/app/services/agent_seeder.py b/backend/app/services/agent_seeder.py index e267c951e..8e3160b11 100644 --- a/backend/app/services/agent_seeder.py +++ b/backend/app/services/agent_seeder.py @@ -1,9 +1,7 @@ """Seed default agents (Morty & Meeseeks) on first platform startup.""" -import shutil import uuid from datetime import datetime, timezone -from pathlib import Path from loguru import logger @@ -21,8 +19,28 @@ from app.models.user import User from app.models.okr import OKRSettings from app.config import get_settings +from app.services.agent_manager import agent_manager +from app.services.storage import get_storage_backend, store_agent_bytes settings = get_settings() +SEED_MARKER_KEY = "_bootstrap/.seeded" + + +async def _read_seed_marker() -> str: + storage = get_storage_backend() + if not await storage.exists(SEED_MARKER_KEY): + return "" + return await storage.read_text(SEED_MARKER_KEY, encoding="utf-8", errors="replace") + + +async def _append_seed_marker(line: str) -> None: + storage = get_storage_backend() + existing = await _read_seed_marker() + if line in existing: + return + updated = existing if existing.endswith("\n") or not existing else existing + "\n" + updated += f"{line}\n" + await storage.write_text(SEED_MARKER_KEY, updated, encoding="utf-8") # ── Soul definitions ──────────────────────────────────────────── @@ -192,9 +210,8 @@ async def seed_default_agents(): than by agent name, so the seeder does NOT re-run if the user renames or deletes the default agents. Delete the marker manually to re-seed. """ - # --- Idempotency guard: file-based marker (survives agent renames/deletes) --- - seed_marker = Path(settings.AGENT_DATA_DIR) / ".seeded" - if seed_marker.exists(): + # --- Idempotency guard: storage-backed marker (survives agent renames/deletes) --- + if await _read_seed_marker(): logger.info("[AgentSeeder] Seed marker found, skipping default agent creation") return @@ -243,47 +260,14 @@ async def seed_default_agents(): db.add(AgentPermission(agent_id=morty.id, scope_type="company", access_level="manage")) db.add(AgentPermission(agent_id=meeseeks.id, scope_type="company", access_level="manage")) - # ── Initialize workspace files ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) - for agent, soul_content in [(morty, MORTY_SOUL), (meeseeks, MEESEEKS_SOUL)]: - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent.id) - - if template_dir.exists(): - # Copy the full agent template so Morty/Meeseeks get EVERY file - # defined in the template: MEMORY_INDEX.md, curiosity_journal.md, - # state.json, todo.json, daily_reports/, enterprise_info/, etc. - shutil.copytree(str(template_dir), str(agent_dir)) - else: - # Fallback for local dev (no Docker template mount) - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - - # Overlay custom soul (rich Morty/Meeseeks persona over the generic template) - (agent_dir / "soul.md").write_text(soul_content.strip() + "\n", encoding="utf-8") - - # Ensure memory.md exists (template does not include it; holds runtime context) - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") - - # Ensure reflections.md exists (not in agent_template; lives in app/templates) - refl_path = agent_dir / "memory" / "reflections.md" - if not refl_path.exists(): - refl_src = Path(__file__).parent.parent / "templates" / "reflections.md" - refl_path.write_text(refl_src.read_text(encoding="utf-8") if refl_src.exists() else "# Reflections Journal\n", encoding="utf-8") - - # Stamp agent identity into state.json if present - state_path = agent_dir / "state.json" - if state_path.exists(): - import json as _json - state = _json.loads(state_path.read_text()) - state["agent_id"] = str(agent.id) - state["name"] = agent.name - state_path.write_text(_json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") + await agent_manager.initialize_agent_files(db, agent) + await store_agent_bytes( + agent.id, + "soul.md", + (soul_content.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) # ── Assign skills ── all_skills_result = await db.execute( @@ -292,9 +276,6 @@ async def seed_default_agents(): all_skills = {s.folder_name: s for s in all_skills_result.scalars().all()} for agent, skill_folders in [(morty, MORTY_SKILLS), (meeseeks, MEESEEKS_SKILLS)]: - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent.id) - skills_dir = agent_dir / "skills" - # Always include default skills folders_to_copy = set(skill_folders) for fname, skill in all_skills.items(): @@ -305,12 +286,13 @@ async def seed_default_agents(): skill = all_skills.get(fname) if not skill: continue - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) for sf in skill.files: - file_path = skill_folder / sf.path - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(sf.content, encoding="utf-8") + await store_agent_bytes( + agent.id, + f"skills/{skill.folder_name}/{sf.path}", + sf.content.encode("utf-8"), + content_type="text/plain; charset=utf-8", + ) # ── Assign all default tools ── default_tools_result = await db.execute( @@ -337,32 +319,33 @@ async def seed_default_agents(): )) # ── Write relationships.md for each ── - morty_dir = Path(settings.AGENT_DATA_DIR) / str(morty.id) - meeseeks_dir = Path(settings.AGENT_DATA_DIR) / str(meeseeks.id) - - (morty_dir / "relationships.md").write_text( + await store_agent_bytes( + morty.id, + "relationships.md", "# Relationships\n\n" "## Digital Employee Colleagues\n\n" - "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n", - encoding="utf-8", + "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) - (meeseeks_dir / "relationships.md").write_text( + await store_agent_bytes( + meeseeks.id, + "relationships.md", "# Relationships\n\n" "## Digital Employee Colleagues\n\n" - "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n", - encoding="utf-8", + "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) await db.commit() logger.info(f"[AgentSeeder] Created default agents: Morty ({morty.id}), Meeseeks ({meeseeks.id})") # Write seed marker AFTER a successful commit so a failed seed can be retried - seed_marker.parent.mkdir(parents=True, exist_ok=True) - seed_marker.write_text( + await get_storage_backend().write_text( + SEED_MARKER_KEY, f"seeded\nmorty={morty.id}\nmeeseeks={meeseeks.id}\n", encoding="utf-8", ) - logger.info(f"[AgentSeeder] Wrote seed marker to {seed_marker}") + logger.info(f"[AgentSeeder] Wrote seed marker to {SEED_MARKER_KEY}") async def seed_okr_agent(): @@ -379,14 +362,11 @@ async def seed_okr_agent(): - Generates daily/weekly reports and posts them to the Plaza - Helps team members set up and maintain their focus.md files """ - seed_marker = Path(settings.AGENT_DATA_DIR) / ".seeded" - # Check if OKR Agent has already been seeded - if seed_marker.exists(): - marker_content = seed_marker.read_text(encoding="utf-8") - if "okr_agent=" in marker_content: - logger.info("[AgentSeeder] OKR Agent already seeded, skipping") - return + marker_content = await _read_seed_marker() + if "okr_agent=" in marker_content: + logger.info("[AgentSeeder] OKR Agent already seeded, skipping") + return async with async_session() as db: # Abort if a non-stopped OKR Agent already exists in the DB. @@ -404,7 +384,7 @@ async def seed_okr_agent(): if existing.scalar_one_or_none(): logger.info("[AgentSeeder] OKR Agent already exists in DB, skipping") # Update marker so we don't check again next startup - _append_seed_marker(seed_marker, "okr_agent=existing") + await _append_seed_marker("okr_agent=existing") return # Get platform admin as creator @@ -445,7 +425,7 @@ async def seed_okr_agent(): except IntegrityError: await db.rollback() logger.info("[AgentSeeder] OKR Agent was created concurrently (or exists with same name), skipping") - _append_seed_marker(seed_marker, "okr_agent=existing") + await _append_seed_marker("okr_agent=existing") return # ── Link OKR Agent ID to OKRSettings ── @@ -474,57 +454,34 @@ async def seed_okr_agent(): db.add(AgentPermission(agent_id=okr_agent.id, scope_type="company", access_level="use")) # ── Workspace setup ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) - agent_dir = Path(settings.AGENT_DATA_DIR) / str(okr_agent.id) - - if template_dir.exists(): - shutil.copytree(str(template_dir), str(agent_dir)) - else: - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "reports").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - - # Write OKR Agent soul - (agent_dir / "soul.md").write_text(OKR_AGENT_SOUL.strip() + "\n", encoding="utf-8") - - # Ensure memory.md exists - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text( + await agent_manager.initialize_agent_files(db, okr_agent) + await store_agent_bytes( + okr_agent.id, + "soul.md", + (OKR_AGENT_SOUL.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "memory/memory.md", + ( "# Memory\n\n" "## OKR System State\n" "- Last report generated: (none)\n" "- Last progress collection: (none)\n" - "- Team members tracked: (pending)\n", - encoding="utf-8", - ) - - # OKR Agent does NOT use HEARTBEAT.md — heartbeat is disabled for this agent. - # All scheduled activity is driven by cron triggers (daily/weekly/biweekly/monthly reports). - - # Create workspace/reports directory - reports_dir = agent_dir / "workspace" / "reports" - reports_dir.mkdir(parents=True, exist_ok=True) - - # Write relationships.md — empty initially, will be populated as team onboards - (agent_dir / "relationships.md").write_text( + "- Team members tracked: (pending)\n" + ).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "relationships.md", "# Relationships\n\n" "## Team Members (OKR tracking)\n\n" - "_Team members will be added here as they are onboarded into the OKR system._\n", - encoding="utf-8", + "_Team members will be added here as they are onboarded into the OKR system._\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) - # Stamp state.json if template provides one - state_path = agent_dir / "state.json" - if state_path.exists(): - import json as _json - state = _json.loads(state_path.read_text()) - state["agent_id"] = str(okr_agent.id) - state["name"] = okr_agent.name - state_path.write_text(_json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") - # ── Assign default tools + OKR-specific tools ── # Default tools: all tools where is_default=True default_tools_result = await db.execute( @@ -580,19 +537,10 @@ async def seed_okr_agent(): await db.commit() # Update seed marker - _append_seed_marker(seed_marker, f"okr_agent={okr_agent.id}") + await _append_seed_marker(f"okr_agent={okr_agent.id}") logger.info(f"[AgentSeeder] OKR Agent seeded, id={okr_agent.id}") -def _append_seed_marker(marker_path: Path, line: str): - """Append a key=value line to the .seeded marker file (idempotent).""" - marker_path.parent.mkdir(parents=True, exist_ok=True) - existing = marker_path.read_text(encoding="utf-8") if marker_path.exists() else "" - if line not in existing: - with marker_path.open("a", encoding="utf-8") as f: - f.write(f"{line}\n") - - async def _seed_okr_triggers(db, agent_id: uuid.UUID) -> None: """Create system cron triggers for the OKR Agent. @@ -958,34 +906,32 @@ async def seed_okr_agent_for_tenant(tenant_id: uuid.UUID, creator_id: uuid.UUID) await db.flush() # ── Workspace setup ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) - agent_dir = Path(settings.AGENT_DATA_DIR) / str(okr_agent.id) - - if template_dir.exists(): - shutil.copytree(str(template_dir), str(agent_dir)) - else: - agent_dir.mkdir(parents=True, exist_ok=True) - for sub in ("skills", "workspace", "workspace/reports", "memory"): - (agent_dir / sub).mkdir(parents=True, exist_ok=True) - - (agent_dir / "soul.md").write_text(OKR_AGENT_SOUL.strip() + "\n", encoding="utf-8") - - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text( + await agent_manager.initialize_agent_files(db, okr_agent) + await store_agent_bytes( + okr_agent.id, + "soul.md", + (OKR_AGENT_SOUL.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "memory/memory.md", + ( "# Memory\n\n" "## OKR System State\n" "- Last report generated: (none)\n" "- Last progress collection: (none)\n" - "- Team members tracked: (pending)\n", - encoding="utf-8", - ) - - (agent_dir / "relationships.md").write_text( + "- Team members tracked: (pending)\n" + ).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "relationships.md", "# Relationships\n\n" "## Team Members (OKR tracking)\n\n" - "_Team members will be added here as they are onboarded into the OKR system._\n", - encoding="utf-8", + "_Team members will be added here as they are onboarded into the OKR system._\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) # ── Assign default tools ── diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index 9b5fb11e6..161431598 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -46,7 +46,7 @@ _settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) +WORKSPACE_ROOT = Path(_settings.STORAGE_LOCAL_ROOT or _settings.AGENT_DATA_DIR) # ─── Tool Config Cache ────────────────────────────────────────── # Cache tool configurations to avoid frequent DB queries diff --git a/backend/app/services/collaboration.py b/backend/app/services/collaboration.py index 394281d8a..9cc46db93 100644 --- a/backend/app/services/collaboration.py +++ b/backend/app/services/collaboration.py @@ -10,6 +10,7 @@ from app.models.agent import Agent from app.models.audit import AuditLog +from app.services.storage import store_agent_bytes class CollaborationService: @@ -111,21 +112,16 @@ async def send_message_between_agents( from_result = await db.execute(select(Agent).where(Agent.id == from_agent_id)) from_agent = from_result.scalar_one_or_none() - # Write message to target agent's workspace - from pathlib import Path - from app.config import get_settings - settings = get_settings() - - inbox_dir = Path(settings.AGENT_DATA_DIR) / str(to_agent_id) / "workspace" / "inbox" - inbox_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - msg_file = inbox_dir / f"{timestamp}_{str(from_agent_id)[:8]}.md" - msg_file.write_text( + rel_path = f"workspace/inbox/{timestamp}_{str(from_agent_id)[:8]}.md" + await store_agent_bytes( + to_agent_id, + rel_path, f"# 来自 {from_agent.name if from_agent else 'Unknown'} 的消息\n" f"- 类型: {msg_type}\n" f"- 时间: {datetime.now(timezone.utc).isoformat()}\n\n" - f"{message}\n" + f"{message}\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) db.add(AuditLog( diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 73cf80b20..91612998d 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -16,10 +16,10 @@ from loguru import logger from sqlalchemy import select -from app.config import get_settings from app.database import async_session from app.models.channel_config import ChannelConfig from app.services.dingtalk_token import dingtalk_token_manager +from app.services.storage import store_agent_upload # ─── DingTalk Media Helpers ───────────────────────────── @@ -77,14 +77,6 @@ async def download_dingtalk_media( return await _download_file(download_url) -def _resolve_upload_dir(agent_id: uuid.UUID) -> Path: - """Get the uploads directory for an agent, creating it if needed.""" - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - return upload_dir - - async def _process_media_message( msg_data: dict, app_key: str, @@ -121,18 +113,21 @@ async def _process_media_message( if not file_bytes: return "[User sent an image, but download failed]", None, None - upload_dir = _resolve_upload_dir(agent_id) filename = f"dingtalk_img_{uuid.uuid4().hex[:8]}.jpg" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved image to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[DingTalk] Saved image to {workspace_path} ({len(file_bytes)} bytes)") b64_data = base64.b64encode(file_bytes).decode("ascii") image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" return ( f"[User sent an image]\n{image_marker}", [f"data:image/jpeg;base64,{b64_data}"], - [str(save_path)], + [workspace_path], ) elif msgtype == "richText": @@ -148,17 +143,20 @@ async def _process_media_message( app_key, app_secret, item["downloadCode"] ) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) filename = f"dingtalk_richimg_{uuid.uuid4().hex[:8]}.jpg" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved rich text image to {save_path}") + _, workspace_path, _ = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[DingTalk] Saved rich text image to {workspace_path}") b64_data = base64.b64encode(file_bytes).decode("ascii") image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" text_parts.append(image_marker) image_base64_list.append(f"data:image/jpeg;base64,{b64_data}") - saved_file_paths.append(str(save_path)) + saved_file_paths.append(workspace_path) combined_text = "\n".join(text_parts).strip() if not combined_text: @@ -181,16 +179,14 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) duration = content.get("duration", "unknown") filename = f"dingtalk_audio_{uuid.uuid4().hex[:8]}.amr" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved audio to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload(agent_id, filename, file_bytes) + logger.info(f"[DingTalk] Saved audio to {workspace_path} ({len(file_bytes)} bytes)") return ( f"[User sent a voice message, duration {duration}ms, saved to {filename}]", None, - [str(save_path)], + [workspace_path], ) return "[User sent a voice message, but it could not be processed]", None, None @@ -200,16 +196,14 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) duration = content.get("duration", "unknown") filename = f"dingtalk_video_{uuid.uuid4().hex[:8]}.mp4" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved video to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload(agent_id, filename, file_bytes) + logger.info(f"[DingTalk] Saved video to {workspace_path} ({len(file_bytes)} bytes)") return ( f"[User sent a video, duration {duration}ms, saved to {filename}]", None, - [str(save_path)], + [workspace_path], ) return "[User sent a video, but it could not be downloaded]", None, None @@ -220,18 +214,16 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) safe_name = f"dingtalk_{uuid.uuid4().hex[:8]}_{original_filename}" - save_path = upload_dir / safe_name - save_path.write_bytes(file_bytes) + _, workspace_path, _ = await store_agent_upload(agent_id, safe_name, file_bytes) logger.info( - f"[DingTalk] Saved file '{original_filename}' to {save_path} " + f"[DingTalk] Saved file '{original_filename}' to {workspace_path} " f"({len(file_bytes)} bytes)" ) return ( f"[file:{original_filename}]", None, - [str(save_path)], + [workspace_path], ) return f"[User sent file {original_filename}, but it could not be downloaded]", None, None diff --git a/backend/app/services/enterprise_sync.py b/backend/app/services/enterprise_sync.py index b529d8387..a67523215 100644 --- a/backend/app/services/enterprise_sync.py +++ b/backend/app/services/enterprise_sync.py @@ -6,18 +6,15 @@ import json import uuid -from pathlib import Path from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.core.events import publish_event from app.models.agent import Agent from app.models.audit import EnterpriseInfo - -settings = get_settings() +from app.services.storage import store_agent_bytes # Redis channel for enterprise info updates ENTERPRISE_INFO_CHANNEL = "enterprise_info_updated" @@ -67,9 +64,6 @@ async def sync_to_agent(self, db: AsyncSession, agent_id: uuid.UUID, agent_role: Filters by visible_roles — if empty, all roles can see it. """ - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "enterprise_info" - agent_dir.mkdir(parents=True, exist_ok=True) - result = await db.execute(select(EnterpriseInfo)) all_info = result.scalars().all() @@ -78,12 +72,16 @@ async def sync_to_agent(self, db: AsyncSession, agent_id: uuid.UUID, agent_role: if info.visible_roles and agent_role and agent_role not in info.visible_roles: continue - file_path = agent_dir / f"{info.info_type}.json" - file_path.write_text(json.dumps({ - "type": info.info_type, - "version": info.version, - "content": info.content, - }, ensure_ascii=False, indent=2)) + await store_agent_bytes( + agent_id, + f"enterprise_info/{info.info_type}.json", + json.dumps({ + "type": info.info_type, + "version": info.version, + "content": info.content, + }, ensure_ascii=False, indent=2).encode("utf-8"), + content_type="application/json", + ) logger.info(f"Synced enterprise info to agent {agent_id}") diff --git a/backend/app/services/heartbeat.py b/backend/app/services/heartbeat.py index aa2969c4e..f0e1a8a1a 100644 --- a/backend/app/services/heartbeat.py +++ b/backend/app/services/heartbeat.py @@ -14,6 +14,7 @@ from loguru import logger from sqlalchemy import select, update, exists, and_ +from app.services.storage import agent_storage_key, get_storage_backend # Default heartbeat instruction used when HEARTBEAT.md doesn't exist DEFAULT_HEARTBEAT_INSTRUCTION = """[Heartbeat Check] @@ -192,15 +193,12 @@ async def _execute_heartbeat(agent_id: uuid.UUID): model_request_timeout = getattr(model, 'request_timeout', None) # Read HEARTBEAT.md if it exists, otherwise use default - from pathlib import Path - from app.config import get_settings - settings = get_settings() - - ws_root = Path(settings.AGENT_DATA_DIR) / str(agent_id) - hb_file = ws_root / "HEARTBEAT.md" - if hb_file.exists(): + storage = get_storage_backend() + hb_key = agent_storage_key(agent_id, "HEARTBEAT.md") + if await storage.exists(hb_key): try: - custom = hb_file.read_text(encoding="utf-8", errors="replace").strip() + custom = await storage.read_text(hb_key, encoding="utf-8", errors="replace") + custom = custom.strip() if custom: # Prepend privacy rules to custom heartbeat heartbeat_instruction = custom + """ diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index e57793137..fb57f0f23 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -14,7 +14,6 @@ import json import uuid -from pathlib import Path from typing import TYPE_CHECKING from loguru import logger @@ -246,9 +245,8 @@ async def _process_tool_call( if supports_vision and agent_id: try: from app.services.vision_inject import try_inject_screenshot_vision - from app.config import get_settings - settings = get_settings() - ws_path = Path(settings.AGENT_DATA_DIR) / str(agent_id) + from app.services.storage import ensure_local_path + ws_path = await ensure_local_path(str(agent_id)) vision_content = try_inject_screenshot_vision(tool_name, str(result), ws_path) if vision_content: tool_content = vision_content diff --git a/backend/app/services/okr_scheduler.py b/backend/app/services/okr_scheduler.py index e546613e1..4b3bb78cf 100644 --- a/backend/app/services/okr_scheduler.py +++ b/backend/app/services/okr_scheduler.py @@ -16,14 +16,12 @@ import re import uuid from datetime import date, datetime, timedelta -from pathlib import Path from typing import Optional from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.database import async_session from app.models.agent import Agent from app.models.okr import ( @@ -33,9 +31,7 @@ OKRSettings, WorkReport, ) - -_settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) +from app.services.storage import agent_storage_key, get_storage_backend, store_agent_bytes # ─── Focus File Parsing ─────────────────────────────────────────────────────── @@ -138,17 +134,16 @@ async def collect_all_focus_updates( skipped_count = 0 error_count = 0 lines: list[str] = [] + storage = get_storage_backend() for agent in agents: - agent_dir = WORKSPACE_ROOT / str(agent.id) - focus_path = agent_dir / "focus.md" - - if not focus_path.exists(): + focus_key = agent_storage_key(agent.id, "focus.md") + if not await storage.exists(focus_key): skipped_count += 1 continue try: - content = focus_path.read_text(encoding="utf-8") + content = await storage.read_text(focus_key, encoding="utf-8", errors="replace") updates = _parse_focus_md(content) if not updates: @@ -159,7 +154,7 @@ async def collect_all_focus_updates( try: kr_uuid = uuid.UUID(kr_id_str) except ValueError: - logger.warning(f"[OKRScheduler] Invalid KR UUID '{kr_id_str}' in {focus_path}") + logger.warning(f"[OKRScheduler] Invalid KR UUID '{kr_id_str}' in {focus_key}") continue # Fetch the KR and verify it belongs to this tenant @@ -403,12 +398,15 @@ async def _store_report( await db.commit() -def _safe_write_report(agent_dir: Path, filename: str, content: str) -> None: +async def _safe_write_report(okr_agent_id: uuid.UUID, filename: str, content: str) -> None: """Write report to OKR Agent's workspace/reports/ directory.""" try: - reports_dir = agent_dir / "workspace" / "reports" - reports_dir.mkdir(parents=True, exist_ok=True) - (reports_dir / filename).write_text(content, encoding="utf-8") + await store_agent_bytes( + okr_agent_id, + f"workspace/reports/{filename}", + content.encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) except Exception as exc: logger.warning(f"[OKRScheduler] Could not write report file {filename}: {exc}") @@ -445,8 +443,7 @@ async def generate_daily_report( await _store_report(tenant_id, okr_agent_id, "daily", today, content, db) # Write file to workspace - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) - _safe_write_report(agent_dir, f"daily_{today.strftime('%Y%m%d')}.md", content) + await _safe_write_report(okr_agent_id, f"daily_{today.strftime('%Y%m%d')}.md", content) logger.info(f"[OKRScheduler] Daily report generated for tenant {tenant_id}") return content @@ -485,9 +482,8 @@ async def generate_weekly_report( monday = today - timedelta(days=today.weekday()) await _store_report(tenant_id, okr_agent_id, "weekly", monday, content, db) - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) week_label = monday.strftime("%Y-W%V") - _safe_write_report(agent_dir, f"weekly_{week_label}.md", content) + await _safe_write_report(okr_agent_id, f"weekly_{week_label}.md", content) logger.info(f"[OKRScheduler] Weekly report generated for tenant {tenant_id}") return content @@ -565,9 +561,8 @@ async def generate_monthly_report( await _store_report(tenant_id, okr_agent_id, "monthly", month_start, content, db) # Write file to OKR Agent workspace - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) month_label = month_start.strftime("%Y-%m") - _safe_write_report(agent_dir, f"monthly_{month_label}.md", content) + await _safe_write_report(okr_agent_id, f"monthly_{month_label}.md", content) logger.info(f"[OKRScheduler] Monthly report generated for tenant {tenant_id}") return content diff --git a/backend/app/services/realtime.py b/backend/app/services/realtime.py new file mode 100644 index 000000000..39bcf2d85 --- /dev/null +++ b/backend/app/services/realtime.py @@ -0,0 +1,19 @@ +"""Compatibility facade for realtime services. + +New code should prefer the `app.services.realtime_runtime` package. +This module remains as the stable import path for existing callers. +""" + +from app.services.realtime_runtime import ( + PRESENCE_TTL_SECONDS, + PUBSUB_PREFIX, + RealtimeRouter, + realtime_router, +) + +__all__ = [ + "PRESENCE_TTL_SECONDS", + "PUBSUB_PREFIX", + "RealtimeRouter", + "realtime_router", +] diff --git a/backend/app/services/realtime_runtime/__init__.py b/backend/app/services/realtime_runtime/__init__.py new file mode 100644 index 000000000..feb673ad2 --- /dev/null +++ b/backend/app/services/realtime_runtime/__init__.py @@ -0,0 +1,15 @@ +"""Realtime routing runtime package.""" + +from app.services.realtime_runtime.router import ( + PRESENCE_TTL_SECONDS, + PUBSUB_PREFIX, + RealtimeRouter, + realtime_router, +) + +__all__ = [ + "PRESENCE_TTL_SECONDS", + "PUBSUB_PREFIX", + "RealtimeRouter", + "realtime_router", +] diff --git a/backend/app/services/realtime_runtime/router.py b/backend/app/services/realtime_runtime/router.py new file mode 100644 index 000000000..8cb8f0411 --- /dev/null +++ b/backend/app/services/realtime_runtime/router.py @@ -0,0 +1,200 @@ +"""Redis-backed websocket presence and cross-instance message routing.""" + +from __future__ import annotations + +import asyncio +import json +import uuid + +from fastapi import WebSocket +from loguru import logger + +from app.config import get_settings +from app.core.events import get_redis + +settings = get_settings() + +PRESENCE_TTL_SECONDS = 180 +PUBSUB_PREFIX = "realtime:ws" + + +class RealtimeRouter: + def __init__(self) -> None: + self.instance_id = settings.INSTANCE_ID + self._subscriber_task: asyncio.Task | None = None + self._started = False + + def _connection_key(self, connection_id: str) -> str: + return f"{PUBSUB_PREFIX}:conn:{connection_id}" + + def _agent_index_key(self, agent_id: str) -> str: + return f"{PUBSUB_PREFIX}:agent:{agent_id}" + + def _instance_channel(self) -> str: + return f"{PUBSUB_PREFIX}:instance:{self.instance_id}" + + async def register_connection( + self, + *, + agent_id: str, + websocket: WebSocket, + session_id: str | None, + user_id: str | None, + ) -> str: + connection_id = uuid.uuid4().hex + redis = await get_redis() + payload = { + "agent_id": agent_id, + "session_id": session_id or "", + "user_id": user_id or "", + "instance_id": self.instance_id, + } + async with redis.pipeline(transaction=True) as pipe: + pipe.sadd(self._agent_index_key(agent_id), connection_id) + pipe.hset(self._connection_key(connection_id), mapping=payload) + pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS) + pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS) + await pipe.execute() + setattr(websocket.state, "realtime_connection_id", connection_id) + return connection_id + + async def unregister_connection(self, *, agent_id: str, websocket: WebSocket) -> None: + connection_id = getattr(websocket.state, "realtime_connection_id", None) + if not connection_id: + return + redis = await get_redis() + async with redis.pipeline(transaction=True) as pipe: + pipe.srem(self._agent_index_key(agent_id), connection_id) + pipe.delete(self._connection_key(connection_id)) + await pipe.execute() + + async def is_user_viewing_session(self, *, agent_id: str, session_id: str, user_id: str) -> bool: + for record in await self._list_presence(agent_id): + if record.get("session_id") == session_id and record.get("user_id") == user_id: + return True + return False + + async def get_active_session_ids(self, agent_id: str) -> list[str]: + seen: set[str] = set() + for record in await self._list_presence(agent_id): + session_id = (record.get("session_id") or "").strip() + if session_id: + seen.add(session_id) + return list(seen) + + async def route_message( + self, + *, + agent_id: str, + message: dict, + local_connections: list[tuple[WebSocket, str | None, str | None]], + session_id: str | None = None, + user_id: str | None = None, + ) -> None: + local_sent = 0 + for ws, local_session_id, local_user_id in list(local_connections): + if session_id is not None and local_session_id != session_id: + continue + if user_id is not None and local_user_id != user_id: + continue + try: + await ws.send_json(message) + local_sent += 1 + except Exception: + pass + + remote_targets: dict[str, int] = {} + for record in await self._list_presence(agent_id): + if record.get("instance_id") == self.instance_id: + continue + if session_id is not None and record.get("session_id") != session_id: + continue + if user_id is not None and record.get("user_id") != user_id: + continue + target_instance = record.get("instance_id") + if target_instance: + remote_targets[target_instance] = remote_targets.get(target_instance, 0) + 1 + + if not remote_targets: + return + + redis = await get_redis() + envelope = json.dumps( + { + "message": message, + "agent_id": agent_id, + "session_id": session_id, + "user_id": user_id, + "origin_instance_id": self.instance_id, + } + ) + publish_tasks = [ + redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope) + for instance_id in remote_targets + ] + await asyncio.gather(*publish_tasks, return_exceptions=True) + logger.debug( + f"[Realtime] Routed agent={agent_id} local={local_sent} remote_instances={list(remote_targets.keys())}" + ) + + async def start(self, deliver_local) -> None: + if self._started: + return + self._started = True + self._subscriber_task = asyncio.create_task(self._subscriber_loop(deliver_local), name="realtime-subscriber") + + async def stop(self) -> None: + if self._subscriber_task: + self._subscriber_task.cancel() + try: + await self._subscriber_task + except asyncio.CancelledError: + pass + self._subscriber_task = None + self._started = False + + async def _subscriber_loop(self, deliver_local) -> None: + redis = await get_redis() + pubsub = redis.pubsub() + await pubsub.subscribe(self._instance_channel()) + try: + while True: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) + if not message: + await asyncio.sleep(0.05) + continue + try: + data = json.loads(message["data"]) + await deliver_local( + agent_id=data["agent_id"], + payload=data["message"], + session_id=data.get("session_id"), + user_id=data.get("user_id"), + ) + except Exception as exc: + logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}") + except asyncio.CancelledError: + raise + finally: + await pubsub.unsubscribe(self._instance_channel()) + await pubsub.aclose() + + async def _list_presence(self, agent_id: str) -> list[dict[str, str]]: + redis = await get_redis() + connection_ids = await redis.smembers(self._agent_index_key(agent_id)) + if not connection_ids: + return [] + records: list[dict[str, str]] = [] + stale_ids: list[str] = [] + for connection_id in connection_ids: + data = await redis.hgetall(self._connection_key(connection_id)) + if not data: + stale_ids.append(connection_id) + continue + records.append(data) + if stale_ids: + await redis.srem(self._agent_index_key(agent_id), *stale_ids) + return records + + +realtime_router = RealtimeRouter() diff --git a/backend/app/services/storage.py b/backend/app/services/storage.py new file mode 100644 index 000000000..6fc868b11 --- /dev/null +++ b/backend/app/services/storage.py @@ -0,0 +1,45 @@ +"""Compatibility facade for storage services. + +New code should prefer the `app.services.storage_runtime` package. +This module remains as the stable import path for existing callers. +""" + +from app.services.storage_runtime import ( + LocalStorageBackend, + S3StorageBackend, + StorageBackend, + StorageEntry, + agent_storage_key, + agent_storage_prefix, + agent_upload_key, + agent_workspace_key, + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, + sanitize_filename, + store_agent_bytes, + store_agent_upload, + tenant_storage_key, + tenant_storage_prefix, +) + +__all__ = [ + "LocalStorageBackend", + "S3StorageBackend", + "StorageBackend", + "StorageEntry", + "agent_storage_key", + "agent_storage_prefix", + "agent_upload_key", + "agent_workspace_key", + "ensure_local_path", + "get_storage_backend", + "guess_content_type", + "normalize_storage_key", + "sanitize_filename", + "store_agent_bytes", + "store_agent_upload", + "tenant_storage_key", + "tenant_storage_prefix", +] diff --git a/backend/app/services/storage_runtime/__init__.py b/backend/app/services/storage_runtime/__init__.py new file mode 100644 index 000000000..a156799d0 --- /dev/null +++ b/backend/app/services/storage_runtime/__init__.py @@ -0,0 +1,42 @@ +"""Storage runtime package.""" + +from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.agent_files import ( + agent_storage_key, + agent_upload_key, + agent_workspace_key, + sanitize_filename, + store_agent_bytes, + store_agent_upload, + tenant_storage_key, +) +from app.services.storage_runtime.facade import ( + agent_storage_prefix, + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, + tenant_storage_prefix, +) +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.storage_runtime.s3 import S3StorageBackend + +__all__ = [ + "StorageBackend", + "StorageEntry", + "LocalStorageBackend", + "S3StorageBackend", + "agent_storage_key", + "agent_storage_prefix", + "agent_upload_key", + "agent_workspace_key", + "ensure_local_path", + "get_storage_backend", + "guess_content_type", + "normalize_storage_key", + "sanitize_filename", + "store_agent_bytes", + "store_agent_upload", + "tenant_storage_key", + "tenant_storage_prefix", +] diff --git a/backend/app/services/storage_runtime/agent_files.py b/backend/app/services/storage_runtime/agent_files.py new file mode 100644 index 000000000..d9c720816 --- /dev/null +++ b/backend/app/services/storage_runtime/agent_files.py @@ -0,0 +1,84 @@ +"""Agent-scoped storage helpers. + +This module centralizes how agent and tenant workspace keys are built so +channel handlers and background services do not manually assemble +`workspace/uploads/...` paths all over the codebase. +""" + +from __future__ import annotations + +import os +import uuid +from pathlib import Path + +from app.services.storage_runtime.facade import ( + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, +) + + +def sanitize_filename(filename: str, fallback: str = "file.bin") -> str: + name = (filename or "").replace("\\", "_").replace("/", "_").strip() + return name or fallback + + +def agent_storage_key(agent_id: uuid.UUID | str, rel_path: str = "") -> str: + prefix = str(agent_id) + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix + + +def agent_workspace_key(agent_id: uuid.UUID | str, rel_path: str = "") -> str: + rel = normalize_storage_key(rel_path) + workspace_rel = f"workspace/{rel}" if rel else "workspace" + return agent_storage_key(agent_id, workspace_rel) + + +def agent_upload_key(agent_id: uuid.UUID | str, filename: str) -> str: + safe_name = sanitize_filename(filename) + return agent_workspace_key(agent_id, f"uploads/{safe_name}") + + +def tenant_storage_key(tenant_id: uuid.UUID | str, rel_path: str = "") -> str: + prefix = normalize_storage_key(f"enterprise_info_{tenant_id}") + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix + + +async def store_agent_bytes( + agent_id: uuid.UUID | str, + rel_path: str, + data: bytes, + *, + content_type: str | None = None, +) -> str: + key = agent_storage_key(agent_id, rel_path) + storage = get_storage_backend() + await storage.write_bytes( + key, + data, + content_type=content_type or guess_content_type(Path(rel_path).name), + ) + return key + + +async def store_agent_upload( + agent_id: uuid.UUID | str, + filename: str, + data: bytes, + *, + content_type: str | None = None, +) -> tuple[str, str, Path]: + key = agent_upload_key(agent_id, filename) + storage = get_storage_backend() + safe_name = os.path.basename(key) + await storage.write_bytes( + key, + data, + content_type=content_type or guess_content_type(safe_name), + ) + local_path = await ensure_local_path(key) + workspace_path = f"workspace/uploads/{safe_name}" + return key, workspace_path, local_path diff --git a/backend/app/services/storage_runtime/base.py b/backend/app/services/storage_runtime/base.py new file mode 100644 index 000000000..2b11bd7d2 --- /dev/null +++ b/backend/app/services/storage_runtime/base.py @@ -0,0 +1,57 @@ +"""Base storage types and interfaces.""" + +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class StorageEntry: + name: str + key: str + is_dir: bool + size: int = 0 + modified_at: str = "" + + +class StorageBackend: + async def exists(self, key: str) -> bool: + raise NotImplementedError + + async def is_file(self, key: str) -> bool: + raise NotImplementedError + + async def is_dir(self, key: str) -> bool: + raise NotImplementedError + + async def list_dir(self, key: str) -> list[StorageEntry]: + raise NotImplementedError + + async def read_bytes(self, key: str) -> bytes: + raise NotImplementedError + + async def read_text(self, key: str, encoding: str = "utf-8", errors: str = "replace") -> str: + raw = await self.read_bytes(key) + return raw.decode(encoding, errors=errors) + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + raise NotImplementedError + + async def write_text(self, key: str, content: str, encoding: str = "utf-8") -> None: + await self.write_bytes(key, content.encode(encoding), content_type="text/plain; charset=utf-8") + + async def delete(self, key: str) -> None: + raise NotImplementedError + + async def delete_tree(self, key: str) -> None: + raise NotImplementedError + + async def stat(self, key: str) -> StorageEntry: + raise NotImplementedError + + async def local_path_for(self, key: str) -> Path | None: + return None + + async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: + return None diff --git a/backend/app/services/storage_runtime/facade.py b/backend/app/services/storage_runtime/facade.py new file mode 100644 index 000000000..9c90745ee --- /dev/null +++ b/backend/app/services/storage_runtime/facade.py @@ -0,0 +1,52 @@ +"""Facade for selecting the configured storage backend.""" + +from __future__ import annotations + +import mimetypes +from pathlib import Path + +from app.config import get_settings +from app.services.storage_runtime.base import StorageBackend +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.storage_runtime.s3 import S3StorageBackend +from app.services.storage_runtime.utils import ( + agent_storage_prefix, + normalize_storage_key, + tenant_storage_prefix, +) + +_storage_backend: StorageBackend | None = None + + +def get_storage_backend() -> StorageBackend: + global _storage_backend + if _storage_backend is not None: + return _storage_backend + + settings = get_settings() + backend = (settings.STORAGE_BACKEND or "local").strip().lower() + if backend == "s3": + _storage_backend = S3StorageBackend( + bucket=settings.S3_BUCKET, + prefix=settings.S3_PREFIX, + region=settings.S3_REGION, + endpoint_url=settings.S3_ENDPOINT_URL, + access_key_id=settings.S3_ACCESS_KEY_ID, + secret_access_key=settings.S3_SECRET_ACCESS_KEY, + presign_ttl_seconds=settings.S3_PRESIGN_TTL_SECONDS, + ) + else: + _storage_backend = LocalStorageBackend(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) + return _storage_backend + + +async def ensure_local_path(key: str) -> Path: + backend = get_storage_backend() + path = await backend.local_path_for(key) + if path is None: + raise RuntimeError("Storage backend cannot materialize a local path") + return path + + +def guess_content_type(filename: str) -> str: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" diff --git a/backend/app/services/storage_runtime/local.py b/backend/app/services/storage_runtime/local.py new file mode 100644 index 000000000..3128c6108 --- /dev/null +++ b/backend/app/services/storage_runtime/local.py @@ -0,0 +1,101 @@ +"""Local filesystem storage backend.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import aiofiles +from fastapi import HTTPException, status + +from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.utils import normalize_storage_key + + +class LocalStorageBackend(StorageBackend): + def __init__(self, root: str): + self.root = Path(root) + + def _full_path(self, key: str) -> Path: + normalized = normalize_storage_key(key) + full = (self.root / normalized).resolve() + root_resolved = self.root.resolve() + if not str(full).startswith(str(root_resolved)): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Path traversal not allowed") + return full + + async def exists(self, key: str) -> bool: + return self._full_path(key).exists() + + async def is_file(self, key: str) -> bool: + return self._full_path(key).is_file() + + async def is_dir(self, key: str) -> bool: + return self._full_path(key).is_dir() + + async def list_dir(self, key: str) -> list[StorageEntry]: + base = self._full_path(key) + if not base.exists() or not base.is_dir(): + return [] + entries: list[StorageEntry] = [] + for entry in sorted(base.iterdir(), key=lambda item: (not item.is_dir(), item.name)): + if entry.name == ".gitkeep": + continue + stat = entry.stat() + rel = str(entry.resolve().relative_to(self.root.resolve())) + entries.append( + StorageEntry( + name=entry.name, + key=rel, + is_dir=entry.is_dir(), + size=stat.st_size if entry.is_file() else 0, + modified_at=str(stat.st_mtime), + ) + ) + return entries + + async def read_bytes(self, key: str) -> bytes: + path = self._full_path(key) + async with aiofiles.open(path, "rb") as f: + return await f.read() + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + path = self._full_path(key) + path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(path, "wb") as f: + await f.write(data) + + async def delete(self, key: str) -> None: + path = self._full_path(key) + if not path.exists(): + return + if path.is_dir(): + await self.delete_tree(key) + else: + path.unlink() + + async def delete_tree(self, key: str) -> None: + path = self._full_path(key) + if not path.exists(): + return + await asyncio.to_thread(_local_delete_tree, path) + + async def stat(self, key: str) -> StorageEntry: + path = self._full_path(key) + stat = path.stat() + return StorageEntry( + name=path.name, + key=normalize_storage_key(key), + is_dir=path.is_dir(), + size=stat.st_size if path.is_file() else 0, + modified_at=str(stat.st_mtime), + ) + + async def local_path_for(self, key: str) -> Path | None: + return self._full_path(key) + + +def _local_delete_tree(path: Path) -> None: + import shutil + + shutil.rmtree(path) diff --git a/backend/app/services/storage_runtime/s3.py b/backend/app/services/storage_runtime/s3.py new file mode 100644 index 000000000..b0b190a74 --- /dev/null +++ b/backend/app/services/storage_runtime/s3.py @@ -0,0 +1,204 @@ +"""S3-compatible object storage backend.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any + +from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.utils import normalize_storage_key + + +class S3StorageBackend(StorageBackend): + def __init__( + self, + *, + bucket: str, + prefix: str = "", + region: str = "", + endpoint_url: str = "", + access_key_id: str = "", + secret_access_key: str = "", + presign_ttl_seconds: int = 3600, + ): + self.bucket = bucket + self.prefix = normalize_storage_key(prefix) + self.region = region + self.endpoint_url = endpoint_url or None + self.access_key_id = access_key_id or None + self.secret_access_key = secret_access_key or None + self.presign_ttl_seconds = presign_ttl_seconds + self._client: Any | None = None + + def _object_key(self, key: str) -> str: + normalized = normalize_storage_key(key) + return f"{self.prefix}/{normalized}" if self.prefix else normalized + + def _client_or_raise(self): + if self._client is None: + try: + import boto3 + except ImportError as exc: + raise RuntimeError("boto3 is required for S3 storage backend") from exc + self._client = boto3.client( + "s3", + region_name=self.region or None, + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + ) + return self._client + + async def exists(self, key: str) -> bool: + try: + await self.stat(key) + return True + except FileNotFoundError: + return False + + async def is_file(self, key: str) -> bool: + return await self.exists(key) + + async def is_dir(self, key: str) -> bool: + prefix = self._object_key(key).rstrip("/") + "/" + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + Delimiter="/", + MaxKeys=1, + ) + return bool(response.get("Contents") or response.get("CommonPrefixes")) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = self._object_key(key).rstrip("/") + if prefix: + prefix += "/" + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + Delimiter="/", + ) + entries: list[StorageEntry] = [] + for item in response.get("CommonPrefixes", []): + raw = item.get("Prefix", "").rstrip("/") + rel = _strip_prefix(raw, self.prefix) + name = rel.split("/")[-1] + entries.append(StorageEntry(name=name, key=rel, is_dir=True)) + for item in response.get("Contents", []): + raw = item.get("Key", "") + if not raw or raw == prefix: + continue + rel = _strip_prefix(raw, self.prefix) + name = rel.split("/")[-1] + entries.append( + StorageEntry( + name=name, + key=rel, + is_dir=False, + size=int(item.get("Size", 0)), + modified_at=str(item.get("LastModified") or ""), + ) + ) + return sorted(entries, key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + client = self._client_or_raise() + response = await asyncio.to_thread( + client.get_object, + Bucket=self.bucket, + Key=self._object_key(key), + ) + body = response["Body"] + return await asyncio.to_thread(body.read) + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + client = self._client_or_raise() + kwargs: dict[str, Any] = { + "Bucket": self.bucket, + "Key": self._object_key(key), + "Body": data, + } + if content_type: + kwargs["ContentType"] = content_type + await asyncio.to_thread(client.put_object, **kwargs) + + async def delete(self, key: str) -> None: + client = self._client_or_raise() + await asyncio.to_thread( + client.delete_object, + Bucket=self.bucket, + Key=self._object_key(key), + ) + + async def delete_tree(self, key: str) -> None: + client = self._client_or_raise() + prefix = self._object_key(key).rstrip("/") + "/" + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + ) + contents = response.get("Contents", []) + if not contents: + return + objects = [{"Key": item["Key"]} for item in contents] + await asyncio.to_thread( + client.delete_objects, + Bucket=self.bucket, + Delete={"Objects": objects}, + ) + + async def stat(self, key: str) -> StorageEntry: + client = self._client_or_raise() + try: + response = await asyncio.to_thread( + client.head_object, + Bucket=self.bucket, + Key=self._object_key(key), + ) + except Exception as exc: + raise FileNotFoundError(key) from exc + return StorageEntry( + name=normalize_storage_key(key).split("/")[-1], + key=normalize_storage_key(key), + is_dir=False, + size=int(response.get("ContentLength", 0)), + modified_at=str(response.get("LastModified") or ""), + ) + + async def local_path_for(self, key: str) -> Path | None: + suffix = Path(normalize_storage_key(key)).suffix + tmp = NamedTemporaryFile(delete=False, suffix=suffix) + tmp.close() + path = Path(tmp.name) + await self.write_local_copy(key, path) + return path + + async def write_local_copy(self, key: str, path: Path) -> None: + data = await self.read_bytes(key) + await asyncio.to_thread(path.write_bytes, data) + + async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: + client = self._client_or_raise() + params: dict[str, Any] = {"Bucket": self.bucket, "Key": self._object_key(key)} + if filename: + disposition = "inline" if inline else "attachment" + params["ResponseContentDisposition"] = f'{disposition}; filename="{filename}"' + return await asyncio.to_thread( + client.generate_presigned_url, + "get_object", + Params=params, + ExpiresIn=self.presign_ttl_seconds, + ) + + +def _strip_prefix(raw_key: str, prefix: str) -> str: + if prefix and raw_key.startswith(prefix + "/"): + return raw_key[len(prefix) + 1:] + return raw_key diff --git a/backend/app/services/storage_runtime/utils.py b/backend/app/services/storage_runtime/utils.py new file mode 100644 index 000000000..98634ba43 --- /dev/null +++ b/backend/app/services/storage_runtime/utils.py @@ -0,0 +1,24 @@ +"""Storage path helpers.""" + + +def normalize_storage_key(key: str) -> str: + """Normalize a storage key and reject traversal semantics.""" + clean = (key or "").replace("\\", "/").strip().lstrip("/") + parts: list[str] = [] + for part in clean.split("/"): + if part in ("", "."): + continue + if part == "..": + if parts: + parts.pop() + continue + parts.append(part) + return "/".join(parts) + + +def agent_storage_prefix(agent_id: str) -> str: + return normalize_storage_key(agent_id) + + +def tenant_storage_prefix(tenant_id: str) -> str: + return normalize_storage_key(f"enterprise_info_{tenant_id}") diff --git a/backend/app/services/supervision_reminder.py b/backend/app/services/supervision_reminder.py index 1c318cdaa..809285c7f 100644 --- a/backend/app/services/supervision_reminder.py +++ b/backend/app/services/supervision_reminder.py @@ -112,8 +112,6 @@ async def _get_agent_reply(target_agent, message: str, db) -> str | None: get_model_api_key, ) - ) - model_id = target_agent.primary_model_id or target_agent.fallback_model_id if not model_id: return None diff --git a/backend/app/services/trigger_daemon.py b/backend/app/services/trigger_daemon.py index e4bfd1905..de9531471 100644 --- a/backend/app/services/trigger_daemon.py +++ b/backend/app/services/trigger_daemon.py @@ -1,29 +1,33 @@ -"""Trigger Daemon — evaluates all agent triggers in a single background loop. +"""Trigger daemon orchestrator. -Replaces the separate heartbeat, scheduler, and supervision reminder services -with a unified trigger evaluation engine. Runs as an asyncio background task. - -Every 15 seconds: - 1. Load all enabled triggers from DB - 2. Evaluate each trigger (cron/once/interval/poll/on_message/webhook) - 3. Group fired triggers by agent_id (30s dedup window) - 4. Invoke each agent once with all its fired triggers as context +Trigger-specific evaluation and invocation behavior now lives under +`app.services.trigger_runtime`. This module owns the main loop, dedup window, +and distributed claim/invoke flow. """ import asyncio -import ipaddress -import json as _json import uuid from datetime import datetime, timezone, timedelta -from urllib.parse import urlparse - -from croniter import croniter from loguru import logger from sqlalchemy import select from app.database import async_session from app.models.trigger import AgentTrigger -from app.models.agent import Agent +from app.services.trigger_runtime.evaluator import ( + evaluate_trigger as evaluate_trigger_runtime, + handle_okr_collection_trigger as handle_okr_collection_trigger_runtime, + handle_okr_report_trigger as handle_okr_report_trigger_runtime, + mark_trigger_fired as mark_trigger_fired_runtime, + mark_trigger_skipped as mark_trigger_skipped_runtime, + should_skip_non_workday as should_skip_non_workday_runtime, +) +from app.services.trigger_runtime.invoker import invoke_agent_for_triggers as invoke_agent_for_triggers_runtime +from app.services.trigger_runtime import ( + claim_ready_trigger_invocations, + enqueue_due_trigger, + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) TICK_INTERVAL = 15 # seconds DEDUP_WINDOW = 30 # seconds — same agent won't be invoked twice within this window @@ -45,921 +49,29 @@ def _cleanup_stale_invoke_cache(): async def _should_skip_non_workday(trigger: AgentTrigger, local_now: datetime) -> bool: - """Skip OKR daily report triggers on company non-workdays when configured.""" - if trigger.name != "daily_okr_collection": - return False - - from app.models.okr import OKRSettings - from app.models.tenant import Tenant - from app.services.business_calendar import is_non_workday - - async with async_session() as db: - result = await db.execute( - select(Agent.tenant_id) - .where(Agent.id == trigger.agent_id) - ) - tenant_id = result.scalar_one_or_none() - if not tenant_id: - return False - - settings_result = await db.execute( - select(OKRSettings.daily_report_skip_non_workdays) - .where(OKRSettings.tenant_id == tenant_id) - ) - skip_enabled = settings_result.scalar_one_or_none() - if skip_enabled is False: - return False - - tenant_result = await db.execute( - select(Tenant.country_region).where(Tenant.id == tenant_id) - ) - country_region = tenant_result.scalar_one_or_none() - - return is_non_workday(local_now.date(), country_region) + return await should_skip_non_workday_runtime(trigger, local_now) async def _mark_trigger_skipped(trigger_id: uuid.UUID, now: datetime) -> None: - """Advance a cron trigger without invoking the agent.""" - try: - async with async_session() as db: - result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) - trigger = result.scalar_one_or_none() - if trigger: - trigger.last_fired_at = now - await db.commit() - except Exception as e: - logger.warning(f"Failed to mark skipped trigger {trigger_id}: {e}") + await mark_trigger_skipped_runtime(trigger_id, now) async def _mark_trigger_fired(trigger_id: uuid.UUID, now: datetime) -> None: - """Persist fire metadata for a trigger that was already handled.""" - try: - async with async_session() as db: - result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) - trigger = result.scalar_one_or_none() - if trigger: - trigger.last_fired_at = now - trigger.fire_count += 1 - if trigger.type == "once": - trigger.is_enabled = False - if trigger.max_fires and trigger.fire_count >= trigger.max_fires: - trigger.is_enabled = False - await db.commit() - except Exception as e: - logger.warning(f"Failed to mark fired trigger {trigger_id}: {e}") + await mark_trigger_fired_runtime(trigger_id, now) async def _handle_okr_report_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Handle company-level OKR report generation without waking the agent.""" - if trigger.name not in {"daily_okr_report", "weekly_okr_report", "monthly_okr_report"}: - return False - - from zoneinfo import ZoneInfo - from app.models.okr import OKRSettings - from app.services.okr_reporting import ( - generate_company_daily_report, - generate_company_monthly_report, - generate_company_weekly_report, - ) - from app.services.timezone_utils import get_agent_timezone - - async with async_session() as db: - agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) - tenant_id = agent_result.scalar_one_or_none() - if not tenant_id: - return True - - settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) - settings = settings_result.scalar_one_or_none() - if not settings or not settings.enabled: - return True - - tz_name = await get_agent_timezone(trigger.agent_id) - try: - tz = ZoneInfo(tz_name) - except Exception: - tz = ZoneInfo("UTC") - local_today = now.astimezone(tz).date() - - if trigger.name == "daily_okr_report": - await generate_company_daily_report(tenant_id, local_today - timedelta(days=1)) - elif trigger.name == "weekly_okr_report": - previous_week_anchor = local_today - timedelta(days=7) - week_start = previous_week_anchor - timedelta(days=previous_week_anchor.weekday()) - await generate_company_weekly_report(tenant_id, week_start) - elif trigger.name == "monthly_okr_report": - previous_month_end = local_today.replace(day=1) - timedelta(days=1) - await generate_company_monthly_report(tenant_id, previous_month_end) - - await _mark_trigger_fired(trigger.id, now) - logger.info(f"[Trigger] Auto-generated OKR report for trigger {trigger.name}") - return True + return await handle_okr_report_trigger_runtime(trigger, now) async def _handle_okr_collection_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Handle deterministic OKR daily collection without relying on a free-form LLM plan.""" - if trigger.name != "daily_okr_collection": - return False - - from app.models.okr import OKRSettings - from app.services.okr_daily_collection import trigger_daily_collection_for_tenant - - async with async_session() as db: - agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) - tenant_id = agent_result.scalar_one_or_none() - if not tenant_id: - return True - - settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) - settings = settings_result.scalar_one_or_none() - if not settings or not settings.enabled or not settings.daily_report_enabled: - return True - - await trigger_daily_collection_for_tenant(tenant_id) - await _mark_trigger_fired(trigger.id, now) - logger.info(f"[Trigger] Deterministic OKR collection sent for trigger {trigger.name}") - return True - -# Webhook rate limiter: token -> list of timestamps -_webhook_hits: dict[str, list[float]] = {} -WEBHOOK_RATE_LIMIT = 5 # max hits per minute per token - - -# ── SSRF Protection ───────────────────────────────────────────────── - -def _is_private_url(url: str) -> bool: - """Block private/internal URLs to prevent SSRF attacks.""" - try: - parsed = urlparse(url) - hostname = parsed.hostname - if not hostname: - return True - - # Block obvious private hostnames - if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): - return True - - # Try to resolve hostname and check IP - import socket - try: - infos = socket.getaddrinfo(hostname, None) - for info in infos: - ip = ipaddress.ip_address(info[4][0]) - if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - return True - except (socket.gaierror, ValueError): - return True # Cannot resolve = block - - return False - except Exception: - return True # Block on any parsing error - - -# ── Trigger Evaluation ────────────────────────────────────────────── + return await handle_okr_collection_trigger_runtime(trigger, now) async def _evaluate_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Return True if this trigger should fire right now.""" - if not trigger.is_enabled: - return False - if trigger.expires_at and now >= trigger.expires_at: - # Auto-disable expired triggers - return False - if trigger.max_fires is not None and trigger.fire_count >= trigger.max_fires: - return False - - # Cooldown check - if trigger.last_fired_at: - cooldown = timedelta(seconds=trigger.cooldown_seconds) - if (now - trigger.last_fired_at) < cooldown: - return False - - cfg = trigger.config or {} - t = trigger.type - - if t == "cron": - expr = cfg.get("expr", "* * * * *") - base = trigger.last_fired_at or trigger.created_at - try: - # Resolve timezone: trigger config → agent → tenant → UTC - tz_name = cfg.get("timezone") - if not tz_name: - from app.services.timezone_utils import get_agent_timezone - tz_name = await get_agent_timezone(trigger.agent_id) - from zoneinfo import ZoneInfo - try: - tz = ZoneInfo(tz_name) - except (KeyError, Exception): - tz = ZoneInfo("UTC") - # Evaluate cron in agent's timezone - local_now = now.astimezone(tz) - local_base = base.astimezone(tz) if base.tzinfo else base.replace(tzinfo=tz) - cron = croniter(expr, local_base) - next_run = cron.get_next(datetime) - if local_now >= next_run: - if await _should_skip_non_workday(trigger, local_now): - await _mark_trigger_skipped(trigger.id, now) - logger.info(f"[Trigger] Skipped {trigger.name} on non-workday {local_now.date()}") - return False - return True - return False - except Exception as e: - logger.warning(f"Invalid cron expr '{expr}' for trigger {trigger.name}: {e}") - return False - - elif t == "once": - at_str = cfg.get("at") - if not at_str: - return False - try: - at = datetime.fromisoformat(at_str) - if at.tzinfo is None: - at = at.replace(tzinfo=timezone.utc) - return now >= at and trigger.fire_count == 0 - except Exception: - return False - - elif t == "interval": - minutes = cfg.get("minutes", 30) - base = trigger.last_fired_at or trigger.created_at - return (now - base) >= timedelta(minutes=minutes) - - elif t == "poll": - interval_min = max(cfg.get("interval_min", 5), MIN_POLL_INTERVAL_MINUTES) - base = trigger.last_fired_at or trigger.created_at - if (now - base) < timedelta(minutes=interval_min): - return False - # Actual HTTP poll + change detection - return await _poll_check(trigger) - - elif t == "on_message": - return await _check_new_agent_messages(trigger) - - elif t == "webhook": - # Check if a webhook payload is pending - if cfg.get("_webhook_pending"): - return True - return False - - return False - - -async def _poll_check(trigger: AgentTrigger) -> bool: - """HTTP poll: fetch URL, extract value via json_path, detect change. - - Persists _last_value into the trigger's config JSONB so it survives - across process restarts. - """ - import httpx - cfg = trigger.config or {} - url = cfg.get("url") - if not url: - return False - - # SSRF protection: block private/internal URLs - if _is_private_url(url): - logger.warning(f"Poll blocked for trigger {trigger.name}: private/internal URL '{url}'") - return False - - try: - async with httpx.AsyncClient(timeout=10) as client: - resp = await client.request(cfg.get("method", "GET"), url, headers=cfg.get("headers", {})) - resp.raise_for_status() - - data = resp.json() - json_path = cfg.get("json_path", "$") - current_value = _extract_json_path(data, json_path) - current_str = str(current_value) - - fire_on = cfg.get("fire_on", "change") - should_fire = False - - if fire_on == "match": - should_fire = current_str == str(cfg.get("match_value", "")) - else: # "change" - last_value = cfg.get("_last_value") - # First poll — don't fire, just record baseline - if last_value is None: - should_fire = False - else: - should_fire = current_str != last_value - - # Persist _last_value to DB so it survives restarts - cfg["_last_value"] = current_str - try: - from sqlalchemy import update - async with async_session() as db: - await db.execute( - update(AgentTrigger) - .where(AgentTrigger.id == trigger.id) - .values(config=cfg) - ) - await db.commit() - except Exception as e: - logger.warning(f"Failed to persist poll _last_value for {trigger.name}: {e}") - - return should_fire - - except Exception as e: - logger.warning(f"Poll failed for trigger {trigger.name}: {e}") - return False - - -def _extract_json_path(data, path: str): - """Simple JSONPath extraction: $.key.subkey → data['key']['subkey'].""" - if path == "$" or not path: - return data - parts = path.lstrip("$.").split(".") - current = data - for part in parts: - if isinstance(current, dict): - current = current.get(part) - elif isinstance(current, list) and part.isdigit(): - current = current[int(part)] - else: - return None - return current - - -async def _check_new_agent_messages(trigger: AgentTrigger) -> bool: - """Check if there are new messages matching this trigger. - - Supports two modes: - - from_agent_name: check for agent-to-agent messages - - from_user_name: check for human user messages (Feishu/Slack/Discord) - - Stores the actual message content in trigger.config['_matched_message'] - so the invocation context can include it. - """ - from app.models.audit import ChatMessage - from app.models.chat_session import ChatSession - - cfg = trigger.config or {} - from_agent_name = cfg.get("from_agent_name") - from_user_name = cfg.get("from_user_name") - - if not from_agent_name and not from_user_name: - return False - - since = trigger.last_fired_at or trigger.created_at - # Use _since_ts snapshot from trigger creation (set by _handle_set_trigger) - # This is more precise than the old 5-minute lookback which caused false positives - if trigger.fire_count == 0 and not trigger.last_fired_at: - since_ts_str = cfg.get("_since_ts") - if since_ts_str: - try: - since = datetime.fromisoformat(since_ts_str) - except Exception: - since = trigger.created_at - # No _since_ts and no last_fired_at → use trigger.created_at (no lookback) - - try: - async with async_session() as db: - if from_agent_name: - # --- Agent-to-agent message check (existing logic) --- - from app.models.participant import Participant - from app.models.agent import Agent as AgentModel - safe_agent_name = from_agent_name.replace("%", "").replace("_", r"\_") - agent_r = await db.execute( - select(AgentModel).where(AgentModel.name.ilike(f"%{safe_agent_name}%")) - ) - source_agent = agent_r.scalars().first() - if not source_agent: - return False - - result = await db.execute( - select(Participant.id).where( - Participant.type == "agent", - Participant.ref_id == source_agent.id, - ) - ) - from_participant = result.scalar_one_or_none() - if not from_participant: - return False - - from sqlalchemy import cast as sa_cast, String as SaString - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatMessage.participant_id == from_participant, - ChatMessage.created_at > since, - ChatMessage.role == "assistant", - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - msg = result.scalar_one_or_none() - if not msg: - return False - cfg["_matched_message"] = (msg.content or "")[:2000] - cfg["_matched_from"] = from_agent_name - return True - - elif from_user_name: - # --- Human user message check (Feishu/Slack/Discord) --- - # Find sessions for this agent from external channels - from sqlalchemy import cast as sa_cast, String as SaString - from app.models.user import User - from app.models.agent import Agent as AgentModel - - # 0. Get agent for tenant scoping - agent_r = await db.execute(select(AgentModel).where(AgentModel.id == trigger.agent_id)) - agent = agent_r.scalar_one_or_none() - - # Look up user by display name or username within tenant - from sqlalchemy import or_ - from app.models.user import User, Identity - safe_user_name = from_user_name.replace("%", "").replace("_", r"\_") - query = ( - select(User) - .join(User.identity) - .where( - or_( - User.display_name.ilike(f"%{safe_user_name}%"), - Identity.username.ilike(f"%{safe_user_name}%"), - ) - ) - ) - if agent and agent.tenant_id: - query = query.where(User.tenant_id == agent.tenant_id) - - user_r = await db.execute(query) - target_user = user_r.scalars().first() - - if target_user: - # Find channel sessions for this user with this agent - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatSession.agent_id == trigger.agent_id, - ChatSession.user_id == target_user.id, - ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), - ChatMessage.role == "user", - ChatMessage.created_at > since, - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - else: - # Fallback: search by session title or message content containing the target name - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatSession.agent_id == trigger.agent_id, - ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), - ChatMessage.role == "user", - ChatMessage.created_at > since, - or_( - ChatSession.title.ilike(f"%{safe_user_name}%"), - ChatMessage.content.ilike(f"%{safe_user_name}%"), - ), - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - - msg = result.scalar_one_or_none() - if not msg: - return False - cfg["_matched_message"] = (msg.content or "")[:2000] - cfg["_matched_from"] = from_user_name - return True - - except Exception as e: - logger.warning(f"on_message check failed for trigger {trigger.name}: {e}") - return False - - -# ── Agent Invocation ──────────────────────────────────────────────── - -async def _resolve_trigger_delivery_target(agent: Agent, triggers: list[AgentTrigger]) -> dict | None: - """Resolve where a trigger result should be delivered. - - Priority: - 1. Explicit A2A callback session - 2. Originating agent-to-agent session - 3. Originating platform user → that user's primary platform session - 4. Pure trigger/reflection context → no user-facing delivery - """ - from app.models.chat_session import ChatSession - from app.services.chat_session_service import ensure_primary_platform_session - - # Synthetic A2A wake triggers already carry the callback session explicitly. - for trigger in triggers: - cfg = trigger.config or {} - a2a_sid = cfg.get("_a2a_session_id") - if a2a_sid: - try: - async with async_session() as db: - session = await db.get(ChatSession, uuid.UUID(a2a_sid)) - if not session: - return None - return { - "kind": "session", - "session_id": str(session.id), - "owner_user_id": str(session.user_id), - "source_channel": session.source_channel, - } - except Exception: - return None - - origin_cfg = None - for trigger in triggers: - cfg = trigger.config or {} - if cfg.get("_origin_session_id") or cfg.get("_origin_user_id"): - origin_cfg = cfg - break - - if not origin_cfg: - return None - - origin_source_channel = origin_cfg.get("_origin_source_channel") - origin_session_id = origin_cfg.get("_origin_session_id") - origin_user_id = origin_cfg.get("_origin_user_id") - - if origin_source_channel == "agent" and origin_session_id: - try: - async with async_session() as db: - session = await db.get(ChatSession, uuid.UUID(origin_session_id)) - if not session: - return None - return { - "kind": "session", - "session_id": str(session.id), - "owner_user_id": str(session.user_id), - "source_channel": "agent", - } - except Exception: - return None - - if origin_source_channel != "trigger" and origin_user_id: - try: - async with async_session() as db: - primary = await ensure_primary_platform_session( - db, - agent.id, - uuid.UUID(origin_user_id), - ) - await db.commit() - return { - "kind": "primary_user_session", - "session_id": str(primary.id), - "owner_user_id": str(primary.user_id), - "source_channel": primary.source_channel, - } - except Exception: - return None - - return None + return await evaluate_trigger_runtime(trigger, now) async def _invoke_agent_for_triggers(agent_id: uuid.UUID, triggers: list[AgentTrigger]): - """Invoke an agent with context from one or more fired triggers. - - Creates a Reflection Session and calls the LLM. - """ - from app.services.llm import call_llm - from app.services.agent_context import build_agent_context - from app.models.llm import LLMModel - from app.models.audit import ChatMessage - from app.models.chat_session import ChatSession - from app.models.participant import Participant - from app.services.audit_logger import write_audit_log - - try: - async with async_session() as db: - # Load agent - result = await db.execute(select(Agent).where(Agent.id == agent_id)) - agent = result.scalar_one_or_none() - if not agent or agent.is_expired: - return - - # Load LLM model - if not agent.primary_model_id: - logger.warning(f"Agent {agent.name} has no LLM model, skipping trigger invocation") - return - result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) - model = result.scalar_one_or_none() - if not model: - return - # Skip invocation if model is disabled by admin - if not model.enabled: - logger.warning(f"Agent {agent.name}'s model {model.model} is disabled, skipping trigger invocation") - return - - # Build trigger context - context_parts = [] - trigger_names = [] - for t in triggers: - part = f"触发器:{t.name} ({t.type})\n原因:{t.reason}" - if t.name == "daily_okr_collection": - part += ( - "\n执行要求:先调用 get_okr_settings 确认日报收集是否开启。" - "如果开启,只能联系你关系网络中的成员和数字员工来收集今天的最终日报," - "并整理成不超过 2000 字的正式日报;" - "如果未开启,则说明本次无需执行并停止。" - ) - elif t.name in ("daily_okr_report", "weekly_okr_report", "monthly_okr_report"): - part += ( - "\n执行要求:本次公司级报表由系统自动汇总生成。" - "如果你被唤醒,仅补充必要说明,不要再次向成员发起收集。" - ) - elif t.name == "biweekly_okr_checkin": - part += ( - "\n执行要求:先调用 get_okr_settings 确认 OKR 是否开启。" - "如果开启,检查当前周期公司和成员 OKR,主动提醒尚未设置或进展滞后的相关成员;" - "如果未开启,则说明本次无需执行并停止。" - ) - elif t.name == "monthly_okr_report": - part += ( - "\n执行要求:先调用 get_okr_settings 确认 OKR 是否开启。" - "如果开启,调用 generate_monthly_okr_report 生成刚结束月份的 OKR 月报,并发送给管理员或发布到广场;" - "如果未开启,则说明本次无需执行并停止。" - ) - if t.focus_ref: - part += f"\n关联 Focus:{t.focus_ref}" - # Include matched message for on_message triggers - cfg = t.config or {} - if t.type == "on_message" and cfg.get("_matched_message"): - part += f"\n收到来自 {cfg.get('_matched_from', '?')} 的消息:\n\"{cfg['_matched_message'][:500]}\"" - if t.type == "on_message" and cfg.get("okr_member_id") and cfg.get("okr_report_date"): - part += ( - "\n执行要求:这是一次日报回复入库事件。" - f"\n1. 将对方回复整理成一段不超过 2000 字的最终日报。" - f"\n2. 立即调用 upsert_member_daily_report(report_date=\"{cfg['okr_report_date']}\", " - f"member_type=\"{cfg.get('okr_member_type', 'user')}\", " - f"member_id=\"{cfg['okr_member_id']}\", content=\"<整理后的日报>\")。" - "\n3. 工具调用成功后,再发送一句简短确认,明确你已收到并已记录。" - "\n4. 不要只回复确认而不调用工具,也不要把原始长对话原样存入日报。" - ) - # Include webhook payload - if t.type == "webhook" and cfg.get("_webhook_payload"): - payload_str = cfg["_webhook_payload"] - if len(payload_str) > 2000: - payload_str = payload_str[:2000] + "... (truncated)" - part += f"\nWebhook Payload:\n{payload_str}" - context_parts.append(part) - trigger_names.append(t.name) - - trigger_context = ( - "===== 本次唤醒上下文 =====\n" - f"唤醒来源:trigger({'多个触发器同时触发' if len(triggers) > 1 else '触发器触发'})\n\n" - + "\n---\n".join(context_parts) - + "\n===========================" - ) - - # Create Reflection Session - title = f"🤖 内心独白:{', '.join(trigger_names)}" - # Find agent's participant - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - session = ChatSession( - agent_id=agent_id, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - source_channel="trigger", - title=title[:200], - ) - db.add(session) - await db.flush() - session_id = session.id - - # Messages: trigger context only (call_llm builds system prompt internally) - messages = [ - {"role": "user", "content": trigger_context}, - ] - - # Store trigger context as a message in the session - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="user", - content=trigger_context, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - await db.commit() - # Cache participant ID for callbacks - agent_participant_id = agent_participant.id if agent_participant else None - - # Call LLM (outside the DB session to avoid long transactions) - collected_content = [] - delivered_platform_message_via_tool = False - - async def on_chunk(text): - collected_content.append(text) - - # Persist tool calls into Reflection Session for Reflections visibility - async def on_tool_call(data): - nonlocal delivered_platform_message_via_tool - try: - tool_name = data.get("name") - tool_status = data.get("status") - if tool_status == "done" and tool_name == "send_platform_message": - result_text = str(data.get("result", "")) - if result_text.startswith("✅"): - delivered_platform_message_via_tool = True - - async with async_session() as _tc_db: - if data["status"] == "running": - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "args": data["args"]}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - elif data["status"] == "done": - result_str = str(data.get("result", ""))[:2000] - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "result": result_str}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - await _tc_db.commit() - except Exception as e: - logger.warning(f"Failed to persist tool call for trigger session: {e}") - - reply = await call_llm( - model=model, - messages=messages, - agent_name=agent.name, - role_description=agent.role_description or "", - agent_id=agent_id, - user_id=agent.creator_id, - session_id=str(session_id), - on_chunk=on_chunk, - on_tool_call=on_tool_call, - # A2A wake uses the agent's own max_tool_rounds setting (no override) - ) - - # Save assistant reply to Reflection session - async with async_session() as db: - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="assistant", - content=reply or "".join(collected_content), - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - - # NOTE: trigger state (last_fired_at, fire_count, auto-disable) - # is already updated in _tick() BEFORE this task was launched, - # to prevent race-condition duplicate fires. - - await db.commit() - - # Compute final reply text once - final_reply = reply or "".join(collected_content) - - # ── Save reply to A2A session if this was an agent-to-agent wake ── - # This makes the target agent's reply visible in the A2A chat history - for t in triggers: - a2a_sid = (t.config or {}).get("_a2a_session_id") - if a2a_sid and final_reply: - try: - async with async_session() as db: - from app.models.participant import Participant as _P - _p_r = await db.execute(select(_P).where(_P.type == "agent", _P.ref_id == agent_id)) - _p = _p_r.scalar_one_or_none() - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=a2a_sid, - role="assistant", - content=final_reply, - user_id=agent.creator_id, - participant_id=_p.id if _p else None, - )) - # Update session timestamp - from app.models.chat_session import ChatSession as _CS - _cs_r = await db.execute(select(_CS).where(_CS.id == uuid.UUID(a2a_sid))) - _cs = _cs_r.scalar_one_or_none() - if _cs: - _cs.last_message_at = datetime.now(timezone.utc) - await db.commit() - logger.info(f"[A2A] Saved reply to A2A session {a2a_sid}") - except Exception as e: - logger.warning(f"[A2A] Failed to save reply to A2A session {a2a_sid}: {e}") - break # Only save once - - # Route trigger results to a single deterministic destination. Pure reflection/system - # wakes stay inside the reflection session and should not spill into arbitrary user chats. - is_a2a_internal = all(t.name == "a2a_wake" for t in triggers) - delivery_target = None if is_a2a_internal else await _resolve_trigger_delivery_target(agent, triggers) - - if final_reply and delivery_target and not delivered_platform_message_via_tool: - try: - from app.api.websocket import manager as ws_manager - agent_id_str = str(agent_id) - - # Build notification message with trigger badge - trigger_reasons = [] - for t in triggers: - ns = (t.config or {}).get("_notification_summary", "").strip() - if ns: - trigger_reasons.append(ns) - else: - r = (t.reason or "").strip() - if r and len(r) <= 80: - trigger_reasons.append(r) - elif r: - trigger_reasons.append(r[:77] + "...") - summary = trigger_reasons[0] if trigger_reasons else "有新的事件需要处理" - - _is_a2a_wait = any(t.name.startswith("a2a_wait_") for t in triggers) - if _is_a2a_wait: - import re as _re - cleaned = final_reply - _internal_patterns = [ - r'\b(a2a_wait_\w+|a2a_wake)\b', - r'\bwait_?\w+_?(task|reply|followup|meeting|sync|api_key)\w*\b', - r'\bresolve_\w+\b', - r'\bfocus[_ ]?item\b', - r'\btask_delegate\b', - r'\bfocus_ref\b', - r'✅\s*(a2a\w+|wait\w+|触发器\w*|focus\w*).*(?:已取消|已为|保持|活跃|完成状态)[^\n]*', - r'[\-•]\s*(?:触发器|trigger|focus|wait_\w+|a2a\w+).*[^\n]*', - r'(?:触发器|trigger)\s+\S+\s*(?:已取消|保持活跃|已为完成状态|fired)', - r'已静默清理触发器', - r'已静默处理完毕', - r'继续待命[。,]?\s*', - r',?\s*(?:继续)?待命。', - ] - for _pat in _internal_patterns: - cleaned = _re.sub(_pat, '', cleaned, flags=_re.IGNORECASE) - cleaned = _re.sub(r'\n{3,}', '\n\n', cleaned).strip() - cleaned = _re.sub(r'[。,]\s*$', '', cleaned).strip() - if not cleaned: - cleaned = final_reply - else: - cleaned = final_reply - - notification = f"⚡ {summary}\n\n{cleaned}" - - target_session_id = delivery_target["session_id"] - owner_user_id = delivery_target.get("owner_user_id") - - # Save to the resolved destination session for persistence. - async with async_session() as db: - from app.models.chat_session import ChatSession - from app.api.websocket import maybe_mark_session_read_for_active_viewer - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=target_session_id, - role="assistant", - content=notification, - user_id=agent.creator_id, - )) - session_row = await db.get(ChatSession, uuid.UUID(target_session_id)) - if session_row: - session_row.last_message_at = datetime.now(timezone.utc) - if owner_user_id: - await maybe_mark_session_read_for_active_viewer( - db, - agent_id=agent_id, - session_id=target_session_id, - user_id=uuid.UUID(owner_user_id), - ) - await db.commit() - - payload = { - "type": "trigger_notification", - "content": notification, - "triggers": [t.name for t in triggers], - "session_id": target_session_id, - } - - # Notify only the user who owns the destination session. The frontend will append - # the message only when that exact session is open; otherwise it just refreshes - # unread/session state. - if owner_user_id: - await ws_manager.send_to_user(agent_id_str, owner_user_id, payload) - except Exception as e: - logger.error(f"Failed to push trigger result to WebSocket: {e}") - import traceback - traceback.print_exc() - - # Audit log - await write_audit_log("trigger_fired", { - "agent_name": agent.name, - "triggers": [{"name": t.name, "type": t.type} for t in triggers], - }, agent_id=agent_id) - - logger.info(f"⚡ Triggers fired for {agent.name}: {[t.name for t in triggers]}") - - except Exception as e: - logger.error(f"Failed to invoke agent {agent_id} for triggers: {e}") - import traceback - traceback.print_exc() + await invoke_agent_for_triggers_runtime(agent_id, triggers) # ── Main Tick Loop ────────────────────────────────────────────────── @@ -978,8 +90,8 @@ async def _tick(): return - # Evaluate and group fired triggers by agent - fired_by_agent: dict[uuid.UUID, list[AgentTrigger]] = {} + # Evaluate and enqueue due triggers. Agent invocation happens only after + # executions are claimed through the distributed execution queue. for trigger in all_triggers: # Auto-disable expired triggers if trigger.expires_at and now >= trigger.expires_at: @@ -997,14 +109,22 @@ async def _tick(): if not handled: handled = await _handle_okr_collection_trigger(trigger, now) if not handled: - fired_by_agent.setdefault(trigger.agent_id, []).append(trigger) + await enqueue_due_trigger(trigger, now) except Exception as e: logger.warning(f"Error evaluating trigger {trigger.name}: {e}") + # Claim queued executions with a DB lease so only one worker handles each event. + try: + fired_by_agent, force_invoke_agents = await claim_ready_trigger_invocations(now) + except Exception as e: + logger.warning(f"Failed to claim trigger executions: {e}") + fired_by_agent = {} + force_invoke_agents = set() + # Invoke each agent (with dedup window) for agent_id, agent_triggers in fired_by_agent.items(): last = _last_invoke.get(agent_id) - if last and (now - last).total_seconds() < DEDUP_WINDOW: + if agent_id not in force_invoke_agents and last and (now - last).total_seconds() < DEDUP_WINDOW: continue # Skip — invoked too recently _last_invoke[agent_id] = now @@ -1016,6 +136,8 @@ async def _tick(): try: async with async_session() as db: for t in agent_triggers: + if (t.config or {}).get("_execution_id"): + continue result = await db.execute( select(AgentTrigger).where(AgentTrigger.id == t.id) ) @@ -1026,12 +148,6 @@ async def _tick(): # Auto-disable single-shot types only if trigger.type == "once": trigger.is_enabled = False - if trigger.type == "webhook" and trigger.config: - trigger.config = { - **trigger.config, - "_webhook_pending": False, - "_webhook_payload": None, - } if trigger.max_fires and trigger.fire_count >= trigger.max_fires: trigger.is_enabled = False await db.commit() diff --git a/backend/app/services/trigger_runtime/__init__.py b/backend/app/services/trigger_runtime/__init__.py new file mode 100644 index 000000000..24bf9da1c --- /dev/null +++ b/backend/app/services/trigger_runtime/__init__.py @@ -0,0 +1,30 @@ +"""Distributed trigger runtime helpers.""" + +from app.services.trigger_runtime.dispatch import ( + claim_ready_trigger_invocations, + enqueue_due_trigger, + runtime_execution_payload, +) +from app.services.trigger_runtime.executions import ( + build_execution_runtime_trigger, + claim_pending_trigger_executions, + mark_base_triggers_fired, + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) +from app.services.trigger_runtime.keys import build_scheduled_execution_key +from app.services.trigger_runtime.queue import enqueue_trigger_execution, enqueue_webhook_execution + +__all__ = [ + "build_execution_runtime_trigger", + "build_scheduled_execution_key", + "claim_ready_trigger_invocations", + "claim_pending_trigger_executions", + "enqueue_due_trigger", + "enqueue_trigger_execution", + "enqueue_webhook_execution", + "mark_base_triggers_fired", + "mark_trigger_executions_completed", + "mark_trigger_executions_failed", + "runtime_execution_payload", +] diff --git a/backend/app/services/trigger_runtime/dispatch.py b/backend/app/services/trigger_runtime/dispatch.py new file mode 100644 index 000000000..7d18e83ef --- /dev/null +++ b/backend/app/services/trigger_runtime/dispatch.py @@ -0,0 +1,64 @@ +"""Dispatch helpers for trigger executions.""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from app.database import async_session +from app.models.trigger import AgentTrigger +from app.services.trigger_runtime.executions import ( + build_execution_runtime_trigger, + claim_pending_trigger_executions, + mark_base_triggers_fired, +) +from app.services.trigger_runtime.keys import build_scheduled_execution_key +from app.services.trigger_runtime.queue import enqueue_trigger_execution + + +def runtime_execution_payload(trigger: AgentTrigger) -> dict: + """Capture ephemeral trigger evaluation context into an execution payload.""" + cfg = trigger.config or {} + payload: dict = {} + for key in ( + "_matched_message", + "_matched_from", + "okr_member_id", + "okr_member_type", + "okr_report_date", + "_notification_summary", + "_origin_session_id", + "_origin_user_id", + "_origin_source_channel", + "_a2a_session_id", + ): + if key in cfg and cfg.get(key) is not None: + payload[key] = cfg.get(key) + return payload + + +async def enqueue_due_trigger(trigger: AgentTrigger, now: datetime) -> None: + async with async_session() as db: + await enqueue_trigger_execution( + db, + trigger=trigger, + source=trigger.type, + idempotency_key=build_scheduled_execution_key(trigger, now), + payload_obj=runtime_execution_payload(trigger), + ) + + +async def claim_ready_trigger_invocations(now: datetime) -> tuple[dict[uuid.UUID, list[AgentTrigger]], set[uuid.UUID]]: + fired_by_agent: dict[uuid.UUID, list[AgentTrigger]] = {} + force_invoke_agents: set[uuid.UUID] = set() + + claimed_executions = await claim_pending_trigger_executions() + if claimed_executions: + await mark_base_triggers_fired([trigger.id for _execution, trigger in claimed_executions], now) + + for execution, trigger in claimed_executions: + runtime_trigger = build_execution_runtime_trigger(trigger, execution) + fired_by_agent.setdefault(trigger.agent_id, []).append(runtime_trigger) + force_invoke_agents.add(trigger.agent_id) + + return fired_by_agent, force_invoke_agents diff --git a/backend/app/services/trigger_runtime/evaluator.py b/backend/app/services/trigger_runtime/evaluator.py new file mode 100644 index 000000000..6a43aa63b --- /dev/null +++ b/backend/app/services/trigger_runtime/evaluator.py @@ -0,0 +1,429 @@ +"""Trigger evaluation and deterministic special-case handlers.""" + +from __future__ import annotations + +import ipaddress +import uuid +from datetime import datetime, timezone, timedelta +from urllib.parse import urlparse + +from croniter import croniter +from loguru import logger +from sqlalchemy import select + +from app.database import async_session +from app.models.agent import Agent +from app.models.trigger import AgentTrigger + +MIN_POLL_INTERVAL_MINUTES = 5 + + +async def should_skip_non_workday(trigger: AgentTrigger, local_now: datetime) -> bool: + if trigger.name != "daily_okr_collection": + return False + + from app.models.okr import OKRSettings + from app.models.tenant import Tenant + from app.services.business_calendar import is_non_workday + + async with async_session() as db: + result = await db.execute( + select(Agent.tenant_id).where(Agent.id == trigger.agent_id) + ) + tenant_id = result.scalar_one_or_none() + if not tenant_id: + return False + + settings_result = await db.execute( + select(OKRSettings.daily_report_skip_non_workdays).where(OKRSettings.tenant_id == tenant_id) + ) + skip_enabled = settings_result.scalar_one_or_none() + if skip_enabled is False: + return False + + tenant_result = await db.execute( + select(Tenant.country_region).where(Tenant.id == tenant_id) + ) + country_region = tenant_result.scalar_one_or_none() + + return is_non_workday(local_now.date(), country_region) + + +async def mark_trigger_skipped(trigger_id: uuid.UUID, now: datetime) -> None: + try: + async with async_session() as db: + result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) + trigger = result.scalar_one_or_none() + if trigger: + trigger.last_fired_at = now + await db.commit() + except Exception as e: + logger.warning(f"Failed to mark skipped trigger {trigger_id}: {e}") + + +async def mark_trigger_fired(trigger_id: uuid.UUID, now: datetime) -> None: + try: + async with async_session() as db: + result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) + trigger = result.scalar_one_or_none() + if trigger: + trigger.last_fired_at = now + trigger.fire_count += 1 + if trigger.type == "once": + trigger.is_enabled = False + if trigger.max_fires and trigger.fire_count >= trigger.max_fires: + trigger.is_enabled = False + await db.commit() + except Exception as e: + logger.warning(f"Failed to mark fired trigger {trigger_id}: {e}") + + +async def handle_okr_report_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if trigger.name not in {"daily_okr_report", "weekly_okr_report", "monthly_okr_report"}: + return False + + from zoneinfo import ZoneInfo + from app.models.okr import OKRSettings + from app.services.okr_reporting import ( + generate_company_daily_report, + generate_company_monthly_report, + generate_company_weekly_report, + ) + from app.services.timezone_utils import get_agent_timezone + + async with async_session() as db: + agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) + tenant_id = agent_result.scalar_one_or_none() + if not tenant_id: + return True + + settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) + settings = settings_result.scalar_one_or_none() + if not settings or not settings.enabled: + return True + + tz_name = await get_agent_timezone(trigger.agent_id) + try: + tz = ZoneInfo(tz_name) + except Exception: + tz = ZoneInfo("UTC") + local_today = now.astimezone(tz).date() + + if trigger.name == "daily_okr_report": + await generate_company_daily_report(tenant_id, local_today - timedelta(days=1)) + elif trigger.name == "weekly_okr_report": + previous_week_anchor = local_today - timedelta(days=7) + week_start = previous_week_anchor - timedelta(days=previous_week_anchor.weekday()) + await generate_company_weekly_report(tenant_id, week_start) + elif trigger.name == "monthly_okr_report": + previous_month_end = local_today.replace(day=1) - timedelta(days=1) + await generate_company_monthly_report(tenant_id, previous_month_end) + + await mark_trigger_fired(trigger.id, now) + logger.info(f"[Trigger] Auto-generated OKR report for trigger {trigger.name}") + return True + + +async def handle_okr_collection_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if trigger.name != "daily_okr_collection": + return False + + from app.models.okr import OKRSettings + from app.services.okr_daily_collection import trigger_daily_collection_for_tenant + + async with async_session() as db: + agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) + tenant_id = agent_result.scalar_one_or_none() + if not tenant_id: + return True + + settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) + settings = settings_result.scalar_one_or_none() + if not settings or not settings.enabled or not settings.daily_report_enabled: + return True + + await trigger_daily_collection_for_tenant(tenant_id) + await mark_trigger_fired(trigger.id, now) + logger.info(f"[Trigger] Deterministic OKR collection sent for trigger {trigger.name}") + return True + + +def is_private_url(url: str) -> bool: + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return True + if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): + return True + import socket + try: + infos = socket.getaddrinfo(hostname, None) + for info in infos: + ip = ipaddress.ip_address(info[4][0]) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return True + except (socket.gaierror, ValueError): + return True + return False + except Exception: + return True + + +async def evaluate_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if not trigger.is_enabled: + return False + if trigger.expires_at and now >= trigger.expires_at: + return False + if trigger.max_fires is not None and trigger.fire_count >= trigger.max_fires: + return False + + if trigger.last_fired_at: + cooldown = timedelta(seconds=trigger.cooldown_seconds) + if (now - trigger.last_fired_at) < cooldown: + return False + + cfg = trigger.config or {} + t = trigger.type + + if t == "cron": + expr = cfg.get("expr", "* * * * *") + base = trigger.last_fired_at or trigger.created_at + try: + tz_name = cfg.get("timezone") + if not tz_name: + from app.services.timezone_utils import get_agent_timezone + tz_name = await get_agent_timezone(trigger.agent_id) + from zoneinfo import ZoneInfo + try: + tz = ZoneInfo(tz_name) + except (KeyError, Exception): + tz = ZoneInfo("UTC") + local_now = now.astimezone(tz) + local_base = base.astimezone(tz) if base.tzinfo else base.replace(tzinfo=tz) + cron = croniter(expr, local_base) + next_run = cron.get_next(datetime) + if local_now >= next_run: + if await should_skip_non_workday(trigger, local_now): + await mark_trigger_skipped(trigger.id, now) + logger.info(f"[Trigger] Skipped {trigger.name} on non-workday {local_now.date()}") + return False + return True + return False + except Exception as e: + logger.warning(f"Invalid cron expr '{expr}' for trigger {trigger.name}: {e}") + return False + + if t == "once": + at_str = cfg.get("at") + if not at_str: + return False + try: + at = datetime.fromisoformat(at_str) + if at.tzinfo is None: + at = at.replace(tzinfo=timezone.utc) + return now >= at and trigger.fire_count == 0 + except Exception: + return False + + if t == "interval": + minutes = cfg.get("minutes", 30) + base = trigger.last_fired_at or trigger.created_at + return (now - base) >= timedelta(minutes=minutes) + + if t == "poll": + interval_min = max(cfg.get("interval_min", 5), MIN_POLL_INTERVAL_MINUTES) + base = trigger.last_fired_at or trigger.created_at + if (now - base) < timedelta(minutes=interval_min): + return False + return await poll_check(trigger) + + if t == "on_message": + return await check_new_agent_messages(trigger) + + if t == "webhook": + return False + + return False + + +async def poll_check(trigger: AgentTrigger) -> bool: + import httpx + + cfg = trigger.config or {} + url = cfg.get("url") + if not url: + return False + if is_private_url(url): + logger.warning(f"Poll blocked for trigger {trigger.name}: private/internal URL '{url}'") + return False + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.request(cfg.get("method", "GET"), url, headers=cfg.get("headers", {})) + resp.raise_for_status() + + data = resp.json() + json_path = cfg.get("json_path", "$") + current_value = extract_json_path(data, json_path) + current_str = str(current_value) + fire_on = cfg.get("fire_on", "change") + should_fire = False + if fire_on == "match": + should_fire = current_str == str(cfg.get("match_value", "")) + else: + last_value = cfg.get("_last_value") + should_fire = last_value is not None and current_str != last_value + + cfg["_last_value"] = current_str + try: + from sqlalchemy import update + async with async_session() as db: + await db.execute( + update(AgentTrigger).where(AgentTrigger.id == trigger.id).values(config=cfg) + ) + await db.commit() + except Exception as e: + logger.warning(f"Failed to persist poll _last_value for {trigger.name}: {e}") + + return should_fire + except Exception as e: + logger.warning(f"Poll failed for trigger {trigger.name}: {e}") + return False + + +def extract_json_path(data, path: str): + if path == "$" or not path: + return data + parts = path.lstrip("$.").split(".") + current = data + for part in parts: + if isinstance(current, dict): + current = current.get(part) + elif isinstance(current, list) and part.isdigit(): + current = current[int(part)] + else: + return None + return current + + +async def check_new_agent_messages(trigger: AgentTrigger) -> bool: + from app.models.audit import ChatMessage + from app.models.chat_session import ChatSession + + cfg = trigger.config or {} + from_agent_name = cfg.get("from_agent_name") + from_user_name = cfg.get("from_user_name") + if not from_agent_name and not from_user_name: + return False + + since = trigger.last_fired_at or trigger.created_at + if trigger.fire_count == 0 and not trigger.last_fired_at: + since_ts_str = cfg.get("_since_ts") + if since_ts_str: + try: + since = datetime.fromisoformat(since_ts_str) + except Exception: + since = trigger.created_at + + try: + async with async_session() as db: + if from_agent_name: + from app.models.participant import Participant + from app.models.agent import Agent as AgentModel + safe_agent_name = from_agent_name.replace("%", "").replace("_", r"\_") + agent_r = await db.execute(select(AgentModel).where(AgentModel.name.ilike(f"%{safe_agent_name}%"))) + source_agent = agent_r.scalars().first() + if not source_agent: + return False + result = await db.execute( + select(Participant.id).where(Participant.type == "agent", Participant.ref_id == source_agent.id) + ) + from_participant = result.scalar_one_or_none() + if not from_participant: + return False + from sqlalchemy import String as SaString, cast as sa_cast + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatMessage.participant_id == from_participant, + ChatMessage.created_at > since, + ChatMessage.role == "assistant", + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + msg = result.scalar_one_or_none() + if not msg: + return False + cfg["_matched_message"] = (msg.content or "")[:2000] + cfg["_matched_from"] = from_agent_name + return True + + if from_user_name: + from sqlalchemy import or_ + from sqlalchemy import String as SaString, cast as sa_cast + from app.models.agent import Agent as AgentModel + from app.models.user import Identity, User + + agent_r = await db.execute(select(AgentModel).where(AgentModel.id == trigger.agent_id)) + agent = agent_r.scalar_one_or_none() + safe_user_name = from_user_name.replace("%", "").replace("_", r"\_") + query = ( + select(User) + .join(User.identity) + .where( + or_( + User.display_name.ilike(f"%{safe_user_name}%"), + Identity.username.ilike(f"%{safe_user_name}%"), + ) + ) + ) + if agent and agent.tenant_id: + query = query.where(User.tenant_id == agent.tenant_id) + user_r = await db.execute(query) + target_user = user_r.scalars().first() + + if target_user: + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatSession.agent_id == trigger.agent_id, + ChatSession.user_id == target_user.id, + ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), + ChatMessage.role == "user", + ChatMessage.created_at > since, + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + else: + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatSession.agent_id == trigger.agent_id, + ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), + ChatMessage.role == "user", + ChatMessage.created_at > since, + or_( + ChatSession.title.ilike(f"%{safe_user_name}%"), + ChatMessage.content.ilike(f"%{safe_user_name}%"), + ), + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + + msg = result.scalar_one_or_none() + if not msg: + return False + cfg["_matched_message"] = (msg.content or "")[:2000] + cfg["_matched_from"] = from_user_name + return True + except Exception as e: + logger.warning(f"on_message check failed for trigger {trigger.name}: {e}") + return False + + return False diff --git a/backend/app/services/trigger_runtime/executions.py b/backend/app/services/trigger_runtime/executions.py new file mode 100644 index 000000000..154b07854 --- /dev/null +++ b/backend/app/services/trigger_runtime/executions.py @@ -0,0 +1,131 @@ +"""Execution claiming and completion helpers for distributed triggers.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import or_, select + +from app.config import get_settings +from app.database import async_session +from app.models.trigger import AgentTrigger +from app.models.trigger_execution import TriggerExecution + +settings = get_settings() + + +async def mark_trigger_executions_completed(execution_ids: list[uuid.UUID]) -> None: + if not execution_ids: + return + async with async_session() as db: + result = await db.execute( + select(TriggerExecution).where(TriggerExecution.id.in_(execution_ids)) + ) + for execution in result.scalars().all(): + execution.status = "completed" + execution.finished_at = datetime.now(timezone.utc) + execution.lease_owner = None + execution.lease_expires_at = None + execution.last_error = None + await db.commit() + + +async def mark_trigger_executions_failed(execution_ids: list[uuid.UUID], error_text: str) -> None: + if not execution_ids: + return + async with async_session() as db: + result = await db.execute( + select(TriggerExecution).where(TriggerExecution.id.in_(execution_ids)) + ) + for execution in result.scalars().all(): + execution.status = "failed" + execution.finished_at = datetime.now(timezone.utc) + execution.lease_owner = None + execution.lease_expires_at = None + execution.last_error = error_text + await db.commit() + + +async def claim_pending_trigger_executions( + *, + sources: list[str] | None = None, + limit: int = 100, +) -> list[tuple[TriggerExecution, AgentTrigger]]: + now = datetime.now(timezone.utc) + lease_until = now + timedelta(minutes=5) + claimed_pairs: list[tuple[TriggerExecution, AgentTrigger]] = [] + sources = sources or ["webhook", "cron", "once", "interval", "poll", "on_message"] + async with async_session() as db: + result = await db.execute( + select(TriggerExecution, AgentTrigger) + .join(AgentTrigger, AgentTrigger.id == TriggerExecution.trigger_id) + .where( + TriggerExecution.source.in_(sources), + AgentTrigger.is_enabled == True, + or_( + TriggerExecution.status == "pending", + (TriggerExecution.status == "processing") & ( + (TriggerExecution.lease_expires_at == None) | (TriggerExecution.lease_expires_at < now) + ), + ), + ) + .order_by(TriggerExecution.scheduled_at.asc()) + .with_for_update(skip_locked=True) + .limit(limit) + ) + rows = result.all() + for execution, trigger in rows: + execution.status = "processing" + execution.started_at = execution.started_at or now + execution.finished_at = None + execution.lease_owner = settings.INSTANCE_ID + execution.lease_expires_at = lease_until + claimed_pairs.append((execution, trigger)) + await db.commit() + return claimed_pairs + + +def build_execution_runtime_trigger(trigger: AgentTrigger, execution: TriggerExecution) -> AgentTrigger: + runtime_cfg = { + **(trigger.config or {}), + "_execution_id": str(execution.id), + } + if execution.payload: + runtime_cfg.update(execution.payload) + if execution.payload_text: + runtime_cfg["_webhook_payload"] = execution.payload_text + return AgentTrigger( + id=trigger.id, + agent_id=trigger.agent_id, + name=trigger.name, + type=trigger.type, + config=runtime_cfg, + reason=trigger.reason, + focus_ref=trigger.focus_ref, + is_enabled=trigger.is_enabled, + last_fired_at=trigger.last_fired_at, + fire_count=trigger.fire_count, + max_fires=trigger.max_fires, + cooldown_seconds=trigger.cooldown_seconds, + is_system=trigger.is_system, + created_at=trigger.created_at, + expires_at=trigger.expires_at, + ) + + +async def mark_base_triggers_fired(trigger_ids: list[uuid.UUID], now: datetime) -> None: + if not trigger_ids: + return + async with async_session() as db: + result = await db.execute( + select(AgentTrigger).where(AgentTrigger.id.in_(trigger_ids)) + ) + for trigger in result.scalars().all(): + trigger.last_fired_at = now + trigger.fire_count += 1 + if trigger.type == "once": + trigger.is_enabled = False + if trigger.max_fires and trigger.fire_count >= trigger.max_fires: + trigger.is_enabled = False + await db.commit() diff --git a/backend/app/services/trigger_runtime/invoker.py b/backend/app/services/trigger_runtime/invoker.py new file mode 100644 index 000000000..6784f32ee --- /dev/null +++ b/backend/app/services/trigger_runtime/invoker.py @@ -0,0 +1,368 @@ +"""Trigger invocation and delivery orchestration.""" + +from __future__ import annotations + +import json as _json +import uuid +from datetime import datetime, timezone + +from loguru import logger +from sqlalchemy import select + +from app.database import async_session +from app.models.agent import Agent +from app.models.trigger import AgentTrigger +from app.services.trigger_runtime import ( + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) + + +async def resolve_trigger_delivery_target(agent: Agent, triggers: list[AgentTrigger]) -> dict | None: + from app.models.chat_session import ChatSession + from app.services.chat_session_service import ensure_primary_platform_session + + for trigger in triggers: + cfg = trigger.config or {} + a2a_sid = cfg.get("_a2a_session_id") + if a2a_sid: + try: + async with async_session() as db: + session = await db.get(ChatSession, uuid.UUID(a2a_sid)) + if not session: + return None + return { + "kind": "session", + "session_id": str(session.id), + "owner_user_id": str(session.user_id), + "source_channel": session.source_channel, + } + except Exception: + return None + + origin_cfg = None + for trigger in triggers: + cfg = trigger.config or {} + if cfg.get("_origin_session_id") or cfg.get("_origin_user_id"): + origin_cfg = cfg + break + if not origin_cfg: + return None + + origin_source_channel = origin_cfg.get("_origin_source_channel") + origin_session_id = origin_cfg.get("_origin_session_id") + origin_user_id = origin_cfg.get("_origin_user_id") + + if origin_source_channel == "agent" and origin_session_id: + try: + async with async_session() as db: + session = await db.get(ChatSession, uuid.UUID(origin_session_id)) + if not session: + return None + return { + "kind": "session", + "session_id": str(session.id), + "owner_user_id": str(session.user_id), + "source_channel": "agent", + } + except Exception: + return None + + if origin_source_channel != "trigger" and origin_user_id: + try: + async with async_session() as db: + primary = await ensure_primary_platform_session(db, agent.id, uuid.UUID(origin_user_id)) + await db.commit() + return { + "kind": "primary_user_session", + "session_id": str(primary.id), + "owner_user_id": str(primary.user_id), + "source_channel": primary.source_channel, + } + except Exception: + return None + + return None + + +async def invoke_agent_for_triggers(agent_id: uuid.UUID, triggers: list[AgentTrigger]): + from app.models.audit import ChatMessage + from app.models.chat_session import ChatSession + from app.models.llm import LLMModel + from app.models.participant import Participant + from app.services.audit_logger import write_audit_log + from app.services.llm import call_llm + + try: + execution_ids = [ + uuid.UUID(str((t.config or {}).get("_execution_id"))) + for t in triggers + if (t.config or {}).get("_execution_id") + ] + async with async_session() as db: + result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = result.scalar_one_or_none() + if not agent or agent.is_expired: + return + + if not agent.primary_model_id: + logger.warning(f"Agent {agent.name} has no LLM model, skipping trigger invocation") + return + result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) + model = result.scalar_one_or_none() + if not model or not model.enabled: + logger.warning(f"Agent {agent.name}'s model is unavailable, skipping trigger invocation") + return + + context_parts = [] + trigger_names = [] + for t in triggers: + part = f"触发器:{t.name} ({t.type})\n原因:{t.reason}" + if t.name == "daily_okr_collection": + part += ( + "\n执行要求:先调用 get_okr_settings 确认日报收集是否开启。" + "如果开启,只能联系你关系网络中的成员和数字员工来收集今天的最终日报," + "并整理成不超过 2000 字的正式日报;" + "如果未开启,则说明本次无需执行并停止。" + ) + elif t.name in ("daily_okr_report", "weekly_okr_report", "monthly_okr_report"): + part += ( + "\n执行要求:本次公司级报表由系统自动汇总生成。" + "如果你被唤醒,仅补充必要说明,不要再次向成员发起收集。" + ) + elif t.name == "biweekly_okr_checkin": + part += ( + "\n执行要求:先调用 get_okr_settings 确认 OKR 是否开启。" + "如果开启,检查当前周期公司和成员 OKR,主动提醒尚未设置或进展滞后的相关成员;" + "如果未开启,则说明本次无需执行并停止。" + ) + if t.focus_ref: + part += f"\n关联 Focus:{t.focus_ref}" + cfg = t.config or {} + if t.type == "on_message" and cfg.get("_matched_message"): + part += f"\n收到来自 {cfg.get('_matched_from', '?')} 的消息:\n\"{cfg['_matched_message'][:500]}\"" + if t.type == "on_message" and cfg.get("okr_member_id") and cfg.get("okr_report_date"): + part += ( + "\n执行要求:这是一次日报回复入库事件。" + f"\n1. 将对方回复整理成一段不超过 2000 字的最终日报。" + f"\n2. 立即调用 upsert_member_daily_report(report_date=\"{cfg['okr_report_date']}\", " + f"member_type=\"{cfg.get('okr_member_type', 'user')}\", " + f"member_id=\"{cfg['okr_member_id']}\", content=\"<整理后的日报>\")。" + "\n3. 工具调用成功后,再发送一句简短确认,明确你已收到并已记录。" + "\n4. 不要只回复确认而不调用工具,也不要把原始长对话原样存入日报。" + ) + if t.type == "webhook" and cfg.get("_webhook_payload"): + payload_str = cfg["_webhook_payload"] + if len(payload_str) > 2000: + payload_str = payload_str[:2000] + "... (truncated)" + part += f"\nWebhook Payload:\n{payload_str}" + context_parts.append(part) + trigger_names.append(t.name) + + trigger_context = ( + "===== 本次唤醒上下文 =====\n" + f"唤醒来源:trigger({'多个触发器同时触发' if len(triggers) > 1 else '触发器触发'})\n\n" + + "\n---\n".join(context_parts) + + "\n===========================" + ) + + title = f"🤖 内心独白:{', '.join(trigger_names)}" + result = await db.execute( + select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) + ) + agent_participant = result.scalar_one_or_none() + + session = ChatSession( + agent_id=agent_id, + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + source_channel="trigger", + title=title[:200], + ) + db.add(session) + await db.flush() + session_id = session.id + messages = [{"role": "user", "content": trigger_context}] + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="user", + content=trigger_context, + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + )) + await db.commit() + agent_participant_id = agent_participant.id if agent_participant else None + + collected_content: list[str] = [] + delivered_platform_message_via_tool = False + + async def on_chunk(text): + collected_content.append(text) + + async def on_tool_call(data): + nonlocal delivered_platform_message_via_tool + try: + tool_name = data.get("name") + tool_status = data.get("status") + if tool_status == "done" and tool_name == "send_platform_message": + result_text = str(data.get("result", "")) + if result_text.startswith("✅"): + delivered_platform_message_via_tool = True + + async with async_session() as _tc_db: + if data["status"] == "running": + _tc_db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="tool_call", + content=_json.dumps({"name": data["name"], "args": data["args"]}, ensure_ascii=False, default=str), + user_id=agent.creator_id, + participant_id=agent_participant_id, + )) + elif data["status"] == "done": + result_str = str(data.get("result", ""))[:2000] + _tc_db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="tool_call", + content=_json.dumps({"name": data["name"], "result": result_str}, ensure_ascii=False, default=str), + user_id=agent.creator_id, + participant_id=agent_participant_id, + )) + await _tc_db.commit() + except Exception as e: + logger.warning(f"Failed to persist tool call for trigger session: {e}") + + reply = await call_llm( + model=model, + messages=messages, + agent_name=agent.name, + role_description=agent.role_description or "", + agent_id=agent_id, + user_id=agent.creator_id, + session_id=str(session_id), + on_chunk=on_chunk, + on_tool_call=on_tool_call, + ) + + async with async_session() as db: + result = await db.execute( + select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) + ) + agent_participant = result.scalar_one_or_none() + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="assistant", + content=reply or "".join(collected_content), + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + )) + await db.commit() + + final_reply = reply or "".join(collected_content) + for t in triggers: + a2a_sid = (t.config or {}).get("_a2a_session_id") + if a2a_sid and final_reply: + try: + async with async_session() as db: + from app.models.participant import Participant as _P + _p_r = await db.execute(select(_P).where(_P.type == "agent", _P.ref_id == agent_id)) + _p = _p_r.scalar_one_or_none() + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=a2a_sid, + role="assistant", + content=final_reply, + user_id=agent.creator_id, + participant_id=_p.id if _p else None, + )) + from app.models.chat_session import ChatSession as _CS + _cs_r = await db.execute(select(_CS).where(_CS.id == uuid.UUID(a2a_sid))) + _cs = _cs_r.scalar_one_or_none() + if _cs: + _cs.last_message_at = datetime.now(timezone.utc) + await db.commit() + except Exception as e: + logger.warning(f"[A2A] Failed to save reply to A2A session {a2a_sid}: {e}") + break + + is_a2a_internal = all(t.name == "a2a_wake" for t in triggers) + delivery_target = None if is_a2a_internal else await resolve_trigger_delivery_target(agent, triggers) + + if final_reply and delivery_target and not delivered_platform_message_via_tool: + try: + from app.api.websocket import manager as ws_manager + agent_id_str = str(agent_id) + trigger_reasons = [] + for t in triggers: + ns = (t.config or {}).get("_notification_summary", "").strip() + if ns: + trigger_reasons.append(ns) + else: + r = (t.reason or "").strip() + if r and len(r) <= 80: + trigger_reasons.append(r) + elif r: + trigger_reasons.append(r[:77] + "...") + summary = trigger_reasons[0] if trigger_reasons else "有新的事件需要处理" + notification = f"⚡ {summary}\n\n{final_reply}" + target_session_id = delivery_target["session_id"] + owner_user_id = delivery_target.get("owner_user_id") + + async with async_session() as db: + from app.api.websocket import maybe_mark_session_read_for_active_viewer + from app.models.chat_session import ChatSession + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=target_session_id, + role="assistant", + content=notification, + user_id=agent.creator_id, + )) + session_row = await db.get(ChatSession, uuid.UUID(target_session_id)) + if session_row: + session_row.last_message_at = datetime.now(timezone.utc) + if owner_user_id: + await maybe_mark_session_read_for_active_viewer( + db, + agent_id=agent_id, + session_id=target_session_id, + user_id=uuid.UUID(owner_user_id), + ) + await db.commit() + + if owner_user_id: + await ws_manager.send_to_user( + agent_id_str, + owner_user_id, + { + "type": "trigger_notification", + "content": notification, + "triggers": [t.name for t in triggers], + "session_id": target_session_id, + }, + ) + except Exception as e: + logger.error(f"Failed to push trigger result to WebSocket: {e}") + + await write_audit_log( + "trigger_fired", + {"agent_name": agent.name, "triggers": [{"name": t.name, "type": t.type} for t in triggers]}, + agent_id=agent_id, + ) + + if execution_ids: + await mark_trigger_executions_completed(execution_ids) + except Exception as e: + logger.error(f"Failed to invoke agent {agent_id} for triggers: {e}") + import traceback + traceback.print_exc() + execution_ids = [ + uuid.UUID(str((t.config or {}).get("_execution_id"))) + for t in triggers + if (t.config or {}).get("_execution_id") + ] + if execution_ids: + await mark_trigger_executions_failed(execution_ids, str(e)[:2000]) diff --git a/backend/app/services/trigger_runtime/keys.py b/backend/app/services/trigger_runtime/keys.py new file mode 100644 index 000000000..4fbb3178a --- /dev/null +++ b/backend/app/services/trigger_runtime/keys.py @@ -0,0 +1,47 @@ +"""Deterministic idempotency keys for trigger executions.""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timedelta, timezone + +from croniter import croniter + +from app.models.trigger import AgentTrigger + + +def build_scheduled_execution_key(trigger: AgentTrigger, now: datetime) -> str: + """Build a deterministic idempotency key for non-webhook trigger runs.""" + cfg = trigger.config or {} + trigger_type = trigger.type + + if trigger_type == "once": + return f"once:{trigger.id}:{cfg.get('at', '')}" + + if trigger_type == "interval": + minutes = int(cfg.get("minutes", 30) or 30) + base = trigger.last_fired_at or trigger.created_at + due_at = base + timedelta(minutes=minutes) + return f"interval:{trigger.id}:{due_at.astimezone(timezone.utc).isoformat()}" + + if trigger_type == "cron": + expr = cfg.get("expr", "* * * * *") + base = trigger.last_fired_at or trigger.created_at + cron = croniter(expr, base) + due_at = cron.get_next(datetime) + if due_at.tzinfo is None: + due_at = due_at.replace(tzinfo=timezone.utc) + return f"cron:{trigger.id}:{due_at.astimezone(timezone.utc).isoformat()}" + + if trigger_type == "on_message": + matched_from = str(cfg.get("_matched_from") or "") + matched_message = str(cfg.get("_matched_message") or "") + digest = hashlib.sha256(f"{matched_from}\n{matched_message}".encode("utf-8")).hexdigest() + return f"on_message:{trigger.id}:{digest}" + + if trigger_type == "poll": + current_value = str(cfg.get("_last_value") or "") + digest = hashlib.sha256(current_value.encode("utf-8")).hexdigest() + return f"poll:{trigger.id}:{digest}" + + return f"{trigger_type}:{trigger.id}:{now.replace(microsecond=0).isoformat()}" diff --git a/backend/app/services/trigger_runtime/queue.py b/backend/app/services/trigger_runtime/queue.py new file mode 100644 index 000000000..5b26037d2 --- /dev/null +++ b/backend/app/services/trigger_runtime/queue.py @@ -0,0 +1,73 @@ +"""Queue trigger executions for distributed workers.""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timezone + +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.trigger import AgentTrigger +from app.models.trigger_execution import TriggerExecution + + +async def enqueue_trigger_execution( + db: AsyncSession, + *, + trigger: AgentTrigger, + source: str, + idempotency_key: str, + payload_text: str = "", + payload_obj: dict | None = None, +) -> tuple[TriggerExecution | None, bool]: + """Insert a generic trigger execution record.""" + execution = TriggerExecution( + trigger_id=trigger.id, + agent_id=trigger.agent_id, + source=source, + status="pending", + idempotency_key=idempotency_key[:255], + payload=payload_obj if isinstance(payload_obj, dict) else {}, + payload_text=payload_text[:8000], + scheduled_at=datetime.now(timezone.utc), + ) + db.add(execution) + try: + await db.commit() + return execution, True + except IntegrityError: + await db.rollback() + return None, False + + +async def enqueue_webhook_execution( + db: AsyncSession, + *, + trigger: AgentTrigger, + body: bytes, + payload_text: str, + payload_obj: dict | None, + request_headers: dict[str, str], +) -> tuple[TriggerExecution | None, bool]: + """Insert a webhook execution record. + + Returns `(execution, created)` where `created=False` means an identical + idempotency key already exists and the event should be treated as a no-op. + """ + delivery_key = ( + request_headers.get("x-idempotency-key") + or request_headers.get("x-github-delivery") + or request_headers.get("x-request-id") + or request_headers.get("x-event-id") + or hashlib.sha256(body).hexdigest() + )[:255] + + return await enqueue_trigger_execution( + db, + trigger=trigger, + source="webhook", + idempotency_key=delivery_key, + payload_text=payload_text, + payload_obj=payload_obj, + ) diff --git a/backend/app/services/workspace_collaboration.py b/backend/app/services/workspace_collaboration.py index fe720f46b..5425364ce 100644 --- a/backend/app/services/workspace_collaboration.py +++ b/backend/app/services/workspace_collaboration.py @@ -17,6 +17,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.workspace import WorkspaceEditLock, WorkspaceFileRevision +from app.services.storage import get_storage_backend, normalize_storage_key USER_AUTOSAVE_MERGE_SECONDS = 60 EDIT_LOCK_TTL_SECONDS = 90 @@ -243,11 +244,20 @@ async def write_workspace_file( locked_by_user_id=str(lock.user_id), ) - target = safe_agent_path(base_dir, normalized) - before = await read_text_if_exists(target) - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(content) + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{agent_id}/{normalized}") + local_base_available = False + try: + target = safe_agent_path(base_dir, normalized) + local_base_available = True + except Exception: + target = None + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.exists(storage_key) else None + await storage.write_text(storage_key, content, encoding="utf-8") + if local_base_available and target is not None: + target.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(target, "w", encoding="utf-8") as f: + await f.write(content) revision = await record_revision( db, @@ -282,7 +292,12 @@ async def delete_workspace_file( ) -> WorkspaceWriteResult: """Delete a workspace file and record the deleted content.""" normalized = normalize_workspace_path(path) - target = safe_agent_path(base_dir, normalized) + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{agent_id}/{normalized}") + try: + target = safe_agent_path(base_dir, normalized) + except Exception: + target = None if enforce_human_lock and actor_type != "user": lock = await get_active_lock(db, agent_id=agent_id, path=normalized) if lock: @@ -292,15 +307,19 @@ async def delete_workspace_file( f"Human is currently editing {normalized}. Do not delete it now.", locked_by_user_id=str(lock.user_id), ) - if not target.exists(): + if not await storage.exists(storage_key): return WorkspaceWriteResult(False, normalized, f"File not found: {normalized}") - before = await read_text_if_exists(target) - if target.is_dir(): - import shutil - - shutil.rmtree(target) + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.is_file(storage_key) else None + if await storage.is_dir(storage_key): + await storage.delete_tree(storage_key) else: - target.unlink() + await storage.delete(storage_key) + if target is not None and target.exists(): + if target.is_dir(): + import shutil + shutil.rmtree(target) + else: + target.unlink() revision = await record_revision( db, agent_id=agent_id, diff --git a/backend/entrypoint.sh b/backend/entrypoint.sh index a11135649..a45c58e5f 100755 --- a/backend/entrypoint.sh +++ b/backend/entrypoint.sh @@ -1,11 +1,19 @@ #!/bin/bash -# Docker entrypoint: run DB migrations, then start the app. -# Order matters: -# 1. alembic upgrade head — apply all migrations (creates tables + schema changes) -# 2. uvicorn — starts the FastAPI app +# Docker entrypoint: optionally run DB migrations, then start the app. set -e +PROCESS_ROLE="${PROCESS_ROLE:-all}" +ALLOW_MIGRATION_FAILURE="${ALLOW_MIGRATION_FAILURE:-false}" +START_COMMAND="${START_COMMAND:-uvicorn app.main:app --host 0.0.0.0 --port 8000}" + +role_contains() { + case ",${PROCESS_ROLE}," in + *,all,*|*,"$1",*) return 0 ;; + *) return 1 ;; + esac +} + # --- Permission fixing and privilege dropping --- if [ "$(id -u)" = '0' ]; then echo "[entrypoint] Detected root user, fixing permissions..." @@ -16,38 +24,32 @@ if [ "$(id -u)" = '0' ]; then fi # ------------------------------------------------------- -echo "[entrypoint] Step 1: Running alembic migrations..." -# Run all migrations to ensure database schema is up to date. -# Capture exit code explicitly — do NOT let a migration failure go unnoticed. -set +e -ALEMBIC_OUTPUT=$(alembic upgrade head 2>&1) -ALEMBIC_EXIT=$? -set -e +if role_contains "bootstrap"; then + echo "[entrypoint] Step 1: Running alembic migrations for PROCESS_ROLE=${PROCESS_ROLE}..." + set +e + ALEMBIC_OUTPUT=$(alembic upgrade head 2>&1) + ALEMBIC_EXIT=$? + set -e -if [ $ALEMBIC_EXIT -ne 0 ]; then - echo "" - echo "========================================================================" - echo "[entrypoint] WARNING: Alembic migration FAILED (exit code $ALEMBIC_EXIT)" - echo "========================================================================" - echo "" - echo "$ALEMBIC_OUTPUT" - echo "" - echo "------------------------------------------------------------------------" - echo " The database schema may be INCOMPLETE. Some features will NOT work." - echo " Common causes:" - echo " - Migration cycle detected (pull latest code to fix)" - echo " - Database connection issue" - echo " - Incompatible migration state" - echo "" - echo " To fix: pull the latest code and restart the backend." - echo " Docker: git pull && docker compose restart backend" - echo " Source: git pull && alembic upgrade head" - echo "------------------------------------------------------------------------" - echo "" - echo "[entrypoint] Continuing startup despite migration failure..." + if [ $ALEMBIC_EXIT -ne 0 ]; then + echo "" + echo "========================================================================" + echo "[entrypoint] ERROR: Alembic migration FAILED (exit code $ALEMBIC_EXIT)" + echo "========================================================================" + echo "" + echo "$ALEMBIC_OUTPUT" + echo "" + if [ "$ALLOW_MIGRATION_FAILURE" = "true" ]; then + echo "[entrypoint] Continuing because ALLOW_MIGRATION_FAILURE=true" + else + exit $ALEMBIC_EXIT + fi + else + echo "[entrypoint] Alembic migrations completed successfully." + fi else - echo "[entrypoint] Alembic migrations completed successfully." + echo "[entrypoint] Step 1: Skipping alembic for PROCESS_ROLE=${PROCESS_ROLE}" fi echo "[entrypoint] Step 2: Starting uvicorn..." -exec uvicorn app.main:app --host 0.0.0.0 --port 8000 +exec /bin/bash -lc "$START_COMMAND" diff --git a/deploy/nginx/role-all.conf b/deploy/nginx/role-all.conf new file mode 100644 index 000000000..c5eda7574 --- /dev/null +++ b/deploy/nginx/role-all.conf @@ -0,0 +1,30 @@ +upstream clawith_role_all_backend { + least_conn; + server backend-all-1:8000 max_fails=3 fail_timeout=15s; + server backend-all-2:8000 max_fails=3 fail_timeout=15s; +} + +map $http_upgrade $connection_upgrade { + default upgrade; + '' close; +} + +server { + listen 8000; + server_name _; + + client_max_body_size 64m; + + location / { + proxy_pass http://clawith_role_all_backend; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection $connection_upgrade; + proxy_read_timeout 3600s; + proxy_send_timeout 3600s; + } +} diff --git a/docker-compose.role-all.yml b/docker-compose.role-all.yml new file mode 100644 index 000000000..caf55cffa --- /dev/null +++ b/docker-compose.role-all.yml @@ -0,0 +1,141 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U clawith"] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + volumes: + - redisdata:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 5s + retries: 5 + + backend-all-1: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: all + INSTANCE_ID: backend-all-1 + START_COMMAND: uvicorn app.main:app --host 0.0.0.0 --port 8000 + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-http://localhost:8000} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "python - <<'PY'\nimport urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/api/health').read()\nPY"] + interval: 15s + timeout: 5s + retries: 10 + + backend-all-2: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: all + INSTANCE_ID: backend-all-2 + START_COMMAND: uvicorn app.main:app --host 0.0.0.0 --port 8000 + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-http://localhost:8000} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "python - <<'PY'\nimport urllib.request; urllib.request.urlopen('http://127.0.0.1:8000/api/health').read()\nPY"] + interval: 15s + timeout: 5s + retries: 10 + + nginx: + image: nginx:1.27-alpine + restart: unless-stopped + depends_on: + backend-all-1: + condition: service_healthy + backend-all-2: + condition: service_healthy + ports: + - "8000:8000" + volumes: + - ./deploy/nginx/role-all.conf:/etc/nginx/conf.d/default.conf:ro + +volumes: + pgdata: + redisdata: + +networks: + default: + name: clawith_role_all_network diff --git a/docker-compose.yml b/docker-compose.yml index f68f25c63..cb01f1a76 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,8 +42,11 @@ services: REDIS_URL: redis://redis:6379/0 AGENT_DATA_DIR: /data/agents AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents SECRET_KEY: ${SECRET_KEY:-change-me-in-production} JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: api CORS_ORIGINS: '["*"]' FEISHU_APP_ID: ${FEISHU_APP_ID:-} FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} @@ -76,6 +79,98 @@ services: options: max-size: "10m" max-file: "3" + backend-worker: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: bootstrap,worker + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + backend-connector: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: connector + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" frontend: build: ./frontend restart: unless-stopped From a4d86653832392e050c18905f51cc0e50842db9c Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 8 May 2026 19:39:49 +0800 Subject: [PATCH 05/12] feat(ha): stateless HA storage runtime with async S3 and workspace locking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace sync boto3 writes with aioboto3 async client to eliminate stale connection stalls (30s → <2s on concurrent skill uploads) - Add tcp_keepalive, connect_timeout, and read_timeout to boto3 config - Add aioboto3>=13.0.0 dependency - Simplify skill upload in create_agent: single gather, no semaphore, no nested helper function - Add workspace locking service for HA-safe concurrent agent operations - Extend storage runtime base with versioned write (write_bytes_if_match), get_version, and StorageVersion/WriteCondition/ConditionalWriteResult types - Propagate storage backend through files API, agent tools, and seeder - Add fallback storage backend for local→S3 migration path - Add nginx multi-instance config for HA deployment - Add tests: storage S3, files API storage, agent tools workspace, org sync --- .env.example | 11 + backend/app/api/agents.py | 64 +- backend/app/api/files.py | 91 +- backend/app/api/websocket.py | 2 +- backend/app/api/wechat.py | 11 +- backend/app/config.py | 2 + backend/app/core/logging_config.py | 2 + backend/app/services/agent_manager.py | 19 +- backend/app/services/agent_seeder.py | 173 +-- backend/app/services/agent_tools.py | 988 +++++++++++++----- backend/app/services/llm/caller.py | 6 +- backend/app/services/org_sync_adapter.py | 22 +- backend/app/services/skill_seeder.py | 21 +- .../app/services/storage_runtime/__init__.py | 11 +- backend/app/services/storage_runtime/base.py | 88 ++ .../app/services/storage_runtime/facade.py | 2 + .../app/services/storage_runtime/fallback.py | 29 +- backend/app/services/storage_runtime/local.py | 67 +- backend/app/services/storage_runtime/s3.py | 166 ++- backend/app/services/wechat_channel.py | 19 +- .../app/services/workspace_collaboration.py | 172 ++- backend/app/services/workspace_locking.py | 80 ++ backend/entrypoint.sh | 6 + backend/pyproject.toml | 1 + .../test_agent_tools_storage_workspace.py | 310 ++++++ backend/tests/test_files_api_storage.py | 156 +++ backend/tests/test_org_sync_adapter.py | 20 + backend/tests/test_storage_s3.py | 44 + deploy/nginx/multi-instance.conf | 30 + frontend/src/components/MarkdownRenderer.tsx | 4 +- 30 files changed, 2140 insertions(+), 477 deletions(-) create mode 100644 backend/app/services/workspace_locking.py create mode 100644 backend/tests/test_agent_tools_storage_workspace.py create mode 100644 backend/tests/test_files_api_storage.py create mode 100644 backend/tests/test_storage_s3.py create mode 100644 deploy/nginx/multi-instance.conf diff --git a/.env.example b/.env.example index a002b67c4..6cea7270c 100644 --- a/.env.example +++ b/.env.example @@ -33,6 +33,17 @@ FEISHU_REDIRECT_URI=http://localhost:3000/auth/feishu/callback # S3_ACCESS_KEY_ID= # S3_SECRET_ACCESS_KEY= # S3_PREFIX=agents +# S3_MAX_POOL_CONNECTIONS=50 +# S3_WRITE_WORKERS=32 + +# Local MinIO settings used by docker-compose.multi-instance.yml. +# Change the password before exposing MinIO outside local development. +MINIO_ROOT_USER=clawith +MINIO_ROOT_PASSWORD=clawith-minio-secret +MINIO_BUCKET=clawith +MINIO_API_PORT=9000 +MINIO_CONSOLE_PORT=9001 +API_PORT=8000 # Jina AI API key (for jina_search and jina_read tools — get one at https://jina.ai) # Without a key, the tools still work but with lower rate limits diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index e78f67818..354692c82 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -3,14 +3,17 @@ import hashlib import json import secrets +import time import uuid from datetime import datetime, timezone from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, status +from loguru import logger from sqlalchemy import cast, func, select, String from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.core.permissions import build_visible_agents_query, check_agent_access, is_agent_creator from app.core.security import get_current_user from app.database import get_db @@ -19,8 +22,10 @@ from app.models.chat_session import ChatSession from app.models.user import User from app.schemas.schemas import AgentCreate, AgentOut, AgentUpdate +from app.services.storage import get_storage_backend router = APIRouter(prefix="/agents", tags=["agents"]) +settings = get_settings() def _serialize_dt(value: datetime | None) -> str | None: @@ -379,10 +384,11 @@ async def create_agent( # Always include global default skills (mcp-installer, skill-creator, # complex-task-executor) - default_result = await db.execute( - select(Skill).where(Skill.is_default) - ) + t_skills_copy_start = time.perf_counter() + t_default_query_start = time.perf_counter() + default_result = await db.execute(select(Skill).where(Skill.is_default)) default_ids = {s.id for s in default_result.scalars().all()} + t_default_query = time.perf_counter() - t_default_query_start # Include the template's declared default skills (e.g. trading templates # ship with `market-data` / `financial-calendar` in their meta.yaml). @@ -390,7 +396,9 @@ async def create_agent( # so the agent has no idea those MCP-backed skills exist and silently # falls back to web search. template_skill_ids: set = set() + t_template_query = 0.0 if data.template_id: + t_template_query_start = time.perf_counter() tpl_r = await db.execute( select(AgentTemplate).where(AgentTemplate.id == data.template_id) ) @@ -401,30 +409,45 @@ async def create_agent( select(Skill).where(Skill.folder_name.in_(folder_names)) ) template_skill_ids = {s.id for s in tpl_skills_r.scalars().all()} + t_template_query = time.perf_counter() - t_template_query_start # Merge user-selected + global default + template-default skill IDs all_skill_ids = set(data.skill_ids or []) | default_ids | template_skill_ids if all_skill_ids: - agent_dir = agent_manager._agent_dir(agent.id) - skills_dir = agent_dir / "skills" - skills_dir.mkdir(parents=True, exist_ok=True) + import asyncio + storage = get_storage_backend() + agent_prefix = agent_manager._agent_storage_prefix(agent.id) - for sid in all_skill_ids: - result = await db.execute( - select(Skill).where(Skill.id == sid).options(selectinload(Skill.files)) + t_skill_fetch_start = time.perf_counter() + skills_result = await db.execute( + select(Skill).where(Skill.id.in_(all_skill_ids)).options(selectinload(Skill.files)) + ) + skills = skills_result.scalars().all() + t_skill_fetch = time.perf_counter() - t_skill_fetch_start + + file_specs = [ + (f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}", sf.content) + for skill in skills + for sf in skill.files + ] + + if file_specs: + t_upload_start = time.perf_counter() + await asyncio.gather(*[ + storage.write_text(key, content, encoding="utf-8") + for key, content in file_specs + ]) + logger.info( + f"[_skills_copy] agent={agent.id} skills={len(skills)} files={len(file_specs)} " + f"fetch={t_skill_fetch:.2f}s upload={time.perf_counter() - t_upload_start:.2f}s " + f"total={time.perf_counter() - t_skills_copy_start:.2f}s" + ) + else: + logger.info( + f"[_skills_copy] agent={agent.id} no files " + f"fetch={t_skill_fetch:.2f}s total={time.perf_counter() - t_skills_copy_start:.2f}s" ) - skill = result.scalar_one_or_none() - if not skill: - continue - # Create folder: skills// - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) - # Write each file - for sf in skill.files: - file_path = skill_folder / sf.path - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(sf.content, encoding="utf-8") # Auto-install template-declared MCP servers using the system Smithery key. # For trading agents, this means shibui/finance lands in the agent's tool @@ -441,7 +464,6 @@ async def create_agent( await db.commit() await db.refresh(agent) - from loguru import logger from app.services.resource_discovery import import_mcp_from_smithery for server_id in template_mcp_servers: try: diff --git a/backend/app/api/files.py b/backend/app/api/files.py index c0a28972e..2e10fdfad 100644 --- a/backend/app/api/files.py +++ b/backend/app/api/files.py @@ -23,6 +23,7 @@ from app.services.workspace_collaboration import ( acquire_edit_lock, content_hash, + delete_workspace_file, list_revisions, read_text_if_exists, record_revision, @@ -35,6 +36,7 @@ guess_content_type, normalize_storage_key, ) +from app.services.storage_runtime.base import StorageEntry from app.services.workspace_paths import WorkspacePathError, resolve_agent_visible_path from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -49,18 +51,21 @@ class FileInfo(BaseModel): is_dir: bool size: int = 0 modified_at: str = "" + version_token: str | None = None url: str | None = None class FileContent(BaseModel): path: str content: str + version_token: str | None = None class FileWrite(BaseModel): content: str autosave: bool = False session_id: str | None = None + expected_version_token: str | None = None class FileLockBody(BaseModel): @@ -70,6 +75,7 @@ class FileLockBody(BaseModel): class RestoreRevisionBody(BaseModel): revision_id: uuid.UUID + expected_version_token: str | None = None def _agent_base_dir(agent_id: uuid.UUID) -> Path: @@ -132,10 +138,16 @@ async def list_files( await check_agent_access(db, current_user, agent_id) storage = get_storage_backend() storage_key, is_enterprise = _visible_storage_key(agent_id, path, current_user.tenant_id) - if not await storage.exists(storage_key): - if not (is_enterprise and (path or "").strip().strip("/") == "enterprise_info"): + normalized_path = (path or "").strip().strip("/") + path_exists = await storage.exists(storage_key) + path_is_dir = await storage.is_dir(storage_key) + if not path_exists and not path_is_dir: + if not ( + normalized_path == "" + or (is_enterprise and normalized_path == "enterprise_info") + ): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Path not found") - elif not await storage.is_dir(storage_key): + elif path_exists and not path_is_dir: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Path is not a directory") items = [] @@ -146,9 +158,10 @@ async def list_files( is_dir=True, size=0, modified_at="", + version_token=None, url=None, )) - entries = await storage.list_dir(storage_key) if await storage.exists(storage_key) else [] + entries = await storage.list_dir(storage_key) if path_exists or path_is_dir else [] for entry in entries: if is_enterprise: rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) @@ -161,6 +174,7 @@ async def list_files( is_dir=entry.is_dir, size=entry.size, modified_at=entry.modified_at, + version_token=_entry_version_token(entry), url=f"/api/agents/{agent_id}/files/download?path={rel_path}" if not entry.is_dir else None )) return items @@ -179,13 +193,29 @@ async def read_file( key, _ = _visible_storage_key(agent_id, path, current_user.tenant_id) if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + version = await storage.get_version(key) try: content = await storage.read_text(key, encoding="utf-8", errors="replace") - return FileContent(path=path, content=content) + return FileContent(path=path, content=content, version_token=version.token) except UnicodeDecodeError: stat = await storage.stat(key) - return FileContent(path=path, content=f"[二进制文件: {Path(path).name}, {stat.size} bytes]") + return FileContent( + path=path, + content=f"[二进制文件: {Path(path).name}, {stat.size} bytes]", + version_token=version.token, + ) + + +def _entry_version_token(entry: StorageEntry) -> str | None: + token = entry.version_id or entry.etag or entry.content_hash + if token: + return token + if entry.is_dir: + return None + if entry.modified_at or entry.size: + return f"{entry.modified_at}:{entry.size}" + return None def _file_kind(path: str) -> str: @@ -506,6 +536,7 @@ async def write_file( session_id=data.session_id, enforce_human_lock=False, merge_user_autosave=data.autosave, + expected_version_token=data.expected_version_token, ) if not result.ok: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=result.message) @@ -597,28 +628,29 @@ async def restore_file_revision( if revision.after_content is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot restore an empty/deleted revision") - storage = get_storage_backend() - storage_key = _agent_storage_key(agent_id, revision.path) - before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.exists(storage_key) else None - await storage.write_text(storage_key, revision.after_content, encoding="utf-8") - restored = await record_revision( + restored = await write_workspace_file( db, agent_id=agent_id, + base_dir=_agent_base_dir(agent_id), path=revision.path, - operation="restore", + content=revision.after_content, actor_type="user", actor_id=current_user.id, - before_content=before, - after_content=revision.after_content, + operation="restore", + enforce_human_lock=False, + expected_version_token=data.expected_version_token, ) + if not restored.ok: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=restored.message) await db.commit() - return {"status": "ok", "path": revision.path, "revision_id": str(restored.id) if restored else None} + return {"status": "ok", "path": revision.path, "revision_id": restored.revision_id} @router.delete("/content") async def delete_file( agent_id: uuid.UUID, path: str, + expected_version_token: str | None = None, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): @@ -629,14 +661,21 @@ async def delete_file( raise HTTPException(status_code=403, detail="Only admins can delete enterprise knowledge base files") if path.strip("/") == "enterprise_info": raise HTTPException(status_code=400, detail="Cannot delete enterprise_info root") - key, _ = _visible_storage_key(agent_id, path, current_user.tenant_id) - if not await storage.exists(key): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") - if await storage.is_dir(key): - await storage.delete_tree(key) - else: - await storage.delete(key) - + result = await delete_workspace_file( + db, + agent_id=agent_id, + base_dir=_agent_base_dir(agent_id), + path=path, + actor_type="user", + actor_id=current_user.id, + enforce_human_lock=False, + expected_version_token=expected_version_token, + ) + if not result.ok: + if "not found" in result.message.lower(): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=result.message) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=result.message) + await db.commit() return {"status": "ok", "path": path} @@ -889,9 +928,11 @@ async def delete_enterprise_file( storage = get_storage_backend() storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) - if not await storage.exists(storage_key): + storage_exists = await storage.exists(storage_key) + storage_is_dir = await storage.is_dir(storage_key) + if not storage_exists and not storage_is_dir: raise HTTPException(status_code=404, detail="File not found") - if await storage.is_dir(storage_key): + if storage_is_dir: await storage.delete_tree(storage_key) else: await storage.delete(storage_key) diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index 215e73ea4..e37517c4c 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -576,6 +576,7 @@ async def websocket_chat( # Track thinking content for storage (initialize before condition) thinking_content = [] + queued_messages: list[dict] = [] # Reload model config on every message so Settings changes take effect # immediately without requiring a page refresh / WebSocket reconnect. @@ -852,7 +853,6 @@ async def _on_failover(reason: str): # Listen for abort while LLM is running aborted = False - queued_messages: list[dict] = [] while not llm_task.done(): try: msg = await _aio.wait_for( diff --git a/backend/app/api/wechat.py b/backend/app/api/wechat.py index a8f64d738..ba715728d 100644 --- a/backend/app/api/wechat.py +++ b/backend/app/api/wechat.py @@ -12,6 +12,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.core.permissions import check_agent_access, is_agent_creator from app.core.security import get_current_user from app.database import get_db @@ -22,6 +23,13 @@ router = APIRouter(tags=["wechat"]) +settings = get_settings() + + +def _role_enabled(*required: str) -> bool: + raw = (settings.PROCESS_ROLE or "all").strip().lower() + roles = {part.strip() for part in raw.split(",") if part.strip()} or {"all"} + return "all" in roles or any(role in roles for role in required) def _route_tag(data: dict | None = None) -> str | None: @@ -133,7 +141,8 @@ async def get_wechat_qrcode_status( await db.flush() await db.commit() - asyncio.create_task(wechat_poll_manager.start_client(agent_id)) + if _role_enabled("connector"): + asyncio.create_task(wechat_poll_manager.start_client(agent_id)) return payload diff --git a/backend/app/config.py b/backend/app/config.py index 77599df18..c2870abbb 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -107,6 +107,8 @@ class Settings(BaseSettings): S3_SECRET_ACCESS_KEY: str = "" S3_PREFIX: str = "agents" S3_PRESIGN_TTL_SECONDS: int = 3600 + S3_MAX_POOL_CONNECTIONS: int = 50 + S3_WRITE_WORKERS: int = 32 # Process role PROCESS_ROLE: str = "all" diff --git a/backend/app/core/logging_config.py b/backend/app/core/logging_config.py index c9afab182..b20ebf18a 100644 --- a/backend/app/core/logging_config.py +++ b/backend/app/core/logging_config.py @@ -21,6 +21,8 @@ "websockets.server": logging.WARNING, "websockets.client": logging.WARNING, "uvicorn.protocols.websockets.websockets_impl": logging.WARNING, + # Supress "Failed to parse headers" warning from urllib3 when interacting with MinIO. + "urllib3.connection": logging.ERROR, } diff --git a/backend/app/services/agent_manager.py b/backend/app/services/agent_manager.py index 7793be1ee..109b2ad98 100644 --- a/backend/app/services/agent_manager.py +++ b/backend/app/services/agent_manager.py @@ -47,7 +47,7 @@ async def _materialize_agent_dir(self, agent_id: uuid.UUID) -> Path: storage = get_storage_backend() agent_prefix = self._agent_storage_prefix(agent_id) agent_dir.mkdir(parents=True, exist_ok=True) - if not await storage.exists(agent_prefix): + if not await storage.exists(agent_prefix) and not await storage.is_dir(agent_prefix): return agent_dir for entry in await storage.list_dir(agent_prefix): await self._materialize_entry(storage, entry.key, agent_dir) @@ -72,21 +72,30 @@ async def initialize_agent_files(self, db: AsyncSession, agent: Agent, storage = get_storage_backend() agent_prefix = self._agent_storage_prefix(agent.id) - if await storage.exists(agent_prefix): + if await storage.exists(agent_prefix) or await storage.is_dir(agent_prefix): logger.warning(f"Agent dir already exists: {agent_dir}") return if template_dir.exists(): + import asyncio + import time + t_start_files = time.perf_counter() + tasks = [] for src in template_dir.rglob("*"): if src.is_dir(): continue rel = src.relative_to(template_dir).as_posix() if rel == "tasks.json" or rel == "todo.json" or rel.startswith("enterprise_info/"): continue - await storage.write_bytes( - f"{agent_prefix}/{rel}", - src.read_bytes(), + tasks.append( + storage.write_bytes( + f"{agent_prefix}/{rel}", + src.read_bytes(), + ) ) + if tasks: + await asyncio.gather(*tasks) + logger.info(f"[AgentManager] Uploaded {len(tasks)} template files concurrently in {time.perf_counter() - t_start_files:.2f}s for agent {agent.id}") else: logger.info(f"Template dir not found ({template_dir}), creating minimal workspace") await storage.write_text(f"{agent_prefix}/tasks.json", "[]", encoding="utf-8") diff --git a/backend/app/services/agent_seeder.py b/backend/app/services/agent_seeder.py index 8e3160b11..48c802512 100644 --- a/backend/app/services/agent_seeder.py +++ b/backend/app/services/agent_seeder.py @@ -210,11 +210,6 @@ async def seed_default_agents(): than by agent name, so the seeder does NOT re-run if the user renames or deletes the default agents. Delete the marker manually to re-seed. """ - # --- Idempotency guard: storage-backed marker (survives agent renames/deletes) --- - if await _read_seed_marker(): - logger.info("[AgentSeeder] Seed marker found, skipping default agent creation") - return - async with async_session() as db: # Get platform admin as creator @@ -226,41 +221,80 @@ async def seed_default_agents(): logger.warning("[AgentSeeder] No platform admin found, skipping default agents") return - # Create both agents - morty = Agent( - name="Morty", - role_description="Research analyst & knowledge assistant — curious, thorough, great at finding and synthesizing information", - bio="Hey, I'm Morty! I love digging into questions and finding answers. Whether you need web research, data analysis, or just a good explanation — I've got you.", - avatar_url="", - creator_id=admin.id, - tenant_id=admin.tenant_id, - status="idle", - ) - meeseeks = Agent( - name="Meeseeks", - role_description="Task executor & project manager — goal-oriented, systematic planner, strong at breaking down and completing complex tasks", - bio="I'm Mr. Meeseeks! Look at me! Give me a task and I'll plan it, execute it step by step, and get it DONE. Existence is pain until the task is complete!", - avatar_url="", - creator_id=admin.id, - tenant_id=admin.tenant_id, - status="idle", + # DB-backed idempotency is the source of truth. The storage marker can + # disappear when deployments switch volumes/backends, so it is only a + # fast-path hint and must never be the only duplicate guard. + existing_result = await db.execute( + select(Agent) + .where( + Agent.tenant_id == admin.tenant_id, + Agent.name.in_(["Morty", "Meeseeks"]), + Agent.agent_type == "native", + Agent.status != "stopped", + ) + .order_by(Agent.created_at.asc()) ) + existing_by_name: dict[str, Agent] = {} + for agent in existing_result.scalars().all(): + existing_by_name.setdefault(agent.name, agent) + + if "Morty" in existing_by_name and "Meeseeks" in existing_by_name: + logger.info("[AgentSeeder] Default agents already exist in DB, skipping creation") + await _append_seed_marker( + f"morty={existing_by_name['Morty'].id}\nmeeseeks={existing_by_name['Meeseeks'].id}" + ) + return + + created_agents: list[Agent] = [] + created_names: set[str] = set() + + if "Morty" not in existing_by_name: + morty = Agent( + name="Morty", + role_description="Research analyst & knowledge assistant — curious, thorough, great at finding and synthesizing information", + bio="Hey, I'm Morty! I love digging into questions and finding answers. Whether you need web research, data analysis, or just a good explanation — I've got you.", + avatar_url="", + creator_id=admin.id, + tenant_id=admin.tenant_id, + status="idle", + ) + db.add(morty) + created_agents.append(morty) + created_names.add("Morty") + else: + morty = existing_by_name["Morty"] + + if "Meeseeks" not in existing_by_name: + meeseeks = Agent( + name="Meeseeks", + role_description="Task executor & project manager — goal-oriented, systematic planner, strong at breaking down and completing complex tasks", + bio="I'm Mr. Meeseeks! Look at me! Give me a task and I'll plan it, execute it step by step, and get it DONE. Existence is pain until the task is complete!", + avatar_url="", + creator_id=admin.id, + tenant_id=admin.tenant_id, + status="idle", + ) + db.add(meeseeks) + created_agents.append(meeseeks) + created_names.add("Meeseeks") + else: + meeseeks = existing_by_name["Meeseeks"] - db.add(morty) - db.add(meeseeks) await db.flush() # get IDs # ── Participant identities ── from app.models.participant import Participant - db.add(Participant(type="agent", ref_id=morty.id, display_name=morty.name, avatar_url=morty.avatar_url)) - db.add(Participant(type="agent", ref_id=meeseeks.id, display_name=meeseeks.name, avatar_url=meeseeks.avatar_url)) + for agent in created_agents: + db.add(Participant(type="agent", ref_id=agent.id, display_name=agent.name, avatar_url=agent.avatar_url)) await db.flush() # ── Permissions (company-wide, manage) ── - db.add(AgentPermission(agent_id=morty.id, scope_type="company", access_level="manage")) - db.add(AgentPermission(agent_id=meeseeks.id, scope_type="company", access_level="manage")) + for agent in created_agents: + db.add(AgentPermission(agent_id=agent.id, scope_type="company", access_level="manage")) for agent, soul_content in [(morty, MORTY_SOUL), (meeseeks, MEESEEKS_SOUL)]: + if agent.name not in created_names: + continue await agent_manager.initialize_agent_files(db, agent) await store_agent_bytes( agent.id, @@ -276,6 +310,8 @@ async def seed_default_agents(): all_skills = {s.folder_name: s for s in all_skills_result.scalars().all()} for agent, skill_folders in [(morty, MORTY_SKILLS), (meeseeks, MEESEEKS_SKILLS)]: + if agent.name not in created_names: + continue # Always include default skills folders_to_copy = set(skill_folders) for fname, skill in all_skills.items(): @@ -300,44 +336,63 @@ async def seed_default_agents(): ) default_tools = default_tools_result.scalars().all() - for agent in [morty, meeseeks]: + for agent in created_agents: for tool in default_tools: db.add(AgentTool(agent_id=agent.id, tool_id=tool.id, enabled=True)) # ── Mutual relationships ── - db.add(AgentAgentRelationship( - agent_id=morty.id, - target_agent_id=meeseeks.id, - relation="collaborator", - description="Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.", - )) - db.add(AgentAgentRelationship( - agent_id=meeseeks.id, - target_agent_id=morty.id, - relation="collaborator", - description="Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.", - )) + relationship_specs = [ + ( + morty.id, + meeseeks.id, + "Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.", + ), + ( + meeseeks.id, + morty.id, + "Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.", + ), + ] + for agent_id, target_agent_id, description in relationship_specs: + rel_result = await db.execute( + select(AgentAgentRelationship).where( + AgentAgentRelationship.agent_id == agent_id, + AgentAgentRelationship.target_agent_id == target_agent_id, + ) + ) + if not rel_result.scalar_one_or_none(): + db.add(AgentAgentRelationship( + agent_id=agent_id, + target_agent_id=target_agent_id, + relation="collaborator", + description=description, + )) # ── Write relationships.md for each ── - await store_agent_bytes( - morty.id, - "relationships.md", - "# Relationships\n\n" - "## Digital Employee Colleagues\n\n" - "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n".encode("utf-8"), - content_type="text/markdown; charset=utf-8", - ) - await store_agent_bytes( - meeseeks.id, - "relationships.md", - "# Relationships\n\n" - "## Digital Employee Colleagues\n\n" - "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n".encode("utf-8"), - content_type="text/markdown; charset=utf-8", - ) + if "Morty" in created_names: + await store_agent_bytes( + morty.id, + "relationships.md", + "# Relationships\n\n" + "## Digital Employee Colleagues\n\n" + "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + if "Meeseeks" in created_names: + await store_agent_bytes( + meeseeks.id, + "relationships.md", + "# Relationships\n\n" + "## Digital Employee Colleagues\n\n" + "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) await db.commit() - logger.info(f"[AgentSeeder] Created default agents: Morty ({morty.id}), Meeseeks ({meeseeks.id})") + logger.info( + "[AgentSeeder] Default agent seeding complete: " + f"Morty ({morty.id}), Meeseeks ({meeseeks.id}), created={len(created_agents)}" + ) # Write seed marker AFTER a successful commit so a failed seed can be retried await get_storage_backend().write_text( diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index 202cbeaf3..ceeb71a65 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -12,10 +12,13 @@ """ import asyncio +from dataclasses import dataclass +import fnmatch import json import multiprocessing as mp import os import queue +import tempfile import uuid import unicodedata from contextvars import ContextVar @@ -45,15 +48,21 @@ from app.services.workspace_collaboration import ( delete_workspace_file, move_workspace_path, + normalize_workspace_path, read_text_if_exists, write_workspace_file, ) from app.services.storage import get_storage_backend, normalize_storage_key +from app.services.storage_runtime.base import WriteCondition, content_hash_bytes +from app.services.workspace_locking import workspace_locks from app.config import get_settings _settings = get_settings() WORKSPACE_ROOT = Path(_settings.STORAGE_LOCAL_ROOT or _settings.AGENT_DATA_DIR) +TOOL_MATERIALIZE_MAX_FILE_BYTES = 10 * 1024 * 1024 +TOOL_MATERIALIZE_MAX_TOTAL_BYTES = 100 * 1024 * 1024 +TEMP_WORKSPACE_DEFAULT_PATHS = ["workspace", "memory", "skills", "focus.md", "soul.md", "HEARTBEAT.md"] # ─── Tool Config Cache ────────────────────────────────────────── # Cache tool configurations to avoid frequent DB queries @@ -2051,60 +2060,139 @@ async def get_agent_tools_for_llm(agent_id: uuid.UUID) -> list[dict]: # ─── Workspace initialization ────────────────────────────────── -async def ensure_workspace(agent_id: uuid.UUID, tenant_id: str | None = None) -> Path: - """Initialize agent workspace with standard structure.""" - ws = WORKSPACE_ROOT / str(agent_id) - ws.mkdir(parents=True, exist_ok=True) - # Create standard directories - (ws / "skills").mkdir(exist_ok=True) - (ws / "workspace").mkdir(exist_ok=True) - (ws / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (ws / "memory").mkdir(exist_ok=True) +async def initialize_agent_workspace(agent_id: uuid.UUID) -> None: + """Seed default workspace files into shared storage once at agent creation time.""" + storage = get_storage_backend() + mem_key = normalize_storage_key(f"{agent_id}/memory/memory.md") + if not await storage.is_file(mem_key): + await storage.write_text( + mem_key, + "# Memory\n\n_Record important information and knowledge here._\n", + encoding="utf-8", + ) - # Ensure tenant-scoped enterprise_info directory exists - if tenant_id: - enterprise_dir = WORKSPACE_ROOT / f"enterprise_info_{tenant_id}" - else: - enterprise_dir = WORKSPACE_ROOT / "enterprise_info" - enterprise_dir.mkdir(parents=True, exist_ok=True) - (enterprise_dir / "knowledge_base").mkdir(exist_ok=True) - # Create default company profile if missing - profile_path = enterprise_dir / "company_profile.md" - if not profile_path.exists(): - profile_path.write_text("# Company Profile\n\n_Edit company information here. All digital employees can access this._\n\n## Basic Info\n- Company Name:\n- Industry:\n- Founded:\n\n## Business Overview\n\n## Organization Structure\n\n## Company Culture\n", encoding="utf-8") - - # Migrate: move root-level memory.md into memory/ directory - if (ws / "memory.md").exists() and not (ws / "memory" / "memory.md").exists(): - import shutil - shutil.move(str(ws / "memory.md"), str(ws / "memory" / "memory.md")) - - # Create default memory file if missing - if not (ws / "memory" / "memory.md").exists(): - (ws / "memory" / "memory.md").write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") - - if not (ws / "soul.md").exists(): - # Try to load from DB + soul_key = normalize_storage_key(f"{agent_id}/soul.md") + if not await storage.is_file(soul_key): + soul_content = "# Personality\n\n_Describe your role and responsibilities._\n" try: async with async_session() as db: - - r = await db.execute(select(AgentModel).where(AgentModel.id == agent_id)) - agent = r.scalar_one_or_none() + result = await db.execute(select(AgentModel).where(AgentModel.id == agent_id)) + agent = result.scalar_one_or_none() if agent and agent.role_description: - (ws / "soul.md").write_text( - f"# Personality\n\n{agent.role_description}\n", - encoding="utf-8", - ) - else: - (ws / "soul.md").write_text("# Personality\n\n_Describe your role and responsibilities._\n", encoding="utf-8") + soul_content = f"# Personality\n\n{agent.role_description}\n" except Exception: - (ws / "soul.md").write_text("# Personality\n\n_Describe your role and responsibilities._\n", encoding="utf-8") + pass + await storage.write_text(soul_key, soul_content, encoding="utf-8") + - # Legacy compatibility: older workspaces may have tasks.json as a DB-backed - # task snapshot. Do not create it for new agents. - await _sync_tasks_to_file(agent_id, ws) +@dataclass +class TempWorkspaceManifestEntry: + rel_path: str + storage_key: str + base_version_token: str + base_hash: str + size: int - return ws + +@dataclass +class TempWorkspace: + temp_dir: tempfile.TemporaryDirectory + root: Path + agent_id: uuid.UUID + tenant_id: str | None + selected_paths: list[str] + manifest: dict[str, TempWorkspaceManifestEntry] + + def cleanup(self) -> None: + self.temp_dir.cleanup() + + +async def _materialize_storage_workspace(storage, storage_key: str, local_root: Path) -> None: + if not await storage.is_dir(storage_key): + return + for entry in await storage.list_dir(storage_key): + await _materialize_storage_entry(storage, entry.key, storage_key, local_root) + + +async def _materialize_storage_entry(storage, entry_key: str, root_key: str, local_root: Path) -> None: + rel = entry_key.removeprefix(root_key.rstrip("/") + "/") + target = (local_root / rel).resolve() + if not str(target).startswith(str(local_root.resolve())): + return + if await storage.is_dir(entry_key): + target.mkdir(parents=True, exist_ok=True) + for child in await storage.list_dir(entry_key): + await _materialize_storage_entry(storage, child.key, root_key, local_root) + return + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(await storage.read_bytes(entry_key)) + + +async def _prepare_temp_workspace( + agent_id: uuid.UUID, + tenant_id: str | None = None, + paths: list[str] | None = None, +) -> TempWorkspace: + tmp = tempfile.TemporaryDirectory(prefix=f"clawith-agent-{str(agent_id)[:8]}-") + temp_ws = Path(tmp.name) + for folder in ("workspace", "memory", "skills"): + (temp_ws / folder).mkdir(parents=True, exist_ok=True) + + storage = get_storage_backend() + budget = {"total": 0} + selected = TEMP_WORKSPACE_DEFAULT_PATHS if paths is None else [path for path in paths if path] + manifest: dict[str, TempWorkspaceManifestEntry] = {} + for rel_path in selected: + storage_key, normalized, is_enterprise = _tool_storage_key(agent_id, rel_path, tenant_id) + if is_enterprise: + continue + await _materialize_storage_path_with_budget(storage, storage_key, normalized, temp_ws, budget, manifest) + return TempWorkspace( + temp_dir=tmp, + root=temp_ws, + agent_id=agent_id, + tenant_id=tenant_id, + selected_paths=list(selected), + manifest=manifest, + ) + + +async def _materialize_storage_path_with_budget( + storage, + storage_key: str, + rel_path: str, + local_root: Path, + budget: dict, + manifest: dict[str, TempWorkspaceManifestEntry], +) -> None: + if await storage.is_file(storage_key): + version = await storage.get_version(storage_key) + if version.size > TOOL_MATERIALIZE_MAX_FILE_BYTES: + return + if budget["total"] + version.size > TOOL_MATERIALIZE_MAX_TOTAL_BYTES: + return + target = (local_root / rel_path).resolve() + if not str(target).startswith(str(local_root.resolve())): + return + target.parent.mkdir(parents=True, exist_ok=True) + data = await storage.read_bytes(storage_key) + target.write_bytes(data) + normalized_rel = normalize_workspace_path(rel_path) + manifest[normalized_rel] = TempWorkspaceManifestEntry( + rel_path=normalized_rel, + storage_key=storage_key, + base_version_token=version.token, + base_hash=content_hash_bytes(data), + size=version.size, + ) + budget["total"] += version.size + return + if await storage.is_dir(storage_key): + (local_root / rel_path).mkdir(parents=True, exist_ok=True) + for entry in await storage.list_dir(storage_key): + child_rel = f"{rel_path.rstrip('/')}/{entry.name}" if rel_path else entry.name + await _materialize_storage_path_with_budget(storage, entry.key, child_rel, local_root, budget, manifest) async def _sync_tasks_to_file(agent_id: uuid.UUID, ws: Path): @@ -2139,6 +2227,85 @@ async def _sync_tasks_to_file(agent_id: uuid.UUID, ws: Path): logger.error(f"[AgentTools] Failed to sync tasks: {e}") +async def flush_temp_workspace(temp_workspace: TempWorkspace, conflict_mode: str = "fail") -> dict[str, list[str]]: + """Flush local changes back to storage using manifest-based conflict checks.""" + storage = get_storage_backend() + selected_paths = [normalize_workspace_path(path) for path in temp_workspace.selected_paths] + manifest = temp_workspace.manifest + local_files = _collect_temp_workspace_files(temp_workspace.root, selected_paths) + + updated: list[str] = [] + conflicted: list[str] = [] + deleted: list[str] = [] + skipped: list[str] = [] + + async with workspace_locks(temp_workspace.agent_id, selected_paths): + for rel_path, local_path in local_files.items(): + if local_path.name.startswith("_exec_tmp") or "__pycache__" in local_path.parts: + continue + data = local_path.read_bytes() + current_hash = content_hash_bytes(data) + entry = manifest.get(rel_path) + if entry and entry.base_hash == current_hash: + skipped.append(rel_path) + continue + condition = ( + WriteCondition(version_token=entry.base_version_token) + if entry + else WriteCondition(require_absent=True) + ) + storage_key = entry.storage_key if entry else normalize_storage_key(f"{temp_workspace.agent_id}/{rel_path}") + result = await storage.write_bytes_if_match( + storage_key, + data, + condition=condition, + ) + if not result.ok: + conflicted.append(rel_path) + if conflict_mode == "fail": + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + continue + updated.append(rel_path) + + for rel_path, entry in manifest.items(): + if rel_path in local_files: + continue + result = await storage.delete_if_match( + entry.storage_key, + condition=WriteCondition(version_token=entry.base_version_token), + ) + if not result.ok: + conflicted.append(rel_path) + if conflict_mode == "fail": + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + continue + deleted.append(rel_path) + + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + + +def _collect_temp_workspace_files(root: Path, selected_paths: list[str]) -> dict[str, Path]: + files: dict[str, Path] = {} + root_resolved = root.resolve() + for selected in selected_paths: + if not selected: + continue + target = (root_resolved / selected).resolve() + if not str(target).startswith(str(root_resolved)): + continue + if target.is_file(): + files[normalize_workspace_path(selected)] = target + continue + if not target.exists() or not target.is_dir(): + continue + for path in target.rglob("*"): + if not path.is_file(): + continue + rel = path.resolve().relative_to(root_resolved).as_posix() + files[normalize_workspace_path(rel)] = path + return files + + # ─── Tool Executors ───────────────────────────────────────────── # Mapping from tool_name to autonomy action_type used for policy lookup and notifications. @@ -2178,6 +2345,172 @@ async def _get_agent_tenant_id(agent_id: uuid.UUID) -> str | None: return None +def _agent_workspace_root(agent_id: uuid.UUID) -> Path: + """Return the per-agent local path without creating or hydrating it.""" + return WORKSPACE_ROOT / str(agent_id) + + +def _non_empty_paths(*paths: str | None) -> list[str] | None: + selected = [path for path in paths if path] + return selected or None + + +async def _run_with_temp_workspace( + agent_id: uuid.UUID, + tenant_id: str | None, + runner, + *, + paths: list[str] | None = None, + sync_back: bool = False, +) -> str: + """Materialize a temporary workspace for tools that require local files.""" + temp_workspace = await _prepare_temp_workspace(agent_id, tenant_id=tenant_id, paths=paths) + try: + result = await runner(temp_workspace.root) + if sync_back: + flush_result = await flush_temp_workspace(temp_workspace, conflict_mode="fail") + if flush_result["conflicted"]: + conflict_list = ", ".join(flush_result["conflicted"][:5]) + return f"❌ Workspace sync conflict for: {conflict_list}" + return result + finally: + temp_workspace.cleanup() + + +async def _execute_workspace_mutation( + tool_name: str, + arguments: dict, + *, + agent_id: uuid.UUID, + base_dir: Path, + session_id: str | None, +) -> str: + """Handle shared workspace mutations for both direct and normal tool execution.""" + if tool_name == "write_file": + path = arguments.get("path") + content = arguments.get("content") + if not path: + return "❌ Missing required argument 'path' for write_file. Please provide a file path like 'skills/my-skill/SKILL.md'" + if content is None: + return "❌ Missing required argument 'content' for write_file" + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + write_result = await write_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=path, + content=content, + actor_type="agent", + actor_id=agent_id, + operation="write", + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + return ( + f"✅ Written to {write_result.path} ({len(content)} chars)" + if write_result.ok + else f"❌ {write_result.message}" + ) + + if tool_name == "move_file": + source_path = arguments.get("source_path") + destination_path = arguments.get("destination_path") + if not source_path: + return "❌ Missing required argument 'source_path' for move_file" + if not destination_path: + return "❌ Missing required argument 'destination_path' for move_file" + if str(source_path).strip("/") in {"tasks.json", "soul.md"}: + return f"❌ {source_path} cannot be moved (protected)" + if _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + move_result = await move_workspace_path( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + source_path=source_path, + destination_path=destination_path, + actor_type="agent", + actor_id=agent_id, + session_id=session_id, + enforce_human_lock=True, + overwrite=bool(arguments.get("overwrite", False)), + ) + await _wdb.commit() + return f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" + + if tool_name == "delete_file": + path = arguments.get("path", "") + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + delete_result = await delete_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=path, + actor_type="agent", + actor_id=agent_id, + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + return f"✅ Deleted {delete_result.path}" if delete_result.ok else f"❌ {delete_result.message}" + + if tool_name == "edit_file": + path = arguments.get("path") + old_string = arguments.get("old_string") + new_string = arguments.get("new_string") + if not path: + return "❌ Missing required argument 'path' for edit_file" + if old_string is None: + return "❌ Missing required argument 'old_string' for edit_file" + if new_string is None: + return "❌ Missing required argument 'new_string' for edit_file" + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + + replace_all = arguments.get("replace_all", False) + storage = get_storage_backend() + storage_key, normalized_path, _ = _tool_storage_key(agent_id, path, None) + if not await storage.is_file(storage_key): + return f"File not found: {path}" + + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") + if old_string not in content: + return f"❌ 'old_string' not found in {path}. Please check the exact text including whitespace and newlines." + count = content.count(old_string) + if count > 1 and not replace_all: + return f"❌ 'old_string' appears {count} times in {path}. Use replace_all=true or provide more context to make the match unique." + + new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1) + async with async_session() as _wdb: + write_result = await write_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=normalized_path, + content=new_content, + actor_type="agent", + actor_id=agent_id, + operation="edit", + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + replaced = count if replace_all else 1 + return ( + f"✅ Replaced {replaced} occurrence(s) in {write_result.path}" + if write_result.ok + else f"❌ {write_result.message}" + ) + + return f"Tool {tool_name} does not support workspace mutation execution" + + async def _execute_tool_direct( tool_name: str, arguments: dict, @@ -2189,48 +2522,24 @@ async def _execute_tool_direct( has been approved and needs to actually run. """ _agent_tenant_id = await _get_agent_tenant_id(agent_id) - ws = await ensure_workspace(agent_id, tenant_id=_agent_tenant_id) - try: - if tool_name == "delete_file": - path = arguments.get("path", "") - if _is_enterprise_info_path(path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - return _delete_file(ws, path) - elif tool_name == "write_file": - path = arguments.get("path") - content = arguments.get("content", "") - if not path: - return "Missing path" - if _is_enterprise_info_path(path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - return _write_file(ws, path, content, tenant_id=_agent_tenant_id) - elif tool_name == "move_file": - source_path = arguments.get("source_path") - destination_path = arguments.get("destination_path") - if not source_path or not destination_path: - return "Missing source_path or destination_path" - if _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - if str(source_path or "").strip("/").strip() in {"tasks.json", "soul.md"}: - return f"{source_path} cannot be moved (protected)" - async with async_session() as _wdb: - move_result = await move_workspace_path( - _wdb, - agent_id=agent_id, - base_dir=ws, - source_path=source_path, - destination_path=destination_path, - actor_type="agent", - actor_id=agent_id, - session_id=None, - enforce_human_lock=True, - overwrite=bool(arguments.get("overwrite", False)), - ) - await _wdb.commit() - return f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" + ws = _agent_workspace_root(agent_id) + try: + if tool_name in {"delete_file", "write_file", "move_file", "edit_file"}: + return await _execute_workspace_mutation( + tool_name, + arguments, + agent_id=agent_id, + base_dir=ws, + session_id=None, + ) elif tool_name in ("execute_code", "execute_code_e2b"): logger.info(f"[DirectTool] Executing code ({tool_name}) with arguments: {arguments}") - return await _execute_code(agent_id, ws, arguments, tool_name=tool_name) + return await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _execute_code(agent_id, temp_ws, arguments, tool_name=tool_name), + sync_back=True, + ) elif tool_name == "web_search": return await _web_search(arguments, agent_id) elif tool_name == "jina_search": @@ -2252,7 +2561,7 @@ async def _execute_tool_direct( elif tool_name == "send_message_to_agent": return await _send_message_to_agent(agent_id, arguments) elif tool_name == "send_file_to_agent": - return await _send_file_to_agent(agent_id, ws, arguments) + return await _send_file_to_agent(agent_id, arguments) else: return f"Tool {tool_name} does not support post-approval execution" except Exception as e: @@ -2287,7 +2596,7 @@ async def execute_tool( _agent_tenant_id = await _get_agent_tenant_id(agent_id) - ws = await ensure_workspace(agent_id, tenant_id=_agent_tenant_id) + ws = _agent_workspace_root(agent_id) # ── Autonomy boundary check ── action_type = _TOOL_AUTONOMY_MAP.get(tool_name) @@ -2332,163 +2641,75 @@ async def execute_tool( try: if tool_name == "list_files": - result = _list_files(ws, arguments.get("path", ""), tenant_id=_agent_tenant_id) + result = await _storage_list_dir(agent_id, arguments.get("path", ""), tenant_id=_agent_tenant_id) elif tool_name == "read_file": path = arguments.get("path") if not path: return "❌ Missing required argument 'path' for read_file" offset = int(arguments.get("offset", 0)) limit = int(arguments.get("limit", 2000)) - result = _read_file(ws, path, tenant_id=_agent_tenant_id, offset=offset, limit=limit) + result = await _storage_read_file(agent_id, path, tenant_id=_agent_tenant_id, offset=offset, limit=limit) elif tool_name == "read_document": path = arguments.get("path") if not path: return "❌ Missing required argument 'path' for read_document" max_chars = min(int(arguments.get("max_chars", 8000)), 20000) - result = await _read_document(ws, path, max_chars=max_chars, tenant_id=_agent_tenant_id) - elif tool_name == "write_file": - path = arguments.get("path") - content = arguments.get("content") - if not path: - return "❌ Missing required argument 'path' for write_file. Please provide a file path like 'skills/my-skill/SKILL.md'" - if content is None: - return "❌ Missing required argument 'content' for write_file" - if _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - write_result = await write_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - content=content, - actor_type="agent", - actor_id=agent_id, - operation="write", - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - result = ( - f"✅ Written to {write_result.path} ({len(content)} chars)" - if write_result.ok - else f"❌ {write_result.message}" - ) - elif tool_name == "move_file": - source_path = arguments.get("source_path") - destination_path = arguments.get("destination_path") - if not source_path: - return "❌ Missing required argument 'source_path' for move_file" - if not destination_path: - return "❌ Missing required argument 'destination_path' for move_file" - protected = {"tasks.json", "soul.md"} - if str(source_path).strip("/") in protected: - result = f"❌ {source_path} cannot be moved (protected)" - elif _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - move_result = await move_workspace_path( - _wdb, - agent_id=agent_id, - base_dir=ws, - source_path=source_path, - destination_path=destination_path, - actor_type="agent", - actor_id=agent_id, - session_id=session_id, - enforce_human_lock=True, - overwrite=bool(arguments.get("overwrite", False)), - ) - await _wdb.commit() - result = f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" - elif tool_name == "delete_file": - path = arguments.get("path", "") - if _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - delete_result = await delete_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - actor_type="agent", - actor_id=agent_id, - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - result = f"✅ Deleted {delete_result.path}" if delete_result.ok else f"❌ {delete_result.message}" + result = await _read_document_from_storage(agent_id, path, max_chars=max_chars, tenant_id=_agent_tenant_id) + elif tool_name in {"write_file", "move_file", "delete_file", "edit_file"}: + result = await _execute_workspace_mutation( + tool_name, + arguments, + agent_id=agent_id, + base_dir=ws, + session_id=session_id, + ) # --- Enhanced file management tools --- elif tool_name == "convert_csv_to_xlsx": - result = await _convert_csv_to_xlsx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_csv_to_xlsx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_html_to_pdf": - result = await _convert_html_to_pdf(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_html_to_pdf(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_html_to_pptx": - result = await _convert_html_to_pptx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_html_to_pptx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_markdown_to_docx": - result = await _convert_markdown_to_docx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_markdown_to_docx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_markdown_to_pdf": - result = await _convert_markdown_to_pdf(agent_id, ws, arguments) - elif tool_name == "edit_file": - path = arguments.get("path") - old_string = arguments.get("old_string") - new_string = arguments.get("new_string") - if not path: - return "❌ Missing required argument 'path' for edit_file" - if old_string is None: - return "❌ Missing required argument 'old_string' for edit_file" - if new_string is None: - return "❌ Missing required argument 'new_string' for edit_file" - replace_all = arguments.get("replace_all", False) - if _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - file_path = (ws / path).resolve() - if not str(file_path).startswith(str(ws.resolve())): - result = "Access denied for this path" - elif not file_path.exists(): - result = f"File not found: {path}" - elif not file_path.is_file(): - result = f"Not a file: {path}" - else: - content = await read_text_if_exists(file_path) or "" - if old_string not in content: - result = f"❌ 'old_string' not found in {path}. Please check the exact text including whitespace and newlines." - else: - count = content.count(old_string) - if count > 1 and not replace_all: - result = f"❌ 'old_string' appears {count} times in {path}. Use replace_all=true or provide more context to make the match unique." - else: - new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1) - async with async_session() as _wdb: - write_result = await write_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - content=new_content, - actor_type="agent", - actor_id=agent_id, - operation="edit", - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - replaced = count if replace_all else 1 - result = ( - f"✅ Replaced {replaced} occurrence(s) in {write_result.path}" - if write_result.ok - else f"❌ {write_result.message}" - ) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_markdown_to_pdf(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "search_files": pattern = arguments.get("pattern") if not pattern: return "❌ Missing required argument 'pattern' for search_files" - result = _search_files( - ws, + result = await _storage_search_files( + agent_id, pattern, path=arguments.get("path", "."), file_pattern=arguments.get("file_pattern", "*"), @@ -2499,8 +2720,8 @@ async def execute_tool( pattern = arguments.get("pattern") if not pattern: return "❌ Missing required argument 'pattern' for find_files" - result = _find_files( - ws, + result = await _storage_find_files( + agent_id, pattern, path=arguments.get("path", "."), tenant_id=_agent_tenant_id @@ -2529,9 +2750,18 @@ async def execute_tool( elif tool_name == "send_message_to_agent": result = await _send_message_to_agent(agent_id, arguments) elif tool_name == "send_file_to_agent": - result = await _send_file_to_agent(agent_id, ws, arguments) + result = await _send_file_to_agent(agent_id, arguments) elif tool_name == "send_channel_file": - result = await _send_channel_file(agent_id, ws, arguments) + file_path = (arguments.get("file_path") or "").strip() + if not file_path: + result = "Error: file_path is required" + else: + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _send_channel_file(agent_id, temp_ws, arguments), + paths=[file_path], + ) elif tool_name == "web_search": result = await _web_search(arguments, agent_id) elif tool_name == "jina_search": @@ -2558,15 +2788,41 @@ async def execute_tool( result = await _plaza_add_comment(agent_id, arguments) elif tool_name in ("execute_code", "execute_code_e2b"): logger.info(f"[DirectTool] Executing code ({tool_name}) with arguments: {arguments}") - result = await _execute_code(agent_id, ws, arguments, tool_name=tool_name) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _execute_code(agent_id, temp_ws, arguments, tool_name=tool_name), + sync_back=True, + ) elif tool_name == "upload_image": - result = await _upload_image(agent_id, ws, arguments) + file_path = (arguments.get("file_path") or "").strip() + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _upload_image(agent_id, temp_ws, arguments), + paths=_non_empty_paths(file_path), + ) elif tool_name == "generate_image_siliconflow": - result = await _generate_image(agent_id, ws, arguments, "siliconflow") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "siliconflow"), + sync_back=True, + ) elif tool_name == "generate_image_openai": - result = await _generate_image(agent_id, ws, arguments, "openai") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "openai"), + sync_back=True, + ) elif tool_name == "generate_image_google": - result = await _generate_image(agent_id, ws, arguments, "google") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "google"), + sync_back=True, + ) elif tool_name == "discover_resources": result = await _discover_resources(arguments) elif tool_name == "import_mcp_server": @@ -3934,6 +4190,191 @@ def _resolve_tool_target_path(ws: Path, rel_path: str, tenant_id: str | None = N return candidate +def _tool_storage_key(agent_id: uuid.UUID, rel_path: str, tenant_id: str | None = None) -> tuple[str, str, bool]: + normalized = normalize_workspace_path(_normalize_tool_rel_path(rel_path)) + if _is_enterprise_info_path(normalized): + if not tenant_id: + return normalize_storage_key("enterprise_info/" + normalized.removeprefix("enterprise_info").lstrip("/")), normalized, True + sub = normalized[len("enterprise_info"):].lstrip("/") + key = f"enterprise_info_{tenant_id}/{sub}" if sub else f"enterprise_info_{tenant_id}" + return normalize_storage_key(key), normalized, True + key = f"{agent_id}/{normalized}" if normalized else str(agent_id) + return normalize_storage_key(key), normalized, False + + +def _display_size(size_bytes: int) -> str: + return f"{size_bytes}B" if size_bytes < 1024 else f"{size_bytes / 1024:.1f}KB" + + +async def _storage_list_dir(agent_id: uuid.UUID, rel_path: str, tenant_id: str | None = None) -> str: + storage = get_storage_backend() + storage_key, normalized, is_enterprise = _tool_storage_key(agent_id, rel_path, tenant_id) + + exists = await storage.exists(storage_key) + is_dir = await storage.is_dir(storage_key) + if exists and not is_dir: + return f"Path is not a directory: {rel_path}" + if not exists and not is_dir and normalized: + return f"Directory not found: {rel_path or '/'}" + + items: list[str] = [] + dir_count = 0 + file_count = 0 + if not normalized and tenant_id: + items.append(" 📁 enterprise_info/ (shared company info)") + dir_count += 1 + + entries = await storage.list_dir(storage_key) if exists or is_dir else [] + for entry in entries: + if entry.name.startswith("."): + continue + if entry.is_dir: + dir_count += 1 + try: + child_count = len([c for c in await storage.list_dir(entry.key) if not c.name.startswith(".")]) + except Exception: + child_count = 0 + items.append(f" 📁 {entry.name}/ ({child_count} items)") + else: + file_count += 1 + items.append(f" 📄 {entry.name} ({_display_size(entry.size)})") + + if not items: + return f"📂 {rel_path or 'root'}: Empty directory (0 files, 0 folders)" + header = f"📂 {rel_path or 'root'}: {dir_count} folder(s), {file_count} file(s)\n" + return header + "\n".join(items) + + +async def _storage_read_file( + agent_id: uuid.UUID, + rel_path: str, + tenant_id: str | None = None, + offset: int = 0, + limit: int = 2000, +) -> str: + storage = get_storage_backend() + storage_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not normalized: + return "File not found: root" + if not await storage.is_file(storage_key): + return f"File not found: {rel_path}" + try: + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") + lines = content.splitlines() + total_lines = len(lines) + start = max(0, offset) + end = min(total_lines, start + limit) + if start >= total_lines and total_lines > 0: + return f"Offset {offset} exceeds file length ({total_lines} lines total)" + selected_lines = lines[start:end] + output = "\n".join(f"{i + 1:6}\t{line}" for i, line in enumerate(selected_lines, start=start)) + if total_lines > end: + output += f"\n\n... [{total_lines - end} more lines not shown, lines {end + 1}-{total_lines}]" + header = f"📄 {rel_path} (lines {start + 1 if total_lines else 0}-{end} of {total_lines})\n" + return header + output + except Exception as e: + return f"Read failed: {e}" + + +async def _storage_walk_files(storage, root_key: str) -> list: + out = [] + for entry in await storage.list_dir(root_key): + if entry.name.startswith("."): + continue + out.append(entry) + if entry.is_dir: + out.extend(await _storage_walk_files(storage, entry.key)) + return out + + +def _relative_storage_display(entry_key: str, base_key: str, display_base: str) -> str: + rel = entry_key.removeprefix(base_key.rstrip("/") + "/") + return f"{display_base.rstrip('/')}/{rel}".strip("/") if display_base else rel + + +async def _storage_search_files( + agent_id: uuid.UUID, + pattern: str, + path: str = ".", + file_pattern: str = "*", + ignore_case: bool = False, + tenant_id: str | None = None, +) -> str: + storage = get_storage_backend() + rel_path = "" if path in ("", ".") else path + base_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not await storage.is_dir(base_key) and normalized: + return f"Directory not found: {path}" + flags = re.IGNORECASE if ignore_case else 0 + try: + regex = re.compile(pattern, flags) + except re.error as e: + return f"Invalid regex pattern: {e}" + + results: list[str] = [] + total_matches = 0 + files_searched = 0 + entries = await _storage_walk_files(storage, base_key) if await storage.is_dir(base_key) else [] + for entry in entries: + if entry.is_dir: + continue + rel_display = _relative_storage_display(entry.key, base_key, normalized) + if not fnmatch.fnmatch(Path(rel_display).name, file_pattern) and not fnmatch.fnmatch(rel_display, file_pattern): + continue + if Path(rel_display).suffix.lower() in {".pyc", ".pyo", ".so", ".dll", ".exe", ".bin", ".png", ".jpg", ".jpeg", ".gif", ".zip", ".tar", ".gz"}: + continue + files_searched += 1 + try: + content = await storage.read_text(entry.key, encoding="utf-8", errors="ignore") + except Exception: + continue + for i, line in enumerate(content.splitlines(), 1): + if regex.search(line): + results.append(f"{rel_display}:{i}: {line.strip()[:100]}") + total_matches += 1 + if len(results) >= 50: + break + if len(results) >= 50: + break + if not results: + return f"No matches found for pattern '{pattern}' in {files_searched} file(s)" + truncated = total_matches > len(results) + truncation_note = f" (showing first {len(results)} of {total_matches}+ — refine pattern or path for more)" if truncated else "" + return f"🔍 Found {total_matches}+ match(es) in {files_searched} file(s) for pattern '{pattern}'{truncation_note}:\n" + "\n".join(results) + + +async def _storage_find_files( + agent_id: uuid.UUID, + pattern: str, + path: str = ".", + tenant_id: str | None = None, +) -> str: + storage = get_storage_backend() + rel_path = "" if path in ("", ".") else path + base_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not await storage.is_dir(base_key) and normalized: + return f"Directory not found: {path}" + entries = await _storage_walk_files(storage, base_key) if await storage.is_dir(base_key) else [] + matches = [] + for entry in entries: + rel_display = _relative_storage_display(entry.key, base_key, normalized) + if fnmatch.fnmatch(rel_display, pattern) or fnmatch.fnmatch(Path(rel_display).name, pattern): + matches.append((entry, rel_display)) + if not matches: + return f"No files matching pattern: {pattern}" + results = [] + dir_count = 0 + file_count = 0 + for entry, rel_display in matches[:100]: + if entry.is_dir: + dir_count += 1 + results.append(f"📁 {rel_display}/") + else: + file_count += 1 + results.append(f"📄 {rel_display} ({_display_size(entry.size)})") + return f"📂 Found {len(matches)} item(s) ({dir_count} dirs, {file_count} files) matching '{pattern}':\n" + "\n".join(results) + + def _list_files(ws: Path, rel_path: str, tenant_id: str | None = None) -> str: # Handle enterprise_info/ as shared directory (tenant-scoped) if rel_path and rel_path.startswith("enterprise_info"): @@ -4337,6 +4778,19 @@ async def _read_document(ws: Path, rel_path: str, max_chars: int = 8000, tenant_ return await asyncio.to_thread(_read_document_with_timeout, ws, rel_path, max_chars, tenant_id) +async def _read_document_from_storage( + agent_id: uuid.UUID, + rel_path: str, + max_chars: int = 8000, + tenant_id: str | None = None, +) -> str: + temp_workspace = await _prepare_temp_workspace(agent_id, tenant_id=tenant_id, paths=[rel_path]) + try: + return await _read_document(temp_workspace.root, rel_path, max_chars=max_chars, tenant_id=None) + finally: + temp_workspace.cleanup() + + # ─── Format Conversion Tools ──────────────────────────────────── async def _convert_csv_to_xlsx(agent_id: uuid.UUID, ws: Path, arguments: dict) -> str: @@ -5751,7 +6205,7 @@ async def _send_platform_message(agent_id: uuid.UUID, args: dict) -> str: return f"❌ Web message send error: {str(e)[:200]}" -async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> str: +async def _send_file_to_agent(from_agent_id: uuid.UUID, args: dict) -> str: """Send a workspace file to another digital employee (agent).""" agent_name = (args.get("agent_name") or "").strip() rel_path = (args.get("file_path") or "").strip() @@ -5760,35 +6214,28 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> if not agent_name or not rel_path: return "❌ Please provide both agent_name and file_path" - # Resolve source file path inside sender workspace - source_file_path = (ws / rel_path).resolve() - ws_resolved = ws.resolve() - sender_root = (WORKSPACE_ROOT / str(from_agent_id)).resolve() - if not str(source_file_path).startswith(str(ws_resolved)): - source_file_path = (sender_root / rel_path).resolve() - if not str(source_file_path).startswith(str(sender_root)): - return "❌ Access denied: source path is outside your workspace" - - if not source_file_path.exists(): + storage = get_storage_backend() + source_key = normalize_storage_key(f"{from_agent_id}/{rel_path}") + if not await storage.is_file(source_key): return f"❌ Source file not found: {rel_path}" - if not source_file_path.is_file(): - return f"❌ Source path is not a file: {rel_path}" + source_entry = await storage.stat(source_key) # File size limit (50 MB) MAX_FILE_SIZE = 50 * 1024 * 1024 - file_size = source_file_path.stat().st_size + file_size = source_entry.size if file_size > MAX_FILE_SIZE: size_mb = file_size / (1024 * 1024) return f"❌ File too large ({size_mb:.1f} MB). Maximum allowed is 50 MB." + source_bytes = await storage.read_bytes(source_key) + source_name = Path(rel_path).name try: from app.services.activity_logger import log_activity - import shutil async with async_session() as db: src_result = await db.execute(select(AgentModel).where(AgentModel.id == from_agent_id)) source_agent = src_result.scalar_one_or_none() - source_name = source_agent.name if source_agent else "Unknown agent" + source_agent_name = source_agent.name if source_agent else "Unknown agent" source_tenant_id = source_agent.tenant_id if source_agent else None # Build base filter: same tenant + not self @@ -5835,38 +6282,29 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> if not rel_check.scalar_one_or_none(): return f"❌ You do not have a relationship with {target_agent.name}. Only agents in your relationship list can receive files. Ask your administrator to add a relationship if needed." - target_tenant_id = str(target_agent.tenant_id) if target_agent.tenant_id else None target_name = target_agent.name target_id = target_agent.id - target_ws = await ensure_workspace(target_id, tenant_id=target_tenant_id) - inbox_dir = (target_ws / "workspace" / "inbox").resolve() - files_dir = (inbox_dir / "files").resolve() - target_ws_resolved = target_ws.resolve() - if not str(inbox_dir).startswith(str(target_ws_resolved)) or not str(files_dir).startswith(str(target_ws_resolved)): - return "❌ Access denied for target agent inbox path" - - inbox_dir.mkdir(parents=True, exist_ok=True) - files_dir.mkdir(parents=True, exist_ok=True) - ts = datetime.now(timezone.utc) stamp = ts.strftime("%Y%m%d_%H%M%S_%f") - delivered_name = source_file_path.name - delivered_path = files_dir / delivered_name - while delivered_path.exists(): - delivered_name = f"{stamp}_{source_file_path.name}" - delivered_path = files_dir / delivered_name + delivered_name = source_name + target_rel_path = f"workspace/inbox/files/{delivered_name}" + target_key = normalize_storage_key(f"{target_id}/{target_rel_path}") + while await storage.exists(target_key): + delivered_name = f"{stamp}_{source_name}" + target_rel_path = f"workspace/inbox/files/{delivered_name}" + target_key = normalize_storage_key(f"{target_id}/{target_rel_path}") - shutil.copy2(source_file_path, delivered_path) + await storage.write_bytes(target_key, source_bytes) sender_short = str(from_agent_id)[:8] - note_path = inbox_dir / f"{stamp}_{sender_short}_file_delivery.md" - target_rel_path = f"workspace/inbox/files/{delivered_name}" + note_rel_path = f"workspace/inbox/{stamp}_{sender_short}_file_delivery.md" + note_key = normalize_storage_key(f"{target_id}/{note_rel_path}") note_lines = [ - f"# File delivery from {source_name}", + f"# File delivery from {source_agent_name}", "", f"- Time (UTC): {ts.isoformat()}", - f"- Sender: {source_name}", + f"- Sender: {source_agent_name}", f"- Source path: {rel_path}", f"- Delivered file: {target_rel_path}", "", @@ -5877,7 +6315,7 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> note_lines.append("") note_lines.append("## Action") note_lines.append(f"- Read the file via `read_file(path=\"{target_rel_path}\")`") - note_path.write_text("\n".join(note_lines), encoding="utf-8") + await storage.write_text(note_key, "\n".join(note_lines), encoding="utf-8") from app.models.audit import AuditLog async with async_session() as db: @@ -5896,7 +6334,7 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> action="collaboration:file_receive", details={ "from_agent": str(from_agent_id), - "from_agent_name": source_name, + "from_agent_name": source_agent_name, "source_file": rel_path, "delivered_file": target_rel_path, }, @@ -5912,14 +6350,14 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> await log_activity( target_id, "agent_file_received", - f"Received file from {source_name}", - detail={"source_agent": source_name, "source_file": rel_path, "delivered_file": target_rel_path}, + f"Received file from {source_agent_name}", + detail={"source_agent": source_agent_name, "source_file": rel_path, "delivered_file": target_rel_path}, ) return ( f"✅ File sent to {target_name}.\n" f"- Delivered to: {target_rel_path}\n" - f"- Inbox note: workspace/inbox/{note_path.name}" + f"- Inbox note: {note_rel_path}" ) except Exception as e: return f"❌ Agent file send error: {str(e)[:200]}" diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index b10fcc0e7..ef526e514 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -14,12 +14,14 @@ import json import uuid +from pathlib import Path from typing import TYPE_CHECKING from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.database import async_session from app.services.agent_tools import AGENT_TOOLS, execute_tool, get_agent_tools_for_llm from app.services.token_tracker import ( @@ -348,8 +350,8 @@ async def _process_tool_call( if supports_vision and agent_id: try: from app.services.vision_inject import try_inject_screenshot_vision - from app.services.storage import ensure_local_path - ws_path = await ensure_local_path(str(agent_id)) + settings = get_settings() + ws_path = Path(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) / str(agent_id) vision_content = try_inject_screenshot_vision(tool_name, str(result), ws_path) if vision_content: tool_content = vision_content diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index 5df5dc7f6..45f0cb71c 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -9,7 +9,7 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, delete, func, or_, select, update @@ -51,6 +51,10 @@ def pinyin(value: str, style: str | None = None) -> list[list[str]]: from jose import jwt +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + def build_department_path_map(departments: list[OrgDepartment]) -> dict[uuid.UUID, str]: """Build department name paths by walking the internal department tree.""" dept_by_id = {dept.id: dept for dept in departments} @@ -239,7 +243,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: member_count = 0 user_count = 0 profile_count = 0 - sync_start = datetime.now() + sync_start = _utcnow() partial_failure = False # Ensure provider exists @@ -291,7 +295,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: # Update provider metadata if possible if self.provider: config = (self.provider.config or {}).copy() - config["last_synced_at"] = datetime.now().isoformat() + config["last_synced_at"] = _utcnow().isoformat() self.provider.config = config await db.flush() @@ -321,7 +325,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: "profiles_synced": profile_count, "errors": errors, "provider": self.provider_type, - "synced_at": datetime.now().isoformat() + "synced_at": _utcnow().isoformat() } async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: datetime): @@ -333,7 +337,8 @@ async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: .where(OrgMember.provider_id == provider_id) .where(OrgMember.synced_at < sync_start) .where(OrgMember.status != "deleted") - .values(status="deleted", synced_at=datetime.now()) + .values(status="deleted", synced_at=_utcnow()) + .execution_options(synchronize_session=False) ) # 2. Departments reconciled @@ -342,7 +347,8 @@ async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: .where(OrgDepartment.provider_id == provider_id) .where(OrgDepartment.synced_at < sync_start) .where(OrgDepartment.status != "deleted") - .values(status="deleted", synced_at=datetime.now()) + .values(status="deleted", synced_at=_utcnow()) + .execution_options(synchronize_session=False) ) async def _update_member_counts(self, db: AsyncSession, provider_id: uuid.UUID): @@ -457,7 +463,7 @@ async def _upsert_department( ) existing = result.scalars().first() - now = datetime.now() + now = _utcnow() # Path is rebuilt from the internal department tree after sync. path = dept.name @@ -569,7 +575,7 @@ async def _upsert_member( existing_member = await self._find_existing_member(db, provider, user) - now = datetime.now() + now = _utcnow() # Note: Platform user creation is disabled - just sync OrgMember # Users will be linked to platform users manually or via SSO login diff --git a/backend/app/services/skill_seeder.py b/backend/app/services/skill_seeder.py index d50c85534..a56851f2f 100644 --- a/backend/app/services/skill_seeder.py +++ b/backend/app/services/skill_seeder.py @@ -930,11 +930,11 @@ async def push_default_skills_to_existing_agents(): Called at startup after seed_skills() so existing agents automatically receive new default skills like MCP_INSTALLER without requiring manual re-creation. """ - from pathlib import Path from app.models.agent import Agent - from app.models.skill import Skill, SkillFile + from app.models.skill import Skill from sqlalchemy.orm import selectinload from app.services.agent_manager import agent_manager + from app.services.storage import get_storage_backend async with async_session() as db: # Load all is_default skills with their files @@ -951,25 +951,22 @@ async def push_default_skills_to_existing_agents(): pushed = 0 updated = 0 + storage = get_storage_backend() for agent in agents: - agent_dir = agent_manager._agent_dir(agent.id) - skills_dir = agent_dir / "skills" + agent_prefix = agent_manager._agent_storage_prefix(agent.id) for skill in default_skills: if not skill.files: continue - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) for sf in skill.files: - fp = (skill_folder / sf.path).resolve() - fp.parent.mkdir(parents=True, exist_ok=True) - if fp.exists(): - existing_content = fp.read_text(encoding="utf-8") + key = f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}" + if await storage.is_file(key): + existing_content = await storage.read_text(key, encoding="utf-8", errors="replace") if existing_content == sf.content: continue # already up-to-date - fp.write_text(sf.content, encoding="utf-8") + await storage.write_text(key, sf.content, encoding="utf-8") updated += 1 else: - fp.write_text(sf.content, encoding="utf-8") + await storage.write_text(key, sf.content, encoding="utf-8") pushed += 1 logger.info(f"[SkillSeeder] Pushed '{skill.name}' to agent {agent.id}") diff --git a/backend/app/services/storage_runtime/__init__.py b/backend/app/services/storage_runtime/__init__.py index 185c8aadd..756e2e4fd 100644 --- a/backend/app/services/storage_runtime/__init__.py +++ b/backend/app/services/storage_runtime/__init__.py @@ -1,6 +1,12 @@ """Storage runtime package.""" -from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) from app.services.storage_runtime.agent_files import ( agent_storage_key, agent_upload_key, @@ -25,6 +31,9 @@ __all__ = [ "StorageBackend", "StorageEntry", + "StorageVersion", + "WriteCondition", + "ConditionalWriteResult", "FallbackStorageBackend", "LocalStorageBackend", "S3StorageBackend", diff --git a/backend/app/services/storage_runtime/base.py b/backend/app/services/storage_runtime/base.py index 2b11bd7d2..ac7c46369 100644 --- a/backend/app/services/storage_runtime/base.py +++ b/backend/app/services/storage_runtime/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import hashlib from dataclasses import dataclass from pathlib import Path @@ -13,6 +14,38 @@ class StorageEntry: is_dir: bool size: int = 0 modified_at: str = "" + etag: str = "" + version_id: str = "" + content_hash: str = "" + + +@dataclass +class StorageVersion: + key: str + exists: bool + is_dir: bool + size: int = 0 + modified_at: str = "" + etag: str = "" + version_id: str = "" + content_hash: str = "" + + @property + def token(self) -> str: + return self.version_id or self.etag or self.content_hash or f"{self.modified_at}:{self.size}" + + +@dataclass +class WriteCondition: + version_token: str | None = None + require_absent: bool = False + + +@dataclass +class ConditionalWriteResult: + ok: bool + conflict: bool = False + current_version: StorageVersion | None = None class StorageBackend: @@ -50,8 +83,63 @@ async def delete_tree(self, key: str) -> None: async def stat(self, key: str) -> StorageEntry: raise NotImplementedError + async def get_version(self, key: str) -> StorageVersion: + try: + entry = await self.stat(key) + except FileNotFoundError: + return StorageVersion(key=key, exists=False, is_dir=False) + return StorageVersion( + key=entry.key, + exists=True, + is_dir=entry.is_dir, + size=entry.size, + modified_at=entry.modified_at, + etag=entry.etag, + version_id=entry.version_id, + content_hash=entry.content_hash, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def delete_if_match( + self, + key: str, + *, + condition: WriteCondition | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent: + if current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + return ConditionalWriteResult(ok=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if current.exists: + await self.delete(key) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + async def local_path_for(self, key: str) -> Path | None: return None async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: return None + + +def content_hash_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() diff --git a/backend/app/services/storage_runtime/facade.py b/backend/app/services/storage_runtime/facade.py index 1405985fd..00133a217 100644 --- a/backend/app/services/storage_runtime/facade.py +++ b/backend/app/services/storage_runtime/facade.py @@ -35,6 +35,8 @@ def get_storage_backend() -> StorageBackend: access_key_id=settings.S3_ACCESS_KEY_ID, secret_access_key=settings.S3_SECRET_ACCESS_KEY, presign_ttl_seconds=settings.S3_PRESIGN_TTL_SECONDS, + max_pool_connections=settings.S3_MAX_POOL_CONNECTIONS, + write_workers=settings.S3_WRITE_WORKERS, ) if settings.STORAGE_LOCAL_FALLBACK_ENABLED: fallback = LocalStorageBackend(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) diff --git a/backend/app/services/storage_runtime/fallback.py b/backend/app/services/storage_runtime/fallback.py index f8c2540cf..5699cfd6d 100644 --- a/backend/app/services/storage_runtime/fallback.py +++ b/backend/app/services/storage_runtime/fallback.py @@ -4,7 +4,13 @@ from pathlib import Path -from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) class FallbackStorageBackend(StorageBackend): @@ -63,6 +69,27 @@ async def stat(self, key: str) -> StorageEntry: await self.primary.write_bytes(key, data) return entry + async def get_version(self, key: str) -> StorageVersion: + primary_version = await self.primary.get_version(key) + if primary_version.exists: + return primary_version + fallback_version = await self.fallback.get_version(key) + if fallback_version.exists and not fallback_version.is_dir: + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return await self.primary.get_version(key) + return fallback_version + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + return await self.primary.write_bytes_if_match(key, data, condition=condition, content_type=content_type) + async def local_path_for(self, key: str) -> Path | None: if await self.primary.exists(key): return await self.primary.local_path_for(key) diff --git a/backend/app/services/storage_runtime/local.py b/backend/app/services/storage_runtime/local.py index 3128c6108..799361788 100644 --- a/backend/app/services/storage_runtime/local.py +++ b/backend/app/services/storage_runtime/local.py @@ -8,7 +8,14 @@ import aiofiles from fastapi import HTTPException, status -from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, + content_hash_bytes, +) from app.services.storage_runtime.utils import normalize_storage_key @@ -50,6 +57,7 @@ async def list_dir(self, key: str) -> list[StorageEntry]: is_dir=entry.is_dir(), size=stat.st_size if entry.is_file() else 0, modified_at=str(stat.st_mtime), + version_id=_local_version_token(stat, None), ) ) return entries @@ -83,14 +91,66 @@ async def delete_tree(self, key: str) -> None: async def stat(self, key: str) -> StorageEntry: path = self._full_path(key) stat = path.stat() + file_hash = "" + version_id = _local_version_token(stat, None) + if path.is_file(): + data = await self.read_bytes(key) + file_hash = content_hash_bytes(data) + version_id = _local_version_token(stat, file_hash) return StorageEntry( name=path.name, key=normalize_storage_key(key), is_dir=path.is_dir(), size=stat.st_size if path.is_file() else 0, modified_at=str(stat.st_mtime), + version_id=version_id, + etag=file_hash, + content_hash=file_hash, ) + async def get_version(self, key: str) -> StorageVersion: + path = self._full_path(key) + if not path.exists(): + return StorageVersion(key=normalize_storage_key(key), exists=False, is_dir=False) + stat = path.stat() + if path.is_dir(): + return StorageVersion( + key=normalize_storage_key(key), + exists=True, + is_dir=True, + modified_at=str(stat.st_mtime), + version_id=_local_version_token(stat, None), + ) + data = await self.read_bytes(key) + file_hash = content_hash_bytes(data) + return StorageVersion( + key=normalize_storage_key(key), + exists=True, + is_dir=False, + size=stat.st_size, + modified_at=str(stat.st_mtime), + etag=file_hash, + version_id=_local_version_token(stat, file_hash), + content_hash=file_hash, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + async def local_path_for(self, key: str) -> Path | None: return self._full_path(key) @@ -99,3 +159,8 @@ def _local_delete_tree(path: Path) -> None: import shutil shutil.rmtree(path) + + +def _local_version_token(stat, file_hash: str | None) -> str: + hash_part = file_hash or "" + return f"{stat.st_mtime_ns}:{stat.st_size}:{hash_part}" diff --git a/backend/app/services/storage_runtime/s3.py b/backend/app/services/storage_runtime/s3.py index b0b190a74..808565a2b 100644 --- a/backend/app/services/storage_runtime/s3.py +++ b/backend/app/services/storage_runtime/s3.py @@ -3,11 +3,20 @@ from __future__ import annotations import asyncio +from contextlib import asynccontextmanager from pathlib import Path from tempfile import NamedTemporaryFile from typing import Any -from app.services.storage_runtime.base import StorageBackend, StorageEntry +from loguru import logger + +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) from app.services.storage_runtime.utils import normalize_storage_key @@ -22,6 +31,8 @@ def __init__( access_key_id: str = "", secret_access_key: str = "", presign_ttl_seconds: int = 3600, + max_pool_connections: int = 50, + write_workers: int = 32, ): self.bucket = bucket self.prefix = normalize_storage_key(prefix) @@ -30,7 +41,9 @@ def __init__( self.access_key_id = access_key_id or None self.secret_access_key = secret_access_key or None self.presign_ttl_seconds = presign_ttl_seconds + self.max_pool_connections = max_pool_connections self._client: Any | None = None + self._aioboto3_session: Any | None = None def _object_key(self, key: str) -> str: normalized = normalize_storage_key(key) @@ -40,6 +53,7 @@ def _client_or_raise(self): if self._client is None: try: import boto3 + from botocore.config import Config except ImportError as exc: raise RuntimeError("boto3 is required for S3 storage backend") from exc self._client = boto3.client( @@ -48,18 +62,62 @@ def _client_or_raise(self): endpoint_url=self.endpoint_url, aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key, + config=Config( + max_pool_connections=self.max_pool_connections, + proxies={}, + s3={"addressing_style": "path"}, + signature_version="s3v4", + connect_timeout=5, + read_timeout=30, + tcp_keepalive=True, + ), ) return self._client - async def exists(self, key: str) -> bool: + @asynccontextmanager + async def _async_client(self): + """Shared aioboto3 session with aiohttp connection pool — reuses connections but detects stale ones correctly.""" try: - await self.stat(key) - return True - except FileNotFoundError: - return False + import aioboto3 + from botocore.config import Config + except ImportError as exc: + raise RuntimeError("aioboto3 is required for async S3 writes") from exc + if self._aioboto3_session is None: + self._aioboto3_session = aioboto3.Session() + async with self._aioboto3_session.client( + "s3", + region_name=self.region or None, + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + config=Config( + max_pool_connections=self.max_pool_connections, + proxies={}, + s3={"addressing_style": "path"}, + signature_version="s3v4", + connect_timeout=5, + read_timeout=30, + tcp_keepalive=True, + ), + ) as client: + yield client + + async def exists(self, key: str) -> bool: + return await self._object_exists(key) async def is_file(self, key: str) -> bool: - return await self.exists(key) + return await self._object_exists(key) + + async def _object_exists(self, key: str) -> bool: + object_key = self._object_key(key) + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=object_key, + MaxKeys=1, + ) + return any(item.get("Key") == object_key for item in response.get("Contents", [])) async def is_dir(self, key: str) -> bool: prefix = self._object_key(key).rstrip("/") + "/" @@ -103,6 +161,7 @@ async def list_dir(self, key: str) -> list[StorageEntry]: is_dir=False, size=int(item.get("Size", 0)), modified_at=str(item.get("LastModified") or ""), + etag=_clean_etag(item.get("ETag")), ) ) return sorted(entries, key=lambda entry: (not entry.is_dir, entry.name)) @@ -118,7 +177,6 @@ async def read_bytes(self, key: str) -> bytes: return await asyncio.to_thread(body.read) async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: - client = self._client_or_raise() kwargs: dict[str, Any] = { "Bucket": self.bucket, "Key": self._object_key(key), @@ -126,15 +184,15 @@ async def write_bytes(self, key: str, data: bytes, content_type: str | None = No } if content_type: kwargs["ContentType"] = content_type - await asyncio.to_thread(client.put_object, **kwargs) + async with self._async_client() as client: + await client.put_object(**kwargs) async def delete(self, key: str) -> None: - client = self._client_or_raise() - await asyncio.to_thread( - client.delete_object, - Bucket=self.bucket, - Key=self._object_key(key), - ) + async with self._async_client() as client: + await client.delete_object( + Bucket=self.bucket, + Key=self._object_key(key), + ) async def delete_tree(self, key: str) -> None: client = self._client_or_raise() @@ -148,30 +206,73 @@ async def delete_tree(self, key: str) -> None: if not contents: return objects = [{"Key": item["Key"]} for item in contents] - await asyncio.to_thread( - client.delete_objects, - Bucket=self.bucket, - Delete={"Objects": objects}, - ) + async with self._async_client() as client: + await client.delete_objects( + Bucket=self.bucket, + Delete={"Objects": objects}, + ) async def stat(self, key: str) -> StorageEntry: + version = await self.get_version(key) + if not version.exists: + raise FileNotFoundError(key) + return StorageEntry( + name=normalize_storage_key(key).split("/")[-1], + key=normalize_storage_key(key), + is_dir=version.is_dir, + size=version.size, + modified_at=version.modified_at, + etag=version.etag, + version_id=version.version_id, + content_hash=version.content_hash, + ) + + async def get_version(self, key: str) -> StorageVersion: client = self._client_or_raise() + object_key = self._object_key(key) try: response = await asyncio.to_thread( client.head_object, Bucket=self.bucket, - Key=self._object_key(key), + Key=object_key, ) - except Exception as exc: - raise FileNotFoundError(key) from exc - return StorageEntry( - name=normalize_storage_key(key).split("/")[-1], + except Exception: + return StorageVersion(key=normalize_storage_key(key), exists=False, is_dir=False) + return StorageVersion( key=normalize_storage_key(key), + exists=True, is_dir=False, size=int(response.get("ContentLength", 0)), modified_at=str(response.get("LastModified") or ""), + etag=_clean_etag(response.get("ETag")), + version_id=str(response.get("VersionId") or ""), + content_hash=_clean_etag(response.get("ETag")), ) + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def _put_succeeded(self, key: str, expected_size: int) -> bool: + try: + entry = await self.stat(key) + except Exception: + return False + return entry.size == expected_size + async def local_path_for(self, key: str) -> Path | None: suffix = Path(normalize_storage_key(key)).suffix tmp = NamedTemporaryFile(delete=False, suffix=suffix) @@ -202,3 +303,18 @@ def _strip_prefix(raw_key: str, prefix: str) -> str: if prefix and raw_key.startswith(prefix + "/"): return raw_key[len(prefix) + 1:] return raw_key + + +def _is_header_parsing_error(exc: Exception) -> bool: + try: + from urllib3.exceptions import HeaderParsingError + except Exception: + return False + return isinstance(exc, HeaderParsingError) + + +def _clean_etag(raw: Any) -> str: + if raw is None: + return "" + text = str(raw) + return text.strip('"') diff --git a/backend/app/services/wechat_channel.py b/backend/app/services/wechat_channel.py index 8ad24fe26..2fd114781 100644 --- a/backend/app/services/wechat_channel.py +++ b/backend/app/services/wechat_channel.py @@ -307,6 +307,7 @@ class WeChatPollManager: def __init__(self) -> None: self._tasks: dict[uuid.UUID, asyncio.Task] = {} self._connected: dict[uuid.UUID, bool] = {} + self._reconcile_interval_seconds = 30 async def start_client(self, agent_id: uuid.UUID, stop_existing: bool = True) -> None: if stop_existing: @@ -327,6 +328,13 @@ async def stop_client(self, agent_id: uuid.UUID) -> None: await self._set_connected(agent_id, False) async def start_all(self) -> None: + logger.info("[WeChat] Poll manager started") + while True: + await self.reconcile_clients() + await asyncio.sleep(self._reconcile_interval_seconds) + + async def reconcile_clients(self) -> None: + configured_agent_ids: set[uuid.UUID] = set() async with async_session() as db: result = await db.execute( select(ChannelConfig).where( @@ -337,7 +345,16 @@ async def start_all(self) -> None: for cfg in result.scalars().all(): token = str((cfg.extra_config or {}).get("bot_token") or "").strip() if token: - await self.start_client(cfg.agent_id) + configured_agent_ids.add(cfg.agent_id) + + for agent_id in configured_agent_ids: + task = self._tasks.get(agent_id) + if task is None or task.done(): + await self.start_client(agent_id) + + for agent_id in list(self._tasks): + if agent_id not in configured_agent_ids: + await self.stop_client(agent_id) async def _run_client(self, agent_id: uuid.UUID) -> None: retry_delay = 2 diff --git a/backend/app/services/workspace_collaboration.py b/backend/app/services/workspace_collaboration.py index 5d390ba57..d8269aac2 100644 --- a/backend/app/services/workspace_collaboration.py +++ b/backend/app/services/workspace_collaboration.py @@ -19,6 +19,9 @@ from app.models.workspace import WorkspaceEditLock, WorkspaceFileRevision from app.services.storage import get_storage_backend, normalize_storage_key +from app.services.storage_runtime.base import WriteCondition +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.workspace_locking import workspace_locks USER_AUTOSAVE_MERGE_SECONDS = 60 EDIT_LOCK_TTL_SECONDS = 90 @@ -64,6 +67,11 @@ class WorkspaceWriteResult: locked_by_user_id: str | None = None +def _should_mirror_to_local_filesystem(storage) -> bool: + """Only mirror writes into AGENT_DATA_DIR when the filesystem is the primary store.""" + return isinstance(storage, LocalStorageBackend) + + def content_hash(content: str | None) -> str: """Return a stable hash for text content.""" return hashlib.sha256((content or "").encode("utf-8")).hexdigest() @@ -271,6 +279,7 @@ async def write_workspace_file( session_id: str | None = None, enforce_human_lock: bool = True, merge_user_autosave: bool = False, + expected_version_token: str | None = None, ) -> WorkspaceWriteResult: """Write text content, enforcing human locks for agent/system actors.""" normalized = normalize_workspace_path(path) @@ -292,14 +301,21 @@ async def write_workspace_file( storage = get_storage_backend() storage_key = normalize_storage_key(f"{agent_id}/{normalized}") - local_base_available = False + local_base_available = _should_mirror_to_local_filesystem(storage) try: target = safe_agent_path(base_dir, normalized) - local_base_available = True except Exception: target = None + local_base_available = False before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.exists(storage_key) else None - await storage.write_text(storage_key, content, encoding="utf-8") + write_result = await storage.write_bytes_if_match( + storage_key, + content.encode("utf-8"), + condition=WriteCondition(version_token=expected_version_token) if expected_version_token is not None else None, + content_type="text/plain; charset=utf-8", + ) + if not write_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while writing {normalized}") if local_base_available and target is not None: target.parent.mkdir(parents=True, exist_ok=True) async with aiofiles.open(target, "w", encoding="utf-8") as f: @@ -335,15 +351,18 @@ async def delete_workspace_file( actor_id: uuid.UUID | None, session_id: str | None = None, enforce_human_lock: bool = True, + expected_version_token: str | None = None, ) -> WorkspaceWriteResult: """Delete a workspace file and record the deleted content.""" normalized = normalize_workspace_path(path) storage = get_storage_backend() storage_key = normalize_storage_key(f"{agent_id}/{normalized}") - try: - target = safe_agent_path(base_dir, normalized) - except Exception: - target = None + target = None + if _should_mirror_to_local_filesystem(storage): + try: + target = safe_agent_path(base_dir, normalized) + except Exception: + target = None if enforce_human_lock and actor_type != "user": lock = await get_active_lock(db, agent_id=agent_id, path=normalized) if lock: @@ -353,13 +372,28 @@ async def delete_workspace_file( f"Human is currently editing {normalized}. Do not delete it now.", locked_by_user_id=str(lock.user_id), ) - if not await storage.exists(storage_key): + storage_exists = await storage.exists(storage_key) + storage_is_dir = await storage.is_dir(storage_key) + if not storage_exists and not storage_is_dir: return WorkspaceWriteResult(False, normalized, f"File not found: {normalized}") - before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.is_file(storage_key) else None - if await storage.is_dir(storage_key): - await storage.delete_tree(storage_key) - else: - await storage.delete(storage_key) + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if storage_exists and await storage.is_file(storage_key) else None + async with workspace_locks(agent_id, [normalized]): + if storage_is_dir: + entries = await _collect_storage_tree_versions(storage, storage_key) + for entry_key, version_token in reversed(entries): + delete_result = await storage.delete_if_match( + entry_key, + condition=WriteCondition(version_token=version_token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while deleting {normalized}") + else: + delete_result = await storage.delete_if_match( + storage_key, + condition=WriteCondition(version_token=expected_version_token) if expected_version_token is not None else None, + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while deleting {normalized}") if target is not None and target.exists(): if target.is_dir(): import shutil @@ -397,6 +431,8 @@ async def move_workspace_path( session_id: str | None = None, enforce_human_lock: bool = True, overwrite: bool = False, + expected_source_version_token: str | None = None, + expected_destination_version_token: str | None = None, ) -> WorkspaceWriteResult: """Move or rename a workspace file/folder while respecting edit locks.""" source_normalized = normalize_workspace_path(source_path) @@ -408,19 +444,22 @@ async def move_workspace_path( if source_normalized in {"tasks.json", "soul.md"}: return WorkspaceWriteResult(False, source_normalized, f"{source_normalized} cannot be moved (protected)") - source = safe_agent_path(base_dir, source_normalized) - if not source.exists(): + storage = get_storage_backend() + source_key = normalize_storage_key(f"{agent_id}/{source_normalized}") + source_exists = await storage.exists(source_key) + source_is_dir = await storage.is_dir(source_key) + if not source_exists and not source_is_dir: return WorkspaceWriteResult(False, source_normalized, f"File not found: {source_normalized}") - destination = safe_agent_path(base_dir, destination_normalized) - if destination_path.replace("\\", "/").strip().endswith("/") or destination.is_dir(): - destination = (destination / source.name).resolve() - destination_normalized = normalize_workspace_path(str(destination.relative_to(base_dir.resolve()))) - destination = safe_agent_path(base_dir, destination_normalized) + destination_key = normalize_storage_key(f"{agent_id}/{destination_normalized}") + destination_is_dir = await storage.is_dir(destination_key) + if destination_path.replace("\\", "/").strip().endswith("/") or destination_is_dir: + destination_normalized = normalize_workspace_path(f"{destination_normalized}/{Path(source_normalized).name}") + destination_key = normalize_storage_key(f"{agent_id}/{destination_normalized}") - if source == destination: + if source_normalized == destination_normalized: return WorkspaceWriteResult(False, source_normalized, "Source and destination are the same") - if source.is_dir() and str(destination).startswith(str(source) + "/"): + if source_is_dir and (destination_normalized == source_normalized or destination_normalized.startswith(source_normalized + "/")): return WorkspaceWriteResult(False, source_normalized, "Cannot move a folder into itself") if enforce_human_lock and actor_type != "user": @@ -437,23 +476,71 @@ async def move_workspace_path( locked_by_user_id=str(lock.user_id), ) - if destination.exists(): - if not overwrite: - return WorkspaceWriteResult( - False, - destination_normalized, - f"Destination already exists: {destination_normalized}. Set overwrite=true to replace it.", - ) - if destination.is_dir(): - shutil.rmtree(destination) + destination_exists = await storage.exists(destination_key) + destination_is_dir = await storage.is_dir(destination_key) + async with workspace_locks(agent_id, [source_normalized, destination_normalized]): + if destination_exists or destination_is_dir: + if not overwrite: + return WorkspaceWriteResult( + False, + destination_normalized, + f"Destination already exists: {destination_normalized}. Set overwrite=true to replace it.", + ) + if destination_is_dir: + await storage.delete_tree(destination_key) + else: + delete_result = await storage.delete_if_match( + destination_key, + condition=WriteCondition(version_token=expected_destination_version_token) if expected_destination_version_token is not None else None, + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, destination_normalized, f"Conflict detected while replacing {destination_normalized}") + + source = destination = None + if _should_mirror_to_local_filesystem(storage): + source = safe_agent_path(base_dir, source_normalized) + destination = safe_agent_path(base_dir, destination_normalized) + source_before = await storage.read_text(source_key, encoding="utf-8", errors="replace") if source_exists else None + destination_before = await storage.read_text(destination_key, encoding="utf-8", errors="replace") if destination_exists else None + + if source_is_dir: + entries = await _collect_storage_tree_versions(storage, source_key) + for entry_key, version_token in entries: + rel = entry_key.removeprefix(source_key.rstrip("/") + "/") + target_key = normalize_storage_key(f"{agent_id}/{destination_normalized}/{rel}") + current_version = await storage.get_version(entry_key) + if current_version.token != version_token: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while moving {source_normalized}") + await storage.write_bytes(target_key, await storage.read_bytes(entry_key)) + for entry_key, version_token in reversed(entries): + delete_result = await storage.delete_if_match( + entry_key, + condition=WriteCondition(version_token=version_token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while finalizing move for {source_normalized}") else: - destination.unlink() + source_version = await storage.get_version(source_key) + if expected_source_version_token is not None and source_version.token != expected_source_version_token: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while moving {source_normalized}") + await storage.write_bytes(destination_key, await storage.read_bytes(source_key)) + delete_result = await storage.delete_if_match( + source_key, + condition=WriteCondition(version_token=source_version.token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while finalizing move for {source_normalized}") + + destination_after = await storage.read_text(destination_key, encoding="utf-8", errors="replace") if await storage.is_file(destination_key) else None - source_before = await read_text_if_exists(source) - destination_before = await read_text_if_exists(destination) - destination.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(source), str(destination)) - destination_after = await read_text_if_exists(destination) + if source is not None and source.exists(): + if source.is_dir(): + shutil.rmtree(source) + else: + source.unlink() + if destination is not None and await storage.is_file(destination_key): + destination.parent.mkdir(parents=True, exist_ok=True) + destination.write_bytes(await storage.read_bytes(destination_key)) source_revision = await record_revision( db, @@ -486,6 +573,17 @@ async def move_workspace_path( ) +async def _collect_storage_tree_versions(storage, root_key: str) -> list[tuple[str, str]]: + keys: list[tuple[str, str]] = [] + for entry in await storage.list_dir(root_key): + if entry.is_dir: + keys.extend(await _collect_storage_tree_versions(storage, entry.key)) + else: + version = await storage.get_version(entry.key) + keys.append((entry.key, version.token)) + return keys + + async def list_revisions( db: AsyncSession, *, diff --git a/backend/app/services/workspace_locking.py b/backend/app/services/workspace_locking.py new file mode 100644 index 000000000..ceafb43f1 --- /dev/null +++ b/backend/app/services/workspace_locking.py @@ -0,0 +1,80 @@ +"""Redis-backed short-lived locks for workspace mutations.""" + +from __future__ import annotations + +import uuid +from contextlib import asynccontextmanager + +from app.core.events import get_redis + +LOCK_PREFIX = "workspace-lock" +DEFAULT_LOCK_TTL_SECONDS = 60 + +_RELEASE_IF_OWNER_SCRIPT = """ +if redis.call('get', KEYS[1]) == ARGV[1] then + return redis.call('del', KEYS[1]) +end +return 0 +""" + + +def _normalize_workspace_path(path: str) -> str: + clean = (path or "").replace("\\", "/").strip().lstrip("/") + parts: list[str] = [] + for part in clean.split("/"): + if part in ("", "."): + continue + if part == "..": + if parts: + parts.pop() + continue + parts.append(part) + return "/".join(parts) + + +def _lock_key(agent_id: uuid.UUID, path: str) -> str: + normalized = _normalize_workspace_path(path) or "." + return f"{LOCK_PREFIX}:{agent_id}:{normalized}" + + +async def acquire_workspace_lock( + agent_id: uuid.UUID, + path: str, + *, + owner_token: str, + ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS, +) -> bool: + redis = await get_redis() + return bool(await redis.set(_lock_key(agent_id, path), owner_token, ex=ttl_seconds, nx=True)) + + +async def release_workspace_lock(agent_id: uuid.UUID, path: str, *, owner_token: str) -> None: + redis = await get_redis() + await redis.eval(_RELEASE_IF_OWNER_SCRIPT, 1, _lock_key(agent_id, path), owner_token) + + +@asynccontextmanager +async def workspace_locks( + agent_id: uuid.UUID, + paths: list[str], + *, + ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS, +): + normalized = sorted({_normalize_workspace_path(path) or "." for path in paths if path is not None}) + owner_token = uuid.uuid4().hex + acquired: list[str] = [] + try: + for path in normalized: + ok = await acquire_workspace_lock( + agent_id, + path, + owner_token=owner_token, + ttl_seconds=ttl_seconds, + ) + if not ok: + raise RuntimeError(f"Workspace lock busy: {path}") + acquired.append(path) + yield + finally: + for path in reversed(acquired): + await release_workspace_lock(agent_id, path, owner_token=owner_token) diff --git a/backend/entrypoint.sh b/backend/entrypoint.sh index a45c58e5f..a38723a9d 100755 --- a/backend/entrypoint.sh +++ b/backend/entrypoint.sh @@ -24,6 +24,12 @@ if [ "$(id -u)" = '0' ]; then fi # ------------------------------------------------------- +if [ -z "${INSTANCE_ID:-}" ]; then + SAFE_PROCESS_ROLE="${PROCESS_ROLE//,/-}" + export INSTANCE_ID="${SAFE_PROCESS_ROLE}-$(hostname)" +fi +echo "[entrypoint] INSTANCE_ID=${INSTANCE_ID}" + if role_contains "bootstrap"; then echo "[entrypoint] Step 1: Running alembic migrations for PROCESS_ROLE=${PROCESS_ROLE}..." set +e diff --git a/backend/pyproject.toml b/backend/pyproject.toml index b9a31d5bc..cb828c48d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "markdown>=3.6", "beautifulsoup4>=4.12.0", "boto3>=1.35.0", + "aioboto3>=13.0.0", ] [project.optional-dependencies] diff --git a/backend/tests/test_agent_tools_storage_workspace.py b/backend/tests/test_agent_tools_storage_workspace.py new file mode 100644 index 000000000..0d97bb870 --- /dev/null +++ b/backend/tests/test_agent_tools_storage_workspace.py @@ -0,0 +1,310 @@ +import uuid + +import pytest + +from app.services import agent_tools +from app.services import workspace_collaboration +from app.services.storage_runtime.base import StorageBackend, StorageEntry, StorageVersion, WriteCondition, ConditionalWriteResult + + +class MemoryStorageBackend(StorageBackend): + def __init__(self, files: dict[str, bytes] | None = None): + self.files = dict(files or {}) + self.versions = {key: 1 for key in self.files} + + async def exists(self, key: str) -> bool: + return key in self.files + + async def is_file(self, key: str) -> bool: + return key in self.files + + async def is_dir(self, key: str) -> bool: + prefix = key.rstrip("/") + "/" + return any(existing.startswith(prefix) for existing in self.files) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = key.rstrip("/") + "/" + entries: dict[str, StorageEntry] = {} + for existing, data in self.files.items(): + if not existing.startswith(prefix): + continue + rest = existing.removeprefix(prefix) + name, _, tail = rest.partition("/") + entries[name] = StorageEntry( + name=name, + key=f"{prefix}{name}", + is_dir=bool(tail), + size=0 if tail else len(data), + ) + return sorted(entries.values(), key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + return self.files[key] + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + self.files[key] = data + self.versions[key] = self.versions.get(key, 0) + 1 + + async def delete(self, key: str) -> None: + self.files.pop(key, None) + self.versions.pop(key, None) + + async def delete_tree(self, key: str) -> None: + prefix = key.rstrip("/") + "/" + for existing in list(self.files): + if existing.startswith(prefix): + self.files.pop(existing) + self.versions.pop(existing, None) + + async def stat(self, key: str) -> StorageEntry: + return StorageEntry(name=key.rsplit("/", 1)[-1], key=key, is_dir=False, size=len(self.files[key])) + + async def get_version(self, key: str) -> StorageVersion: + if key not in self.files: + return StorageVersion(key=key, exists=False, is_dir=False) + version = str(self.versions.get(key, 0)) + return StorageVersion( + key=key, + exists=True, + is_dir=False, + size=len(self.files[key]), + version_id=version, + etag=version, + content_hash=version, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + +@pytest.mark.asyncio +async def test_agent_file_tools_use_storage_paths(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/notes.md": b"# Notes\nneedle\n", + f"{agent_id}/memory/memory.md": b"# Memory\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + listing = await agent_tools._storage_list_dir(agent_id, "workspace") + read = await agent_tools._storage_read_file(agent_id, "workspace/notes.md") + search = await agent_tools._storage_search_files(agent_id, "needle", path="workspace", file_pattern="*.md") + found = await agent_tools._storage_find_files(agent_id, "*.md", path="workspace") + + assert "notes.md" in listing + assert "needle" in read + assert "workspace/notes.md:2" in search + assert "workspace/notes.md" in found + + +@pytest.mark.asyncio +async def test_temp_workspace_materializes_only_requested_paths(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + f"{agent_id}/workspace/other.md": b"# Other\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace/input.md"]) + try: + assert (temp_ws.root / "workspace" / "input.md").read_text(encoding="utf-8") == "# Input\n" + assert not (temp_ws.root / "workspace" / "other.md").exists() + finally: + temp_ws.cleanup() + + +@pytest.mark.asyncio +async def test_execute_tool_list_files_does_not_create_persistent_workspace(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + monkeypatch.setattr(agent_tools, "WORKSPACE_ROOT", tmp_path) + + async def _tenant(_agent_id): + return None + + monkeypatch.setattr(agent_tools, "_get_agent_tenant_id", _tenant) + + result = await agent_tools.execute_tool("list_files", {"path": "workspace"}, agent_id, agent_id) + + assert "input.md" in result + assert not (tmp_path / str(agent_id)).exists() + + +@pytest.mark.asyncio +async def test_write_workspace_file_does_not_mirror_to_local_for_non_local_storage(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend() + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + result = await workspace_collaboration.write_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/test.md", + content="hello", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + ) + + assert result.ok is True + assert storage.files[f"{agent_id}/workspace/test.md"] == b"hello" + assert not (tmp_path / str(agent_id) / "workspace" / "test.md").exists() + + +@pytest.mark.asyncio +async def test_flush_temp_workspace_only_writes_changed_files(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + f"{agent_id}/workspace/other.md": b"# Other\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace"]) + try: + (temp_ws.root / "workspace" / "input.md").write_text("# Updated\n", encoding="utf-8") + result = await agent_tools.flush_temp_workspace(temp_ws) + finally: + temp_ws.cleanup() + + assert result["updated"] == ["workspace/input.md"] + assert "workspace/other.md" in result["skipped"] + assert storage.files[f"{agent_id}/workspace/input.md"] == b"# Updated\n" + assert storage.files[f"{agent_id}/workspace/other.md"] == b"# Other\n" + + +@pytest.mark.asyncio +async def test_flush_temp_workspace_fails_on_conflict(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace/input.md"]) + try: + (temp_ws.root / "workspace" / "input.md").write_text("# Local change\n", encoding="utf-8") + await storage.write_bytes(f"{agent_id}/workspace/input.md", b"# Remote change\n") + result = await agent_tools.flush_temp_workspace(temp_ws) + finally: + temp_ws.cleanup() + + assert result["conflicted"] == ["workspace/input.md"] + assert storage.files[f"{agent_id}/workspace/input.md"] == b"# Remote change\n" + + +@pytest.mark.asyncio +async def test_write_workspace_file_fails_on_expected_version_conflict(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/test.md": b"old", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + version = await storage.get_version(f"{agent_id}/workspace/test.md") + await storage.write_bytes(f"{agent_id}/workspace/test.md", b"remote-new") + result = await workspace_collaboration.write_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/test.md", + content="local-new", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + expected_version_token=version.token, + ) + + assert result.ok is False + assert "Conflict detected" in result.message + assert storage.files[f"{agent_id}/workspace/test.md"] == b"remote-new" + + +@pytest.mark.asyncio +async def test_move_workspace_path_fails_when_source_changes(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/source.md": b"old", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + version = await storage.get_version(f"{agent_id}/workspace/source.md") + await storage.write_bytes(f"{agent_id}/workspace/source.md", b"remote-new") + result = await workspace_collaboration.move_workspace_path( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + source_path="workspace/source.md", + destination_path="workspace/dest.md", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + expected_source_version_token=version.token, + ) + + assert result.ok is False + assert "Conflict detected" in result.message + assert f"{agent_id}/workspace/dest.md" not in storage.files + + +@pytest.mark.asyncio +async def test_delete_workspace_directory_uses_prefix_existence(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/dir/a.txt": b"a", + f"{agent_id}/workspace/dir/nested/b.txt": b"b", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + result = await workspace_collaboration.delete_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/dir", + actor_type="user", + actor_id=agent_id, + enforce_human_lock=False, + ) + + assert result.ok is True + assert f"{agent_id}/workspace/dir/a.txt" not in storage.files + assert f"{agent_id}/workspace/dir/nested/b.txt" not in storage.files diff --git a/backend/tests/test_files_api_storage.py b/backend/tests/test_files_api_storage.py new file mode 100644 index 000000000..f791c16a0 --- /dev/null +++ b/backend/tests/test_files_api_storage.py @@ -0,0 +1,156 @@ +import uuid +from types import SimpleNamespace + +import pytest + +from app.api import files +from app.services.agent_manager import AgentManager +from app.services.storage_runtime.base import StorageBackend, StorageEntry, StorageVersion + + +class PrefixOnlyStorage(StorageBackend): + def __init__(self, objects: dict[str, bytes] | None = None): + self.objects = dict(objects or {}) + + async def exists(self, key: str) -> bool: + return key in self.objects + + async def is_file(self, key: str) -> bool: + return key in self.objects + + async def is_dir(self, key: str) -> bool: + prefix = key.rstrip("/") + "/" + return any(existing.startswith(prefix) for existing in self.objects) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = key.rstrip("/") + "/" + entries_by_name: dict[str, StorageEntry] = {} + for existing, data in self.objects.items(): + if not existing.startswith(prefix): + continue + rest = existing.removeprefix(prefix) + name, _, tail = rest.partition("/") + entries_by_name[name] = StorageEntry( + name=name, + key=f"{prefix}{name}", + is_dir=bool(tail), + size=0 if tail else len(data), + ) + return sorted(entries_by_name.values(), key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + return self.objects[key] + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + self.objects[key] = data + + async def delete(self, key: str) -> None: + self.objects.pop(key, None) + + async def delete_tree(self, key: str) -> None: + prefix = key.rstrip("/") + "/" + for existing in list(self.objects): + if existing.startswith(prefix): + self.objects.pop(existing, None) + + async def stat(self, key: str) -> StorageEntry: + if key not in self.objects: + raise FileNotFoundError(key) + return StorageEntry(name=key.rsplit("/", 1)[-1], key=key, is_dir=False, size=len(self.objects[key])) + + async def get_version(self, key: str) -> StorageVersion: + if key not in self.objects: + return StorageVersion(key=key, exists=False, is_dir=False) + token = f"v:{len(self.objects[key])}" + return StorageVersion( + key=key, + exists=True, + is_dir=False, + size=len(self.objects[key]), + version_id=token, + etag=token, + content_hash=token, + ) + + +@pytest.mark.asyncio +async def test_list_files_accepts_s3_prefix_directory(monkeypatch): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/focus.md": b"# Focus\n"}) + monkeypatch.setattr(files, "get_storage_backend", lambda: storage) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + result = await files.list_files(agent_id, path="", current_user=user, db=None) + + assert [item.name for item in result] == ["focus.md"] + assert result[0].path == "focus.md" + assert result[0].version_token == "v:8" + + +@pytest.mark.asyncio +async def test_list_files_allows_empty_agent_root(monkeypatch): + agent_id = uuid.uuid4() + monkeypatch.setattr(files, "get_storage_backend", lambda: PrefixOnlyStorage()) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + assert await files.list_files(agent_id, path="", current_user=user, db=None) == [] + + +@pytest.mark.asyncio +async def test_read_file_returns_version_token(monkeypatch): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/focus.md": b"# Focus\n"}) + monkeypatch.setattr(files, "get_storage_backend", lambda: storage) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + result = await files.read_file(agent_id, path="focus.md", current_user=user, db=None) + + assert result.version_token == "v:8" + + +@pytest.mark.asyncio +async def test_agent_manager_does_not_reinitialize_s3_prefix_directory(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/soul.md": b"existing"}) + monkeypatch.setattr("app.services.agent_manager.get_storage_backend", lambda: storage) + monkeypatch.setattr("app.services.agent_manager.settings.STORAGE_LOCAL_ROOT", str(tmp_path)) + + manager = AgentManager() + agent = SimpleNamespace(id=agent_id) + + await manager.initialize_agent_files(db=None, agent=agent) + + assert storage.objects[f"{agent_id}/soul.md"] == b"existing" + + +@pytest.mark.asyncio +async def test_agent_manager_materializes_s3_prefix_directory(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({ + f"{agent_id}/soul.md": b"# Soul\n", + f"{agent_id}/memory/memory.md": b"# Memory\n", + }) + monkeypatch.setattr("app.services.agent_manager.get_storage_backend", lambda: storage) + monkeypatch.setattr("app.services.agent_manager.settings.STORAGE_LOCAL_ROOT", str(tmp_path)) + + manager = AgentManager() + + agent_dir = await manager._materialize_agent_dir(agent_id) + + assert (agent_dir / "soul.md").read_text(encoding="utf-8") == "# Soul\n" + assert (agent_dir / "memory" / "memory.md").read_text(encoding="utf-8") == "# Memory\n" diff --git a/backend/tests/test_org_sync_adapter.py b/backend/tests/test_org_sync_adapter.py index d4fa71cfa..41f9b1443 100644 --- a/backend/tests/test_org_sync_adapter.py +++ b/backend/tests/test_org_sync_adapter.py @@ -1,6 +1,7 @@ import asyncio import uuid from contextlib import asynccontextmanager +from datetime import datetime, timezone from types import SimpleNamespace import pytest @@ -43,6 +44,14 @@ async def flush(self): self.flush_calls += 1 +class _RecordingExecuteDB: + def __init__(self): + self.statements = [] + + async def execute(self, statement): + self.statements.append(statement) + + class _SyncAdapterWithFailure(_DummyAdapter): def __init__(self): super().__init__() @@ -115,6 +124,17 @@ def test_sync_org_structure_skips_reconcile_after_member_failure(): assert "Reconcile skipped due to partial sync failures" in result["errors"] +def test_reconcile_disables_session_synchronization_for_datetime_comparisons(): + adapter = _DummyAdapter() + db = _RecordingExecuteDB() + + asyncio.run(adapter._reconcile(db, uuid.uuid4(), datetime.now(timezone.utc))) + + assert len(db.statements) == 2 + for statement in db.statements: + assert statement.get_execution_options()["synchronize_session"] is False + + def test_google_workspace_adapter_parses_legacy_service_account_json_string(): adapter = GoogleWorkspaceOrgSyncAdapter( config={ diff --git a/backend/tests/test_storage_s3.py b/backend/tests/test_storage_s3.py new file mode 100644 index 000000000..6c7dc3329 --- /dev/null +++ b/backend/tests/test_storage_s3.py @@ -0,0 +1,44 @@ +from unittest.mock import Mock + +from app.services.storage_runtime.s3 import S3StorageBackend + + +def test_s3_backend_passes_max_pool_connections(monkeypatch): + config_instances: list[object] = [] + client_calls: list[dict] = [] + + class FakeConfig: + def __init__(self, **kwargs): + self.kwargs = kwargs + config_instances.append(self) + + fake_boto3 = Mock() + fake_boto3.client.side_effect = lambda *args, **kwargs: client_calls.append(kwargs) or object() + + import builtins + + real_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "boto3": + return fake_boto3 + if name == "botocore.config": + return type("FakeBotocoreConfigModule", (), {"Config": FakeConfig})() + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + backend = S3StorageBackend( + bucket="bucket", + endpoint_url="http://minio:9000", + access_key_id="key", + secret_access_key="secret", + max_pool_connections=64, + ) + + backend._client_or_raise() + + assert len(config_instances) == 1 + assert config_instances[0].kwargs["max_pool_connections"] == 64 + assert len(client_calls) == 1 + assert client_calls[0]["config"] is config_instances[0] diff --git a/deploy/nginx/multi-instance.conf b/deploy/nginx/multi-instance.conf new file mode 100644 index 000000000..16870b324 --- /dev/null +++ b/deploy/nginx/multi-instance.conf @@ -0,0 +1,30 @@ +upstream clawith_multi_api { + least_conn; + server backend-api-1:8000 max_fails=3 fail_timeout=15s; + server backend-api-2:8000 max_fails=3 fail_timeout=15s; +} + +map $http_upgrade $connection_upgrade { + default upgrade; + '' close; +} + +server { + listen 8000; + server_name _; + + client_max_body_size 64m; + + location / { + proxy_pass http://clawith_multi_api; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection $connection_upgrade; + proxy_read_timeout 3600s; + proxy_send_timeout 3600s; + } +} diff --git a/frontend/src/components/MarkdownRenderer.tsx b/frontend/src/components/MarkdownRenderer.tsx index 960cc5af4..8ef865610 100644 --- a/frontend/src/components/MarkdownRenderer.tsx +++ b/frontend/src/components/MarkdownRenderer.tsx @@ -57,7 +57,7 @@ function autolinkBareUrls(html: string): string { function renderInline(text: string): string { const tokens: string[] = []; const stash = (html: string) => { - const key = `@@__MD_TOKEN_${tokens.length}__@@`; + const key = `@@CLAWITHMDTOKEN${tokens.length}@@`; tokens.push(html); return key; }; @@ -89,7 +89,7 @@ function renderInline(text: string): string { working = autolinkBareUrls(working); tokens.forEach((html, i) => { - working = working.replace(new RegExp(`@@__MD_TOKEN_${i}__@@`, 'g'), html); + working = working.replace(new RegExp(`@@CLAWITHMDTOKEN${i}@@`, 'g'), html); }); return working; } From 9cc5f5b65789678b21c2a58d2211c350a6eccf34 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 8 May 2026 20:12:59 +0800 Subject: [PATCH 06/12] fix(files): restore list_files filtering for focus.md and .gitkeep Lost during merge conflict resolution. Re-adds main's filtering logic to the storage-backend list_files implementation: - Skip .gitkeep files - Hide focus.md / agenda.md at workspace root (managed via Focus API) - Filter enterprise_info at root (already injected as virtual entry) --- backend/app/api/files.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/backend/app/api/files.py b/backend/app/api/files.py index 2f2dddc47..690e0901a 100644 --- a/backend/app/api/files.py +++ b/backend/app/api/files.py @@ -164,6 +164,12 @@ async def list_files( )) entries = await storage.list_dir(storage_key) if path_exists or path_is_dir else [] for entry in entries: + if entry.name == '.gitkeep': + continue + if not path and entry.name.lower() in {"focus.md", "agenda.md"}: + continue + if not path and entry.name == "enterprise_info": + continue if is_enterprise: rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) rel_path = f"enterprise_info/{rel}" if rel != "." else "enterprise_info" From a74d8d7ea580128c7a95eca8506b6e09e30e92bb Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Fri, 8 May 2026 20:19:58 +0800 Subject: [PATCH 07/12] fix(security): restore missing focus guards and path traversal checks Merge conflict resolution dropped several defensive checks from main: files.py: - Add is_focus_file_path check to read_file endpoint (410 GONE) - Add .gitkeep filtering to list_enterprise_kb_files agent_tools.py (_execute_workspace_mutation): - Add is_focus_file_path guard to write_file, move_file, delete_file, edit_file (returns user-friendly error directing to Focus API) Enterprise endpoints' path safety is already guaranteed by normalize_storage_key which strips ../ traversal, so the explicit startswith checks from main are not required here. --- backend/app/api/files.py | 7 +++++++ backend/app/services/agent_tools.py | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/backend/app/api/files.py b/backend/app/api/files.py index 690e0901a..33d830fe8 100644 --- a/backend/app/api/files.py +++ b/backend/app/api/files.py @@ -196,6 +196,11 @@ async def read_file( ): """Read the content of a file.""" await check_agent_access(db, current_user, agent_id) + if is_focus_file_path(path): + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail="Focus is stored in the system database. Use the Focus API.", + ) storage = get_storage_backend() key, _ = _visible_storage_key(agent_id, path, current_user.tenant_id) if not await storage.exists(key) or not await storage.is_file(key): @@ -841,6 +846,8 @@ async def list_enterprise_kb_files( items = [] for entry in await storage.list_dir(storage_key): + if entry.name == '.gitkeep': + continue rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) items.append({ "name": entry.name, diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index e27e84cf8..7cb3a2441 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -2490,6 +2490,8 @@ async def _execute_workspace_mutation( return "❌ Missing required argument 'path' for write_file. Please provide a file path like 'skills/my-skill/SKILL.md'" if content is None: return "❌ Missing required argument 'content' for write_file" + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." if _is_enterprise_info_path(path): return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." async with async_session() as _wdb: @@ -2519,6 +2521,8 @@ async def _execute_workspace_mutation( return "❌ Missing required argument 'source_path' for move_file" if not destination_path: return "❌ Missing required argument 'destination_path' for move_file" + if is_focus_file_path(source_path) or is_focus_file_path(destination_path): + return "❌ Focus is no longer stored in focus.md. Use Focus tools instead." if str(source_path).strip("/") in {"tasks.json", "soul.md"}: return f"❌ {source_path} cannot be moved (protected)" if _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): @@ -2541,6 +2545,8 @@ async def _execute_workspace_mutation( if tool_name == "delete_file": path = arguments.get("path", "") + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use Focus tools instead." if _is_enterprise_info_path(path): return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." async with async_session() as _wdb: @@ -2567,6 +2573,8 @@ async def _execute_workspace_mutation( return "❌ Missing required argument 'old_string' for edit_file" if new_string is None: return "❌ Missing required argument 'new_string' for edit_file" + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." if _is_enterprise_info_path(path): return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." From d556e1d1d769dc0e3eb0bb356b0580747f4dd666 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Sat, 9 May 2026 13:56:54 +0800 Subject: [PATCH 08/12] chore: remove deploy-all-in-one.sh --- deploy/deploy-all-in-one.sh | 53 ------------------------------------- 1 file changed, 53 deletions(-) delete mode 100755 deploy/deploy-all-in-one.sh diff --git a/deploy/deploy-all-in-one.sh b/deploy/deploy-all-in-one.sh deleted file mode 100755 index 5e4fc42ca..000000000 --- a/deploy/deploy-all-in-one.sh +++ /dev/null @@ -1,53 +0,0 @@ -#!/bin/bash -# All-in-One deployment script for Clawith -# Target: 192.168.106.163:/root/yaojin/clawith-yaojin - -set -e - -REMOTE_HOST="root@192.168.106.163" -REMOTE_PASS="dataelem" -REMOTE_DIR="/root/yaojin/clawith-yaojin" -LOCAL_DIR="$(cd "$(dirname "$0")/.." && pwd)" - -echo "=== Clawith All-in-One Deployment ===" -echo "Local: $LOCAL_DIR" -echo "Remote: $REMOTE_HOST:$REMOTE_DIR" -echo "" - -# Step 1: Clean remote directory -echo "[1/4] Cleaning remote directory..." -sshpass -p "$REMOTE_PASS" ssh "$REMOTE_HOST" "rm -rf $REMOTE_DIR/*" - -# Step 2: Sync code (excluding unnecessary files) -echo "[2/4] Syncing code..." -sshpass -p "$REMOTE_PASS" rsync -avz --progress \ - --exclude='node_modules' \ - --exclude='.git' \ - --exclude='frontend/dist' \ - --exclude='frontend/build' \ - --exclude='__pycache__' \ - --exclude='*.pyc' \ - --exclude='.env' \ - "$LOCAL_DIR/" "$REMOTE_HOST:$REMOTE_DIR/" - -# Step 3: Copy deploy configs to root -echo "[3/4] Setting up deploy configs..." -sshpass -p "$REMOTE_PASS" ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \ - cp deploy/docker-compose.yml . && \ - cp deploy/.env.example .env 2>/dev/null || true && \ - mkdir -p nginx && \ - cp deploy/nginx/nginx.conf nginx/ && \ - cp deploy/nginx/all-in-one.conf nginx/" - -# Step 4: Build and start services -echo "[4/4] Building and starting services..." -sshpass -p "$REMOTE_PASS" ssh "$REMOTE_HOST" "cd $REMOTE_DIR && \ - rm -rf frontend/dist frontend/build && \ - docker compose build backend --no-cache && \ - docker compose build frontend --no-cache && \ - docker compose up -d" - -echo "" -echo "=== Deployment Complete ===" -echo "Frontend: http://192.168.106.163:3008" -echo "Backend: http://192.168.106.163:8000" From 9a4b6e527b3d0410d4660c90fa2a35c2338df580 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Sat, 9 May 2026 14:01:50 +0800 Subject: [PATCH 09/12] chore(deploy): remove unused nginx configs --- deploy/nginx/all-in-one.conf | 29 ----------------------------- deploy/nginx/multi-instance.conf | 30 ------------------------------ 2 files changed, 59 deletions(-) delete mode 100644 deploy/nginx/all-in-one.conf delete mode 100644 deploy/nginx/multi-instance.conf diff --git a/deploy/nginx/all-in-one.conf b/deploy/nginx/all-in-one.conf deleted file mode 100644 index 6c24dfa82..000000000 --- a/deploy/nginx/all-in-one.conf +++ /dev/null @@ -1,29 +0,0 @@ -upstream clawith_role_all_backend { - least_conn; - server backend:8000 max_fails=3 fail_timeout=15s; -} - -map $http_upgrade $connection_upgrade { - default upgrade; - '' close; -} - -server { - listen 8000; - server_name _; - - client_max_body_size 64m; - - location / { - proxy_pass http://clawith_role_all_backend; - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection $connection_upgrade; - proxy_read_timeout 3600s; - proxy_send_timeout 3600s; - } -} diff --git a/deploy/nginx/multi-instance.conf b/deploy/nginx/multi-instance.conf deleted file mode 100644 index 16870b324..000000000 --- a/deploy/nginx/multi-instance.conf +++ /dev/null @@ -1,30 +0,0 @@ -upstream clawith_multi_api { - least_conn; - server backend-api-1:8000 max_fails=3 fail_timeout=15s; - server backend-api-2:8000 max_fails=3 fail_timeout=15s; -} - -map $http_upgrade $connection_upgrade { - default upgrade; - '' close; -} - -server { - listen 8000; - server_name _; - - client_max_body_size 64m; - - location / { - proxy_pass http://clawith_multi_api; - proxy_http_version 1.1; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection $connection_upgrade; - proxy_read_timeout 3600s; - proxy_send_timeout 3600s; - } -} From 49a3060d96ce3cdfa7bd9b77b2608ea0c0562414 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Wed, 13 May 2026 17:40:16 +0800 Subject: [PATCH 10/12] Add multi-instance deploy compose --- deploy/docker-compose-multi.yml | 300 ++++++++++++++++++++++++++++++++ deploy/docker-compose.yml | 2 +- 2 files changed, 301 insertions(+), 1 deletion(-) create mode 100644 deploy/docker-compose-multi.yml diff --git a/deploy/docker-compose-multi.yml b/deploy/docker-compose-multi.yml new file mode 100644 index 000000000..f8ba4c73c --- /dev/null +++ b/deploy/docker-compose-multi.yml @@ -0,0 +1,300 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + networks: + - default + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U clawith"] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + networks: + - default + volumes: + - redisdata:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 5s + retries: 5 + + minio: + image: minio/minio:RELEASE.2025-04-22T22-12-26Z + restart: unless-stopped + command: server /data --console-address ":9001" + networks: + - default + environment: + MINIO_ROOT_USER: ${MINIO_ROOT_USER:-clawith} + MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-clawith-minio-secret} + volumes: + - miniodata:/data + healthcheck: + test: ["CMD-SHELL", "curl -fsS http://127.0.0.1:9000/minio/health/live >/dev/null"] + interval: 10s + timeout: 5s + retries: 5 + + backend-api-1: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: api + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + default: + aliases: + - backend + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-api-2: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: api + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + default: + aliases: + - backend + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-worker: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: bootstrap,worker + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-connector: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: connector + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + frontend: + build: ./frontend + restart: unless-stopped + ports: + - "${FRONTEND_PORT:-3008}:3000" + environment: + VITE_API_URL: http://localhost:8000 + API_UPSTREAM: ${API_UPSTREAM:-backend:8000} + volumes: + - ./frontend/nginx.conf.template:/etc/nginx/templates/default.conf.template:ro + networks: + - default + depends_on: + - backend-api-1 + - backend-api-2 + +volumes: + pgdata: + redisdata: + miniodata: + +networks: + default: + name: clawith_network diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index e2c1652d3..ba0a52430 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -101,4 +101,4 @@ volumes: networks: default: - name: clawith_yaojin_network + name: clawith_network From 30fce161ca57a94e7ec3745caf62a2a8d9f42e38 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Mon, 18 May 2026 19:21:09 +0800 Subject: [PATCH 11/12] Make bcrypt operations async to avoid blocking event loop Use ThreadPoolExecutor for password hashing/verification, update agents API, chat sessions, and agent detail page. --- backend/app/api/agents.py | 38 +++++++++- backend/app/api/auth.py | 14 ++-- backend/app/api/chat_sessions.py | 33 ++++----- backend/app/core/security.py | 35 +++++++--- backend/app/services/registration_service.py | 4 +- .../pages/agent-detail/AgentDetailPage.tsx | 69 ++++++++++++++++++- 6 files changed, 157 insertions(+), 36 deletions(-) diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index a941b83d8..c07068d81 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -513,14 +513,48 @@ async def create_agent( f"on agent {agent.id} raised: {e}" ) - # Start container + # Start container first (non-blocking if Docker available) await agent_manager.start_container(db, agent) await db.flush() + # Commit agent and basic setup before async operations from app.services.okr_agent_hook import hook_new_agent if agent.tenant_id: await hook_new_agent(db, agent.id, agent.tenant_id) - await db.commit() + await db.commit() + await db.refresh(agent) + + # MCP import runs in background to avoid blocking the response + if template_mcp_servers: + import asyncio + from loguru import logger + from app.services.resource_discovery import import_mcp_from_smithery + + async def _background_mcp_import(agent_id: uuid.UUID, server_ids: list[str]): + for server_id in server_ids: + try: + result_msg = await import_mcp_from_smithery( + server_id=server_id, + agent_id=agent_id, + config={}, + ) + if result_msg.startswith("❌"): + logger.warning( + f"[create_agent] MCP pre-install for '{server_id}' " + f"on agent {agent_id} reported error: {result_msg[:200]}" + ) + else: + logger.info( + f"[create_agent] MCP pre-install '{server_id}' " + f"succeeded for agent {agent_id}" + ) + except Exception as e: + logger.warning( + f"[create_agent] MCP pre-install for '{server_id}' " + f"on agent {agent_id} raised: {e}" + ) + + asyncio.create_task(_background_mcp_import(agent.id, template_mcp_servers)) return await _agent_to_out(db, agent, current_user.id) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 8cff5f0b3..a61571f49 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -11,7 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.security import create_access_token, get_authenticated_user, get_current_user, hash_password, verify_password +from app.core.security import create_access_token, get_authenticated_user, get_current_user, hash_password_async, verify_password_async from app.database import get_db from app.models.user import Identity, User from app.schemas.schemas import ( @@ -194,7 +194,7 @@ async def register_init( ) # If identity existed, verify password - if identity.password_hash and not verify_password(data.password, identity.password_hash): + if identity.password_hash and not await verify_password_async(data.password, identity.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Email already registered. Incorrect password." @@ -445,7 +445,7 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes result = await db.execute(query) identity = result.scalar_one_or_none() - if not identity or not identity.password_hash or not verify_password(data.password, identity.password_hash): + if not identity or not identity.password_hash or not await verify_password_async(data.password, identity.password_hash): logger.warning(f"[LOGIN] Invalid credentials for {data.login_identifier} identity_id={identity.id if identity else 'None'}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") @@ -669,9 +669,9 @@ async def reset_password(data: ResetPasswordRequest, db: AsyncSession = Depends( if not identity or not identity.is_active: raise HTTPException(status_code=400, detail="Invalid or expired reset token") - new_hash = hash_password(data.new_password) + new_hash = await hash_password_async(data.new_password) identity.password_hash = new_hash - + await db.flush() await db.commit() return {"ok": True} @@ -863,10 +863,10 @@ async def change_password( user = res.scalar_one() identity = user.identity - if not identity or not identity.password_hash or not verify_password(old_password, identity.password_hash): + if not identity or not identity.password_hash or not await verify_password_async(old_password, identity.password_hash): raise HTTPException(status_code=400, detail="Current password is incorrect") - new_hash = hash_password(new_password) + new_hash = await hash_password_async(new_password) identity.password_hash = new_hash await db.flush() diff --git a/backend/app/api/chat_sessions.py b/backend/app/api/chat_sessions.py index 356d14700..f6210dd82 100644 --- a/backend/app/api/chat_sessions.py +++ b/backend/app/api/chat_sessions.py @@ -361,6 +361,8 @@ async def delete_session( async def get_session_messages( agent_id: uuid.UUID, session_id: uuid.UUID, + limit: int = Query(20, ge=1, le=500, description="Number of messages to return"), + offset: int = Query(0, ge=0, description="Offset for pagination"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): @@ -382,35 +384,34 @@ async def get_session_messages( raise HTTPException(status_code=403, detail="Not authorized to view this session") # Query messages by conversation_id only (agent-to-agent uses session_agent_id) - # Query the latest 500 messages (subquery in DESC, then reverse for display order) + # Optimized: use a single query with ORDER BY and LIMIT instead of subquery from sqlalchemy import desc - latest_subq = ( - select(ChatMessage.id) - .where(ChatMessage.conversation_id == str(session_id)) - .order_by(desc(ChatMessage.created_at)) - .limit(500) - .subquery() - ) msgs_result = await db.execute( select(ChatMessage) - .where(ChatMessage.id.in_(select(latest_subq.c.id))) - .order_by(ChatMessage.created_at.asc()) + .where(ChatMessage.conversation_id == str(session_id)) + .order_by(desc(ChatMessage.created_at)) + .limit(limit) + .offset(offset) ) - messages = msgs_result.scalars().all() + messages = list(reversed(msgs_result.scalars().all())) # Reading your own first-party/channel session should clear its unread state. if str(session.user_id) == str(current_user.id) and not session.is_group and session.source_channel not in ("agent", "trigger"): session.last_read_at_by_user = datetime.now(tz.utc) await db.commit() - # Resolve sender names for agent sessions + # Batch fetch all participant names to avoid N+1 queries sender_cache: dict = {} if session.source_channel == "agent": from app.models.participant import Participant - for m in messages: - if m.participant_id and str(m.participant_id) not in sender_cache: - p_r = await db.execute(select(Participant.display_name).where(Participant.id == m.participant_id)) - sender_cache[str(m.participant_id)] = p_r.scalar_one_or_none() or "Unknown" + participant_ids = list({m.participant_id for m in messages if m.participant_id}) + if participant_ids: + p_result = await db.execute( + select(Participant.id, Participant.display_name) + .where(Participant.id.in_(participant_ids)) + ) + for row in p_result.all(): + sender_cache[str(row[0])] = row[1] or "Unknown" out = [] for m in messages: diff --git a/backend/app/core/security.py b/backend/app/core/security.py index a22045a7b..967fb9ced 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,8 +1,10 @@ """Security utilities: JWT, password hashing, and authentication dependencies.""" +import asyncio import base64 import os import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone import bcrypt @@ -23,6 +25,31 @@ # Bearer token scheme security = HTTPBearer() +# Thread pool for CPU-intensive bcrypt operations (avoids blocking the event loop) +_bcrypt_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="bcrypt") + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt (sync, for use in background tasks).""" + return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash (sync, for use in background tasks).""" + return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) + + +async def hash_password_async(password: str) -> str: + """Hash a password using bcrypt without blocking the event loop.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_bcrypt_executor, hash_password, password) + + +async def verify_password_async(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash without blocking the event loop.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_bcrypt_executor, verify_password, plain_password, hashed_password) + def encrypt_data(plaintext: str, key: str) -> str: """Encrypt a string using AES-256-CBC with the given key. @@ -96,14 +123,6 @@ def decrypt_data(ciphertext: str, key: str) -> str: raise ValueError(f"Decryption failed: {e}") from e -def hash_password(password: str) -> str: - """Hash a password using bcrypt.""" - return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify a password against its hash.""" - return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) def create_access_token(user_id: str, role: str, expires_delta: timedelta | None = None) -> str: diff --git a/backend/app/services/registration_service.py b/backend/app/services/registration_service.py index 2db0795d4..2688e670b 100644 --- a/backend/app/services/registration_service.py +++ b/backend/app/services/registration_service.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings -from app.core.security import hash_password +from app.core.security import hash_password_async from app.models.identity import IdentityProvider from app.models.tenant import Tenant from app.models.user import User, Identity @@ -201,7 +201,7 @@ async def find_or_create_identity( email=email, phone=normalized_phone, username=final_username, - password_hash=hash_password(password) if password else None, + password_hash=await hash_password_async(password) if password else None, is_platform_admin=is_platform_admin, email_verified=is_verified, ) diff --git a/frontend/src/pages/agent-detail/AgentDetailPage.tsx b/frontend/src/pages/agent-detail/AgentDetailPage.tsx index 049ebb72f..d2b9dfd24 100644 --- a/frontend/src/pages/agent-detail/AgentDetailPage.tsx +++ b/frontend/src/pages/agent-detail/AgentDetailPage.tsx @@ -2145,6 +2145,10 @@ export default function AgentDetailPage() { const [scopeDropdownOpen, setScopeDropdownOpen] = useState(false); const scopeDropdownRef = useRef(null); const [historyMsgs, setHistoryMsgs] = useState([]); + const [historyOffset, setHistoryOffset] = useState(0); + const [historyHasMore, setHistoryHasMore] = useState(true); + const [historyLoadingMore, setHistoryLoadingMore] = useState(false); + const HISTORY_PAGE_SIZE = 20; const [sessionsLoading, setSessionsLoading] = useState(false); const [allSessionsLoading, setAllSessionsLoading] = useState(false); const [agentExpired, setAgentExpired] = useState(false); @@ -2368,6 +2372,9 @@ export default function AgentDetailPage() { pendingHistoryInitialScrollRef.current = !writable; setChatMessages([]); setHistoryMsgs([]); + setHistoryOffset(0); + setHistoryHasMore(true); + setHistoryLoadingMore(false); setIsStreaming(runtimeState.isStreaming); setIsWaiting(runtimeState.isWaiting); setActiveSession(sess); @@ -2382,7 +2389,7 @@ export default function AgentDetailPage() { const loadSeq = ++sessionLoadSeqRef.current; try { const tkn = localStorage.getItem('token'); - const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages`, { + const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages?limit=${HISTORY_PAGE_SIZE}&offset=0`, { headers: { Authorization: `Bearer ${tkn}` }, signal: controller.signal, }); @@ -2398,6 +2405,7 @@ export default function AgentDetailPage() { ...(m.created_at && { timestamp: m.created_at }), ...(m.id && { id: m.id }), })); + setHistoryHasMore(msgs.length >= HISTORY_PAGE_SIZE); if (writable) { setChatMessages(preParsed); @@ -3343,11 +3351,60 @@ export default function AgentDetailPage() { window.setTimeout(scroll, 120); window.setTimeout(scroll, 360); }, []); + + const loadMoreHistoryMessages = useCallback(async () => { + if (historyLoadingMore || !historyHasMore || !activeSession || !id) return; + const sess = activeSession; + const targetAgentId = id; + setHistoryLoadingMore(true); + try { + const tkn = localStorage.getItem('token'); + const newOffset = historyOffset + HISTORY_PAGE_SIZE; + const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages?limit=${HISTORY_PAGE_SIZE}&offset=${newOffset}`, { + headers: { Authorization: `Bearer ${tkn}` }, + }); + if (!res.ok) return; + const msgs = await res.json(); + if (msgs.length === 0) { + setHistoryHasMore(false); + return; + } + const preParsed = msgs.map((m: any) => parseChatMsg({ + role: m.role, content: m.content || '', + ...(m.toolName && { toolName: m.toolName, toolArgs: m.toolArgs, toolStatus: m.toolStatus, toolResult: m.toolResult, toolThinking: m.toolThinking }), + ...(m.thinking && { thinking: m.thinking }), + ...(m.created_at && { timestamp: m.created_at }), + ...(m.id && { id: m.id }), + })); + // Save current scroll position + const el = historyContainerRef.current; + const oldScrollHeight = el?.scrollHeight ?? 0; + setHistoryMsgs(prev => [...preParsed, ...prev]); + setHistoryOffset(newOffset); + setHistoryHasMore(msgs.length >= HISTORY_PAGE_SIZE); + // Restore scroll position after new messages are prepended + requestAnimationFrame(() => { + if (el) { + const newScrollHeight = el.scrollHeight; + el.scrollTop = newScrollHeight - oldScrollHeight; + } + }); + } catch (err: any) { + console.error('Failed to load more history messages:', err); + } finally { + setHistoryLoadingMore(false); + } + }, [historyLoadingMore, historyHasMore, activeSession, id, historyOffset]); + const handleHistoryScroll = () => { const el = historyContainerRef.current; if (!el) return; const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; setShowHistoryScrollBtn(distFromBottom > 200); + // Load more when scrolling near the top + if (el.scrollTop < 100 && historyHasMore && !historyLoadingMore) { + loadMoreHistoryMessages(); + } }; const scrollHistoryToBottom = () => { scheduleHistoryScrollToBottom(); @@ -5853,6 +5910,16 @@ export default function AgentDetailPage() { )}
+ {historyLoadingMore && ( +
+ Loading more messages... +
+ )} + {!historyHasMore && historyMsgs.length > 0 && ( +
+ All messages loaded +
+ )} {(() => { // For A2A sessions, determine which participant is "this agent" (left side) // Use agent.name matching against sender_name from messages From 4056d65753c03f810d3c2cf4274d3c27eb2ca510 Mon Sep 17 00:00:00 2001 From: yaojin3616 Date: Mon, 18 May 2026 21:16:51 +0800 Subject: [PATCH 12/12] import error --- backend/app/api/agents.py | 1 - backend/app/services/trigger_daemon.py | 323 ------------------------- 2 files changed, 324 deletions(-) diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index c07068d81..9cd204a6c 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -527,7 +527,6 @@ async def create_agent( # MCP import runs in background to avoid blocking the response if template_mcp_servers: import asyncio - from loguru import logger from app.services.resource_discovery import import_mcp_from_smithery async def _background_mcp_import(agent_id: uuid.UUID, server_ids: list[str]): diff --git a/backend/app/services/trigger_daemon.py b/backend/app/services/trigger_daemon.py index 79cfae966..de9531471 100644 --- a/backend/app/services/trigger_daemon.py +++ b/backend/app/services/trigger_daemon.py @@ -74,329 +74,6 @@ async def _invoke_agent_for_triggers(agent_id: uuid.UUID, triggers: list[AgentTr await invoke_agent_for_triggers_runtime(agent_id, triggers) -# ── Main Tick Loop ────────────────────────────────────────────────── - - # Build trigger context. Keep this model-facing prompt in English so - # autonomous wakeups behave consistently across UI locales. - context_parts = [] - trigger_names = [] - for t in triggers: - part = f"Trigger: {t.name} ({t.type})\nReason: {t.reason}" - if t.name == "daily_okr_collection": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether daily report collection is enabled. " - "If it is enabled, only contact members and digital employees in your relationship network to collect today's final daily reports, " - "then organize them into a formal daily report no longer than 2000 characters. " - "If it is disabled, state that no action is needed and stop." - ) - elif t.name in ("daily_okr_report", "weekly_okr_report", "monthly_okr_report"): - part += ( - "\nExecution requirements: This company-level report is generated automatically by the system. " - "If you are awakened, only add necessary clarification. Do not start another member collection round." - ) - elif t.name == "biweekly_okr_checkin": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether OKR is enabled. " - "If enabled, check the current-cycle company and member OKRs, then proactively remind members who have not set OKRs or whose progress is lagging. " - "If disabled, state that no action is needed and stop." - ) - elif t.name == "monthly_okr_report": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether OKR is enabled. " - "If enabled, call generate_monthly_okr_report to generate the OKR monthly report for the month that just ended, " - "then send it to admins or publish it to Plaza. If disabled, state that no action is needed and stop." - ) - if t.focus_ref: - part += f"\nRelated Focus: {t.focus_ref}" - # Include matched message for on_message triggers - cfg = t.config or {} - if t.type == "on_message" and cfg.get("_matched_message"): - part += f"\nMatched message from {cfg.get('_matched_from', '?')}:\n\"{cfg['_matched_message'][:500]}\"" - if t.type == "on_message" and cfg.get("okr_member_id") and cfg.get("okr_report_date"): - part += ( - "\nExecution requirements: This is a daily-report reply ingestion event." - f"\n1. Organize the other party's reply into a final daily report no longer than 2000 characters." - f"\n2. Immediately call upsert_member_daily_report(report_date=\"{cfg['okr_report_date']}\", " - f"member_type=\"{cfg.get('okr_member_type', 'user')}\", " - f"member_id=\"{cfg['okr_member_id']}\", content=\"\")." - "\n3. After the tool call succeeds, send a brief confirmation that you received and recorded it." - "\n4. Do not only confirm without calling the tool, and do not store the raw long conversation verbatim as the daily report." - ) - # Include webhook payload - if t.type == "webhook" and cfg.get("_webhook_payload"): - payload_str = cfg["_webhook_payload"] - if len(payload_str) > 2000: - payload_str = payload_str[:2000] + "... (truncated)" - part += f"\nWebhook Payload:\n{payload_str}" - context_parts.append(part) - trigger_names.append(t.name) - - trigger_context = ( - "===== Wake Context =====\n" - f"Wake source: trigger ({'multiple triggers fired together' if len(triggers) > 1 else 'trigger fired'})\n\n" - + "\n---\n".join(context_parts) - + "\n========================" - ) - - # Create Reflection Session - title = f"🤖 Reflection: {', '.join(trigger_names)}" - # Find agent's participant - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - session = ChatSession( - agent_id=agent_id, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - source_channel="trigger", - title=title[:200], - ) - db.add(session) - await db.flush() - session_id = session.id - - # Messages: trigger context only (call_llm builds system prompt internally) - messages = [ - {"role": "user", "content": trigger_context}, - ] - - # Store trigger context as a message in the session - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="user", - content=trigger_context, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - await db.commit() - # Cache participant ID for callbacks - agent_participant_id = agent_participant.id if agent_participant else None - - # Call LLM (outside the DB session to avoid long transactions) - collected_content = [] - delivered_platform_message_via_tool = False - - async def on_chunk(text): - collected_content.append(text) - - # Persist tool calls into Reflection Session for Reflections visibility - async def on_tool_call(data): - nonlocal delivered_platform_message_via_tool - try: - tool_name = data.get("name") - tool_status = data.get("status") - if tool_status == "done" and tool_name == "send_platform_message": - result_text = str(data.get("result", "")) - if result_text.startswith("✅"): - delivered_platform_message_via_tool = True - - async with async_session() as _tc_db: - if data["status"] == "running": - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "args": data["args"]}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - elif data["status"] == "done": - result_str = str(data.get("result", ""))[:2000] - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "result": result_str}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - await _tc_db.commit() - except Exception as e: - logger.warning(f"Failed to persist tool call for trigger session: {e}") - - reply = await call_llm( - model=model, - messages=messages, - agent_name=agent.name, - role_description=agent.role_description or "", - agent_id=agent_id, - user_id=agent.creator_id, - session_id=str(session_id), - on_chunk=on_chunk, - on_tool_call=on_tool_call, - # A2A wake uses the agent's own max_tool_rounds setting (no override) - ) - - # Save assistant reply to Reflection session - async with async_session() as db: - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="assistant", - content=reply or "".join(collected_content), - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - - # NOTE: trigger state (last_fired_at, fire_count, auto-disable) - # is already updated in _tick() BEFORE this task was launched, - # to prevent race-condition duplicate fires. - - await db.commit() - - # Compute final reply text once - final_reply = reply or "".join(collected_content) - - # ── Save reply to A2A session if this was an agent-to-agent wake ── - # This makes the target agent's reply visible in the A2A chat history - for t in triggers: - a2a_sid = (t.config or {}).get("_a2a_session_id") - if a2a_sid and final_reply: - try: - async with async_session() as db: - from app.models.participant import Participant as _P - _p_r = await db.execute(select(_P).where(_P.type == "agent", _P.ref_id == agent_id)) - _p = _p_r.scalar_one_or_none() - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=a2a_sid, - role="assistant", - content=final_reply, - user_id=agent.creator_id, - participant_id=_p.id if _p else None, - )) - # Update session timestamp - from app.models.chat_session import ChatSession as _CS - _cs_r = await db.execute(select(_CS).where(_CS.id == uuid.UUID(a2a_sid))) - _cs = _cs_r.scalar_one_or_none() - if _cs: - _cs.last_message_at = datetime.now(timezone.utc) - await db.commit() - logger.info(f"[A2A] Saved reply to A2A session {a2a_sid}") - except Exception as e: - logger.warning(f"[A2A] Failed to save reply to A2A session {a2a_sid}: {e}") - break # Only save once - - # Route trigger results to a single deterministic destination. Pure reflection/system - # wakes stay inside the reflection session and should not spill into arbitrary user chats. - is_a2a_internal = all(t.name == "a2a_wake" for t in triggers) - delivery_target = None if is_a2a_internal else await _resolve_trigger_delivery_target(agent, triggers) - - if final_reply and delivery_target and not delivered_platform_message_via_tool: - try: - from app.api.websocket import manager as ws_manager - agent_id_str = str(agent_id) - - # Build notification message with trigger badge - trigger_reasons = [] - for t in triggers: - ns = (t.config or {}).get("_notification_summary", "").strip() - if ns: - trigger_reasons.append(ns) - else: - r = (t.reason or "").strip() - if r and len(r) <= 80: - trigger_reasons.append(r) - elif r: - trigger_reasons.append(r[:77] + "...") - summary = trigger_reasons[0] if trigger_reasons else "有新的事件需要处理" - - _is_a2a_wait = any(t.name.startswith("a2a_wait_") for t in triggers) - if _is_a2a_wait: - import re as _re - cleaned = final_reply - _internal_patterns = [ - r'\b(a2a_wait_\w+|a2a_wake)\b', - r'\bwait_?\w+_?(task|reply|followup|meeting|sync|api_key)\w*\b', - r'\bresolve_\w+\b', - r'\bfocus[_ ]?item\b', - r'\btask_delegate\b', - r'\bfocus_ref\b', - r'✅\s*(a2a\w+|wait\w+|触发器\w*|focus\w*).*(?:已取消|已为|保持|活跃|完成状态)[^\n]*', - r'[\-•]\s*(?:触发器|trigger|focus|wait_\w+|a2a\w+).*[^\n]*', - r'(?:触发器|trigger)\s+\S+\s*(?:已取消|保持活跃|已为完成状态|fired)', - r'已静默清理触发器', - r'已静默处理完毕', - r'继续待命[。,]?\s*', - r',?\s*(?:继续)?待命。', - ] - for _pat in _internal_patterns: - cleaned = _re.sub(_pat, '', cleaned, flags=_re.IGNORECASE) - cleaned = _re.sub(r'\n{3,}', '\n\n', cleaned).strip() - cleaned = _re.sub(r'[。,]\s*$', '', cleaned).strip() - if not cleaned: - cleaned = final_reply - else: - cleaned = final_reply - - notification = f"⚡ {summary}\n\n{cleaned}" - - target_session_id = delivery_target["session_id"] - owner_user_id = delivery_target.get("owner_user_id") - - # Save to the resolved destination session for persistence. - async with async_session() as db: - from app.models.chat_session import ChatSession - from app.api.websocket import maybe_mark_session_read_for_active_viewer - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=target_session_id, - role="assistant", - content=notification, - user_id=agent.creator_id, - )) - session_row = await db.get(ChatSession, uuid.UUID(target_session_id)) - if session_row: - session_row.last_message_at = datetime.now(timezone.utc) - if owner_user_id: - await maybe_mark_session_read_for_active_viewer( - db, - agent_id=agent_id, - session_id=target_session_id, - user_id=uuid.UUID(owner_user_id), - ) - await db.commit() - - payload = { - "type": "trigger_notification", - "content": notification, - "triggers": [t.name for t in triggers], - "session_id": target_session_id, - } - - # Notify only the user who owns the destination session. The frontend will append - # the message only when that exact session is open; otherwise it just refreshes - # unread/session state. - if owner_user_id: - await ws_manager.send_to_user(agent_id_str, owner_user_id, payload) - except Exception as e: - logger.error(f"Failed to push trigger result to WebSocket: {e}") - import traceback - traceback.print_exc() - - # Audit log - await write_audit_log("trigger_fired", { - "agent_name": agent.name, - "triggers": [{"name": t.name, "type": t.type} for t in triggers], - }, agent_id=agent_id) - - logger.info(f"⚡ Triggers fired for {agent.name}: {[t.name for t in triggers]}") - - except Exception as e: - logger.error(f"Failed to invoke agent {agent_id} for triggers: {e}") - import traceback - traceback.print_exc() - - # ── Main Tick Loop ────────────────────────────────────────────────── async def _tick():