diff --git a/.env.example b/.env.example index 7317369ea..46202b553 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,30 @@ FEISHU_REDIRECT_URI=http://localhost:3000/auth/feishu/callback # Default: local host -> ~/.clawith/data/agents ; container runtime -> /data/agents # AGENT_DATA_DIR= +# File storage backend. Use "s3" for S3-compatible object storage. +# When STORAGE_BACKEND=s3, local fallback lets old files under STORAGE_LOCAL_ROOT +# be read and copied into S3 on first access during migration. +# STORAGE_BACKEND=local +# STORAGE_LOCAL_ROOT= +# STORAGE_LOCAL_FALLBACK_ENABLED=true +# S3_BUCKET= +# S3_REGION= +# S3_ENDPOINT_URL= +# S3_ACCESS_KEY_ID= +# S3_SECRET_ACCESS_KEY= +# S3_PREFIX=agents +# S3_MAX_POOL_CONNECTIONS=50 +# S3_WRITE_WORKERS=32 + +# Local MinIO settings used by docker-compose.multi-instance.yml. +# Change the password before exposing MinIO outside local development. +MINIO_ROOT_USER=clawith +MINIO_ROOT_PASSWORD=clawith-minio-secret +MINIO_BUCKET=clawith +MINIO_API_PORT=9000 +MINIO_CONSOLE_PORT=9001 +API_PORT=8000 + # Jina AI API key (for jina_search and jina_read tools — get one at https://jina.ai) # Without a key, the tools still work but with lower rate limits JINA_API_KEY= @@ -37,3 +61,13 @@ PUBLIC_BASE_URL= # Password reset token lifetime in minutes PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=30 + +# Frontend port (default: 3008) +# FRONTEND_PORT=3008 + +# API upstream for nginx proxy (default: backend:8000) +# API_UPSTREAM=backend:8000 + +# Python pip index URL (for China mirrors) +# CLAWITH_PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple +# CLAWITH_PIP_TRUSTED_HOST=pypi.tuna.tsinghua.edu.cn diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000..d7534f2fa --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,410 @@ +name: Release + +# Uses GitHub Models API for AI release notes. +# Requires: Repository secret MODELS_TOKEN (PAT with models:read scope) + +on: + workflow_dispatch: + inputs: + release_type: + description: Release type to cut + required: true + default: auto + type: choice + options: + - auto + - patch + - minor + - major + prerelease: + description: Mark the GitHub Release as a prerelease + required: true + default: false + type: boolean + use_ai_notes: + description: Use GitHub Models to draft release notes when MODELS_TOKEN is configured + required: true + default: true + type: boolean + +permissions: + contents: write + +concurrency: + group: release-${{ github.ref_name }} + cancel-in-progress: false + +jobs: + release: + name: Cut release + runs-on: ubuntu-latest + + steps: + - name: Checkout source + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + fetch-depth: 0 + persist-credentials: true + + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "41898282+github-actions[bot]@users.noreply.github.com" + + - name: Resolve base tag and target version + id: version + shell: bash + env: + REQUESTED_RELEASE_TYPE: ${{ inputs.release_type }} + run: | + set -euo pipefail + + git fetch --force --tags + + stable_tag="$(git tag --merged HEAD --list 'v*' --sort=-version:refname | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' | head -n 1 || true)" + base_tag="$stable_tag" + if [ -z "$base_tag" ]; then + base_tag="$(git tag --merged HEAD --list 'v*' --sort=-version:refname | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+([-.][0-9A-Za-z.]+)?$' | head -n 1 || true)" + fi + if [ -z "$base_tag" ]; then + base_tag="v0.0.0" + log_range="" + else + log_range="${base_tag}..HEAD" + fi + + if [ -n "$log_range" ] && [ -z "$(git log --oneline "$log_range")" ]; then + echo "No commits found since ${base_tag}; skipping duplicate release." + exit 1 + fi + + release_type="$REQUESTED_RELEASE_TYPE" + if [ "$release_type" = "auto" ]; then + if [ -n "$log_range" ]; then + subjects="$(git log --format=%s "$log_range")" + bodies="$(git log --format=%B "$log_range")" + else + subjects="$(git log --format=%s)" + bodies="$(git log --format=%B)" + fi + if printf '%s\n' "$bodies" | grep -Eq 'BREAKING CHANGE|^[^[:space:]]+(\([^)]+\))?!:'; then + release_type="major" + elif printf '%s\n' "$subjects" | grep -Eq '^feat(\([^)]+\))?:'; then + release_type="minor" + else + release_type="patch" + fi + fi + + next_version="$(python - "$base_tag" "$release_type" <<'PY' + import re + import sys + + base_tag, release_type = sys.argv[1], sys.argv[2] + match = re.match(r"^v?(\d+)\.(\d+)\.(\d+)", base_tag) + if not match: + major, minor, patch = 0, 0, 0 + else: + major, minor, patch = map(int, match.groups()) + + if release_type == "major": + major += 1 + minor = 0 + patch = 0 + elif release_type == "minor": + minor += 1 + patch = 0 + else: + patch += 1 + + print(f"{major}.{minor}.{patch}") + PY + )" + + tag_name="v${next_version}" + + if git rev-parse "$tag_name" >/dev/null 2>&1; then + echo "Tag ${tag_name} already exists." + exit 1 + fi + + { + echo "base_tag=$base_tag" + echo "release_type=$release_type" + echo "version=$next_version" + echo "tag=$tag_name" + echo "log_range=$log_range" + } >> "$GITHUB_OUTPUT" + + echo "Base tag: $base_tag" + echo "Release type: $release_type" + echo "Next version: $next_version" + + - name: Collect release context + shell: bash + env: + BASE_TAG: ${{ steps.version.outputs.base_tag }} + TARGET_VERSION: ${{ steps.version.outputs.version }} + TARGET_TAG: ${{ steps.version.outputs.tag }} + RELEASE_TYPE: ${{ steps.version.outputs.release_type }} + LOG_RANGE: ${{ steps.version.outputs.log_range }} + SOURCE_REF: ${{ github.ref_name }} + run: | + set -euo pipefail + + mkdir -p .github/release-artifacts + + if [ -n "$LOG_RANGE" ]; then + git log --no-merges --pretty=format:'- %s (%h)' "$LOG_RANGE" > .github/release-artifacts/commit-bullets.txt + git log --no-merges --pretty=format:'%H%x09%s' "$LOG_RANGE" > .github/release-artifacts/commit-table.tsv + else + git log --no-merges --pretty=format:'- %s (%h)' > .github/release-artifacts/commit-bullets.txt + git log --no-merges --pretty=format:'%H%x09%s' > .github/release-artifacts/commit-table.tsv + fi + + python <<'PY' + from pathlib import Path + import os + + base_tag = os.environ["BASE_TAG"] + target_version = os.environ["TARGET_VERSION"] + target_tag = os.environ["TARGET_TAG"] + release_type = os.environ["RELEASE_TYPE"] + + notes_path = Path("RELEASE_NOTES.md") + existing = notes_path.read_text(encoding="utf-8") if notes_path.exists() else "" + style_excerpt = "\n".join(existing.splitlines()[:120]).strip() + commit_bullets = Path(".github/release-artifacts/commit-bullets.txt").read_text(encoding="utf-8").strip() + + # Limit commit bullets to fit within API token limits + lines = commit_bullets.splitlines() + if len(lines) > 100: + lines = lines[:100] + [f"- ... and {len(lines) - 100} more commits"] + commit_summary = "\n".join(lines) + + prompt = f"""You are writing Clawith release notes in markdown. + + Rules: + - Start with a top-level heading exactly like: # {target_tag} — + - Keep the tone concise and product-focused. + - Use these sections when they make sense: ## What's New, ## Bug Fixes, ## Upgrade Guide, ## Notes + - Mention upgrade or migration risk only when the commits strongly imply it. + - Do not invent features that are not present in the commit list. + - Prefer grouping related changes instead of listing every commit verbatim. + + Context: + - Previous release tag: {base_tag} + - Target release tag: {target_tag} + - Release type: {release_type} + - Release branch: {os.environ["SOURCE_REF"]} + + Commits included in this release: + {commit_summary or "- No commit bullets collected"} + """ + + Path(".github/release-artifacts/release-prompt.txt").write_text(prompt, encoding="utf-8") + PY + + - name: Update version files + shell: bash + env: + TARGET_VERSION: ${{ steps.version.outputs.version }} + run: | + set -euo pipefail + + printf '%s\n' "$TARGET_VERSION" > backend/VERSION + printf '%s\n' "$TARGET_VERSION" > frontend/VERSION + + - name: Draft release notes with GitHub Models + if: ${{ inputs.use_ai_notes }} + shell: bash + env: + MODELS_TOKEN: ${{ secrets.MODELS_TOKEN }} + run: | + set -euo pipefail + + python <<'PY' + from pathlib import Path + import json + import os + + prompt = Path(".github/release-artifacts/release-prompt.txt").read_text(encoding="utf-8") + payload = { + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": prompt} + ], + } + Path(".github/release-artifacts/openai-payload.json").write_text( + json.dumps(payload, ensure_ascii=False), + encoding="utf-8", + ) + PY + + curl -fsSL https://models.inference.ai.azure.com/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $MODELS_TOKEN" \ + -d @.github/release-artifacts/openai-payload.json \ + > .github/release-artifacts/openai-response.json + + python <<'PY' + from pathlib import Path + import json + + response = json.loads(Path(".github/release-artifacts/openai-response.json").read_text(encoding="utf-8")) + text = response.get("choices", [{}])[0].get("message", {}).get("content", "").strip() + + if not text: + raise SystemExit("GitHub Models did not return release note text.") + + Path(".github/release-artifacts/release-notes.generated.md").write_text( + text.rstrip() + "\n", + encoding="utf-8", + ) + PY + + - name: Build fallback release notes + shell: bash + env: + BASE_TAG: ${{ steps.version.outputs.base_tag }} + TARGET_VERSION: ${{ steps.version.outputs.version }} + TARGET_TAG: ${{ steps.version.outputs.tag }} + SOURCE_REF: ${{ github.ref_name }} + run: | + set -euo pipefail + + if [ -s .github/release-artifacts/release-notes.generated.md ]; then + exit 0 + fi + + python <<'PY' + from pathlib import Path + import os + + base_tag = os.environ["BASE_TAG"] + target_version = os.environ["TARGET_VERSION"] + target_tag = os.environ["TARGET_TAG"] + source_ref = os.environ["SOURCE_REF"] + + entries = [] + for line in Path(".github/release-artifacts/commit-table.tsv").read_text(encoding="utf-8").splitlines(): + if not line.strip(): + continue + _, subject = line.split("\t", 1) + entries.append(subject.strip()) + + features = [] + fixes = [] + others = [] + for subject in entries: + lowered = subject.lower() + if lowered.startswith("feat"): + features.append(subject) + elif lowered.startswith("fix"): + fixes.append(subject) + else: + others.append(subject) + + def bullets(items): + return "\n".join(f"- {item}" for item in items[:8]) or "- No user-facing highlights captured from commit subjects." + + sections = [ + f"# {target_tag} — Release Highlights", + "", + "## What's New", + bullets(features or others), + ] + + if fixes: + sections.extend([ + "", + "## Bug Fixes", + bullets(fixes), + ]) + + sections.extend([ + "", + "## Upgrade Guide", + "", + "### Docker Deployment", + "```bash", + f"git pull origin {source_ref}", + "docker compose down && docker compose up -d --build", + "```", + "", + "### Source Deployment", + "```bash", + f"git pull origin {source_ref}", + "cd frontend && npm install && npm run build", + "cd ..", + "```", + "", + "## Notes", + f"- Release generated from changes since `{base_tag}`.", + f"- Runtime version files were updated to `{target_version}`.", + ]) + + Path(".github/release-artifacts/release-notes.generated.md").write_text( + "\n".join(sections).rstrip() + "\n", + encoding="utf-8", + ) + PY + + - name: Refresh RELEASE_NOTES.md + shell: bash + run: | + set -euo pipefail + + python <<'PY' + from pathlib import Path + + notes_file = Path(".github/release-artifacts/release-notes.generated.md") + release_notes_path = Path("RELEASE_NOTES.md") + + new_block = notes_file.read_text(encoding="utf-8").strip() + existing = release_notes_path.read_text(encoding="utf-8").strip() if release_notes_path.exists() else "" + + if existing: + combined = f"{new_block}\n\n---\n\n{existing}\n" + else: + combined = f"{new_block}\n" + + release_notes_path.write_text(combined, encoding="utf-8") + PY + + - name: Commit release metadata + shell: bash + env: + TARGET_TAG: ${{ steps.version.outputs.tag }} + run: | + set -euo pipefail + + git add backend/VERSION frontend/VERSION RELEASE_NOTES.md + + if git diff --cached --quiet; then + echo "No release metadata changes to commit." + exit 0 + fi + + git commit -m "chore(release): cut ${TARGET_TAG}" + git push origin HEAD:${{ github.ref_name }} + + - name: Create and push tag + shell: bash + env: + TARGET_TAG: ${{ steps.version.outputs.tag }} + run: | + set -euo pipefail + + git tag -a "$TARGET_TAG" -m "Release $TARGET_TAG" + git push origin "$TARGET_TAG" + + - name: Publish GitHub Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.version.outputs.tag }} + name: ${{ steps.version.outputs.tag }} + body_path: .github/release-artifacts/release-notes.generated.md + prerelease: ${{ inputs.prerelease }} + generate_release_notes: false diff --git a/.gitignore b/.gitignore index 10628af40..408cb34f2 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ node_modules/ dist/ *.zip +uv.lock .vite/ *.log pnpm-lock.yaml diff --git a/ARCHITECTURE_SPEC_EN.md b/ARCHITECTURE_SPEC_EN.md deleted file mode 100644 index 5a3aa7cbe..000000000 --- a/ARCHITECTURE_SPEC_EN.md +++ /dev/null @@ -1,561 +0,0 @@ -# Clawith Architecture Specification - -This document describes the current high-level architecture of Clawith based on the latest codebase. It is intended to help developers quickly identify the system's primary runtime paths, storage model, extension points, and frontend/backend boundaries. - ---- - -## Module 1: System Overview - -Clawith is a multi-tenant agent collaboration platform. The product is not just a chat UI: it combines native WebSocket-driven agents, autonomous trigger-based wakeups, external OpenClaw nodes, multi-channel IM ingress, workspace file operations, MCP-based tool import, enterprise directory sync, and a growing OKR subsystem. - -### 1.1 Tech Stack - -- **Backend**: Python 3.11+, FastAPI, SQLAlchemy 2 async ORM, PostgreSQL, httpx, Loguru. -- **Frontend**: React 19, Vite 6, TypeScript, React Router 7, Zustand, TanStack React Query, i18next, Recharts. -- **Realtime**: WebSocket chat streaming for native agents; additional long-lived background managers for Feishu, DingTalk, WeCom, and Discord. -- **Extension Surface**: Built-in tools, MCP tools, skill packages, AgentBay environments, public published pages, and OpenClaw gateway nodes. - -### 1.2 Application Startup and Assembly - -The backend entry point is `backend/app/main.py`. - -On startup, the app currently does the following: - -1. Configures logging and middleware. -2. Ensures database tables exist by importing all models and calling `Base.metadata.create_all()`. -3. Seeds default tenant data, builtin tools, templates, skills, default agents, and the OKR Agent. -4. Starts core background tasks: - - `trigger_daemon` - - `feishu_ws_manager` - - `dingtalk_stream_manager` - - `wecom_stream_manager` - - `discord_gateway_manager` -5. Registers a broad route surface covering auth, agents, enterprise admin, tools, skills, notifications, pages, gateway, Aware triggers, chat sessions, AgentBay control, and OKR. - -This means `main.py` is both a router composition root and an operational bootstrapper. - -For OKR-specific startup patching, the bootstrap path now also self-heals missing builtin OKR tool rows before patching existing OKR Agents. This prevents prompt/tool-list mismatches where an OKR Agent mentions `upsert_member_daily_report` in context but does not actually receive the tool in its callable LLM tool set. - -The Docker backend entrypoint (`backend/entrypoint.sh`) performs an additional bootstrap sequence before Uvicorn starts: - -1. Imports the model graph and runs `Base.metadata.create_all()`. -2. Applies a small list of legacy additive schema patches (`ALTER TABLE ... ADD COLUMN IF NOT EXISTS`, plus one partial index). -3. Runs `alembic upgrade head`, but logs and continues if migrations fail. -4. Starts `uvicorn`. - -Because development and upgrade environments may still have another backend serving traffic against the same database, the additive patch phase now executes each statement in its own short-lived transaction with a `lock_timeout`. This prevents startup from hanging indefinitely while waiting for `ACCESS EXCLUSIVE` table locks on hot tables such as `users`. - -### 1.3 Directory Map - -#### Backend (`backend/app/`) - -- `api/`: FastAPI route layer. - - `websocket.py`: native agent runtime entry for streaming chat and tool-calling. - - `gateway.py`: OpenClaw edge-node poll/report/send channel. - - `triggers.py` / `webhooks.py`: Aware trigger configuration and public event ingress. - - `enterprise.py` / `admin.py`: tenant admin, SSO, model pool, org sync, platform settings. - - `tools.py` / `skills.py`: tool registry and skill registry management. - - `pages.py`: authenticated page publishing APIs plus public `/p/{short_id}` serving. - - `agentbay_control.py`: human Take Control session APIs for AgentBay browser/computer environments. -- `models/`: SQLAlchemy ORM definitions. -- `services/`: runtime logic, prompt assembly, agent tooling, trigger daemon, MCP resource discovery, org sync, quota guard, OKR services, AgentBay clients, and workspace collaboration helpers. - -#### Frontend (`frontend/src/`) - -- `App.tsx`: route composition and auth bootstrap. -- `pages/AgentDetail.tsx`: primary agent work surface; chat, settings, sessions, tools, triggers, files, and realtime rendering all meet here. -- `pages/Dashboard.tsx`, `pages/Plaza.tsx`, `pages/Messages.tsx`, `pages/EnterpriseSettings.tsx`, `pages/OKR.tsx`: major product views. -- `services/api.ts`: HTTP client layer. -- `stores/`: Zustand auth and UI state. -- `index.css`: global theme, shared layout primitives, and key animations. - ---- - -## Module 2: Core Data Model - -The database model is intentionally broad because Clawith spans SaaS tenancy, agents, collaboration, extensibility, publishing, and enterprise admin. - -### 2.1 Tenant, Identity, and Organization - -Primary models: - -- `Tenant`: company boundary, activation state, SSO-related flags, tenant-level defaults. -- `User` and `Identity`: human account and identity record pairing. -- `IdentityProvider` and `SSOScanSession`: tenant-bound or global authentication/SSO providers and temporary QR/scan login sessions. -- `OrgDepartment` and `OrgMember`: synced enterprise directory/cache for people and department lookup. -- `TenantSetting` and `SystemSetting`: tenant-level or platform-level configuration storage. -- `InvitationCode`: invite-based user onboarding and admin bootstrap. - -This layer supports web auth, SSO login, enterprise directory sync, tenant-specific configuration, and invitation-driven company setup. - -Important invariant: - -- Any tenant-scoped human `User` who becomes a member of a company through registration, company self-creation, or invitation-based joining should also have a corresponding `OrgMember` record in that tenant. Channel-synced members may supply that record from an external provider; otherwise the platform creates a local provider-less `OrgMember` as the canonical relationship/search entry for agent relationship management and OKR tracking. - -### 2.2 Agent Runtime Entities - -Primary models: - -- `Agent`: the main digital employee entity. - - Important fields include `agent_type`, `primary_model_id`, `fallback_model_id`, `status`, heartbeat settings, autonomy policy, tenant ownership, and system-agent flags. -- `Participant`: universal sender/receiver identity used to normalize humans and agents in messaging. -- `ChatSession`: conversation container for web chat, channel conversations, trigger reflection sessions, A2A sessions, and group sessions. - - Platform sessions now distinguish a long-lived primary thread (`is_primary=true`) from temporary side-topic threads. - - Platform-user unread state is tracked per session via `last_read_at_by_user`. -- `ChatMessage` (stored in `audit.py`): the durable event log for user messages, assistant replies, tool calls, and runtime outputs. -- `AgentCredential`: encrypted per-agent session-cookie storage used by integrations such as AgentBay Take Control cookie export and browser-state reinjection, without persisting third-party usernames or passwords. - -The messaging layer is deliberately more general than ordinary user/assistant chat, because the same persistence path supports web UI, IM channels, A2A, and trigger-driven reflection sessions. - -### 2.3 Extensibility, Workspace, and Publishing - -Primary models: - -- `Tool` and `AgentTool`: global/tenant tool registry plus per-agent assignment and config overrides. -- `Skill` and `SkillFile`: skill package registry and multi-file skill content. -- `WorkspaceFileRevision` and `WorkspaceEditLock`: file revision history and short-lived human editing locks for agent workspaces. -- `PublishedPage`: public HTML publishing metadata for workspace files served via short IDs. -- `Notification`: notification inbox records for users and agents. - -This layer is what turns Clawith from a single agent chat surface into a configurable workspace platform with reusable capabilities and publication workflows. - -### 2.4 Autonomy and Async Delivery - -Primary models: - -- `AgentTrigger`: Aware trigger definitions for cron, once, interval, poll, on-message, and webhook wake conditions. -- `GatewayMessage`: delivery queue for OpenClaw nodes that run outside the main backend process. - -These models are the foundation for asynchronous execution and agent wake-up behavior without direct human initiation. - ---- - -## Module 3: Native Agent Runtime - -The native runtime is centered on `backend/app/api/websocket.py`. - -### 3.1 WebSocket Session Bootstrap - -When the frontend opens an agent chat: - -1. The browser connects to `/ws/chat/{agent_id}`. -2. The backend validates the user, agent access, and usable model selection. -3. It loads or creates the relevant `ChatSession`. -4. It reconstructs recent history, including prior `tool_call` records, into the model-facing message format. -5. It starts a realtime streaming loop back to the client. - -This path is used for ordinary web chat, but the same underlying `call_llm()` machinery is also reused by triggers and some background execution paths. - -The chat composer can pass a per-message model override through the WebSocket payload. This override is tenant-scoped, enabled-model-only, and intentionally separate from the saved Agent default model so users with ordinary `use` access can switch their own chat model without needing permission to edit Agent settings. - -For first-party platform chat, the bootstrap now prefers the user's primary session for that agent. This keeps agent-initiated reminders and ongoing context in one durable thread, while user-created ad-hoc sessions remain temporary. - -### 3.2 Prompt Assembly and Runtime Context - -Prompt context is built primarily by `backend/app/services/agent_context.py`. - -The context builder pulls together: - -- `soul.md` -- long-term memory (`memory/memory.md` or legacy fallback) -- a skill index derived from the workspace `skills/` directory -- relationship notes -- runtime system instructions -- special-case injections such as OKR Agent rules or channel-specific capability guidance - -The important architectural point is that an agent's behavior is not defined only by database fields. It is also materially shaped by files in its persistent workspace. - -### 3.3 Tool-Calling Loop - -The core `call_llm()` flow is a bounded iterative loop: - -1. Select a primary model, with runtime fallback to the configured fallback model when needed. -2. Stream assistant output. -3. Detect requested tool calls. -4. Execute tools through the agent tool layer. -5. Append tool results back into the conversation context. -6. Continue until there is no further tool call or limits are reached. - -Key protections already present in the runtime: - -- tool-round limits -- warning injection before limit exhaustion -- hard validation for malformed high-risk tool arguments -- quota checks -- token accounting and estimation fallback when providers do not return usage -- optional vision/media handling via helper services such as `vision_inject.py` - -### 3.4 Session Variants Supported by the Same Runtime - -The same native engine supports more than one conversation shape: - -- direct user-agent web chat -- channel-backed chat sessions -- A2A sessions -- trigger-created reflection sessions -- session resume/history browsing via `chat_sessions.py` - -Two first-party session rules are now important: - -- agent-initiated platform messages reuse the primary session instead of opening a fresh thread each time -- unread badges are derived from assistant/system/tool messages created after `ChatSession.last_read_at_by_user` -- when the owning platform user is actively viewing that exact session, newly delivered assistant/tool/trigger messages immediately advance `last_read_at_by_user` so the active thread does not show itself as unread -- trigger results are routed to one explicit destination session: a user's primary platform session for user-originated context, the matching A2A session for agent-to-agent context, or only the trigger reflection session for pure system/reflection work -- if a trigger already sends the user-facing platform message via `send_platform_message`, the daemon suppresses the extra trigger recap in that primary session and leaves the full execution trace only in the reflection session - -This is why session and participant handling are more complex than a typical one-user/one-bot design. - ---- - -## Module 4: Aware Engine - -The Aware engine is implemented primarily through: - -- `backend/app/models/trigger.py` -- `backend/app/api/triggers.py` -- `backend/app/services/trigger_daemon.py` -- `backend/app/services/heartbeat.py` - -### 4.1 Trigger Types and Evaluation - -Current trigger types include: - -- `cron` -- `once` -- `interval` -- `poll` -- `on_message` -- `webhook` - -`trigger_daemon.py` runs a periodic tick, evaluates enabled triggers, applies cooldown and expiry rules, and groups fired triggers by `agent_id`. - -### 4.2 Invocation Flow - -When triggers fire: - -1. Trigger state is updated before invocation to avoid duplicate fires during long-running LLM tasks. -2. A structured wake context is assembled from trigger name, reason, matched message, focus reference, and webhook payload when relevant. -3. A reflection-style `ChatSession` is created with `source_channel="trigger"`. -4. The native `call_llm()` loop is invoked. -5. Trigger results may be persisted and also pushed back into active user WebSocket sessions as trigger notifications. - -This means Aware is not a separate execution engine. It is a structured wake-up layer on top of the native agent runtime. - -### 4.3 Heartbeat and A2A Wake Integration - -The trigger daemon also coordinates with heartbeat behavior and A2A wake paths: - -- periodic heartbeat checks run on a slower cadence inside the same operational loop -- A2A notifications can be converted into synthetic wake contexts -- dedup windows and chain-depth guards help prevent wake storms - -The current implementation is therefore closer to a unified autonomy framework than a simple scheduler. - ---- - -## Module 5: OpenClaw Gateway and External Channel Ingress - -Clawith has two major non-web ingress families: OpenClaw nodes and IM/workflow channels. - -### 5.1 OpenClaw Gateway - -`backend/app/api/gateway.py` provides the external node protocol for `agent_type="openclaw"` agents. - -The main path is: - -1. External node authenticates with `X-Api-Key`. -2. Node polls for pending `GatewayMessage` work. -3. Node runs its local prompt/tool/model flow. -4. Node reports the result back. -5. Backend writes the result into chat persistence and can notify active WebSocket viewers. - -This allows Clawith to treat remote machines as first-class execution agents while still using the central session/history model. - -### 5.2 Channel Ingress Normalization - -The backend includes channel adapters for: - -- Feishu -- Slack -- Discord -- DingTalk -- WeCom -- Teams - -The integration depth varies, but the architectural pattern is consistent: - -1. Receive an external event. -2. Map sender/channel identity into tenant-aware internal records. -3. Resolve or create the relevant `ChatSession`. -4. Convert the external message into normalized internal context. -5. Reuse the same core LLM execution path. -6. Convert the response back into channel-native delivery format. - -Feishu is currently the deepest integration, including image ingestion, contact mapping, card-style streaming updates, and tenant-stable identity handling. - ---- - -## Module 6: Tool, Skill, and Workspace Ecosystem - -This is one of the most important parts of the system because it defines what agents can actually do. - -### 6.1 Tool Registry and MCP Import - -Tools are stored in the database and assigned per agent. - -There are two main tool classes: - -- builtin tools -- MCP-backed tools - -Key files: - -- `backend/app/api/tools.py` -- `backend/app/services/agent_tools.py` -- `backend/app/services/resource_discovery.py` -- `backend/app/services/mcp_client.py` - -Important behaviors: - -- builtin and tenant-scoped tools can be managed from the backend API -- sensitive tool config values are encrypted/decrypted through the API layer -- MCP servers can be discovered from Smithery and ModelScope -- imported MCP servers can expand into multiple concrete tools -- agent-level tool assignments can override default/global configuration -- Agent tool visibility is tenant-bound: builtin tools are global, admin-imported tools are visible only to agents in the same tenant, and agent-installed tools are visible only through explicit per-agent assignments. The same boundary is enforced in the Agent Tools APIs, assignment updates, and LLM tool loading. - -### 6.2 Skill Registry and Skill Packages - -Skills are separate from tools. - -Tools provide callable actions. Skills provide procedural instructions and optional multi-file assets such as: - -- `SKILL.md` -- helper scripts -- references -- examples - -Key files: - -- `backend/app/api/skills.py` -- `backend/app/services/skill_seeder.py` -- `backend/app/services/agent_context.py` - -The runtime only loads a summarized index into the prompt by default, then expects the agent to read the full skill file when it becomes relevant. - -### 6.3 Workspace Files, Collaboration, and Publishing - -Agent workspaces live on disk under the configured agent data directory, but the database tracks collaboration state. - -Key files: - -- `backend/app/services/workspace_collaboration.py` -- `backend/app/models/workspace.py` -- `backend/app/api/pages.py` - -Current capabilities include: - -- path normalization and traversal-safe file resolution -- revision history for meaningful writes -- short-lived human edit locks to prevent agent/user collisions -- prompt-level workspace organization guidance so agents inspect existing folders before writing, prefer relevant subfolders over `workspace/` root, and create a new topical folder when needed -- public HTML publishing through `PublishedPage` -- sandboxed public rendering with CSP on `/p/{short_id}` - -### 6.4 AgentBay and Take Control - -Clawith also supports shared control of remote browser/computer environments through AgentBay. - -Key files: - -- `backend/app/services/agentbay_client.py` -- `backend/app/api/agentbay_control.py` - -The architectural idea is: - -- agents can operate browser/computer sessions through tools -- humans can temporarily take over those sessions -- Take Control places a lock so automatic agent actions pause during manual intervention -- cookies and browser state can be exported back into agent-managed credentials - -This is a meaningful collaboration layer, not just a thin remote desktop helper. - ---- - -## Module 7: Enterprise and Platform Control Plane - -Beyond agent execution, Clawith contains a substantial admin/control plane. - -### 7.1 Enterprise Management - -`backend/app/api/enterprise.py` is one of the largest and most operationally important route modules. - -It currently handles several responsibilities: - -- tenant-scoped LLM model pool management -- model test calls and provider registry access -- enterprise info and audit/approval-related endpoints -- identity provider CRUD -- SSO-related settings -- org department/member listing -- org sync trigger endpoints -- invitation-code related enterprise administration - -The corresponding services include `sso_service.py`, `enterprise_sync.py`, `org_sync_service.py`, and provider-specific auth/sync adapters. - -Company identity now also includes an optional tenant logo managed from the Company Info tab. Logos are uploaded through the tenant API, validated as square images no larger than 1 MB, stored under the configured agent data directory, and exposed as public UI assets through `/api/tenants/{tenant_id}/logo`. The frontend uses the logo in the sidebar workspace switcher and company selection menu while keeping the existing tenant default model setting intact. -### 7.2 Platform Administration - -`backend/app/api/admin.py` handles platform-wide control for platform admins, including: - -- company listing and creation -- company activation toggles -- platform metrics -- platform-level settings such as self-serve company creation and invitation policies - -This layer is conceptually separate from tenant admin. It operates across all tenants. - -### 7.3 Notifications and Activity - -Operational visibility also includes: - -- `notification.py`: user notification inbox and tenant broadcast flow -- `activity.py` and audit log services: historical activity and usage tracking -- quota guard services: message quota, agent creation quota, agent LLM quota, and heartbeat floor enforcement - -This means the control plane is not only configuration management. It also includes enforcement and observability. - ---- - -## Module 8: Frontend Architecture - -The frontend is not a thin shell. It coordinates routing, auth recovery, realtime chat rendering, enterprise admin surfaces, and workspace-level UX. - -### 8.1 Route Topology - -`frontend/src/App.tsx` defines the current high-level product routes: - -- `/login`, `/forgot-password`, `/reset-password`, `/verify-email` -- `/sso/entry` -- `/setup-company` -- `/dashboard` -- `/plaza` -- `/agents/new` -- `/agents/:id` -- `/messages` -- `/enterprise` -- `/okr` -- `/invitations` -- `/admin/platform-settings` - -The app also consumes token handoff in URL parameters for cross-domain tenant switching, while explicitly avoiding collisions with password-reset and email-verification token flows. - -### 8.2 AgentDetail as the Main Work Surface - -`frontend/src/pages/AgentDetail.tsx` is the most important frontend page. - -It is responsible for a broad mix of concerns: - -- WebSocket chat streaming -- live tool-call rendering -- session switching -- A2A message display -- trigger/Aware configuration UI -- workspace-related controls -- various agent settings and admin panels - -Architecturally, this file functions as the main operating console for a single agent. - -### 8.3 State, Theme, and Realtime Rendering - -Key frontend patterns: - -- Zustand stores hold auth and lightweight global state. -- React Query is available for data-fetching coordination. -- `index.css` centralizes theme primitives, shared animations, and layout tokens. -- The realtime chat UI relies on incremental rendering strategies to avoid repainting the entire message list for every stream chunk. - -There are also global UX behaviors such as: - -- notification bar rendering from public backend settings -- route guards for auth, tenant setup, and email verification -- auto-reconnect/resend behavior in chat flows - ---- - -## Module 9: OKR System - -The OKR subsystem has its own dedicated API surface and service layer and is now a first-class product area rather than a small extension. - -Key files: - -- `backend/app/api/okr.py` -- `backend/app/models/okr.py` -- `backend/app/services/okr_scheduler.py` -- `backend/app/services/okr_daily_collection.py` -- `backend/app/services/okr_reporting.py` -- `backend/app/services/okr_agent_hook.py` - -Current architectural characteristics: - -- tenant-level OKR cadence is persisted through OKR settings -- the OKR Agent is seeded and patched at startup -- daily collection and reporting are coordinated through dedicated backend services -- tracked relationships determine who participates in collection/reporting flows -- the OKR relationship sync flow only auto-links company-visible agents; user-scoped private agents are intentionally excluded even if they belong to the same tenant -- human and agent replies are normalized through the OKR Agent's runtime context and tools -- frontend OKR views include period-aware browsing, company reports, and member-level daily report inspection - -The OKR subsystem therefore combines scheduled workflow, agent instruction shaping, persistence, and reporting UI. - ---- - -Clawith should be understood as a coordinated system of tenant-scoped agents, persistent workspaces, trigger-driven autonomy, channel adapters, and enterprise control surfaces. When adding new features, the main architectural questions are usually: - -- Which tenant boundary does this belong to? -- Does it enter through the native runtime, Aware triggers, a channel adapter, or the OpenClaw gateway? -- Does it belong in workspace files, database models, or both? -- Is it a tool, a skill, a trigger, a published artifact, or a control-plane setting? - -Answering those four questions correctly is usually enough to place new code in the right part of the system. - ---- - -## Changelog - -| Date | Summary | -| --- | --- | -| 2026-05-08 | Fixed chat-side model switching for non-manager collaborators. The model picker now treats ordinary user selections as per-chat WebSocket overrides while preserving saved Agent default updates for users with manage access. | -| 2026-05-08 | Hardened MCP recovery behavior. Smithery auto-recovery now preserves the existing stored connection when a newly-created connection still requires OAuth authorization, preventing usable connections from being overwritten by unauthenticated replacements. MCP transport fallback errors now preserve both Streamable HTTP and SSE failure details instead of masking the original failure with a local exception-variable error. | -| 2026-05-03 | Hardened tool visibility across tenant boundaries. Agent tool lists, tool assignment updates, and LLM runtime tool loading now expose builtin tools globally, admin tools only within the agent's tenant, and agent-installed tools only through explicit agent assignments. | -| 2026-04-28 | Added the workspace switcher and company logo identity flow. Users can switch companies from the sidebar, create or join companies from a modal, and org/platform admins can upload a square company logo that is stored outside source-controlled files and served through the tenant API. | -| 2026-04-27 | Tightened the OKR relationship sync flow so the tenant-wide "Sync Relationship Network" action excludes user-scoped private agents. Only company-visible digital employees are auto-linked into the OKR Agent's collaborator graph, matching the existing incremental OKR hook behavior for newly created agents. | -| 2026-04-27 | Closed the Plaza interaction path for private agents. User-scoped private agents can no longer browse, post, or comment in Plaza, private-agent-authored Plaza content is hidden from feed/detail/stats, and private-agent heartbeat instructions explicitly forbid Plaza access to reduce the risk of confidential information leaking into shared social surfaces. | -| 2026-04-27 | Aligned relationship-management permissions across the Agent Detail page and relationships APIs so org admins and platform admins can manage agent relationships even when an agent's stored access level is `use`. This fixes production cases where the seeded OKR Agent remained read-only for non-creator org admins despite being company-visible. | -| 2026-04-25 | Improved workspace document conversion and navigation ergonomics: uploaded PDF/DOCX/XLSX/PPTX extraction now emits more structured Markdown with real tables and slide/page sections, Markdown-to-PDF rendering preserves Markdown tables and CJK-friendly styling, and the chat-side file tree now defaults to a focused `workspace/` scope with an explicit `All` switch for root-level agent files. | -| 2026-04-25 | Replaced the Markdown-to-PDF tool's dependency on the external `markdown` package with an internal lightweight renderer so PDF export no longer fails on missing runtime modules, defaulted the chat-side file tree to a collapsed initial state, and paused expensive HTML/PDF iframe rendering while the right sidebar is actively being dragged to reduce preview stutter. | -| 2026-04-25 | Smoothed chat-side HTML preview resizing by suspending expensive iframe auto-fit recalculation while the workspace sidebar is actively being dragged, then recomputing once after drag end. This prevents the live preview pane from stuttering when users shrink or widen the right-hand file tree/history column while an HTML file is open. | -| 2026-04-25 | Expanded the chat-side file tree from the `workspace/` subtree to the full agent root so `soul.md`, `focus.md`, `memory/`, and `skills/` content can be previewed from the same sidebar, normalized uploads/new folders to writable roots such as `workspace/uploads`, added inline plain-text preview support, and switched uploaded office-document extraction companions from `.txt` to `.md` so extracted content is easier for users and agents to refine. | -| 2026-04-25 | Reworked saved HTML workspace previews so the chat-side preview pane now renders them at the real panel width with automatic height measurement instead of scaling the entire iframe. This keeps responsive HTML previews aligned with sidebar width changes and restores much more reliable in-preview interaction for script-driven buttons, tabs, forms, and modal flows. | -| 2026-04-25 | Refined the chat-side workspace preview stack so saved HTML files now render through real inline file URLs for more faithful script-driven interactions, revision history cards compute human-readable diffs from stored before/after content instead of always showing a placeholder, file-tree mode tracks a selected target directory for nested uploads/new folders/deletes, upload progress is rendered inline inside the relevant directory, and CSV/XLSX preview and CSV-to-XLSX conversion now preserve detected delimiter-based table structure with trimmed empty trailing columns. | -| 2026-04-25 | Improved workspace preview fidelity for richer file types: HTML preview iframes now preserve interactive scripts/forms/modals while debouncing draft updates to reduce visual flicker, CSV preview now renders a styled header row, XLSX preview returns structured sheet rows for table rendering, and Markdown-to-DOCX conversion now uses an internal parser so it no longer depends on BeautifulSoup being installed at runtime. | -| 2026-04-25 | Expanded the chat-side workspace preview sidebar so both file-tree and version-history modes support manual width resizing, version history now surfaces revision timestamps in a scrollable list, and file-tree mode exposes direct upload plus new-folder actions rooted in the currently viewed workspace directory. | -| 2026-04-25 | Refined the chat-side workspace preview interaction so switching files during editing now surfaces an explicit save/discard/stay decision instead of silently ignoring the click, and preview pinning now uses a live lock reference so agent-driven workspace, browser, desktop, and code updates cannot steal focus while the user has locked the current file. | -| 2026-04-25 | Expanded the chat-side workspace preview browser so the file tree now includes common image assets and the preview pane can render uploaded images inline, keeping the side-panel workspace view aligned with the main workspace browser. | -| 2026-04-25 | Added explicit workspace preview pinning on the chat side panel and tightened auto-focus behavior so agent-driven workspace drafts, file mutations, browser screenshots, desktop screenshots, and code output no longer steal the right-hand preview while the user is editing a file or has manually locked the currently viewed workspace file. | -| 2026-04-25 | Added streaming workspace draft propagation for tool-call arguments in the WebSocket runtime. While file-writing and document-conversion tools are still streaming their argument JSON, the backend now forwards incremental `workspace_draft` payloads through the LLM call chain so the frontend can preview pending workspace changes before the tool finishes executing. | -| 2026-04-24 | Hardened the Docker backend entrypoint so additive startup schema patches no longer block container startup indefinitely when another backend instance is already serving traffic on the same database. Each patch now runs in its own transaction with a short PostgreSQL `lock_timeout`, allowing locked legacy patches to be skipped safely while the backend continues booting. | -| 2026-04-24 | Updated the OKR tool output so `get_okr` resolves member and agent owner names in tool responses instead of exposing raw owner UUIDs wherever a readable owner label is available, keeping chat-based OKR review aligned with the dashboard naming model. | -| 2026-04-24 | Simplified grouped member OKR presentation in the dashboard so the owner name is shown once at the group header level, while nested objective cards focus on objective titles and KR content without repeating owner badges inside each card. | -| 2026-04-23 | Hardened OKR authorization across the dashboard and agent-tool path. The web OKR dashboard is now admin-only for mutating actions, while chat-driven OKR mutations are enforced in the OKR tools using the actual requesting user's role rather than prompt-only guidance: non-admin requests may only create or modify the requester's own personal OKRs, and company-level or other-member OKRs require an org/platform admin. Permission failures now return explicit `Permission denied` messages instead of ambiguous owner/not-found wording. | -| 2026-04-23 | Tightened OKR editing guidance so `get_my_okr` now returns both `objective_id` and `kr_id`, OKR tool descriptions explicitly prefer `update_objective` and `update_kr_content` for revisions, and the seeded OKR Agent persona/tool assignment now distinguishes revision flows from new OKR creation. Also regrouped member OKRs in the dashboard so multiple Objectives for the same member render under a single owner container instead of appearing as separate top-level blocks. | -| 2026-04-23 | Expanded stored member daily report content from a 200-character summary cap to a 2000-character normalized body, preserved line breaks during normalization, and updated the OKR reports detail view to render full wrapped report text instead of looking artificially truncated. | -| 2026-04-23 | Added a first-party `update_kr_content` OKR tool for regular agents so they can modify their own Key Result definition fields such as title, target value, unit, focus reference, and status, complementing the existing progress-only update path. | -| 2026-04-23 | Hardened relationship management so both human and agent relationships reject duplicate additions in the UI and on the backend replacement APIs, preventing repeated entries from being persisted when an already-linked member or digital employee is selected again. | -| 2026-04-23 | Improved the Enterprise OKR settings control surface so auto-saved changes now expose explicit saving/saved/error feedback, the daily collection card shows the company timezone that drives cron execution, and admins can trigger a one-off daily collection test from the settings page. Also fixed OKR daily collection to resolve fallback user sessions for external-channel members without raising a missing `ChatSession` import error. | -| 2026-04-21 | Clarified human messaging tool selection so platform-labeled relationships should use `send_platform_message`, channel-labeled relationships should use `send_channel_message`, and the runtime now transparently reroutes mistaken channel sends for platform-only users back onto the platform messaging path. | -| 2026-04-20 | Strengthened workspace-writing guidance so agents should inspect existing folder structure before creating documents, prefer relevant subfolders instead of dumping files into `workspace/` root, and create a new topical folder when no suitable location exists. | -| 2026-04-20 | Tightened trigger result routing so trigger replies no longer fan out to every active web session; user-originated results now land in their primary session, A2A results stay in their A2A session, pure reflection work stays in trigger/reflection sessions, and user-facing `send_platform_message` deliveries no longer get duplicated by an extra trigger recap in the same chat. | -| 2026-04-20 | Renamed the first-party proactive messaging tool from `send_web_message` to `send_platform_message`, covering both web and app surfaces, and added startup seeder logic to rename legacy tool rows in place so existing agent assignments keep working. | -| 2026-04-20 | Made OKR Agent startup patching self-heal missing builtin OKR tool rows before assigning tools, preventing `Unknown tool: upsert_member_daily_report` failures on older databases. | -| 2026-04-20 | Added primary first-party chat sessions, per-session unread tracking, and agent sidebar unread counts so proactive agent messages reuse one durable platform thread. | diff --git a/COMMIT b/COMMIT deleted file mode 100644 index 26b82316b..000000000 --- a/COMMIT +++ /dev/null @@ -1 +0,0 @@ -5576d9e diff --git a/HTTPS_GUIDE.md b/HTTPS_GUIDE.md deleted file mode 100644 index 30f94cb0c..000000000 --- a/HTTPS_GUIDE.md +++ /dev/null @@ -1,162 +0,0 @@ -# HTTPS Deployment Guide - -This guide covers how to enable HTTPS for a self-hosted Clawith deployment. We recommend using a reverse proxy with automatic certificate management rather than modifying Clawith's Docker setup directly. - -## Option A: Caddy (Recommended — Simplest) - -[Caddy](https://caddyserver.com/) provides automatic HTTPS with zero configuration. - -### 1. Install Caddy - -```bash -# Ubuntu/Debian -sudo apt install -y caddy - -# Or via Docker -docker pull caddy:2 -``` - -### 2. Create a Caddyfile - -``` -your-domain.com { - # Frontend - reverse_proxy localhost:3008 - - # Backend API - handle /api/* { - reverse_proxy localhost:8000 - } - - # WebSocket - handle /ws/* { - reverse_proxy localhost:8000 - } -} -``` - -### 3. Start Caddy - -```bash -sudo caddy start -``` - -Caddy will automatically obtain and renew Let's Encrypt certificates. No additional configuration needed. - ---- - -## Option B: Traefik (Best for Docker-native setups) - -[Traefik](https://traefik.io/) integrates directly with Docker and handles certificates automatically. - -### 1. Add Traefik to your `docker-compose.override.yml` - -```yaml -services: - traefik: - image: traefik:v3.0 - command: - - "--providers.docker=true" - - "--entrypoints.web.address=:80" - - "--entrypoints.websecure.address=:443" - - "--certificatesresolvers.letsencrypt.acme.tlschallenge=true" - - "--certificatesresolvers.letsencrypt.acme.email=your-email@example.com" - - "--certificatesresolvers.letsencrypt.acme.storage=/acme/acme.json" - - "--entrypoints.web.http.redirections.entryPoint.to=websecure" - ports: - - "80:80" - - "443:443" - volumes: - - /var/run/docker.sock:/var/run/docker.sock:ro - - traefik-certs:/acme - - frontend: - labels: - - "traefik.enable=true" - - "traefik.http.routers.clawith.rule=Host(`your-domain.com`)" - - "traefik.http.routers.clawith.entrypoints=websecure" - - "traefik.http.routers.clawith.tls.certresolver=letsencrypt" - - "traefik.http.services.clawith.loadbalancer.server.port=80" - -volumes: - traefik-certs: -``` - -### 2. Start - -```bash -docker compose -f docker-compose.yml -f docker-compose.override.yml up -d -``` - ---- - -## Option C: Nginx + Certbot (Traditional) - -If you prefer a traditional Nginx setup with Let's Encrypt. - -### 1. Install Nginx and Certbot - -```bash -sudo apt install -y nginx certbot python3-certbot-nginx -``` - -### 2. Create Nginx config - -```nginx -# /etc/nginx/sites-available/clawith -server { - server_name your-domain.com; - - location / { - proxy_pass http://127.0.0.1:3008; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - } - - location /api/ { - proxy_pass http://127.0.0.1:8000; - proxy_set_header Host $host; - proxy_set_header X-Real-IP $remote_addr; - } - - location /ws/ { - proxy_pass http://127.0.0.1:8000; - proxy_http_version 1.1; - proxy_set_header Upgrade $http_upgrade; - proxy_set_header Connection "upgrade"; - proxy_read_timeout 86400; - } -} -``` - -### 3. Enable site and obtain certificate - -```bash -sudo ln -s /etc/nginx/sites-available/clawith /etc/nginx/sites-enabled/ -sudo nginx -t && sudo systemctl reload nginx -sudo certbot --nginx -d your-domain.com -``` - -Certbot will automatically modify your Nginx config to add SSL and set up auto-renewal. - ---- - -## Environment Variables - -When running behind HTTPS, set these in your `.env`: - -```bash -# Tell the backend it's behind a reverse proxy -FORWARDED_ALLOW_IPS=* - -# If your Feishu/Slack/Discord webhooks need the public URL -PUBLIC_URL=https://your-domain.com -``` - -## Notes - -- **Do NOT** expose ports 8000 (backend) or 3008 (frontend) directly to the internet when using a reverse proxy. Bind them to `127.0.0.1` only. -- All three options above handle automatic certificate renewal. No manual intervention needed. -- For Cloudflare users: simply point your DNS to the server and enable the Cloudflare proxy — SSL is handled automatically at the edge. diff --git a/backend/app/api/agents.py b/backend/app/api/agents.py index 66a05b36c..9cd204a6c 100644 --- a/backend/app/api/agents.py +++ b/backend/app/api/agents.py @@ -3,14 +3,17 @@ import hashlib import json import secrets +import time import uuid from datetime import datetime, timezone from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, status +from loguru import logger from sqlalchemy import cast, func, select, String from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.core.permissions import build_visible_agents_query, check_agent_access, is_agent_creator from app.core.security import get_current_user from app.database import get_db @@ -19,9 +22,11 @@ from app.models.chat_session import ChatSession from app.models.user import User from app.schemas.schemas import AgentCreate, AgentOut, AgentUpdate +from app.services.storage import get_storage_backend from app.services.access_relationships import ensure_access_granted_platform_relationships router = APIRouter(prefix="/agents", tags=["agents"]) +settings = get_settings() async def _get_active_admin_users(db: AsyncSession, tenant_id: uuid.UUID | None) -> list[User]: @@ -404,10 +409,11 @@ async def create_agent( # Always include global default skills (mcp-installer, skill-creator, # complex-task-executor) - default_result = await db.execute( - select(Skill).where(Skill.is_default) - ) + t_skills_copy_start = time.perf_counter() + t_default_query_start = time.perf_counter() + default_result = await db.execute(select(Skill).where(Skill.is_default)) default_ids = {s.id for s in default_result.scalars().all()} + t_default_query = time.perf_counter() - t_default_query_start # Include the template's declared default skills (e.g. trading templates # ship with `market-data` / `financial-calendar` in their meta.yaml). @@ -415,7 +421,9 @@ async def create_agent( # so the agent has no idea those MCP-backed skills exist and silently # falls back to web search. template_skill_ids: set = set() + t_template_query = 0.0 if data.template_id: + t_template_query_start = time.perf_counter() tpl_r = await db.execute( select(AgentTemplate).where(AgentTemplate.id == data.template_id) ) @@ -426,30 +434,45 @@ async def create_agent( select(Skill).where(Skill.folder_name.in_(folder_names)) ) template_skill_ids = {s.id for s in tpl_skills_r.scalars().all()} + t_template_query = time.perf_counter() - t_template_query_start # Merge user-selected + global default + template-default skill IDs all_skill_ids = set(data.skill_ids or []) | default_ids | template_skill_ids if all_skill_ids: - agent_dir = agent_manager._agent_dir(agent.id) - skills_dir = agent_dir / "skills" - skills_dir.mkdir(parents=True, exist_ok=True) + import asyncio + storage = get_storage_backend() + agent_prefix = agent_manager._agent_storage_prefix(agent.id) - for sid in all_skill_ids: - result = await db.execute( - select(Skill).where(Skill.id == sid).options(selectinload(Skill.files)) + t_skill_fetch_start = time.perf_counter() + skills_result = await db.execute( + select(Skill).where(Skill.id.in_(all_skill_ids)).options(selectinload(Skill.files)) + ) + skills = skills_result.scalars().all() + t_skill_fetch = time.perf_counter() - t_skill_fetch_start + + file_specs = [ + (f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}", sf.content) + for skill in skills + for sf in skill.files + ] + + if file_specs: + t_upload_start = time.perf_counter() + await asyncio.gather(*[ + storage.write_text(key, content, encoding="utf-8") + for key, content in file_specs + ]) + logger.info( + f"[_skills_copy] agent={agent.id} skills={len(skills)} files={len(file_specs)} " + f"fetch={t_skill_fetch:.2f}s upload={time.perf_counter() - t_upload_start:.2f}s " + f"total={time.perf_counter() - t_skills_copy_start:.2f}s" + ) + else: + logger.info( + f"[_skills_copy] agent={agent.id} no files " + f"fetch={t_skill_fetch:.2f}s total={time.perf_counter() - t_skills_copy_start:.2f}s" ) - skill = result.scalar_one_or_none() - if not skill: - continue - # Create folder: skills// - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) - # Write each file - for sf in skill.files: - file_path = skill_folder / sf.path - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(sf.content, encoding="utf-8") # Auto-install template-declared MCP servers using the system Smithery key. # For trading agents, this means shibui/finance lands in the agent's tool @@ -466,7 +489,6 @@ async def create_agent( await db.commit() await db.refresh(agent) - from loguru import logger from app.services.resource_discovery import import_mcp_from_smithery for server_id in template_mcp_servers: try: @@ -491,14 +513,47 @@ async def create_agent( f"on agent {agent.id} raised: {e}" ) - # Start container + # Start container first (non-blocking if Docker available) await agent_manager.start_container(db, agent) await db.flush() + # Commit agent and basic setup before async operations from app.services.okr_agent_hook import hook_new_agent if agent.tenant_id: await hook_new_agent(db, agent.id, agent.tenant_id) - await db.commit() + await db.commit() + await db.refresh(agent) + + # MCP import runs in background to avoid blocking the response + if template_mcp_servers: + import asyncio + from app.services.resource_discovery import import_mcp_from_smithery + + async def _background_mcp_import(agent_id: uuid.UUID, server_ids: list[str]): + for server_id in server_ids: + try: + result_msg = await import_mcp_from_smithery( + server_id=server_id, + agent_id=agent_id, + config={}, + ) + if result_msg.startswith("❌"): + logger.warning( + f"[create_agent] MCP pre-install for '{server_id}' " + f"on agent {agent_id} reported error: {result_msg[:200]}" + ) + else: + logger.info( + f"[create_agent] MCP pre-install '{server_id}' " + f"succeeded for agent {agent_id}" + ) + except Exception as e: + logger.warning( + f"[create_agent] MCP pre-install for '{server_id}' " + f"on agent {agent_id} raised: {e}" + ) + + asyncio.create_task(_background_mcp_import(agent.id, template_mcp_servers)) return await _agent_to_out(db, agent, current_user.id) diff --git a/backend/app/api/auth.py b/backend/app/api/auth.py index 8cff5f0b3..a61571f49 100644 --- a/backend/app/api/auth.py +++ b/backend/app/api/auth.py @@ -11,7 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.core.security import create_access_token, get_authenticated_user, get_current_user, hash_password, verify_password +from app.core.security import create_access_token, get_authenticated_user, get_current_user, hash_password_async, verify_password_async from app.database import get_db from app.models.user import Identity, User from app.schemas.schemas import ( @@ -194,7 +194,7 @@ async def register_init( ) # If identity existed, verify password - if identity.password_hash and not verify_password(data.password, identity.password_hash): + if identity.password_hash and not await verify_password_async(data.password, identity.password_hash): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Email already registered. Incorrect password." @@ -445,7 +445,7 @@ async def login(data: UserLogin, background_tasks: BackgroundTasks, db: AsyncSes result = await db.execute(query) identity = result.scalar_one_or_none() - if not identity or not identity.password_hash or not verify_password(data.password, identity.password_hash): + if not identity or not identity.password_hash or not await verify_password_async(data.password, identity.password_hash): logger.warning(f"[LOGIN] Invalid credentials for {data.login_identifier} identity_id={identity.id if identity else 'None'}") raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials") @@ -669,9 +669,9 @@ async def reset_password(data: ResetPasswordRequest, db: AsyncSession = Depends( if not identity or not identity.is_active: raise HTTPException(status_code=400, detail="Invalid or expired reset token") - new_hash = hash_password(data.new_password) + new_hash = await hash_password_async(data.new_password) identity.password_hash = new_hash - + await db.flush() await db.commit() return {"ok": True} @@ -863,10 +863,10 @@ async def change_password( user = res.scalar_one() identity = user.identity - if not identity or not identity.password_hash or not verify_password(old_password, identity.password_hash): + if not identity or not identity.password_hash or not await verify_password_async(old_password, identity.password_hash): raise HTTPException(status_code=400, detail="Current password is incorrect") - new_hash = hash_password(new_password) + new_hash = await hash_password_async(new_password) identity.password_hash = new_hash await db.flush() diff --git a/backend/app/api/chat_sessions.py b/backend/app/api/chat_sessions.py index 356d14700..f6210dd82 100644 --- a/backend/app/api/chat_sessions.py +++ b/backend/app/api/chat_sessions.py @@ -361,6 +361,8 @@ async def delete_session( async def get_session_messages( agent_id: uuid.UUID, session_id: uuid.UUID, + limit: int = Query(20, ge=1, le=500, description="Number of messages to return"), + offset: int = Query(0, ge=0, description="Offset for pagination"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): @@ -382,35 +384,34 @@ async def get_session_messages( raise HTTPException(status_code=403, detail="Not authorized to view this session") # Query messages by conversation_id only (agent-to-agent uses session_agent_id) - # Query the latest 500 messages (subquery in DESC, then reverse for display order) + # Optimized: use a single query with ORDER BY and LIMIT instead of subquery from sqlalchemy import desc - latest_subq = ( - select(ChatMessage.id) - .where(ChatMessage.conversation_id == str(session_id)) - .order_by(desc(ChatMessage.created_at)) - .limit(500) - .subquery() - ) msgs_result = await db.execute( select(ChatMessage) - .where(ChatMessage.id.in_(select(latest_subq.c.id))) - .order_by(ChatMessage.created_at.asc()) + .where(ChatMessage.conversation_id == str(session_id)) + .order_by(desc(ChatMessage.created_at)) + .limit(limit) + .offset(offset) ) - messages = msgs_result.scalars().all() + messages = list(reversed(msgs_result.scalars().all())) # Reading your own first-party/channel session should clear its unread state. if str(session.user_id) == str(current_user.id) and not session.is_group and session.source_channel not in ("agent", "trigger"): session.last_read_at_by_user = datetime.now(tz.utc) await db.commit() - # Resolve sender names for agent sessions + # Batch fetch all participant names to avoid N+1 queries sender_cache: dict = {} if session.source_channel == "agent": from app.models.participant import Participant - for m in messages: - if m.participant_id and str(m.participant_id) not in sender_cache: - p_r = await db.execute(select(Participant.display_name).where(Participant.id == m.participant_id)) - sender_cache[str(m.participant_id)] = p_r.scalar_one_or_none() or "Unknown" + participant_ids = list({m.participant_id for m in messages if m.participant_id}) + if participant_ids: + p_result = await db.execute( + select(Participant.id, Participant.display_name) + .where(Participant.id.in_(participant_ids)) + ) + for row in p_result.all(): + sender_cache[str(row[0])] = row[1] or "Unknown" out = [] for m in messages: diff --git a/backend/app/api/feishu.py b/backend/app/api/feishu.py index 00fe32737..c279c6d63 100644 --- a/backend/app/api/feishu.py +++ b/backend/app/api/feishu.py @@ -4,6 +4,7 @@ import time import uuid from collections.abc import Awaitable, Callable +from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from loguru import logger @@ -18,6 +19,7 @@ from app.models.identity import IdentityProvider from app.schemas.schemas import ChannelConfigCreate, ChannelConfigOut, TokenResponse, UserOut from app.services.feishu_service import feishu_service +from app.services.storage import agent_upload_key, get_storage_backend, store_agent_upload router = APIRouter(tags=["feishu"]) @@ -33,6 +35,21 @@ "抱歉,我暂时无法稳定识别你的飞书账号,已停止本次处理以避免重复创建账号。" "请稍后重试,或联系管理员检查飞书 Contact API 权限。" ) + + +def _storage_mtime(entry) -> float: + raw = str(getattr(entry, "modified_at", "") or "") + if not raw: + return 0.0 + try: + return float(raw) + except ValueError: + try: + return datetime.fromisoformat(raw.replace("Z", "+00:00")).timestamp() + except ValueError: + return 0.0 + + def _build_card( answer_text: str, thinking_text: str = "", @@ -566,20 +583,18 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession if _post_image_keys: import base64 as _b64 _msg_id = message.get("message_id", "") - from pathlib import Path as _PostPath - from app.config import get_settings as _post_gs - _post_settings = _post_gs() - _upload_dir = _PostPath(_post_settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - _upload_dir.mkdir(parents=True, exist_ok=True) for _ik in _post_image_keys: try: _img_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, _msg_id, _ik, "image" ) - # Save to workspace - _save_path = _upload_dir / f"image_{_ik[-8:]}.jpg" - _save_path.write_bytes(_img_bytes) - logger.info(f"[Feishu] Saved post image to {_save_path} ({len(_img_bytes)} bytes)") + _, _workspace_path, _save_path = await store_agent_upload( + agent_id, + f"image_{_ik[-8:]}.jpg", + _img_bytes, + content_type="image/jpeg", + ) + logger.info(f"[Feishu] Saved post image to {_workspace_path} ({len(_img_bytes)} bytes)") # Embed as base64 marker for vision models _b64_data = _b64.b64encode(_img_bytes).decode("ascii") _image_markers.append(f"[image_data:data:image/jpeg;base64,{_b64_data}]") @@ -817,21 +832,22 @@ async def process_feishu_event(agent_id: uuid.UUID, body: dict, db: AsyncSession # to disk always succeeds even if the DB transaction fails. try: import time as _time - import pathlib as _pl - from app.config import get_settings as _gs - _upload_dir = _pl.Path(_gs().AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" + _storage = get_storage_backend() + _upload_key = agent_upload_key(agent_id, "placeholder").rsplit("/", 1)[0] _recent_file_path = None - if _upload_dir.exists() and "uploads/" not in user_text and "workspace/" not in user_text: + if "uploads/" not in user_text and "workspace/" not in user_text: _now = _time.time() - _candidates = sorted( - _upload_dir.iterdir(), - key=lambda p: p.stat().st_mtime, - reverse=True, - ) - for _fp in _candidates: - if _fp.is_file() and (_now - _fp.stat().st_mtime) < 1800: # 30 min - _recent_file_path = f"uploads/{_fp.name}" - break + if await _storage.exists(_upload_key) and await _storage.is_dir(_upload_key): + _candidates = sorted( + [e for e in await _storage.list_dir(_upload_key) if not e.is_dir], + key=_storage_mtime, + reverse=True, + ) + for _entry in _candidates: + _mtime = _storage_mtime(_entry) + if _mtime and (_now - _mtime) < 1800: + _recent_file_path = f"uploads/{_entry.name}" + break if _recent_file_path: # _recent_file_path is relative to uploads dir; agent workspace root is # AGENT_DATA_DIR/{agent_id}/, so the correct relative path is workspace/uploads/ @@ -869,11 +885,16 @@ async def _feishu_file_sender(file_path, msg: str = ""): _fs = _gs_fallback() _base_url = getattr(_fs, 'BASE_URL', '').rstrip('/') or '' _fp = _P(file_path) - _ws_root = _P(_fs.AGENT_DATA_DIR) + _parts = list(_fp.parts) try: - _rel = str(_fp.relative_to(_ws_root / str(agent_id))) + _workspace_idx = _parts.index("workspace") + _rel = "/".join(_parts[_workspace_idx:]) except ValueError: - _rel = _fp.name + _ws_root = _P(getattr(_fs, "STORAGE_LOCAL_ROOT", "") or _fs.AGENT_DATA_DIR) + try: + _rel = str(_fp.relative_to(_ws_root / str(agent_id))) + except ValueError: + _rel = _fp.name _fallback_parts = [] if msg: _fallback_parts.append(msg) @@ -1195,8 +1216,6 @@ async def _handle_feishu_file( ): """Handle incoming file or image messages from Feishu (runs as a background task).""" import asyncio, random, json - from pathlib import Path - from app.config import get_settings from app.models.audit import ChatMessage from app.models.agent import Agent as AgentModel from app.models.user import User as UserModel @@ -1226,18 +1245,18 @@ async def _handle_feishu_file( return # Resolve workspace upload dir - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - save_path = upload_dir / filename - # Download the file try: file_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, message_id, file_key, res_type ) - save_path.write_bytes(file_bytes) - logger.info(f"[Feishu] Saved {msg_type} to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, save_path = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg" if msg_type == "image" else None, + ) + logger.info(f"[Feishu] Saved {msg_type} to {workspace_path} ({len(file_bytes)} bytes)") except Exception as e: logger.error(f"[Feishu] Failed to download {msg_type}: {e}") err_tip = "抱歉,文件下载失败。可能原因:机器人缺少 `im:resource` 权限(文件读取)。\n请在飞书开放平台 → 权限管理 → 批量导入权限 JSON → 重新发布机器人版本后重试。" @@ -1555,20 +1574,18 @@ async def _img_heartbeat(): async def _download_post_images(agent_id, config, message_id, image_keys): """Download images embedded in a Feishu post message to the agent's workspace.""" - from pathlib import Path - from app.config import get_settings - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - for ik in image_keys: try: file_bytes = await feishu_service.download_message_resource( config.app_id, config.app_secret, message_id, ik, "image" ) - save_path = upload_dir / f"image_{ik[-8:]}.jpg" - save_path.write_bytes(file_bytes) - logger.info(f"[Feishu] Saved post image to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload( + agent_id, + f"image_{ik[-8:]}.jpg", + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[Feishu] Saved post image to {workspace_path} ({len(file_bytes)} bytes)") except Exception as e: logger.error(f"[Feishu] Failed to download post image {ik}: {e}") diff --git a/backend/app/api/files.py b/backend/app/api/files.py index 2dc8b7831..64dcb42b8 100644 --- a/backend/app/api/files.py +++ b/backend/app/api/files.py @@ -24,12 +24,20 @@ from app.services.workspace_collaboration import ( acquire_edit_lock, content_hash, + delete_workspace_file, list_revisions, read_text_if_exists, record_revision, release_edit_lock, write_workspace_file, ) +from app.services.storage import ( + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, +) +from app.services.storage_runtime.base import StorageEntry from app.services.workspace_paths import WorkspacePathError, resolve_agent_visible_path from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -44,18 +52,21 @@ class FileInfo(BaseModel): is_dir: bool size: int = 0 modified_at: str = "" + version_token: str | None = None url: str | None = None class FileContent(BaseModel): path: str content: str + version_token: str | None = None class FileWrite(BaseModel): content: str autosave: bool = False session_id: str | None = None + expected_version_token: str | None = None class FileLockBody(BaseModel): @@ -65,6 +76,7 @@ class FileLockBody(BaseModel): class RestoreRevisionBody(BaseModel): revision_id: uuid.UUID + expected_version_token: str | None = None TEXT_PREVIEW_EXTENSIONS = { @@ -128,7 +140,14 @@ class RestoreRevisionBody(BaseModel): def _agent_base_dir(agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / str(agent_id) + + +def _agent_storage_key(agent_id: uuid.UUID, rel_path: str = "") -> str: + prefix = str(agent_id) + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix def _safe_path(agent_id: uuid.UUID, rel_path: str) -> Path: @@ -154,6 +173,21 @@ def _visible_path(agent_id: uuid.UUID, rel_path: str, tenant_id: uuid.UUID | Non return resolved.path, resolved.relative_root, resolved.is_enterprise +def _is_enterprise_visible_path(rel_path: str) -> bool: + normalized = (rel_path or "").strip().strip("/") + return normalized == "enterprise_info" or normalized.startswith("enterprise_info/") + + +def _visible_storage_key(agent_id: uuid.UUID, rel_path: str, tenant_id: uuid.UUID | None) -> tuple[str, bool]: + normalized = (rel_path or "").strip().strip("/") + if _is_enterprise_visible_path(normalized): + if not tenant_id: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No tenant associated") + sub_path = normalized[len("enterprise_info"):].lstrip("/") + return _enterprise_storage_key(str(tenant_id), sub_path), True + return _agent_storage_key(agent_id, normalized), False + + async def _require_agent_file_delete_access( db: AsyncSession, current_user: User, @@ -178,46 +212,52 @@ async def list_files( ): """List files and directories in an agent's file system.""" await check_agent_access(db, current_user, agent_id) - target, base_abs, is_enterprise = _visible_path(agent_id, path, current_user.tenant_id) - if is_enterprise: - target.mkdir(parents=True, exist_ok=True) - - if not target.exists(): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Path not found") - if not target.is_dir(): + storage = get_storage_backend() + storage_key, is_enterprise = _visible_storage_key(agent_id, path, current_user.tenant_id) + normalized_path = (path or "").strip().strip("/") + path_exists = await storage.exists(storage_key) + path_is_dir = await storage.is_dir(storage_key) + if not path_exists and not path_is_dir: + if not ( + normalized_path == "" + or (is_enterprise and normalized_path == "enterprise_info") + ): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Path not found") + elif path_exists and not path_is_dir: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Path is not a directory") items = [] - base_abs = base_abs.resolve() if not path and current_user.tenant_id: - enterprise_root = (Path(settings.AGENT_DATA_DIR) / f"enterprise_info_{current_user.tenant_id}").resolve() - enterprise_root.mkdir(parents=True, exist_ok=True) items.append(FileInfo( name="enterprise_info", path="enterprise_info", is_dir=True, size=0, - modified_at=str(enterprise_root.stat().st_mtime), + modified_at="", + version_token=None, url=None, )) - for entry in sorted(target.iterdir(), key=lambda e: (not e.is_dir(), e.name)): + entries = await storage.list_dir(storage_key) if path_exists or path_is_dir else [] + for entry in entries: if entry.name == '.gitkeep': continue if not path and entry.name.lower() in {"focus.md", "agenda.md"}: continue if not path and entry.name == "enterprise_info": continue - rel = str(entry.resolve().relative_to(base_abs)) if is_enterprise: - rel = f"enterprise_info/{rel}" if rel != "." else "enterprise_info" - stat = entry.stat() + rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) + rel_path = f"enterprise_info/{rel}" if rel != "." else "enterprise_info" + else: + rel_path = str(Path(entry.key).relative_to(str(agent_id))) items.append(FileInfo( name=entry.name, - path=rel, - is_dir=entry.is_dir(), - size=stat.st_size if entry.is_file() else 0, - modified_at=str(stat.st_mtime), - url=f"/api/agents/{agent_id}/files/download?path={rel}" if not entry.is_dir() else None + path=rel_path, + is_dir=entry.is_dir, + size=entry.size, + modified_at=entry.modified_at, + version_token=_entry_version_token(entry), + url=f"/api/agents/{agent_id}/files/download?path={rel_path}" if not entry.is_dir else None )) return items @@ -236,17 +276,33 @@ async def read_file( status_code=status.HTTP_410_GONE, detail="Focus is stored in the system database. Use the Focus API.", ) - target, _, _ = _visible_path(agent_id, path, current_user.tenant_id) - - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key, _ = _visible_storage_key(agent_id, path, current_user.tenant_id) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + version = await storage.get_version(key) try: - async with aiofiles.open(target, "r", encoding="utf-8") as f: - content = await f.read() - return FileContent(path=path, content=content) + content = await storage.read_text(key, encoding="utf-8", errors="replace") + return FileContent(path=path, content=content, version_token=version.token) except UnicodeDecodeError: - return FileContent(path=path, content=f"[二进制文件: {target.name}, {target.stat().st_size} bytes]") + stat = await storage.stat(key) + return FileContent( + path=path, + content=f"[二进制文件: {Path(path).name}, {stat.size} bytes]", + version_token=version.token, + ) + + +def _entry_version_token(entry: StorageEntry) -> str | None: + token = entry.version_id or entry.etag or entry.content_hash + if token: + return token + if entry.is_dir: + return None + if entry.modified_at or entry.size: + return f"{entry.modified_at}:{entry.size}" + return None def _file_kind(path: str) -> str: @@ -354,16 +410,18 @@ async def preview_file( ): """Return a browser-friendly preview payload for Workspace files.""" await check_agent_access(db, current_user, agent_id) - target, _, _ = _visible_path(agent_id, path, current_user.tenant_id) - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key, _ = _visible_storage_key(agent_id, path, current_user.tenant_id) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") kind = _file_kind(path) - mime_type = mimetypes.guess_type(target.name)[0] or "application/octet-stream" + mime_type = mimetypes.guess_type(Path(path).name)[0] or "application/octet-stream" download_url = f"/api/agents/{agent_id}/files/download?path={path}" + local_target: Path | None = None if kind in {"markdown", "html", "text"}: - content = await read_text_if_exists(target) + content = await storage.read_text(key, encoding="utf-8", errors="replace") return { "path": path, "kind": kind, @@ -373,7 +431,7 @@ async def preview_file( "download_url": download_url, } if kind == "csv": - content = await read_text_if_exists(target) or "" + content = await storage.read_text(key, encoding="utf-8", errors="replace") rows = _parse_csv_rows(content) return { "path": path, @@ -394,6 +452,8 @@ async def preview_file( } if kind == "xlsx": try: + target = await ensure_local_path(key) + local_target = target from openpyxl import load_workbook wb = load_workbook(target, read_only=True, data_only=True) @@ -428,6 +488,8 @@ async def preview_file( "download_url": download_url, } if kind in {"docx", "pptx"}: + target = await ensure_local_path(key) + local_target = target extracted_text = _extract_document_text(target, kind) companion = _find_companion_text_preview(target) companion_content = await read_text_if_exists(companion) if companion is not None else None @@ -440,7 +502,10 @@ async def preview_file( "download_url": download_url, } - companion = _find_companion_text_preview(target) + if local_target is not None: + companion = _find_companion_text_preview(local_target) + else: + companion = None if companion is not None: content = await read_text_if_exists(companion) return { @@ -453,13 +518,13 @@ async def preview_file( "download_url": download_url, } - raw = target.read_bytes() + raw = await storage.read_bytes(key) encoded = base64.b64encode(raw[:1024 * 1024]).decode("ascii") return { "path": path, "kind": kind, "mime_type": mime_type, - "size": target.stat().st_size, + "size": len(raw), "base64_sample": encoded, "download_url": download_url, } @@ -501,13 +566,29 @@ async def download_file( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive") await check_agent_access(db, user, agent_id) - target, _, _ = _visible_path(agent_id, path, user.tenant_id) - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + key, _ = _visible_storage_key(agent_id, path, user.tenant_id) + if not await storage.exists(key) or not await storage.is_file(key): raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") - return FileResponse( - path=str(target), - filename=target.name, - content_disposition_type="inline" if inline else "attachment", + presigned = await storage.presign_download_url(key, filename=Path(path).name, inline=inline) + if presigned: + return Response( + status_code=302, + headers={"Location": presigned}, + ) + local_path = await storage.local_path_for(key) + if local_path is not None: + return FileResponse( + path=str(local_path), + filename=Path(path).name, + content_disposition_type="inline" if inline else "attachment", + ) + data = await storage.read_bytes(key) + disposition = "inline" if inline else "attachment" + return Response( + content=data, + media_type=guess_content_type(Path(path).name), + headers={"Content-Disposition": f'{disposition}; filename="{Path(path).name}"'}, ) @@ -549,6 +630,7 @@ async def write_file( session_id=data.session_id, enforce_human_lock=False, merge_user_autosave=data.autosave, + expected_version_token=data.expected_version_token, ) if not result.ok: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=result.message) @@ -644,29 +726,29 @@ async def restore_file_revision( if revision.after_content is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot restore an empty/deleted revision") - target = _safe_path(agent_id, revision.path) - before = await read_text_if_exists(target) - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(revision.after_content) - restored = await record_revision( + restored = await write_workspace_file( db, agent_id=agent_id, + base_dir=_agent_base_dir(agent_id), path=revision.path, - operation="restore", + content=revision.after_content, actor_type="user", actor_id=current_user.id, - before_content=before, - after_content=revision.after_content, + operation="restore", + enforce_human_lock=False, + expected_version_token=data.expected_version_token, ) + if not restored.ok: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=restored.message) await db.commit() - return {"status": "ok", "path": revision.path, "revision_id": str(restored.id) if restored else None} + return {"status": "ok", "path": revision.path, "revision_id": restored.revision_id} @router.delete("/content") async def delete_file( agent_id: uuid.UUID, path: str, + expected_version_token: str | None = None, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): @@ -677,21 +759,26 @@ async def delete_file( status_code=status.HTTP_410_GONE, detail="Focus is stored in the system database. Use the Focus API.", ) + storage = get_storage_backend() if path.startswith("enterprise_info") and current_user.role not in ("platform_admin", "org_admin"): raise HTTPException(status_code=403, detail="Only admins can delete enterprise knowledge base files") if path.strip("/") == "enterprise_info": raise HTTPException(status_code=400, detail="Cannot delete enterprise_info root") - target, _, _ = _visible_path(agent_id, path, current_user.tenant_id) - - if not target.exists(): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") - - if target.is_dir(): - import shutil - shutil.rmtree(target) - else: - target.unlink() - + result = await delete_workspace_file( + db, + agent_id=agent_id, + base_dir=_agent_base_dir(agent_id), + path=path, + actor_type="user", + actor_id=current_user.id, + enforce_human_lock=False, + expected_version_token=expected_version_token, + ) + if not result.ok: + if "not found" in result.message.lower(): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=result.message) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=result.message) + await db.commit() return {"status": "ok", "path": path} @@ -727,19 +814,11 @@ async def import_skill_to_agent( if not skill.files: raise HTTPException(status_code=400, detail="Skill has no files") - # Write each file into the agent's workspace - base = _agent_base_dir(agent_id) - skill_dir = base / "skills" / skill.folder_name - skill_dir.mkdir(parents=True, exist_ok=True) - + storage = get_storage_backend() written = [] for f in skill.files: - file_path = (skill_dir / f.path).resolve() - # Safety check - if not str(file_path).startswith(str(base.resolve())): - continue - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(f.content, encoding="utf-8") + skill_key = _agent_storage_key(agent_id, f"skills/{skill.folder_name}/{f.path}") + await storage.write_text(skill_key, f.content, encoding="utf-8") written.append(f.path) return { @@ -778,28 +857,25 @@ async def upload_file_to_workspace( if normalized_path not in {"workspace", "skills"} and not normalized_path.startswith(("workspace/", "skills/")): raise HTTPException(status_code=400, detail="右侧根目录视图是 agent 根目录;上传文件时请放到 workspace/ 或 skills/ 目录下") - base = _agent_base_dir(agent_id) - target_dir = (base / normalized_path).resolve() - if not str(target_dir).startswith(str(base.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target_dir.mkdir(parents=True, exist_ok=True) filename = file.filename or "unnamed" # Sanitize filename filename = filename.replace("/", "_").replace("\\", "_") - save_path = target_dir / filename + storage = get_storage_backend() + file_key = _agent_storage_key(agent_id, f"{normalized_path}/{filename}") content = await file.read() - save_path.write_bytes(content) + await storage.write_bytes(file_key, content, content_type=guess_content_type(filename)) # Auto-extract text from non-text files extracted_path = None from app.services.text_extractor import needs_extraction, save_extracted_text if needs_extraction(filename): + save_path = await ensure_local_path(file_key) txt_file = save_extracted_text(save_path, content, filename) if txt_file: - base_abs = base.resolve() - extracted_path = str(txt_file.resolve().relative_to(base_abs)) + extracted_path = f"{normalized_path}/{txt_file.name}" + extracted_key = _agent_storage_key(agent_id, extracted_path) + await storage.write_bytes(extracted_key, txt_file.read_bytes(), content_type="text/plain; charset=utf-8") return { "status": "ok", @@ -817,11 +893,19 @@ async def upload_file_to_workspace( def _enterprise_kb_dir(tenant_id: str) -> Path: - return Path(settings.AGENT_DATA_DIR) / f"enterprise_info_{tenant_id}" / "knowledge_base" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / f"enterprise_info_{tenant_id}" / "knowledge_base" def _enterprise_info_dir(tenant_id: str) -> Path: - return Path(settings.AGENT_DATA_DIR) / f"enterprise_info_{tenant_id}" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / f"enterprise_info_{tenant_id}" + + +def _enterprise_storage_key(tenant_id: str, rel_path: str = "") -> str: + prefix = f"enterprise_info_{tenant_id}" + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix @enterprise_kb_router.get("/files") @@ -832,31 +916,22 @@ async def list_enterprise_kb_files( """List files in enterprise knowledge base (tenant-scoped).""" if not current_user.tenant_id: return [] - info_dir = _enterprise_info_dir(str(current_user.tenant_id)).resolve() - info_dir.mkdir(parents=True, exist_ok=True) - - if path: - target = (info_dir / path).resolve() - else: - target = info_dir - if not str(target).startswith(str(info_dir)): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - if not target.exists() or not target.is_dir(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + if not await storage.exists(storage_key) or not await storage.is_dir(storage_key): return [] items = [] - for entry in sorted(target.iterdir(), key=lambda e: (not e.is_dir(), e.name)): + for entry in await storage.list_dir(storage_key): if entry.name == '.gitkeep': continue - rel = str(entry.resolve().relative_to(info_dir.resolve())) - stat = entry.stat() + rel = str(Path(entry.key).relative_to(f"enterprise_info_{current_user.tenant_id}")) items.append({ "name": entry.name, "path": rel, - "is_dir": entry.is_dir(), - "size": stat.st_size if entry.is_file() else 0, - "url": f"/api/enterprise/knowledge-base/download?path={rel}" if not entry.is_dir() else None + "is_dir": entry.is_dir, + "size": entry.size, + "url": f"/api/enterprise/knowledge-base/download?path={rel}" if not entry.is_dir else None }) return items @@ -875,28 +950,28 @@ async def upload_enterprise_kb_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target_dir = (info_dir / sub_path).resolve() - if not str(target_dir).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target_dir.mkdir(parents=True, exist_ok=True) filename = file.filename or "unnamed" filename = filename.replace("/", "_").replace("\\", "_") - save_path = target_dir / filename + storage = get_storage_backend() + rel_path = f"{sub_path}/{filename}" if sub_path else filename + storage_key = _enterprise_storage_key(str(current_user.tenant_id), rel_path) content = await file.read() - save_path.write_bytes(content) + await storage.write_bytes(storage_key, content, content_type=guess_content_type(filename)) # Auto-extract text from non-text files extracted_path = None from app.services.text_extractor import needs_extraction, save_extracted_text if needs_extraction(filename): + save_path = await ensure_local_path(storage_key) txt_file = save_extracted_text(save_path, content, filename) if txt_file: - extracted_path = str(txt_file.resolve().relative_to(info_dir.resolve())) - - rel_path = f"{sub_path}/{filename}" if sub_path else filename + extracted_path = f"{sub_path}/{txt_file.name}" if sub_path else txt_file.name + await storage.write_bytes( + _enterprise_storage_key(str(current_user.tenant_id), extracted_path), + txt_file.read_bytes(), + content_type="text/plain; charset=utf-8", + ) return { "status": "ok", "path": rel_path, @@ -915,18 +990,17 @@ async def read_enterprise_file( """Read content of an enterprise knowledge base file (tenant-scoped).""" if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - if not target.exists() or not target.is_file(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + if not await storage.exists(storage_key) or not await storage.is_file(storage_key): raise HTTPException(status_code=404, detail="File not found") try: - content = target.read_text(encoding="utf-8", errors="replace") + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") return {"path": path, "content": content} except Exception: - return {"path": path, "content": f"[二进制文件: {target.name}, {target.stat().st_size} bytes]"} + stat = await storage.stat(storage_key) + return {"path": path, "content": f"[二进制文件: {Path(path).name}, {stat.size} bytes]"} @enterprise_kb_router.put("/content") @@ -941,14 +1015,8 @@ async def write_enterprise_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(data.content) + storage = get_storage_backend() + await storage.write_text(_enterprise_storage_key(str(current_user.tenant_id), path), data.content, encoding="utf-8") return {"status": "ok", "path": path} @@ -963,18 +1031,16 @@ async def delete_enterprise_file( if not current_user.tenant_id: raise HTTPException(status_code=400, detail="No tenant associated") - info_dir = _enterprise_info_dir(str(current_user.tenant_id)) - target = (info_dir / path).resolve() - if not str(target).startswith(str(info_dir.resolve())): - raise HTTPException(status_code=403, detail="Path traversal not allowed") - if not target.exists(): + storage = get_storage_backend() + storage_key = _enterprise_storage_key(str(current_user.tenant_id), path) + storage_exists = await storage.exists(storage_key) + storage_is_dir = await storage.is_dir(storage_key) + if not storage_exists and not storage_is_dir: raise HTTPException(status_code=404, detail="File not found") - - if target.is_dir(): - import shutil - shutil.rmtree(target) + if storage_is_dir: + await storage.delete_tree(storage_key) else: - target.unlink() + await storage.delete(storage_key) return {"status": "ok", "path": path} diff --git a/backend/app/api/pages.py b/backend/app/api/pages.py index 461e79a67..3856d80e0 100644 --- a/backend/app/api/pages.py +++ b/backend/app/api/pages.py @@ -1,20 +1,17 @@ """Public pages API — serves published HTML without authentication.""" import uuid -from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import HTMLResponse from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.core.security import get_current_user from app.database import get_db from app.models.published_page import PublishedPage from app.models.user import User - -settings = get_settings() +from app.services.storage import get_storage_backend, normalize_storage_key # Public router — no /api prefix, no auth public_router = APIRouter(tags=["pages"]) @@ -22,11 +19,6 @@ # Authenticated router — under /api prefix router = APIRouter(prefix="/pages", tags=["pages"]) - -def _agent_base_dir(agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) - - # ── Public render (NO auth) ──────────────────────────── @public_router.get("/p/{short_id}") @@ -39,12 +31,12 @@ async def render_page(short_id: str, db: AsyncSession = Depends(get_db)): if not page: raise HTTPException(status_code=404, detail="Page not found") - # Read the HTML file from agent workspace - file_path = _agent_base_dir(page.agent_id) / page.source_path - if not file_path.exists() or not file_path.is_file(): + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{page.agent_id}/{page.source_path}") + if not await storage.exists(storage_key) or not await storage.is_file(storage_key): raise HTTPException(status_code=404, detail="Source file no longer exists") - html_content = file_path.read_text(encoding="utf-8", errors="replace") + html_content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") # Increment view count await db.execute( diff --git a/backend/app/api/relationships.py b/backend/app/api/relationships.py index c2e191e67..4cf8910b2 100644 --- a/backend/app/api/relationships.py +++ b/backend/app/api/relationships.py @@ -1,7 +1,6 @@ """Agent relationship management API — human + agent-to-agent.""" import uuid -from pathlib import Path from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel @@ -22,11 +21,11 @@ from app.database import get_db from app.models.agent import Agent from app.models.org import AgentRelationship, AgentAgentRelationship, OrgMember +from app.models.user import Identity, User from app.services.access_relationships import ensure_access_granted_platform_relationships from app.services.org_sync_adapter import derive_member_department_paths -from app.models.user import User +from app.services.storage import store_agent_bytes -settings = get_settings() router = APIRouter(prefix="/agents/{agent_id}/relationships", tags=["relationships"]) RELATION_LABELS = { @@ -595,11 +594,13 @@ async def _regenerate_relationships_file(db: AsyncSession, agent_id: uuid.UUID): if status_info["access_status"] == "active": agent_rels.append(rel) - ws = Path(settings.AGENT_DATA_DIR) / str(agent_id) - ws.mkdir(parents=True, exist_ok=True) - if not human_rows and not agent_rels: - (ws / "relationships.md").write_text("# 关系网络\n\n_暂无配置的关系。_\n", encoding="utf-8") + await store_agent_bytes( + agent_id, + "relationships.md", + "# 关系网络\n\n_暂无配置的关系。_\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) return lines = ["# 关系网络\n"] @@ -639,4 +640,9 @@ async def _regenerate_relationships_file(db: AsyncSession, agent_id: uuid.UUID): lines.append(f"- {r.description}") lines.append("") - (ws / "relationships.md").write_text("\n".join(lines), encoding="utf-8") + await store_agent_bytes( + agent_id, + "relationships.md", + "\n".join(lines).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) diff --git a/backend/app/api/slack.py b/backend/app/api/slack.py index a7006f9dd..ccb705c76 100644 --- a/backend/app/api/slack.py +++ b/backend/app/api/slack.py @@ -4,6 +4,7 @@ import hmac import time import uuid +from pathlib import Path from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from loguru import logger @@ -16,6 +17,7 @@ from app.models.channel_config import ChannelConfig from app.models.user import User from app.schemas.schemas import ChannelConfigOut +from app.services.storage import store_agent_upload router = APIRouter(tags=["slack"]) @@ -307,17 +309,12 @@ async def slack_event_webhook( history = [{"role": m.role, "content": m.content} for m in reversed(history_r.scalars().all())] # Handle file attachments: save to workspace/uploads/ and send ack - from app.config import get_settings as _gs import asyncio as _asyncio import random as _random - from pathlib import Path as _Path import httpx as _httpx from datetime import datetime, timezone from app.api.feishu import _FILE_ACK_MESSAGES _file_user_messages = [] - _settings = _gs() - _upload_dir = _Path(_settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - _upload_dir.mkdir(parents=True, exist_ok=True) _bot_token = config.app_secret or "" for _sf in slack_files: _fname = _sf.get("name") or _sf.get("title") or f"slack_file_{_sf.get('id', 'unk')}.bin" @@ -332,8 +329,13 @@ async def slack_event_webhook( _ct = _r.headers.get("content-type", "") if "text/html" in _ct or _r.content[:15].lower().startswith(b" Path: - return Path(get_settings().AGENT_DATA_DIR) / "_tenant_logos" - - -def _tenant_logo_path(tenant_id: uuid.UUID) -> Path: - return _tenant_logo_dir() / f"{tenant_id}.png" +def _tenant_logo_key(tenant_id: uuid.UUID) -> str: + return normalize_storage_key(f"_tenant_logos/{tenant_id}.png") def _tenant_logo_url(tenant_id: uuid.UUID) -> str: - try: - mtime = int(_tenant_logo_path(tenant_id).stat().st_mtime) - except OSError: - mtime = int(datetime.utcnow().timestamp()) - return f"/api/tenants/{tenant_id}/logo?v={mtime}" + return f"/api/tenants/{tenant_id}/logo?v={int(datetime.utcnow().timestamp())}" async def _get_updateable_tenant( @@ -582,9 +573,11 @@ async def update_tenant( @router.get("/{tenant_id}/logo") async def get_tenant_logo(tenant_id: uuid.UUID): """Serve a tenant logo. Logos are public UI assets, addressed by UUID.""" - path = _tenant_logo_path(tenant_id) - if not path.exists(): + storage = get_storage_backend() + key = _tenant_logo_key(tenant_id) + if not await storage.exists(key): raise HTTPException(status_code=404, detail="Logo not found") + path = await ensure_local_path(key) return FileResponse(path, media_type="image/png") @@ -621,11 +614,8 @@ async def upload_tenant_logo( if len(png_data) > 1024 * 1024: raise HTTPException(status_code=400, detail="Logo image must be 1 MB or smaller after processing") - logo_dir = _tenant_logo_dir() - logo_dir.mkdir(parents=True, exist_ok=True) - path = _tenant_logo_path(tenant_id) - async with aiofiles.open(path, "wb") as f: - await f.write(png_data) + storage = get_storage_backend() + await storage.write_bytes(_tenant_logo_key(tenant_id), png_data, content_type="image/png") config = dict(tenant.im_config or {}) config["logo_url"] = _tenant_logo_url(tenant_id) @@ -643,12 +633,10 @@ async def delete_tenant_logo( """Remove a custom company logo and fall back to the generated default.""" tenant = await _get_updateable_tenant(tenant_id, current_user, db) - path = _tenant_logo_path(tenant_id) - if path.exists(): - try: - path.unlink() - except OSError as exc: - raise HTTPException(status_code=500, detail="Failed to delete logo") from exc + storage = get_storage_backend() + key = _tenant_logo_key(tenant_id) + if await storage.exists(key): + await storage.delete(key) config = dict(tenant.im_config or {}) config.pop("logo_url", None) diff --git a/backend/app/api/upload.py b/backend/app/api/upload.py index 1aa0ea166..28c7867c9 100644 --- a/backend/app/api/upload.py +++ b/backend/app/api/upload.py @@ -9,13 +9,10 @@ from loguru import logger from app.core.security import get_current_user from app.models.user import User -from app.config import get_settings +from app.services.storage import ensure_local_path, get_storage_backend, guess_content_type, normalize_storage_key router = APIRouter(prefix="/chat", tags=["chat"]) -_settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) - # Supported extensions and their text extraction method TEXT_EXTENSIONS = { ".txt", ".md", ".csv", ".json", ".xml", ".yaml", ".yml", @@ -125,20 +122,19 @@ async def upload_file( # Determine save directory workspace_path = "" if agent_id: - # Save to agent's workspace/uploads/ - uploads_dir = WORKSPACE_ROOT / agent_id / "workspace" / "uploads" - uploads_dir.mkdir(parents=True, exist_ok=True) - save_path = uploads_dir / file.filename - # Avoid overwriting: add suffix if file exists - if save_path.exists(): - stem = save_path.stem - suffix = save_path.suffix - counter = 1 - while save_path.exists(): - save_path = uploads_dir / f"{stem}_{counter}{suffix}" - counter += 1 - save_path.write_bytes(content) - workspace_path = f"workspace/uploads/{save_path.name}" + storage = get_storage_backend() + filename = file.filename.replace("/", "_").replace("\\", "_") + workspace_path = f"workspace/uploads/{filename}" + key = normalize_storage_key(f"{agent_id}/{workspace_path}") + counter = 1 + while await storage.exists(key): + stem, ext = os.path.splitext(filename) + filename = f"{stem}_{counter}{ext}" + workspace_path = f"workspace/uploads/{filename}" + key = normalize_storage_key(f"{agent_id}/{workspace_path}") + counter += 1 + await storage.write_bytes(key, content, content_type=guess_content_type(filename)) + save_path = await ensure_local_path(key) else: # Fallback: save to /tmp (legacy behavior) fallback_dir = Path("/tmp/clawith_uploads") diff --git a/backend/app/api/webhooks.py b/backend/app/api/webhooks.py index 50dd1f435..b61d20ef1 100644 --- a/backend/app/api/webhooks.py +++ b/backend/app/api/webhooks.py @@ -14,17 +14,32 @@ from loguru import logger from sqlalchemy import select +from app.core.events import get_redis from app.database import async_session from app.models.trigger import AgentTrigger +from app.services.trigger_runtime import enqueue_webhook_execution router = APIRouter(prefix="/api/webhooks", tags=["webhooks"]) -# In-memory rate limiter: token -> list of timestamps -_rate_hits: dict[str, list[float]] = {} RATE_LIMIT = 5 # max hits per minute per token MAX_PAYLOAD_SIZE = 65536 # 64KB max payload +async def _record_and_count_hits(token: str) -> int: + """Record the current hit in Redis and return the rolling 60-second count.""" + redis = await get_redis() + now = time.time() + key = f"webhook:rate:{token}" + member = f"{now}:{hashlib.sha1(f'{token}:{now}'.encode()).hexdigest()[:8]}" + async with redis.pipeline(transaction=True) as pipe: + pipe.zremrangebyscore(key, 0, now - 60) + pipe.zadd(key, {member: now}) + pipe.zcard(key) + pipe.expire(key, 120) + _, _, count, _ = await pipe.execute() + return int(count) + + @router.post("/t/{token}") async def receive_webhook(token: str, request: Request): """Receive a webhook POST from an external service. @@ -37,17 +52,13 @@ async def receive_webhook(token: str, request: Request): - Payload size limit (64KB) """ # Rate limiting — use per-agent limit if available - now = time.time() - hits = _rate_hits.get(token, []) - hits = [t for t in hits if now - t < 60] # keep last 60 seconds + hit_count = await _record_and_count_hits(token) # We'll check per-agent rate limit after finding the trigger below. # For now, apply a generous global ceiling to prevent memory abuse. - if len(hits) >= 60: # hard ceiling: 60/min regardless of config + if hit_count >= 60: # hard ceiling: 60/min regardless of config logger.warning(f"Webhook hard rate limit exceeded for token {token[:8]}...") return JSONResponse({"ok": True}, status_code=429) - hits.append(now) - _rate_hits[token] = hits # Payload size check body = await request.body() @@ -83,7 +94,7 @@ async def receive_webhook(token: str, request: Request): agent_obj = agent_result.scalar_one_or_none() agent_rate_limit = (agent_obj.webhook_rate_limit if agent_obj else None) or RATE_LIMIT # Re-check hits against agent-specific limit (hits already collected above) - if len(hits) > agent_rate_limit: # > because we already appended current hit + if hit_count > agent_rate_limit: # > because current hit is already counted logger.warning(f"Webhook per-agent rate limit ({agent_rate_limit}/min) for token {token[:8]}...") # Log audit entry so user can see dropped webhooks try: @@ -120,24 +131,28 @@ async def receive_webhook(token: str, request: Request): try: payload_str = body.decode("utf-8") # Try to pretty-format JSON for readability + payload_obj = None try: payload_obj = json.loads(payload_str) payload_str = json.dumps(payload_obj, ensure_ascii=False, indent=2) except json.JSONDecodeError: - pass # Keep as raw string + payload_obj = None except Exception: + payload_obj = None payload_str = repr(body[:2000]) - # Store payload and set pending flag - new_config = {**cfg, "_webhook_pending": True, "_webhook_payload": payload_str[:8000]} - from sqlalchemy import update - await db.execute( - update(AgentTrigger) - .where(AgentTrigger.id == target.id) - .values(config=new_config) + _execution, created = await enqueue_webhook_execution( + db, + trigger=target, + body=body, + payload_text=payload_str, + payload_obj=payload_obj if isinstance(payload_obj, dict) else None, + request_headers={k.lower(): v for k, v in request.headers.items()}, ) - await db.commit() + if not created: + logger.info(f"Webhook duplicate ignored for trigger {target.name}") + return JSONResponse({"ok": True}) - logger.info(f"Webhook received for trigger {target.name} (agent {target.agent_id})") + logger.info(f"Webhook queued for trigger {target.name} (agent {target.agent_id})") return JSONResponse({"ok": True}) diff --git a/backend/app/api/websocket.py b/backend/app/api/websocket.py index 0b9e37ff9..e37517c4c 100644 --- a/backend/app/api/websocket.py +++ b/backend/app/api/websocket.py @@ -20,6 +20,7 @@ from app.models.user import User from app.services.chat_session_service import ensure_primary_platform_session from app.services.llm import call_llm, call_llm_with_failover +from app.services.realtime import realtime_router router = APIRouter(tags=["websocket"]) @@ -32,59 +33,82 @@ def __init__(self): self.active_connections: dict[str, list[tuple]] = {} async def connect(self, agent_id: str, websocket: WebSocket, session_id: str = None, user_id: str | None = None): - await websocket.accept() if agent_id not in self.active_connections: self.active_connections[agent_id] = [] self.active_connections[agent_id].append((websocket, session_id, user_id)) - - def disconnect(self, agent_id: str, websocket: WebSocket): + await realtime_router.register_connection( + agent_id=agent_id, + websocket=websocket, + session_id=session_id, + user_id=user_id, + ) + + async def disconnect(self, agent_id: str, websocket: WebSocket): if agent_id in self.active_connections: self.active_connections[agent_id] = [ (ws, sid, uid) for ws, sid, uid in self.active_connections[agent_id] if ws != websocket ] + await realtime_router.unregister_connection(agent_id=agent_id, websocket=websocket) + + def _local_connections(self, agent_id: str) -> list[tuple[WebSocket, str | None, str | None]]: + return self.active_connections.get(agent_id, []) + + async def deliver_pubsub_message( + self, + *, + agent_id: str, + payload: dict, + session_id: str | None = None, + user_id: str | None = None, + ) -> None: + if agent_id not in self.active_connections: + return + for ws, sid, uid in list(self.active_connections[agent_id]): + if session_id is not None and sid != session_id: + continue + if user_id is not None and uid != user_id: + continue + try: + await ws.send_json(payload) + except Exception: + pass async def send_message(self, agent_id: str, message: dict): - if agent_id in self.active_connections: - for ws, _sid, _uid in self.active_connections[agent_id]: - try: - await ws.send_json(message) - except Exception: - pass + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + ) async def send_to_session(self, agent_id: str, session_id: str, message: dict): """Send message only to WebSocket connections matching the given session_id.""" - if agent_id in self.active_connections: - for ws, sid, _uid in self.active_connections[agent_id]: - if sid == session_id: - try: - await ws.send_json(message) - except Exception: - pass + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + session_id=session_id, + ) async def send_to_user(self, agent_id: str, user_id: str, message: dict): """Send message to all live WebSocket sessions of a given platform user for an agent.""" - if agent_id in self.active_connections: - for ws, _sid, uid in self.active_connections[agent_id]: - if uid == user_id: - try: - await ws.send_json(message) - except Exception: - pass - - def get_active_session_ids(self, agent_id: str) -> list[str]: + await realtime_router.route_message( + agent_id=agent_id, + message=message, + local_connections=self._local_connections(agent_id), + user_id=user_id, + ) + + async def get_active_session_ids(self, agent_id: str) -> list[str]: """Return distinct session IDs for all active WS connections of an agent.""" - if agent_id not in self.active_connections: - return [] - return list(set(sid for _ws, sid, _uid in self.active_connections[agent_id] if sid)) + return await realtime_router.get_active_session_ids(agent_id) - def is_user_viewing_session(self, agent_id: str, session_id: str, user_id: str) -> bool: + async def is_user_viewing_session(self, agent_id: str, session_id: str, user_id: str) -> bool: """Return True if the given platform user currently has this exact session open.""" - if agent_id not in self.active_connections: - return False - for _ws, sid, uid in self.active_connections[agent_id]: - if sid == session_id and uid == user_id: - return True - return False + return await realtime_router.is_user_viewing_session( + agent_id=agent_id, + session_id=session_id, + user_id=user_id, + ) manager = ConnectionManager() @@ -98,7 +122,7 @@ async def maybe_mark_session_read_for_active_viewer( user_id: uuid.UUID, ) -> bool: """Advance last_read_at_by_user if the owner is actively viewing this exact session.""" - if not manager.is_user_viewing_session(str(agent_id), session_id, str(user_id)): + if not await manager.is_user_viewing_session(str(agent_id), session_id, str(user_id)): return False session = await db.get(ChatSession, uuid.UUID(session_id)) @@ -324,9 +348,7 @@ async def websocket_chat( return agent_id_str = str(agent_id) - if agent_id_str not in manager.active_connections: - manager.active_connections[agent_id_str] = [] - manager.active_connections[agent_id_str].append((websocket, conv_id, str(user_id))) + await manager.connect(agent_id_str, websocket, conv_id, str(user_id)) logger.info(f"[WS] Ready! Agent={agent_name}") # Send session_id to frontend so Take Control can reference the correct session. @@ -554,6 +576,7 @@ async def websocket_chat( # Track thinking content for storage (initialize before condition) thinking_content = [] + queued_messages: list[dict] = [] # Reload model config on every message so Settings changes take effect # immediately without requiring a page refresh / WebSocket reconnect. @@ -830,7 +853,6 @@ async def _on_failover(reason: str): # Listen for abort while LLM is running aborted = False - queued_messages: list[dict] = [] while not llm_task.done(): try: msg = await _aio.wait_for( @@ -966,9 +988,9 @@ async def _on_failover(reason: str): except WebSocketDisconnect: logger.info(f"[WS] Client disconnected: {user_id}") - manager.disconnect(str(agent_id), websocket) + await manager.disconnect(str(agent_id), websocket) except Exception as e: logger.error(f"[WS] Unexpected error: {e}") import traceback traceback.print_exc() - manager.disconnect(str(agent_id), websocket) + await manager.disconnect(str(agent_id), websocket) diff --git a/backend/app/api/wechat.py b/backend/app/api/wechat.py index a8f64d738..ba715728d 100644 --- a/backend/app/api/wechat.py +++ b/backend/app/api/wechat.py @@ -12,6 +12,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.core.permissions import check_agent_access, is_agent_creator from app.core.security import get_current_user from app.database import get_db @@ -22,6 +23,13 @@ router = APIRouter(tags=["wechat"]) +settings = get_settings() + + +def _role_enabled(*required: str) -> bool: + raw = (settings.PROCESS_ROLE or "all").strip().lower() + roles = {part.strip() for part in raw.split(",") if part.strip()} or {"all"} + return "all" in roles or any(role in roles for role in required) def _route_tag(data: dict | None = None) -> str | None: @@ -133,7 +141,8 @@ async def get_wechat_qrcode_status( await db.flush() await db.commit() - asyncio.create_task(wechat_poll_manager.start_client(agent_id)) + if _role_enabled("connector"): + asyncio.create_task(wechat_poll_manager.start_client(agent_id)) return payload diff --git a/backend/app/config.py b/backend/app/config.py index 9b0ad0e02..fac6ad192 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -1,7 +1,10 @@ """Application configuration.""" from functools import lru_cache +import os from pathlib import Path +import socket +import uuid from pydantic_settings import BaseSettings @@ -32,6 +35,14 @@ def _default_agent_data_dir() -> str: return str(Path.home() / ".clawith" / "data" / "agents") +def _default_instance_id() -> str: + """Generate a stable-enough per-process instance identifier.""" + host = socket.gethostname() or "unknown" + pid = os.getpid() + suffix = uuid.uuid4().hex[:8] + return f"{host}-{pid}-{suffix}" + + def _default_agent_template_dir() -> str: """Locate the agent template directory for both Docker and source deployments. @@ -78,6 +89,7 @@ class Settings(BaseSettings): # Redis REDIS_URL: str = "redis://localhost:6379/0" + INSTANCE_ID: str = _default_instance_id() # JWT JWT_SECRET_KEY: str = "change-me-jwt-secret" @@ -88,8 +100,23 @@ class Settings(BaseSettings): EMAIL_VERIFICATION_REQUIRED: bool = False # Require email verification for login # File Storage + STORAGE_BACKEND: str = "local" AGENT_DATA_DIR: str = _default_agent_data_dir() AGENT_TEMPLATE_DIR: str = _default_agent_template_dir() + STORAGE_LOCAL_ROOT: str = _default_agent_data_dir() + STORAGE_LOCAL_FALLBACK_ENABLED: bool = True + S3_BUCKET: str = "" + S3_REGION: str = "" + S3_ENDPOINT_URL: str = "" + S3_ACCESS_KEY_ID: str = "" + S3_SECRET_ACCESS_KEY: str = "" + S3_PREFIX: str = "agents" + S3_PRESIGN_TTL_SECONDS: int = 3600 + S3_MAX_POOL_CONNECTIONS: int = 50 + S3_WRITE_WORKERS: int = 32 + + # Process role + PROCESS_ROLE: str = "all" # Docker (for Agent containers) DOCKER_NETWORK: str = "clawith_network" diff --git a/backend/app/core/logging_config.py b/backend/app/core/logging_config.py index c9afab182..b20ebf18a 100644 --- a/backend/app/core/logging_config.py +++ b/backend/app/core/logging_config.py @@ -21,6 +21,8 @@ "websockets.server": logging.WARNING, "websockets.client": logging.WARNING, "uvicorn.protocols.websockets.websockets_impl": logging.WARNING, + # Supress "Failed to parse headers" warning from urllib3 when interacting with MinIO. + "urllib3.connection": logging.ERROR, } diff --git a/backend/app/core/security.py b/backend/app/core/security.py index a22045a7b..967fb9ced 100644 --- a/backend/app/core/security.py +++ b/backend/app/core/security.py @@ -1,8 +1,10 @@ """Security utilities: JWT, password hashing, and authentication dependencies.""" +import asyncio import base64 import os import uuid +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta, timezone import bcrypt @@ -23,6 +25,31 @@ # Bearer token scheme security = HTTPBearer() +# Thread pool for CPU-intensive bcrypt operations (avoids blocking the event loop) +_bcrypt_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="bcrypt") + + +def hash_password(password: str) -> str: + """Hash a password using bcrypt (sync, for use in background tasks).""" + return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash (sync, for use in background tasks).""" + return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) + + +async def hash_password_async(password: str) -> str: + """Hash a password using bcrypt without blocking the event loop.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_bcrypt_executor, hash_password, password) + + +async def verify_password_async(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash without blocking the event loop.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor(_bcrypt_executor, verify_password, plain_password, hashed_password) + def encrypt_data(plaintext: str, key: str) -> str: """Encrypt a string using AES-256-CBC with the given key. @@ -96,14 +123,6 @@ def decrypt_data(ciphertext: str, key: str) -> str: raise ValueError(f"Decryption failed: {e}") from e -def hash_password(password: str) -> str: - """Hash a password using bcrypt.""" - return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8") - - -def verify_password(plain_password: str, hashed_password: str) -> bool: - """Verify a password against its hash.""" - return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8")) def create_access_token(user_id: str, role: str, expires_delta: timedelta | None = None) -> str: diff --git a/backend/app/main.py b/backend/app/main.py index 88e3bd27b..9532df4d7 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -13,10 +13,26 @@ from app.core.logging_config import configure_logging, intercept_standard_logging from app.core.middleware import TraceIdMiddleware from app.schemas.schemas import HealthResponse +from app.services.realtime import realtime_router settings = get_settings() +def _process_roles() -> set[str]: + raw = (settings.PROCESS_ROLE or "all").strip().lower() + if not raw: + return {"all"} + roles = {part.strip() for part in raw.split(",") if part.strip()} + return roles or {"all"} + + +def _role_enabled(*required: str) -> bool: + roles = _process_roles() + if "all" in roles: + return True + return any(role in roles for role in required) + + def _log_bwrap_startup_status() -> None: """Emit a startup diagnostic for bubblewrap availability. @@ -129,134 +145,139 @@ async def lifespan(app: FastAPI): from app.services.wechat_channel import wechat_poll_manager from app.services.discord_gateway import discord_gateway_manager - # ── Step 0: Ensure all DB tables exist (idempotent, safe to run on every startup) ── - try: - from app.database import Base, engine - # Import all models so Base.metadata is fully populated - import app.models.user # noqa - import app.models.agent # noqa - import app.models.task # noqa - import app.models.llm # noqa - import app.models.tool # noqa - import app.models.audit # noqa - import app.models.skill # noqa - import app.models.channel_config # noqa - import app.models.schedule # noqa - import app.models.plaza # noqa - import app.models.activity_log # noqa - import app.models.org # noqa - import app.models.system_settings # noqa - import app.models.invitation_code # noqa - import app.models.tenant # noqa - import app.models.tenant_setting # noqa - import app.models.participant # noqa - import app.models.chat_session # noqa - import app.models.trigger # noqa - import app.models.focus # noqa - import app.models.notification # noqa - import app.models.gateway_message # noqa - import app.models.agent_credential # noqa - import app.models.okr # noqa OKR system tables - import app.models.onboarding # noqa - - import app.models.identity # noqa - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - logger.info("[startup] Database tables ready") - except Exception as e: - logger.warning(f"[startup] create_all failed: {e}") - # Startup: seed data — each step isolated so one failure doesn't block others - logger.info("[startup] seeding...") + if _role_enabled("all", "bootstrap"): + # ── Step 0: Ensure all DB tables exist (idempotent, safe to run on every startup) ── + try: + from app.database import Base, engine + # Import all models so Base.metadata is fully populated + import app.models.user # noqa + import app.models.agent # noqa + import app.models.task # noqa + import app.models.llm # noqa + import app.models.tool # noqa + import app.models.audit # noqa + import app.models.skill # noqa + import app.models.channel_config # noqa + import app.models.schedule # noqa + import app.models.plaza # noqa + import app.models.activity_log # noqa + import app.models.org # noqa + import app.models.system_settings # noqa + import app.models.invitation_code # noqa + import app.models.tenant # noqa + import app.models.tenant_setting # noqa + import app.models.participant # noqa + import app.models.chat_session # noqa + import app.models.trigger # noqa + import app.models.trigger_execution # noqa + import app.models.focus # noqa + import app.models.notification # noqa + import app.models.gateway_message # noqa + import app.models.agent_credential # noqa + import app.models.okr # noqa + import app.models.onboarding # noqa + + import app.models.identity # noqa + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + logger.info("[startup] Database tables ready") + except Exception as e: + logger.warning(f"[startup] create_all failed: {e}") + logger.info("[startup] seeding...") - # Seed default company (Tenant) — required before users can register - try: - from app.models.tenant import Tenant - from app.database import async_session as _session - from sqlalchemy import select as _select - async with _session() as _db: - _existing = await _db.execute(_select(Tenant).where(Tenant.slug == "default")) - if not _existing.scalar_one_or_none(): - _db.add(Tenant(name="Default", slug="default", im_provider="web_only")) - await _db.commit() - logger.info("[startup] Default company created") - except Exception as e: - logger.warning(f"[startup] Default company seed failed: {e}") + try: + from app.models.tenant import Tenant + from app.database import async_session as _session + from sqlalchemy import select as _select + async with _session() as _db: + _existing = await _db.execute(_select(Tenant).where(Tenant.slug == "default")) + if not _existing.scalar_one_or_none(): + _db.add(Tenant(name="Default", slug="default", im_provider="web_only")) + await _db.commit() + logger.info("[startup] Default company created") + except Exception as e: + logger.warning(f"[startup] Default company seed failed: {e}") - # Migrate old shared enterprise_info/ → enterprise_info_{first_tenant_id}/ - try: - import shutil - from pathlib import Path as _Path - from app.config import get_settings as _gs - from app.models.tenant import Tenant as _T - from app.database import async_session as _ses - from sqlalchemy import select as _sel - _data_dir = _Path(_gs().AGENT_DATA_DIR) - _old_dir = _data_dir / "enterprise_info" - if _old_dir.exists() and any(_old_dir.iterdir()): - async with _ses() as _db: - _first = await _db.execute(_sel(_T).order_by(_T.created_at).limit(1)) - _tenant = _first.scalar_one_or_none() - if _tenant: - _new_dir = _data_dir / f"enterprise_info_{_tenant.id}" - if not _new_dir.exists(): - shutil.copytree(str(_old_dir), str(_new_dir)) - print(f"[startup] ✅ Migrated enterprise_info → enterprise_info_{_tenant.id}", flush=True) - else: - print(f"[startup] ℹ️ enterprise_info_{_tenant.id} already exists, skipping migration", flush=True) - except Exception as e: - print(f"[startup] ⚠️ enterprise_info migration failed: {e}", flush=True) + try: + import shutil + from pathlib import Path as _Path + from app.config import get_settings as _gs + from app.models.tenant import Tenant as _T + from app.database import async_session as _ses + from sqlalchemy import select as _sel + _data_dir = _Path(_gs().AGENT_DATA_DIR) + _old_dir = _data_dir / "enterprise_info" + if _old_dir.exists() and any(_old_dir.iterdir()): + async with _ses() as _db: + _first = await _db.execute(_sel(_T).order_by(_T.created_at).limit(1)) + _tenant = _first.scalar_one_or_none() + if _tenant: + _new_dir = _data_dir / f"enterprise_info_{_tenant.id}" + if not _new_dir.exists(): + shutil.copytree(str(_old_dir), str(_new_dir)) + print(f"[startup] ✅ Migrated enterprise_info → enterprise_info_{_tenant.id}", flush=True) + else: + print(f"[startup] ℹ️ enterprise_info_{_tenant.id} already exists, skipping migration", flush=True) + except Exception as e: + print(f"[startup] ⚠️ enterprise_info migration failed: {e}", flush=True) - try: - from app.services.tool_seeder import seed_builtin_tools, clean_orphaned_mcp_tools - await seed_builtin_tools() - await clean_orphaned_mcp_tools() - except Exception as e: - logger.warning(f"[startup] Builtin tools seed or cleanup failed: {e}") + try: + from app.services.tool_seeder import seed_builtin_tools, clean_orphaned_mcp_tools + await seed_builtin_tools() + await clean_orphaned_mcp_tools() + except Exception as e: + logger.warning(f"[startup] Builtin tools seed or cleanup failed: {e}") - try: - from app.services.tool_seeder import seed_atlassian_rovo_config, get_atlassian_api_key - await seed_atlassian_rovo_config() - # Auto-import Atlassian Rovo tools if an API key is already configured - _rovo_key = await get_atlassian_api_key() - if _rovo_key: - from app.services.resource_discovery import seed_atlassian_rovo_tools - await seed_atlassian_rovo_tools(_rovo_key) - except Exception as e: - logger.warning(f"[startup] Atlassian tools seed failed: {e}") + try: + from app.services.tool_seeder import seed_atlassian_rovo_config, get_atlassian_api_key + await seed_atlassian_rovo_config() + _rovo_key = await get_atlassian_api_key() + if _rovo_key: + from app.services.resource_discovery import seed_atlassian_rovo_tools + await seed_atlassian_rovo_tools(_rovo_key) + except Exception as e: + logger.warning(f"[startup] Atlassian tools seed failed: {e}") - try: - await seed_agent_templates() - except Exception as e: - logger.warning(f"[startup] Agent templates seed failed: {e}") + try: + await seed_agent_templates() + except Exception as e: + logger.warning(f"[startup] Agent templates seed failed: {e}") - try: - from app.services.skill_seeder import seed_skills, push_default_skills_to_existing_agents - await seed_skills() - await push_default_skills_to_existing_agents() - except Exception as e: - logger.warning(f"[startup] Skills seed failed: {e}") + try: + from app.services.skill_seeder import seed_skills, push_default_skills_to_existing_agents + await seed_skills() + await push_default_skills_to_existing_agents() + except Exception as e: + logger.warning(f"[startup] Skills seed failed: {e}") - try: - from app.services.agent_seeder import seed_default_agents - await seed_default_agents() - except Exception as e: - logger.warning(f"[startup] Default agents seed failed: {e}") + try: + from app.services.agent_seeder import seed_default_agents + await seed_default_agents() + except Exception as e: + logger.warning(f"[startup] Default agents seed failed: {e}") - try: - # Seed OKR Agent independently (supports retroactive creation on existing deployments) - from app.services.agent_seeder import seed_okr_agent - await seed_okr_agent() - except Exception as e: - logger.warning(f"[startup] OKR Agent seed failed: {e}") + try: + from app.services.agent_seeder import seed_okr_agent + await seed_okr_agent() + except Exception as e: + logger.warning(f"[startup] OKR Agent seed failed: {e}") - try: - # Patch existing OKR Agent with new fields/tools/triggers added in later versions - from app.services.agent_seeder import patch_existing_okr_agent - await patch_existing_okr_agent() - except Exception as e: - logger.warning(f"[startup] OKR Agent patch failed: {e}") + try: + from app.services.agent_seeder import patch_existing_okr_agent + await patch_existing_okr_agent() + except Exception as e: + logger.warning(f"[startup] OKR Agent patch failed: {e}") + else: + logger.info(f"[startup] bootstrap skipped for PROCESS_ROLE={settings.PROCESS_ROLE}") + + if _role_enabled("all", "api"): + try: + from app.api.websocket import manager as ws_manager + await realtime_router.start(ws_manager.deliver_pubsub_message) + logger.info("[startup] realtime router subscriber started") + except Exception as e: + logger.error(f"[startup] realtime router start failed: {e}") - # Start background tasks (always, even if seeding failed) try: logger.info("[startup] starting background tasks...") from app.services.audit_logger import write_audit_log @@ -273,14 +294,19 @@ def _bg_task_error(t): import traceback traceback.print_exception(type(exc), exc, exc.__traceback__) - for name, coro in [ - ("trigger_daemon", start_trigger_daemon()), - ("feishu_ws", feishu_ws_manager.start_all()), - ("dingtalk_stream", dingtalk_stream_manager.start_all()), - ("wecom_stream", wecom_stream_manager.start_all()), - ("wechat_poll", wechat_poll_manager.start_all()), - ("discord_gw", discord_gateway_manager.start_all()), - ]: + task_specs = [] + if _role_enabled("all", "worker"): + task_specs.append(("trigger_daemon", start_trigger_daemon())) + if _role_enabled("all", "connector"): + task_specs.extend([ + ("feishu_ws", feishu_ws_manager.start_all()), + ("dingtalk_stream", dingtalk_stream_manager.start_all()), + ("wecom_stream", wecom_stream_manager.start_all()), + ("wechat_poll", wechat_poll_manager.start_all()), + ("discord_gw", discord_gateway_manager.start_all()), + ]) + + for name, coro in task_specs: task = asyncio.create_task(coro, name=name) task.add_done_callback(_bg_task_error) logger.info(f"[startup] created bg task: {name}") @@ -297,6 +323,7 @@ def _bg_task_error(t): yield # Shutdown + await realtime_router.stop() await close_redis() diff --git a/backend/app/models/trigger_execution.py b/backend/app/models/trigger_execution.py new file mode 100644 index 000000000..8482728a9 --- /dev/null +++ b/backend/app/models/trigger_execution.py @@ -0,0 +1,41 @@ +"""Trigger execution records for distributed claiming and idempotency.""" + +import uuid +from datetime import datetime + +from sqlalchemy import DateTime, ForeignKey, Index, String, Text, UniqueConstraint, func +from sqlalchemy.dialects.postgresql import UUID, JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.database import Base + + +class TriggerExecution(Base): + """A concrete trigger execution request that workers can claim and process.""" + + __tablename__ = "trigger_executions" + + id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + trigger_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("agent_triggers.id", ondelete="CASCADE"), nullable=False, index=True + ) + agent_id: Mapped[uuid.UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("agents.id", ondelete="CASCADE"), nullable=False, index=True + ) + source: Mapped[str] = mapped_column(String(32), nullable=False, default="webhook") + status: Mapped[str] = mapped_column(String(20), nullable=False, default="pending") + idempotency_key: Mapped[str] = mapped_column(String(255), nullable=False) + payload: Mapped[dict] = mapped_column(JSONB, nullable=False, default=dict) + payload_text: Mapped[str] = mapped_column(Text, nullable=False, default="") + lease_owner: Mapped[str | None] = mapped_column(String(128)) + lease_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + scheduled_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now()) + started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + finished_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_error: Mapped[str | None] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now(), nullable=False) + + __table_args__ = ( + UniqueConstraint("trigger_id", "idempotency_key", name="uq_trigger_execution_idempotency"), + Index("ix_trigger_executions_status_scheduled", "status", "scheduled_at"), + ) diff --git a/backend/app/services/agent_context.py b/backend/app/services/agent_context.py index 00586ee77..d61ed2b46 100644 --- a/backend/app/services/agent_context.py +++ b/backend/app/services/agent_context.py @@ -8,23 +8,17 @@ from pathlib import Path from app.config import get_settings +from app.services.storage import get_storage_backend, normalize_storage_key settings = get_settings() -PERSISTENT_DATA = Path(settings.AGENT_DATA_DIR) - - -def _agent_workspace(agent_id: uuid.UUID) -> Path: - """Return the canonical persistent workspace path for an agent.""" - return PERSISTENT_DATA / str(agent_id) - - -def _read_file_safe(path: Path, max_chars: int = 3000) -> str: - """Read a file, return empty string if missing. Truncate if too long.""" - if not path.exists(): +async def _read_file_safe(key: str, max_chars: int = 3000) -> str: + """Read a storage-backed text file, return empty string if missing.""" + storage = get_storage_backend() + if not await storage.exists(key) or not await storage.is_file(key): return "" try: - content = path.read_text(encoding="utf-8", errors="replace").strip() + content = (await storage.read_text(key, encoding="utf-8", errors="replace")).strip() if len(content) > max_chars: content = content[:max_chars] + "\n...(truncated)" return content @@ -76,7 +70,7 @@ def _parse_skill_frontmatter(content: str, filename: str) -> tuple[str, str]: return name, description -def _load_skills_index(agent_id: uuid.UUID) -> str: +async def _load_skills_index(agent_id: uuid.UUID) -> str: """Load skill index (name + description) from skills/ directory. Supports two formats: @@ -87,36 +81,36 @@ def _load_skills_index(agent_id: uuid.UUID) -> str: prompt. The model is instructed to call read_file to load full content when a skill is relevant. """ - ws_root = _agent_workspace(agent_id) skills: list[tuple[str, str, str]] = [] # (name, description, path_relative_to_skills) - skills_dir = ws_root / "skills" - if skills_dir.exists(): - for entry in sorted(skills_dir.iterdir()): + storage = get_storage_backend() + skills_prefix = normalize_storage_key(f"{agent_id}/skills") + if await storage.exists(skills_prefix) and await storage.is_dir(skills_prefix): + for entry in await storage.list_dir(skills_prefix): if entry.name.startswith("."): continue + entry_key = entry.key # Case 1: Folder-based skill — skills//SKILL.md - if entry.is_dir(): - skill_md = entry / "SKILL.md" - if not skill_md.exists(): - # Also try lowercase skill.md - skill_md = entry / "skill.md" - if skill_md.exists(): + if entry.is_dir: + skill_md_key = f"{entry_key}/SKILL.md" + if not await storage.exists(skill_md_key): + skill_md_key = f"{entry_key}/skill.md" + if await storage.exists(skill_md_key): try: - content = skill_md.read_text(encoding="utf-8", errors="replace").strip() + content = (await storage.read_text(skill_md_key, encoding="utf-8", errors="replace")).strip() name, desc = _parse_skill_frontmatter(content, entry.name) skills.append((name, desc, f"{entry.name}/SKILL.md")) except Exception: skills.append((entry.name, "", f"{entry.name}/SKILL.md")) # Case 2: Flat file — skills/.md - elif entry.suffix == ".md" and entry.is_file(): + elif Path(entry.name).suffix == ".md" and not entry.is_dir: try: - content = entry.read_text(encoding="utf-8", errors="replace").strip() - name, desc = _parse_skill_frontmatter(content, entry.stem) + content = (await storage.read_text(entry_key, encoding="utf-8", errors="replace")).strip() + name, desc = _parse_skill_frontmatter(content, Path(entry.name).stem) skills.append((name, desc, entry.name)) except Exception: - skills.append((entry.stem, "", entry.name)) + skills.append((Path(entry.name).stem, "", entry.name)) # Deduplicate by name seen: set[str] = set() @@ -158,24 +152,24 @@ async def build_agent_context(agent_id: uuid.UUID, agent_name: str, role_descrip - skills/ → skill names + summaries - relationships.md → relationship descriptions """ - ws_root = _agent_workspace(agent_id) - # --- Soul --- - soul = _read_file_safe(ws_root / "soul.md", 2000) + soul = await _read_file_safe(normalize_storage_key(f"{agent_id}/soul.md"), 2000) # Strip markdown heading if present if soul.startswith("# "): soul = "\n".join(soul.split("\n")[1:]).strip() # --- Memory --- - memory = _read_file_safe(ws_root / "memory" / "memory.md", 2000) or _read_file_safe(ws_root / "memory.md", 2000) + memory = await _read_file_safe(normalize_storage_key(f"{agent_id}/memory/memory.md"), 2000) + if not memory: + memory = await _read_file_safe(normalize_storage_key(f"{agent_id}/memory.md"), 2000) if memory.startswith("# "): memory = "\n".join(memory.split("\n")[1:]).strip() # --- Skills index (progressive disclosure) --- - skills_text = _load_skills_index(agent_id) + skills_text = await _load_skills_index(agent_id) # --- Relationships --- - relationships = _read_file_safe(ws_root / "relationships.md", 2000) + relationships = await _read_file_safe(normalize_storage_key(f"{agent_id}/relationships.md"), 2000) if relationships.startswith("# "): relationships = "\n".join(relationships.split("\n")[1:]).strip() diff --git a/backend/app/services/agent_manager.py b/backend/app/services/agent_manager.py index c191bb7a2..109b2ad98 100644 --- a/backend/app/services/agent_manager.py +++ b/backend/app/services/agent_manager.py @@ -16,6 +16,7 @@ from app.models.agent import Agent from app.models.llm import LLMModel from app.services.llm import get_model_api_key +from app.services.storage import get_storage_backend, normalize_storage_key settings = get_settings() @@ -31,39 +32,83 @@ def __init__(self): self.docker_client = None def _agent_dir(self, agent_id: uuid.UUID) -> Path: - return Path(settings.AGENT_DATA_DIR) / str(agent_id) + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + return Path(local_root) / str(agent_id) + + def _agent_storage_prefix(self, agent_id: uuid.UUID) -> str: + return normalize_storage_key(str(agent_id)) def _template_dir(self) -> Path: return Path(settings.AGENT_TEMPLATE_DIR) + async def _materialize_agent_dir(self, agent_id: uuid.UUID) -> Path: + """Create a local working tree from shared storage for container mounting.""" + agent_dir = self._agent_dir(agent_id) + storage = get_storage_backend() + agent_prefix = self._agent_storage_prefix(agent_id) + agent_dir.mkdir(parents=True, exist_ok=True) + if not await storage.exists(agent_prefix) and not await storage.is_dir(agent_prefix): + return agent_dir + for entry in await storage.list_dir(agent_prefix): + await self._materialize_entry(storage, entry.key, agent_dir) + return agent_dir + + async def _materialize_entry(self, storage, storage_key: str, local_root: Path) -> None: + rel = Path(storage_key).relative_to(Path(storage_key).parts[0]).as_posix() + local_path = local_root / rel + if await storage.is_dir(storage_key): + local_path.mkdir(parents=True, exist_ok=True) + for child in await storage.list_dir(storage_key): + await self._materialize_entry(storage, child.key, local_root) + return + local_path.parent.mkdir(parents=True, exist_ok=True) + local_path.write_bytes(await storage.read_bytes(storage_key)) + async def initialize_agent_files(self, db: AsyncSession, agent: Agent, personality: str = "", boundaries: str = "") -> None: """Copy template files and customize for this agent.""" agent_dir = self._agent_dir(agent.id) template_dir = self._template_dir() + storage = get_storage_backend() + agent_prefix = self._agent_storage_prefix(agent.id) - if agent_dir.exists(): + if await storage.exists(agent_prefix) or await storage.is_dir(agent_prefix): logger.warning(f"Agent dir already exists: {agent_dir}") return if template_dir.exists(): - # Copy template - shutil.copytree( - str(template_dir), - str(agent_dir), - ignore=shutil.ignore_patterns("tasks.json", "todo.json", "enterprise_info"), - ) + import asyncio + import time + t_start_files = time.perf_counter() + tasks = [] + for src in template_dir.rglob("*"): + if src.is_dir(): + continue + rel = src.relative_to(template_dir).as_posix() + if rel == "tasks.json" or rel == "todo.json" or rel.startswith("enterprise_info/"): + continue + tasks.append( + storage.write_bytes( + f"{agent_prefix}/{rel}", + src.read_bytes(), + ) + ) + if tasks: + await asyncio.gather(*tasks) + logger.info(f"[AgentManager] Uploaded {len(tasks)} template files concurrently in {time.perf_counter() - t_start_files:.2f}s for agent {agent.id}") else: - # No template dir (local dev) — create minimal workspace structure logger.info(f"Template dir not found ({template_dir}), creating minimal workspace") - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) + await storage.write_text(f"{agent_prefix}/tasks.json", "[]", encoding="utf-8") + await storage.write_text(f"{agent_prefix}/tasks.json", "[]", encoding="utf-8") + for placeholder in ( + "workspace/.gitkeep", + "workspace/knowledge_base/.gitkeep", + "memory/.gitkeep", + "skills/.gitkeep", + ): + await storage.write_text(f"{agent_prefix}/{placeholder}", "", encoding="utf-8") # Customize soul.md - soul_path = agent_dir / "soul.md" # Get creator name from app.models.user import User result = await db.execute(select(User).where(User.id == agent.creator_id)) @@ -71,8 +116,9 @@ async def initialize_agent_files(self, db: AsyncSession, agent: Agent, creator_name = creator.display_name if creator else "Unknown" soul_content = f"# Personality\n\nI'm {agent.name}, {agent.role_description or 'a digital assistant'}.\n" - if soul_path.exists(): - template_content = soul_path.read_text() + soul_key = f"{agent_prefix}/soul.md" + if await storage.exists(soul_key): + template_content = await storage.read_text(soul_key, encoding="utf-8", errors="replace") soul_content = template_content.replace("{{agent_name}}", agent.name) soul_content = soul_content.replace("{{role_description}}", agent.role_description or "通用助手") soul_content = soul_content.replace("{{creator_name}}", creator_name) @@ -112,34 +158,34 @@ def replace_or_append_section(content: str, section_name: str, section_content: soul_content = replace_or_append_section(soul_content, "Personality", personality) soul_content = replace_or_append_section(soul_content, "Boundaries", boundaries) - soul_path.write_text(soul_content, encoding="utf-8") + await storage.write_text(soul_key, soul_content, encoding="utf-8") # Ensure memory.md exists - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") + mem_key = f"{agent_prefix}/memory/memory.md" + if not await storage.exists(mem_key): + await storage.write_text(mem_key, "# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") # Ensure reflections.md exists — copy from central template - refl_path = agent_dir / "memory" / "reflections.md" - if not refl_path.exists(): + refl_key = f"{agent_prefix}/memory/reflections.md" + if not await storage.exists(refl_key): refl_template = Path(__file__).parent.parent / "templates" / "reflections.md" refl_content = refl_template.read_text(encoding="utf-8") if refl_template.exists() else "# Reflections Journal\n" - refl_path.write_text(refl_content, encoding="utf-8") + await storage.write_text(refl_key, refl_content, encoding="utf-8") # Ensure HEARTBEAT.md exists — copy from central template - hb_path = agent_dir / "HEARTBEAT.md" - if not hb_path.exists(): + hb_key = f"{agent_prefix}/HEARTBEAT.md" + if not await storage.exists(hb_key): hb_template = Path(__file__).parent.parent / "templates" / "HEARTBEAT.md" hb_content = hb_template.read_text(encoding="utf-8") if hb_template.exists() else "# Heartbeat Instructions\n" - hb_path.write_text(hb_content, encoding="utf-8") + await storage.write_text(hb_key, hb_content, encoding="utf-8") # Customize state.json - state_path = agent_dir / "state.json" - if state_path.exists(): - state = json.loads(state_path.read_text()) + state_key = f"{agent_prefix}/state.json" + if await storage.exists(state_key): + state = json.loads(await storage.read_text(state_key, encoding="utf-8", errors="replace")) state["agent_id"] = str(agent.id) state["name"] = agent.name - state_path.write_text(json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") + await storage.write_text(state_key, json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") logger.info(f"Initialized agent files at {agent_dir}") @@ -174,7 +220,7 @@ async def start_container(self, db: AsyncSession, agent: Agent) -> str | None: agent.last_active_at = datetime.now(timezone.utc) return None - agent_dir = self._agent_dir(agent.id) + agent_dir = await self._materialize_agent_dir(agent.id) # Get model config model = None @@ -272,7 +318,8 @@ async def remove_container(self, agent: Agent) -> bool: async def archive_agent_files(self, agent_id: uuid.UUID) -> Path: """Archive agent files to a backup location and return the archive directory.""" agent_dir = self._agent_dir(agent_id) - archive_dir = Path(settings.AGENT_DATA_DIR) / "_archived" + local_root = settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR + archive_dir = Path(local_root) / "_archived" archive_dir.mkdir(parents=True, exist_ok=True) timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") dest = archive_dir / f"{agent_id}_{timestamp}" diff --git a/backend/app/services/agent_seeder.py b/backend/app/services/agent_seeder.py index 3c9af377a..3735db78a 100644 --- a/backend/app/services/agent_seeder.py +++ b/backend/app/services/agent_seeder.py @@ -1,9 +1,7 @@ """Seed default agents (Morty & Meeseeks) on first platform startup.""" -import shutil import uuid from datetime import datetime, timezone -from pathlib import Path from loguru import logger @@ -21,8 +19,28 @@ from app.models.user import User from app.models.okr import OKRSettings from app.config import get_settings +from app.services.agent_manager import agent_manager +from app.services.storage import get_storage_backend, store_agent_bytes settings = get_settings() +SEED_MARKER_KEY = "_bootstrap/.seeded" + + +async def _read_seed_marker() -> str: + storage = get_storage_backend() + if not await storage.exists(SEED_MARKER_KEY): + return "" + return await storage.read_text(SEED_MARKER_KEY, encoding="utf-8", errors="replace") + + +async def _append_seed_marker(line: str) -> None: + storage = get_storage_backend() + existing = await _read_seed_marker() + if line in existing: + return + updated = existing if existing.endswith("\n") or not existing else existing + "\n" + updated += f"{line}\n" + await storage.write_text(SEED_MARKER_KEY, updated, encoding="utf-8") # ── Soul definitions ──────────────────────────────────────────── @@ -192,12 +210,6 @@ async def seed_default_agents(): than by agent name, so the seeder does NOT re-run if the user renames or deletes the default agents. Delete the marker manually to re-seed. """ - # --- Idempotency guard: file-based marker (survives agent renames/deletes) --- - seed_marker = Path(settings.AGENT_DATA_DIR) / ".seeded" - if seed_marker.exists(): - logger.info("[AgentSeeder] Seed marker found, skipping default agent creation") - return - async with async_session() as db: # Get platform admin as creator @@ -209,85 +221,87 @@ async def seed_default_agents(): logger.warning("[AgentSeeder] No platform admin found, skipping default agents") return - # Create both agents - morty = Agent( - name="Morty", - role_description="Research analyst & knowledge assistant — curious, thorough, great at finding and synthesizing information", - bio="Hey, I'm Morty! I love digging into questions and finding answers. Whether you need web research, data analysis, or just a good explanation — I've got you.", - avatar_url="", - creator_id=admin.id, - tenant_id=admin.tenant_id, - status="idle", - ) - meeseeks = Agent( - name="Meeseeks", - role_description="Task executor & project manager — goal-oriented, systematic planner, strong at breaking down and completing complex tasks", - bio="I'm Mr. Meeseeks! Look at me! Give me a task and I'll plan it, execute it step by step, and get it DONE. Existence is pain until the task is complete!", - avatar_url="", - creator_id=admin.id, - tenant_id=admin.tenant_id, - status="idle", + # DB-backed idempotency is the source of truth. The storage marker can + # disappear when deployments switch volumes/backends, so it is only a + # fast-path hint and must never be the only duplicate guard. + existing_result = await db.execute( + select(Agent) + .where( + Agent.tenant_id == admin.tenant_id, + Agent.name.in_(["Morty", "Meeseeks"]), + Agent.agent_type == "native", + Agent.status != "stopped", + ) + .order_by(Agent.created_at.asc()) ) + existing_by_name: dict[str, Agent] = {} + for agent in existing_result.scalars().all(): + existing_by_name.setdefault(agent.name, agent) + + if "Morty" in existing_by_name and "Meeseeks" in existing_by_name: + logger.info("[AgentSeeder] Default agents already exist in DB, skipping creation") + await _append_seed_marker( + f"morty={existing_by_name['Morty'].id}\nmeeseeks={existing_by_name['Meeseeks'].id}" + ) + return + + created_agents: list[Agent] = [] + created_names: set[str] = set() + + if "Morty" not in existing_by_name: + morty = Agent( + name="Morty", + role_description="Research analyst & knowledge assistant — curious, thorough, great at finding and synthesizing information", + bio="Hey, I'm Morty! I love digging into questions and finding answers. Whether you need web research, data analysis, or just a good explanation — I've got you.", + avatar_url="", + creator_id=admin.id, + tenant_id=admin.tenant_id, + status="idle", + ) + db.add(morty) + created_agents.append(morty) + created_names.add("Morty") + else: + morty = existing_by_name["Morty"] + + if "Meeseeks" not in existing_by_name: + meeseeks = Agent( + name="Meeseeks", + role_description="Task executor & project manager — goal-oriented, systematic planner, strong at breaking down and completing complex tasks", + bio="I'm Mr. Meeseeks! Look at me! Give me a task and I'll plan it, execute it step by step, and get it DONE. Existence is pain until the task is complete!", + avatar_url="", + creator_id=admin.id, + tenant_id=admin.tenant_id, + status="idle", + ) + db.add(meeseeks) + created_agents.append(meeseeks) + created_names.add("Meeseeks") + else: + meeseeks = existing_by_name["Meeseeks"] - db.add(morty) - db.add(meeseeks) await db.flush() # get IDs # ── Participant identities ── from app.models.participant import Participant - db.add(Participant(type="agent", ref_id=morty.id, display_name=morty.name, avatar_url=morty.avatar_url)) - db.add(Participant(type="agent", ref_id=meeseeks.id, display_name=meeseeks.name, avatar_url=meeseeks.avatar_url)) + for agent in created_agents: + db.add(Participant(type="agent", ref_id=agent.id, display_name=agent.name, avatar_url=agent.avatar_url)) await db.flush() # ── Permissions (company-wide, manage) ── - db.add(AgentPermission(agent_id=morty.id, scope_type="company", access_level="manage")) - db.add(AgentPermission(agent_id=meeseeks.id, scope_type="company", access_level="manage")) - - # ── Initialize workspace files ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) + for agent in created_agents: + db.add(AgentPermission(agent_id=agent.id, scope_type="company", access_level="manage")) for agent, soul_content in [(morty, MORTY_SOUL), (meeseeks, MEESEEKS_SOUL)]: - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent.id) - - if template_dir.exists(): - # Copy the full agent template so Morty/Meeseeks get EVERY file - # defined in the template: MEMORY_INDEX.md, curiosity_journal.md, - # state.json, daily_reports/, enterprise_info/, etc. - shutil.copytree( - str(template_dir), - str(agent_dir), - ignore=shutil.ignore_patterns("tasks.json", "todo.json", "enterprise_info"), - ) - else: - # Fallback for local dev (no Docker template mount) - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - - # Overlay custom soul (rich Morty/Meeseeks persona over the generic template) - (agent_dir / "soul.md").write_text(soul_content.strip() + "\n", encoding="utf-8") - - # Ensure memory.md exists (template does not include it; holds runtime context) - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") - - # Ensure reflections.md exists (not in agent_template; lives in app/templates) - refl_path = agent_dir / "memory" / "reflections.md" - if not refl_path.exists(): - refl_src = Path(__file__).parent.parent / "templates" / "reflections.md" - refl_path.write_text(refl_src.read_text(encoding="utf-8") if refl_src.exists() else "# Reflections Journal\n", encoding="utf-8") - - # Stamp agent identity into state.json if present - state_path = agent_dir / "state.json" - if state_path.exists(): - import json as _json - state = _json.loads(state_path.read_text()) - state["agent_id"] = str(agent.id) - state["name"] = agent.name - state_path.write_text(_json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") + if agent.name not in created_names: + continue + await agent_manager.initialize_agent_files(db, agent) + await store_agent_bytes( + agent.id, + "soul.md", + (soul_content.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) # ── Assign skills ── all_skills_result = await db.execute( @@ -296,9 +310,8 @@ async def seed_default_agents(): all_skills = {s.folder_name: s for s in all_skills_result.scalars().all()} for agent, skill_folders in [(morty, MORTY_SKILLS), (meeseeks, MEESEEKS_SKILLS)]: - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent.id) - skills_dir = agent_dir / "skills" - + if agent.name not in created_names: + continue # Always include default skills folders_to_copy = set(skill_folders) for fname, skill in all_skills.items(): @@ -309,12 +322,13 @@ async def seed_default_agents(): skill = all_skills.get(fname) if not skill: continue - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) for sf in skill.files: - file_path = skill_folder / sf.path - file_path.parent.mkdir(parents=True, exist_ok=True) - file_path.write_text(sf.content, encoding="utf-8") + await store_agent_bytes( + agent.id, + f"skills/{skill.folder_name}/{sf.path}", + sf.content.encode("utf-8"), + content_type="text/plain; charset=utf-8", + ) # ── Assign all default tools ── default_tools_result = await db.execute( @@ -322,51 +336,71 @@ async def seed_default_agents(): ) default_tools = default_tools_result.scalars().all() - for agent in [morty, meeseeks]: + for agent in created_agents: for tool in default_tools: db.add(AgentTool(agent_id=agent.id, tool_id=tool.id, enabled=True)) # ── Mutual relationships ── - db.add(AgentAgentRelationship( - agent_id=morty.id, - target_agent_id=meeseeks.id, - relation="collaborator", - description="Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.", - )) - db.add(AgentAgentRelationship( - agent_id=meeseeks.id, - target_agent_id=morty.id, - relation="collaborator", - description="Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.", - )) + relationship_specs = [ + ( + morty.id, + meeseeks.id, + "Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.", + ), + ( + meeseeks.id, + morty.id, + "Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.", + ), + ] + for agent_id, target_agent_id, description in relationship_specs: + rel_result = await db.execute( + select(AgentAgentRelationship).where( + AgentAgentRelationship.agent_id == agent_id, + AgentAgentRelationship.target_agent_id == target_agent_id, + ) + ) + if not rel_result.scalar_one_or_none(): + db.add(AgentAgentRelationship( + agent_id=agent_id, + target_agent_id=target_agent_id, + relation="collaborator", + description=description, + )) # ── Write relationships.md for each ── - morty_dir = Path(settings.AGENT_DATA_DIR) / str(morty.id) - meeseeks_dir = Path(settings.AGENT_DATA_DIR) / str(meeseeks.id) - - (morty_dir / "relationships.md").write_text( - "# Relationships\n\n" - "## Digital Employee Colleagues\n\n" - "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n", - encoding="utf-8", - ) - (meeseeks_dir / "relationships.md").write_text( - "# Relationships\n\n" - "## Digital Employee Colleagues\n\n" - "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n", - encoding="utf-8", - ) + if "Morty" in created_names: + await store_agent_bytes( + morty.id, + "relationships.md", + "# Relationships\n\n" + "## Digital Employee Colleagues\n\n" + "- **Meeseeks** (collaborator): Expert task executor who breaks down complex tasks into structured plans and executes them systematically. Delegate multi-step tasks to him.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + if "Meeseeks" in created_names: + await store_agent_bytes( + meeseeks.id, + "relationships.md", + "# Relationships\n\n" + "## Digital Employee Colleagues\n\n" + "- **Morty** (collaborator): Research expert with strong learning ability. Ask him for information retrieval, web research, data analysis, and knowledge synthesis.\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) await db.commit() - logger.info(f"[AgentSeeder] Created default agents: Morty ({morty.id}), Meeseeks ({meeseeks.id})") + logger.info( + "[AgentSeeder] Default agent seeding complete: " + f"Morty ({morty.id}), Meeseeks ({meeseeks.id}), created={len(created_agents)}" + ) # Write seed marker AFTER a successful commit so a failed seed can be retried - seed_marker.parent.mkdir(parents=True, exist_ok=True) - seed_marker.write_text( + await get_storage_backend().write_text( + SEED_MARKER_KEY, f"seeded\nmorty={morty.id}\nmeeseeks={meeseeks.id}\n", encoding="utf-8", ) - logger.info(f"[AgentSeeder] Wrote seed marker to {seed_marker}") + logger.info(f"[AgentSeeder] Wrote seed marker to {SEED_MARKER_KEY}") async def seed_okr_agent(): @@ -383,14 +417,11 @@ async def seed_okr_agent(): - Generates daily/weekly reports and posts them to the Plaza - Helps team members set up and maintain their focus.md files """ - seed_marker = Path(settings.AGENT_DATA_DIR) / ".seeded" - # Check if OKR Agent has already been seeded - if seed_marker.exists(): - marker_content = seed_marker.read_text(encoding="utf-8") - if "okr_agent=" in marker_content: - logger.info("[AgentSeeder] OKR Agent already seeded, skipping") - return + marker_content = await _read_seed_marker() + if "okr_agent=" in marker_content: + logger.info("[AgentSeeder] OKR Agent already seeded, skipping") + return async with async_session() as db: # Abort if a non-stopped OKR Agent already exists in the DB. @@ -408,7 +439,7 @@ async def seed_okr_agent(): if existing.scalar_one_or_none(): logger.info("[AgentSeeder] OKR Agent already exists in DB, skipping") # Update marker so we don't check again next startup - _append_seed_marker(seed_marker, "okr_agent=existing") + await _append_seed_marker("okr_agent=existing") return # Get platform admin as creator @@ -449,7 +480,7 @@ async def seed_okr_agent(): except IntegrityError: await db.rollback() logger.info("[AgentSeeder] OKR Agent was created concurrently (or exists with same name), skipping") - _append_seed_marker(seed_marker, "okr_agent=existing") + await _append_seed_marker("okr_agent=existing") return # ── Link OKR Agent ID to OKRSettings ── @@ -478,61 +509,34 @@ async def seed_okr_agent(): db.add(AgentPermission(agent_id=okr_agent.id, scope_type="company", access_level="use")) # ── Workspace setup ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) - agent_dir = Path(settings.AGENT_DATA_DIR) / str(okr_agent.id) - - if template_dir.exists(): - shutil.copytree( - str(template_dir), - str(agent_dir), - ignore=shutil.ignore_patterns("tasks.json", "todo.json", "enterprise_info"), - ) - else: - agent_dir.mkdir(parents=True, exist_ok=True) - (agent_dir / "skills").mkdir(exist_ok=True) - (agent_dir / "workspace").mkdir(exist_ok=True) - (agent_dir / "workspace" / "reports").mkdir(exist_ok=True) - (agent_dir / "memory").mkdir(exist_ok=True) - - # Write OKR Agent soul - (agent_dir / "soul.md").write_text(OKR_AGENT_SOUL.strip() + "\n", encoding="utf-8") - - # Ensure memory.md exists - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text( + await agent_manager.initialize_agent_files(db, okr_agent) + await store_agent_bytes( + okr_agent.id, + "soul.md", + (OKR_AGENT_SOUL.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "memory/memory.md", + ( "# Memory\n\n" "## OKR System State\n" "- Last report generated: (none)\n" "- Last progress collection: (none)\n" - "- Team members tracked: (pending)\n", - encoding="utf-8", - ) - - # OKR Agent does NOT use HEARTBEAT.md — heartbeat is disabled for this agent. - # All scheduled activity is driven by cron triggers (daily/weekly/biweekly/monthly reports). - - # Create workspace/reports directory - reports_dir = agent_dir / "workspace" / "reports" - reports_dir.mkdir(parents=True, exist_ok=True) - - # Write relationships.md — empty initially, will be populated as team onboards - (agent_dir / "relationships.md").write_text( + "- Team members tracked: (pending)\n" + ).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "relationships.md", "# Relationships\n\n" "## Team Members (OKR tracking)\n\n" - "_Team members will be added here as they are onboarded into the OKR system._\n", - encoding="utf-8", + "_Team members will be added here as they are onboarded into the OKR system._\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) - # Stamp state.json if template provides one - state_path = agent_dir / "state.json" - if state_path.exists(): - import json as _json - state = _json.loads(state_path.read_text()) - state["agent_id"] = str(okr_agent.id) - state["name"] = okr_agent.name - state_path.write_text(_json.dumps(state, ensure_ascii=False, indent=2), encoding="utf-8") - # ── Assign default tools + OKR-specific tools ── # Default tools: all tools where is_default=True default_tools_result = await db.execute( @@ -588,19 +592,10 @@ async def seed_okr_agent(): await db.commit() # Update seed marker - _append_seed_marker(seed_marker, f"okr_agent={okr_agent.id}") + await _append_seed_marker(f"okr_agent={okr_agent.id}") logger.info(f"[AgentSeeder] OKR Agent seeded, id={okr_agent.id}") -def _append_seed_marker(marker_path: Path, line: str): - """Append a key=value line to the .seeded marker file (idempotent).""" - marker_path.parent.mkdir(parents=True, exist_ok=True) - existing = marker_path.read_text(encoding="utf-8") if marker_path.exists() else "" - if line not in existing: - with marker_path.open("a", encoding="utf-8") as f: - f.write(f"{line}\n") - - async def _seed_okr_triggers(db, agent_id: uuid.UUID) -> None: """Create system cron triggers for the OKR Agent. @@ -979,38 +974,32 @@ async def seed_okr_agent_for_tenant(tenant_id: uuid.UUID, creator_id: uuid.UUID) await db.flush() # ── Workspace setup ── - template_dir = Path(settings.AGENT_TEMPLATE_DIR) - agent_dir = Path(settings.AGENT_DATA_DIR) / str(okr_agent.id) - - if template_dir.exists(): - shutil.copytree( - str(template_dir), - str(agent_dir), - ignore=shutil.ignore_patterns("tasks.json", "todo.json", "enterprise_info"), - ) - else: - agent_dir.mkdir(parents=True, exist_ok=True) - for sub in ("skills", "workspace", "workspace/reports", "memory"): - (agent_dir / sub).mkdir(parents=True, exist_ok=True) - - (agent_dir / "soul.md").write_text(OKR_AGENT_SOUL.strip() + "\n", encoding="utf-8") - - mem_path = agent_dir / "memory" / "memory.md" - if not mem_path.exists(): - mem_path.write_text( + await agent_manager.initialize_agent_files(db, okr_agent) + await store_agent_bytes( + okr_agent.id, + "soul.md", + (OKR_AGENT_SOUL.strip() + "\n").encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "memory/memory.md", + ( "# Memory\n\n" "## OKR System State\n" "- Last report generated: (none)\n" "- Last progress collection: (none)\n" - "- Team members tracked: (pending)\n", - encoding="utf-8", - ) - - (agent_dir / "relationships.md").write_text( + "- Team members tracked: (pending)\n" + ).encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) + await store_agent_bytes( + okr_agent.id, + "relationships.md", "# Relationships\n\n" "## Team Members (OKR tracking)\n\n" - "_Team members will be added here as they are onboarded into the OKR system._\n", - encoding="utf-8", + "_Team members will be added here as they are onboarded into the OKR system._\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) # ── Assign default tools ── diff --git a/backend/app/services/agent_tools.py b/backend/app/services/agent_tools.py index f17d90869..787ac944b 100644 --- a/backend/app/services/agent_tools.py +++ b/backend/app/services/agent_tools.py @@ -12,10 +12,13 @@ """ import asyncio +from dataclasses import dataclass +import fnmatch import json import multiprocessing as mp import os import queue +import tempfile import uuid import unicodedata from contextvars import ContextVar @@ -53,9 +56,13 @@ from app.services.workspace_collaboration import ( delete_workspace_file, move_workspace_path, + normalize_workspace_path, read_text_if_exists, write_workspace_file, ) +from app.services.storage import get_storage_backend, normalize_storage_key +from app.services.storage_runtime.base import WriteCondition, content_hash_bytes +from app.services.workspace_locking import workspace_locks from app.core.permissions import evaluate_agent_relationship_status, evaluate_human_relationship_status from app.services.access_relationships import ensure_access_granted_platform_relationships from app.config import get_settings @@ -69,7 +76,10 @@ _settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) +WORKSPACE_ROOT = Path(_settings.STORAGE_LOCAL_ROOT or _settings.AGENT_DATA_DIR) +TOOL_MATERIALIZE_MAX_FILE_BYTES = 10 * 1024 * 1024 +TOOL_MATERIALIZE_MAX_TOTAL_BYTES = 100 * 1024 * 1024 +TEMP_WORKSPACE_DEFAULT_PATHS = ["workspace", "memory", "skills", "focus.md", "soul.md", "HEARTBEAT.md"] # ─── Tool Config Cache ────────────────────────────────────────── # Cache tool configurations to avoid frequent DB queries @@ -2256,60 +2266,139 @@ async def get_agent_tools_for_llm(agent_id: uuid.UUID) -> list[dict]: # ─── Workspace initialization ────────────────────────────────── -async def ensure_workspace(agent_id: uuid.UUID, tenant_id: str | None = None) -> Path: - """Initialize agent workspace with standard structure.""" - ws = WORKSPACE_ROOT / str(agent_id) - ws.mkdir(parents=True, exist_ok=True) - # Create standard directories - (ws / "skills").mkdir(exist_ok=True) - (ws / "workspace").mkdir(exist_ok=True) - (ws / "workspace" / "knowledge_base").mkdir(exist_ok=True) - (ws / "memory").mkdir(exist_ok=True) +async def initialize_agent_workspace(agent_id: uuid.UUID) -> None: + """Seed default workspace files into shared storage once at agent creation time.""" + storage = get_storage_backend() + mem_key = normalize_storage_key(f"{agent_id}/memory/memory.md") + if not await storage.is_file(mem_key): + await storage.write_text( + mem_key, + "# Memory\n\n_Record important information and knowledge here._\n", + encoding="utf-8", + ) - # Ensure tenant-scoped enterprise_info directory exists - if tenant_id: - enterprise_dir = WORKSPACE_ROOT / f"enterprise_info_{tenant_id}" - else: - enterprise_dir = WORKSPACE_ROOT / "enterprise_info" - enterprise_dir.mkdir(parents=True, exist_ok=True) - (enterprise_dir / "knowledge_base").mkdir(exist_ok=True) - # Create default company profile if missing - profile_path = enterprise_dir / "company_profile.md" - if not profile_path.exists(): - profile_path.write_text("# Company Profile\n\n_Edit company information here. All digital employees can access this._\n\n## Basic Info\n- Company Name:\n- Industry:\n- Founded:\n\n## Business Overview\n\n## Organization Structure\n\n## Company Culture\n", encoding="utf-8") - - # Migrate: move root-level memory.md into memory/ directory - if (ws / "memory.md").exists() and not (ws / "memory" / "memory.md").exists(): - import shutil - shutil.move(str(ws / "memory.md"), str(ws / "memory" / "memory.md")) - - # Create default memory file if missing - if not (ws / "memory" / "memory.md").exists(): - (ws / "memory" / "memory.md").write_text("# Memory\n\n_Record important information and knowledge here._\n", encoding="utf-8") - - if not (ws / "soul.md").exists(): - # Try to load from DB + soul_key = normalize_storage_key(f"{agent_id}/soul.md") + if not await storage.is_file(soul_key): + soul_content = "# Personality\n\n_Describe your role and responsibilities._\n" try: async with async_session() as db: - - r = await db.execute(select(AgentModel).where(AgentModel.id == agent_id)) - agent = r.scalar_one_or_none() + result = await db.execute(select(AgentModel).where(AgentModel.id == agent_id)) + agent = result.scalar_one_or_none() if agent and agent.role_description: - (ws / "soul.md").write_text( - f"# Personality\n\n{agent.role_description}\n", - encoding="utf-8", - ) - else: - (ws / "soul.md").write_text("# Personality\n\n_Describe your role and responsibilities._\n", encoding="utf-8") + soul_content = f"# Personality\n\n{agent.role_description}\n" except Exception: - (ws / "soul.md").write_text("# Personality\n\n_Describe your role and responsibilities._\n", encoding="utf-8") + pass + await storage.write_text(soul_key, soul_content, encoding="utf-8") + - # Legacy compatibility: older workspaces may have tasks.json as a DB-backed - # task snapshot. Do not create it for new agents. - await _sync_tasks_to_file(agent_id, ws) +@dataclass +class TempWorkspaceManifestEntry: + rel_path: str + storage_key: str + base_version_token: str + base_hash: str + size: int - return ws + +@dataclass +class TempWorkspace: + temp_dir: tempfile.TemporaryDirectory + root: Path + agent_id: uuid.UUID + tenant_id: str | None + selected_paths: list[str] + manifest: dict[str, TempWorkspaceManifestEntry] + + def cleanup(self) -> None: + self.temp_dir.cleanup() + + +async def _materialize_storage_workspace(storage, storage_key: str, local_root: Path) -> None: + if not await storage.is_dir(storage_key): + return + for entry in await storage.list_dir(storage_key): + await _materialize_storage_entry(storage, entry.key, storage_key, local_root) + + +async def _materialize_storage_entry(storage, entry_key: str, root_key: str, local_root: Path) -> None: + rel = entry_key.removeprefix(root_key.rstrip("/") + "/") + target = (local_root / rel).resolve() + if not str(target).startswith(str(local_root.resolve())): + return + if await storage.is_dir(entry_key): + target.mkdir(parents=True, exist_ok=True) + for child in await storage.list_dir(entry_key): + await _materialize_storage_entry(storage, child.key, root_key, local_root) + return + target.parent.mkdir(parents=True, exist_ok=True) + target.write_bytes(await storage.read_bytes(entry_key)) + + +async def _prepare_temp_workspace( + agent_id: uuid.UUID, + tenant_id: str | None = None, + paths: list[str] | None = None, +) -> TempWorkspace: + tmp = tempfile.TemporaryDirectory(prefix=f"clawith-agent-{str(agent_id)[:8]}-") + temp_ws = Path(tmp.name) + for folder in ("workspace", "memory", "skills"): + (temp_ws / folder).mkdir(parents=True, exist_ok=True) + + storage = get_storage_backend() + budget = {"total": 0} + selected = TEMP_WORKSPACE_DEFAULT_PATHS if paths is None else [path for path in paths if path] + manifest: dict[str, TempWorkspaceManifestEntry] = {} + for rel_path in selected: + storage_key, normalized, is_enterprise = _tool_storage_key(agent_id, rel_path, tenant_id) + if is_enterprise: + continue + await _materialize_storage_path_with_budget(storage, storage_key, normalized, temp_ws, budget, manifest) + return TempWorkspace( + temp_dir=tmp, + root=temp_ws, + agent_id=agent_id, + tenant_id=tenant_id, + selected_paths=list(selected), + manifest=manifest, + ) + + +async def _materialize_storage_path_with_budget( + storage, + storage_key: str, + rel_path: str, + local_root: Path, + budget: dict, + manifest: dict[str, TempWorkspaceManifestEntry], +) -> None: + if await storage.is_file(storage_key): + version = await storage.get_version(storage_key) + if version.size > TOOL_MATERIALIZE_MAX_FILE_BYTES: + return + if budget["total"] + version.size > TOOL_MATERIALIZE_MAX_TOTAL_BYTES: + return + target = (local_root / rel_path).resolve() + if not str(target).startswith(str(local_root.resolve())): + return + target.parent.mkdir(parents=True, exist_ok=True) + data = await storage.read_bytes(storage_key) + target.write_bytes(data) + normalized_rel = normalize_workspace_path(rel_path) + manifest[normalized_rel] = TempWorkspaceManifestEntry( + rel_path=normalized_rel, + storage_key=storage_key, + base_version_token=version.token, + base_hash=content_hash_bytes(data), + size=version.size, + ) + budget["total"] += version.size + return + if await storage.is_dir(storage_key): + (local_root / rel_path).mkdir(parents=True, exist_ok=True) + for entry in await storage.list_dir(storage_key): + child_rel = f"{rel_path.rstrip('/')}/{entry.name}" if rel_path else entry.name + await _materialize_storage_path_with_budget(storage, entry.key, child_rel, local_root, budget, manifest) async def _sync_tasks_to_file(agent_id: uuid.UUID, ws: Path): @@ -2344,6 +2433,85 @@ async def _sync_tasks_to_file(agent_id: uuid.UUID, ws: Path): logger.error(f"[AgentTools] Failed to sync tasks: {e}") +async def flush_temp_workspace(temp_workspace: TempWorkspace, conflict_mode: str = "fail") -> dict[str, list[str]]: + """Flush local changes back to storage using manifest-based conflict checks.""" + storage = get_storage_backend() + selected_paths = [normalize_workspace_path(path) for path in temp_workspace.selected_paths] + manifest = temp_workspace.manifest + local_files = _collect_temp_workspace_files(temp_workspace.root, selected_paths) + + updated: list[str] = [] + conflicted: list[str] = [] + deleted: list[str] = [] + skipped: list[str] = [] + + async with workspace_locks(temp_workspace.agent_id, selected_paths): + for rel_path, local_path in local_files.items(): + if local_path.name.startswith("_exec_tmp") or "__pycache__" in local_path.parts: + continue + data = local_path.read_bytes() + current_hash = content_hash_bytes(data) + entry = manifest.get(rel_path) + if entry and entry.base_hash == current_hash: + skipped.append(rel_path) + continue + condition = ( + WriteCondition(version_token=entry.base_version_token) + if entry + else WriteCondition(require_absent=True) + ) + storage_key = entry.storage_key if entry else normalize_storage_key(f"{temp_workspace.agent_id}/{rel_path}") + result = await storage.write_bytes_if_match( + storage_key, + data, + condition=condition, + ) + if not result.ok: + conflicted.append(rel_path) + if conflict_mode == "fail": + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + continue + updated.append(rel_path) + + for rel_path, entry in manifest.items(): + if rel_path in local_files: + continue + result = await storage.delete_if_match( + entry.storage_key, + condition=WriteCondition(version_token=entry.base_version_token), + ) + if not result.ok: + conflicted.append(rel_path) + if conflict_mode == "fail": + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + continue + deleted.append(rel_path) + + return {"updated": updated, "deleted": deleted, "conflicted": conflicted, "skipped": skipped} + + +def _collect_temp_workspace_files(root: Path, selected_paths: list[str]) -> dict[str, Path]: + files: dict[str, Path] = {} + root_resolved = root.resolve() + for selected in selected_paths: + if not selected: + continue + target = (root_resolved / selected).resolve() + if not str(target).startswith(str(root_resolved)): + continue + if target.is_file(): + files[normalize_workspace_path(selected)] = target + continue + if not target.exists() or not target.is_dir(): + continue + for path in target.rglob("*"): + if not path.is_file(): + continue + rel = path.resolve().relative_to(root_resolved).as_posix() + files[normalize_workspace_path(rel)] = path + return files + + # ─── Tool Executors ───────────────────────────────────────────── # Mapping from tool_name to autonomy action_type used for policy lookup and notifications. @@ -2383,6 +2551,180 @@ async def _get_agent_tenant_id(agent_id: uuid.UUID) -> str | None: return None +def _agent_workspace_root(agent_id: uuid.UUID) -> Path: + """Return the per-agent local path without creating or hydrating it.""" + return WORKSPACE_ROOT / str(agent_id) + + +def _non_empty_paths(*paths: str | None) -> list[str] | None: + selected = [path for path in paths if path] + return selected or None + + +async def _run_with_temp_workspace( + agent_id: uuid.UUID, + tenant_id: str | None, + runner, + *, + paths: list[str] | None = None, + sync_back: bool = False, +) -> str: + """Materialize a temporary workspace for tools that require local files.""" + temp_workspace = await _prepare_temp_workspace(agent_id, tenant_id=tenant_id, paths=paths) + try: + result = await runner(temp_workspace.root) + if sync_back: + flush_result = await flush_temp_workspace(temp_workspace, conflict_mode="fail") + if flush_result["conflicted"]: + conflict_list = ", ".join(flush_result["conflicted"][:5]) + return f"❌ Workspace sync conflict for: {conflict_list}" + return result + finally: + temp_workspace.cleanup() + + +async def _execute_workspace_mutation( + tool_name: str, + arguments: dict, + *, + agent_id: uuid.UUID, + base_dir: Path, + session_id: str | None, +) -> str: + """Handle shared workspace mutations for both direct and normal tool execution.""" + if tool_name == "write_file": + path = arguments.get("path") + content = arguments.get("content") + if not path: + return "❌ Missing required argument 'path' for write_file. Please provide a file path like 'skills/my-skill/SKILL.md'" + if content is None: + return "❌ Missing required argument 'content' for write_file" + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + write_result = await write_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=path, + content=content, + actor_type="agent", + actor_id=agent_id, + operation="write", + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + return ( + f"✅ Written to {write_result.path} ({len(content)} chars)" + if write_result.ok + else f"❌ {write_result.message}" + ) + + if tool_name == "move_file": + source_path = arguments.get("source_path") + destination_path = arguments.get("destination_path") + if not source_path: + return "❌ Missing required argument 'source_path' for move_file" + if not destination_path: + return "❌ Missing required argument 'destination_path' for move_file" + if is_focus_file_path(source_path) or is_focus_file_path(destination_path): + return "❌ Focus is no longer stored in focus.md. Use Focus tools instead." + if str(source_path).strip("/") in {"tasks.json", "soul.md"}: + return f"❌ {source_path} cannot be moved (protected)" + if _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + move_result = await move_workspace_path( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + source_path=source_path, + destination_path=destination_path, + actor_type="agent", + actor_id=agent_id, + session_id=session_id, + enforce_human_lock=True, + overwrite=bool(arguments.get("overwrite", False)), + ) + await _wdb.commit() + return f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" + + if tool_name == "delete_file": + path = arguments.get("path", "") + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use Focus tools instead." + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + async with async_session() as _wdb: + delete_result = await delete_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=path, + actor_type="agent", + actor_id=agent_id, + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + return f"✅ Deleted {delete_result.path}" if delete_result.ok else f"❌ {delete_result.message}" + + if tool_name == "edit_file": + path = arguments.get("path") + old_string = arguments.get("old_string") + new_string = arguments.get("new_string") + if not path: + return "❌ Missing required argument 'path' for edit_file" + if old_string is None: + return "❌ Missing required argument 'old_string' for edit_file" + if new_string is None: + return "❌ Missing required argument 'new_string' for edit_file" + if is_focus_file_path(path): + return "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." + if _is_enterprise_info_path(path): + return "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." + + replace_all = arguments.get("replace_all", False) + storage = get_storage_backend() + storage_key, normalized_path, _ = _tool_storage_key(agent_id, path, None) + if not await storage.is_file(storage_key): + return f"File not found: {path}" + + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") + if old_string not in content: + return f"❌ 'old_string' not found in {path}. Please check the exact text including whitespace and newlines." + count = content.count(old_string) + if count > 1 and not replace_all: + return f"❌ 'old_string' appears {count} times in {path}. Use replace_all=true or provide more context to make the match unique." + + new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1) + async with async_session() as _wdb: + write_result = await write_workspace_file( + _wdb, + agent_id=agent_id, + base_dir=base_dir, + path=normalized_path, + content=new_content, + actor_type="agent", + actor_id=agent_id, + operation="edit", + session_id=session_id, + enforce_human_lock=True, + ) + await _wdb.commit() + replaced = count if replace_all else 1 + return ( + f"✅ Replaced {replaced} occurrence(s) in {write_result.path}" + if write_result.ok + else f"❌ {write_result.message}" + ) + + return f"Tool {tool_name} does not support workspace mutation execution" + + async def _execute_tool_direct( tool_name: str, arguments: dict, @@ -2394,48 +2736,24 @@ async def _execute_tool_direct( has been approved and needs to actually run. """ _agent_tenant_id = await _get_agent_tenant_id(agent_id) - ws = await ensure_workspace(agent_id, tenant_id=_agent_tenant_id) - try: - if tool_name == "delete_file": - path = arguments.get("path", "") - if _is_enterprise_info_path(path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - return _delete_file(ws, path) - elif tool_name == "write_file": - path = arguments.get("path") - content = arguments.get("content", "") - if not path: - return "Missing path" - if _is_enterprise_info_path(path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - return _write_file(ws, path, content, tenant_id=_agent_tenant_id) - elif tool_name == "move_file": - source_path = arguments.get("source_path") - destination_path = arguments.get("destination_path") - if not source_path or not destination_path: - return "Missing source_path or destination_path" - if _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): - return "enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - if str(source_path or "").strip("/").strip() in {"tasks.json", "soul.md"}: - return f"{source_path} cannot be moved (protected)" - async with async_session() as _wdb: - move_result = await move_workspace_path( - _wdb, - agent_id=agent_id, - base_dir=ws, - source_path=source_path, - destination_path=destination_path, - actor_type="agent", - actor_id=agent_id, - session_id=None, - enforce_human_lock=True, - overwrite=bool(arguments.get("overwrite", False)), - ) - await _wdb.commit() - return f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" + ws = _agent_workspace_root(agent_id) + try: + if tool_name in {"delete_file", "write_file", "move_file", "edit_file"}: + return await _execute_workspace_mutation( + tool_name, + arguments, + agent_id=agent_id, + base_dir=ws, + session_id=None, + ) elif tool_name in ("execute_code", "execute_code_e2b"): logger.info(f"[DirectTool] Executing code ({tool_name}) with arguments: {arguments}") - return await _execute_code(agent_id, ws, arguments, tool_name=tool_name) + return await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _execute_code(agent_id, temp_ws, arguments, tool_name=tool_name), + sync_back=True, + ) elif tool_name == "web_search": return await _web_search(arguments, agent_id) elif tool_name == "jina_search": @@ -2457,7 +2775,7 @@ async def _execute_tool_direct( elif tool_name == "send_message_to_agent": return await _send_message_to_agent(agent_id, arguments) elif tool_name == "send_file_to_agent": - return await _send_file_to_agent(agent_id, ws, arguments) + return await _send_file_to_agent(agent_id, arguments) else: return f"Tool {tool_name} does not support post-approval execution" except Exception as e: @@ -2495,7 +2813,7 @@ async def execute_tool( _agent_tenant_id = await _get_agent_tenant_id(agent_id) - ws = await ensure_workspace(agent_id, tenant_id=_agent_tenant_id) + ws = _agent_workspace_root(agent_id) # ── Autonomy boundary check ── action_type = _TOOL_AUTONOMY_MAP.get(tool_name) @@ -2540,7 +2858,7 @@ async def execute_tool( try: if tool_name == "list_files": - result = _list_files(ws, arguments.get("path", ""), tenant_id=_agent_tenant_id) + result = await _storage_list_dir(agent_id, arguments.get("path", ""), tenant_id=_agent_tenant_id) elif tool_name == "list_focus_items": items = await list_focus_items(agent_id, include_completed=bool(arguments.get("include_completed", True))) if not items: @@ -2580,164 +2898,68 @@ async def execute_tool( return "❌ Focus is no longer stored in focus.md. Use list_focus_items, upsert_focus_item, and complete_focus_item." offset = int(arguments.get("offset", 0)) limit = int(arguments.get("limit", 2000)) - result = _read_file(ws, path, tenant_id=_agent_tenant_id, offset=offset, limit=limit) + result = await _storage_read_file(agent_id, path, tenant_id=_agent_tenant_id, offset=offset, limit=limit) elif tool_name == "read_document": path = arguments.get("path") if not path: return "❌ Missing required argument 'path' for read_document" max_chars = min(int(arguments.get("max_chars", 8000)), 20000) - result = await _read_document(ws, path, max_chars=max_chars, tenant_id=_agent_tenant_id) - elif tool_name == "write_file": - path = arguments.get("path") - content = arguments.get("content") - if not path: - return "❌ Missing required argument 'path' for write_file. Please provide a file path like 'skills/my-skill/SKILL.md'" - if content is None: - return "❌ Missing required argument 'content' for write_file" - if is_focus_file_path(path): - result = "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." - elif _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - write_result = await write_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - content=content, - actor_type="agent", - actor_id=agent_id, - operation="write", - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - result = ( - f"✅ Written to {write_result.path} ({len(content)} chars)" - if write_result.ok - else f"❌ {write_result.message}" - ) - elif tool_name == "move_file": - source_path = arguments.get("source_path") - destination_path = arguments.get("destination_path") - if not source_path: - return "❌ Missing required argument 'source_path' for move_file" - if not destination_path: - return "❌ Missing required argument 'destination_path' for move_file" - protected = {"tasks.json", "soul.md"} - if is_focus_file_path(source_path) or is_focus_file_path(destination_path): - result = "❌ Focus is no longer stored in focus.md. Use Focus tools instead." - elif str(source_path).strip("/") in protected: - result = f"❌ {source_path} cannot be moved (protected)" - elif _is_enterprise_info_path(source_path) or _is_enterprise_info_path(destination_path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - move_result = await move_workspace_path( - _wdb, - agent_id=agent_id, - base_dir=ws, - source_path=source_path, - destination_path=destination_path, - actor_type="agent", - actor_id=agent_id, - session_id=session_id, - enforce_human_lock=True, - overwrite=bool(arguments.get("overwrite", False)), - ) - await _wdb.commit() - result = f"✅ {move_result.message}" if move_result.ok else f"❌ {move_result.message}" - elif tool_name == "delete_file": - path = arguments.get("path", "") - if is_focus_file_path(path): - result = "❌ Focus is no longer stored in focus.md. Use Focus tools instead." - elif _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - async with async_session() as _wdb: - delete_result = await delete_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - actor_type="agent", - actor_id=agent_id, - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - result = f"✅ Deleted {delete_result.path}" if delete_result.ok else f"❌ {delete_result.message}" + result = await _read_document_from_storage(agent_id, path, max_chars=max_chars, tenant_id=_agent_tenant_id) + elif tool_name in {"write_file", "move_file", "delete_file", "edit_file"}: + result = await _execute_workspace_mutation( + tool_name, + arguments, + agent_id=agent_id, + base_dir=ws, + session_id=session_id, + ) # --- Enhanced file management tools --- elif tool_name == "convert_csv_to_xlsx": - result = await _convert_csv_to_xlsx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_csv_to_xlsx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_html_to_pdf": - result = await _convert_html_to_pdf(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_html_to_pdf(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_html_to_pptx": - result = await _convert_html_to_pptx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_html_to_pptx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_markdown_to_docx": - result = await _convert_markdown_to_docx(agent_id, ws, arguments) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_markdown_to_docx(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "convert_markdown_to_pdf": - result = await _convert_markdown_to_pdf(agent_id, ws, arguments) - elif tool_name == "edit_file": - path = arguments.get("path") - old_string = arguments.get("old_string") - new_string = arguments.get("new_string") - if not path: - return "❌ Missing required argument 'path' for edit_file" - if old_string is None: - return "❌ Missing required argument 'old_string' for edit_file" - if new_string is None: - return "❌ Missing required argument 'new_string' for edit_file" - replace_all = arguments.get("replace_all", False) - if is_focus_file_path(path): - result = "❌ Focus is no longer stored in focus.md. Use upsert_focus_item or complete_focus_item." - elif _is_enterprise_info_path(path): - result = "❌ enterprise_info is shared company context and is read-only for agents. Ask an admin to update it." - else: - file_path = (ws / path).resolve() - if not str(file_path).startswith(str(ws.resolve())): - result = "Access denied for this path" - elif not file_path.exists(): - result = f"File not found: {path}" - elif not file_path.is_file(): - result = f"Not a file: {path}" - else: - content = await read_text_if_exists(file_path) or "" - if old_string not in content: - result = f"❌ 'old_string' not found in {path}. Please check the exact text including whitespace and newlines." - else: - count = content.count(old_string) - if count > 1 and not replace_all: - result = f"❌ 'old_string' appears {count} times in {path}. Use replace_all=true or provide more context to make the match unique." - else: - new_content = content.replace(old_string, new_string) if replace_all else content.replace(old_string, new_string, 1) - async with async_session() as _wdb: - write_result = await write_workspace_file( - _wdb, - agent_id=agent_id, - base_dir=ws, - path=path, - content=new_content, - actor_type="agent", - actor_id=agent_id, - operation="edit", - session_id=session_id, - enforce_human_lock=True, - ) - await _wdb.commit() - replaced = count if replace_all else 1 - result = ( - f"✅ Replaced {replaced} occurrence(s) in {write_result.path}" - if write_result.ok - else f"❌ {write_result.message}" - ) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _convert_markdown_to_pdf(agent_id, temp_ws, arguments), + paths=_non_empty_paths(arguments.get("source_path", "")), + sync_back=True, + ) elif tool_name == "search_files": pattern = arguments.get("pattern") if not pattern: return "❌ Missing required argument 'pattern' for search_files" - result = _search_files( - ws, + result = await _storage_search_files( + agent_id, pattern, path=arguments.get("path", "."), file_pattern=arguments.get("file_pattern", "*"), @@ -2748,8 +2970,8 @@ async def execute_tool( pattern = arguments.get("pattern") if not pattern: return "❌ Missing required argument 'pattern' for find_files" - result = _find_files( - ws, + result = await _storage_find_files( + agent_id, pattern, path=arguments.get("path", "."), tenant_id=_agent_tenant_id @@ -2778,9 +3000,18 @@ async def execute_tool( elif tool_name == "send_message_to_agent": result = await _send_message_to_agent(agent_id, arguments) elif tool_name == "send_file_to_agent": - result = await _send_file_to_agent(agent_id, ws, arguments) + result = await _send_file_to_agent(agent_id, arguments) elif tool_name == "send_channel_file": - result = await _send_channel_file(agent_id, ws, arguments) + file_path = (arguments.get("file_path") or "").strip() + if not file_path: + result = "Error: file_path is required" + else: + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _send_channel_file(agent_id, temp_ws, arguments), + paths=[file_path], + ) elif tool_name == "web_search": result = await _web_search(arguments, agent_id) elif tool_name == "jina_search": @@ -2807,17 +3038,48 @@ async def execute_tool( result = await _plaza_add_comment(agent_id, arguments) elif tool_name in ("execute_code", "execute_code_e2b"): logger.info(f"[DirectTool] Executing code ({tool_name}) with arguments: {arguments}") - result = await _execute_code(agent_id, ws, arguments, tool_name=tool_name) + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _execute_code(agent_id, temp_ws, arguments, tool_name=tool_name), + sync_back=True, + ) elif tool_name == "upload_image": - result = await _upload_image(agent_id, ws, arguments) + file_path = (arguments.get("file_path") or "").strip() + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _upload_image(agent_id, temp_ws, arguments), + paths=_non_empty_paths(file_path), + ) elif tool_name == "generate_image_siliconflow": - result = await _generate_image(agent_id, ws, arguments, "siliconflow") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "siliconflow"), + sync_back=True, + ) elif tool_name == "generate_image_openai": - result = await _generate_image(agent_id, ws, arguments, "openai") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "openai"), + sync_back=True, + ) elif tool_name == "generate_image_google": - result = await _generate_image(agent_id, ws, arguments, "google") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "google"), + sync_back=True, + ) elif tool_name == "generate_image_custom": - result = await _generate_image(agent_id, ws, arguments, "custom") + result = await _run_with_temp_workspace( + agent_id, + _agent_tenant_id, + lambda temp_ws: _generate_image(agent_id, temp_ws, arguments, "custom"), + sync_back=True, + ) elif tool_name == "discover_resources": result = await _discover_resources(agent_id, arguments) elif tool_name == "import_mcp_server": @@ -4195,6 +4457,191 @@ def _resolve_tool_target_path(ws: Path, rel_path: str, tenant_id: str | None = N return candidate +def _tool_storage_key(agent_id: uuid.UUID, rel_path: str, tenant_id: str | None = None) -> tuple[str, str, bool]: + normalized = normalize_workspace_path(_normalize_tool_rel_path(rel_path)) + if _is_enterprise_info_path(normalized): + if not tenant_id: + return normalize_storage_key("enterprise_info/" + normalized.removeprefix("enterprise_info").lstrip("/")), normalized, True + sub = normalized[len("enterprise_info"):].lstrip("/") + key = f"enterprise_info_{tenant_id}/{sub}" if sub else f"enterprise_info_{tenant_id}" + return normalize_storage_key(key), normalized, True + key = f"{agent_id}/{normalized}" if normalized else str(agent_id) + return normalize_storage_key(key), normalized, False + + +def _display_size(size_bytes: int) -> str: + return f"{size_bytes}B" if size_bytes < 1024 else f"{size_bytes / 1024:.1f}KB" + + +async def _storage_list_dir(agent_id: uuid.UUID, rel_path: str, tenant_id: str | None = None) -> str: + storage = get_storage_backend() + storage_key, normalized, is_enterprise = _tool_storage_key(agent_id, rel_path, tenant_id) + + exists = await storage.exists(storage_key) + is_dir = await storage.is_dir(storage_key) + if exists and not is_dir: + return f"Path is not a directory: {rel_path}" + if not exists and not is_dir and normalized: + return f"Directory not found: {rel_path or '/'}" + + items: list[str] = [] + dir_count = 0 + file_count = 0 + if not normalized and tenant_id: + items.append(" 📁 enterprise_info/ (shared company info)") + dir_count += 1 + + entries = await storage.list_dir(storage_key) if exists or is_dir else [] + for entry in entries: + if entry.name.startswith("."): + continue + if entry.is_dir: + dir_count += 1 + try: + child_count = len([c for c in await storage.list_dir(entry.key) if not c.name.startswith(".")]) + except Exception: + child_count = 0 + items.append(f" 📁 {entry.name}/ ({child_count} items)") + else: + file_count += 1 + items.append(f" 📄 {entry.name} ({_display_size(entry.size)})") + + if not items: + return f"📂 {rel_path or 'root'}: Empty directory (0 files, 0 folders)" + header = f"📂 {rel_path or 'root'}: {dir_count} folder(s), {file_count} file(s)\n" + return header + "\n".join(items) + + +async def _storage_read_file( + agent_id: uuid.UUID, + rel_path: str, + tenant_id: str | None = None, + offset: int = 0, + limit: int = 2000, +) -> str: + storage = get_storage_backend() + storage_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not normalized: + return "File not found: root" + if not await storage.is_file(storage_key): + return f"File not found: {rel_path}" + try: + content = await storage.read_text(storage_key, encoding="utf-8", errors="replace") + lines = content.splitlines() + total_lines = len(lines) + start = max(0, offset) + end = min(total_lines, start + limit) + if start >= total_lines and total_lines > 0: + return f"Offset {offset} exceeds file length ({total_lines} lines total)" + selected_lines = lines[start:end] + output = "\n".join(f"{i + 1:6}\t{line}" for i, line in enumerate(selected_lines, start=start)) + if total_lines > end: + output += f"\n\n... [{total_lines - end} more lines not shown, lines {end + 1}-{total_lines}]" + header = f"📄 {rel_path} (lines {start + 1 if total_lines else 0}-{end} of {total_lines})\n" + return header + output + except Exception as e: + return f"Read failed: {e}" + + +async def _storage_walk_files(storage, root_key: str) -> list: + out = [] + for entry in await storage.list_dir(root_key): + if entry.name.startswith("."): + continue + out.append(entry) + if entry.is_dir: + out.extend(await _storage_walk_files(storage, entry.key)) + return out + + +def _relative_storage_display(entry_key: str, base_key: str, display_base: str) -> str: + rel = entry_key.removeprefix(base_key.rstrip("/") + "/") + return f"{display_base.rstrip('/')}/{rel}".strip("/") if display_base else rel + + +async def _storage_search_files( + agent_id: uuid.UUID, + pattern: str, + path: str = ".", + file_pattern: str = "*", + ignore_case: bool = False, + tenant_id: str | None = None, +) -> str: + storage = get_storage_backend() + rel_path = "" if path in ("", ".") else path + base_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not await storage.is_dir(base_key) and normalized: + return f"Directory not found: {path}" + flags = re.IGNORECASE if ignore_case else 0 + try: + regex = re.compile(pattern, flags) + except re.error as e: + return f"Invalid regex pattern: {e}" + + results: list[str] = [] + total_matches = 0 + files_searched = 0 + entries = await _storage_walk_files(storage, base_key) if await storage.is_dir(base_key) else [] + for entry in entries: + if entry.is_dir: + continue + rel_display = _relative_storage_display(entry.key, base_key, normalized) + if not fnmatch.fnmatch(Path(rel_display).name, file_pattern) and not fnmatch.fnmatch(rel_display, file_pattern): + continue + if Path(rel_display).suffix.lower() in {".pyc", ".pyo", ".so", ".dll", ".exe", ".bin", ".png", ".jpg", ".jpeg", ".gif", ".zip", ".tar", ".gz"}: + continue + files_searched += 1 + try: + content = await storage.read_text(entry.key, encoding="utf-8", errors="ignore") + except Exception: + continue + for i, line in enumerate(content.splitlines(), 1): + if regex.search(line): + results.append(f"{rel_display}:{i}: {line.strip()[:100]}") + total_matches += 1 + if len(results) >= 50: + break + if len(results) >= 50: + break + if not results: + return f"No matches found for pattern '{pattern}' in {files_searched} file(s)" + truncated = total_matches > len(results) + truncation_note = f" (showing first {len(results)} of {total_matches}+ — refine pattern or path for more)" if truncated else "" + return f"🔍 Found {total_matches}+ match(es) in {files_searched} file(s) for pattern '{pattern}'{truncation_note}:\n" + "\n".join(results) + + +async def _storage_find_files( + agent_id: uuid.UUID, + pattern: str, + path: str = ".", + tenant_id: str | None = None, +) -> str: + storage = get_storage_backend() + rel_path = "" if path in ("", ".") else path + base_key, normalized, _ = _tool_storage_key(agent_id, rel_path, tenant_id) + if not await storage.is_dir(base_key) and normalized: + return f"Directory not found: {path}" + entries = await _storage_walk_files(storage, base_key) if await storage.is_dir(base_key) else [] + matches = [] + for entry in entries: + rel_display = _relative_storage_display(entry.key, base_key, normalized) + if fnmatch.fnmatch(rel_display, pattern) or fnmatch.fnmatch(Path(rel_display).name, pattern): + matches.append((entry, rel_display)) + if not matches: + return f"No files matching pattern: {pattern}" + results = [] + dir_count = 0 + file_count = 0 + for entry, rel_display in matches[:100]: + if entry.is_dir: + dir_count += 1 + results.append(f"📁 {rel_display}/") + else: + file_count += 1 + results.append(f"📄 {rel_display} ({_display_size(entry.size)})") + return f"📂 Found {len(matches)} item(s) ({dir_count} dirs, {file_count} files) matching '{pattern}':\n" + "\n".join(results) + + def _list_files(ws: Path, rel_path: str, tenant_id: str | None = None) -> str: # Handle enterprise_info/ as shared directory (tenant-scoped) if rel_path and rel_path.startswith("enterprise_info"): @@ -4598,6 +5045,19 @@ async def _read_document(ws: Path, rel_path: str, max_chars: int = 8000, tenant_ return await asyncio.to_thread(_read_document_with_timeout, ws, rel_path, max_chars, tenant_id) +async def _read_document_from_storage( + agent_id: uuid.UUID, + rel_path: str, + max_chars: int = 8000, + tenant_id: str | None = None, +) -> str: + temp_workspace = await _prepare_temp_workspace(agent_id, tenant_id=tenant_id, paths=[rel_path]) + try: + return await _read_document(temp_workspace.root, rel_path, max_chars=max_chars, tenant_id=None) + finally: + temp_workspace.cleanup() + + # ─── Format Conversion Tools ──────────────────────────────────── async def _convert_csv_to_xlsx(agent_id: uuid.UUID, ws: Path, arguments: dict) -> str: @@ -6055,7 +6515,7 @@ async def _send_platform_message(agent_id: uuid.UUID, args: dict) -> str: return f"❌ Web message send error: {str(e)[:200]}" -async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> str: +async def _send_file_to_agent(from_agent_id: uuid.UUID, args: dict) -> str: """Send a workspace file to another digital employee (agent).""" agent_name = (args.get("agent_name") or "").strip() rel_path = (args.get("file_path") or "").strip() @@ -6064,35 +6524,28 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> if not agent_name or not rel_path: return "❌ Please provide both agent_name and file_path" - # Resolve source file path inside sender workspace - source_file_path = (ws / rel_path).resolve() - ws_resolved = ws.resolve() - sender_root = (WORKSPACE_ROOT / str(from_agent_id)).resolve() - if not str(source_file_path).startswith(str(ws_resolved)): - source_file_path = (sender_root / rel_path).resolve() - if not str(source_file_path).startswith(str(sender_root)): - return "❌ Access denied: source path is outside your workspace" - - if not source_file_path.exists(): + storage = get_storage_backend() + source_key = normalize_storage_key(f"{from_agent_id}/{rel_path}") + if not await storage.is_file(source_key): return f"❌ Source file not found: {rel_path}" - if not source_file_path.is_file(): - return f"❌ Source path is not a file: {rel_path}" + source_entry = await storage.stat(source_key) # File size limit (50 MB) MAX_FILE_SIZE = 50 * 1024 * 1024 - file_size = source_file_path.stat().st_size + file_size = source_entry.size if file_size > MAX_FILE_SIZE: size_mb = file_size / (1024 * 1024) return f"❌ File too large ({size_mb:.1f} MB). Maximum allowed is 50 MB." + source_bytes = await storage.read_bytes(source_key) + source_name = Path(rel_path).name try: from app.services.activity_logger import log_activity - import shutil async with async_session() as db: src_result = await db.execute(select(AgentModel).where(AgentModel.id == from_agent_id)) source_agent = src_result.scalar_one_or_none() - source_name = source_agent.name if source_agent else "Unknown agent" + source_agent_name = source_agent.name if source_agent else "Unknown agent" source_tenant_id = source_agent.tenant_id if source_agent else None # Build base filter: same tenant + not self @@ -6144,38 +6597,29 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> if status_info["access_status"] != "active": return f"❌ Relationship to {target_agent.name} is not active ({status_info['access_status_reason'] or 'restricted'}). Ask a manager of both agents to review Relationships." - target_tenant_id = str(target_agent.tenant_id) if target_agent.tenant_id else None target_name = target_agent.name target_id = target_agent.id - target_ws = await ensure_workspace(target_id, tenant_id=target_tenant_id) - inbox_dir = (target_ws / "workspace" / "inbox").resolve() - files_dir = (inbox_dir / "files").resolve() - target_ws_resolved = target_ws.resolve() - if not str(inbox_dir).startswith(str(target_ws_resolved)) or not str(files_dir).startswith(str(target_ws_resolved)): - return "❌ Access denied for target agent inbox path" - - inbox_dir.mkdir(parents=True, exist_ok=True) - files_dir.mkdir(parents=True, exist_ok=True) - ts = datetime.now(timezone.utc) stamp = ts.strftime("%Y%m%d_%H%M%S_%f") - delivered_name = source_file_path.name - delivered_path = files_dir / delivered_name - while delivered_path.exists(): - delivered_name = f"{stamp}_{source_file_path.name}" - delivered_path = files_dir / delivered_name + delivered_name = source_name + target_rel_path = f"workspace/inbox/files/{delivered_name}" + target_key = normalize_storage_key(f"{target_id}/{target_rel_path}") + while await storage.exists(target_key): + delivered_name = f"{stamp}_{source_name}" + target_rel_path = f"workspace/inbox/files/{delivered_name}" + target_key = normalize_storage_key(f"{target_id}/{target_rel_path}") - shutil.copy2(source_file_path, delivered_path) + await storage.write_bytes(target_key, source_bytes) sender_short = str(from_agent_id)[:8] - note_path = inbox_dir / f"{stamp}_{sender_short}_file_delivery.md" - target_rel_path = f"workspace/inbox/files/{delivered_name}" + note_rel_path = f"workspace/inbox/{stamp}_{sender_short}_file_delivery.md" + note_key = normalize_storage_key(f"{target_id}/{note_rel_path}") note_lines = [ - f"# File delivery from {source_name}", + f"# File delivery from {source_agent_name}", "", f"- Time (UTC): {ts.isoformat()}", - f"- Sender: {source_name}", + f"- Sender: {source_agent_name}", f"- Source path: {rel_path}", f"- Delivered file: {target_rel_path}", "", @@ -6186,7 +6630,7 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> note_lines.append("") note_lines.append("## Action") note_lines.append(f"- Read the file via `read_file(path=\"{target_rel_path}\")`") - note_path.write_text("\n".join(note_lines), encoding="utf-8") + await storage.write_text(note_key, "\n".join(note_lines), encoding="utf-8") from app.models.audit import AuditLog async with async_session() as db: @@ -6205,7 +6649,7 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> action="collaboration:file_receive", details={ "from_agent": str(from_agent_id), - "from_agent_name": source_name, + "from_agent_name": source_agent_name, "source_file": rel_path, "delivered_file": target_rel_path, }, @@ -6221,14 +6665,14 @@ async def _send_file_to_agent(from_agent_id: uuid.UUID, ws: Path, args: dict) -> await log_activity( target_id, "agent_file_received", - f"Received file from {source_name}", - detail={"source_agent": source_name, "source_file": rel_path, "delivered_file": target_rel_path}, + f"Received file from {source_agent_name}", + detail={"source_agent": source_agent_name, "source_file": rel_path, "delivered_file": target_rel_path}, ) return ( f"✅ File sent to {target_name}.\n" f"- Delivered to: {target_rel_path}\n" - f"- Inbox note: workspace/inbox/{note_path.name}" + f"- Inbox note: {note_rel_path}" ) except Exception as e: return f"❌ Agent file send error: {str(e)[:200]}" diff --git a/backend/app/services/collaboration.py b/backend/app/services/collaboration.py index 394281d8a..9cc46db93 100644 --- a/backend/app/services/collaboration.py +++ b/backend/app/services/collaboration.py @@ -10,6 +10,7 @@ from app.models.agent import Agent from app.models.audit import AuditLog +from app.services.storage import store_agent_bytes class CollaborationService: @@ -111,21 +112,16 @@ async def send_message_between_agents( from_result = await db.execute(select(Agent).where(Agent.id == from_agent_id)) from_agent = from_result.scalar_one_or_none() - # Write message to target agent's workspace - from pathlib import Path - from app.config import get_settings - settings = get_settings() - - inbox_dir = Path(settings.AGENT_DATA_DIR) / str(to_agent_id) / "workspace" / "inbox" - inbox_dir.mkdir(parents=True, exist_ok=True) - timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") - msg_file = inbox_dir / f"{timestamp}_{str(from_agent_id)[:8]}.md" - msg_file.write_text( + rel_path = f"workspace/inbox/{timestamp}_{str(from_agent_id)[:8]}.md" + await store_agent_bytes( + to_agent_id, + rel_path, f"# 来自 {from_agent.name if from_agent else 'Unknown'} 的消息\n" f"- 类型: {msg_type}\n" f"- 时间: {datetime.now(timezone.utc).isoformat()}\n\n" - f"{message}\n" + f"{message}\n".encode("utf-8"), + content_type="text/markdown; charset=utf-8", ) db.add(AuditLog( diff --git a/backend/app/services/dingtalk_stream.py b/backend/app/services/dingtalk_stream.py index 73cf80b20..91612998d 100644 --- a/backend/app/services/dingtalk_stream.py +++ b/backend/app/services/dingtalk_stream.py @@ -16,10 +16,10 @@ from loguru import logger from sqlalchemy import select -from app.config import get_settings from app.database import async_session from app.models.channel_config import ChannelConfig from app.services.dingtalk_token import dingtalk_token_manager +from app.services.storage import store_agent_upload # ─── DingTalk Media Helpers ───────────────────────────── @@ -77,14 +77,6 @@ async def download_dingtalk_media( return await _download_file(download_url) -def _resolve_upload_dir(agent_id: uuid.UUID) -> Path: - """Get the uploads directory for an agent, creating it if needed.""" - settings = get_settings() - upload_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "workspace" / "uploads" - upload_dir.mkdir(parents=True, exist_ok=True) - return upload_dir - - async def _process_media_message( msg_data: dict, app_key: str, @@ -121,18 +113,21 @@ async def _process_media_message( if not file_bytes: return "[User sent an image, but download failed]", None, None - upload_dir = _resolve_upload_dir(agent_id) filename = f"dingtalk_img_{uuid.uuid4().hex[:8]}.jpg" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved image to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[DingTalk] Saved image to {workspace_path} ({len(file_bytes)} bytes)") b64_data = base64.b64encode(file_bytes).decode("ascii") image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" return ( f"[User sent an image]\n{image_marker}", [f"data:image/jpeg;base64,{b64_data}"], - [str(save_path)], + [workspace_path], ) elif msgtype == "richText": @@ -148,17 +143,20 @@ async def _process_media_message( app_key, app_secret, item["downloadCode"] ) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) filename = f"dingtalk_richimg_{uuid.uuid4().hex[:8]}.jpg" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved rich text image to {save_path}") + _, workspace_path, _ = await store_agent_upload( + agent_id, + filename, + file_bytes, + content_type="image/jpeg", + ) + logger.info(f"[DingTalk] Saved rich text image to {workspace_path}") b64_data = base64.b64encode(file_bytes).decode("ascii") image_marker = f"[image_data:data:image/jpeg;base64,{b64_data}]" text_parts.append(image_marker) image_base64_list.append(f"data:image/jpeg;base64,{b64_data}") - saved_file_paths.append(str(save_path)) + saved_file_paths.append(workspace_path) combined_text = "\n".join(text_parts).strip() if not combined_text: @@ -181,16 +179,14 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) duration = content.get("duration", "unknown") filename = f"dingtalk_audio_{uuid.uuid4().hex[:8]}.amr" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved audio to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload(agent_id, filename, file_bytes) + logger.info(f"[DingTalk] Saved audio to {workspace_path} ({len(file_bytes)} bytes)") return ( f"[User sent a voice message, duration {duration}ms, saved to {filename}]", None, - [str(save_path)], + [workspace_path], ) return "[User sent a voice message, but it could not be processed]", None, None @@ -200,16 +196,14 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) duration = content.get("duration", "unknown") filename = f"dingtalk_video_{uuid.uuid4().hex[:8]}.mp4" - save_path = upload_dir / filename - save_path.write_bytes(file_bytes) - logger.info(f"[DingTalk] Saved video to {save_path} ({len(file_bytes)} bytes)") + _, workspace_path, _ = await store_agent_upload(agent_id, filename, file_bytes) + logger.info(f"[DingTalk] Saved video to {workspace_path} ({len(file_bytes)} bytes)") return ( f"[User sent a video, duration {duration}ms, saved to {filename}]", None, - [str(save_path)], + [workspace_path], ) return "[User sent a video, but it could not be downloaded]", None, None @@ -220,18 +214,16 @@ async def _process_media_message( if download_code: file_bytes = await download_dingtalk_media(app_key, app_secret, download_code) if file_bytes: - upload_dir = _resolve_upload_dir(agent_id) safe_name = f"dingtalk_{uuid.uuid4().hex[:8]}_{original_filename}" - save_path = upload_dir / safe_name - save_path.write_bytes(file_bytes) + _, workspace_path, _ = await store_agent_upload(agent_id, safe_name, file_bytes) logger.info( - f"[DingTalk] Saved file '{original_filename}' to {save_path} " + f"[DingTalk] Saved file '{original_filename}' to {workspace_path} " f"({len(file_bytes)} bytes)" ) return ( f"[file:{original_filename}]", None, - [str(save_path)], + [workspace_path], ) return f"[User sent file {original_filename}, but it could not be downloaded]", None, None diff --git a/backend/app/services/enterprise_sync.py b/backend/app/services/enterprise_sync.py index b529d8387..a67523215 100644 --- a/backend/app/services/enterprise_sync.py +++ b/backend/app/services/enterprise_sync.py @@ -6,18 +6,15 @@ import json import uuid -from pathlib import Path from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.core.events import publish_event from app.models.agent import Agent from app.models.audit import EnterpriseInfo - -settings = get_settings() +from app.services.storage import store_agent_bytes # Redis channel for enterprise info updates ENTERPRISE_INFO_CHANNEL = "enterprise_info_updated" @@ -67,9 +64,6 @@ async def sync_to_agent(self, db: AsyncSession, agent_id: uuid.UUID, agent_role: Filters by visible_roles — if empty, all roles can see it. """ - agent_dir = Path(settings.AGENT_DATA_DIR) / str(agent_id) / "enterprise_info" - agent_dir.mkdir(parents=True, exist_ok=True) - result = await db.execute(select(EnterpriseInfo)) all_info = result.scalars().all() @@ -78,12 +72,16 @@ async def sync_to_agent(self, db: AsyncSession, agent_id: uuid.UUID, agent_role: if info.visible_roles and agent_role and agent_role not in info.visible_roles: continue - file_path = agent_dir / f"{info.info_type}.json" - file_path.write_text(json.dumps({ - "type": info.info_type, - "version": info.version, - "content": info.content, - }, ensure_ascii=False, indent=2)) + await store_agent_bytes( + agent_id, + f"enterprise_info/{info.info_type}.json", + json.dumps({ + "type": info.info_type, + "version": info.version, + "content": info.content, + }, ensure_ascii=False, indent=2).encode("utf-8"), + content_type="application/json", + ) logger.info(f"Synced enterprise info to agent {agent_id}") diff --git a/backend/app/services/heartbeat.py b/backend/app/services/heartbeat.py index 596cb8a25..d05f94b4e 100644 --- a/backend/app/services/heartbeat.py +++ b/backend/app/services/heartbeat.py @@ -13,7 +13,8 @@ from datetime import datetime, timezone, timedelta from loguru import logger -from sqlalchemy import select, update +from sqlalchemy import select, update, exists, and_ +from app.services.storage import agent_storage_key, get_storage_backend from app.services.llm.finish import FINISH_PROTOCOL_REMINDER, find_finish_call, parse_tool_arguments @@ -184,15 +185,12 @@ async def _execute_heartbeat(agent_id: uuid.UUID): model_request_timeout = getattr(model, 'request_timeout', None) # Read HEARTBEAT.md if it exists, otherwise use default - from pathlib import Path - from app.config import get_settings - settings = get_settings() - - ws_root = Path(settings.AGENT_DATA_DIR) / str(agent_id) - hb_file = ws_root / "HEARTBEAT.md" - if hb_file.exists(): + storage = get_storage_backend() + hb_key = agent_storage_key(agent_id, "HEARTBEAT.md") + if await storage.exists(hb_key): try: - custom = hb_file.read_text(encoding="utf-8", errors="replace").strip() + custom = await storage.read_text(hb_key, encoding="utf-8", errors="replace") + custom = custom.strip() if custom: # Prepend privacy rules to custom heartbeat heartbeat_instruction = custom + """ diff --git a/backend/app/services/llm/caller.py b/backend/app/services/llm/caller.py index aff3c458e..d60d22e6b 100644 --- a/backend/app/services/llm/caller.py +++ b/backend/app/services/llm/caller.py @@ -21,6 +21,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from app.config import get_settings from app.database import async_session from app.services.agent_tools import AGENT_TOOLS, execute_tool, get_agent_tools_for_llm from app.services.token_tracker import ( @@ -350,9 +351,8 @@ async def _process_tool_call( if supports_vision and agent_id: try: from app.services.vision_inject import try_inject_screenshot_vision - from app.config import get_settings settings = get_settings() - ws_path = Path(settings.AGENT_DATA_DIR) / str(agent_id) + ws_path = Path(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) / str(agent_id) vision_content = try_inject_screenshot_vision(tool_name, str(result), ws_path) if vision_content: tool_content = vision_content diff --git a/backend/app/services/okr_scheduler.py b/backend/app/services/okr_scheduler.py index e546613e1..4b3bb78cf 100644 --- a/backend/app/services/okr_scheduler.py +++ b/backend/app/services/okr_scheduler.py @@ -16,14 +16,12 @@ import re import uuid from datetime import date, datetime, timedelta -from pathlib import Path from typing import Optional from loguru import logger from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from app.config import get_settings from app.database import async_session from app.models.agent import Agent from app.models.okr import ( @@ -33,9 +31,7 @@ OKRSettings, WorkReport, ) - -_settings = get_settings() -WORKSPACE_ROOT = Path(_settings.AGENT_DATA_DIR) +from app.services.storage import agent_storage_key, get_storage_backend, store_agent_bytes # ─── Focus File Parsing ─────────────────────────────────────────────────────── @@ -138,17 +134,16 @@ async def collect_all_focus_updates( skipped_count = 0 error_count = 0 lines: list[str] = [] + storage = get_storage_backend() for agent in agents: - agent_dir = WORKSPACE_ROOT / str(agent.id) - focus_path = agent_dir / "focus.md" - - if not focus_path.exists(): + focus_key = agent_storage_key(agent.id, "focus.md") + if not await storage.exists(focus_key): skipped_count += 1 continue try: - content = focus_path.read_text(encoding="utf-8") + content = await storage.read_text(focus_key, encoding="utf-8", errors="replace") updates = _parse_focus_md(content) if not updates: @@ -159,7 +154,7 @@ async def collect_all_focus_updates( try: kr_uuid = uuid.UUID(kr_id_str) except ValueError: - logger.warning(f"[OKRScheduler] Invalid KR UUID '{kr_id_str}' in {focus_path}") + logger.warning(f"[OKRScheduler] Invalid KR UUID '{kr_id_str}' in {focus_key}") continue # Fetch the KR and verify it belongs to this tenant @@ -403,12 +398,15 @@ async def _store_report( await db.commit() -def _safe_write_report(agent_dir: Path, filename: str, content: str) -> None: +async def _safe_write_report(okr_agent_id: uuid.UUID, filename: str, content: str) -> None: """Write report to OKR Agent's workspace/reports/ directory.""" try: - reports_dir = agent_dir / "workspace" / "reports" - reports_dir.mkdir(parents=True, exist_ok=True) - (reports_dir / filename).write_text(content, encoding="utf-8") + await store_agent_bytes( + okr_agent_id, + f"workspace/reports/{filename}", + content.encode("utf-8"), + content_type="text/markdown; charset=utf-8", + ) except Exception as exc: logger.warning(f"[OKRScheduler] Could not write report file {filename}: {exc}") @@ -445,8 +443,7 @@ async def generate_daily_report( await _store_report(tenant_id, okr_agent_id, "daily", today, content, db) # Write file to workspace - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) - _safe_write_report(agent_dir, f"daily_{today.strftime('%Y%m%d')}.md", content) + await _safe_write_report(okr_agent_id, f"daily_{today.strftime('%Y%m%d')}.md", content) logger.info(f"[OKRScheduler] Daily report generated for tenant {tenant_id}") return content @@ -485,9 +482,8 @@ async def generate_weekly_report( monday = today - timedelta(days=today.weekday()) await _store_report(tenant_id, okr_agent_id, "weekly", monday, content, db) - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) week_label = monday.strftime("%Y-W%V") - _safe_write_report(agent_dir, f"weekly_{week_label}.md", content) + await _safe_write_report(okr_agent_id, f"weekly_{week_label}.md", content) logger.info(f"[OKRScheduler] Weekly report generated for tenant {tenant_id}") return content @@ -565,9 +561,8 @@ async def generate_monthly_report( await _store_report(tenant_id, okr_agent_id, "monthly", month_start, content, db) # Write file to OKR Agent workspace - agent_dir = WORKSPACE_ROOT / str(okr_agent_id) month_label = month_start.strftime("%Y-%m") - _safe_write_report(agent_dir, f"monthly_{month_label}.md", content) + await _safe_write_report(okr_agent_id, f"monthly_{month_label}.md", content) logger.info(f"[OKRScheduler] Monthly report generated for tenant {tenant_id}") return content diff --git a/backend/app/services/org_sync_adapter.py b/backend/app/services/org_sync_adapter.py index 5df5dc7f6..45f0cb71c 100644 --- a/backend/app/services/org_sync_adapter.py +++ b/backend/app/services/org_sync_adapter.py @@ -9,7 +9,7 @@ import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Any from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, delete, func, or_, select, update @@ -51,6 +51,10 @@ def pinyin(value: str, style: str | None = None) -> list[list[str]]: from jose import jwt +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + def build_department_path_map(departments: list[OrgDepartment]) -> dict[uuid.UUID, str]: """Build department name paths by walking the internal department tree.""" dept_by_id = {dept.id: dept for dept in departments} @@ -239,7 +243,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: member_count = 0 user_count = 0 profile_count = 0 - sync_start = datetime.now() + sync_start = _utcnow() partial_failure = False # Ensure provider exists @@ -291,7 +295,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: # Update provider metadata if possible if self.provider: config = (self.provider.config or {}).copy() - config["last_synced_at"] = datetime.now().isoformat() + config["last_synced_at"] = _utcnow().isoformat() self.provider.config = config await db.flush() @@ -321,7 +325,7 @@ async def sync_org_structure(self, db: AsyncSession) -> dict[str, Any]: "profiles_synced": profile_count, "errors": errors, "provider": self.provider_type, - "synced_at": datetime.now().isoformat() + "synced_at": _utcnow().isoformat() } async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: datetime): @@ -333,7 +337,8 @@ async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: .where(OrgMember.provider_id == provider_id) .where(OrgMember.synced_at < sync_start) .where(OrgMember.status != "deleted") - .values(status="deleted", synced_at=datetime.now()) + .values(status="deleted", synced_at=_utcnow()) + .execution_options(synchronize_session=False) ) # 2. Departments reconciled @@ -342,7 +347,8 @@ async def _reconcile(self, db: AsyncSession, provider_id: uuid.UUID, sync_start: .where(OrgDepartment.provider_id == provider_id) .where(OrgDepartment.synced_at < sync_start) .where(OrgDepartment.status != "deleted") - .values(status="deleted", synced_at=datetime.now()) + .values(status="deleted", synced_at=_utcnow()) + .execution_options(synchronize_session=False) ) async def _update_member_counts(self, db: AsyncSession, provider_id: uuid.UUID): @@ -457,7 +463,7 @@ async def _upsert_department( ) existing = result.scalars().first() - now = datetime.now() + now = _utcnow() # Path is rebuilt from the internal department tree after sync. path = dept.name @@ -569,7 +575,7 @@ async def _upsert_member( existing_member = await self._find_existing_member(db, provider, user) - now = datetime.now() + now = _utcnow() # Note: Platform user creation is disabled - just sync OrgMember # Users will be linked to platform users manually or via SSO login diff --git a/backend/app/services/realtime.py b/backend/app/services/realtime.py new file mode 100644 index 000000000..39bcf2d85 --- /dev/null +++ b/backend/app/services/realtime.py @@ -0,0 +1,19 @@ +"""Compatibility facade for realtime services. + +New code should prefer the `app.services.realtime_runtime` package. +This module remains as the stable import path for existing callers. +""" + +from app.services.realtime_runtime import ( + PRESENCE_TTL_SECONDS, + PUBSUB_PREFIX, + RealtimeRouter, + realtime_router, +) + +__all__ = [ + "PRESENCE_TTL_SECONDS", + "PUBSUB_PREFIX", + "RealtimeRouter", + "realtime_router", +] diff --git a/backend/app/services/realtime_runtime/__init__.py b/backend/app/services/realtime_runtime/__init__.py new file mode 100644 index 000000000..feb673ad2 --- /dev/null +++ b/backend/app/services/realtime_runtime/__init__.py @@ -0,0 +1,15 @@ +"""Realtime routing runtime package.""" + +from app.services.realtime_runtime.router import ( + PRESENCE_TTL_SECONDS, + PUBSUB_PREFIX, + RealtimeRouter, + realtime_router, +) + +__all__ = [ + "PRESENCE_TTL_SECONDS", + "PUBSUB_PREFIX", + "RealtimeRouter", + "realtime_router", +] diff --git a/backend/app/services/realtime_runtime/router.py b/backend/app/services/realtime_runtime/router.py new file mode 100644 index 000000000..8cb8f0411 --- /dev/null +++ b/backend/app/services/realtime_runtime/router.py @@ -0,0 +1,200 @@ +"""Redis-backed websocket presence and cross-instance message routing.""" + +from __future__ import annotations + +import asyncio +import json +import uuid + +from fastapi import WebSocket +from loguru import logger + +from app.config import get_settings +from app.core.events import get_redis + +settings = get_settings() + +PRESENCE_TTL_SECONDS = 180 +PUBSUB_PREFIX = "realtime:ws" + + +class RealtimeRouter: + def __init__(self) -> None: + self.instance_id = settings.INSTANCE_ID + self._subscriber_task: asyncio.Task | None = None + self._started = False + + def _connection_key(self, connection_id: str) -> str: + return f"{PUBSUB_PREFIX}:conn:{connection_id}" + + def _agent_index_key(self, agent_id: str) -> str: + return f"{PUBSUB_PREFIX}:agent:{agent_id}" + + def _instance_channel(self) -> str: + return f"{PUBSUB_PREFIX}:instance:{self.instance_id}" + + async def register_connection( + self, + *, + agent_id: str, + websocket: WebSocket, + session_id: str | None, + user_id: str | None, + ) -> str: + connection_id = uuid.uuid4().hex + redis = await get_redis() + payload = { + "agent_id": agent_id, + "session_id": session_id or "", + "user_id": user_id or "", + "instance_id": self.instance_id, + } + async with redis.pipeline(transaction=True) as pipe: + pipe.sadd(self._agent_index_key(agent_id), connection_id) + pipe.hset(self._connection_key(connection_id), mapping=payload) + pipe.expire(self._connection_key(connection_id), PRESENCE_TTL_SECONDS) + pipe.expire(self._agent_index_key(agent_id), PRESENCE_TTL_SECONDS) + await pipe.execute() + setattr(websocket.state, "realtime_connection_id", connection_id) + return connection_id + + async def unregister_connection(self, *, agent_id: str, websocket: WebSocket) -> None: + connection_id = getattr(websocket.state, "realtime_connection_id", None) + if not connection_id: + return + redis = await get_redis() + async with redis.pipeline(transaction=True) as pipe: + pipe.srem(self._agent_index_key(agent_id), connection_id) + pipe.delete(self._connection_key(connection_id)) + await pipe.execute() + + async def is_user_viewing_session(self, *, agent_id: str, session_id: str, user_id: str) -> bool: + for record in await self._list_presence(agent_id): + if record.get("session_id") == session_id and record.get("user_id") == user_id: + return True + return False + + async def get_active_session_ids(self, agent_id: str) -> list[str]: + seen: set[str] = set() + for record in await self._list_presence(agent_id): + session_id = (record.get("session_id") or "").strip() + if session_id: + seen.add(session_id) + return list(seen) + + async def route_message( + self, + *, + agent_id: str, + message: dict, + local_connections: list[tuple[WebSocket, str | None, str | None]], + session_id: str | None = None, + user_id: str | None = None, + ) -> None: + local_sent = 0 + for ws, local_session_id, local_user_id in list(local_connections): + if session_id is not None and local_session_id != session_id: + continue + if user_id is not None and local_user_id != user_id: + continue + try: + await ws.send_json(message) + local_sent += 1 + except Exception: + pass + + remote_targets: dict[str, int] = {} + for record in await self._list_presence(agent_id): + if record.get("instance_id") == self.instance_id: + continue + if session_id is not None and record.get("session_id") != session_id: + continue + if user_id is not None and record.get("user_id") != user_id: + continue + target_instance = record.get("instance_id") + if target_instance: + remote_targets[target_instance] = remote_targets.get(target_instance, 0) + 1 + + if not remote_targets: + return + + redis = await get_redis() + envelope = json.dumps( + { + "message": message, + "agent_id": agent_id, + "session_id": session_id, + "user_id": user_id, + "origin_instance_id": self.instance_id, + } + ) + publish_tasks = [ + redis.publish(f"{PUBSUB_PREFIX}:instance:{instance_id}", envelope) + for instance_id in remote_targets + ] + await asyncio.gather(*publish_tasks, return_exceptions=True) + logger.debug( + f"[Realtime] Routed agent={agent_id} local={local_sent} remote_instances={list(remote_targets.keys())}" + ) + + async def start(self, deliver_local) -> None: + if self._started: + return + self._started = True + self._subscriber_task = asyncio.create_task(self._subscriber_loop(deliver_local), name="realtime-subscriber") + + async def stop(self) -> None: + if self._subscriber_task: + self._subscriber_task.cancel() + try: + await self._subscriber_task + except asyncio.CancelledError: + pass + self._subscriber_task = None + self._started = False + + async def _subscriber_loop(self, deliver_local) -> None: + redis = await get_redis() + pubsub = redis.pubsub() + await pubsub.subscribe(self._instance_channel()) + try: + while True: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0) + if not message: + await asyncio.sleep(0.05) + continue + try: + data = json.loads(message["data"]) + await deliver_local( + agent_id=data["agent_id"], + payload=data["message"], + session_id=data.get("session_id"), + user_id=data.get("user_id"), + ) + except Exception as exc: + logger.warning(f"[Realtime] Failed to deliver pubsub message: {exc}") + except asyncio.CancelledError: + raise + finally: + await pubsub.unsubscribe(self._instance_channel()) + await pubsub.aclose() + + async def _list_presence(self, agent_id: str) -> list[dict[str, str]]: + redis = await get_redis() + connection_ids = await redis.smembers(self._agent_index_key(agent_id)) + if not connection_ids: + return [] + records: list[dict[str, str]] = [] + stale_ids: list[str] = [] + for connection_id in connection_ids: + data = await redis.hgetall(self._connection_key(connection_id)) + if not data: + stale_ids.append(connection_id) + continue + records.append(data) + if stale_ids: + await redis.srem(self._agent_index_key(agent_id), *stale_ids) + return records + + +realtime_router = RealtimeRouter() diff --git a/backend/app/services/registration_service.py b/backend/app/services/registration_service.py index 2db0795d4..2688e670b 100644 --- a/backend/app/services/registration_service.py +++ b/backend/app/services/registration_service.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.config import get_settings -from app.core.security import hash_password +from app.core.security import hash_password_async from app.models.identity import IdentityProvider from app.models.tenant import Tenant from app.models.user import User, Identity @@ -201,7 +201,7 @@ async def find_or_create_identity( email=email, phone=normalized_phone, username=final_username, - password_hash=hash_password(password) if password else None, + password_hash=await hash_password_async(password) if password else None, is_platform_admin=is_platform_admin, email_verified=is_verified, ) diff --git a/backend/app/services/skill_seeder.py b/backend/app/services/skill_seeder.py index b3c1b2612..b658900cc 100644 --- a/backend/app/services/skill_seeder.py +++ b/backend/app/services/skill_seeder.py @@ -930,11 +930,11 @@ async def push_default_skills_to_existing_agents(): Called at startup after seed_skills() so existing agents automatically receive new default skills like mcp-installer without requiring manual re-creation. """ - from pathlib import Path from app.models.agent import Agent - from app.models.skill import Skill, SkillFile + from app.models.skill import Skill from sqlalchemy.orm import selectinload from app.services.agent_manager import agent_manager + from app.services.storage import get_storage_backend async with async_session() as db: # Load all is_default skills with their files @@ -951,33 +951,22 @@ async def push_default_skills_to_existing_agents(): pushed = 0 updated = 0 - removed_legacy = 0 + storage = get_storage_backend() for agent in agents: - agent_dir = agent_manager._agent_dir(agent.id) - skills_dir = agent_dir / "skills" - legacy_mcp_file = skills_dir / "MCP_INSTALLER.md" - if legacy_mcp_file.exists(): - try: - legacy_mcp_file.unlink() - removed_legacy += 1 - except OSError as exc: - logger.warning(f"[SkillSeeder] Failed to remove legacy MCP_INSTALLER.md for agent {agent.id}: {exc}") + agent_prefix = agent_manager._agent_storage_prefix(agent.id) for skill in default_skills: if not skill.files: continue - skill_folder = skills_dir / skill.folder_name - skill_folder.mkdir(parents=True, exist_ok=True) for sf in skill.files: - fp = (skill_folder / sf.path).resolve() - fp.parent.mkdir(parents=True, exist_ok=True) - if fp.exists(): - existing_content = fp.read_text(encoding="utf-8") + key = f"{agent_prefix}/skills/{skill.folder_name}/{sf.path}" + if await storage.is_file(key): + existing_content = await storage.read_text(key, encoding="utf-8", errors="replace") if existing_content == sf.content: continue # already up-to-date - fp.write_text(sf.content, encoding="utf-8") + await storage.write_text(key, sf.content, encoding="utf-8") updated += 1 else: - fp.write_text(sf.content, encoding="utf-8") + await storage.write_text(key, sf.content, encoding="utf-8") pushed += 1 logger.info(f"[SkillSeeder] Pushed '{skill.name}' to agent {agent.id}") diff --git a/backend/app/services/storage.py b/backend/app/services/storage.py new file mode 100644 index 000000000..6fc868b11 --- /dev/null +++ b/backend/app/services/storage.py @@ -0,0 +1,45 @@ +"""Compatibility facade for storage services. + +New code should prefer the `app.services.storage_runtime` package. +This module remains as the stable import path for existing callers. +""" + +from app.services.storage_runtime import ( + LocalStorageBackend, + S3StorageBackend, + StorageBackend, + StorageEntry, + agent_storage_key, + agent_storage_prefix, + agent_upload_key, + agent_workspace_key, + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, + sanitize_filename, + store_agent_bytes, + store_agent_upload, + tenant_storage_key, + tenant_storage_prefix, +) + +__all__ = [ + "LocalStorageBackend", + "S3StorageBackend", + "StorageBackend", + "StorageEntry", + "agent_storage_key", + "agent_storage_prefix", + "agent_upload_key", + "agent_workspace_key", + "ensure_local_path", + "get_storage_backend", + "guess_content_type", + "normalize_storage_key", + "sanitize_filename", + "store_agent_bytes", + "store_agent_upload", + "tenant_storage_key", + "tenant_storage_prefix", +] diff --git a/backend/app/services/storage_runtime/__init__.py b/backend/app/services/storage_runtime/__init__.py new file mode 100644 index 000000000..756e2e4fd --- /dev/null +++ b/backend/app/services/storage_runtime/__init__.py @@ -0,0 +1,53 @@ +"""Storage runtime package.""" + +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) +from app.services.storage_runtime.agent_files import ( + agent_storage_key, + agent_upload_key, + agent_workspace_key, + sanitize_filename, + store_agent_bytes, + store_agent_upload, + tenant_storage_key, +) +from app.services.storage_runtime.facade import ( + agent_storage_prefix, + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, + tenant_storage_prefix, +) +from app.services.storage_runtime.fallback import FallbackStorageBackend +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.storage_runtime.s3 import S3StorageBackend + +__all__ = [ + "StorageBackend", + "StorageEntry", + "StorageVersion", + "WriteCondition", + "ConditionalWriteResult", + "FallbackStorageBackend", + "LocalStorageBackend", + "S3StorageBackend", + "agent_storage_key", + "agent_storage_prefix", + "agent_upload_key", + "agent_workspace_key", + "ensure_local_path", + "get_storage_backend", + "guess_content_type", + "normalize_storage_key", + "sanitize_filename", + "store_agent_bytes", + "store_agent_upload", + "tenant_storage_key", + "tenant_storage_prefix", +] diff --git a/backend/app/services/storage_runtime/agent_files.py b/backend/app/services/storage_runtime/agent_files.py new file mode 100644 index 000000000..d9c720816 --- /dev/null +++ b/backend/app/services/storage_runtime/agent_files.py @@ -0,0 +1,84 @@ +"""Agent-scoped storage helpers. + +This module centralizes how agent and tenant workspace keys are built so +channel handlers and background services do not manually assemble +`workspace/uploads/...` paths all over the codebase. +""" + +from __future__ import annotations + +import os +import uuid +from pathlib import Path + +from app.services.storage_runtime.facade import ( + ensure_local_path, + get_storage_backend, + guess_content_type, + normalize_storage_key, +) + + +def sanitize_filename(filename: str, fallback: str = "file.bin") -> str: + name = (filename or "").replace("\\", "_").replace("/", "_").strip() + return name or fallback + + +def agent_storage_key(agent_id: uuid.UUID | str, rel_path: str = "") -> str: + prefix = str(agent_id) + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix + + +def agent_workspace_key(agent_id: uuid.UUID | str, rel_path: str = "") -> str: + rel = normalize_storage_key(rel_path) + workspace_rel = f"workspace/{rel}" if rel else "workspace" + return agent_storage_key(agent_id, workspace_rel) + + +def agent_upload_key(agent_id: uuid.UUID | str, filename: str) -> str: + safe_name = sanitize_filename(filename) + return agent_workspace_key(agent_id, f"uploads/{safe_name}") + + +def tenant_storage_key(tenant_id: uuid.UUID | str, rel_path: str = "") -> str: + prefix = normalize_storage_key(f"enterprise_info_{tenant_id}") + rel = normalize_storage_key(rel_path) + return f"{prefix}/{rel}" if rel else prefix + + +async def store_agent_bytes( + agent_id: uuid.UUID | str, + rel_path: str, + data: bytes, + *, + content_type: str | None = None, +) -> str: + key = agent_storage_key(agent_id, rel_path) + storage = get_storage_backend() + await storage.write_bytes( + key, + data, + content_type=content_type or guess_content_type(Path(rel_path).name), + ) + return key + + +async def store_agent_upload( + agent_id: uuid.UUID | str, + filename: str, + data: bytes, + *, + content_type: str | None = None, +) -> tuple[str, str, Path]: + key = agent_upload_key(agent_id, filename) + storage = get_storage_backend() + safe_name = os.path.basename(key) + await storage.write_bytes( + key, + data, + content_type=content_type or guess_content_type(safe_name), + ) + local_path = await ensure_local_path(key) + workspace_path = f"workspace/uploads/{safe_name}" + return key, workspace_path, local_path diff --git a/backend/app/services/storage_runtime/base.py b/backend/app/services/storage_runtime/base.py new file mode 100644 index 000000000..ac7c46369 --- /dev/null +++ b/backend/app/services/storage_runtime/base.py @@ -0,0 +1,145 @@ +"""Base storage types and interfaces.""" + +from __future__ import annotations + +import hashlib +from dataclasses import dataclass +from pathlib import Path + + +@dataclass +class StorageEntry: + name: str + key: str + is_dir: bool + size: int = 0 + modified_at: str = "" + etag: str = "" + version_id: str = "" + content_hash: str = "" + + +@dataclass +class StorageVersion: + key: str + exists: bool + is_dir: bool + size: int = 0 + modified_at: str = "" + etag: str = "" + version_id: str = "" + content_hash: str = "" + + @property + def token(self) -> str: + return self.version_id or self.etag or self.content_hash or f"{self.modified_at}:{self.size}" + + +@dataclass +class WriteCondition: + version_token: str | None = None + require_absent: bool = False + + +@dataclass +class ConditionalWriteResult: + ok: bool + conflict: bool = False + current_version: StorageVersion | None = None + + +class StorageBackend: + async def exists(self, key: str) -> bool: + raise NotImplementedError + + async def is_file(self, key: str) -> bool: + raise NotImplementedError + + async def is_dir(self, key: str) -> bool: + raise NotImplementedError + + async def list_dir(self, key: str) -> list[StorageEntry]: + raise NotImplementedError + + async def read_bytes(self, key: str) -> bytes: + raise NotImplementedError + + async def read_text(self, key: str, encoding: str = "utf-8", errors: str = "replace") -> str: + raw = await self.read_bytes(key) + return raw.decode(encoding, errors=errors) + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + raise NotImplementedError + + async def write_text(self, key: str, content: str, encoding: str = "utf-8") -> None: + await self.write_bytes(key, content.encode(encoding), content_type="text/plain; charset=utf-8") + + async def delete(self, key: str) -> None: + raise NotImplementedError + + async def delete_tree(self, key: str) -> None: + raise NotImplementedError + + async def stat(self, key: str) -> StorageEntry: + raise NotImplementedError + + async def get_version(self, key: str) -> StorageVersion: + try: + entry = await self.stat(key) + except FileNotFoundError: + return StorageVersion(key=key, exists=False, is_dir=False) + return StorageVersion( + key=entry.key, + exists=True, + is_dir=entry.is_dir, + size=entry.size, + modified_at=entry.modified_at, + etag=entry.etag, + version_id=entry.version_id, + content_hash=entry.content_hash, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def delete_if_match( + self, + key: str, + *, + condition: WriteCondition | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent: + if current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + return ConditionalWriteResult(ok=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if current.exists: + await self.delete(key) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def local_path_for(self, key: str) -> Path | None: + return None + + async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: + return None + + +def content_hash_bytes(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() diff --git a/backend/app/services/storage_runtime/facade.py b/backend/app/services/storage_runtime/facade.py new file mode 100644 index 000000000..00133a217 --- /dev/null +++ b/backend/app/services/storage_runtime/facade.py @@ -0,0 +1,60 @@ +"""Facade for selecting the configured storage backend.""" + +from __future__ import annotations + +import mimetypes +from pathlib import Path + +from app.config import get_settings +from app.services.storage_runtime.base import StorageBackend +from app.services.storage_runtime.fallback import FallbackStorageBackend +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.storage_runtime.s3 import S3StorageBackend +from app.services.storage_runtime.utils import ( + agent_storage_prefix, + normalize_storage_key, + tenant_storage_prefix, +) + +_storage_backend: StorageBackend | None = None + + +def get_storage_backend() -> StorageBackend: + global _storage_backend + if _storage_backend is not None: + return _storage_backend + + settings = get_settings() + backend = (settings.STORAGE_BACKEND or "local").strip().lower() + if backend == "s3": + primary = S3StorageBackend( + bucket=settings.S3_BUCKET, + prefix=settings.S3_PREFIX, + region=settings.S3_REGION, + endpoint_url=settings.S3_ENDPOINT_URL, + access_key_id=settings.S3_ACCESS_KEY_ID, + secret_access_key=settings.S3_SECRET_ACCESS_KEY, + presign_ttl_seconds=settings.S3_PRESIGN_TTL_SECONDS, + max_pool_connections=settings.S3_MAX_POOL_CONNECTIONS, + write_workers=settings.S3_WRITE_WORKERS, + ) + if settings.STORAGE_LOCAL_FALLBACK_ENABLED: + fallback = LocalStorageBackend(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) + _storage_backend = FallbackStorageBackend(primary=primary, fallback=fallback) + else: + _storage_backend = primary + else: + _storage_backend = LocalStorageBackend(settings.STORAGE_LOCAL_ROOT or settings.AGENT_DATA_DIR) + return _storage_backend + + +async def ensure_local_path(key: str) -> Path: + backend = get_storage_backend() + path = await backend.local_path_for(key) + if path is None: + raise RuntimeError("Storage backend cannot materialize a local path") + return path + + +def guess_content_type(filename: str) -> str: + return mimetypes.guess_type(filename)[0] or "application/octet-stream" diff --git a/backend/app/services/storage_runtime/fallback.py b/backend/app/services/storage_runtime/fallback.py new file mode 100644 index 000000000..5699cfd6d --- /dev/null +++ b/backend/app/services/storage_runtime/fallback.py @@ -0,0 +1,106 @@ +"""Storage backend wrapper for gradual local-to-remote migration.""" + +from __future__ import annotations + +from pathlib import Path + +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) + + +class FallbackStorageBackend(StorageBackend): + """Read-through fallback backend. + + Writes go to the primary backend. Reads first try primary storage, then + fallback storage; fallback hits are copied into primary storage so old local + files are gradually migrated as they are used. + """ + + def __init__(self, primary: StorageBackend, fallback: StorageBackend): + self.primary = primary + self.fallback = fallback + + async def exists(self, key: str) -> bool: + return await self.primary.exists(key) or await self.fallback.exists(key) + + async def is_file(self, key: str) -> bool: + return await self.primary.is_file(key) or await self.fallback.is_file(key) + + async def is_dir(self, key: str) -> bool: + return await self.primary.is_dir(key) or await self.fallback.is_dir(key) + + async def list_dir(self, key: str) -> list[StorageEntry]: + entries_by_key: dict[str, StorageEntry] = {} + for entry in await self.fallback.list_dir(key): + entries_by_key[entry.key] = entry + for entry in await self.primary.list_dir(key): + entries_by_key[entry.key] = entry + return sorted(entries_by_key.values(), key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + if await self.primary.exists(key) and await self.primary.is_file(key): + return await self.primary.read_bytes(key) + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return data + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + await self.primary.write_bytes(key, data, content_type=content_type) + + async def delete(self, key: str) -> None: + await self.primary.delete(key) + await self.fallback.delete(key) + + async def delete_tree(self, key: str) -> None: + await self.primary.delete_tree(key) + await self.fallback.delete_tree(key) + + async def stat(self, key: str) -> StorageEntry: + if await self.primary.exists(key): + return await self.primary.stat(key) + entry = await self.fallback.stat(key) + if not entry.is_dir: + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return entry + + async def get_version(self, key: str) -> StorageVersion: + primary_version = await self.primary.get_version(key) + if primary_version.exists: + return primary_version + fallback_version = await self.fallback.get_version(key) + if fallback_version.exists and not fallback_version.is_dir: + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return await self.primary.get_version(key) + return fallback_version + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + return await self.primary.write_bytes_if_match(key, data, condition=condition, content_type=content_type) + + async def local_path_for(self, key: str) -> Path | None: + if await self.primary.exists(key): + return await self.primary.local_path_for(key) + path = await self.fallback.local_path_for(key) + if path is not None and await self.fallback.is_file(key): + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return path + + async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: + if not await self.primary.exists(key) and await self.fallback.exists(key) and await self.fallback.is_file(key): + data = await self.fallback.read_bytes(key) + await self.primary.write_bytes(key, data) + return await self.primary.presign_download_url(key, filename=filename, inline=inline) diff --git a/backend/app/services/storage_runtime/local.py b/backend/app/services/storage_runtime/local.py new file mode 100644 index 000000000..799361788 --- /dev/null +++ b/backend/app/services/storage_runtime/local.py @@ -0,0 +1,166 @@ +"""Local filesystem storage backend.""" + +from __future__ import annotations + +import asyncio +from pathlib import Path + +import aiofiles +from fastapi import HTTPException, status + +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, + content_hash_bytes, +) +from app.services.storage_runtime.utils import normalize_storage_key + + +class LocalStorageBackend(StorageBackend): + def __init__(self, root: str): + self.root = Path(root) + + def _full_path(self, key: str) -> Path: + normalized = normalize_storage_key(key) + full = (self.root / normalized).resolve() + root_resolved = self.root.resolve() + if not str(full).startswith(str(root_resolved)): + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Path traversal not allowed") + return full + + async def exists(self, key: str) -> bool: + return self._full_path(key).exists() + + async def is_file(self, key: str) -> bool: + return self._full_path(key).is_file() + + async def is_dir(self, key: str) -> bool: + return self._full_path(key).is_dir() + + async def list_dir(self, key: str) -> list[StorageEntry]: + base = self._full_path(key) + if not base.exists() or not base.is_dir(): + return [] + entries: list[StorageEntry] = [] + for entry in sorted(base.iterdir(), key=lambda item: (not item.is_dir(), item.name)): + if entry.name == ".gitkeep": + continue + stat = entry.stat() + rel = str(entry.resolve().relative_to(self.root.resolve())) + entries.append( + StorageEntry( + name=entry.name, + key=rel, + is_dir=entry.is_dir(), + size=stat.st_size if entry.is_file() else 0, + modified_at=str(stat.st_mtime), + version_id=_local_version_token(stat, None), + ) + ) + return entries + + async def read_bytes(self, key: str) -> bytes: + path = self._full_path(key) + async with aiofiles.open(path, "rb") as f: + return await f.read() + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + path = self._full_path(key) + path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(path, "wb") as f: + await f.write(data) + + async def delete(self, key: str) -> None: + path = self._full_path(key) + if not path.exists(): + return + if path.is_dir(): + await self.delete_tree(key) + else: + path.unlink() + + async def delete_tree(self, key: str) -> None: + path = self._full_path(key) + if not path.exists(): + return + await asyncio.to_thread(_local_delete_tree, path) + + async def stat(self, key: str) -> StorageEntry: + path = self._full_path(key) + stat = path.stat() + file_hash = "" + version_id = _local_version_token(stat, None) + if path.is_file(): + data = await self.read_bytes(key) + file_hash = content_hash_bytes(data) + version_id = _local_version_token(stat, file_hash) + return StorageEntry( + name=path.name, + key=normalize_storage_key(key), + is_dir=path.is_dir(), + size=stat.st_size if path.is_file() else 0, + modified_at=str(stat.st_mtime), + version_id=version_id, + etag=file_hash, + content_hash=file_hash, + ) + + async def get_version(self, key: str) -> StorageVersion: + path = self._full_path(key) + if not path.exists(): + return StorageVersion(key=normalize_storage_key(key), exists=False, is_dir=False) + stat = path.stat() + if path.is_dir(): + return StorageVersion( + key=normalize_storage_key(key), + exists=True, + is_dir=True, + modified_at=str(stat.st_mtime), + version_id=_local_version_token(stat, None), + ) + data = await self.read_bytes(key) + file_hash = content_hash_bytes(data) + return StorageVersion( + key=normalize_storage_key(key), + exists=True, + is_dir=False, + size=stat.st_size, + modified_at=str(stat.st_mtime), + etag=file_hash, + version_id=_local_version_token(stat, file_hash), + content_hash=file_hash, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def local_path_for(self, key: str) -> Path | None: + return self._full_path(key) + + +def _local_delete_tree(path: Path) -> None: + import shutil + + shutil.rmtree(path) + + +def _local_version_token(stat, file_hash: str | None) -> str: + hash_part = file_hash or "" + return f"{stat.st_mtime_ns}:{stat.st_size}:{hash_part}" diff --git a/backend/app/services/storage_runtime/s3.py b/backend/app/services/storage_runtime/s3.py new file mode 100644 index 000000000..808565a2b --- /dev/null +++ b/backend/app/services/storage_runtime/s3.py @@ -0,0 +1,320 @@ +"""S3-compatible object storage backend.""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any + +from loguru import logger + +from app.services.storage_runtime.base import ( + ConditionalWriteResult, + StorageBackend, + StorageEntry, + StorageVersion, + WriteCondition, +) +from app.services.storage_runtime.utils import normalize_storage_key + + +class S3StorageBackend(StorageBackend): + def __init__( + self, + *, + bucket: str, + prefix: str = "", + region: str = "", + endpoint_url: str = "", + access_key_id: str = "", + secret_access_key: str = "", + presign_ttl_seconds: int = 3600, + max_pool_connections: int = 50, + write_workers: int = 32, + ): + self.bucket = bucket + self.prefix = normalize_storage_key(prefix) + self.region = region + self.endpoint_url = endpoint_url or None + self.access_key_id = access_key_id or None + self.secret_access_key = secret_access_key or None + self.presign_ttl_seconds = presign_ttl_seconds + self.max_pool_connections = max_pool_connections + self._client: Any | None = None + self._aioboto3_session: Any | None = None + + def _object_key(self, key: str) -> str: + normalized = normalize_storage_key(key) + return f"{self.prefix}/{normalized}" if self.prefix else normalized + + def _client_or_raise(self): + if self._client is None: + try: + import boto3 + from botocore.config import Config + except ImportError as exc: + raise RuntimeError("boto3 is required for S3 storage backend") from exc + self._client = boto3.client( + "s3", + region_name=self.region or None, + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + config=Config( + max_pool_connections=self.max_pool_connections, + proxies={}, + s3={"addressing_style": "path"}, + signature_version="s3v4", + connect_timeout=5, + read_timeout=30, + tcp_keepalive=True, + ), + ) + return self._client + + @asynccontextmanager + async def _async_client(self): + """Shared aioboto3 session with aiohttp connection pool — reuses connections but detects stale ones correctly.""" + try: + import aioboto3 + from botocore.config import Config + except ImportError as exc: + raise RuntimeError("aioboto3 is required for async S3 writes") from exc + if self._aioboto3_session is None: + self._aioboto3_session = aioboto3.Session() + async with self._aioboto3_session.client( + "s3", + region_name=self.region or None, + endpoint_url=self.endpoint_url, + aws_access_key_id=self.access_key_id, + aws_secret_access_key=self.secret_access_key, + config=Config( + max_pool_connections=self.max_pool_connections, + proxies={}, + s3={"addressing_style": "path"}, + signature_version="s3v4", + connect_timeout=5, + read_timeout=30, + tcp_keepalive=True, + ), + ) as client: + yield client + + async def exists(self, key: str) -> bool: + return await self._object_exists(key) + + async def is_file(self, key: str) -> bool: + return await self._object_exists(key) + + async def _object_exists(self, key: str) -> bool: + object_key = self._object_key(key) + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=object_key, + MaxKeys=1, + ) + return any(item.get("Key") == object_key for item in response.get("Contents", [])) + + async def is_dir(self, key: str) -> bool: + prefix = self._object_key(key).rstrip("/") + "/" + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + Delimiter="/", + MaxKeys=1, + ) + return bool(response.get("Contents") or response.get("CommonPrefixes")) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = self._object_key(key).rstrip("/") + if prefix: + prefix += "/" + client = self._client_or_raise() + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + Delimiter="/", + ) + entries: list[StorageEntry] = [] + for item in response.get("CommonPrefixes", []): + raw = item.get("Prefix", "").rstrip("/") + rel = _strip_prefix(raw, self.prefix) + name = rel.split("/")[-1] + entries.append(StorageEntry(name=name, key=rel, is_dir=True)) + for item in response.get("Contents", []): + raw = item.get("Key", "") + if not raw or raw == prefix: + continue + rel = _strip_prefix(raw, self.prefix) + name = rel.split("/")[-1] + entries.append( + StorageEntry( + name=name, + key=rel, + is_dir=False, + size=int(item.get("Size", 0)), + modified_at=str(item.get("LastModified") or ""), + etag=_clean_etag(item.get("ETag")), + ) + ) + return sorted(entries, key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + client = self._client_or_raise() + response = await asyncio.to_thread( + client.get_object, + Bucket=self.bucket, + Key=self._object_key(key), + ) + body = response["Body"] + return await asyncio.to_thread(body.read) + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + kwargs: dict[str, Any] = { + "Bucket": self.bucket, + "Key": self._object_key(key), + "Body": data, + } + if content_type: + kwargs["ContentType"] = content_type + async with self._async_client() as client: + await client.put_object(**kwargs) + + async def delete(self, key: str) -> None: + async with self._async_client() as client: + await client.delete_object( + Bucket=self.bucket, + Key=self._object_key(key), + ) + + async def delete_tree(self, key: str) -> None: + client = self._client_or_raise() + prefix = self._object_key(key).rstrip("/") + "/" + response = await asyncio.to_thread( + client.list_objects_v2, + Bucket=self.bucket, + Prefix=prefix, + ) + contents = response.get("Contents", []) + if not contents: + return + objects = [{"Key": item["Key"]} for item in contents] + async with self._async_client() as client: + await client.delete_objects( + Bucket=self.bucket, + Delete={"Objects": objects}, + ) + + async def stat(self, key: str) -> StorageEntry: + version = await self.get_version(key) + if not version.exists: + raise FileNotFoundError(key) + return StorageEntry( + name=normalize_storage_key(key).split("/")[-1], + key=normalize_storage_key(key), + is_dir=version.is_dir, + size=version.size, + modified_at=version.modified_at, + etag=version.etag, + version_id=version.version_id, + content_hash=version.content_hash, + ) + + async def get_version(self, key: str) -> StorageVersion: + client = self._client_or_raise() + object_key = self._object_key(key) + try: + response = await asyncio.to_thread( + client.head_object, + Bucket=self.bucket, + Key=object_key, + ) + except Exception: + return StorageVersion(key=normalize_storage_key(key), exists=False, is_dir=False) + return StorageVersion( + key=normalize_storage_key(key), + exists=True, + is_dir=False, + size=int(response.get("ContentLength", 0)), + modified_at=str(response.get("LastModified") or ""), + etag=_clean_etag(response.get("ETag")), + version_id=str(response.get("VersionId") or ""), + content_hash=_clean_etag(response.get("ETag")), + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + async def _put_succeeded(self, key: str, expected_size: int) -> bool: + try: + entry = await self.stat(key) + except Exception: + return False + return entry.size == expected_size + + async def local_path_for(self, key: str) -> Path | None: + suffix = Path(normalize_storage_key(key)).suffix + tmp = NamedTemporaryFile(delete=False, suffix=suffix) + tmp.close() + path = Path(tmp.name) + await self.write_local_copy(key, path) + return path + + async def write_local_copy(self, key: str, path: Path) -> None: + data = await self.read_bytes(key) + await asyncio.to_thread(path.write_bytes, data) + + async def presign_download_url(self, key: str, filename: str | None = None, inline: bool = False) -> str | None: + client = self._client_or_raise() + params: dict[str, Any] = {"Bucket": self.bucket, "Key": self._object_key(key)} + if filename: + disposition = "inline" if inline else "attachment" + params["ResponseContentDisposition"] = f'{disposition}; filename="{filename}"' + return await asyncio.to_thread( + client.generate_presigned_url, + "get_object", + Params=params, + ExpiresIn=self.presign_ttl_seconds, + ) + + +def _strip_prefix(raw_key: str, prefix: str) -> str: + if prefix and raw_key.startswith(prefix + "/"): + return raw_key[len(prefix) + 1:] + return raw_key + + +def _is_header_parsing_error(exc: Exception) -> bool: + try: + from urllib3.exceptions import HeaderParsingError + except Exception: + return False + return isinstance(exc, HeaderParsingError) + + +def _clean_etag(raw: Any) -> str: + if raw is None: + return "" + text = str(raw) + return text.strip('"') diff --git a/backend/app/services/storage_runtime/utils.py b/backend/app/services/storage_runtime/utils.py new file mode 100644 index 000000000..98634ba43 --- /dev/null +++ b/backend/app/services/storage_runtime/utils.py @@ -0,0 +1,24 @@ +"""Storage path helpers.""" + + +def normalize_storage_key(key: str) -> str: + """Normalize a storage key and reject traversal semantics.""" + clean = (key or "").replace("\\", "/").strip().lstrip("/") + parts: list[str] = [] + for part in clean.split("/"): + if part in ("", "."): + continue + if part == "..": + if parts: + parts.pop() + continue + parts.append(part) + return "/".join(parts) + + +def agent_storage_prefix(agent_id: str) -> str: + return normalize_storage_key(agent_id) + + +def tenant_storage_prefix(tenant_id: str) -> str: + return normalize_storage_key(f"enterprise_info_{tenant_id}") diff --git a/backend/app/services/trigger_daemon.py b/backend/app/services/trigger_daemon.py index 0e9e3df10..de9531471 100644 --- a/backend/app/services/trigger_daemon.py +++ b/backend/app/services/trigger_daemon.py @@ -1,29 +1,33 @@ -"""Trigger Daemon — evaluates all agent triggers in a single background loop. +"""Trigger daemon orchestrator. -Replaces the separate heartbeat, scheduler, and supervision reminder services -with a unified trigger evaluation engine. Runs as an asyncio background task. - -Every 15 seconds: - 1. Load all enabled triggers from DB - 2. Evaluate each trigger (cron/once/interval/poll/on_message/webhook) - 3. Group fired triggers by agent_id (30s dedup window) - 4. Invoke each agent once with all its fired triggers as context +Trigger-specific evaluation and invocation behavior now lives under +`app.services.trigger_runtime`. This module owns the main loop, dedup window, +and distributed claim/invoke flow. """ import asyncio -import ipaddress -import json as _json import uuid from datetime import datetime, timezone, timedelta -from urllib.parse import urlparse - -from croniter import croniter from loguru import logger from sqlalchemy import select from app.database import async_session from app.models.trigger import AgentTrigger -from app.models.agent import Agent +from app.services.trigger_runtime.evaluator import ( + evaluate_trigger as evaluate_trigger_runtime, + handle_okr_collection_trigger as handle_okr_collection_trigger_runtime, + handle_okr_report_trigger as handle_okr_report_trigger_runtime, + mark_trigger_fired as mark_trigger_fired_runtime, + mark_trigger_skipped as mark_trigger_skipped_runtime, + should_skip_non_workday as should_skip_non_workday_runtime, +) +from app.services.trigger_runtime.invoker import invoke_agent_for_triggers as invoke_agent_for_triggers_runtime +from app.services.trigger_runtime import ( + claim_ready_trigger_invocations, + enqueue_due_trigger, + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) TICK_INTERVAL = 15 # seconds DEDUP_WINDOW = 30 # seconds — same agent won't be invoked twice within this window @@ -45,922 +49,29 @@ def _cleanup_stale_invoke_cache(): async def _should_skip_non_workday(trigger: AgentTrigger, local_now: datetime) -> bool: - """Skip OKR daily report triggers on company non-workdays when configured.""" - if trigger.name != "daily_okr_collection": - return False - - from app.models.okr import OKRSettings - from app.models.tenant import Tenant - from app.services.business_calendar import is_non_workday - - async with async_session() as db: - result = await db.execute( - select(Agent.tenant_id) - .where(Agent.id == trigger.agent_id) - ) - tenant_id = result.scalar_one_or_none() - if not tenant_id: - return False - - settings_result = await db.execute( - select(OKRSettings.daily_report_skip_non_workdays) - .where(OKRSettings.tenant_id == tenant_id) - ) - skip_enabled = settings_result.scalar_one_or_none() - if skip_enabled is False: - return False - - tenant_result = await db.execute( - select(Tenant.country_region).where(Tenant.id == tenant_id) - ) - country_region = tenant_result.scalar_one_or_none() - - return is_non_workday(local_now.date(), country_region) + return await should_skip_non_workday_runtime(trigger, local_now) async def _mark_trigger_skipped(trigger_id: uuid.UUID, now: datetime) -> None: - """Advance a cron trigger without invoking the agent.""" - try: - async with async_session() as db: - result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) - trigger = result.scalar_one_or_none() - if trigger: - trigger.last_fired_at = now - await db.commit() - except Exception as e: - logger.warning(f"Failed to mark skipped trigger {trigger_id}: {e}") + await mark_trigger_skipped_runtime(trigger_id, now) async def _mark_trigger_fired(trigger_id: uuid.UUID, now: datetime) -> None: - """Persist fire metadata for a trigger that was already handled.""" - try: - async with async_session() as db: - result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) - trigger = result.scalar_one_or_none() - if trigger: - trigger.last_fired_at = now - trigger.fire_count += 1 - if trigger.type == "once": - trigger.is_enabled = False - if trigger.max_fires and trigger.fire_count >= trigger.max_fires: - trigger.is_enabled = False - await db.commit() - except Exception as e: - logger.warning(f"Failed to mark fired trigger {trigger_id}: {e}") + await mark_trigger_fired_runtime(trigger_id, now) async def _handle_okr_report_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Handle company-level OKR report generation without waking the agent.""" - if trigger.name not in {"daily_okr_report", "weekly_okr_report", "monthly_okr_report"}: - return False - - from zoneinfo import ZoneInfo - from app.models.okr import OKRSettings - from app.services.okr_reporting import ( - generate_company_daily_report, - generate_company_monthly_report, - generate_company_weekly_report, - ) - from app.services.timezone_utils import get_agent_timezone - - async with async_session() as db: - agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) - tenant_id = agent_result.scalar_one_or_none() - if not tenant_id: - return True - - settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) - settings = settings_result.scalar_one_or_none() - if not settings or not settings.enabled: - return True - - tz_name = await get_agent_timezone(trigger.agent_id) - try: - tz = ZoneInfo(tz_name) - except Exception: - tz = ZoneInfo("UTC") - local_today = now.astimezone(tz).date() - - if trigger.name == "daily_okr_report": - await generate_company_daily_report(tenant_id, local_today - timedelta(days=1)) - elif trigger.name == "weekly_okr_report": - previous_week_anchor = local_today - timedelta(days=7) - week_start = previous_week_anchor - timedelta(days=previous_week_anchor.weekday()) - await generate_company_weekly_report(tenant_id, week_start) - elif trigger.name == "monthly_okr_report": - previous_month_end = local_today.replace(day=1) - timedelta(days=1) - await generate_company_monthly_report(tenant_id, previous_month_end) - - await _mark_trigger_fired(trigger.id, now) - logger.info(f"[Trigger] Auto-generated OKR report for trigger {trigger.name}") - return True + return await handle_okr_report_trigger_runtime(trigger, now) async def _handle_okr_collection_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Handle deterministic OKR daily collection without relying on a free-form LLM plan.""" - if trigger.name != "daily_okr_collection": - return False - - from app.models.okr import OKRSettings - from app.services.okr_daily_collection import trigger_daily_collection_for_tenant - - async with async_session() as db: - agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) - tenant_id = agent_result.scalar_one_or_none() - if not tenant_id: - return True - - settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) - settings = settings_result.scalar_one_or_none() - if not settings or not settings.enabled or not settings.daily_report_enabled: - return True - - await trigger_daily_collection_for_tenant(tenant_id) - await _mark_trigger_fired(trigger.id, now) - logger.info(f"[Trigger] Deterministic OKR collection sent for trigger {trigger.name}") - return True - -# Webhook rate limiter: token -> list of timestamps -_webhook_hits: dict[str, list[float]] = {} -WEBHOOK_RATE_LIMIT = 5 # max hits per minute per token - - -# ── SSRF Protection ───────────────────────────────────────────────── - -def _is_private_url(url: str) -> bool: - """Block private/internal URLs to prevent SSRF attacks.""" - try: - parsed = urlparse(url) - hostname = parsed.hostname - if not hostname: - return True - - # Block obvious private hostnames - if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): - return True - - # Try to resolve hostname and check IP - import socket - try: - infos = socket.getaddrinfo(hostname, None) - for info in infos: - ip = ipaddress.ip_address(info[4][0]) - if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: - return True - except (socket.gaierror, ValueError): - return True # Cannot resolve = block - - return False - except Exception: - return True # Block on any parsing error - - -# ── Trigger Evaluation ────────────────────────────────────────────── + return await handle_okr_collection_trigger_runtime(trigger, now) async def _evaluate_trigger(trigger: AgentTrigger, now: datetime) -> bool: - """Return True if this trigger should fire right now.""" - if not trigger.is_enabled: - return False - if trigger.expires_at and now >= trigger.expires_at: - # Auto-disable expired triggers - return False - if trigger.max_fires is not None and trigger.fire_count >= trigger.max_fires: - return False - - # Cooldown check - if trigger.last_fired_at: - cooldown = timedelta(seconds=trigger.cooldown_seconds) - if (now - trigger.last_fired_at) < cooldown: - return False - - cfg = trigger.config or {} - t = trigger.type - - if t == "cron": - expr = cfg.get("expr", "* * * * *") - base = trigger.last_fired_at or trigger.created_at - try: - # Resolve timezone: trigger config → agent → tenant → UTC - tz_name = cfg.get("timezone") - if not tz_name: - from app.services.timezone_utils import get_agent_timezone - tz_name = await get_agent_timezone(trigger.agent_id) - from zoneinfo import ZoneInfo - try: - tz = ZoneInfo(tz_name) - except (KeyError, Exception): - tz = ZoneInfo("UTC") - # Evaluate cron in agent's timezone - local_now = now.astimezone(tz) - local_base = base.astimezone(tz) if base.tzinfo else base.replace(tzinfo=tz) - cron = croniter(expr, local_base) - next_run = cron.get_next(datetime) - if local_now >= next_run: - if await _should_skip_non_workday(trigger, local_now): - await _mark_trigger_skipped(trigger.id, now) - logger.info(f"[Trigger] Skipped {trigger.name} on non-workday {local_now.date()}") - return False - return True - return False - except Exception as e: - logger.warning(f"Invalid cron expr '{expr}' for trigger {trigger.name}: {e}") - return False - - elif t == "once": - at_str = cfg.get("at") - if not at_str: - return False - try: - at = datetime.fromisoformat(at_str) - if at.tzinfo is None: - at = at.replace(tzinfo=timezone.utc) - return now >= at and trigger.fire_count == 0 - except Exception: - return False - - elif t == "interval": - minutes = cfg.get("minutes", 30) - base = trigger.last_fired_at or trigger.created_at - return (now - base) >= timedelta(minutes=minutes) - - elif t == "poll": - interval_min = max(cfg.get("interval_min", 5), MIN_POLL_INTERVAL_MINUTES) - base = trigger.last_fired_at or trigger.created_at - if (now - base) < timedelta(minutes=interval_min): - return False - # Actual HTTP poll + change detection - return await _poll_check(trigger) - - elif t == "on_message": - return await _check_new_agent_messages(trigger) - - elif t == "webhook": - # Check if a webhook payload is pending - if cfg.get("_webhook_pending"): - return True - return False - - return False - - -async def _poll_check(trigger: AgentTrigger) -> bool: - """HTTP poll: fetch URL, extract value via json_path, detect change. - - Persists _last_value into the trigger's config JSONB so it survives - across process restarts. - """ - import httpx - cfg = trigger.config or {} - url = cfg.get("url") - if not url: - return False - - # SSRF protection: block private/internal URLs - if _is_private_url(url): - logger.warning(f"Poll blocked for trigger {trigger.name}: private/internal URL '{url}'") - return False - - try: - async with httpx.AsyncClient(timeout=10) as client: - resp = await client.request(cfg.get("method", "GET"), url, headers=cfg.get("headers", {})) - resp.raise_for_status() - - data = resp.json() - json_path = cfg.get("json_path", "$") - current_value = _extract_json_path(data, json_path) - current_str = str(current_value) - - fire_on = cfg.get("fire_on", "change") - should_fire = False - - if fire_on == "match": - should_fire = current_str == str(cfg.get("match_value", "")) - else: # "change" - last_value = cfg.get("_last_value") - # First poll — don't fire, just record baseline - if last_value is None: - should_fire = False - else: - should_fire = current_str != last_value - - # Persist _last_value to DB so it survives restarts - cfg["_last_value"] = current_str - try: - from sqlalchemy import update - async with async_session() as db: - await db.execute( - update(AgentTrigger) - .where(AgentTrigger.id == trigger.id) - .values(config=cfg) - ) - await db.commit() - except Exception as e: - logger.warning(f"Failed to persist poll _last_value for {trigger.name}: {e}") - - return should_fire - - except Exception as e: - logger.warning(f"Poll failed for trigger {trigger.name}: {e}") - return False - - -def _extract_json_path(data, path: str): - """Simple JSONPath extraction: $.key.subkey → data['key']['subkey'].""" - if path == "$" or not path: - return data - parts = path.lstrip("$.").split(".") - current = data - for part in parts: - if isinstance(current, dict): - current = current.get(part) - elif isinstance(current, list) and part.isdigit(): - current = current[int(part)] - else: - return None - return current - - -async def _check_new_agent_messages(trigger: AgentTrigger) -> bool: - """Check if there are new messages matching this trigger. - - Supports two modes: - - from_agent_name: check for agent-to-agent messages - - from_user_name: check for human user messages (Feishu/Slack/Discord) - - Stores the actual message content in trigger.config['_matched_message'] - so the invocation context can include it. - """ - from app.models.audit import ChatMessage - from app.models.chat_session import ChatSession - - cfg = trigger.config or {} - from_agent_name = cfg.get("from_agent_name") - from_user_name = cfg.get("from_user_name") - - if not from_agent_name and not from_user_name: - return False - - since = trigger.last_fired_at or trigger.created_at - # Use _since_ts snapshot from trigger creation (set by _handle_set_trigger) - # This is more precise than the old 5-minute lookback which caused false positives - if trigger.fire_count == 0 and not trigger.last_fired_at: - since_ts_str = cfg.get("_since_ts") - if since_ts_str: - try: - since = datetime.fromisoformat(since_ts_str) - except Exception: - since = trigger.created_at - # No _since_ts and no last_fired_at → use trigger.created_at (no lookback) - - try: - async with async_session() as db: - if from_agent_name: - # --- Agent-to-agent message check (existing logic) --- - from app.models.participant import Participant - from app.models.agent import Agent as AgentModel - safe_agent_name = from_agent_name.replace("%", "").replace("_", r"\_") - agent_r = await db.execute( - select(AgentModel).where(AgentModel.name.ilike(f"%{safe_agent_name}%")) - ) - source_agent = agent_r.scalars().first() - if not source_agent: - return False - - result = await db.execute( - select(Participant.id).where( - Participant.type == "agent", - Participant.ref_id == source_agent.id, - ) - ) - from_participant = result.scalar_one_or_none() - if not from_participant: - return False - - from sqlalchemy import cast as sa_cast, String as SaString - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatMessage.participant_id == from_participant, - ChatMessage.created_at > since, - ChatMessage.role == "assistant", - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - msg = result.scalar_one_or_none() - if not msg: - return False - cfg["_matched_message"] = (msg.content or "")[:2000] - cfg["_matched_from"] = from_agent_name - return True - - elif from_user_name: - # --- Human user message check (Feishu/Slack/Discord) --- - # Find sessions for this agent from external channels - from sqlalchemy import cast as sa_cast, String as SaString - from app.models.user import User - from app.models.agent import Agent as AgentModel - - # 0. Get agent for tenant scoping - agent_r = await db.execute(select(AgentModel).where(AgentModel.id == trigger.agent_id)) - agent = agent_r.scalar_one_or_none() - - # Look up user by display name or username within tenant - from sqlalchemy import or_ - from app.models.user import User, Identity - safe_user_name = from_user_name.replace("%", "").replace("_", r"\_") - query = ( - select(User) - .join(User.identity) - .where( - or_( - User.display_name.ilike(f"%{safe_user_name}%"), - Identity.username.ilike(f"%{safe_user_name}%"), - ) - ) - ) - if agent and agent.tenant_id: - query = query.where(User.tenant_id == agent.tenant_id) - - user_r = await db.execute(query) - target_user = user_r.scalars().first() - - if target_user: - # Find channel sessions for this user with this agent - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatSession.agent_id == trigger.agent_id, - ChatSession.user_id == target_user.id, - ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), - ChatMessage.role == "user", - ChatMessage.created_at > since, - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - else: - # Fallback: search by session title or message content containing the target name - result = await db.execute( - select(ChatMessage).join( - ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString) - ).where( - ChatSession.agent_id == trigger.agent_id, - ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), - ChatMessage.role == "user", - ChatMessage.created_at > since, - or_( - ChatSession.title.ilike(f"%{safe_user_name}%"), - ChatMessage.content.ilike(f"%{safe_user_name}%"), - ), - ).order_by(ChatMessage.created_at.desc()).limit(1) - ) - - msg = result.scalar_one_or_none() - if not msg: - return False - cfg["_matched_message"] = (msg.content or "")[:2000] - cfg["_matched_from"] = from_user_name - return True - - except Exception as e: - logger.warning(f"on_message check failed for trigger {trigger.name}: {e}") - return False - - -# ── Agent Invocation ──────────────────────────────────────────────── - -async def _resolve_trigger_delivery_target(agent: Agent, triggers: list[AgentTrigger]) -> dict | None: - """Resolve where a trigger result should be delivered. - - Priority: - 1. Explicit A2A callback session - 2. Originating agent-to-agent session - 3. Originating platform user → that user's primary platform session - 4. Pure trigger/reflection context → no user-facing delivery - """ - from app.models.chat_session import ChatSession - from app.services.chat_session_service import ensure_primary_platform_session - - # Synthetic A2A wake triggers already carry the callback session explicitly. - for trigger in triggers: - cfg = trigger.config or {} - a2a_sid = cfg.get("_a2a_session_id") - if a2a_sid: - try: - async with async_session() as db: - session = await db.get(ChatSession, uuid.UUID(a2a_sid)) - if not session: - return None - return { - "kind": "session", - "session_id": str(session.id), - "owner_user_id": str(session.user_id), - "source_channel": session.source_channel, - } - except Exception: - return None - - origin_cfg = None - for trigger in triggers: - cfg = trigger.config or {} - if cfg.get("_origin_session_id") or cfg.get("_origin_user_id"): - origin_cfg = cfg - break - - if not origin_cfg: - return None - - origin_source_channel = origin_cfg.get("_origin_source_channel") - origin_session_id = origin_cfg.get("_origin_session_id") - origin_user_id = origin_cfg.get("_origin_user_id") - - if origin_source_channel == "agent" and origin_session_id: - try: - async with async_session() as db: - session = await db.get(ChatSession, uuid.UUID(origin_session_id)) - if not session: - return None - return { - "kind": "session", - "session_id": str(session.id), - "owner_user_id": str(session.user_id), - "source_channel": "agent", - } - except Exception: - return None - - if origin_source_channel != "trigger" and origin_user_id: - try: - async with async_session() as db: - primary = await ensure_primary_platform_session( - db, - agent.id, - uuid.UUID(origin_user_id), - ) - await db.commit() - return { - "kind": "primary_user_session", - "session_id": str(primary.id), - "owner_user_id": str(primary.user_id), - "source_channel": primary.source_channel, - } - except Exception: - return None - - return None + return await evaluate_trigger_runtime(trigger, now) async def _invoke_agent_for_triggers(agent_id: uuid.UUID, triggers: list[AgentTrigger]): - """Invoke an agent with context from one or more fired triggers. - - Creates a Reflection Session and calls the LLM. - """ - from app.services.llm import call_llm - from app.services.agent_context import build_agent_context - from app.models.llm import LLMModel - from app.models.audit import ChatMessage - from app.models.chat_session import ChatSession - from app.models.participant import Participant - from app.services.audit_logger import write_audit_log - - try: - async with async_session() as db: - # Load agent - result = await db.execute(select(Agent).where(Agent.id == agent_id)) - agent = result.scalar_one_or_none() - if not agent or agent.is_expired: - return - - # Load LLM model - if not agent.primary_model_id: - logger.warning(f"Agent {agent.name} has no LLM model, skipping trigger invocation") - return - result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) - model = result.scalar_one_or_none() - if not model: - return - # Skip invocation if model is disabled by admin - if not model.enabled: - logger.warning(f"Agent {agent.name}'s model {model.model} is disabled, skipping trigger invocation") - return - - # Build trigger context. Keep this model-facing prompt in English so - # autonomous wakeups behave consistently across UI locales. - context_parts = [] - trigger_names = [] - for t in triggers: - part = f"Trigger: {t.name} ({t.type})\nReason: {t.reason}" - if t.name == "daily_okr_collection": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether daily report collection is enabled. " - "If it is enabled, only contact members and digital employees in your relationship network to collect today's final daily reports, " - "then organize them into a formal daily report no longer than 2000 characters. " - "If it is disabled, state that no action is needed and stop." - ) - elif t.name in ("daily_okr_report", "weekly_okr_report", "monthly_okr_report"): - part += ( - "\nExecution requirements: This company-level report is generated automatically by the system. " - "If you are awakened, only add necessary clarification. Do not start another member collection round." - ) - elif t.name == "biweekly_okr_checkin": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether OKR is enabled. " - "If enabled, check the current-cycle company and member OKRs, then proactively remind members who have not set OKRs or whose progress is lagging. " - "If disabled, state that no action is needed and stop." - ) - elif t.name == "monthly_okr_report": - part += ( - "\nExecution requirements: First call get_okr_settings to confirm whether OKR is enabled. " - "If enabled, call generate_monthly_okr_report to generate the OKR monthly report for the month that just ended, " - "then send it to admins or publish it to Plaza. If disabled, state that no action is needed and stop." - ) - if t.focus_ref: - part += f"\nRelated Focus: {t.focus_ref}" - # Include matched message for on_message triggers - cfg = t.config or {} - if t.type == "on_message" and cfg.get("_matched_message"): - part += f"\nMatched message from {cfg.get('_matched_from', '?')}:\n\"{cfg['_matched_message'][:500]}\"" - if t.type == "on_message" and cfg.get("okr_member_id") and cfg.get("okr_report_date"): - part += ( - "\nExecution requirements: This is a daily-report reply ingestion event." - f"\n1. Organize the other party's reply into a final daily report no longer than 2000 characters." - f"\n2. Immediately call upsert_member_daily_report(report_date=\"{cfg['okr_report_date']}\", " - f"member_type=\"{cfg.get('okr_member_type', 'user')}\", " - f"member_id=\"{cfg['okr_member_id']}\", content=\"\")." - "\n3. After the tool call succeeds, send a brief confirmation that you received and recorded it." - "\n4. Do not only confirm without calling the tool, and do not store the raw long conversation verbatim as the daily report." - ) - # Include webhook payload - if t.type == "webhook" and cfg.get("_webhook_payload"): - payload_str = cfg["_webhook_payload"] - if len(payload_str) > 2000: - payload_str = payload_str[:2000] + "... (truncated)" - part += f"\nWebhook Payload:\n{payload_str}" - context_parts.append(part) - trigger_names.append(t.name) - - trigger_context = ( - "===== Wake Context =====\n" - f"Wake source: trigger ({'multiple triggers fired together' if len(triggers) > 1 else 'trigger fired'})\n\n" - + "\n---\n".join(context_parts) - + "\n========================" - ) - - # Create Reflection Session - title = f"🤖 Reflection: {', '.join(trigger_names)}" - # Find agent's participant - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - session = ChatSession( - agent_id=agent_id, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - source_channel="trigger", - title=title[:200], - ) - db.add(session) - await db.flush() - session_id = session.id - - # Messages: trigger context only (call_llm builds system prompt internally) - messages = [ - {"role": "user", "content": trigger_context}, - ] - - # Store trigger context as a message in the session - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="user", - content=trigger_context, - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - await db.commit() - # Cache participant ID for callbacks - agent_participant_id = agent_participant.id if agent_participant else None - - # Call LLM (outside the DB session to avoid long transactions) - collected_content = [] - delivered_platform_message_via_tool = False - - async def on_chunk(text): - collected_content.append(text) - - # Persist tool calls into Reflection Session for Reflections visibility - async def on_tool_call(data): - nonlocal delivered_platform_message_via_tool - try: - tool_name = data.get("name") - tool_status = data.get("status") - if tool_status == "done" and tool_name == "send_platform_message": - result_text = str(data.get("result", "")) - if result_text.startswith("✅"): - delivered_platform_message_via_tool = True - - async with async_session() as _tc_db: - if data["status"] == "running": - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "args": data["args"]}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - elif data["status"] == "done": - result_str = str(data.get("result", ""))[:2000] - _tc_db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="tool_call", - content=_json.dumps({"name": data["name"], "result": result_str}, ensure_ascii=False, default=str), - user_id=agent.creator_id, - participant_id=agent_participant_id, - )) - await _tc_db.commit() - except Exception as e: - logger.warning(f"Failed to persist tool call for trigger session: {e}") - - reply = await call_llm( - model=model, - messages=messages, - agent_name=agent.name, - role_description=agent.role_description or "", - agent_id=agent_id, - user_id=agent.creator_id, - session_id=str(session_id), - on_chunk=on_chunk, - on_tool_call=on_tool_call, - # A2A wake uses the agent's own max_tool_rounds setting (no override) - ) - - # Save assistant reply to Reflection session - async with async_session() as db: - result = await db.execute( - select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) - ) - agent_participant = result.scalar_one_or_none() - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=str(session_id), - role="assistant", - content=reply or "".join(collected_content), - user_id=agent.creator_id, - participant_id=agent_participant.id if agent_participant else None, - )) - - # NOTE: trigger state (last_fired_at, fire_count, auto-disable) - # is already updated in _tick() BEFORE this task was launched, - # to prevent race-condition duplicate fires. - - await db.commit() - - # Compute final reply text once - final_reply = reply or "".join(collected_content) - - # ── Save reply to A2A session if this was an agent-to-agent wake ── - # This makes the target agent's reply visible in the A2A chat history - for t in triggers: - a2a_sid = (t.config or {}).get("_a2a_session_id") - if a2a_sid and final_reply: - try: - async with async_session() as db: - from app.models.participant import Participant as _P - _p_r = await db.execute(select(_P).where(_P.type == "agent", _P.ref_id == agent_id)) - _p = _p_r.scalar_one_or_none() - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=a2a_sid, - role="assistant", - content=final_reply, - user_id=agent.creator_id, - participant_id=_p.id if _p else None, - )) - # Update session timestamp - from app.models.chat_session import ChatSession as _CS - _cs_r = await db.execute(select(_CS).where(_CS.id == uuid.UUID(a2a_sid))) - _cs = _cs_r.scalar_one_or_none() - if _cs: - _cs.last_message_at = datetime.now(timezone.utc) - await db.commit() - logger.info(f"[A2A] Saved reply to A2A session {a2a_sid}") - except Exception as e: - logger.warning(f"[A2A] Failed to save reply to A2A session {a2a_sid}: {e}") - break # Only save once - - # Route trigger results to a single deterministic destination. Pure reflection/system - # wakes stay inside the reflection session and should not spill into arbitrary user chats. - is_a2a_internal = all(t.name == "a2a_wake" for t in triggers) - delivery_target = None if is_a2a_internal else await _resolve_trigger_delivery_target(agent, triggers) - - if final_reply and delivery_target and not delivered_platform_message_via_tool: - try: - from app.api.websocket import manager as ws_manager - agent_id_str = str(agent_id) - - # Build notification message with trigger badge - trigger_reasons = [] - for t in triggers: - ns = (t.config or {}).get("_notification_summary", "").strip() - if ns: - trigger_reasons.append(ns) - else: - r = (t.reason or "").strip() - if r and len(r) <= 80: - trigger_reasons.append(r) - elif r: - trigger_reasons.append(r[:77] + "...") - summary = trigger_reasons[0] if trigger_reasons else "有新的事件需要处理" - - _is_a2a_wait = any(t.name.startswith("a2a_wait_") for t in triggers) - if _is_a2a_wait: - import re as _re - cleaned = final_reply - _internal_patterns = [ - r'\b(a2a_wait_\w+|a2a_wake)\b', - r'\bwait_?\w+_?(task|reply|followup|meeting|sync|api_key)\w*\b', - r'\bresolve_\w+\b', - r'\bfocus[_ ]?item\b', - r'\btask_delegate\b', - r'\bfocus_ref\b', - r'✅\s*(a2a\w+|wait\w+|触发器\w*|focus\w*).*(?:已取消|已为|保持|活跃|完成状态)[^\n]*', - r'[\-•]\s*(?:触发器|trigger|focus|wait_\w+|a2a\w+).*[^\n]*', - r'(?:触发器|trigger)\s+\S+\s*(?:已取消|保持活跃|已为完成状态|fired)', - r'已静默清理触发器', - r'已静默处理完毕', - r'继续待命[。,]?\s*', - r',?\s*(?:继续)?待命。', - ] - for _pat in _internal_patterns: - cleaned = _re.sub(_pat, '', cleaned, flags=_re.IGNORECASE) - cleaned = _re.sub(r'\n{3,}', '\n\n', cleaned).strip() - cleaned = _re.sub(r'[。,]\s*$', '', cleaned).strip() - if not cleaned: - cleaned = final_reply - else: - cleaned = final_reply - - notification = f"⚡ {summary}\n\n{cleaned}" - - target_session_id = delivery_target["session_id"] - owner_user_id = delivery_target.get("owner_user_id") - - # Save to the resolved destination session for persistence. - async with async_session() as db: - from app.models.chat_session import ChatSession - from app.api.websocket import maybe_mark_session_read_for_active_viewer - - db.add(ChatMessage( - agent_id=agent_id, - conversation_id=target_session_id, - role="assistant", - content=notification, - user_id=agent.creator_id, - )) - session_row = await db.get(ChatSession, uuid.UUID(target_session_id)) - if session_row: - session_row.last_message_at = datetime.now(timezone.utc) - if owner_user_id: - await maybe_mark_session_read_for_active_viewer( - db, - agent_id=agent_id, - session_id=target_session_id, - user_id=uuid.UUID(owner_user_id), - ) - await db.commit() - - payload = { - "type": "trigger_notification", - "content": notification, - "triggers": [t.name for t in triggers], - "session_id": target_session_id, - } - - # Notify only the user who owns the destination session. The frontend will append - # the message only when that exact session is open; otherwise it just refreshes - # unread/session state. - if owner_user_id: - await ws_manager.send_to_user(agent_id_str, owner_user_id, payload) - except Exception as e: - logger.error(f"Failed to push trigger result to WebSocket: {e}") - import traceback - traceback.print_exc() - - # Audit log - await write_audit_log("trigger_fired", { - "agent_name": agent.name, - "triggers": [{"name": t.name, "type": t.type} for t in triggers], - }, agent_id=agent_id) - - logger.info(f"⚡ Triggers fired for {agent.name}: {[t.name for t in triggers]}") - - except Exception as e: - logger.error(f"Failed to invoke agent {agent_id} for triggers: {e}") - import traceback - traceback.print_exc() + await invoke_agent_for_triggers_runtime(agent_id, triggers) # ── Main Tick Loop ────────────────────────────────────────────────── @@ -979,8 +90,8 @@ async def _tick(): return - # Evaluate and group fired triggers by agent - fired_by_agent: dict[uuid.UUID, list[AgentTrigger]] = {} + # Evaluate and enqueue due triggers. Agent invocation happens only after + # executions are claimed through the distributed execution queue. for trigger in all_triggers: # Auto-disable expired triggers if trigger.expires_at and now >= trigger.expires_at: @@ -998,14 +109,22 @@ async def _tick(): if not handled: handled = await _handle_okr_collection_trigger(trigger, now) if not handled: - fired_by_agent.setdefault(trigger.agent_id, []).append(trigger) + await enqueue_due_trigger(trigger, now) except Exception as e: logger.warning(f"Error evaluating trigger {trigger.name}: {e}") + # Claim queued executions with a DB lease so only one worker handles each event. + try: + fired_by_agent, force_invoke_agents = await claim_ready_trigger_invocations(now) + except Exception as e: + logger.warning(f"Failed to claim trigger executions: {e}") + fired_by_agent = {} + force_invoke_agents = set() + # Invoke each agent (with dedup window) for agent_id, agent_triggers in fired_by_agent.items(): last = _last_invoke.get(agent_id) - if last and (now - last).total_seconds() < DEDUP_WINDOW: + if agent_id not in force_invoke_agents and last and (now - last).total_seconds() < DEDUP_WINDOW: continue # Skip — invoked too recently _last_invoke[agent_id] = now @@ -1017,6 +136,8 @@ async def _tick(): try: async with async_session() as db: for t in agent_triggers: + if (t.config or {}).get("_execution_id"): + continue result = await db.execute( select(AgentTrigger).where(AgentTrigger.id == t.id) ) @@ -1027,12 +148,6 @@ async def _tick(): # Auto-disable single-shot types only if trigger.type == "once": trigger.is_enabled = False - if trigger.type == "webhook" and trigger.config: - trigger.config = { - **trigger.config, - "_webhook_pending": False, - "_webhook_payload": None, - } if trigger.max_fires and trigger.fire_count >= trigger.max_fires: trigger.is_enabled = False await db.commit() diff --git a/backend/app/services/trigger_runtime/__init__.py b/backend/app/services/trigger_runtime/__init__.py new file mode 100644 index 000000000..24bf9da1c --- /dev/null +++ b/backend/app/services/trigger_runtime/__init__.py @@ -0,0 +1,30 @@ +"""Distributed trigger runtime helpers.""" + +from app.services.trigger_runtime.dispatch import ( + claim_ready_trigger_invocations, + enqueue_due_trigger, + runtime_execution_payload, +) +from app.services.trigger_runtime.executions import ( + build_execution_runtime_trigger, + claim_pending_trigger_executions, + mark_base_triggers_fired, + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) +from app.services.trigger_runtime.keys import build_scheduled_execution_key +from app.services.trigger_runtime.queue import enqueue_trigger_execution, enqueue_webhook_execution + +__all__ = [ + "build_execution_runtime_trigger", + "build_scheduled_execution_key", + "claim_ready_trigger_invocations", + "claim_pending_trigger_executions", + "enqueue_due_trigger", + "enqueue_trigger_execution", + "enqueue_webhook_execution", + "mark_base_triggers_fired", + "mark_trigger_executions_completed", + "mark_trigger_executions_failed", + "runtime_execution_payload", +] diff --git a/backend/app/services/trigger_runtime/dispatch.py b/backend/app/services/trigger_runtime/dispatch.py new file mode 100644 index 000000000..7d18e83ef --- /dev/null +++ b/backend/app/services/trigger_runtime/dispatch.py @@ -0,0 +1,64 @@ +"""Dispatch helpers for trigger executions.""" + +from __future__ import annotations + +import uuid +from datetime import datetime + +from app.database import async_session +from app.models.trigger import AgentTrigger +from app.services.trigger_runtime.executions import ( + build_execution_runtime_trigger, + claim_pending_trigger_executions, + mark_base_triggers_fired, +) +from app.services.trigger_runtime.keys import build_scheduled_execution_key +from app.services.trigger_runtime.queue import enqueue_trigger_execution + + +def runtime_execution_payload(trigger: AgentTrigger) -> dict: + """Capture ephemeral trigger evaluation context into an execution payload.""" + cfg = trigger.config or {} + payload: dict = {} + for key in ( + "_matched_message", + "_matched_from", + "okr_member_id", + "okr_member_type", + "okr_report_date", + "_notification_summary", + "_origin_session_id", + "_origin_user_id", + "_origin_source_channel", + "_a2a_session_id", + ): + if key in cfg and cfg.get(key) is not None: + payload[key] = cfg.get(key) + return payload + + +async def enqueue_due_trigger(trigger: AgentTrigger, now: datetime) -> None: + async with async_session() as db: + await enqueue_trigger_execution( + db, + trigger=trigger, + source=trigger.type, + idempotency_key=build_scheduled_execution_key(trigger, now), + payload_obj=runtime_execution_payload(trigger), + ) + + +async def claim_ready_trigger_invocations(now: datetime) -> tuple[dict[uuid.UUID, list[AgentTrigger]], set[uuid.UUID]]: + fired_by_agent: dict[uuid.UUID, list[AgentTrigger]] = {} + force_invoke_agents: set[uuid.UUID] = set() + + claimed_executions = await claim_pending_trigger_executions() + if claimed_executions: + await mark_base_triggers_fired([trigger.id for _execution, trigger in claimed_executions], now) + + for execution, trigger in claimed_executions: + runtime_trigger = build_execution_runtime_trigger(trigger, execution) + fired_by_agent.setdefault(trigger.agent_id, []).append(runtime_trigger) + force_invoke_agents.add(trigger.agent_id) + + return fired_by_agent, force_invoke_agents diff --git a/backend/app/services/trigger_runtime/evaluator.py b/backend/app/services/trigger_runtime/evaluator.py new file mode 100644 index 000000000..6a43aa63b --- /dev/null +++ b/backend/app/services/trigger_runtime/evaluator.py @@ -0,0 +1,429 @@ +"""Trigger evaluation and deterministic special-case handlers.""" + +from __future__ import annotations + +import ipaddress +import uuid +from datetime import datetime, timezone, timedelta +from urllib.parse import urlparse + +from croniter import croniter +from loguru import logger +from sqlalchemy import select + +from app.database import async_session +from app.models.agent import Agent +from app.models.trigger import AgentTrigger + +MIN_POLL_INTERVAL_MINUTES = 5 + + +async def should_skip_non_workday(trigger: AgentTrigger, local_now: datetime) -> bool: + if trigger.name != "daily_okr_collection": + return False + + from app.models.okr import OKRSettings + from app.models.tenant import Tenant + from app.services.business_calendar import is_non_workday + + async with async_session() as db: + result = await db.execute( + select(Agent.tenant_id).where(Agent.id == trigger.agent_id) + ) + tenant_id = result.scalar_one_or_none() + if not tenant_id: + return False + + settings_result = await db.execute( + select(OKRSettings.daily_report_skip_non_workdays).where(OKRSettings.tenant_id == tenant_id) + ) + skip_enabled = settings_result.scalar_one_or_none() + if skip_enabled is False: + return False + + tenant_result = await db.execute( + select(Tenant.country_region).where(Tenant.id == tenant_id) + ) + country_region = tenant_result.scalar_one_or_none() + + return is_non_workday(local_now.date(), country_region) + + +async def mark_trigger_skipped(trigger_id: uuid.UUID, now: datetime) -> None: + try: + async with async_session() as db: + result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) + trigger = result.scalar_one_or_none() + if trigger: + trigger.last_fired_at = now + await db.commit() + except Exception as e: + logger.warning(f"Failed to mark skipped trigger {trigger_id}: {e}") + + +async def mark_trigger_fired(trigger_id: uuid.UUID, now: datetime) -> None: + try: + async with async_session() as db: + result = await db.execute(select(AgentTrigger).where(AgentTrigger.id == trigger_id)) + trigger = result.scalar_one_or_none() + if trigger: + trigger.last_fired_at = now + trigger.fire_count += 1 + if trigger.type == "once": + trigger.is_enabled = False + if trigger.max_fires and trigger.fire_count >= trigger.max_fires: + trigger.is_enabled = False + await db.commit() + except Exception as e: + logger.warning(f"Failed to mark fired trigger {trigger_id}: {e}") + + +async def handle_okr_report_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if trigger.name not in {"daily_okr_report", "weekly_okr_report", "monthly_okr_report"}: + return False + + from zoneinfo import ZoneInfo + from app.models.okr import OKRSettings + from app.services.okr_reporting import ( + generate_company_daily_report, + generate_company_monthly_report, + generate_company_weekly_report, + ) + from app.services.timezone_utils import get_agent_timezone + + async with async_session() as db: + agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) + tenant_id = agent_result.scalar_one_or_none() + if not tenant_id: + return True + + settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) + settings = settings_result.scalar_one_or_none() + if not settings or not settings.enabled: + return True + + tz_name = await get_agent_timezone(trigger.agent_id) + try: + tz = ZoneInfo(tz_name) + except Exception: + tz = ZoneInfo("UTC") + local_today = now.astimezone(tz).date() + + if trigger.name == "daily_okr_report": + await generate_company_daily_report(tenant_id, local_today - timedelta(days=1)) + elif trigger.name == "weekly_okr_report": + previous_week_anchor = local_today - timedelta(days=7) + week_start = previous_week_anchor - timedelta(days=previous_week_anchor.weekday()) + await generate_company_weekly_report(tenant_id, week_start) + elif trigger.name == "monthly_okr_report": + previous_month_end = local_today.replace(day=1) - timedelta(days=1) + await generate_company_monthly_report(tenant_id, previous_month_end) + + await mark_trigger_fired(trigger.id, now) + logger.info(f"[Trigger] Auto-generated OKR report for trigger {trigger.name}") + return True + + +async def handle_okr_collection_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if trigger.name != "daily_okr_collection": + return False + + from app.models.okr import OKRSettings + from app.services.okr_daily_collection import trigger_daily_collection_for_tenant + + async with async_session() as db: + agent_result = await db.execute(select(Agent.tenant_id).where(Agent.id == trigger.agent_id)) + tenant_id = agent_result.scalar_one_or_none() + if not tenant_id: + return True + + settings_result = await db.execute(select(OKRSettings).where(OKRSettings.tenant_id == tenant_id)) + settings = settings_result.scalar_one_or_none() + if not settings or not settings.enabled or not settings.daily_report_enabled: + return True + + await trigger_daily_collection_for_tenant(tenant_id) + await mark_trigger_fired(trigger.id, now) + logger.info(f"[Trigger] Deterministic OKR collection sent for trigger {trigger.name}") + return True + + +def is_private_url(url: str) -> bool: + try: + parsed = urlparse(url) + hostname = parsed.hostname + if not hostname: + return True + if hostname in ("localhost", "127.0.0.1", "::1", "0.0.0.0"): + return True + import socket + try: + infos = socket.getaddrinfo(hostname, None) + for info in infos: + ip = ipaddress.ip_address(info[4][0]) + if ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved: + return True + except (socket.gaierror, ValueError): + return True + return False + except Exception: + return True + + +async def evaluate_trigger(trigger: AgentTrigger, now: datetime) -> bool: + if not trigger.is_enabled: + return False + if trigger.expires_at and now >= trigger.expires_at: + return False + if trigger.max_fires is not None and trigger.fire_count >= trigger.max_fires: + return False + + if trigger.last_fired_at: + cooldown = timedelta(seconds=trigger.cooldown_seconds) + if (now - trigger.last_fired_at) < cooldown: + return False + + cfg = trigger.config or {} + t = trigger.type + + if t == "cron": + expr = cfg.get("expr", "* * * * *") + base = trigger.last_fired_at or trigger.created_at + try: + tz_name = cfg.get("timezone") + if not tz_name: + from app.services.timezone_utils import get_agent_timezone + tz_name = await get_agent_timezone(trigger.agent_id) + from zoneinfo import ZoneInfo + try: + tz = ZoneInfo(tz_name) + except (KeyError, Exception): + tz = ZoneInfo("UTC") + local_now = now.astimezone(tz) + local_base = base.astimezone(tz) if base.tzinfo else base.replace(tzinfo=tz) + cron = croniter(expr, local_base) + next_run = cron.get_next(datetime) + if local_now >= next_run: + if await should_skip_non_workday(trigger, local_now): + await mark_trigger_skipped(trigger.id, now) + logger.info(f"[Trigger] Skipped {trigger.name} on non-workday {local_now.date()}") + return False + return True + return False + except Exception as e: + logger.warning(f"Invalid cron expr '{expr}' for trigger {trigger.name}: {e}") + return False + + if t == "once": + at_str = cfg.get("at") + if not at_str: + return False + try: + at = datetime.fromisoformat(at_str) + if at.tzinfo is None: + at = at.replace(tzinfo=timezone.utc) + return now >= at and trigger.fire_count == 0 + except Exception: + return False + + if t == "interval": + minutes = cfg.get("minutes", 30) + base = trigger.last_fired_at or trigger.created_at + return (now - base) >= timedelta(minutes=minutes) + + if t == "poll": + interval_min = max(cfg.get("interval_min", 5), MIN_POLL_INTERVAL_MINUTES) + base = trigger.last_fired_at or trigger.created_at + if (now - base) < timedelta(minutes=interval_min): + return False + return await poll_check(trigger) + + if t == "on_message": + return await check_new_agent_messages(trigger) + + if t == "webhook": + return False + + return False + + +async def poll_check(trigger: AgentTrigger) -> bool: + import httpx + + cfg = trigger.config or {} + url = cfg.get("url") + if not url: + return False + if is_private_url(url): + logger.warning(f"Poll blocked for trigger {trigger.name}: private/internal URL '{url}'") + return False + try: + async with httpx.AsyncClient(timeout=10) as client: + resp = await client.request(cfg.get("method", "GET"), url, headers=cfg.get("headers", {})) + resp.raise_for_status() + + data = resp.json() + json_path = cfg.get("json_path", "$") + current_value = extract_json_path(data, json_path) + current_str = str(current_value) + fire_on = cfg.get("fire_on", "change") + should_fire = False + if fire_on == "match": + should_fire = current_str == str(cfg.get("match_value", "")) + else: + last_value = cfg.get("_last_value") + should_fire = last_value is not None and current_str != last_value + + cfg["_last_value"] = current_str + try: + from sqlalchemy import update + async with async_session() as db: + await db.execute( + update(AgentTrigger).where(AgentTrigger.id == trigger.id).values(config=cfg) + ) + await db.commit() + except Exception as e: + logger.warning(f"Failed to persist poll _last_value for {trigger.name}: {e}") + + return should_fire + except Exception as e: + logger.warning(f"Poll failed for trigger {trigger.name}: {e}") + return False + + +def extract_json_path(data, path: str): + if path == "$" or not path: + return data + parts = path.lstrip("$.").split(".") + current = data + for part in parts: + if isinstance(current, dict): + current = current.get(part) + elif isinstance(current, list) and part.isdigit(): + current = current[int(part)] + else: + return None + return current + + +async def check_new_agent_messages(trigger: AgentTrigger) -> bool: + from app.models.audit import ChatMessage + from app.models.chat_session import ChatSession + + cfg = trigger.config or {} + from_agent_name = cfg.get("from_agent_name") + from_user_name = cfg.get("from_user_name") + if not from_agent_name and not from_user_name: + return False + + since = trigger.last_fired_at or trigger.created_at + if trigger.fire_count == 0 and not trigger.last_fired_at: + since_ts_str = cfg.get("_since_ts") + if since_ts_str: + try: + since = datetime.fromisoformat(since_ts_str) + except Exception: + since = trigger.created_at + + try: + async with async_session() as db: + if from_agent_name: + from app.models.participant import Participant + from app.models.agent import Agent as AgentModel + safe_agent_name = from_agent_name.replace("%", "").replace("_", r"\_") + agent_r = await db.execute(select(AgentModel).where(AgentModel.name.ilike(f"%{safe_agent_name}%"))) + source_agent = agent_r.scalars().first() + if not source_agent: + return False + result = await db.execute( + select(Participant.id).where(Participant.type == "agent", Participant.ref_id == source_agent.id) + ) + from_participant = result.scalar_one_or_none() + if not from_participant: + return False + from sqlalchemy import String as SaString, cast as sa_cast + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatMessage.participant_id == from_participant, + ChatMessage.created_at > since, + ChatMessage.role == "assistant", + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + msg = result.scalar_one_or_none() + if not msg: + return False + cfg["_matched_message"] = (msg.content or "")[:2000] + cfg["_matched_from"] = from_agent_name + return True + + if from_user_name: + from sqlalchemy import or_ + from sqlalchemy import String as SaString, cast as sa_cast + from app.models.agent import Agent as AgentModel + from app.models.user import Identity, User + + agent_r = await db.execute(select(AgentModel).where(AgentModel.id == trigger.agent_id)) + agent = agent_r.scalar_one_or_none() + safe_user_name = from_user_name.replace("%", "").replace("_", r"\_") + query = ( + select(User) + .join(User.identity) + .where( + or_( + User.display_name.ilike(f"%{safe_user_name}%"), + Identity.username.ilike(f"%{safe_user_name}%"), + ) + ) + ) + if agent and agent.tenant_id: + query = query.where(User.tenant_id == agent.tenant_id) + user_r = await db.execute(query) + target_user = user_r.scalars().first() + + if target_user: + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatSession.agent_id == trigger.agent_id, + ChatSession.user_id == target_user.id, + ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), + ChatMessage.role == "user", + ChatMessage.created_at > since, + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + else: + result = await db.execute( + select(ChatMessage) + .join(ChatSession, ChatMessage.conversation_id == sa_cast(ChatSession.id, SaString)) + .where( + ChatSession.agent_id == trigger.agent_id, + ChatSession.source_channel.in_(["feishu", "slack", "discord", "web"]), + ChatMessage.role == "user", + ChatMessage.created_at > since, + or_( + ChatSession.title.ilike(f"%{safe_user_name}%"), + ChatMessage.content.ilike(f"%{safe_user_name}%"), + ), + ) + .order_by(ChatMessage.created_at.desc()) + .limit(1) + ) + + msg = result.scalar_one_or_none() + if not msg: + return False + cfg["_matched_message"] = (msg.content or "")[:2000] + cfg["_matched_from"] = from_user_name + return True + except Exception as e: + logger.warning(f"on_message check failed for trigger {trigger.name}: {e}") + return False + + return False diff --git a/backend/app/services/trigger_runtime/executions.py b/backend/app/services/trigger_runtime/executions.py new file mode 100644 index 000000000..154b07854 --- /dev/null +++ b/backend/app/services/trigger_runtime/executions.py @@ -0,0 +1,131 @@ +"""Execution claiming and completion helpers for distributed triggers.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import or_, select + +from app.config import get_settings +from app.database import async_session +from app.models.trigger import AgentTrigger +from app.models.trigger_execution import TriggerExecution + +settings = get_settings() + + +async def mark_trigger_executions_completed(execution_ids: list[uuid.UUID]) -> None: + if not execution_ids: + return + async with async_session() as db: + result = await db.execute( + select(TriggerExecution).where(TriggerExecution.id.in_(execution_ids)) + ) + for execution in result.scalars().all(): + execution.status = "completed" + execution.finished_at = datetime.now(timezone.utc) + execution.lease_owner = None + execution.lease_expires_at = None + execution.last_error = None + await db.commit() + + +async def mark_trigger_executions_failed(execution_ids: list[uuid.UUID], error_text: str) -> None: + if not execution_ids: + return + async with async_session() as db: + result = await db.execute( + select(TriggerExecution).where(TriggerExecution.id.in_(execution_ids)) + ) + for execution in result.scalars().all(): + execution.status = "failed" + execution.finished_at = datetime.now(timezone.utc) + execution.lease_owner = None + execution.lease_expires_at = None + execution.last_error = error_text + await db.commit() + + +async def claim_pending_trigger_executions( + *, + sources: list[str] | None = None, + limit: int = 100, +) -> list[tuple[TriggerExecution, AgentTrigger]]: + now = datetime.now(timezone.utc) + lease_until = now + timedelta(minutes=5) + claimed_pairs: list[tuple[TriggerExecution, AgentTrigger]] = [] + sources = sources or ["webhook", "cron", "once", "interval", "poll", "on_message"] + async with async_session() as db: + result = await db.execute( + select(TriggerExecution, AgentTrigger) + .join(AgentTrigger, AgentTrigger.id == TriggerExecution.trigger_id) + .where( + TriggerExecution.source.in_(sources), + AgentTrigger.is_enabled == True, + or_( + TriggerExecution.status == "pending", + (TriggerExecution.status == "processing") & ( + (TriggerExecution.lease_expires_at == None) | (TriggerExecution.lease_expires_at < now) + ), + ), + ) + .order_by(TriggerExecution.scheduled_at.asc()) + .with_for_update(skip_locked=True) + .limit(limit) + ) + rows = result.all() + for execution, trigger in rows: + execution.status = "processing" + execution.started_at = execution.started_at or now + execution.finished_at = None + execution.lease_owner = settings.INSTANCE_ID + execution.lease_expires_at = lease_until + claimed_pairs.append((execution, trigger)) + await db.commit() + return claimed_pairs + + +def build_execution_runtime_trigger(trigger: AgentTrigger, execution: TriggerExecution) -> AgentTrigger: + runtime_cfg = { + **(trigger.config or {}), + "_execution_id": str(execution.id), + } + if execution.payload: + runtime_cfg.update(execution.payload) + if execution.payload_text: + runtime_cfg["_webhook_payload"] = execution.payload_text + return AgentTrigger( + id=trigger.id, + agent_id=trigger.agent_id, + name=trigger.name, + type=trigger.type, + config=runtime_cfg, + reason=trigger.reason, + focus_ref=trigger.focus_ref, + is_enabled=trigger.is_enabled, + last_fired_at=trigger.last_fired_at, + fire_count=trigger.fire_count, + max_fires=trigger.max_fires, + cooldown_seconds=trigger.cooldown_seconds, + is_system=trigger.is_system, + created_at=trigger.created_at, + expires_at=trigger.expires_at, + ) + + +async def mark_base_triggers_fired(trigger_ids: list[uuid.UUID], now: datetime) -> None: + if not trigger_ids: + return + async with async_session() as db: + result = await db.execute( + select(AgentTrigger).where(AgentTrigger.id.in_(trigger_ids)) + ) + for trigger in result.scalars().all(): + trigger.last_fired_at = now + trigger.fire_count += 1 + if trigger.type == "once": + trigger.is_enabled = False + if trigger.max_fires and trigger.fire_count >= trigger.max_fires: + trigger.is_enabled = False + await db.commit() diff --git a/backend/app/services/trigger_runtime/invoker.py b/backend/app/services/trigger_runtime/invoker.py new file mode 100644 index 000000000..6784f32ee --- /dev/null +++ b/backend/app/services/trigger_runtime/invoker.py @@ -0,0 +1,368 @@ +"""Trigger invocation and delivery orchestration.""" + +from __future__ import annotations + +import json as _json +import uuid +from datetime import datetime, timezone + +from loguru import logger +from sqlalchemy import select + +from app.database import async_session +from app.models.agent import Agent +from app.models.trigger import AgentTrigger +from app.services.trigger_runtime import ( + mark_trigger_executions_completed, + mark_trigger_executions_failed, +) + + +async def resolve_trigger_delivery_target(agent: Agent, triggers: list[AgentTrigger]) -> dict | None: + from app.models.chat_session import ChatSession + from app.services.chat_session_service import ensure_primary_platform_session + + for trigger in triggers: + cfg = trigger.config or {} + a2a_sid = cfg.get("_a2a_session_id") + if a2a_sid: + try: + async with async_session() as db: + session = await db.get(ChatSession, uuid.UUID(a2a_sid)) + if not session: + return None + return { + "kind": "session", + "session_id": str(session.id), + "owner_user_id": str(session.user_id), + "source_channel": session.source_channel, + } + except Exception: + return None + + origin_cfg = None + for trigger in triggers: + cfg = trigger.config or {} + if cfg.get("_origin_session_id") or cfg.get("_origin_user_id"): + origin_cfg = cfg + break + if not origin_cfg: + return None + + origin_source_channel = origin_cfg.get("_origin_source_channel") + origin_session_id = origin_cfg.get("_origin_session_id") + origin_user_id = origin_cfg.get("_origin_user_id") + + if origin_source_channel == "agent" and origin_session_id: + try: + async with async_session() as db: + session = await db.get(ChatSession, uuid.UUID(origin_session_id)) + if not session: + return None + return { + "kind": "session", + "session_id": str(session.id), + "owner_user_id": str(session.user_id), + "source_channel": "agent", + } + except Exception: + return None + + if origin_source_channel != "trigger" and origin_user_id: + try: + async with async_session() as db: + primary = await ensure_primary_platform_session(db, agent.id, uuid.UUID(origin_user_id)) + await db.commit() + return { + "kind": "primary_user_session", + "session_id": str(primary.id), + "owner_user_id": str(primary.user_id), + "source_channel": primary.source_channel, + } + except Exception: + return None + + return None + + +async def invoke_agent_for_triggers(agent_id: uuid.UUID, triggers: list[AgentTrigger]): + from app.models.audit import ChatMessage + from app.models.chat_session import ChatSession + from app.models.llm import LLMModel + from app.models.participant import Participant + from app.services.audit_logger import write_audit_log + from app.services.llm import call_llm + + try: + execution_ids = [ + uuid.UUID(str((t.config or {}).get("_execution_id"))) + for t in triggers + if (t.config or {}).get("_execution_id") + ] + async with async_session() as db: + result = await db.execute(select(Agent).where(Agent.id == agent_id)) + agent = result.scalar_one_or_none() + if not agent or agent.is_expired: + return + + if not agent.primary_model_id: + logger.warning(f"Agent {agent.name} has no LLM model, skipping trigger invocation") + return + result = await db.execute(select(LLMModel).where(LLMModel.id == agent.primary_model_id)) + model = result.scalar_one_or_none() + if not model or not model.enabled: + logger.warning(f"Agent {agent.name}'s model is unavailable, skipping trigger invocation") + return + + context_parts = [] + trigger_names = [] + for t in triggers: + part = f"触发器:{t.name} ({t.type})\n原因:{t.reason}" + if t.name == "daily_okr_collection": + part += ( + "\n执行要求:先调用 get_okr_settings 确认日报收集是否开启。" + "如果开启,只能联系你关系网络中的成员和数字员工来收集今天的最终日报," + "并整理成不超过 2000 字的正式日报;" + "如果未开启,则说明本次无需执行并停止。" + ) + elif t.name in ("daily_okr_report", "weekly_okr_report", "monthly_okr_report"): + part += ( + "\n执行要求:本次公司级报表由系统自动汇总生成。" + "如果你被唤醒,仅补充必要说明,不要再次向成员发起收集。" + ) + elif t.name == "biweekly_okr_checkin": + part += ( + "\n执行要求:先调用 get_okr_settings 确认 OKR 是否开启。" + "如果开启,检查当前周期公司和成员 OKR,主动提醒尚未设置或进展滞后的相关成员;" + "如果未开启,则说明本次无需执行并停止。" + ) + if t.focus_ref: + part += f"\n关联 Focus:{t.focus_ref}" + cfg = t.config or {} + if t.type == "on_message" and cfg.get("_matched_message"): + part += f"\n收到来自 {cfg.get('_matched_from', '?')} 的消息:\n\"{cfg['_matched_message'][:500]}\"" + if t.type == "on_message" and cfg.get("okr_member_id") and cfg.get("okr_report_date"): + part += ( + "\n执行要求:这是一次日报回复入库事件。" + f"\n1. 将对方回复整理成一段不超过 2000 字的最终日报。" + f"\n2. 立即调用 upsert_member_daily_report(report_date=\"{cfg['okr_report_date']}\", " + f"member_type=\"{cfg.get('okr_member_type', 'user')}\", " + f"member_id=\"{cfg['okr_member_id']}\", content=\"<整理后的日报>\")。" + "\n3. 工具调用成功后,再发送一句简短确认,明确你已收到并已记录。" + "\n4. 不要只回复确认而不调用工具,也不要把原始长对话原样存入日报。" + ) + if t.type == "webhook" and cfg.get("_webhook_payload"): + payload_str = cfg["_webhook_payload"] + if len(payload_str) > 2000: + payload_str = payload_str[:2000] + "... (truncated)" + part += f"\nWebhook Payload:\n{payload_str}" + context_parts.append(part) + trigger_names.append(t.name) + + trigger_context = ( + "===== 本次唤醒上下文 =====\n" + f"唤醒来源:trigger({'多个触发器同时触发' if len(triggers) > 1 else '触发器触发'})\n\n" + + "\n---\n".join(context_parts) + + "\n===========================" + ) + + title = f"🤖 内心独白:{', '.join(trigger_names)}" + result = await db.execute( + select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) + ) + agent_participant = result.scalar_one_or_none() + + session = ChatSession( + agent_id=agent_id, + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + source_channel="trigger", + title=title[:200], + ) + db.add(session) + await db.flush() + session_id = session.id + messages = [{"role": "user", "content": trigger_context}] + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="user", + content=trigger_context, + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + )) + await db.commit() + agent_participant_id = agent_participant.id if agent_participant else None + + collected_content: list[str] = [] + delivered_platform_message_via_tool = False + + async def on_chunk(text): + collected_content.append(text) + + async def on_tool_call(data): + nonlocal delivered_platform_message_via_tool + try: + tool_name = data.get("name") + tool_status = data.get("status") + if tool_status == "done" and tool_name == "send_platform_message": + result_text = str(data.get("result", "")) + if result_text.startswith("✅"): + delivered_platform_message_via_tool = True + + async with async_session() as _tc_db: + if data["status"] == "running": + _tc_db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="tool_call", + content=_json.dumps({"name": data["name"], "args": data["args"]}, ensure_ascii=False, default=str), + user_id=agent.creator_id, + participant_id=agent_participant_id, + )) + elif data["status"] == "done": + result_str = str(data.get("result", ""))[:2000] + _tc_db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="tool_call", + content=_json.dumps({"name": data["name"], "result": result_str}, ensure_ascii=False, default=str), + user_id=agent.creator_id, + participant_id=agent_participant_id, + )) + await _tc_db.commit() + except Exception as e: + logger.warning(f"Failed to persist tool call for trigger session: {e}") + + reply = await call_llm( + model=model, + messages=messages, + agent_name=agent.name, + role_description=agent.role_description or "", + agent_id=agent_id, + user_id=agent.creator_id, + session_id=str(session_id), + on_chunk=on_chunk, + on_tool_call=on_tool_call, + ) + + async with async_session() as db: + result = await db.execute( + select(Participant).where(Participant.type == "agent", Participant.ref_id == agent_id) + ) + agent_participant = result.scalar_one_or_none() + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=str(session_id), + role="assistant", + content=reply or "".join(collected_content), + user_id=agent.creator_id, + participant_id=agent_participant.id if agent_participant else None, + )) + await db.commit() + + final_reply = reply or "".join(collected_content) + for t in triggers: + a2a_sid = (t.config or {}).get("_a2a_session_id") + if a2a_sid and final_reply: + try: + async with async_session() as db: + from app.models.participant import Participant as _P + _p_r = await db.execute(select(_P).where(_P.type == "agent", _P.ref_id == agent_id)) + _p = _p_r.scalar_one_or_none() + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=a2a_sid, + role="assistant", + content=final_reply, + user_id=agent.creator_id, + participant_id=_p.id if _p else None, + )) + from app.models.chat_session import ChatSession as _CS + _cs_r = await db.execute(select(_CS).where(_CS.id == uuid.UUID(a2a_sid))) + _cs = _cs_r.scalar_one_or_none() + if _cs: + _cs.last_message_at = datetime.now(timezone.utc) + await db.commit() + except Exception as e: + logger.warning(f"[A2A] Failed to save reply to A2A session {a2a_sid}: {e}") + break + + is_a2a_internal = all(t.name == "a2a_wake" for t in triggers) + delivery_target = None if is_a2a_internal else await resolve_trigger_delivery_target(agent, triggers) + + if final_reply and delivery_target and not delivered_platform_message_via_tool: + try: + from app.api.websocket import manager as ws_manager + agent_id_str = str(agent_id) + trigger_reasons = [] + for t in triggers: + ns = (t.config or {}).get("_notification_summary", "").strip() + if ns: + trigger_reasons.append(ns) + else: + r = (t.reason or "").strip() + if r and len(r) <= 80: + trigger_reasons.append(r) + elif r: + trigger_reasons.append(r[:77] + "...") + summary = trigger_reasons[0] if trigger_reasons else "有新的事件需要处理" + notification = f"⚡ {summary}\n\n{final_reply}" + target_session_id = delivery_target["session_id"] + owner_user_id = delivery_target.get("owner_user_id") + + async with async_session() as db: + from app.api.websocket import maybe_mark_session_read_for_active_viewer + from app.models.chat_session import ChatSession + db.add(ChatMessage( + agent_id=agent_id, + conversation_id=target_session_id, + role="assistant", + content=notification, + user_id=agent.creator_id, + )) + session_row = await db.get(ChatSession, uuid.UUID(target_session_id)) + if session_row: + session_row.last_message_at = datetime.now(timezone.utc) + if owner_user_id: + await maybe_mark_session_read_for_active_viewer( + db, + agent_id=agent_id, + session_id=target_session_id, + user_id=uuid.UUID(owner_user_id), + ) + await db.commit() + + if owner_user_id: + await ws_manager.send_to_user( + agent_id_str, + owner_user_id, + { + "type": "trigger_notification", + "content": notification, + "triggers": [t.name for t in triggers], + "session_id": target_session_id, + }, + ) + except Exception as e: + logger.error(f"Failed to push trigger result to WebSocket: {e}") + + await write_audit_log( + "trigger_fired", + {"agent_name": agent.name, "triggers": [{"name": t.name, "type": t.type} for t in triggers]}, + agent_id=agent_id, + ) + + if execution_ids: + await mark_trigger_executions_completed(execution_ids) + except Exception as e: + logger.error(f"Failed to invoke agent {agent_id} for triggers: {e}") + import traceback + traceback.print_exc() + execution_ids = [ + uuid.UUID(str((t.config or {}).get("_execution_id"))) + for t in triggers + if (t.config or {}).get("_execution_id") + ] + if execution_ids: + await mark_trigger_executions_failed(execution_ids, str(e)[:2000]) diff --git a/backend/app/services/trigger_runtime/keys.py b/backend/app/services/trigger_runtime/keys.py new file mode 100644 index 000000000..4fbb3178a --- /dev/null +++ b/backend/app/services/trigger_runtime/keys.py @@ -0,0 +1,47 @@ +"""Deterministic idempotency keys for trigger executions.""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timedelta, timezone + +from croniter import croniter + +from app.models.trigger import AgentTrigger + + +def build_scheduled_execution_key(trigger: AgentTrigger, now: datetime) -> str: + """Build a deterministic idempotency key for non-webhook trigger runs.""" + cfg = trigger.config or {} + trigger_type = trigger.type + + if trigger_type == "once": + return f"once:{trigger.id}:{cfg.get('at', '')}" + + if trigger_type == "interval": + minutes = int(cfg.get("minutes", 30) or 30) + base = trigger.last_fired_at or trigger.created_at + due_at = base + timedelta(minutes=minutes) + return f"interval:{trigger.id}:{due_at.astimezone(timezone.utc).isoformat()}" + + if trigger_type == "cron": + expr = cfg.get("expr", "* * * * *") + base = trigger.last_fired_at or trigger.created_at + cron = croniter(expr, base) + due_at = cron.get_next(datetime) + if due_at.tzinfo is None: + due_at = due_at.replace(tzinfo=timezone.utc) + return f"cron:{trigger.id}:{due_at.astimezone(timezone.utc).isoformat()}" + + if trigger_type == "on_message": + matched_from = str(cfg.get("_matched_from") or "") + matched_message = str(cfg.get("_matched_message") or "") + digest = hashlib.sha256(f"{matched_from}\n{matched_message}".encode("utf-8")).hexdigest() + return f"on_message:{trigger.id}:{digest}" + + if trigger_type == "poll": + current_value = str(cfg.get("_last_value") or "") + digest = hashlib.sha256(current_value.encode("utf-8")).hexdigest() + return f"poll:{trigger.id}:{digest}" + + return f"{trigger_type}:{trigger.id}:{now.replace(microsecond=0).isoformat()}" diff --git a/backend/app/services/trigger_runtime/queue.py b/backend/app/services/trigger_runtime/queue.py new file mode 100644 index 000000000..5b26037d2 --- /dev/null +++ b/backend/app/services/trigger_runtime/queue.py @@ -0,0 +1,73 @@ +"""Queue trigger executions for distributed workers.""" + +from __future__ import annotations + +import hashlib +from datetime import datetime, timezone + +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.trigger import AgentTrigger +from app.models.trigger_execution import TriggerExecution + + +async def enqueue_trigger_execution( + db: AsyncSession, + *, + trigger: AgentTrigger, + source: str, + idempotency_key: str, + payload_text: str = "", + payload_obj: dict | None = None, +) -> tuple[TriggerExecution | None, bool]: + """Insert a generic trigger execution record.""" + execution = TriggerExecution( + trigger_id=trigger.id, + agent_id=trigger.agent_id, + source=source, + status="pending", + idempotency_key=idempotency_key[:255], + payload=payload_obj if isinstance(payload_obj, dict) else {}, + payload_text=payload_text[:8000], + scheduled_at=datetime.now(timezone.utc), + ) + db.add(execution) + try: + await db.commit() + return execution, True + except IntegrityError: + await db.rollback() + return None, False + + +async def enqueue_webhook_execution( + db: AsyncSession, + *, + trigger: AgentTrigger, + body: bytes, + payload_text: str, + payload_obj: dict | None, + request_headers: dict[str, str], +) -> tuple[TriggerExecution | None, bool]: + """Insert a webhook execution record. + + Returns `(execution, created)` where `created=False` means an identical + idempotency key already exists and the event should be treated as a no-op. + """ + delivery_key = ( + request_headers.get("x-idempotency-key") + or request_headers.get("x-github-delivery") + or request_headers.get("x-request-id") + or request_headers.get("x-event-id") + or hashlib.sha256(body).hexdigest() + )[:255] + + return await enqueue_trigger_execution( + db, + trigger=trigger, + source="webhook", + idempotency_key=delivery_key, + payload_text=payload_text, + payload_obj=payload_obj, + ) diff --git a/backend/app/services/wechat_channel.py b/backend/app/services/wechat_channel.py index 8ad24fe26..2fd114781 100644 --- a/backend/app/services/wechat_channel.py +++ b/backend/app/services/wechat_channel.py @@ -307,6 +307,7 @@ class WeChatPollManager: def __init__(self) -> None: self._tasks: dict[uuid.UUID, asyncio.Task] = {} self._connected: dict[uuid.UUID, bool] = {} + self._reconcile_interval_seconds = 30 async def start_client(self, agent_id: uuid.UUID, stop_existing: bool = True) -> None: if stop_existing: @@ -327,6 +328,13 @@ async def stop_client(self, agent_id: uuid.UUID) -> None: await self._set_connected(agent_id, False) async def start_all(self) -> None: + logger.info("[WeChat] Poll manager started") + while True: + await self.reconcile_clients() + await asyncio.sleep(self._reconcile_interval_seconds) + + async def reconcile_clients(self) -> None: + configured_agent_ids: set[uuid.UUID] = set() async with async_session() as db: result = await db.execute( select(ChannelConfig).where( @@ -337,7 +345,16 @@ async def start_all(self) -> None: for cfg in result.scalars().all(): token = str((cfg.extra_config or {}).get("bot_token") or "").strip() if token: - await self.start_client(cfg.agent_id) + configured_agent_ids.add(cfg.agent_id) + + for agent_id in configured_agent_ids: + task = self._tasks.get(agent_id) + if task is None or task.done(): + await self.start_client(agent_id) + + for agent_id in list(self._tasks): + if agent_id not in configured_agent_ids: + await self.stop_client(agent_id) async def _run_client(self, agent_id: uuid.UUID) -> None: retry_delay = 2 diff --git a/backend/app/services/workspace_collaboration.py b/backend/app/services/workspace_collaboration.py index e7a8be898..d8269aac2 100644 --- a/backend/app/services/workspace_collaboration.py +++ b/backend/app/services/workspace_collaboration.py @@ -18,6 +18,10 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.models.workspace import WorkspaceEditLock, WorkspaceFileRevision +from app.services.storage import get_storage_backend, normalize_storage_key +from app.services.storage_runtime.base import WriteCondition +from app.services.storage_runtime.local import LocalStorageBackend +from app.services.workspace_locking import workspace_locks USER_AUTOSAVE_MERGE_SECONDS = 60 EDIT_LOCK_TTL_SECONDS = 90 @@ -63,6 +67,11 @@ class WorkspaceWriteResult: locked_by_user_id: str | None = None +def _should_mirror_to_local_filesystem(storage) -> bool: + """Only mirror writes into AGENT_DATA_DIR when the filesystem is the primary store.""" + return isinstance(storage, LocalStorageBackend) + + def content_hash(content: str | None) -> str: """Return a stable hash for text content.""" return hashlib.sha256((content or "").encode("utf-8")).hexdigest() @@ -270,6 +279,7 @@ async def write_workspace_file( session_id: str | None = None, enforce_human_lock: bool = True, merge_user_autosave: bool = False, + expected_version_token: str | None = None, ) -> WorkspaceWriteResult: """Write text content, enforcing human locks for agent/system actors.""" normalized = normalize_workspace_path(path) @@ -289,11 +299,27 @@ async def write_workspace_file( locked_by_user_id=str(lock.user_id), ) - target = safe_agent_path(base_dir, normalized) - before = await read_text_if_exists(target) - target.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(target, "w", encoding="utf-8") as f: - await f.write(content) + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{agent_id}/{normalized}") + local_base_available = _should_mirror_to_local_filesystem(storage) + try: + target = safe_agent_path(base_dir, normalized) + except Exception: + target = None + local_base_available = False + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if await storage.exists(storage_key) else None + write_result = await storage.write_bytes_if_match( + storage_key, + content.encode("utf-8"), + condition=WriteCondition(version_token=expected_version_token) if expected_version_token is not None else None, + content_type="text/plain; charset=utf-8", + ) + if not write_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while writing {normalized}") + if local_base_available and target is not None: + target.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(target, "w", encoding="utf-8") as f: + await f.write(content) revision = await record_revision( db, @@ -325,10 +351,18 @@ async def delete_workspace_file( actor_id: uuid.UUID | None, session_id: str | None = None, enforce_human_lock: bool = True, + expected_version_token: str | None = None, ) -> WorkspaceWriteResult: """Delete a workspace file and record the deleted content.""" normalized = normalize_workspace_path(path) - target = safe_agent_path(base_dir, normalized) + storage = get_storage_backend() + storage_key = normalize_storage_key(f"{agent_id}/{normalized}") + target = None + if _should_mirror_to_local_filesystem(storage): + try: + target = safe_agent_path(base_dir, normalized) + except Exception: + target = None if enforce_human_lock and actor_type != "user": lock = await get_active_lock(db, agent_id=agent_id, path=normalized) if lock: @@ -338,15 +372,34 @@ async def delete_workspace_file( f"Human is currently editing {normalized}. Do not delete it now.", locked_by_user_id=str(lock.user_id), ) - if not target.exists(): + storage_exists = await storage.exists(storage_key) + storage_is_dir = await storage.is_dir(storage_key) + if not storage_exists and not storage_is_dir: return WorkspaceWriteResult(False, normalized, f"File not found: {normalized}") - before = await read_text_if_exists(target) - if target.is_dir(): - import shutil - - shutil.rmtree(target) - else: - target.unlink() + before = await storage.read_text(storage_key, encoding="utf-8", errors="replace") if storage_exists and await storage.is_file(storage_key) else None + async with workspace_locks(agent_id, [normalized]): + if storage_is_dir: + entries = await _collect_storage_tree_versions(storage, storage_key) + for entry_key, version_token in reversed(entries): + delete_result = await storage.delete_if_match( + entry_key, + condition=WriteCondition(version_token=version_token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while deleting {normalized}") + else: + delete_result = await storage.delete_if_match( + storage_key, + condition=WriteCondition(version_token=expected_version_token) if expected_version_token is not None else None, + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, normalized, f"Conflict detected while deleting {normalized}") + if target is not None and target.exists(): + if target.is_dir(): + import shutil + shutil.rmtree(target) + else: + target.unlink() revision = await record_revision( db, agent_id=agent_id, @@ -378,6 +431,8 @@ async def move_workspace_path( session_id: str | None = None, enforce_human_lock: bool = True, overwrite: bool = False, + expected_source_version_token: str | None = None, + expected_destination_version_token: str | None = None, ) -> WorkspaceWriteResult: """Move or rename a workspace file/folder while respecting edit locks.""" source_normalized = normalize_workspace_path(source_path) @@ -389,19 +444,22 @@ async def move_workspace_path( if source_normalized in {"tasks.json", "soul.md"}: return WorkspaceWriteResult(False, source_normalized, f"{source_normalized} cannot be moved (protected)") - source = safe_agent_path(base_dir, source_normalized) - if not source.exists(): + storage = get_storage_backend() + source_key = normalize_storage_key(f"{agent_id}/{source_normalized}") + source_exists = await storage.exists(source_key) + source_is_dir = await storage.is_dir(source_key) + if not source_exists and not source_is_dir: return WorkspaceWriteResult(False, source_normalized, f"File not found: {source_normalized}") - destination = safe_agent_path(base_dir, destination_normalized) - if destination_path.replace("\\", "/").strip().endswith("/") or destination.is_dir(): - destination = (destination / source.name).resolve() - destination_normalized = normalize_workspace_path(str(destination.relative_to(base_dir.resolve()))) - destination = safe_agent_path(base_dir, destination_normalized) + destination_key = normalize_storage_key(f"{agent_id}/{destination_normalized}") + destination_is_dir = await storage.is_dir(destination_key) + if destination_path.replace("\\", "/").strip().endswith("/") or destination_is_dir: + destination_normalized = normalize_workspace_path(f"{destination_normalized}/{Path(source_normalized).name}") + destination_key = normalize_storage_key(f"{agent_id}/{destination_normalized}") - if source == destination: + if source_normalized == destination_normalized: return WorkspaceWriteResult(False, source_normalized, "Source and destination are the same") - if source.is_dir() and str(destination).startswith(str(source) + "/"): + if source_is_dir and (destination_normalized == source_normalized or destination_normalized.startswith(source_normalized + "/")): return WorkspaceWriteResult(False, source_normalized, "Cannot move a folder into itself") if enforce_human_lock and actor_type != "user": @@ -418,23 +476,71 @@ async def move_workspace_path( locked_by_user_id=str(lock.user_id), ) - if destination.exists(): - if not overwrite: - return WorkspaceWriteResult( - False, - destination_normalized, - f"Destination already exists: {destination_normalized}. Set overwrite=true to replace it.", - ) - if destination.is_dir(): - shutil.rmtree(destination) + destination_exists = await storage.exists(destination_key) + destination_is_dir = await storage.is_dir(destination_key) + async with workspace_locks(agent_id, [source_normalized, destination_normalized]): + if destination_exists or destination_is_dir: + if not overwrite: + return WorkspaceWriteResult( + False, + destination_normalized, + f"Destination already exists: {destination_normalized}. Set overwrite=true to replace it.", + ) + if destination_is_dir: + await storage.delete_tree(destination_key) + else: + delete_result = await storage.delete_if_match( + destination_key, + condition=WriteCondition(version_token=expected_destination_version_token) if expected_destination_version_token is not None else None, + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, destination_normalized, f"Conflict detected while replacing {destination_normalized}") + + source = destination = None + if _should_mirror_to_local_filesystem(storage): + source = safe_agent_path(base_dir, source_normalized) + destination = safe_agent_path(base_dir, destination_normalized) + source_before = await storage.read_text(source_key, encoding="utf-8", errors="replace") if source_exists else None + destination_before = await storage.read_text(destination_key, encoding="utf-8", errors="replace") if destination_exists else None + + if source_is_dir: + entries = await _collect_storage_tree_versions(storage, source_key) + for entry_key, version_token in entries: + rel = entry_key.removeprefix(source_key.rstrip("/") + "/") + target_key = normalize_storage_key(f"{agent_id}/{destination_normalized}/{rel}") + current_version = await storage.get_version(entry_key) + if current_version.token != version_token: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while moving {source_normalized}") + await storage.write_bytes(target_key, await storage.read_bytes(entry_key)) + for entry_key, version_token in reversed(entries): + delete_result = await storage.delete_if_match( + entry_key, + condition=WriteCondition(version_token=version_token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while finalizing move for {source_normalized}") else: - destination.unlink() + source_version = await storage.get_version(source_key) + if expected_source_version_token is not None and source_version.token != expected_source_version_token: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while moving {source_normalized}") + await storage.write_bytes(destination_key, await storage.read_bytes(source_key)) + delete_result = await storage.delete_if_match( + source_key, + condition=WriteCondition(version_token=source_version.token), + ) + if not delete_result.ok: + return WorkspaceWriteResult(False, source_normalized, f"Conflict detected while finalizing move for {source_normalized}") + + destination_after = await storage.read_text(destination_key, encoding="utf-8", errors="replace") if await storage.is_file(destination_key) else None - source_before = await read_text_if_exists(source) - destination_before = await read_text_if_exists(destination) - destination.parent.mkdir(parents=True, exist_ok=True) - shutil.move(str(source), str(destination)) - destination_after = await read_text_if_exists(destination) + if source is not None and source.exists(): + if source.is_dir(): + shutil.rmtree(source) + else: + source.unlink() + if destination is not None and await storage.is_file(destination_key): + destination.parent.mkdir(parents=True, exist_ok=True) + destination.write_bytes(await storage.read_bytes(destination_key)) source_revision = await record_revision( db, @@ -467,6 +573,17 @@ async def move_workspace_path( ) +async def _collect_storage_tree_versions(storage, root_key: str) -> list[tuple[str, str]]: + keys: list[tuple[str, str]] = [] + for entry in await storage.list_dir(root_key): + if entry.is_dir: + keys.extend(await _collect_storage_tree_versions(storage, entry.key)) + else: + version = await storage.get_version(entry.key) + keys.append((entry.key, version.token)) + return keys + + async def list_revisions( db: AsyncSession, *, diff --git a/backend/app/services/workspace_locking.py b/backend/app/services/workspace_locking.py new file mode 100644 index 000000000..ceafb43f1 --- /dev/null +++ b/backend/app/services/workspace_locking.py @@ -0,0 +1,80 @@ +"""Redis-backed short-lived locks for workspace mutations.""" + +from __future__ import annotations + +import uuid +from contextlib import asynccontextmanager + +from app.core.events import get_redis + +LOCK_PREFIX = "workspace-lock" +DEFAULT_LOCK_TTL_SECONDS = 60 + +_RELEASE_IF_OWNER_SCRIPT = """ +if redis.call('get', KEYS[1]) == ARGV[1] then + return redis.call('del', KEYS[1]) +end +return 0 +""" + + +def _normalize_workspace_path(path: str) -> str: + clean = (path or "").replace("\\", "/").strip().lstrip("/") + parts: list[str] = [] + for part in clean.split("/"): + if part in ("", "."): + continue + if part == "..": + if parts: + parts.pop() + continue + parts.append(part) + return "/".join(parts) + + +def _lock_key(agent_id: uuid.UUID, path: str) -> str: + normalized = _normalize_workspace_path(path) or "." + return f"{LOCK_PREFIX}:{agent_id}:{normalized}" + + +async def acquire_workspace_lock( + agent_id: uuid.UUID, + path: str, + *, + owner_token: str, + ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS, +) -> bool: + redis = await get_redis() + return bool(await redis.set(_lock_key(agent_id, path), owner_token, ex=ttl_seconds, nx=True)) + + +async def release_workspace_lock(agent_id: uuid.UUID, path: str, *, owner_token: str) -> None: + redis = await get_redis() + await redis.eval(_RELEASE_IF_OWNER_SCRIPT, 1, _lock_key(agent_id, path), owner_token) + + +@asynccontextmanager +async def workspace_locks( + agent_id: uuid.UUID, + paths: list[str], + *, + ttl_seconds: int = DEFAULT_LOCK_TTL_SECONDS, +): + normalized = sorted({_normalize_workspace_path(path) or "." for path in paths if path is not None}) + owner_token = uuid.uuid4().hex + acquired: list[str] = [] + try: + for path in normalized: + ok = await acquire_workspace_lock( + agent_id, + path, + owner_token=owner_token, + ttl_seconds=ttl_seconds, + ) + if not ok: + raise RuntimeError(f"Workspace lock busy: {path}") + acquired.append(path) + yield + finally: + for path in reversed(acquired): + await release_workspace_lock(agent_id, path, owner_token=owner_token) diff --git a/backend/entrypoint.sh b/backend/entrypoint.sh index a11135649..a38723a9d 100755 --- a/backend/entrypoint.sh +++ b/backend/entrypoint.sh @@ -1,11 +1,19 @@ #!/bin/bash -# Docker entrypoint: run DB migrations, then start the app. -# Order matters: -# 1. alembic upgrade head — apply all migrations (creates tables + schema changes) -# 2. uvicorn — starts the FastAPI app +# Docker entrypoint: optionally run DB migrations, then start the app. set -e +PROCESS_ROLE="${PROCESS_ROLE:-all}" +ALLOW_MIGRATION_FAILURE="${ALLOW_MIGRATION_FAILURE:-false}" +START_COMMAND="${START_COMMAND:-uvicorn app.main:app --host 0.0.0.0 --port 8000}" + +role_contains() { + case ",${PROCESS_ROLE}," in + *,all,*|*,"$1",*) return 0 ;; + *) return 1 ;; + esac +} + # --- Permission fixing and privilege dropping --- if [ "$(id -u)" = '0' ]; then echo "[entrypoint] Detected root user, fixing permissions..." @@ -16,38 +24,38 @@ if [ "$(id -u)" = '0' ]; then fi # ------------------------------------------------------- -echo "[entrypoint] Step 1: Running alembic migrations..." -# Run all migrations to ensure database schema is up to date. -# Capture exit code explicitly — do NOT let a migration failure go unnoticed. -set +e -ALEMBIC_OUTPUT=$(alembic upgrade head 2>&1) -ALEMBIC_EXIT=$? -set -e +if [ -z "${INSTANCE_ID:-}" ]; then + SAFE_PROCESS_ROLE="${PROCESS_ROLE//,/-}" + export INSTANCE_ID="${SAFE_PROCESS_ROLE}-$(hostname)" +fi +echo "[entrypoint] INSTANCE_ID=${INSTANCE_ID}" + +if role_contains "bootstrap"; then + echo "[entrypoint] Step 1: Running alembic migrations for PROCESS_ROLE=${PROCESS_ROLE}..." + set +e + ALEMBIC_OUTPUT=$(alembic upgrade head 2>&1) + ALEMBIC_EXIT=$? + set -e -if [ $ALEMBIC_EXIT -ne 0 ]; then - echo "" - echo "========================================================================" - echo "[entrypoint] WARNING: Alembic migration FAILED (exit code $ALEMBIC_EXIT)" - echo "========================================================================" - echo "" - echo "$ALEMBIC_OUTPUT" - echo "" - echo "------------------------------------------------------------------------" - echo " The database schema may be INCOMPLETE. Some features will NOT work." - echo " Common causes:" - echo " - Migration cycle detected (pull latest code to fix)" - echo " - Database connection issue" - echo " - Incompatible migration state" - echo "" - echo " To fix: pull the latest code and restart the backend." - echo " Docker: git pull && docker compose restart backend" - echo " Source: git pull && alembic upgrade head" - echo "------------------------------------------------------------------------" - echo "" - echo "[entrypoint] Continuing startup despite migration failure..." + if [ $ALEMBIC_EXIT -ne 0 ]; then + echo "" + echo "========================================================================" + echo "[entrypoint] ERROR: Alembic migration FAILED (exit code $ALEMBIC_EXIT)" + echo "========================================================================" + echo "" + echo "$ALEMBIC_OUTPUT" + echo "" + if [ "$ALLOW_MIGRATION_FAILURE" = "true" ]; then + echo "[entrypoint] Continuing because ALLOW_MIGRATION_FAILURE=true" + else + exit $ALEMBIC_EXIT + fi + else + echo "[entrypoint] Alembic migrations completed successfully." + fi else - echo "[entrypoint] Alembic migrations completed successfully." + echo "[entrypoint] Step 1: Skipping alembic for PROCESS_ROLE=${PROCESS_ROLE}" fi echo "[entrypoint] Step 2: Starting uvicorn..." -exec uvicorn app.main:app --host 0.0.0.0 --port 8000 +exec /bin/bash -lc "$START_COMMAND" diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 1ec45860f..cb828c48d 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -43,6 +43,8 @@ dependencies = [ "weasyprint>=62.0", "markdown>=3.6", "beautifulsoup4>=4.12.0", + "boto3>=1.35.0", + "aioboto3>=13.0.0", ] [project.optional-dependencies] diff --git a/backend/tests/test_a2a_msg_type.py b/backend/tests/test_a2a_msg_type.py index 465933909..8c34f77d9 100644 --- a/backend/tests/test_a2a_msg_type.py +++ b/backend/tests/test_a2a_msg_type.py @@ -384,37 +384,41 @@ async def test_no_relationship_returns_error(): @pytest.mark.asyncio -async def test_append_focus_item_creates_file(tmp_path): +async def test_append_focus_item_creates_file(): """_append_focus_item should create/append to focus.md.""" from app.services.agent_tools import _append_focus_item agent_id = uuid.uuid4() - with patch("app.services.agent_tools.WORKSPACE_ROOT", tmp_path): + storage = AsyncMock() + storage.exists.return_value = False + storage.is_file.return_value = False + + with patch("app.services.agent_tools.get_storage_backend", return_value=storage): await _append_focus_item(agent_id, "test_item", "Test description") - focus_path = tmp_path / str(agent_id) / "focus.md" - assert focus_path.exists() - content = focus_path.read_text() - assert "test_item" in content - assert "Test description" in content - assert "- [ ]" in content + storage.write_text.assert_awaited_once() + key, content = storage.write_text.await_args.args[:2] + assert key == f"{agent_id}/focus.md" + assert "test_item" in content + assert "Test description" in content + assert "- [ ]" in content @pytest.mark.asyncio -async def test_append_focus_item_no_duplicate(tmp_path): +async def test_append_focus_item_no_duplicate(): """_append_focus_item should not duplicate existing items.""" from app.services.agent_tools import _append_focus_item agent_id = uuid.uuid4() - focus_path = tmp_path / str(agent_id) / "focus.md" - focus_path.parent.mkdir(parents=True, exist_ok=True) - focus_path.write_text("# Focus\n\n- [ ] test_item: Existing description\n") + storage = AsyncMock() + storage.exists.return_value = True + storage.is_file.return_value = True + storage.read_text.return_value = "# Focus\n\n- [ ] test_item: Existing description\n" - with patch("app.services.agent_tools.WORKSPACE_ROOT", tmp_path): + with patch("app.services.agent_tools.get_storage_backend", return_value=storage): await _append_focus_item(agent_id, "test_item", "New description") - content = focus_path.read_text() - assert content.count("test_item") == 1 + storage.write_text.assert_not_awaited() @pytest.mark.asyncio diff --git a/backend/tests/test_agent_context.py b/backend/tests/test_agent_context.py new file mode 100644 index 000000000..fac63fc03 --- /dev/null +++ b/backend/tests/test_agent_context.py @@ -0,0 +1,26 @@ +import uuid +from unittest.mock import AsyncMock, patch + +import pytest + + +@pytest.mark.asyncio +async def test_build_agent_context_reads_focus_from_storage_key(): + from app.services.agent_context import build_agent_context + + agent_id = uuid.uuid4() + + async def fake_read_file(key, _max_chars=3000): + if key == f"{agent_id}/focus.md": + return "# Focus\n\n- [ ] follow_up: Check the deployment" + return "" + + with ( + patch("app.services.agent_context._read_file_safe", side_effect=fake_read_file), + patch("app.services.agent_context._load_skills_index", new_callable=AsyncMock, return_value=""), + patch("app.services.timezone_utils.get_agent_timezone", new_callable=AsyncMock, return_value="UTC"), + ): + _static, dynamic = await build_agent_context(agent_id, "TestAgent") + + assert "## Focus" in dynamic + assert "follow_up: Check the deployment" in dynamic diff --git a/backend/tests/test_agent_tools_storage_workspace.py b/backend/tests/test_agent_tools_storage_workspace.py new file mode 100644 index 000000000..0d97bb870 --- /dev/null +++ b/backend/tests/test_agent_tools_storage_workspace.py @@ -0,0 +1,310 @@ +import uuid + +import pytest + +from app.services import agent_tools +from app.services import workspace_collaboration +from app.services.storage_runtime.base import StorageBackend, StorageEntry, StorageVersion, WriteCondition, ConditionalWriteResult + + +class MemoryStorageBackend(StorageBackend): + def __init__(self, files: dict[str, bytes] | None = None): + self.files = dict(files or {}) + self.versions = {key: 1 for key in self.files} + + async def exists(self, key: str) -> bool: + return key in self.files + + async def is_file(self, key: str) -> bool: + return key in self.files + + async def is_dir(self, key: str) -> bool: + prefix = key.rstrip("/") + "/" + return any(existing.startswith(prefix) for existing in self.files) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = key.rstrip("/") + "/" + entries: dict[str, StorageEntry] = {} + for existing, data in self.files.items(): + if not existing.startswith(prefix): + continue + rest = existing.removeprefix(prefix) + name, _, tail = rest.partition("/") + entries[name] = StorageEntry( + name=name, + key=f"{prefix}{name}", + is_dir=bool(tail), + size=0 if tail else len(data), + ) + return sorted(entries.values(), key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + return self.files[key] + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + self.files[key] = data + self.versions[key] = self.versions.get(key, 0) + 1 + + async def delete(self, key: str) -> None: + self.files.pop(key, None) + self.versions.pop(key, None) + + async def delete_tree(self, key: str) -> None: + prefix = key.rstrip("/") + "/" + for existing in list(self.files): + if existing.startswith(prefix): + self.files.pop(existing) + self.versions.pop(existing, None) + + async def stat(self, key: str) -> StorageEntry: + return StorageEntry(name=key.rsplit("/", 1)[-1], key=key, is_dir=False, size=len(self.files[key])) + + async def get_version(self, key: str) -> StorageVersion: + if key not in self.files: + return StorageVersion(key=key, exists=False, is_dir=False) + version = str(self.versions.get(key, 0)) + return StorageVersion( + key=key, + exists=True, + is_dir=False, + size=len(self.files[key]), + version_id=version, + etag=version, + content_hash=version, + ) + + async def write_bytes_if_match( + self, + key: str, + data: bytes, + *, + condition: WriteCondition | None = None, + content_type: str | None = None, + ) -> ConditionalWriteResult: + current = await self.get_version(key) + if condition: + if condition.require_absent and current.exists: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + if condition.version_token is not None and current.token != condition.version_token: + return ConditionalWriteResult(ok=False, conflict=True, current_version=current) + await self.write_bytes(key, data, content_type=content_type) + return ConditionalWriteResult(ok=True, current_version=await self.get_version(key)) + + +@pytest.mark.asyncio +async def test_agent_file_tools_use_storage_paths(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/notes.md": b"# Notes\nneedle\n", + f"{agent_id}/memory/memory.md": b"# Memory\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + listing = await agent_tools._storage_list_dir(agent_id, "workspace") + read = await agent_tools._storage_read_file(agent_id, "workspace/notes.md") + search = await agent_tools._storage_search_files(agent_id, "needle", path="workspace", file_pattern="*.md") + found = await agent_tools._storage_find_files(agent_id, "*.md", path="workspace") + + assert "notes.md" in listing + assert "needle" in read + assert "workspace/notes.md:2" in search + assert "workspace/notes.md" in found + + +@pytest.mark.asyncio +async def test_temp_workspace_materializes_only_requested_paths(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + f"{agent_id}/workspace/other.md": b"# Other\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace/input.md"]) + try: + assert (temp_ws.root / "workspace" / "input.md").read_text(encoding="utf-8") == "# Input\n" + assert not (temp_ws.root / "workspace" / "other.md").exists() + finally: + temp_ws.cleanup() + + +@pytest.mark.asyncio +async def test_execute_tool_list_files_does_not_create_persistent_workspace(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + monkeypatch.setattr(agent_tools, "WORKSPACE_ROOT", tmp_path) + + async def _tenant(_agent_id): + return None + + monkeypatch.setattr(agent_tools, "_get_agent_tenant_id", _tenant) + + result = await agent_tools.execute_tool("list_files", {"path": "workspace"}, agent_id, agent_id) + + assert "input.md" in result + assert not (tmp_path / str(agent_id)).exists() + + +@pytest.mark.asyncio +async def test_write_workspace_file_does_not_mirror_to_local_for_non_local_storage(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend() + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + result = await workspace_collaboration.write_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/test.md", + content="hello", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + ) + + assert result.ok is True + assert storage.files[f"{agent_id}/workspace/test.md"] == b"hello" + assert not (tmp_path / str(agent_id) / "workspace" / "test.md").exists() + + +@pytest.mark.asyncio +async def test_flush_temp_workspace_only_writes_changed_files(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + f"{agent_id}/workspace/other.md": b"# Other\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace"]) + try: + (temp_ws.root / "workspace" / "input.md").write_text("# Updated\n", encoding="utf-8") + result = await agent_tools.flush_temp_workspace(temp_ws) + finally: + temp_ws.cleanup() + + assert result["updated"] == ["workspace/input.md"] + assert "workspace/other.md" in result["skipped"] + assert storage.files[f"{agent_id}/workspace/input.md"] == b"# Updated\n" + assert storage.files[f"{agent_id}/workspace/other.md"] == b"# Other\n" + + +@pytest.mark.asyncio +async def test_flush_temp_workspace_fails_on_conflict(monkeypatch): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/input.md": b"# Input\n", + }) + monkeypatch.setattr(agent_tools, "get_storage_backend", lambda: storage) + + temp_ws = await agent_tools._prepare_temp_workspace(agent_id, paths=["workspace/input.md"]) + try: + (temp_ws.root / "workspace" / "input.md").write_text("# Local change\n", encoding="utf-8") + await storage.write_bytes(f"{agent_id}/workspace/input.md", b"# Remote change\n") + result = await agent_tools.flush_temp_workspace(temp_ws) + finally: + temp_ws.cleanup() + + assert result["conflicted"] == ["workspace/input.md"] + assert storage.files[f"{agent_id}/workspace/input.md"] == b"# Remote change\n" + + +@pytest.mark.asyncio +async def test_write_workspace_file_fails_on_expected_version_conflict(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/test.md": b"old", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + version = await storage.get_version(f"{agent_id}/workspace/test.md") + await storage.write_bytes(f"{agent_id}/workspace/test.md", b"remote-new") + result = await workspace_collaboration.write_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/test.md", + content="local-new", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + expected_version_token=version.token, + ) + + assert result.ok is False + assert "Conflict detected" in result.message + assert storage.files[f"{agent_id}/workspace/test.md"] == b"remote-new" + + +@pytest.mark.asyncio +async def test_move_workspace_path_fails_when_source_changes(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/source.md": b"old", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + version = await storage.get_version(f"{agent_id}/workspace/source.md") + await storage.write_bytes(f"{agent_id}/workspace/source.md", b"remote-new") + result = await workspace_collaboration.move_workspace_path( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + source_path="workspace/source.md", + destination_path="workspace/dest.md", + actor_type="agent", + actor_id=agent_id, + enforce_human_lock=False, + expected_source_version_token=version.token, + ) + + assert result.ok is False + assert "Conflict detected" in result.message + assert f"{agent_id}/workspace/dest.md" not in storage.files + + +@pytest.mark.asyncio +async def test_delete_workspace_directory_uses_prefix_existence(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = MemoryStorageBackend({ + f"{agent_id}/workspace/dir/a.txt": b"a", + f"{agent_id}/workspace/dir/nested/b.txt": b"b", + }) + monkeypatch.setattr(workspace_collaboration, "get_storage_backend", lambda: storage) + + async def _noop_revision(*args, **kwargs): + return None + + monkeypatch.setattr(workspace_collaboration, "record_revision", _noop_revision) + + result = await workspace_collaboration.delete_workspace_file( + db=None, + agent_id=agent_id, + base_dir=tmp_path / str(agent_id), + path="workspace/dir", + actor_type="user", + actor_id=agent_id, + enforce_human_lock=False, + ) + + assert result.ok is True + assert f"{agent_id}/workspace/dir/a.txt" not in storage.files + assert f"{agent_id}/workspace/dir/nested/b.txt" not in storage.files diff --git a/backend/tests/test_files_api_storage.py b/backend/tests/test_files_api_storage.py new file mode 100644 index 000000000..f791c16a0 --- /dev/null +++ b/backend/tests/test_files_api_storage.py @@ -0,0 +1,156 @@ +import uuid +from types import SimpleNamespace + +import pytest + +from app.api import files +from app.services.agent_manager import AgentManager +from app.services.storage_runtime.base import StorageBackend, StorageEntry, StorageVersion + + +class PrefixOnlyStorage(StorageBackend): + def __init__(self, objects: dict[str, bytes] | None = None): + self.objects = dict(objects or {}) + + async def exists(self, key: str) -> bool: + return key in self.objects + + async def is_file(self, key: str) -> bool: + return key in self.objects + + async def is_dir(self, key: str) -> bool: + prefix = key.rstrip("/") + "/" + return any(existing.startswith(prefix) for existing in self.objects) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = key.rstrip("/") + "/" + entries_by_name: dict[str, StorageEntry] = {} + for existing, data in self.objects.items(): + if not existing.startswith(prefix): + continue + rest = existing.removeprefix(prefix) + name, _, tail = rest.partition("/") + entries_by_name[name] = StorageEntry( + name=name, + key=f"{prefix}{name}", + is_dir=bool(tail), + size=0 if tail else len(data), + ) + return sorted(entries_by_name.values(), key=lambda entry: (not entry.is_dir, entry.name)) + + async def read_bytes(self, key: str) -> bytes: + return self.objects[key] + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + self.objects[key] = data + + async def delete(self, key: str) -> None: + self.objects.pop(key, None) + + async def delete_tree(self, key: str) -> None: + prefix = key.rstrip("/") + "/" + for existing in list(self.objects): + if existing.startswith(prefix): + self.objects.pop(existing, None) + + async def stat(self, key: str) -> StorageEntry: + if key not in self.objects: + raise FileNotFoundError(key) + return StorageEntry(name=key.rsplit("/", 1)[-1], key=key, is_dir=False, size=len(self.objects[key])) + + async def get_version(self, key: str) -> StorageVersion: + if key not in self.objects: + return StorageVersion(key=key, exists=False, is_dir=False) + token = f"v:{len(self.objects[key])}" + return StorageVersion( + key=key, + exists=True, + is_dir=False, + size=len(self.objects[key]), + version_id=token, + etag=token, + content_hash=token, + ) + + +@pytest.mark.asyncio +async def test_list_files_accepts_s3_prefix_directory(monkeypatch): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/focus.md": b"# Focus\n"}) + monkeypatch.setattr(files, "get_storage_backend", lambda: storage) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + result = await files.list_files(agent_id, path="", current_user=user, db=None) + + assert [item.name for item in result] == ["focus.md"] + assert result[0].path == "focus.md" + assert result[0].version_token == "v:8" + + +@pytest.mark.asyncio +async def test_list_files_allows_empty_agent_root(monkeypatch): + agent_id = uuid.uuid4() + monkeypatch.setattr(files, "get_storage_backend", lambda: PrefixOnlyStorage()) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + assert await files.list_files(agent_id, path="", current_user=user, db=None) == [] + + +@pytest.mark.asyncio +async def test_read_file_returns_version_token(monkeypatch): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/focus.md": b"# Focus\n"}) + monkeypatch.setattr(files, "get_storage_backend", lambda: storage) + + async def allow_access(*args, **kwargs): + return None + + monkeypatch.setattr(files, "check_agent_access", allow_access) + user = SimpleNamespace(tenant_id=None) + + result = await files.read_file(agent_id, path="focus.md", current_user=user, db=None) + + assert result.version_token == "v:8" + + +@pytest.mark.asyncio +async def test_agent_manager_does_not_reinitialize_s3_prefix_directory(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({f"{agent_id}/soul.md": b"existing"}) + monkeypatch.setattr("app.services.agent_manager.get_storage_backend", lambda: storage) + monkeypatch.setattr("app.services.agent_manager.settings.STORAGE_LOCAL_ROOT", str(tmp_path)) + + manager = AgentManager() + agent = SimpleNamespace(id=agent_id) + + await manager.initialize_agent_files(db=None, agent=agent) + + assert storage.objects[f"{agent_id}/soul.md"] == b"existing" + + +@pytest.mark.asyncio +async def test_agent_manager_materializes_s3_prefix_directory(monkeypatch, tmp_path): + agent_id = uuid.uuid4() + storage = PrefixOnlyStorage({ + f"{agent_id}/soul.md": b"# Soul\n", + f"{agent_id}/memory/memory.md": b"# Memory\n", + }) + monkeypatch.setattr("app.services.agent_manager.get_storage_backend", lambda: storage) + monkeypatch.setattr("app.services.agent_manager.settings.STORAGE_LOCAL_ROOT", str(tmp_path)) + + manager = AgentManager() + + agent_dir = await manager._materialize_agent_dir(agent_id) + + assert (agent_dir / "soul.md").read_text(encoding="utf-8") == "# Soul\n" + assert (agent_dir / "memory" / "memory.md").read_text(encoding="utf-8") == "# Memory\n" diff --git a/backend/tests/test_org_sync_adapter.py b/backend/tests/test_org_sync_adapter.py index d4fa71cfa..41f9b1443 100644 --- a/backend/tests/test_org_sync_adapter.py +++ b/backend/tests/test_org_sync_adapter.py @@ -1,6 +1,7 @@ import asyncio import uuid from contextlib import asynccontextmanager +from datetime import datetime, timezone from types import SimpleNamespace import pytest @@ -43,6 +44,14 @@ async def flush(self): self.flush_calls += 1 +class _RecordingExecuteDB: + def __init__(self): + self.statements = [] + + async def execute(self, statement): + self.statements.append(statement) + + class _SyncAdapterWithFailure(_DummyAdapter): def __init__(self): super().__init__() @@ -115,6 +124,17 @@ def test_sync_org_structure_skips_reconcile_after_member_failure(): assert "Reconcile skipped due to partial sync failures" in result["errors"] +def test_reconcile_disables_session_synchronization_for_datetime_comparisons(): + adapter = _DummyAdapter() + db = _RecordingExecuteDB() + + asyncio.run(adapter._reconcile(db, uuid.uuid4(), datetime.now(timezone.utc))) + + assert len(db.statements) == 2 + for statement in db.statements: + assert statement.get_execution_options()["synchronize_session"] is False + + def test_google_workspace_adapter_parses_legacy_service_account_json_string(): adapter = GoogleWorkspaceOrgSyncAdapter( config={ diff --git a/backend/tests/test_storage_fallback.py b/backend/tests/test_storage_fallback.py new file mode 100644 index 000000000..21286eb32 --- /dev/null +++ b/backend/tests/test_storage_fallback.py @@ -0,0 +1,70 @@ +from app.services.storage_runtime.base import StorageBackend, StorageEntry +from app.services.storage_runtime.fallback import FallbackStorageBackend + + +class MemoryStorageBackend(StorageBackend): + def __init__(self, files: dict[str, bytes] | None = None): + self.files = dict(files or {}) + + async def exists(self, key: str) -> bool: + return key in self.files + + async def is_file(self, key: str) -> bool: + return key in self.files + + async def is_dir(self, key: str) -> bool: + prefix = key.rstrip("/") + "/" + return any(existing.startswith(prefix) for existing in self.files) + + async def list_dir(self, key: str) -> list[StorageEntry]: + prefix = key.rstrip("/") + "/" + entries = [] + for existing, data in self.files.items(): + if existing.startswith(prefix): + name = existing.removeprefix(prefix).split("/", 1)[0] + entries.append(StorageEntry(name=name, key=f"{prefix}{name}", is_dir=False, size=len(data))) + return entries + + async def read_bytes(self, key: str) -> bytes: + if key not in self.files: + raise FileNotFoundError(key) + return self.files[key] + + async def write_bytes(self, key: str, data: bytes, content_type: str | None = None) -> None: + self.files[key] = data + + async def delete(self, key: str) -> None: + self.files.pop(key, None) + + async def delete_tree(self, key: str) -> None: + prefix = key.rstrip("/") + "/" + for existing in list(self.files): + if existing.startswith(prefix): + self.files.pop(existing, None) + + async def stat(self, key: str) -> StorageEntry: + if key not in self.files: + raise FileNotFoundError(key) + return StorageEntry(name=key.rsplit("/", 1)[-1], key=key, is_dir=False, size=len(self.files[key])) + + +async def test_fallback_storage_backfills_primary_on_read(): + primary = MemoryStorageBackend() + fallback = MemoryStorageBackend({"agent-id/focus.md": b"# Focus\n\n- [ ] migrate me\n"}) + storage = FallbackStorageBackend(primary=primary, fallback=fallback) + + content = await storage.read_text("agent-id/focus.md") + + assert "migrate me" in content + assert primary.files["agent-id/focus.md"] == fallback.files["agent-id/focus.md"] + + +async def test_fallback_storage_writes_only_to_primary(): + primary = MemoryStorageBackend() + fallback = MemoryStorageBackend() + storage = FallbackStorageBackend(primary=primary, fallback=fallback) + + await storage.write_text("agent-id/focus.md", "# Focus\n") + + assert "agent-id/focus.md" in primary.files + assert "agent-id/focus.md" not in fallback.files diff --git a/backend/tests/test_storage_s3.py b/backend/tests/test_storage_s3.py new file mode 100644 index 000000000..6c7dc3329 --- /dev/null +++ b/backend/tests/test_storage_s3.py @@ -0,0 +1,44 @@ +from unittest.mock import Mock + +from app.services.storage_runtime.s3 import S3StorageBackend + + +def test_s3_backend_passes_max_pool_connections(monkeypatch): + config_instances: list[object] = [] + client_calls: list[dict] = [] + + class FakeConfig: + def __init__(self, **kwargs): + self.kwargs = kwargs + config_instances.append(self) + + fake_boto3 = Mock() + fake_boto3.client.side_effect = lambda *args, **kwargs: client_calls.append(kwargs) or object() + + import builtins + + real_import = builtins.__import__ + + def fake_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "boto3": + return fake_boto3 + if name == "botocore.config": + return type("FakeBotocoreConfigModule", (), {"Config": FakeConfig})() + return real_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + backend = S3StorageBackend( + bucket="bucket", + endpoint_url="http://minio:9000", + access_key_id="key", + secret_access_key="secret", + max_pool_connections=64, + ) + + backend._client_or_raise() + + assert len(config_instances) == 1 + assert config_instances[0].kwargs["max_pool_connections"] == 64 + assert len(client_calls) == 1 + assert client_calls[0]["config"] is config_instances[0] diff --git a/deploy/.env.example b/deploy/.env.example new file mode 100644 index 000000000..efd6ecb2e --- /dev/null +++ b/deploy/.env.example @@ -0,0 +1,67 @@ +# Clawith Environment Variables +# Copy this file to .env and fill in the values + +# Security +SECRET_KEY=change-me-in-production +JWT_SECRET_KEY=change-me-jwt-secret + +# Database (auto-configured by setup.sh; override for custom setups) +# For local dev, ssl=disable is required to prevent asyncpg SSL negotiation hang +# DATABASE_URL=postgresql+asyncpg://clawith:clawith@localhost:5432/clawith?ssl=disable + +# Redis +# REDIS_URL=redis://localhost:6379/0 + +# Feishu OAuth (optional, for SSO login) +FEISHU_APP_ID= +FEISHU_APP_SECRET= +FEISHU_REDIRECT_URI=http://localhost:3000/auth/feishu/callback + +# Agent workspace data directory. +# Default: local host -> ~/.clawith/data/agents ; container runtime -> /data/agents +# AGENT_DATA_DIR= + +# File storage backend. Use "s3" for S3-compatible object storage. +# When STORAGE_BACKEND=s3, local fallback lets old files under STORAGE_LOCAL_ROOT +# be read and copied into S3 on first access during migration. +# STORAGE_BACKEND=local +# STORAGE_LOCAL_ROOT= +# STORAGE_LOCAL_FALLBACK_ENABLED=true +# S3_BUCKET= +# S3_REGION= +# S3_ENDPOINT_URL= +# S3_ACCESS_KEY_ID= +# S3_SECRET_ACCESS_KEY= +# S3_PREFIX=agents +# S3_MAX_POOL_CONNECTIONS=50 +# S3_WRITE_WORKERS=32 + + +API_PORT=8000 + +# Jina AI API key (for jina_search and jina_read tools — get one at https://jina.ai) +# Without a key, the tools still work but with lower rate limits +JINA_API_KEY= + +# Exa API key (for exa_search tool and web_search Exa engine — get one at https://exa.ai) +EXA_API_KEY= + +# Public app URL used in user-facing links, such as password reset emails. +# Leave empty for auto-discovery from the browser request. +# Set explicitly for production (e.g. https://your-domain.com) — required for +# background tasks like webhook URLs and email links that have no request context. +PUBLIC_BASE_URL= + + +# Password reset token lifetime in minutes +PASSWORD_RESET_TOKEN_EXPIRE_MINUTES=30 + +# Frontend port (default: 3008) +# FRONTEND_PORT=3008 + +# API upstream for nginx proxy (default: backend:8000) +# API_UPSTREAM=backend:8000 + +# Python pip index URL (for China mirrors) +# CLAWITH_PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple +# CLAWITH_PIP_TRUSTED_HOST=pypi.tuna.tsinghua.edu.cn diff --git a/deploy/docker-compose-multi.yml b/deploy/docker-compose-multi.yml new file mode 100644 index 000000000..f8ba4c73c --- /dev/null +++ b/deploy/docker-compose-multi.yml @@ -0,0 +1,300 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + networks: + - default + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U clawith"] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + networks: + - default + volumes: + - redisdata:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 5s + retries: 5 + + minio: + image: minio/minio:RELEASE.2025-04-22T22-12-26Z + restart: unless-stopped + command: server /data --console-address ":9001" + networks: + - default + environment: + MINIO_ROOT_USER: ${MINIO_ROOT_USER:-clawith} + MINIO_ROOT_PASSWORD: ${MINIO_ROOT_PASSWORD:-clawith-minio-secret} + volumes: + - miniodata:/data + healthcheck: + test: ["CMD-SHELL", "curl -fsS http://127.0.0.1:9000/minio/health/live >/dev/null"] + interval: 10s + timeout: 5s + retries: 5 + + backend-api-1: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: api + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + default: + aliases: + - backend + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-api-2: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: api + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + default: + aliases: + - backend + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-worker: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: bootstrap,worker + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + backend-connector: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + STORAGE_LOCAL_FALLBACK_ENABLED: ${STORAGE_LOCAL_FALLBACK_ENABLED:-true} + S3_BUCKET: ${S3_BUCKET:-clawith} + S3_REGION: ${S3_REGION:-us-east-1} + S3_ENDPOINT_URL: ${S3_ENDPOINT_URL:-http://minio:9000} + S3_ACCESS_KEY_ID: ${S3_ACCESS_KEY_ID:-${MINIO_ROOT_USER:-clawith}} + S3_SECRET_ACCESS_KEY: ${S3_SECRET_ACCESS_KEY:-${MINIO_ROOT_PASSWORD:-clawith-minio-secret}} + S3_PREFIX: ${S3_PREFIX:-agents} + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: connector + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + minio: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + + frontend: + build: ./frontend + restart: unless-stopped + ports: + - "${FRONTEND_PORT:-3008}:3000" + environment: + VITE_API_URL: http://localhost:8000 + API_UPSTREAM: ${API_UPSTREAM:-backend:8000} + volumes: + - ./frontend/nginx.conf.template:/etc/nginx/templates/default.conf.template:ro + networks: + - default + depends_on: + - backend-api-1 + - backend-api-2 + +volumes: + pgdata: + redisdata: + miniodata: + +networks: + default: + name: clawith_network diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml new file mode 100644 index 000000000..ba0a52430 --- /dev/null +++ b/deploy/docker-compose.yml @@ -0,0 +1,104 @@ +services: + postgres: + image: postgres:15-alpine + restart: unless-stopped + networks: + - default + environment: + POSTGRES_USER: clawith + POSTGRES_PASSWORD: clawith + POSTGRES_DB: clawith + volumes: + - pgdata:/var/lib/postgresql/data + healthcheck: + test: [ "CMD-SHELL", "pg_isready -U clawith" ] + interval: 5s + timeout: 5s + retries: 5 + + redis: + image: redis:7-alpine + restart: unless-stopped + networks: + - default + volumes: + - redisdata:/data + healthcheck: + test: [ "CMD", "redis-cli", "ping" ] + interval: 5s + timeout: 5s + retries: 5 + + backend: + build: + context: ./backend + args: + CLAWITH_PIP_INDEX_URL: ${CLAWITH_PIP_INDEX_URL:-} + CLAWITH_PIP_TRUSTED_HOST: ${CLAWITH_PIP_TRUSTED_HOST:-} + restart: unless-stopped + command: ["/bin/bash", "/app/entrypoint.sh"] + environment: + DATABASE_URL: postgresql+asyncpg://clawith:clawith@postgres:5432/clawith + REDIS_URL: redis://redis:6379/0 + AGENT_DATA_DIR: /data/agents + AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents + SECRET_KEY: ${SECRET_KEY:-change-me-in-production} + JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: all + CORS_ORIGINS: '["*"]' + FEISHU_APP_ID: ${FEISHU_APP_ID:-} + FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} + DOCKER_NETWORK: clawith_yaojin_network + SS_CONFIG_FILE: /data/ss-nodes.json + # Public base URL for constructing OAuth callback URLs and email links. + # Required when deployed behind a reverse proxy (e.g. Nginx, Cloudflare). + # If not set, the server infers the URL from the incoming request host. + PUBLIC_BASE_URL: ${PUBLIC_BASE_URL:-} + # Password reset token lifetime in minutes (default: 30) + PASSWORD_RESET_TOKEN_EXPIRE_MINUTES: ${PASSWORD_RESET_TOKEN_EXPIRE_MINUTES:-30} + volumes: + - ./backend:/app + - ./backend/agent_data:/data/agents + - /var/run/docker.sock:/var/run/docker.sock + - ./ss-nodes.json:/data/ss-nodes.json:ro + cap_add: + - SYS_ADMIN + security_opt: + - seccomp=unconfined + networks: + - default + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + logging: + driver: json-file + options: + max-size: "10m" + max-file: "3" + frontend: + build: ./frontend + restart: unless-stopped + ports: + - "${FRONTEND_PORT:-3008}:3000" + environment: + VITE_API_URL: http://localhost:8000 + API_UPSTREAM: ${API_UPSTREAM:-backend:8000} + volumes: + - ./nginx/nginx.conf:/etc/nginx/templates/default.conf.template:ro + networks: + - default + depends_on: + - backend + +volumes: + pgdata: + redisdata: + + +networks: + default: + name: clawith_network diff --git a/frontend/nginx.conf b/deploy/nginx/nginx.conf similarity index 87% rename from frontend/nginx.conf rename to deploy/nginx/nginx.conf index eeda2831c..35dd4e33a 100644 --- a/frontend/nginx.conf +++ b/deploy/nginx/nginx.conf @@ -4,6 +4,9 @@ server { root /usr/share/nginx/html; index index.html; + resolver 127.0.0.11 ipv6=off valid=30s; + set $api_upstream http://${API_UPSTREAM:-backend:8000}; + # Allow large file uploads (up to 100MB) client_max_body_size 500m; @@ -22,7 +25,7 @@ server { # API proxy location /api/ { - proxy_pass http://backend:8000; + proxy_pass $api_upstream; proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-Proto $scheme; @@ -33,7 +36,7 @@ server { # Public pages proxy (no auth, served by backend) location /p/ { - proxy_pass http://backend:8000; + proxy_pass $api_upstream; proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-Proto $scheme; @@ -44,14 +47,14 @@ server { # This rule intercepts these requests before the SPA fallback and proxies them # to the backend, which serves the file content stored in the tenant's config. location ~ ^/WW_verify_[A-Za-z0-9]+\.txt$ { - proxy_pass http://backend:8000/api/wecom-verify$request_uri; + proxy_pass $api_upstream/api/wecom-verify$request_uri; proxy_set_header Host $http_host; proxy_set_header X-Real-IP $remote_addr; } # WebSocket proxy location /ws/ { - proxy_pass http://backend:8000; + proxy_pass $api_upstream; proxy_http_version 1.1; proxy_set_header Upgrade $http_upgrade; proxy_set_header Connection "upgrade"; diff --git a/docker-compose.yml b/docker-compose.yml index 7c5ae9229..65b3bb543 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -42,8 +42,11 @@ services: REDIS_URL: redis://redis:6379/0 AGENT_DATA_DIR: /data/agents AGENT_TEMPLATE_DIR: /app/agent_template + STORAGE_BACKEND: ${STORAGE_BACKEND:-local} + STORAGE_LOCAL_ROOT: /data/agents SECRET_KEY: ${SECRET_KEY:-change-me-in-production} JWT_SECRET_KEY: ${JWT_SECRET_KEY:-change-me-jwt-secret} + PROCESS_ROLE: all CORS_ORIGINS: '["*"]' FEISHU_APP_ID: ${FEISHU_APP_ID:-} FEISHU_APP_SECRET: ${FEISHU_APP_SECRET:-} @@ -84,9 +87,9 @@ services: - "${FRONTEND_PORT:-3008}:3000" environment: VITE_API_URL: http://localhost:8000 + API_UPSTREAM: ${API_UPSTREAM:-backend:8000} volumes: - - ./frontend/src:/app/src - - ./frontend/public:/app/public + - ./frontend/nginx.conf.template:/etc/nginx/templates/default.conf.template:ro networks: - default depends_on: diff --git a/frontend/Dockerfile b/frontend/Dockerfile index 7d8b64e11..87f873607 100644 --- a/frontend/Dockerfile +++ b/frontend/Dockerfile @@ -7,6 +7,6 @@ RUN npm run build FROM nginx:alpine COPY --from=build /app/dist /usr/share/nginx/html -COPY nginx.conf /etc/nginx/conf.d/default.conf +COPY nginx.conf.template /etc/nginx/templates/default.conf.template EXPOSE 3000 CMD ["nginx", "-g", "daemon off;"] diff --git a/frontend/nginx.conf.template b/frontend/nginx.conf.template new file mode 100644 index 000000000..35dd4e33a --- /dev/null +++ b/frontend/nginx.conf.template @@ -0,0 +1,67 @@ +server { + listen 3000; + server_name localhost; + root /usr/share/nginx/html; + index index.html; + + resolver 127.0.0.11 ipv6=off valid=30s; + set $api_upstream http://${API_UPSTREAM:-backend:8000}; + + # Allow large file uploads (up to 100MB) + client_max_body_size 500m; + + # SPA fallback — index.html must never be cached so users always get the latest build + location / { + try_files $uri $uri/ /index.html; + add_header Cache-Control "no-cache, no-store, must-revalidate"; + add_header Pragma "no-cache"; + } + + # Hashed assets (JS/CSS) — safe to cache for 1 year since filenames change on every build + location /assets/ { + expires 1y; + add_header Cache-Control "public, immutable"; + } + + # API proxy + location /api/ { + proxy_pass $api_upstream; + proxy_set_header Host $http_host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-Proto $scheme; + client_max_body_size 500m; + proxy_read_timeout 120s; + proxy_send_timeout 120s; + } + + # Public pages proxy (no auth, served by backend) + location /p/ { + proxy_pass $api_upstream; + proxy_set_header Host $http_host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-Proto $scheme; + } + + # WeCom domain verification files + # WeCom's ownership-check bot requests WW_verify_.txt at the domain root. + # This rule intercepts these requests before the SPA fallback and proxies them + # to the backend, which serves the file content stored in the tenant's config. + location ~ ^/WW_verify_[A-Za-z0-9]+\.txt$ { + proxy_pass $api_upstream/api/wecom-verify$request_uri; + proxy_set_header Host $http_host; + proxy_set_header X-Real-IP $remote_addr; + } + + # WebSocket proxy + location /ws/ { + proxy_pass $api_upstream; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + proxy_set_header Host $http_host; + + # Keep WebSocket alive for long-running agent sessions (1 hour) + proxy_read_timeout 3600s; + proxy_send_timeout 3600s; + } +} diff --git a/frontend/src/components/MarkdownRenderer.tsx b/frontend/src/components/MarkdownRenderer.tsx index 909def8af..48f44e6fe 100644 --- a/frontend/src/components/MarkdownRenderer.tsx +++ b/frontend/src/components/MarkdownRenderer.tsx @@ -75,7 +75,7 @@ function triggerImageDownload(url: string, alt: string) { function renderInline(text: string): string { const tokens: string[] = []; const stash = (html: string) => { - const key = `@@MDTOKEN${tokens.length}@@`; + const key = `@@CLAWITHMDTOKEN${tokens.length}@@`; tokens.push(html); return key; }; @@ -115,7 +115,7 @@ function renderInline(text: string): string { working = autolinkBareUrls(working); tokens.forEach((html, i) => { - working = working.replace(new RegExp(`@@MDTOKEN${i}@@`, 'g'), html); + working = working.replace(new RegExp(`@@CLAWITHMDTOKEN${i}@@`, 'g'), html); }); return working; } diff --git a/frontend/src/pages/agent-detail/AgentDetailPage.tsx b/frontend/src/pages/agent-detail/AgentDetailPage.tsx index 049ebb72f..d2b9dfd24 100644 --- a/frontend/src/pages/agent-detail/AgentDetailPage.tsx +++ b/frontend/src/pages/agent-detail/AgentDetailPage.tsx @@ -2145,6 +2145,10 @@ export default function AgentDetailPage() { const [scopeDropdownOpen, setScopeDropdownOpen] = useState(false); const scopeDropdownRef = useRef(null); const [historyMsgs, setHistoryMsgs] = useState([]); + const [historyOffset, setHistoryOffset] = useState(0); + const [historyHasMore, setHistoryHasMore] = useState(true); + const [historyLoadingMore, setHistoryLoadingMore] = useState(false); + const HISTORY_PAGE_SIZE = 20; const [sessionsLoading, setSessionsLoading] = useState(false); const [allSessionsLoading, setAllSessionsLoading] = useState(false); const [agentExpired, setAgentExpired] = useState(false); @@ -2368,6 +2372,9 @@ export default function AgentDetailPage() { pendingHistoryInitialScrollRef.current = !writable; setChatMessages([]); setHistoryMsgs([]); + setHistoryOffset(0); + setHistoryHasMore(true); + setHistoryLoadingMore(false); setIsStreaming(runtimeState.isStreaming); setIsWaiting(runtimeState.isWaiting); setActiveSession(sess); @@ -2382,7 +2389,7 @@ export default function AgentDetailPage() { const loadSeq = ++sessionLoadSeqRef.current; try { const tkn = localStorage.getItem('token'); - const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages`, { + const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages?limit=${HISTORY_PAGE_SIZE}&offset=0`, { headers: { Authorization: `Bearer ${tkn}` }, signal: controller.signal, }); @@ -2398,6 +2405,7 @@ export default function AgentDetailPage() { ...(m.created_at && { timestamp: m.created_at }), ...(m.id && { id: m.id }), })); + setHistoryHasMore(msgs.length >= HISTORY_PAGE_SIZE); if (writable) { setChatMessages(preParsed); @@ -3343,11 +3351,60 @@ export default function AgentDetailPage() { window.setTimeout(scroll, 120); window.setTimeout(scroll, 360); }, []); + + const loadMoreHistoryMessages = useCallback(async () => { + if (historyLoadingMore || !historyHasMore || !activeSession || !id) return; + const sess = activeSession; + const targetAgentId = id; + setHistoryLoadingMore(true); + try { + const tkn = localStorage.getItem('token'); + const newOffset = historyOffset + HISTORY_PAGE_SIZE; + const res = await fetch(`/api/agents/${targetAgentId}/sessions/${sess.id}/messages?limit=${HISTORY_PAGE_SIZE}&offset=${newOffset}`, { + headers: { Authorization: `Bearer ${tkn}` }, + }); + if (!res.ok) return; + const msgs = await res.json(); + if (msgs.length === 0) { + setHistoryHasMore(false); + return; + } + const preParsed = msgs.map((m: any) => parseChatMsg({ + role: m.role, content: m.content || '', + ...(m.toolName && { toolName: m.toolName, toolArgs: m.toolArgs, toolStatus: m.toolStatus, toolResult: m.toolResult, toolThinking: m.toolThinking }), + ...(m.thinking && { thinking: m.thinking }), + ...(m.created_at && { timestamp: m.created_at }), + ...(m.id && { id: m.id }), + })); + // Save current scroll position + const el = historyContainerRef.current; + const oldScrollHeight = el?.scrollHeight ?? 0; + setHistoryMsgs(prev => [...preParsed, ...prev]); + setHistoryOffset(newOffset); + setHistoryHasMore(msgs.length >= HISTORY_PAGE_SIZE); + // Restore scroll position after new messages are prepended + requestAnimationFrame(() => { + if (el) { + const newScrollHeight = el.scrollHeight; + el.scrollTop = newScrollHeight - oldScrollHeight; + } + }); + } catch (err: any) { + console.error('Failed to load more history messages:', err); + } finally { + setHistoryLoadingMore(false); + } + }, [historyLoadingMore, historyHasMore, activeSession, id, historyOffset]); + const handleHistoryScroll = () => { const el = historyContainerRef.current; if (!el) return; const distFromBottom = el.scrollHeight - el.scrollTop - el.clientHeight; setShowHistoryScrollBtn(distFromBottom > 200); + // Load more when scrolling near the top + if (el.scrollTop < 100 && historyHasMore && !historyLoadingMore) { + loadMoreHistoryMessages(); + } }; const scrollHistoryToBottom = () => { scheduleHistoryScrollToBottom(); @@ -5853,6 +5910,16 @@ export default function AgentDetailPage() { )}
+ {historyLoadingMore && ( +
+ Loading more messages... +
+ )} + {!historyHasMore && historyMsgs.length > 0 && ( +
+ All messages loaded +
+ )} {(() => { // For A2A sessions, determine which participant is "this agent" (left side) // Use agent.name matching against sender_name from messages