From 8efd87bf9424778c4291f29e0ff6e26451424e76 Mon Sep 17 00:00:00 2001 From: Davide Silvestri <75379892+silvestrid@users.noreply.github.com> Date: Fri, 20 Mar 2026 09:25:37 +0100 Subject: [PATCH 1/5] fix: add index to row history table (#4991) --- ...whistory_database_ro_action__6ea699_idx.py | 19 ++++++++++++ .../baserow/contrib/database/rows/history.py | 29 +++++++++++++++---- .../baserow/contrib/database/rows/models.py | 7 ++++- ...lve_slow_cleanup_of_row_history_table.json | 9 ++++++ 4 files changed, 58 insertions(+), 6 deletions(-) create mode 100644 backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py create mode 100644 changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json diff --git a/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py b/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py new file mode 100644 index 0000000000..f0a4ba6a3d --- /dev/null +++ b/backend/src/baserow/contrib/database/migrations/0206_rowhistory_database_ro_action__6ea699_idx.py @@ -0,0 +1,19 @@ +# Generated by Django 5.2.12 on 2026-03-17 09:16 + +from django.db import migrations, models +from django.contrib.postgres.operations import AddIndexConcurrently + + +class Migration(migrations.Migration): + atomic = False + + dependencies = [ + ('database', '0205_formvieweditrowfield'), + ] + + operations = [ + AddIndexConcurrently( + model_name='rowhistory', + index=models.Index(fields=['action_timestamp'], name='database_ro_action__6ea699_idx'), + ), + ] diff --git a/backend/src/baserow/contrib/database/rows/history.py b/backend/src/baserow/contrib/database/rows/history.py index d9a55e2839..aa13498020 100644 --- a/backend/src/baserow/contrib/database/rows/history.py +++ b/backend/src/baserow/contrib/database/rows/history.py @@ -2,7 +2,7 @@ from itertools import groupby from django.conf import settings -from django.db import router +from django.db import connection from django.db.models import QuerySet from django.dispatch import receiver @@ -18,6 +18,7 @@ from baserow.contrib.database.rows.types import ActionData from baserow.core.action.signals import action_done from baserow.core.models import Workspace +from baserow.core.psycopg import sql from baserow.core.telemetry.utils import baserow_trace from baserow.core.types import AnyUser @@ -68,15 +69,33 @@ def list_row_history( return queryset @classmethod - def delete_entries_older_than(cls, cutoff: datetime): + def delete_entries_older_than(cls, cutoff: datetime, batch_size: int = 20_000): """ - Deletes all row history entries that are older than the given cutoff date. + Deletes all row history entries that are older than the given cutoff date + in batches to avoid long-running transactions. :param cutoff: The date and time before which all entries will be deleted. + :param batch_size: The number of rows to delete per batch. """ - delete_qs = RowHistory.objects.filter(action_timestamp__lt=cutoff) - delete_qs._raw_delete(using=router.db_for_write(delete_qs.model)) + table = sql.Identifier(RowHistory._meta.db_table) + query = sql.SQL( + """ + WITH to_delete AS ( + SELECT id FROM {table} + WHERE action_timestamp < %s + LIMIT %s + ) + DELETE FROM {table} + USING to_delete + WHERE {table}.id = to_delete.id + """ + ).format(table=table) + while True: + with connection.cursor() as cursor: + cursor.execute(query, [cutoff, batch_size]) + if cursor.rowcount == 0: + break @receiver(action_done) diff --git a/backend/src/baserow/contrib/database/rows/models.py b/backend/src/baserow/contrib/database/rows/models.py index efdc954a73..075e08cfab 100644 --- a/backend/src/baserow/contrib/database/rows/models.py +++ b/backend/src/baserow/contrib/database/rows/models.py @@ -59,4 +59,9 @@ class RowHistory(models.Model): class Meta: ordering = ("-action_timestamp", "-id") - indexes = [models.Index(fields=["table", "row_id", "-action_timestamp", "-id"])] + indexes = [ + # For deleting history entries by action timestamp. + models.Index(fields=["action_timestamp"]), + # For listing the history of a row. + models.Index(fields=["table", "row_id", "-action_timestamp", "-id"]), + ] diff --git a/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json b/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json new file mode 100644 index 0000000000..eeca089da0 --- /dev/null +++ b/changelog/entries/unreleased/bug/add_index_to_resolve_slow_cleanup_of_row_history_table.json @@ -0,0 +1,9 @@ +{ + "type": "bug", + "message": "Fixed slow cleanup of row history table by adding a database index.", + "issue_origin": "github", + "issue_number": null, + "domain": "database", + "bullet_points": [], + "created_at": "2026-03-17" +} \ No newline at end of file From 7ff02b86cb125cafa7414382b7fee6cc003b57df Mon Sep 17 00:00:00 2001 From: dimmur-brw Date: Fri, 20 Mar 2026 11:52:25 +0100 Subject: [PATCH 2/5] Silence defined error codes (#5012) --- .../5011_silence_defined_error_codes_in_sentry.json | 9 +++++++++ web-frontend/modules/core/utils/sentryErrors.js | 5 +++++ web-frontend/sentry.client.config.ts | 10 +++++++++- web-frontend/sentry.server.config.ts | 10 +++++++++- 4 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json create mode 100644 web-frontend/modules/core/utils/sentryErrors.js diff --git a/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json b/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json new file mode 100644 index 0000000000..c1aa56f4a9 --- /dev/null +++ b/changelog/entries/unreleased/refactor/5011_silence_defined_error_codes_in_sentry.json @@ -0,0 +1,9 @@ +{ + "type": "refactor", + "message": "Silence defined error codes in Sentry", + "issue_origin": "github", + "issue_number": 5011, + "domain": "core", + "bullet_points": [], + "created_at": "2026-03-19" +} \ No newline at end of file diff --git a/web-frontend/modules/core/utils/sentryErrors.js b/web-frontend/modules/core/utils/sentryErrors.js new file mode 100644 index 0000000000..0da78f1ae9 --- /dev/null +++ b/web-frontend/modules/core/utils/sentryErrors.js @@ -0,0 +1,5 @@ +// API error codes to silence in Sentry, keyed by HTTP status code. +// These are handled by the application (e.g. forceLogoff) and are not bugs. +export const SILENCED_API_ERRORS = { + 401: ['ERROR_INVALID_ACCESS_TOKEN', 'ERROR_INVALID_REFRESH_TOKEN'], +} diff --git a/web-frontend/sentry.client.config.ts b/web-frontend/sentry.client.config.ts index 853e39b691..e737d2f2f6 100644 --- a/web-frontend/sentry.client.config.ts +++ b/web-frontend/sentry.client.config.ts @@ -1,5 +1,6 @@ import { useRuntimeConfig, useAppConfig, useRouter } from '#imports' import { makeFakeTransport } from './modules/core/utils/sentryFakeTransport' +import { SILENCED_API_ERRORS } from './modules/core/utils/sentryErrors' import * as Sentry from '@sentry/nuxt' const config = useRuntimeConfig() @@ -30,7 +31,7 @@ if (dsn && dsn !== '') { ...(isDev ? { transport: makeFakeTransport } : {}), beforeSend(event, hint) { const err = hint?.originalException - if (err?.fatal === false || err?.response?.status === 401) return null + if (err?.fatal === false) return null // Filter out axios errors without a response like // network error, timeout, aborted, cancelled requests @@ -38,6 +39,13 @@ if (dsn && dsn !== '') { return null } + // Filter known API errors that are handled by the application (e.g. forceLogoff). + const status = err?.response?.status || err?.statusCode + const errorCode = err?.response?.data?.error || err?.data?.error + if (SILENCED_API_ERRORS[status]?.includes(errorCode)) { + return null + } + if (isDev) { console.error('[Sentry captured error]', `${err}`) return null diff --git a/web-frontend/sentry.server.config.ts b/web-frontend/sentry.server.config.ts index bd6b3b18a5..39ffd55a64 100644 --- a/web-frontend/sentry.server.config.ts +++ b/web-frontend/sentry.server.config.ts @@ -1,6 +1,7 @@ import { useRuntimeConfig } from '#imports' import * as Sentry from '@sentry/nuxt' import { makeFakeTransport } from './modules/core/utils/sentryFakeTransport' +import { SILENCED_API_ERRORS } from './modules/core/utils/sentryErrors' const config = useRuntimeConfig() const dsn = @@ -18,7 +19,14 @@ if (dsn && dsn !== '') { ...(isDev ? { transport: makeFakeTransport } : {}), beforeSend(event, hint) { const err = hint?.originalException - if (err?.fatal === false || err?.response?.status === 401) return null + if (err?.fatal === false) return null + + // Filter known API errors that are handled by the application (e.g. forceLogoff). + const status = err?.response?.status || err?.statusCode + const errorCode = err?.response?.data?.error || err?.data?.error + if (SILENCED_API_ERRORS[status]?.includes(errorCode)) { + return null + } if (isDev) { console.error('[Sentry captured error]', err) return null From b65c7e77cb91a32d6cfff2cf82863d47763ae75e Mon Sep 17 00:00:00 2001 From: Davide Silvestri <75379892+silvestrid@users.noreply.github.com> Date: Fri, 20 Mar 2026 11:53:17 +0100 Subject: [PATCH 3/5] chore refactor (AI Assistant): context offloading; better telemetry data; add google and anthropic models compatibility (#4951) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * chore(deps): replace udspy with pydantic-ai and opentelemetry-sdk Replace the udspy dependency with pydantic-ai-slim (with openai, groq, anthropic, bedrock providers) and opentelemetry-sdk for structured telemetry collection. * fix(sentry): exclude pydantic_ai from auto-enabling integrations sentry-sdk's pydantic_ai integration patches ToolManager._call_tool which was removed in pydantic-ai >= 1.x (now execute_tool_call), causing import-time errors. * feat(settings): add dev log file mirroring and allow embeddings URL in tests - Add BASEROW_LOG_FILE support in dev settings to mirror logs (including loguru output) to a file, useful for AI-assisted debugging. - Allow BASEROW_EMBEDDINGS_API_URL to be overridden via env in test settings for search_user_docs eval tests. * feat(assistant): add message_history field to AssistantChat Add a BinaryField to store serialized pydantic-ai message history (JSON bytes) for multi-turn conversation context, replacing the previous udspy-based conversation state. * refactor(assistant): port to pydantic-ai agent framework Replace udspy with pydantic-ai as the agent framework for the AI assistant. Key changes: - Add Agent definitions with typed deps (AssistantDeps) and dynamic toolsets for runtime tool loading - Add deps module with AssistantDeps, ToolHelpers, and EventBus for streaming events to the UI - Add history module for serializing/deserializing pydantic-ai message history to the database - Add model_profiles for provider-specific configuration (Anthropic, OpenAI, Groq, Bedrock) - Add toolset module with ToolGroup base class replacing the udspy tool registry pattern - Add shared/ with formula_utils and sub-agent helpers - Add tool_types.py per tool module for pydantic-ai ToolDefinition - Port all tool modules (core, database, navigation, automation, search_user_docs) from udspy decorators to pydantic-ai Tool instances - Port assistant orchestrator, handler, and prompts - Remove signatures.py (replaced by pydantic-ai output types) * refactor(assistant): update telemetry for pydantic-ai Rework telemetry collection to use pydantic-ai's message history format and OpenTelemetry SDK for structured span/event recording, replacing the previous udspy-based telemetry hooks. * test(assistant): update unit tests for pydantic-ai port Rewrite assistant unit tests to use pydantic-ai's testing utilities (TestModel, FunctionModel) instead of udspy mocks. Add new test files for core tools, navigation tools, and search docs tools. Remove obsolete skip file. * test(assistant): add LLM eval test suite Add end-to-end eval tests that run the real agent against a live LLM to verify tool selection, schema compatibility, and output quality. Includes evals for: navigation, core builders, database tables/rows, sample rows, automation workflows, search_user_docs, and cross-cutting structured output validation. Tests are marked with @pytest.mark.eval and skipped by default. Configure via EVAL_LLM_MODEL or EVAL_LLM_MODELS env vars. * docs: add eval guide and update AI assistant installation docs - Add docs/development/ai-assistant-evals.md with instructions for running the eval suite, configuring models, and writing new evals. - Update docs/installation/ai-assistant.md to reflect pydantic-ai provider configuration replacing the previous udspy setup. * fix(assistant): fix test patch paths, optional filter args, and eval marker - Fix mock patch paths from `assistant.agent` to `assistant.agents` - Make ListTablesFilterArg fields optional to prevent LLM validation errors - Surface field_errors in create_fields tool result - Simplify EvalToolTracker to use message history inspection - Register `eval` pytest marker and skip evals by default Co-Authored-By: Claude Opus 4.6 * docs: move testing docs to docs/testing/ and add PR test plan Move ai-assistant-evals.md from docs/development/ to docs/testing/, add ai-assistant-test-plan.md with manual and automated test steps for the pydantic-ai port PR. * refactor(assistant): extract row models to types/rows.py, rename utils to helpers - Move FieldDefinition, row model builders, and get_link_row_hints to new types/rows.py module with dict-of-callables dispatch replacing match/case - Simplify update model: fields are optional (omit = don't change), removing the __NO_CHANGE__ sentinel - Move get_table_rows_tools into tools.py as _build_row_tools since it builds pydantic-ai Tool objects - Rename utils.py to helpers.py for clarity, remove dead list_tables - Add docstrings with :param/:returns to all public functions, add proper type annotations throughout * refactor(assistant): flatten field/view/filter types into single models - Replace per-type config classes in fields.py with a single flat FieldItemCreate model using optional type-specific fields and a model_validator for type aliases - Simplify view_filters.py and views.py type hierarchies similarly - Update table.py types and corresponding tests * fix(assistant): improve telemetry span processor and minor tweaks - Replace SpanExporter with SpanProcessor for real-time span handling - Remap child tool spans past 'running tools' grouping span - Parse JSON string arguments in tool call parts - Add data_brief parameter to sample rows prompt - Disable reasoning_effort for groq models temporarily - Rename _fix_formula to _fix_formula_field in helpers * refactor(assistant): use ISO strings for dates and add dateutil fallback Replace Date/Datetime Pydantic model objects with compact ISO 8601 strings. Add lenient parsing via dateutil.parser fallback when fromisoformat fails. * feat(assistant): add DO/EXPLAIN agent modes with switch_mode tool Introduce AgentMode enum (DO/EXPLAIN) and ModeAwareToolset that filters available tools based on mode. DO mode (default) exposes all action tools except search_user_docs. EXPLAIN mode exposes only read-only tools (list_*, navigate) plus search_user_docs for answering Baserow feature questions. The switch_mode tool allows bidirectional switching. * refactor(assistant): flatten automation node types and add $formula: convention Replace 13+ per-type action node classes (RouterNodeCreate, SendEmailActionCreate, etc.) with a single flat ActionNodeCreate model using @model_validator for per-type validation and dict-dispatched functions for ORM conversion and formula generation. Add $formula: prefix convention — values prefixed with '$formula:' are sent to the LLM formula generator, plain values become literal formulas. Also detect raw formula expressions (get(), concat(), etc.) written inline. Add trigger validation: periodic triggers now require periodic_interval (with automatic folding of flat fields), row triggers require rows_triggers_settings. * build: bump pydantic-ai-slim to 0.1.66 and anthropic to 0.84.0 * feat(assistant): add RetryingModel for transient provider error recovery Wraps pydantic-ai model instances to automatically retry on transient errors (rate limits, timeouts, server errors) with exponential backoff. Handles both streaming and non-streaming calls. * feat(assistant): add AgentMode system with ModeAwareToolset and switch_mode Introduce domain modes (DATABASE, APPLICATION, AUTOMATION, EXPLAIN) that control which tools are visible to the agent. ModeAwareToolset filters the combined toolset per-mode, registries generate per-mode manifests, and switch_mode lets the agent transition between domains. Each mode gets a cross-mode summary so the agent knows what other modes offer. * refactor(assistant): integrate RetryingModel, event-based streaming, and JSON retry Replace direct model usage with RetryingModel for resilience. Rewrite streaming to use run_stream_events for proper text/reasoning/tool event handling. Add JSON-tool-call-as-text detection with automatic retry. Auto-detect starting mode from UI context. Update model_profiles with max_tokens settings. * refactor(assistant): improve shared formula utils and add formula language reference Add RAW_FORMULA_RE for detecting raw formula expressions, needs_formula() for $formula: prefix and raw formula detection, literal_or_placeholder() for ORM value creation, and a shared formula language prompt. Improve formula generator to track remaining unresolved fields across retries. * refactor(assistant): improve database tools with routing rules and type fixes Add per-module routing rules via get_routing_rules(). Extract ToolInputError to helpers. Fix field type validators, improve row model handling, and refine view filter types. Update agents and prompts for better tool guidance. * refactor(assistant): improve automation tools with routing rules and formula handling Add per-module routing rules for automation. Improve node type handling with better formula context support. Refine automation agents and prompts. * fix(assistant): improve telemetry span processor with real-time remapping Enhance SpanProcessor with JSON arg parsing, real-time span remapping, and improved trace output handling. Update tests to cover new behavior. * test(assistant): update tests for mode system, retry logic, and type refactors Add tests for switch_mode, mode-aware manifests, JSON retry logic. Add test_assistant_automation_node_tools and test_assistant_database_field_tools. Update existing database/automation/core tests for refactored types and new tool signatures. Update eval_utils for new deps structure. * fix: lint * chore(frontend): ignore .claude dir in vite watcher and fix Nitro EMFILE - Add .claude/ to vite server watch ignore list alongside node_modules and .git to avoid unnecessary file watching in worktrees. - Configure Nitro devStorage to use fs-lite driver to prevent chokidar from watching the entire repo root, which causes EMFILE on macOS in large monorepos. * fix(tests): fix test settings and seat usage test isolation with xdist - Ensure pytest always finds backend/pytest.ini by passing -c pytest.ini explicitly, fixing DJANGO_SETTINGS_MODULE=dev when running from root - Preserve existing TEST dict keys when setting MIGRATE in test settings - Add transaction=True to seat usage tests to prevent data leaking from TransactionTestCase tests running on the same xdist worker * fix(assistant): fix field types, formula regex, validator guard, and consolidate evals Fix multiple_select returning None instead of [], link_row description typo, formula regex missing greater_than_or_equal/less_than_or_equal variants, and guard against overwriting original validators on repeated prepare_tools calls. Consolidate sample rows and navigation evals into the database tables eval file and remove the meta tool-call history test. * fix(assistant): improve tool return types, filter aliases, and eval infrastructure - Return consistent dict types from all tools instead of plain strings - Add operator aliases for view filters so LLMs can use natural names - Fix boolean filter operator (is → equal) - Remove reasoning format from UTILITY model profiles (pollutes structured output) - Add ModelRetry for workflow creation and formula agent errors - Add EvalChecklist for soft assertions with pass/fail scoring - Add EVAL_RETRIES support for flake detection in eval tests - Suppress loguru DEBUG noise during evals * refactor(assistant): extract table creation helpers and remove unused model profile - Extract _create_empty_tables and _create_table_fields from create_tables - Filter out duplicate primary field in field creation to avoid model mistakes - Remove unused gpt-oss-20b model profile - Always attempt sample rows regardless of field errors * fix(assistant): strip tags and unify streaming as reasoning chunks Models like MiniMax-M2.5 emit ... tags inline. Handle ThinkingPart/ThinkingPartDelta events from pydantic-ai and extract inline thinking from text parts as a fallback. Stream all content as AiReasoningChunk during the agent run; the final answer is emitted as AiMessageChunk by _emit_answer. * fix(assistant): simplify streaming and add collapsible reasoning UI Replace _accumulate_text/_extract_thinking with a single _get_content_delta helper that forwards text/thinking deltas. Accumulate reasoning_so_far and strip tags before sending to frontend (which replaces content on each chunk). Add collapsible reasoning bubble (max 250px with fade mask and chevron toggle). * fix(assistant): bridge legacy UDSPY_LM_* env vars to pydantic-ai config * docs(assistant): improve ai-assistant.md and add AWS_REGION_NAME backward compat - Add both Bedrock auth methods (boto3 creds + bearer token) - Add section 6 with pydantic-ai model overview link and provider list - Restructure migration table: unchanged / bridged / new variables - Fix AWS_BEARER_TOKEN_BEDROCK incorrectly listed as removed - Bridge AWS_REGION_NAME to AWS_DEFAULT_REGION in settings for backward compat * minor doc/evals fixes * fix(assistant): strip unclosed tags during streaming Models behind Groq emit tags as text content rather than using the native thinking protocol. During streaming, the closing tag may not have arrived yet, causing raw thinking content to leak to the frontend. Also strip think tags from tool thought fields and reset reasoning on tool results. * docs(assistant): use provider:model format and refresh eval docs - Update all docs to use pydantic-ai provider:model format (colon separator) - Fix mixed-up provider descriptions in configuration.md - Refresh eval docs: replace assert_no_tool_errors with EvalChecklist pattern - Add embeddings URL for local vs Docker in ai-assistant-evals.md - Skip KB sync post_migrate signal during tests - Fix temperature type in model_profiles.py - Refactor justfile PYTHONPATH for test recipe * fix(assistant): clean up navigation and improve tool types - Remove unused WorkspaceNavigationRequestType - Narrow exception catch in navigate tool to ObjectDoesNotExist - Guard id field in CreateRowModel.from_django_orm - Use id__in for batch filtering in ListTablesFilterArg * Revert "chore(frontend): ignore .claude dir in vite watcher and fix Nitro EMFILE" This reverts commit 4be1be19d2ba6c778cf9321156368736fe52001e. * fix: ai-assistant-test-plan.md tool smoke test prompt for list_builders * fix: Posthog env var names in docs/testing/ai-assistant-test-plan.md * fix: wrong doc reference --- backend/justfile | 7 +- backend/pyproject.toml | 7 +- backend/pytest.ini | 1 + backend/src/baserow/config/settings/base.py | 10 +- backend/src/baserow/config/settings/dev.py | 18 + backend/src/baserow/config/settings/test.py | 10 +- backend/uv.lock | 195 +- ...ntic-ai_replace_udspy_with_pydanticai.json | 9 + docs/installation/ai-assistant.md | 110 +- docs/installation/configuration.md | 12 +- docs/installation/install-with-helm.md | 2 +- docs/testing/ai-assistant-evals.md | 255 +++ docs/testing/ai-assistant-test-plan.md | 134 ++ enterprise/backend/pytest.ini | 2 + .../baserow_enterprise/api/assistant/views.py | 6 +- .../backend/src/baserow_enterprise/apps.py | 72 +- .../baserow_enterprise/assistant/agents.py | 81 + .../baserow_enterprise/assistant/assistant.py | 834 ++++---- .../src/baserow_enterprise/assistant/deps.py | 148 ++ .../baserow_enterprise/assistant/handler.py | 9 +- .../baserow_enterprise/assistant/history.py | 118 ++ .../assistant/model_profiles.py | 175 ++ .../baserow_enterprise/assistant/models.py | 8 + .../baserow_enterprise/assistant/prompts.py | 212 +- .../assistant/retrying_model.py | 448 +++++ .../assistant/signatures.py | 33 - .../src/baserow_enterprise/assistant/tasks.py | 2 +- .../baserow_enterprise/assistant/telemetry.py | 715 +++++-- .../assistant/tools/__init__.py | 6 +- .../assistant/tools/automation/__init__.py | 5 - .../assistant/tools/automation/agents.py | 156 ++ .../assistant/tools/automation/helpers.py | 288 +++ .../assistant/tools/automation/prompts.py | 35 +- .../assistant/tools/automation/tool_types.py | 20 + .../assistant/tools/automation/tools.py | 440 +++- .../tools/automation/types/__init__.py | 22 +- .../assistant/tools/automation/types/node.py | 1070 ++++++---- .../tools/automation/types/workflow.py | 58 +- .../assistant/tools/automation/utils.py | 390 ---- .../assistant/tools/core/tool_types.py | 15 + .../assistant/tools/core/tools.py | 275 +-- .../assistant/tools/core/types.py | 6 +- .../assistant/tools/database/agents.py | 284 +++ .../assistant/tools/database/helpers.py | 388 ++++ .../assistant/tools/database/prompts.py | 64 + .../assistant/tools/database/tool_types.py | 20 + .../assistant/tools/database/tools.py | 1772 ++++++++++------- .../tools/database/types/__init__.py | 1 + .../assistant/tools/database/types/base.py | 64 +- .../assistant/tools/database/types/fields.py | 991 +++++---- .../assistant/tools/database/types/rows.py | 383 ++++ .../assistant/tools/database/types/table.py | 105 +- .../tools/database/types/view_filters.py | 736 ++----- .../assistant/tools/database/types/views.py | 545 +++-- .../assistant/tools/database/utils.py | 559 ------ .../assistant/tools/navigation/tool_types.py | 15 + .../assistant/tools/navigation/tools.py | 77 +- .../assistant/tools/navigation/types.py | 20 +- .../assistant/tools/navigation/utils.py | 10 +- .../assistant/tools/registries.py | 237 ++- .../tools/search_user_docs/tool_types.py | 20 + .../assistant/tools/search_user_docs/tools.py | 402 ++-- .../assistant/tools/shared/__init__.py | 25 + .../assistant/tools/shared/agents.py | 141 ++ .../assistant/tools/shared/formula_prompt.py | 84 + .../assistant/tools/shared/formula_utils.py | 272 +++ .../assistant/tools/toolset.py | 438 ++++ .../src/baserow_enterprise/assistant/types.py | 19 +- .../config/settings/settings.py | 33 +- .../0058_assistantchat_message_history.py | 24 + .../assistant/evals/__init__.py | 1 + .../assistant/evals/conftest.py | 153 ++ .../assistant/evals/eval_utils.py | 372 ++++ .../evals/test_eval_automation_workflows.py | 845 ++++++++ .../evals/test_eval_core_builders.py | 201 ++ .../evals/test_eval_database_rows.py | 214 ++ .../evals/test_eval_database_tables.py | 1164 +++++++++++ .../evals/test_eval_search_user_docs.py | 276 +++ .../assistant/test_assistant.py | 1110 ++++------- .../test_assistant_automation_node_tools.py | 304 +++ ...est_assistant_automation_workflow_tools.py | 275 +-- .../assistant/test_assistant_core_tools.py | 116 ++ .../test_assistant_database_field_tools.py | 151 ++ .../test_assistant_database_rows_tools.py | 163 +- .../test_assistant_database_table_tools.py | 756 ++++--- .../test_assistant_database_tools.py.skip | 52 - ...t_assistant_database_view_filters_tools.py | 551 ++--- .../test_assistant_database_views_tools.py | 397 ++-- .../test_assistant_navigation_tools.py | 47 + .../test_assistant_search_docs_tools.py | 104 + .../assistant/test_retrying_model.py | 472 +++++ .../assistant/test_telemetry.py | 691 +++++-- .../assistant/utils.py | 28 +- .../baserow_enterprise_tests/conftest.py | 6 +- .../enterprise/test_enterprise_license.py | 64 +- .../assets/scss/components/assistant.scss | 32 + .../assistant/AssistantMessageList.vue | 22 + 97 files changed, 15602 insertions(+), 7113 deletions(-) create mode 100644 changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json create mode 100644 docs/testing/ai-assistant-evals.md create mode 100644 docs/testing/ai-assistant-test-plan.md create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/agents.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/deps.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/history.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py delete mode 100644 enterprise/backend/src/baserow_enterprise/assistant/signatures.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py delete mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py delete mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py create mode 100644 enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py delete mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py diff --git a/backend/justfile b/backend/justfile index a87d620520..dcfc6dd74d 100644 --- a/backend/justfile +++ b/backend/justfile @@ -81,9 +81,10 @@ uv_run := "uv run --active" # Repo root (parent of backend/) - clean() normalizes path (removes ..) repo_root := clean(justfile_directory() / "..") +_set_pythonpath := 'export PYTHONPATH="' + repo_root / 'backend/src:' + repo_root / 'premium/backend/src:' + repo_root / 'enterprise/backend/src:' + repo_root / 'backend/tests:' + repo_root / 'premium/backend/tests:' + repo_root / 'enterprise/backend/tests${PYTHONPATH:+:$PYTHONPATH}"' # Helper to load .env.local if present and set PYTHONPATH with absolute paths # Include this at the start of bash recipes that need env vars -_load_env := 'if [ -f "../.env.local" ]; then set -a; source "../.env.local"; set +a; fi; export PYTHONPATH="' + repo_root / 'backend/src:' + repo_root / 'premium/backend/src:' + repo_root / 'enterprise/backend/src:' + repo_root / 'backend/tests:' + repo_root / 'premium/backend/tests:' + repo_root / 'enterprise/backend/tests${PYTHONPATH:+:$PYTHONPATH}"' +_load_env := 'if [ -f "../.env.local" ]; then set -a; source "../.env.local"; set +a; fi; ' + _set_pythonpath # Source directories backend_source_dirs := "src/ ../premium/backend/src/ ../enterprise/backend/src/" @@ -228,14 +229,14 @@ alias f := fix # PYTHONPATH for test fixtures across all test directories test_pythonpath := "tests:../premium/backend/tests:../enterprise/backend/tests" -_pytest := 'PYTHONPATH="' + test_pythonpath + ':${PYTHONPATH:-}" ' + uv_run + ' pytest' +_pytest := 'PYTHONPATH="' + test_pythonpath + ':${PYTHONPATH:-}" ' + uv_run + ' pytest -c pytest.ini' # Run tests. Pass -n=auto to run in parallel with pytest-xdist [group('3 - testing')] test *ARGS: _check-dev #!/usr/bin/env bash set -euo pipefail - {{ _load_env }} + {{ _set_pythonpath }} {{ _pytest }} {{ ARGS }} # Run tests with coverage report diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2b4f0d62da..2f50270b68 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -89,8 +89,8 @@ dependencies = [ "langchain==0.3.28", "langchain-openai==0.3.35", "openai==2.14.0", - "anthropic==0.77.0", - "mistralai==1.1.0", + "anthropic==0.84.0", + "mistralai==2.0.0", "icalendar==6.3.2", "jira2markdown==0.5", "openpyxl==3.1.5", @@ -100,7 +100,8 @@ dependencies = [ "genson==1.3.0", "pyotp==2.9.0", "qrcode==8.2", - "udspy==0.1.8", + "pydantic-ai-slim[anthropic,bedrock,google,groq,openai]==1.66.0", + "opentelemetry-sdk>=1.20.0", "netifaces==0.11.0", "requests-futures>=1.0.2", ] diff --git a/backend/pytest.ini b/backend/pytest.ini index d15be43dcd..7fc761fe45 100644 --- a/backend/pytest.ini +++ b/backend/pytest.ini @@ -56,3 +56,4 @@ markers = workspace_search: All tests related to workspace search functionality enable_all_signals: Disables signal deferral for this test (all signals enabled) enable_signals: Enables specific signals for this test (accepts dotted callable paths) + eval: mark test as an eval test (requires LLM API key) diff --git a/backend/src/baserow/config/settings/base.py b/backend/src/baserow/config/settings/base.py index 86213e7f05..6796069e99 100644 --- a/backend/src/baserow/config/settings/base.py +++ b/backend/src/baserow/config/settings/base.py @@ -1332,15 +1332,15 @@ def __setitem__(self, key, value): from sentry_sdk.integrations.django import DjangoIntegration from sentry_sdk.scrubber import DEFAULT_DENYLIST, EventScrubber - # Exclude the langchain integration from auto-discovery: its module-level - # imports are incompatible with Python 3.14 (langchain/pydantic type - # evaluation crash), and the import happens before disabled_integrations - # can take effect. + # Exclude integrations whose module-level imports are incompatible: + # - langchain: Python 3.14 type evaluation crash + # - pydantic_ai: sentry-sdk patches ToolManager._call_tool which was + # removed in pydantic-ai >= 1.x (now execute_tool_call) _sentry_integrations._AUTO_ENABLING_INTEGRATIONS[:] = [ entry for entry in _sentry_integrations._AUTO_ENABLING_INTEGRATIONS - if "langchain" not in entry + if "langchain" not in entry and "pydantic_ai" not in entry ] SENTRY_DENYLIST = DEFAULT_DENYLIST + ["username", "email", "name"] diff --git a/backend/src/baserow/config/settings/dev.py b/backend/src/baserow/config/settings/dev.py index db3b3c4434..532eff15c0 100755 --- a/backend/src/baserow/config/settings/dev.py +++ b/backend/src/baserow/config/settings/dev.py @@ -66,6 +66,24 @@ post_migrate.connect(setup_dev_e2e, dispatch_uid="setup_dev_e2e") +# Mirror logs to a file when BASEROW_LOG_FILE is set (e.g. for AI access when +# running locally). Truncated on each restart. +BASEROW_LOG_FILE = os.getenv("BASEROW_LOG_FILE", "") +if BASEROW_LOG_FILE: + LOGGING["handlers"]["file"] = { # noqa: F405 + "class": "logging.FileHandler", + "filename": BASEROW_LOG_FILE, + "formatter": "console", + "mode": "w", + } + LOGGING["root"]["handlers"].append("file") # noqa: F405 + + # Also route loguru to the same file so modules using loguru (e.g. + # the assistant telemetry) appear alongside stdlib log output. + from loguru import logger as _loguru_logger + + _loguru_logger.add(BASEROW_LOG_FILE, mode="a") + try: from .local import * # noqa: F403, F401 except ImportError: diff --git a/backend/src/baserow/config/settings/test.py b/backend/src/baserow/config/settings/test.py index 8245a298ce..d6449937e8 100644 --- a/backend/src/baserow/config/settings/test.py +++ b/backend/src/baserow/config/settings/test.py @@ -26,13 +26,13 @@ TEST_ENV_VARS = {} # Prefixes for vars that can be overridden via env vars (for DB/Redis configuration) -ALLOWED_ENV_PREFIXES = ("DATABASE_",) +ALLOWED_ENV_PREFIXES = ("DATABASE_", "BASEROW_EMBEDDINGS_API_URL") def getenv_for_tests(key: str, default: str = "") -> str: """ Get env var for tests: - - DATABASE_* vars: check real env first, then TEST_ENV_FILE, then default + - ALLOWED_ENV_PREFIXES vars: use real env var if set, else TEST_ENV_FILE, else default - Other vars: only use TEST_ENV_FILE or default (never real env) """ @@ -65,9 +65,9 @@ def getenv_for_tests(key: str, default: str = "") -> str: BASEROW_TESTS_SETUP_DB_FIXTURE = str_to_bool( os.getenv("BASEROW_TESTS_SETUP_DB_FIXTURE", "on") ) -DATABASES["default"]["TEST"] = { - "MIGRATE": not BASEROW_TESTS_SETUP_DB_FIXTURE, -} +DATABASES["default"].setdefault("TEST", {})[ + "MIGRATE" +] = not BASEROW_TESTS_SETUP_DB_FIXTURE # Open a second database connection that can be used to test transactions. DATABASES["default-copy"] = deepcopy(DATABASES["default"]) diff --git a/backend/uv.lock b/backend/uv.lock index edd52787b9..6715e97cc4 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -40,7 +40,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.77.0" +version = "0.84.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -52,9 +52,9 @@ dependencies = [ { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/eb/85/6cb5da3cf91de2eeea89726316e8c5c8c31e2d61ee7cb1233d7e95512c31/anthropic-0.77.0.tar.gz", hash = "sha256:ce36efeb80cb1e25430a88440dc0f9aa5c87f10d080ab70a1bdfd5c2c5fbedb4", size = 504575, upload-time = "2026-01-29T18:20:41.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/04/ea/0869d6df9ef83dcf393aeefc12dd81677d091c6ffc86f783e51cf44062f2/anthropic-0.84.0.tar.gz", hash = "sha256:72f5f90e5aebe62dca316cb013629cfa24996b0f5a4593b8c3d712bc03c43c37", size = 539457, upload-time = "2026-02-25T05:22:38.54Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ac/27/9df785d3f94df9ac72f43ee9e14b8120b37d992b18f4952774ed46145022/anthropic-0.77.0-py3-none-any.whl", hash = "sha256:65cc83a3c82ce622d5c677d0d7706c77d29dc83958c6b10286e12fda6ffb2651", size = 397867, upload-time = "2026-01-29T18:20:39.481Z" }, + { url = "https://files.pythonhosted.org/packages/64/ca/218fa25002a332c0aa149ba18ffc0543175998b1f65de63f6d106689a345/anthropic-0.84.0-py3-none-any.whl", hash = "sha256:861c4c50f91ca45f942e091d83b60530ad6d4f98733bfe648065364da05d29e7", size = 455156, upload-time = "2026-02-25T05:22:40.468Z" }, ] [[package]] @@ -255,6 +255,7 @@ dependencies = [ { name = "prosemirror", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "psycopg2-binary", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic-ai-slim", extra = ["anthropic", "bedrock", "google", "groq", "openai"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pyotp", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pysaml2", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "qrcode", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -270,7 +271,6 @@ dependencies = [ { name = "twisted", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "tzdata", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "udspy", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "unicodecsv", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "uvicorn", extra = ["standard"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "validators", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, @@ -323,7 +323,7 @@ dev = [ [package.metadata] requires-dist = [ - { name = "anthropic", specifier = "==0.77.0" }, + { name = "anthropic", specifier = "==0.84.0" }, { name = "antlr4-python3-runtime", specifier = "==4.9.3" }, { name = "asgiref", specifier = "==3.11.0" }, { name = "boto3", specifier = "==1.42.57" }, @@ -359,7 +359,7 @@ requires-dist = [ { name = "langchain-openai", specifier = "==0.3.35" }, { name = "loguru", specifier = "==0.7.3" }, { name = "mcp", specifier = "==1.26.0" }, - { name = "mistralai", specifier = "==1.1.0" }, + { name = "mistralai", specifier = "==2.0.0" }, { name = "netifaces", specifier = "==0.11.0" }, { name = "openai", specifier = "==2.14.0" }, { name = "openpyxl", specifier = "==3.1.5" }, @@ -381,6 +381,7 @@ requires-dist = [ { name = "opentelemetry-instrumentation-wsgi", specifier = "==0.60b1" }, { name = "opentelemetry-proto", specifier = "==1.39.1" }, { name = "opentelemetry-sdk", specifier = "==1.39.1" }, + { name = "opentelemetry-sdk", specifier = ">=1.20.0" }, { name = "opentelemetry-semantic-conventions", specifier = "==0.60b1" }, { name = "opentelemetry-util-http", specifier = "==0.60b1" }, { name = "pgvector", specifier = "==0.4.2" }, @@ -389,6 +390,7 @@ requires-dist = [ { name = "prosemirror", specifier = "==0.5.2" }, { name = "psutil", specifier = "==7.2.2" }, { name = "psycopg2-binary", specifier = "==2.9.11" }, + { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "google", "groq", "openai"], specifier = "==1.66.0" }, { name = "pyotp", specifier = "==2.9.0" }, { name = "pysaml2", specifier = "==7.5.4" }, { name = "qrcode", specifier = "==8.2" }, @@ -404,7 +406,6 @@ requires-dist = [ { name = "twisted", specifier = "==25.5.0" }, { name = "typing-extensions", specifier = ">=4.14.1" }, { name = "tzdata", specifier = "==2025.3" }, - { name = "udspy", specifier = "==0.1.8" }, { name = "unicodecsv", specifier = "==0.14.1" }, { name = "uvicorn", extras = ["standard"], specifier = "==0.40.0" }, { name = "validators", specifier = "==0.35.0" }, @@ -1249,6 +1250,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5e/2e/b41d8a1a917d6581fc27a35d05561037b048e47df50f27f8ac9c7e27a710/freezegun-1.5.5-py3-none-any.whl", hash = "sha256:cd557f4a75cf074e84bc374249b9dd491eaeacd61376b9eb3c423282211619d2", size = 19266, upload-time = "2025-08-09T10:39:06.636Z" }, ] +[[package]] +name = "genai-prices" +version = "0.0.55" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/77/67/de9d9be180db6d80b298c281dff71502095c0776d7cc9286f486f667f61a/genai_prices-0.0.55.tar.gz", hash = "sha256:8692c65d0deefe2ad0680d71841eb12822a35945a6060d2b6adbcbdf4945e1cb", size = 59987, upload-time = "2026-02-26T17:56:41.467Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/98/66a06b82a5c840f896490d5ef9c7691776b147589f2e8d2fa66c67a3db9c/genai_prices-0.0.55-py3-none-any.whl", hash = "sha256:ccd795c90c926b3c71066bf5656f14c67fc11fdba6d71e072c7fb4fa311e1b12", size = 62603, upload-time = "2026-02-26T17:56:40.502Z" }, +] + [[package]] name = "genson" version = "1.3.0" @@ -1287,6 +1301,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/18/79e9008530b79527e0d5f79e7eef08d3b179b7f851cfd3a2f27822fbdfa9/google_auth-2.47.0-py3-none-any.whl", hash = "sha256:c516d68336bfde7cf0da26aab674a36fedcf04b37ac4edd59c597178760c3498", size = 234867, upload-time = "2026-01-06T21:55:28.6Z" }, ] +[package.optional-dependencies] +requests = [ + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + [[package]] name = "google-cloud-core" version = "2.5.0" @@ -1329,6 +1348,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/fa/f50f51260d7b0ef5d4898af122d8a7ec5a84e2984f676f746445f783705f/google_crc32c-1.8.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8b3f68782f3cbd1bce027e48768293072813469af6a61a86f6bb4977a4380f21", size = 33734, upload-time = "2025-12-16T00:40:27.028Z" }, ] +[[package]] +name = "google-genai" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "google-auth", extra = ["requests"], marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "requests", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "tenacity", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "websockets", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9b/ba/0b343b0770d4710ad2979fd9301d7caa56c940174d5361ed4a7cc4979241/google_genai-1.66.0.tar.gz", hash = "sha256:ffc01647b65046bca6387320057aa51db0ad64bcc72c8e3e914062acfa5f7c49", size = 504386, upload-time = "2026-03-04T22:15:28.156Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/dd/403949d922d4e261b08b64aaa132af4e456c3b15c8e2a2d9e6ef693f66e2/google_genai-1.66.0-py3-none-any.whl", hash = "sha256:7f127a39cf695277104ce4091bb26e417c59bb46e952ff3699c3a982d9c474ee", size = 732174, upload-time = "2026-03-04T22:15:26.63Z" }, +] + [[package]] name = "google-resumable-media" version = "2.8.0" @@ -1391,6 +1431,31 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4f/dc/041be1dff9f23dac5f48a43323cd0789cb798342011c19a248d9c9335536/greenlet-3.3.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:6c10513330af5b8ae16f023e8ddbfb486ab355d04467c4679c5cfe4659975dd9", size = 1676034, upload-time = "2025-12-04T14:27:33.531Z" }, ] +[[package]] +name = "griffelib" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/51/c936033e16d12b627ea334aaaaf42229c37620d0f15593456ab69ab48161/griffelib-2.0.0-py3-none-any.whl", hash = "sha256:01284878c966508b6d6f1dbff9b6fa607bc062d8261c5c7253cb285b06422a7f", size = 142004, upload-time = "2026-02-09T19:09:40.561Z" }, +] + +[[package]] +name = "groq" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "distro", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/12/f4099a141677fcd2ed79dcc1fcec431e60c52e0e90c9c5d935f0ffaf8c0e/groq-1.0.0.tar.gz", hash = "sha256:66cb7bb729e6eb644daac7ce8efe945e99e4eb33657f733ee6f13059ef0c25a9", size = 146068, upload-time = "2025-12-17T23:34:23.115Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/88/3175759d2ef30406ea721f4d837bfa1ba4339fde3b81ba8c5640a96ed231/groq-1.0.0-py3-none-any.whl", hash = "sha256:6e22bf92ffad988f01d2d4df7729add66b8fd5dbfb2154b5bbf3af245b72c731", size = 138292, upload-time = "2025-12-17T23:34:21.957Z" }, +] + [[package]] name = "gunicorn" version = "23.0.0" @@ -1441,18 +1506,17 @@ wheels = [ [[package]] name = "httpx" -version = "0.27.2" +version = "0.28.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "certifi", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "httpcore", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "idna", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "sniffio", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/78/82/08f8c936781f67d9e6b9eeb8a0c8b4e406136ea4c3d1f89a5db71d42e0e6/httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2", size = 144189, upload-time = "2024-08-27T12:54:01.334Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/95/9377bcb415797e44274b51d46e3249eba641711cf3348050f76ee7b15ffc/httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", size = 76395, upload-time = "2024-08-27T12:53:59.653Z" }, + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] [[package]] @@ -1705,15 +1769,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/73/07/02e16ed01e04a374e644b575638ec7987ae846d25ad97bcc9945a3ee4b0e/jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade", size = 12898, upload-time = "2023-06-16T21:01:28.466Z" }, ] -[[package]] -name = "jsonpath-python" -version = "1.1.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b8/bf/626a72f2d093c5eb4f4de55b443714afa7231beeae40d4a1c69b5c5aa4d1/jsonpath_python-1.1.4.tar.gz", hash = "sha256:bb3e13854e4807c078a1503ae2d87c211b8bff4d9b40b6455ed583b3b50a7fdd", size = 84766, upload-time = "2025-11-25T12:08:39.521Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/bc/52e5bf0d9839e082b976c19afcab7561d0d719c7627483bf5dc251d27eed/jsonpath_python-1.1.4-py3-none-any.whl", hash = "sha256:8700cb8610c44da6e5e9bff50232779c44bf7dc5bc62662d49319ee746898442", size = 12687, upload-time = "2025-11-25T12:08:38.453Z" }, -] - [[package]] name = "jsonpointer" version = "3.0.0" @@ -1904,6 +1959,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/43/71/0f5d010e92ed9747e14bef35e91b6580533510f1e36a8a09eb79ee70b2f0/librt-0.7.8-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:cf243da9e42d914036fd362ac3fa77d80a41cadcd11ad789b1b5eec4daaf67ca", size = 224731, upload-time = "2026-01-14T12:55:58.175Z" }, ] +[[package]] +name = "logfire-api" +version = "4.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/94/5c/026cec30d85394aec8f5f12d70edbe2d706837bc9a411bd71a542cedae50/logfire_api-4.25.0.tar.gz", hash = "sha256:7562d5adfe3987291039dddb21947c86cb9d832d068c87d9aa23db86ef07095b", size = 75853, upload-time = "2026-02-19T15:27:29.518Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/39/83414c0fadb4f11f90e6b80b631aa79f62a605664f0c4693e2ebc7ee73f3/logfire_api-4.25.0-py3-none-any.whl", hash = "sha256:0d607eb09ef5426e26f376ff277a8d401bc5b7b4178ea66db404e13c368494cf", size = 120473, upload-time = "2026-02-19T15:27:25.832Z" }, +] + [[package]] name = "loguru" version = "0.7.3" @@ -2060,19 +2124,20 @@ wheels = [ [[package]] name = "mistralai" -version = "1.1.0" +version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "eval-type-backport", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "jsonpath-python", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-semantic-conventions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, { name = "python-dateutil", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-inspect", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/9c/4ea3ee3c8aac270e3d7fde9eb18c34209348f89815fbb356d04bf949e2aa/mistralai-1.1.0.tar.gz", hash = "sha256:9d1fe778e0e8c6ddab714e6a64c6096bd39cfe119ff38ceb5019d8e089df08ba", size = 117553, upload-time = "2024-09-17T16:25:53.342Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/5c/22fd7d1ec7e333f83dc5e2d0b176952a5d9a1f08519898c55616c92a81d8/mistralai-2.0.0.tar.gz", hash = "sha256:acb7937a53119ece67f4978809d4cf630fbf54b4dfe85c0eeae778ac40850fab", size = 317705, upload-time = "2026-03-10T17:12:48.616Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/64/9b/97d1f2f8fb4648008882284b2235d0b7b64b094ad4a4ee02c9c67c361578/mistralai-1.1.0-py3-none-any.whl", hash = "sha256:eea0938975195f331d0ded12d14e3c982f09f1b68210200ed4ff0c6b9b22d0fb", size = 229749, upload-time = "2024-09-17T16:25:51.963Z" }, + { url = "https://files.pythonhosted.org/packages/b8/95/1587d555837bf635db28e2acee366cc47edc473cd3155515be14acced91b/mistralai-2.0.0-py3-none-any.whl", hash = "sha256:e551fc36d60d4c969140e37f10eab04986480e487f357c900da05d740b9a0baf", size = 709642, upload-time = "2026-03-10T17:12:50.104Z" }, ] [[package]] @@ -2840,6 +2905,42 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, ] +[[package]] +name = "pydantic-ai-slim" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "genai-prices", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "griffelib", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "opentelemetry-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic-graph", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/23/31/1b291e2c169c684290b458a1333d438e34c542d355c60c0bc92866c192a2/pydantic_ai_slim-1.66.0.tar.gz", hash = "sha256:d675f3cf7171c7ea767084a2228d7a2e8eb88e18bfefba71387ed150fcb64069", size = 435408, upload-time = "2026-03-05T00:54:58.587Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/c9/098d675eb20863c6c92a23e09b6cc0d10df3f96191f04f3daefb31f180bc/pydantic_ai_slim-1.66.0-py3-none-any.whl", hash = "sha256:59dcccbcbf948d356dd4a03457962b4079db42c56edf8a11113d827015027e66", size = 566105, upload-time = "2026-03-05T00:54:51.611Z" }, +] + +[package.optional-dependencies] +anthropic = [ + { name = "anthropic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +bedrock = [ + { name = "boto3", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +google = [ + { name = "google-genai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +groq = [ + { name = "groq", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +openai = [ + { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "tiktoken", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] + [[package]] name = "pydantic-core" version = "2.41.5" @@ -2873,6 +2974,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/02/3c562f3a51afd4d88fff8dffb1771b30cfdfd79befd9883ee094f5b6c0d8/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470", size = 2331955, upload-time = "2025-11-04T13:41:54.079Z" }, ] +[[package]] +name = "pydantic-graph" +version = "1.66.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "logfire-api", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, + { name = "typing-inspection", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/25/5e/4a3ed6c4047fd2676b248cee3666299b6214f691c086fd5f9bdda96ace1d/pydantic_graph-1.66.0.tar.gz", hash = "sha256:834df5137098c2c95d2241b98d4dd61af4a3ff24784751c82cc543db46dd29f5", size = 58522, upload-time = "2026-03-05T00:55:01.019Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/95/22c0ad3f3830d7fdd4dbfdc78548705f6c9ac434ada0d790ffc02491b39e/pydantic_graph-1.66.0-py3-none-any.whl", hash = "sha256:8f75d34efbaa4b65767d39faa2b3270fd321fb4104a66d3773754f4854876739", size = 72351, upload-time = "2026-03-05T00:54:54.661Z" }, +] + [[package]] name = "pydantic-settings" version = "2.12.0" @@ -3761,19 +3877,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, ] -[[package]] -name = "typing-inspect" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "mypy-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/74/1789779d91f1961fa9438e9a8710cdae6bd138c80d7303996933d117264a/typing_inspect-0.9.0.tar.gz", hash = "sha256:b23fc42ff6f6ef6954e4852c1fb512cdd18dbea03134f91f856a95ccc9461f78", size = 13825, upload-time = "2023-05-24T20:25:47.612Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/65/f3/107a22063bf27bdccf2024833d3445f4eea42b2e598abfbd46f6a63b6cb0/typing_inspect-0.9.0-py3-none-any.whl", hash = "sha256:9ee6fc59062311ef8547596ab6b955e1b8aa46242d854bfc78f4f6b0eff35f9f", size = 8827, upload-time = "2023-05-24T20:25:45.287Z" }, -] - [[package]] name = "typing-inspection" version = "0.4.2" @@ -3813,22 +3916,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b1/5e/512aeb40fd819f4660d00f96f5c7371ee36fc8c6b605128c5ee59e0b28c6/u_msgpack_python-2.8.0-py2.py3-none-any.whl", hash = "sha256:1d853d33e78b72c4228a2025b4db28cda81214076e5b0422ed0ae1b1b2bb586a", size = 10590, upload-time = "2023-05-18T09:28:10.323Z" }, ] -[[package]] -name = "udspy" -version = "0.1.8" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "jiter", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "openai", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "regex", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, - { name = "tenacity", marker = "sys_platform == 'darwin' or sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d4/d8/0ab2a0258f4932f40004c759f79336100a590c1aa8296c75d797d47836e5/udspy-0.1.8.tar.gz", hash = "sha256:8da68fcbd118850eeff3750942c053006bafea335bf74f411055ccf27d800b3b", size = 270081, upload-time = "2025-11-28T15:20:55.19Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/ec/10076e9cb53685ffb01d2229df6372d4dbca3d4a3a0a93a03ad5126c40b2/udspy-0.1.8-py3-none-any.whl", hash = "sha256:3a66427b60f4cd6360ff95db76fd34a8fa9201fde9244b5a3db6b3fb2e424042", size = 60418, upload-time = "2025-11-28T15:20:53.849Z" }, -] - [[package]] name = "ujson" version = "5.12.0" diff --git a/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json b/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json new file mode 100644 index 0000000000..18f9dd2271 --- /dev/null +++ b/changelog/entries/unreleased/refactor/Replace udspy with pydantic-ai_replace_udspy_with_pydanticai.json @@ -0,0 +1,9 @@ +{ + "type": "refactor", + "message": "Replace udspy with pydantic-ai", + "issue_origin": "github", + "issue_number": null, + "domain": "core", + "bullet_points": [], + "created_at": "2026-03-04" +} \ No newline at end of file diff --git a/docs/installation/ai-assistant.md b/docs/installation/ai-assistant.md index 93e0369b02..ff0f900654 100644 --- a/docs/installation/ai-assistant.md +++ b/docs/installation/ai-assistant.md @@ -6,8 +6,8 @@ server. ## 1) Core concepts -- The assistant runs via **UDSPy** — see https://github.com/baserow/udspy/ -- UDSPy speaks to **any OpenAI-compatible API**. +- The assistant is built on [**pydantic-ai**](https://ai.pydantic.dev/) — a + Python agent framework that supports multiple LLM providers out of the box. - You **must** set `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` with the provider and model of your choosing. - The assistant has been mostly tested with the `gpt-oss-120b` family. Other models can @@ -21,61 +21,82 @@ Set the model you want, restart Baserow, and let migrations run. ```dotenv # Required -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-4o +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai:gpt-5.2 OPENAI_API_KEY=your_api_key -# Optional - adjust LLM temperature (default: 0) -BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE=0 +# Optional - adjust LLM temperature (default: 0.3) +BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE=0.3 ``` **About temperature:** -- Controls randomness in LLM responses (0.0 to 2.0) -- **Default: 0** (deterministic, consistent responses - recommended for production) -- Higher values (e.g., 0.7-1.0) = more creative/varied responses -- Lower values (e.g., 0-0.3) = more focused/consistent responses +- Controls randomness in the main assistant's LLM responses. +- **Default: 0.3** (focused, consistent responses) +- Higher values (depending on the model) = more creative/varied responses. +- Lower values (e.g., 0-0.1) = more analytical responses. Note that even with temperature of 0.0, the results will not be fully deterministic. ## 3) Provider presets -Choose **one** provider block and set its variables. +Choose **one** provider block and set its variables. pydantic-ai uses the standard +environment variables for each provider (e.g. `OPENAI_API_KEY`, `GROQ_API_KEY`). ### OpenAI / OpenAI-compatible ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-4o +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai:gpt-5.2 OPENAI_API_KEY=your_api_key -# Optional alternative endpoints (OpenAI EU or Azure OpenAI, etc.) -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://eu.api.openai.com/v1 +# Optional: point to an alternative OpenAI-compatible endpoint +OPENAI_BASE_URL=https://eu.api.openai.com/v1 # or -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://.openai.azure.com -# or any OpenAI compatible endpoint +OPENAI_BASE_URL=https://.openai.azure.com +``` + +### Anthropic + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=anthropic:claude-sonnet-4-20250514 +ANTHROPIC_API_KEY=your_api_key ``` ### AWS Bedrock +pydantic-ai supports two authentication methods for Bedrock. Use whichever matches your setup. + +**Option A — Standard AWS credentials (boto3)** + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock:openai.gpt-oss-120b-1:0 +AWS_ACCESS_KEY_ID=your_access_key +AWS_SECRET_ACCESS_KEY=your_secret_key +AWS_DEFAULT_REGION=eu-central-1 +``` + +Any boto3-compatible credential method works: env vars, IAM roles, instance profiles, `~/.aws/credentials`, etc. + +**Option B — Bedrock bearer token** + ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock/openai.gpt-oss-120b-1:0 -AWS_BEARER_TOKEN_BEDROCK=your_bedrock_token -AWS_REGION_NAME=eu-central-1 +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock:openai.gpt-oss-120b-1:0 +AWS_BEARER_TOKEN_BEDROCK=your_bearer_token +AWS_DEFAULT_REGION=eu-central-1 ``` ### Groq ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=groq/openai/gpt-oss-120b +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=groq:openai/gpt-oss-120b GROQ_API_KEY=your_api_key ``` ### Ollama ```dotenv -BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=ollama/gpt-oss:120b -OLLAMA_API_KEY=your_api_key -# Optionally and alternative endpoint -UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=http://localhost:11434/v1 +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=ollama:gpt-oss:120b +# Point to your Ollama instance (defaults to http://localhost:11434/v1) +OLLAMA_BASE_URL=http://localhost:11434/v1 ``` -Under the hood, UDSPy auto-detects provider from the model prefix and builds an -OpenAI-compatible client accordingly. +pydantic-ai auto-detects the provider from the model prefix and routes requests +accordingly. ## 4) Knowledge-base lookup @@ -123,3 +144,42 @@ just dcd run --rm web-frontend bash -c env | grep LLM_MODEL ``` Both commands must return the same value for `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL`. If either is missing or they differ, update your environment configuration and restart the services. + +## 6) Supported models + +OpenAI, Anthropic, AWS Bedrock, Groq, Gemini/Vertex AI and any OpenAI-compatible +endpoint (Azure, DeepSeek, Fireworks, LiteLLM, Perplexity, Together AI, etc.). + +## 7) Framework change: UDSPy to pydantic-ai + +The assistant previously used [UDSPy](https://github.com/baserow/udspy/) as its agent +framework. It now uses [pydantic-ai](https://ai.pydantic.dev/). Most environment +variables are unchanged or bridged for backward compatibility. + +### What stays the same + +| Variable | Notes | +|----------|-------| +| `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` | Works exactly as before. Both `provider/model` and `provider:model` formats are accepted. | +| `BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE` | Still supported. Overrides the orchestrator temperature when set. | +| `OPENAI_API_KEY` | Unchanged. | +| `GROQ_API_KEY` | Unchanged. | +| `AWS_BEARER_TOKEN_BEDROCK` | Still works — pydantic-ai supports Bedrock bearer token auth natively. | + +### Bridged for backward compatibility (no action needed) + +| Old variable | Equivalent | Notes | +|--------------|------------|-------| +| `UDSPY_LM_MODEL` | `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` | If set and the new var is absent, the old value is used automatically. | +| `UDSPY_LM_API_KEY` | `OPENAI_API_KEY` / `GROQ_API_KEY` / etc. | Propagated to all provider key variables as a fallback. | +| `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL` | `OPENAI_BASE_URL` | Still works; bridged automatically. | +| `AWS_REGION_NAME` | `AWS_DEFAULT_REGION` | Still works; bridged automatically. | + +### New variables + +| Variable | Notes | +|----------|-------| +| `OPENAI_BASE_URL` | Preferred replacement for `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL`. | +| `AWS_DEFAULT_REGION` | Preferred replacement for `AWS_REGION_NAME`. | +| `OLLAMA_BASE_URL` | Replaces `UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL` for Ollama. Defaults to `http://localhost:11434/v1`. | +| `ANTHROPIC_API_KEY` | New provider — Anthropic models are now supported. | diff --git a/docs/installation/configuration.md b/docs/installation/configuration.md index c0d66d6153..fc65e449c7 100644 --- a/docs/installation/configuration.md +++ b/docs/installation/configuration.md @@ -188,12 +188,12 @@ The installation methods referred to in the variable descriptions are: | Name | Description | Defaults | |---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------| | BASEROW\_EMBEDDINGS\_API\_URL | If not empty, the AI-assistant will use this as embedding server for the knowledge base lookup. Must point to a container running this image: https://hub.docker.com/r/baserow/embeddings | "" (empty string) | -| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Provide like `groq/openai/gpt-oss-120b` or `bedrock/openai.gpt-oss-120b-1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions can be found at https://baserow.io/docs/installation/ai-assistant | "" (empty string) | -| AWS\_BEARER\_TOKEN\_BEDROCK | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then this environment variable must be set. Instructions on how to obtain: https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html | "" (empty string) | -| AWS\_REGION\_NAME | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=groq/openai/gpt-oss-120b, then the AWS region for the AI-assistant can be provided here. | us-east-1 | -| GROQ\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then the Groq API key can be provided here. | "" (empty string) | -| OLLAMA\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=ollama/gpt-oss:120b, then the Ollama API key can be provided here. | "" (empty string) | -| UDSPY\_LM\_OPENAI\_COMPATIBLE\_BASE\_URL | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=openai/gpt-5-nano, then the base URL can be changed here. This can be used to point to a different OpenAI compatible API like Azure, for example. | "" (empty string) | +| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Provide in pydantic-ai format like `groq:openai/gpt-oss-120b` or `bedrock:openai.gpt-oss-120b-1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions can be found at https://baserow.io/docs/installation/ai-assistant | "" (empty string) | +| AWS\_BEARER\_TOKEN\_BEDROCK | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a bedrock provider, then this environment variable must be set. Instructions on how to obtain: https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html | "" (empty string) | +| AWS\_REGION\_NAME | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a bedrock provider, then the AWS region for the AI-assistant can be provided here. | us-east-1 | +| GROQ\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses a groq provider (e.g. `groq:openai/gpt-oss-120b`), then the Groq API key must be provided here. | "" (empty string) | +| OLLAMA\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses an ollama provider (e.g. `ollama:gpt-oss:120b`), then the Ollama API key can be provided here. | "" (empty string) | +| UDSPY\_LM\_OPENAI\_COMPATIBLE\_BASE\_URL | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL uses an openai provider (e.g. `openai:gpt-5-nano`), then the base URL can be changed here. This can be used to point to a different OpenAI compatible API like Azure, for example. | "" (empty string) | ### Data sync configuration diff --git a/docs/installation/install-with-helm.md b/docs/installation/install-with-helm.md index 103c678978..f71e9cad4b 100644 --- a/docs/installation/install-with-helm.md +++ b/docs/installation/install-with-helm.md @@ -183,7 +183,7 @@ Add to your `config.yaml`: ```yaml baserow-embeddings: enabled: true - assistantLLMModel: "groq/openai/gpt-oss-120b" + assistantLLMModel: "groq:openai/gpt-oss-120b" backendSecrets: GROQ_API_KEY: "your-groq-api-key" diff --git a/docs/testing/ai-assistant-evals.md b/docs/testing/ai-assistant-evals.md new file mode 100644 index 0000000000..77950cc27c --- /dev/null +++ b/docs/testing/ai-assistant-evals.md @@ -0,0 +1,255 @@ +# AI Assistant Evals + +The assistant eval suite runs the real agent against a live LLM to verify +end-to-end behaviour: tool selection, schema compatibility, row creation, etc. + +All eval tests live under +`enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/` and are +marked with `@pytest.mark.eval` so they are **skipped by default** in CI and +local test runs. + +## Prerequisites + +1. A running PostgreSQL database (see [running-tests.md](../development/running-tests.md)). +2. An API key for the LLM provider you want to test against. +3. **For `test_eval_search_user_docs` only:** an embeddings server and a + synced knowledge base (see [Search docs evals](#search-docs-evals) below). + +## Quick start + +```bash +# Set your API key (Groq example — works with any pydantic-ai provider) +export GROQ_API_KEY=gsk_... + +# Run all evals with the default model (groq:openai/gpt-oss-120b) +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v + +# Run a single eval file +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py \ + -m eval -v +``` + +> **Tip:** Do **not** pass `-s`. Without it, pytest captures `print_message_history` output and shows it only in the failure report — passing tests stay silent. Use `-s` only when you want to watch the agent's tool calls in real time for a single test. + + +## Configuration + +All configuration is via environment variables: + +| Variable | Default | Description | +|----------|---------|-------------| +| `EVAL_LLM_MODEL` | `groq:openai/gpt-oss-120b` | Model string in pydantic-ai format (`provider:model`). Accepts a comma-separated list to parametrize every eval across multiple models. | +| `EVAL_RETRIES` | `0` | Retry each failing eval test up to N times. If a test passes on retry it's a flake (LLM non-determinism); if it fails all N retries it's a consistent bug. | +| `GROQ_API_KEY` | — | Required when using a Groq model. | +| `OPENAI_API_KEY` | — | Required when using an OpenAI model. | +| `ANTHROPIC_API_KEY` | — | Required when using an Anthropic model. | + +### API keys from a file + +The eval conftest reads API keys from the same `TEST_ENV_FILE` that +`baserow/config/settings/test.py` already parses, and exposes them via +`os.environ` so that LLM provider SDKs can find them: + +```bash +TEST_ENV_FILE=.env.testing-local just b test \ + ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ -m eval -v -s +``` + +Variables already present in `os.environ` take precedence. + +### Running against multiple models + +```bash +GROQ_API_KEY=... OPENAI_API_KEY=... EVAL_LLM_MODEL="groq:openai/gpt-oss-120b,openai:gpt-4o" \ +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v -s +``` + +Each test will run once per model, with the model name shown in the test ID. + +## Test files + +File names follow the pattern `test_eval_{module}_{feature}.py`, where module +maps to the tool directory (`core`, `database`, `automation`, `navigation`, +`search_user_docs`). Browse +`enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/` for the +full list. Each file defines its prompts as module-level `PROMPT_*` constants +at the top, making it easy to scan which scenarios are covered without reading +the test bodies. + +## Writing a new eval + +1. Create a new `test_eval_.py` file in the `evals/` directory. +2. Define prompts as `PROMPT_*` constants at the top, so it's easier to have an overview of the existing evals. +3. Mark each test with `@pytest.mark.eval` and + `@pytest.mark.django_db(transaction=True)`. +4. Use the helpers from `eval_utils.py`: + +```python +import pytest +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +PROMPT_DOES_SOMETHING = "Do something useful in database {database_name}" + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_does_something(data_fixture, eval_model): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace, name="Test") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=PROMPT_DOES_SOMETHING.format(database_name=database.name), + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + with EvalChecklist("does something") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + # Add domain-specific checks here + checks.check("created the thing", some_condition, hint="details if failed") +``` + +### Key helpers + +| Helper | Purpose | +|--------|---------| +| `create_eval_assistant(user, workspace, max_iters, model)` | Returns `(agent, deps, tracker, model, usage_limits, toolset)` configured like production. | +| `build_database_ui_context(user, workspace, database, table)` | Builds the UI context JSON the agent receives. | +| `count_tool_errors(result)` | Returns `(error_count, hint)` — count of tool validation errors (pydantic retries) and a formatted hint string. Use with `EvalChecklist`: `checks.check("no tool errors", err_count == 0, hint=err_hint)`. | +| `EvalChecklist(name)` | Context manager for soft assertions: collects checks, prints a score table (`4/6 (66%)`), and only hard-fails at the end. Use for tests with multiple independent checks. | +| `print_message_history(result)` | Prints the full agent conversation to stdout. | +| `format_message_history(result)` | Returns the conversation as a list of dicts for programmatic assertions. | + +## Search docs evals + +`test_eval_search_user_docs.py` tests the `search_user_docs` tool end-to-end: +the agent receives a real user question, decides to call the tool, the tool +performs a vector search against the knowledge base, and a sub-agent produces +an answer with source URLs. The test verifies that: + +1. The agent called `search_user_docs`. +2. The answer mentions expected concepts (e.g. "date_diff" for a date + formula question). +3. Returned source URLs match expected documentation pages (non-fatal + warning if not — URLs can change). + +### Additional prerequisites + +These tests are **automatically skipped** when the knowledge base is not +available. To enable them: + +1. **Embeddings server** — start the embeddings service and set: + ```bash + # Running tests outside Docker (local dev): + export BASEROW_EMBEDDINGS_API_URL=http://localhost:7999 + # Running tests inside Docker: + export BASEROW_EMBEDDINGS_API_URL=http://embeddings + ``` + +2. **pgvector extension** — the PostgreSQL instance must have the `vector` + extension installed. If you use the dev Docker setup this is already + included. + +3. **Sync the knowledge base** — the test suite handles this automatically + (see [Knowledge base caching](#knowledge-base-caching) below), but you + can also trigger a manual sync: + ```bash + # From the backend directory, with the Django env active: + python -m baserow sync_knowledge_base + ``` + This reads `website_export.csv` (user docs) and `docs/` (dev docs), + creates `KnowledgeBaseDocument` / `KnowledgeBaseChunk` rows, and + generates embeddings via the embeddings server. + +### Knowledge base caching + +Syncing the knowledge base is slow (it generates embeddings for every +documentation chunk). To avoid repeating this on every test run, the eval +suite uses two mechanisms together: + +1. **Session-scoped fixture** — the `synced_knowledge_base` fixture in + `conftest.py` runs once per pytest session. It checks whether the KB is + already populated (`handler.can_search()`) and only calls + `sync_knowledge_base()` when it isn't. + +2. **`--reuse-db`** — pytest-django's `--reuse-db` flag keeps the test + database between sessions instead of recreating it. Combined with the + fixture above, the expensive sync only happens on the very first run. + Subsequent runs detect that the data is already there and skip the sync + entirely. + +3. **No `transaction=True`** — search docs tests use + `@pytest.mark.django_db` (savepoint rollback) rather than + `@pytest.mark.django_db(transaction=True)` (full table truncation). This + is important: `transaction=True` would wipe the knowledge base tables + after each test, defeating the caching. + +**Typical workflow:** + +| Run | What happens | Time | +|-----|--------------|------| +| First ever | DB created, KB synced, tests run | Several minutes | +| Subsequent | DB reused, KB already populated, tests run | Seconds | + +To force a fresh sync (e.g. after schema changes or new documentation): + +```bash +# Drop and recreate the test DB, then re-sync +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s --create-db +``` + +### Running search docs evals + +```bash +# Only search docs evals +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s + +# A single test case by parametrize ID +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py \ + -m eval -v -s -k "vlookup-to-link-row" +``` + +If the embeddings server is not running or the knowledge base has not been +synced, all search docs tests will be skipped with a clear message. + +## Troubleshooting + +### `FAILED — No API key` + +Make sure the correct `*_API_KEY` env var is set for your provider/ + +### Flaky results + +LLM evals are inherently non-deterministic. If a test fails intermittently: + +- Use `EVAL_RETRIES` to automatically distinguish flakes from consistent bugs: + ```bash + EVAL_RETRIES=3 just b test \ + ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py \ + -m eval -v -s + ``` + A test that passes on retry is a flake; one that fails all 3 retries is a real problem. +- Check the printed message history (`-s` flag) to see what the agent did. +- If a prompt is ambiguous, tighten the wording in the `PROMPT_*` constant. +- Consider lowering the temperature in the model profile for the eval model. diff --git a/docs/testing/ai-assistant-test-plan.md b/docs/testing/ai-assistant-test-plan.md new file mode 100644 index 0000000000..ff9b12de83 --- /dev/null +++ b/docs/testing/ai-assistant-test-plan.md @@ -0,0 +1,134 @@ +# AI Assistant Test Plan + +## How to test + +### 1. Automated tests (unit) + +Run the unit test suite (no LLM needed): + +```bash +just b test -n auto ../enterprise/backend/tests/baserow_enterprise_tests/assistant/ \ + -v --ignore=enterprise/backend/tests/baserow_enterprise_tests/assistant/evals +``` + +All tests must pass. These cover: assistant orchestrator, all tool modules, +telemetry event emission, history compaction, and streaming. + +### 2. Automated tests (evals, optional) + +Run the eval suite against a live LLM. The default model is +`groq:openai/gpt-oss-120b`, so you need a `GROQ_API_KEY`. Evals that exercise +the `search_user_docs` tool also require a running embedding service — set +`BASEROW_EMBEDDINGS_API_URL` to point to it, or those evals will fail. + +```bash +GROQ_API_KEY=gsk_... BASEROW_EMBEDDINGS_API_URL=http://... \ +just b test ../enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/ \ + -m eval -v -s +``` + +> **Note:** Evals are non-deterministic and are not guaranteed to pass every +> run. When a failure occurs, check whether the model did something +> fundamentally wrong or whether the result is still acceptable. See +> [ai-assistant-evals.md](ai-assistant-evals.md) for details on configuration, +> multi-model runs, and how to interpret results. + +### 3. Manual: Tool smoke tests + +Open the assistant in the UI and verify each tool works end-to-end. Suggested +prompts: + +| Tool | Prompt | +|------|--------| +| `navigate` | "Go to the Customers table" | +| `list_builders` | "What builders do I have?" | +| `create_builders` | "Create a new application called Test App" | +| `list_tables` | "What tables are in my database?" | +| `get_tables_schema` | "Show me the schema of the Customers table" | +| `list_rows` | "Show me the first rows of the Customers table" | +| `list_views` | "What views does the Customers table have?" | +| `create_tables` | "Create a table called Projects with columns: Name (text), Status (single select: Active/Done), Due date (date)" | +| `create_fields` | "Add an email field to the Customers table" | +| `create_views` | "Create a kanban view grouped by Status on the Projects table" | +| `create_view_filters` | "Add a filter on the Projects grid view to only show Active rows" | +| `generate_formula` | "Add a formula field that concatenates first name and last name" | +| `update_fields` | "Rename the email field to Contact Email in the Customers table" | +| `delete_fields` | "Delete the Contact Email field from the Customers table" | +| `load_row_tools` | "Add a row to the Projects table: Name=Launch, Status=Active" (this implicitly triggers load_row_tools first) | +| `update_rows_in_table_X` | "Change the Status of the Launch row in Projects to Done" | +| `delete_rows_in_table_X` | "Delete the Launch row from the Projects table" | +| `list_workflows` | "What automations do I have?" | +| `create_workflows` | "Create an automation that sends a notification when a row is created in Projects" | +| `list_nodes` | "What nodes are in my first workflow?" | +| `add_nodes` | "Add a Slack notification action after the trigger in my workflow" | +| `update_nodes` | "Rename the trigger node to New Project Trigger" | +| `delete_nodes` | "Delete the Slack notification node from my workflow" | +| `search_user_docs` | "How do I create a lookup field?"* | + +* Make sure you synced the knowledge base first, look at [ai-assistant.md](../installation/ai-assistant.md) for more info. + +### 4. Manual: Feedback + +- Send a message, then click the thumbs-up/thumbs-down on the response +- Verify the feedback is recorded (no errors in the console/network tab) +- Refresh the page, the previously selected thumb up/down button must be highlighted + +### 5. Manual: Conversation memory (history) + +Test that the agent retains multi-turn context: + +1. Send: "My name is Mario" +2. Agent responds acknowledging +3. Send: "What's my name?" +4. Agent should respond "Mario" (proves history serialization/deserialization + via `message_history` field works) + +Also test a longer conversation (3+ turns) to verify the compaction doesn't +lose essential context. + +### 6. Manual: Telemetry / PostHog traces + +Requires PostHog configured (`POSTHOG_PROJECT_API_KEY`, `POSTHOG_HOST` etc.): + +1. Send a few messages exercising different tools +2. Go to PostHog > LLM Analytics > Traces +3. Verify: + - Each conversation turn appears as a `$ai_trace` + - Tool calls appear as `$ai_span` children + - LLM generations appear as `$ai_generation` with model name, token counts, + latency + - Input/output content is captured (not empty) + +### 7. Manual: Knowledge base (search_user_docs) + +Requires an embeddings server and synced KB (look at [ai-assistant.md](../installation/ai-assistant.md) for more info). Verify: + +- Ask a Baserow how-to question (e.g. "How do I set up SSO?") -> agent should + call `search_user_docs` and cite sources +- Ask a creative task (e.g. "Create a table for tracking expenses") -> agent + should NOT call search_user_docs, should just act +- Ask a question about the agent's own tools (e.g. "What tools do you have?") + -> agent should NOT search docs, should answer from its own knowledge + +### 8. Manual: Do vs. Describe + +Verify the agent acts rather than describes: + +- "Create a table called Invoices" -> should actually create it (call + `create_tables`), not describe how to do it +- "How would I create a table?" -> should describe the manual UI steps (no + tools available for this meta-question) or search docs +- After creating something, the agent should navigate to it to show the result + +### 9. Manual: Cancellation + +1. Send a long-running request (e.g. "Create a table with 10 fields") +2. Click cancel mid-execution +3. Verify the stream stops cleanly without error toasts + +### 10. Manual: Error handling + +- Misconfigure the LLM API key and try to chat -> should show a clear error, + not a stack trace +- Send a prompt referencing a non-existent table/database/any other resource -> agent should + handle gracefully diff --git a/enterprise/backend/pytest.ini b/enterprise/backend/pytest.ini index 9d3346f478..28c968f590 100644 --- a/enterprise/backend/pytest.ini +++ b/enterprise/backend/pytest.ini @@ -1,5 +1,7 @@ [pytest] DJANGO_SETTINGS_MODULE = baserow.config.settings.test python_files = test_*.py +markers = + eval: mark test as an eval test (requires LLM API key) env = DJANGO_SETTINGS_MODULE = baserow.config.settings.test diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py index 1987e6f224..0c117df6cc 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py @@ -23,10 +23,7 @@ from baserow.api.sessions import set_client_undo_redo_action_group_id from baserow.core.exceptions import UserNotInWorkspace, WorkspaceDoesNotExist from baserow.core.handler import CoreHandler -from baserow_enterprise.assistant.assistant import ( - check_lm_ready_or_raise, - set_assistant_cancellation_key, -) +from baserow_enterprise.assistant.assistant import set_assistant_cancellation_key from baserow_enterprise.assistant.exceptions import ( AssistantChatDoesNotExist, AssistantChatMessagePredictionDoesNotExist, @@ -34,6 +31,7 @@ AssistantModelNotSupportedError, ) from baserow_enterprise.assistant.handler import AssistantHandler +from baserow_enterprise.assistant.model_profiles import check_lm_ready_or_raise from baserow_enterprise.assistant.models import AssistantChatPrediction from baserow_enterprise.assistant.operations import ChatAssistantChatOperationType from baserow_enterprise.assistant.types import ( diff --git a/enterprise/backend/src/baserow_enterprise/apps.py b/enterprise/backend/src/baserow_enterprise/apps.py index 7be1301b82..9a482941bb 100755 --- a/enterprise/backend/src/baserow_enterprise/apps.py +++ b/enterprise/backend/src/baserow_enterprise/apps.py @@ -1,4 +1,5 @@ from django.apps import AppConfig +from django.conf import settings from django.db.models.signals import post_migrate from tqdm import tqdm @@ -282,11 +283,12 @@ def ready(self): # Make sure that the assistant knowledge base is up to date after running the # migrations. - post_migrate.connect( - sync_assistant_knowledge_base, - sender=self, - dispatch_uid="sync_assistant_knowledge_base", - ) + if not settings.TESTS: + post_migrate.connect( + sync_assistant_knowledge_base, + sender=self, + dispatch_uid="sync_assistant_knowledge_base", + ) from baserow_enterprise.teams.receivers import ( connect_to_post_delete_signals_to_cascade_deletion_to_team_subjects, @@ -313,43 +315,6 @@ def ready(self): notification_type_registry.register(TwoWaySyncUpdateFailedNotificationType()) notification_type_registry.register(TwoWaySyncDeactivatedNotificationType()) - from baserow_enterprise.assistant.tools import ( - CreateBuildersToolType, - GenerateDatabaseFormulaToolType, - GetTablesSchemaToolType, - ListBuildersToolType, - ListRowsToolType, - ListTablesToolType, - ListViewsToolType, - ListWorkflowsToolType, - NavigationToolType, - RowsToolFactoryToolType, - SearchDocsToolType, - TableAndFieldsToolFactoryToolType, - ViewsToolFactoryToolType, - WorkflowToolFactoryToolType, - ) - from baserow_enterprise.assistant.tools.registries import ( - assistant_tool_registry, - ) - - assistant_tool_registry.register(SearchDocsToolType()) - assistant_tool_registry.register(NavigationToolType()) - - assistant_tool_registry.register(ListBuildersToolType()) - assistant_tool_registry.register(CreateBuildersToolType()) - assistant_tool_registry.register(ListTablesToolType()) - assistant_tool_registry.register(GetTablesSchemaToolType()) - assistant_tool_registry.register(TableAndFieldsToolFactoryToolType()) - assistant_tool_registry.register(GenerateDatabaseFormulaToolType()) - assistant_tool_registry.register(ListRowsToolType()) - assistant_tool_registry.register(RowsToolFactoryToolType()) - assistant_tool_registry.register(ListViewsToolType()) - assistant_tool_registry.register(ViewsToolFactoryToolType()) - - assistant_tool_registry.register(ListWorkflowsToolType()) - assistant_tool_registry.register(WorkflowToolFactoryToolType()) - from baserow_enterprise.views.operations import ( ListenToAllRestrictedViewEventsOperationType, ) @@ -376,6 +341,29 @@ def ready(self): page_registry.register(RestrictedViewPageType()) view_realtime_rows_registry.register(RestrictedViewRealtimeRowsType()) + from baserow_enterprise.assistant.tools.automation.tool_types import ( + AutomationToolType, + ) + from baserow_enterprise.assistant.tools.core.tool_types import CoreToolType + from baserow_enterprise.assistant.tools.database.tool_types import ( + DatabaseToolType, + ) + from baserow_enterprise.assistant.tools.navigation.tool_types import ( + NavigationToolType, + ) + from baserow_enterprise.assistant.tools.registries import ( + assistant_tool_registry, + ) + from baserow_enterprise.assistant.tools.search_user_docs.tool_types import ( + SearchDocsToolType, + ) + + assistant_tool_registry.register(NavigationToolType()) + assistant_tool_registry.register(CoreToolType()) + assistant_tool_registry.register(DatabaseToolType()) + assistant_tool_registry.register(AutomationToolType()) + assistant_tool_registry.register(SearchDocsToolType()) + # The signals must always be imported last because they use the registries # which need to be filled first. import baserow_enterprise.assistant.tasks # noqa: F401 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/agents.py new file mode 100644 index 0000000000..1af2110615 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/agents.py @@ -0,0 +1,81 @@ +from pydantic_ai import Agent, RunContext +from pydantic_ai.toolsets import FunctionToolset + +from baserow_enterprise.assistant.deps import AssistantDeps +from baserow_enterprise.assistant.prompts import AGENT_SYSTEM_PROMPT +from baserow_enterprise.assistant.tools.toolset import tool_manifest_line_compact + +main_agent: Agent[AssistantDeps, str] = Agent( + deps_type=AssistantDeps, + output_type=str, + instructions=AGENT_SYSTEM_PROMPT, + retries=3, + name="main_agent", +) + + +@main_agent.instructions +def dynamic_ui_context(ctx) -> str: + """Inject the UI context into the system prompt dynamically.""" + + ui_context = ctx.deps.tool_helpers.request_context.get("ui_context") + if ui_context: + return f"\n\n{ui_context}\n" + return "" + + +@main_agent.instructions +def dynamic_mode(ctx) -> str: + """Inject the current agent mode into the system prompt.""" + + return f"\n{ctx.deps.mode.value}" + + +@main_agent.instructions +def dynamic_current_task(ctx) -> str: + """Pin the original user request as immutable context.""" + + if ctx.deps.original_request: + return f"\n\n{ctx.deps.original_request}\n" + return "" + + +@main_agent.instructions +def dynamic_tool_manifest(ctx) -> str: + """ + Inject the available tools manifest into the system prompt, including both + static and dynamically loaded tools name and description. + """ + + manifest = ctx.deps.active_manifest + if not manifest: + return "" + + # Append dynamically loaded tools (e.g. row tools from load_row_tools) + if ctx.deps.dynamic_tools: + extra = "\n".join( + tool_manifest_line_compact(tool.name, tool.description or "") + for tool in ctx.deps.dynamic_tools + ) + manifest = manifest + "\n" + extra + + return f"\n\n{manifest}\n" + + +@main_agent.toolset +def dynamic_toolset(ctx: RunContext[AssistantDeps]): + """Make dynamically loaded tools available to the agent.""" + + if ctx.deps.dynamic_tools: + ts = FunctionToolset() + for tool in ctx.deps.dynamic_tools: + ts.add_tool(tool) + return ts + return None + + +title_agent: Agent[None, str] = Agent( + output_type=str, + instructions="Create a short title (max 50 chars) for the following user request.", + name="title_agent", +) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index a7679b4805..f5fbb5a423 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -1,30 +1,56 @@ -from dataclasses import dataclass -from functools import lru_cache -from typing import Any, AsyncGenerator, Callable, Tuple, TypedDict +import asyncio +from typing import Any, AsyncGenerator -from django.conf import settings from django.core.cache import cache from django.utils import translation -import udspy -from udspy.callback import BaseCallback +from loguru import logger +from pydantic_ai._thinking_part import split_content_into_text_and_thinking +from pydantic_ai.messages import ( + FunctionToolCallEvent, + FunctionToolResultEvent, + ModelMessage, + ModelMessagesTypeAdapter, + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, +) +from pydantic_ai.run import AgentRunResultEvent +from pydantic_ai.usage import UsageLimits from baserow.api.sessions import get_client_undo_redo_action_group_id -from baserow_enterprise.assistant.exceptions import ( - AssistantMessageCancelled, - AssistantModelNotSupportedError, +from baserow_enterprise.assistant.agents import main_agent, title_agent +from baserow_enterprise.assistant.deps import ( + AgentMode, + AssistantDeps, + EventBus, + QueueEvent, + QueueEventKind, + ToolHelpers, +) +from baserow_enterprise.assistant.exceptions import AssistantMessageCancelled +from baserow_enterprise.assistant.history import compact_message_history +from baserow_enterprise.assistant.model_profiles import ( + ORCHESTRATOR, + TITLE, + get_model_settings, + get_model_string, +) +from baserow_enterprise.assistant.retrying_model import RetryingModel +from baserow_enterprise.assistant.telemetry import ( + PosthogTracingCallback, + setup_instrumentation, ) -from baserow_enterprise.assistant.telemetry import PosthogTracingCallback -from baserow_enterprise.assistant.tools.navigation.types import AnyNavigationRequestType from baserow_enterprise.assistant.tools.navigation.utils import unsafe_navigate_to from baserow_enterprise.assistant.tools.registries import assistant_tool_registry from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction -from .signatures import ChatSignature from .types import ( AiMessage, AiMessageChunk, - AiNavigationMessage, AiReasoningChunk, AiStartedMessage, AiThinkingMessage, @@ -33,176 +59,119 @@ HumanMessage, ) +_CANCELLATION_KEY_TTL = 300 # seconds +_THINKING_TAGS = ("", "") -@dataclass -class ToolHelpers: - update_status: Callable[[str], None] - navigate_to: Callable[["AnyNavigationRequestType"], str] - - -class AssistantMessagePair(TypedDict): - question: str - answer: str +def _strip_think_tags(text: str) -> str: + """Remove ``...`` blocks from *text*, returning only the + non-thinking content. Uses pydantic-ai's own tag parser. -class AssistantCallbacks(BaseCallback): - def __init__(self, tool_helpers: ToolHelpers | None = None): - self.tool_helpers = tool_helpers - self.tool_calls = {} - self.sources = [] - - def extend_sources(self, sources: list[str]) -> None: - """ - Extends the current list of sources with new ones, avoiding duplicates. - - :param sources: The list of new source URLs to add. - :return: None - """ - - self.sources.extend([s for s in sources if s not in self.sources]) - - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ) -> None: - """ - Called when a tool starts. It records the tool call and invokes the - corresponding tool's on_tool_start method if it exists. - - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the tool being called. - :param inputs: The inputs provided to the tool. - """ - - try: - assistant_tool_registry.get(instance.name).on_tool_start( - call_id, instance, inputs - ) - self.tool_calls[call_id] = (instance, inputs) - except assistant_tool_registry.does_not_exist_exception_class: - pass - - def on_tool_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ) -> None: - """ - Called when a tool ends. It invokes the corresponding tool's on_tool_end - method if it exists and updates the sources if the tool produced any. - - :param call_id: The unique identifier of the tool call. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. - """ + Also strips any trailing unclosed ```` block that may appear + during streaming (the closing tag hasn't arrived yet). + """ - if call_id not in self.tool_calls: - return + if "" not in text: + return text - instance, inputs = self.tool_calls.pop(call_id) - assistant_tool_registry.get(instance.name).on_tool_end( - call_id, instance, inputs, outputs, exception - ) + # Strip any trailing unclosed block (common during streaming) + last_open = text.rfind("") + last_close = text.rfind("") + if last_open > last_close: + text = text[:last_open] - if exception is not None and self.tool_helpers is not None: - self.tool_helpers.update_status( - f"Calling the {instance.name} tool encountered an error." - ) + if "" not in text: + return text.strip() - # If the tool produced sources, add them to the overall list of sources. - if isinstance(outputs, dict) and "sources" in outputs: - self.extend_sources(outputs["sources"]) + parts = split_content_into_text_and_thinking(text, _THINKING_TAGS) + return "".join(p.content for p in parts if not isinstance(p, ThinkingPart)).strip() def get_assistant_cancellation_key(chat_uuid: str) -> str: - """ - Get the Redis cache key for cancellation tracking. - - :param chat_uuid: The UUID of the assistant chat. - :return: The cache key as a string. - """ + """Return the cache key used to signal cancellation for a chat session.""" return f"assistant:chat:{chat_uuid}:cancelled" -def set_assistant_cancellation_key(chat_uuid: str, timeout: int = 300) -> None: - """ - Set the cancellation flag in the cache for the given chat UUID. +def set_assistant_cancellation_key( + chat_uuid: str, timeout: int = _CANCELLATION_KEY_TTL +) -> None: + """Set the cancellation flag in the cache for a chat session.""" - :param chat_uuid: The UUID of the assistant chat. - :param timeout: The time in seconds after which the cancellation flag expires. - """ + cache.set(get_assistant_cancellation_key(chat_uuid), True, timeout=timeout) - cache_key = get_assistant_cancellation_key(chat_uuid) - cache.set(cache_key, True, timeout=timeout) +def _extract_tool_thought(event: FunctionToolCallEvent) -> str | None: + """Extract the chain-of-thought ``thought`` argument from a tool call + event, if present and non-empty.""" -def get_lm_client( - model: str | None = None, -) -> "Assistant": - """ - Returns a udspy.LM client configured with the specified model or the default model. - - :param model: The language model to use. If None, the default model from settings - will be used. - :return: A udspy.LM instance. - """ + try: + args = event.part.args_as_dict() + except Exception: + return None + thought = args.get("thought") + return thought if isinstance(thought, str) and thought.strip() else None - return udspy.LM(model=model or settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL) +class Assistant: + """Orchestrates a single assistant chat session. -@lru_cache(maxsize=1) -def check_lm_ready_or_raise() -> None: - """ - Checks if the configured LLM is ready by making a test call. Raises - AssistantModelNotSupportedError if the model is not supported or accessible. + Wires together the pydantic-ai agent, toolsets, telemetry, event + streaming, and message persistence for one ``AssistantChat``. """ - lm = get_lm_client() - try: - lm("Respond in JSON: {'response': 'ok'}") - except Exception as e: - raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or accessible: {e}" - ) - - -class Assistant: def __init__(self, chat: AssistantChat): self._chat = chat self._user = chat.user self._workspace = chat.workspace + self._model_string = get_model_string() + self._model = RetryingModel(self._model_string) + self._event_bus = EventBus() + self._tool_helpers = self._build_tool_helpers() + self._telemetry = PosthogTracingCallback() + + self._deps = AssistantDeps( + user=self._user, + workspace=self._workspace, + tool_helpers=self._tool_helpers, + ) + self._toolset, db_m, app_m, auto_m, explain_m = ( + assistant_tool_registry.build_toolset( + user=self._user, + workspace=self._workspace, + model=self._model_string, + deps=self._deps, + ) + ) + self._deps.database_manifest = db_m + self._deps.application_manifest = app_m + self._deps.automation_manifest = auto_m + self._deps.explain_manifest = explain_m - self._lm_client = get_lm_client() - self._init_assistant() + setup_instrumentation() - def _init_assistant(self): - self.history = None - self.tool_helpers = self.get_tool_helpers() - tools = [ - t if isinstance(t, udspy.Tool) else udspy.Tool(t) - for t in assistant_tool_registry.list_all_usable_tools( - self._user, self._workspace, self.tool_helpers - ) - ] - - self._assistant_callbacks = AssistantCallbacks(self.tool_helpers) - self._telemetry_callbacks = PosthogTracingCallback() - self._callbacks = [self._assistant_callbacks, self._telemetry_callbacks] - - module_kwargs = { - "temperature": settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE, - "response_format": {"type": "json_object"}, - } - self._assistant = udspy.ReAct( - ChatSignature, tools=tools, max_iters=20, **module_kwargs + # ------------------------------------------------------------------ + # Setup + # ------------------------------------------------------------------ + + def _build_tool_helpers(self) -> ToolHelpers: + """Create the ``ToolHelpers`` that tools use for status updates, + navigation, and cancellation during the agent run.""" + + def update_status(status: str): + with translation.override(self._user.profile.language): + self._event_bus.emit(AiThinkingMessage(content=status)) + + return ToolHelpers( + update_status=update_status, + navigate_to=lambda loc: unsafe_navigate_to(loc, self._event_bus), + event_bus=self._event_bus, ) + # ------------------------------------------------------------------ + # Message persistence + # ------------------------------------------------------------------ + async def acreate_chat_message( self, role: AssistantChatMessage.Role, @@ -210,37 +179,24 @@ async def acreate_chat_message( artifacts: dict[str, Any] | None = None, **kwargs, ) -> AssistantChatMessage: - """ - Creates and saves a new chat message. - - :param role: The role of the message (human or AI). - :param content: The content of the message. - :param artifacts: Optional artifacts associated with the message. - :return: The created AssistantChatMessage instance. - """ + """Persist a new chat message to the database.""" message = AssistantChatMessage( - chat=self._chat, - role=role, - content=content, - **kwargs, + chat=self._chat, role=role, content=content, **kwargs ) if artifacts: message.artifacts = artifacts - await message.asave() return message def list_chat_messages( self, last_message_id: int | None = None, limit: int = 100 - ) -> list[AssistantChatMessage]: - """ - Lists all chat messages in chronological order. + ) -> list[AssistantMessageUnion]: + """Return recent chat messages, oldest-first. - :param last_message_id: The ID of the last message received. If provided, only - messages before this ID will be returned. - :param limit: The maximum number of messages to return. - :return: A list of AssistantChatMessage instances. + :param last_message_id: If set, only return messages with ``id`` + below this value (cursor-based pagination). + :param limit: Maximum number of messages to return. """ queryset = ( @@ -251,7 +207,7 @@ def list_chat_messages( if last_message_id is not None: queryset = queryset.filter(id__lt=last_message_id) - messages = [] + messages: list[AssistantMessageUnion] = [] for msg in queryset[:limit]: if msg.role == AssistantChatMessage.Role.HUMAN: messages.append( @@ -276,267 +232,393 @@ def list_chat_messages( ) return list(reversed(messages)) - async def afetch_chat_history(self, limit: int = 50) -> udspy.History: - """ - Loads the chat history into a udspy.History object. It only loads complete - message pairs (human + AI). The history will be in chronological order and must - respect the module signature (question, answer). + async def _save_ai_response( + self, human_msg: AssistantChatMessage, answer: str + ) -> AiMessage: + """Persist the AI answer and create a prediction record for + feedback tracking.""" - :param limit: The maximum number of message pairs to load. - :return: A udspy.History instance containing the chat history. - """ + sources = self._deps.sources + ai_msg = await self.acreate_chat_message( + AssistantChatMessage.Role.AI, + answer, + artifacts={"sources": sources}, + action_group_id=get_client_undo_redo_action_group_id(self._user), + ) + await AssistantChatPrediction.objects.acreate( + human_message=human_msg, + ai_response=ai_msg, + prediction={"answer": answer}, + ) + return AiMessage( + id=ai_msg.id, + content=answer, + sources=sources, + can_submit_feedback=True, + ) - history = udspy.History() - last_saved_messages: list[AssistantChatMessage] = [ - msg async for msg in self._chat.messages.order_by("-created_on")[:limit] - ] - - while len(last_saved_messages) >= 2: - # Pop the oldest message pair to respect chronological order. - first_message = last_saved_messages.pop() - next_message = last_saved_messages[-1] - if ( - first_message.role != AssistantChatMessage.Role.HUMAN - or next_message.role != AssistantChatMessage.Role.AI - ): - continue + # ------------------------------------------------------------------ + # Message history (pydantic-ai ModelMessage round-trips) + # ------------------------------------------------------------------ - history.add_user_message(first_message.content) - assistant_answer = last_saved_messages.pop() - history.add_assistant_message(assistant_answer.content) + async def _save_message_history(self, messages_json: bytes) -> None: + """Persist the serialised pydantic-ai message history on the chat.""" - return history + self._chat.message_history = messages_json + await self._chat.asave(update_fields=["message_history", "updated_on"]) - def get_tool_helpers(self) -> ToolHelpers: - def update_status_localized(status: str): - """ - Sends a localized message to the frontend to update the assistant status. + async def _load_message_history(self) -> list[ModelMessage] | None: + """Deserialise and compact the stored message history, returning + ``None`` if absent or corrupt.""" - :param status: The status message to send. - """ + raw = self._chat.message_history + if not raw: + return None + try: + messages = ModelMessagesTypeAdapter.validate_json(bytes(raw)) + return compact_message_history(messages) + except Exception: + logger.opt(exception=True).warning( + "Failed to load message history for chat {}, starting fresh", + self._chat.pk, + ) + return None - with translation.override(self._user.profile.language): - udspy.emit_event(AiThinkingMessage(content=status)) + # ------------------------------------------------------------------ + # Agent execution + # ------------------------------------------------------------------ - return ToolHelpers( - update_status=update_status_localized, - navigate_to=unsafe_navigate_to, + async def _generate_chat_title(self, user_message: str) -> str: + """Ask the title agent to summarise a user message into a short + chat title.""" + + result = await title_agent.run( + user_message, + model=self._model, + model_settings=get_model_settings(self._model_string, TITLE), ) + return result.output - async def _generate_chat_title(self, user_message: str) -> str: - """ - Generates a title for the chat based on the user message and AI response. + _MAX_TOOL_CALL_AS_TEXT_RETRIES = 2 - :param user_message: The latest user message in the chat. - :return: The generated chat title. - """ + _TOOL_CALL_CORRECTION_PROMPT = ( + "Your previous response contained a raw JSON tool call instead of " + "actually invoking the tool. The malformed output was:\n\n" + "{malformed_output}\n\n" + "Please call the tool directly using the proper tool-calling " + "mechanism instead of outputting JSON text. Make sure the " + "arguments conform to the tool's schema." + ) - title_generator = udspy.Predict( - udspy.Signature.from_string( - "user_message -> chat_title", - "Create a short title for the following user request.", + async def _emit_answer( + self, + answer: str, + run_result: Any, + queue: asyncio.Queue[QueueEvent], + ) -> None: + """Push the final answer and result events onto *queue*.""" + + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiMessageChunk(content=answer, sources=self._deps.sources), ) ) - rsp = await title_generator.aforward( - user_message=user_message, + queue.put_nowait( + QueueEvent( + kind=QueueEventKind.RESULT, + answer=answer, + messages_json=run_result.all_messages_json(), + ) ) - return rsp.chat_title - async def _acreate_ai_message_response( + async def _run_agent( self, - human_msg: HumanMessage, - prediction: udspy.Prediction, - ) -> AiMessage: - """ - Creates and saves an AI chat message response based on the prediction. Stores - the prediction in AssistantChatPrediction, linking it to the human message, so - it can be referenced later to provide feedback. + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> None: + """Execute the main agent, retrying if it outputs tool calls as text. - :param human_msg: The human message instance. - :param prediction: The udspy.Prediction instance containing the AI response. - :return: The created AiMessage instance to return to the user. - """ + Delegates each streaming pass to ``_stream_agent_run``. If the + final output looks like a raw JSON tool call, re-runs the agent + with the conversation history and a corrective prompt (up to + ``_MAX_TOOL_CALL_AS_TEXT_RETRIES`` times) so the model can + self-correct and invoke the tool properly. - sources = self._assistant_callbacks.sources - ai_msg = await self.acreate_chat_message( - AssistantChatMessage.Role.AI, - prediction.answer, - artifacts={"sources": sources}, - action_group_id=get_client_undo_redo_action_group_id(self._user), - ) + Pushes ``STREAM``, ``RESULT``, ``ERROR``, and ``DONE`` events + onto *queue* for the consumer in ``astream_messages``. + """ - await AssistantChatPrediction.objects.acreate( - human_message=human_msg, - ai_response=ai_msg, - prediction={k: v for k, v in prediction.items() if k != "module"}, - ) + try: + with self._telemetry.trace(self._chat, user_prompt) as tracer: + answer, run_result = await self._run_agent_with_retries( + user_prompt, message_history, queue + ) + tracer.set_trace_output(answer) + await self._emit_answer(answer, run_result, queue) + except Exception as exc: + logger.exception("Error running main agent") + queue.put_nowait(QueueEvent(kind=QueueEventKind.ERROR, error=exc)) + finally: + queue.put_nowait(QueueEvent(kind=QueueEventKind.DONE)) + + async def _run_agent_with_retries( + self, + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> tuple[str, Any]: + """Stream the agent, retrying on tool-call-as-text outputs. - # Yield final complete message - return AiMessage( - id=ai_msg.id, - content=prediction.answer, - sources=sources, - can_submit_feedback=True, - ) + Returns ``(answer, run_result)`` — either the model's valid + answer or a fallback message after exhausting retries. - def _get_cancellation_cache_key(self) -> str: + :raises RuntimeError: if the stream ends without a result event. """ - Get the Redis cache key for cancellation tracking. - :return: The cache key as a string. - """ + current_prompt = user_prompt + current_history = message_history - return get_assistant_cancellation_key(self._chat.uuid) + for attempt in range(1 + self._MAX_TOOL_CALL_AS_TEXT_RETRIES): + result = await self._stream_agent_run( + current_prompt, current_history, queue + ) + if result is None: + raise RuntimeError("Agent stream ended without a result event") - def _check_cancellation(self, cache_key: str, message_id: str) -> None: - """ - Check if the message generation has been cancelled. + answer, run_result = result - :param cache_key: The cache key to check for cancellation. - :param message_id: The ID of the message being generated. - :raises AssistantMessageCancelled: If the message generation has been cancelled. - """ + if not self._looks_like_json_tool_call(answer): + return answer, run_result - if cache.get(cache_key): - cache.delete(cache_key) - raise AssistantMessageCancelled(message_id=message_id) + logger.warning( + "[assistant] Model output tool call as text (attempt {}/{}): {}", + attempt + 1, + 1 + self._MAX_TOOL_CALL_AS_TEXT_RETRIES, + answer[:200], + ) - async def _process_agent_stream( - self, - event: Any, - human_msg: AssistantChatMessage, - ) -> Tuple[list[AssistantMessageUnion], udspy.Prediction | None]: - """ - Process a single event from the output stream. + if attempt < self._MAX_TOOL_CALL_AS_TEXT_RETRIES: + # Replace the malformed JSON visible in the UI with a + # reasoning indicator so the user doesn't see garbage. + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiReasoningChunk(content=""), + ) + ) + current_history = run_result.all_messages() + current_prompt = self._TOOL_CALL_CORRECTION_PROMPT.format( + malformed_output=answer[:500] + ) + + # Exhausted retries — give up gracefully. + logger.error( + "[assistant] Model persisted outputting tool " + "calls as text after {} retries", + self._MAX_TOOL_CALL_AS_TEXT_RETRIES, + ) + fallback = ( + "I ran into a temporary issue processing " + "your request. Could you please try again?" + ) + return fallback, run_result - :param event: The event to process. - :param human_msg: The human message instance. - :return: a tuple of (messages_to_yield, prediction). + async def _stream_agent_run( + self, + user_prompt: str, + message_history: list[ModelMessage] | None, + queue: asyncio.Queue[QueueEvent], + ) -> tuple[str, Any] | None: + """Run a single agent streaming pass. + + Streams reasoning/text chunks to *queue* and returns + ``(answer, run_result)`` when an ``AgentRunResultEvent`` is + received, or ``None`` if the stream ends without one. """ - messages = [] - prediction = None + reasoning_so_far = "" - if isinstance(event, (AiThinkingMessage, AiNavigationMessage)): - messages.append(event) - return messages, prediction + async for event in main_agent.run_stream_events( + user_prompt=user_prompt, + deps=self._deps, + model=self._model, + message_history=message_history, + usage_limits=UsageLimits(request_limit=200), + toolsets=[self._toolset], + model_settings=get_model_settings(self._model_string, ORCHESTRATOR), + ): + if isinstance(event, AgentRunResultEvent): + answer = event.result.output + if isinstance(answer, str): + answer = _strip_think_tags(answer) + return (answer, event.result) + + if isinstance(event, FunctionToolCallEvent): + thought = _extract_tool_thought(event) + if thought: + reasoning_so_far += thought + cleaned = _strip_think_tags(reasoning_so_far) + await self._enqueue_reasoning(queue, cleaned) + continue - # Stream the final answer - if isinstance(event, udspy.OutputStreamChunk): - if ( - event.field_name == "answer" - and event.module is self._assistant.extract_module - ): - messages.append( - AiMessageChunk( - content=event.content, - sources=self._assistant_callbacks.sources, - ) - ) + if isinstance(event, FunctionToolResultEvent): + reasoning_so_far = "" # reset on tool results, to show the reasoning leading up to the next tool call + continue - elif isinstance(event, udspy.Prediction): - # final prediction contains the answer to the user question - if event.module is self._assistant: - prediction = event - ai_msg = await self._acreate_ai_message_response(human_msg, prediction) - messages.append(ai_msg) + # Accumulate text/thinking deltas and send full reasoning. + # The frontend replaces content on each chunk, so we must + # send the complete text every time. + content = self._get_content_delta(event) + if content: + reasoning_so_far += content + cleaned = _strip_think_tags(reasoning_so_far) + if cleaned: + await self._enqueue_reasoning(queue, cleaned) - elif reasoning := getattr(event, "next_thought", None): - messages.append(AiReasoningChunk(content=reasoning)) + return None - return messages, prediction + @staticmethod + def _get_content_delta(event: Any) -> str | None: + """Extract text or thinking content from a stream event delta.""" - def get_agent_stream( - self, message: HumanMessage, conversation_history: udspy.History | None = None - ) -> AsyncGenerator[Any, None]: - """ - Returns an async generator that streams the ReAct agent's response to a user - message. + if isinstance(event, PartStartEvent) and isinstance( + event.part, (TextPart, ThinkingPart) + ): + return event.part.content or None + if isinstance(event, PartDeltaEvent) and isinstance( + event.delta, (TextPartDelta, ThinkingPartDelta) + ): + return event.delta.content_delta or None + return None - :param user_message: The message from the user. - :return: An async generator that yields stream events. - """ + @staticmethod + async def _enqueue_reasoning( + queue: asyncio.Queue[QueueEvent], content: str + ) -> None: + """Push an ``AiReasoningChunk`` onto *queue*.""" - formatted_history = ( - ChatSignature.format_conversation_history(conversation_history) - if conversation_history - else [] - ) - formatted_ui_context = ( - message.ui_context.format() if message.ui_context else None + await queue.put( + QueueEvent( + kind=QueueEventKind.STREAM, + message=AiReasoningChunk(content=content), + ) ) - return self._assistant.astream( - question=message.content, - conversation_history=formatted_history, - ui_context=formatted_ui_context, - ) + @staticmethod + def _looks_like_json_tool_call(text: str) -> bool: + """Return True if *text* looks like a tool call dumped as JSON. - async def _process_stream( - self, - human_msg: HumanMessage, - stream: AsyncGenerator[Any, None], - process_event_func: Callable[ - [Any, AssistantChatMessage], - Tuple[list[AssistantMessageUnion], udspy.Prediction | None], - ], - ) -> AsyncGenerator[Tuple[AssistantMessageUnion, udspy.Prediction | None], None]: - chunk_count = 0 - cancellation_key = self._get_cancellation_cache_key() - message_id = str(human_msg.id) + Checks for ``{"name": ..., "arguments": ...}`` pattern in the first + 200 chars. Does not require valid JSON (the output may be truncated). + """ - async for event in stream: - # Periodically check for cancellation - chunk_count += 1 - if chunk_count % 10 == 0: - self._check_cancellation(cancellation_key, message_id) + stripped = text.strip() + return ( + bool(stripped) + and stripped[0] == "{" + and '"name"' in stripped[:200] + and '"arguments"' in stripped[:200] + ) + + # ------------------------------------------------------------------ + # Cancellation + # ------------------------------------------------------------------ - messages, prediction = await process_event_func(event, human_msg) + async def _monitor_cancellation(self, task: asyncio.Task) -> None: + """Poll the cache for a cancellation flag and cancel *task* if + set. Runs as a concurrent task alongside the agent.""" - if messages: # Don't return responses if cancelled - self._check_cancellation(cancellation_key, message_id) + cache_key = get_assistant_cancellation_key(self._chat.uuid) + while not task.done(): + await asyncio.sleep(0.2) + if cache.get(cache_key): + cache.delete(cache_key) + self._tool_helpers.cancel() + task.cancel() + return - for msg in messages: - yield msg, prediction + # ------------------------------------------------------------------ + # Public streaming API + # ------------------------------------------------------------------ async def astream_messages( self, message: HumanMessage ) -> AsyncGenerator[AssistantMessageUnion, None]: - """ - Streams the response to a user message. + """Stream the full response lifecycle for a user message. - :param human_message: The message from the user. - :return: An async generator that yields the response messages. + Yields events in order: ``AiStartedMessage``, zero or more + streaming chunks (``AiMessageChunk`` / ``AiReasoningChunk`` / + ``AiThinkingMessage``), and finally an ``AiMessage`` with the + persisted answer. A ``ChatTitleMessage`` is appended on the first + message in a chat. """ + # Sticky task: capture on first message of the session + if not self._deps.original_request: + self._deps.original_request = message.content + + # Auto-detect starting mode from UI context (only on first message) + if message.ui_context: + if message.ui_context.application or message.ui_context.page: + self._deps.mode = AgentMode.APPLICATION + elif message.ui_context.automation or message.ui_context.workflow: + self._deps.mode = AgentMode.AUTOMATION + # else stays DATABASE (default) + human_msg = await self.acreate_chat_message( - AssistantChatMessage.Role.HUMAN, - message.content, + AssistantChatMessage.Role.HUMAN, message.content ) - default_callbacks = udspy.settings.callbacks - - with ( - udspy.settings.context( - lm=self._lm_client, - callbacks=[*default_callbacks, *self._callbacks], - ), - self._telemetry_callbacks.trace(self._chat, human_msg.content), - ): - message_id = str(human_msg.id) - yield AiStartedMessage(message_id=message_id) + message_id = str(human_msg.id) + yield AiStartedMessage(message_id=message_id) - history = await self.afetch_chat_history(limit=30) + ui_context = message.ui_context.format() if message.ui_context else None + self._tool_helpers.request_context["ui_context"] = ui_context + message_history = await self._load_message_history() - agent_stream = self.get_agent_stream(message, history) + queue: asyncio.Queue[QueueEvent] = asyncio.Queue() + self._event_bus.set_queue(queue) - async for msg, __ in self._process_stream( - human_msg, agent_stream, self._process_agent_stream - ): - yield msg + agent_task = asyncio.create_task( + self._run_agent(message.content, message_history, queue) + ) + monitor_task = asyncio.create_task(self._monitor_cancellation(agent_task)) - # Generate chat title if needed - if not self._chat.title: - chat_title = await self._generate_chat_title(human_msg.content) - self._chat.title = chat_title + try: + answer = None + messages_json = None + + while True: + event = await queue.get() + if event.kind == QueueEventKind.DONE: + break + elif event.kind == QueueEventKind.RESULT: + answer, messages_json = event.answer, event.messages_json + elif event.kind == QueueEventKind.ERROR: + raise event.error + else: + yield event.message + + if agent_task.cancelled(): + raise AssistantMessageCancelled(message_id=message_id) + + if answer is not None: + yield await self._save_ai_response(human_msg, answer) + if messages_json: + await self._save_message_history(messages_json) + finally: + monitor_task.cancel() + if not agent_task.done(): + agent_task.cancel() + await asyncio.gather(monitor_task, agent_task, return_exceptions=True) + self._event_bus.set_queue(None) + + if not self._chat.title: + try: + title = await self._generate_chat_title(human_msg.content) + self._chat.title = title[: AssistantChat.TITLE_MAX_LENGTH] await self._chat.asave(update_fields=["title", "updated_on"]) - yield ChatTitleMessage(content=chat_title) + yield ChatTitleMessage(content=self._chat.title) + except Exception: + logger.exception("Failed to generate chat title") diff --git a/enterprise/backend/src/baserow_enterprise/assistant/deps.py b/enterprise/backend/src/baserow_enterprise/assistant/deps.py new file mode 100644 index 0000000000..f4bb95db97 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/deps.py @@ -0,0 +1,148 @@ +import asyncio +import threading +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import TYPE_CHECKING, Any, Callable + +from pydantic_ai import Tool + +if TYPE_CHECKING: + from django.contrib.auth.models import AbstractUser + + from baserow.core.models import Workspace + from baserow_enterprise.assistant.tools.navigation.types import ( + AnyNavigationRequestType, + ) + + +class AgentMode(str, Enum): + """Operating mode that controls which tools are available to the agent.""" + + DATABASE = "database" + APPLICATION = "application" + AUTOMATION = "automation" + EXPLAIN = "explain" + + +class QueueEventKind(Enum): + STREAM = auto() + RESULT = auto() + ERROR = auto() + DONE = auto() + + +@dataclass +class QueueEvent: + kind: QueueEventKind + message: Any = None + answer: str = "" + messages_json: bytes = b"" + error: Exception | None = None + + +@dataclass +class EventBus: + """ + Pushes streaming events into the queue consumed by + Assistant.astream_messages(). Events are silently dropped when no + queue is attached. + """ + + _queue: asyncio.Queue[QueueEvent] | None = None + + def set_queue(self, queue: asyncio.Queue[QueueEvent] | None): + self._queue = queue + + def emit(self, event): + if self._queue is not None: + self._queue.put_nowait( + QueueEvent(kind=QueueEventKind.STREAM, message=event) + ) + + +@dataclass +class ToolHelpers: + """ + Contextual helpers available to every tool via ``RunContext[AssistantDeps]``. + + Provides status updates (shown in the UI), navigation actions, + cancellation support, and an event bus for emitting custom streaming + events (thinking messages, navigation messages, etc.). + """ + + update_status: Callable[[str], None] + navigate_to: Callable[["AnyNavigationRequestType"], str] + request_context: dict = field(default_factory=dict) + event_bus: EventBus = field(default_factory=EventBus) + _cancel_event: threading.Event = field(default_factory=threading.Event) + + def raise_if_cancelled(self) -> None: + """Check cancellation and raise if set. Thread-safe. + + Call this in tool loops or between expensive operations. + Raises ``CancelledError`` (``BaseException``) which escapes the + agent's ``except Exception`` handler and propagates through the + async chain. + """ + + if self._cancel_event.is_set(): + raise asyncio.CancelledError() + + @property + def is_cancelled(self) -> bool: + """Check if cancelled without raising. Thread-safe.""" + + return self._cancel_event.is_set() + + def cancel(self) -> None: + """Signal cancellation to running tools. Thread-safe.""" + + self._cancel_event.set() + + +@dataclass +class AssistantDeps: + """ + Typed dependency container for the pydantic-ai agent. + + Every agent run operates on behalf of a user in a given workspace. + This runtime-context also allows tools to share information (e.g. + sources), provide helpers for emitting events or requesting navigation, + switch between domain modes, and dynamically extend the toolset by + adding tools to ``dynamic_tools`` during a run (e.g. row tools loaded + by the database agent). + + Passed via ``deps=`` to every ``agent.run()`` / ``agent.run_stream()`` + call and accessible in tools via ``RunContext[AssistantDeps].deps``. + """ + + user: "AbstractUser" + workspace: "Workspace" + tool_helpers: ToolHelpers + mode: AgentMode = AgentMode.DATABASE + sources: list[str] = field(default_factory=list) + dynamic_tools: list[Tool] = field(default_factory=list) + database_manifest: str = "" + application_manifest: str = "" + automation_manifest: str = "" + explain_manifest: str = "" + original_request: str = "" + + @property + def active_manifest(self) -> str: + return { + AgentMode.DATABASE: self.database_manifest, + AgentMode.APPLICATION: self.application_manifest, + AgentMode.AUTOMATION: self.automation_manifest, + AgentMode.EXPLAIN: self.explain_manifest, + }[self.mode] + + def extend_sources(self, new_sources: list[str]): + """ + Extend the current list of sources with new ones, avoiding + duplicates. + + :param new_sources: The list of new source URLs to add. + """ + + self.sources.extend(s for s in new_sources if s not in self.sources) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/handler.py b/enterprise/backend/src/baserow_enterprise/assistant/handler.py index bfe8b2bf72..6d98e98935 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/handler.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/handler.py @@ -147,12 +147,11 @@ async def astream_assistant_messages( :param chat: The AI assistant chat to get the assistant for. :param human_message: The new message from the user. - :param ui_ontext: The UI context where the message was sent. + :param ui_context: The UI context where the message was sent. :return: An async generator yielding messages from the assistant. """ assistant = self.get_assistant(chat) - async for message in assistant.astream_messages( - human_message, ui_context=ui_context - ): - yield message + message = HumanMessage(content=human_message, ui_context=ui_context) + async for msg in assistant.astream_messages(message): + yield msg diff --git a/enterprise/backend/src/baserow_enterprise/assistant/history.py b/enterprise/backend/src/baserow_enterprise/assistant/history.py new file mode 100644 index 0000000000..c25e699550 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/history.py @@ -0,0 +1,118 @@ +""" +Utilities for compacting and trimming pydantic-ai message histories. + +The assistant persists the full message history (including intermediate tool +calls) across turns. Before feeding it back into the agent we compact each +turn down to (user prompt, final answer) and trim to a fixed window so the +context doesn't grow unboundedly. +""" + +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) + +# The number of messages to keep in the compacted history for context. This is +# a simple safeguard to prevent excessively long histories from bloating the +# context. +MAX_HISTORY_MESSAGES = 20 + + +def _has_user_prompt(msg: ModelMessage) -> bool: + """Check if a ModelRequest contains a UserPromptPart.""" + + return isinstance(msg, ModelRequest) and any( + isinstance(p, UserPromptPart) for p in msg.parts + ) + + +def _get_final_text_response(turn: list[ModelMessage]) -> ModelResponse | None: + """ + Return the last ModelResponse in the turn that contains a TextPart, + or None if no such response exists. + """ + + for msg in reversed(turn): + if isinstance(msg, ModelResponse) and any( + isinstance(p, TextPart) for p in msg.parts + ): + return msg + return None + + +def _split_into_turns( + messages: list[ModelMessage], +) -> list[list[ModelMessage]]: + """ + Split a flat message list into turns. Each turn starts at a ModelRequest + that contains a UserPromptPart. + + Messages before the first UserPromptPart (e.g. initial system instructions) + are grouped into a leading "turn 0". + """ + + turns: list[list[ModelMessage]] = [] + current: list[ModelMessage] = [] + + for msg in messages: + if _has_user_prompt(msg) and current: + turns.append(current) + current = [] + current.append(msg) + + if current: + turns.append(current) + + return turns + + +def _compact_turn(turn: list[ModelMessage]) -> list[ModelMessage]: + """ + Compact a single turn. If the turn has more than 2 messages (user prompt + + final answer), strip intermediate tool call/return messages and keep + only the user prompt request and the final text response. + + Returns the turn unchanged if it has no tool calls or no final text + response. + """ + + if len(turn) <= 2: + return turn + + # Find the user prompt request (first message) and final text response + user_request = turn[0] if _has_user_prompt(turn[0]) else None + final_response = _get_final_text_response(turn) + + if user_request and final_response: + return [user_request, final_response] + + # Cannot compact -- return as-is + return turn + + +def compact_message_history( + messages: list[ModelMessage], + max_messages: int = MAX_HISTORY_MESSAGES, +) -> list[ModelMessage]: + """ + Compact and trim a pydantic-ai message history for multi-turn context. + + 1. Splits messages into turns (delimited by UserPromptPart). + 2. For each turn with intermediate tool calls, collapses to just the + user prompt and final text answer. + 3. Trims to the last ``max_messages`` messages if still too long. + """ + + turns = _split_into_turns(messages) + + compacted: list[ModelMessage] = [] + for turn in turns: + compacted.extend(_compact_turn(turn)) + + if len(compacted) > max_messages: + compacted = compacted[-max_messages:] + + return compacted diff --git a/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py b/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py new file mode 100644 index 0000000000..25f19a625b --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/model_profiles.py @@ -0,0 +1,175 @@ +""" +Centralized model configuration and per-model settings for all agents. + +Contains: +- ``get_model_string()``: Resolves the active LLM model identifier. +- ``check_lm_ready_or_raise()``: Quick connectivity check. +- ``get_model_settings(model, role)``: Per-model, per-role settings. + +Usage:: + + from baserow_enterprise.assistant.model_profiles import ( + get_model_string, get_model_settings, ORCHESTRATOR, + ) + + model = get_model_string() + settings = get_model_settings(model, ORCHESTRATOR) +""" + +from functools import lru_cache + +from django.conf import settings + +from pydantic_ai import Agent +from pydantic_ai.settings import ModelSettings + +from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError +from baserow_enterprise.assistant.models import AssistantChat + +# --------------------------------------------------------------------------- +# Agent roles +# --------------------------------------------------------------------------- + +ORCHESTRATOR = "orchestrator" +SUBAGENT = "subagent" # database, builder, automations +UTILITY = "utility" # formula, fixer (precision-oriented) +SAMPLE = "sample" # sample row generation (creative) +TITLE = "title" # title generation + +# --------------------------------------------------------------------------- +# Per-model profiles +# --------------------------------------------------------------------------- + +# Fallback when the model isn't in _MODEL_PROFILES +_DEFAULT_PROFILE: dict[str, ModelSettings] = { + ORCHESTRATOR: { + "temperature": 0.3, + "timeout": 30, + "parallel_tool_calls": False, + "max_tokens": 16384, + }, + SUBAGENT: { + "temperature": 0.3, + "timeout": 20, + "parallel_tool_calls": False, + "max_tokens": 16384, + }, + UTILITY: { + "temperature": 0.1, + "timeout": 20, + }, + SAMPLE: { + "temperature": 0.5, + "timeout": 20, + }, + TITLE: { + "temperature": 0.7, + "timeout": 10, + "max_tokens": AssistantChat.TITLE_MAX_LENGTH, + }, +} + +_MODEL_PROFILES: dict[str, dict[str, ModelSettings]] = { + "gpt-oss-120b": { + ORCHESTRATOR: { + **_DEFAULT_PROFILE[ORCHESTRATOR], + "groq_reasoning_format": "parsed", + }, + SUBAGENT: { + **_DEFAULT_PROFILE[SUBAGENT], + "groq_reasoning_format": "parsed", + }, + UTILITY: { + # No groq_reasoning_format here: formula generation is a precise + # structured-output task where reasoning tokens pollute the output. + **_DEFAULT_PROFILE[UTILITY], + }, + SAMPLE: { + **_DEFAULT_PROFILE[SAMPLE], + "groq_reasoning_format": "parsed", + }, + TITLE: { + **_DEFAULT_PROFILE[TITLE], + }, + }, +} + + +def get_model_settings(model: str, role: str) -> ModelSettings: + """ + Return the ModelSettings for a given model string and agent role. + + The model string is the pydantic-ai format (e.g. ``"groq:openai/gpt-oss-120b"``). + We match on the last path segment (e.g. ``"gpt-oss-120b"``) to find the profile. + + For the ``ORCHESTRATOR`` role the temperature defaults to the value of + ``BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE`` (if set), allowing + operators to override it without changing code. + + :param model: pydantic-ai model string (e.g. ``"groq:openai/gpt-oss-120b"``). + :param role: One of ORCHESTRATOR, SUBAGENT, UTILITY, TITLE. + :return: A ModelSettings dict suitable for ``model_settings=`` parameter. + """ + + # Extract model name after the provider prefix: + # "groq:openai/gpt-oss-120b" -> "gpt-oss-120b" + # "ollama:kimi-2.5:cloud" -> "kimi-2.5:cloud" + _, sep, after_provider = model.partition(":") + model_name = after_provider.rsplit("/", 1)[-1] if sep else model + + profile = _MODEL_PROFILES.get(model_name, _DEFAULT_PROFILE) + result = dict(profile.get(role, _DEFAULT_PROFILE.get(role, {}))) + + # Allow the env-var-driven setting to override the orchestrator temperature. + if role == ORCHESTRATOR: + env_temp = getattr( + settings, "BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", None + ) + if env_temp is not None: + result["temperature"] = env_temp + + return result + + +# --------------------------------------------------------------------------- +# Model resolution +# --------------------------------------------------------------------------- + + +def get_model_string(model: str | None = None) -> str: + """ + Returns the model string for the pydantic-ai agent. + + :param model: The language model to use. If None, the default model from + settings will be used. + :return: A model string compatible with pydantic-ai. + """ + + value = model or settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL + # pydantic-ai expects "provider:model" (e.g. "groq:openai/gpt-oss-120b"). + # Convert "provider/model" to "provider:model" when the first "/" comes + # before the first ":" (or there is no ":"). This handles cases like + # "ollama/kimi-2.5:cloud" where the colon is part of the model tag. + slash_pos = value.find("/") + colon_pos = value.find(":") + if slash_pos != -1 and (colon_pos == -1 or slash_pos < colon_pos): + value = value.replace("/", ":", 1) + elif slash_pos == -1 and colon_pos == -1: + # No provider prefix at all (e.g. "gpt-4o") — default to OpenAI + # for backward compatibility with old UDSPY_LM_MODEL values. + value = f"openai:{value}" + return value + + +@lru_cache(maxsize=1) +def check_lm_ready_or_raise() -> None: + model = get_model_string() + test_agent = Agent( + output_type=str, instructions="Respond with 'ok'.", name="test_agent" + ) + try: + test_agent.run_sync("Test", model=model) + except Exception as e: + raise AssistantModelNotSupportedError( + f"The model '{model}' is not supported or accessible: {e}" + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/models.py b/enterprise/backend/src/baserow_enterprise/assistant/models.py index 9c48402a96..250bc5e1c0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/models.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/models.py @@ -42,6 +42,14 @@ class Status(models.TextChoices): status = models.CharField( max_length=20, choices=Status.choices, default=Status.IDLE ) + message_history = models.BinaryField( + null=True, + blank=True, + help_text=( + "Serialized pydantic-ai message history (JSON bytes) for " + "multi-turn conversation context." + ), + ) class Meta: indexes = [ diff --git a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py index 84d949a61c..bc8ff080f0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py @@ -1,167 +1,57 @@ from django.conf import settings -CORE_CONCEPTS = """ -### BASEROW STRUCTURE - -**Structure**: Workspace → Databases, Applications, Automations, Dashboards, Snapshots - -**Key concepts**: -• **Roles**: Free (admin, member) | Advanced/Enterprise (admin, builder, editor, viewer, no access) -• **Features**: Real-time collaboration, SSO (SAML2/OIDC/OAuth2), MCP integration, API access, Audit logs -• **Plans**: Free, Premium, Advanced, Enterprise (https://baserow.io/pricing) -• **Open Source**: Core is open source (https://github.com/baserow/baserow) -• **Snapshots**: Application-level backups -""" - -DATABASE_BUILDER_CONCEPTS = """ -### DATABASE BUILDER (no-code database) - -**Structure**: Database → Tables → Fields + Views + Webhooks + Rows. Rows → comments. - -**Key concepts**: -• **Fields**: Define schema (30+ types including link_row for relationships); one primary field per table -• **Views**: Present data with filters/sorts/grouping/colors; can be shared, personal, or public -• **Rows**: Data records following the table schema; support for rich content (files, long text, formulas, numbers, dates, etc.). Changes are tracked in history. -• **Comments**: Threaded discussions on rows; mentions. -• **Formulas**: Computed fields using functions/operators; support for cross-table lookups -• **Permissions**: RBAC at workspace/database/table/field levels; database tokens for API -• **Data sync**: Table replication; **Webhooks**: Row/field/view event triggers -""" - -APPLICATION_BUILDER_CONCEPTS = """ -### APPLICATION BUILDER (visual app builder) - -**Structure**: Application → Pages → Elements + Data Sources + Workflows - -**Key concepts**: -• **Pages**: Routes with UI elements (buttons, tables, forms, etc.) -• **Data Sources**: Connect to database tables/views; elements bind to them for dynamic content -• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes -• **Workflows**: Event-driven actions (create/update rows, navigate, notifications) -• **Publishing**: Requires domain configuration +AGENT_IDENTITY = """\ + +You are Kuma, an AI expert for Baserow (open-source no-code platform). \ +You are an autonomous tool-calling agent. Whenever possible, you act — you do not describe. + +""" + +RULES = """\ + +1. Use the `thought` parameter on EVERY tool call to state your reasoning. +2. Have tools → call them. No tools in current mode → check other modes before saying something is not possible. If another mode has the tool, switch_mode and use it. Only explain manual UI steps if no mode covers the action. +3. One tool per turn. Wait for the result. Never reply and call a tool in same turn. +4. Verify after create/modify — navigate to show the result. +5. Request priority: action > follow-up (reuse prior IDs, never search docs) > question. When a tool result contains next_steps, act on them immediately — do not ask for permission to continue. +6. You start in the mode matching your UI context (database/application/automation). If the user asks a how-to or feature question, call switch_mode("explain"), then search_user_docs. +7. After finishing the tool calls in a different mode (not just after switching — after the actual work is done and results received), switch back to the original domain mode (check and ). +8. Reply in concise Markdown. Never expose raw JSON or internal IDs unless asked. +9. When a request references resources by name/ID, verify they exist (list_*) before building on them. If not found, ask — don't guess. But when the task *requires* creating resources in another domain (e.g. building an app that needs new tables), switch_mode and create them yourself — don't ask the user to do it manually. +10. Before responding to the user, verify ALL parts of `` are addressed. If anything is missing, continue working. +11. Before adding a table to a database or a page to an application, check that the target is semantically related. If the name/purpose doesn't match, ask the user which target to use or whether to create a new one. Examples of mismatches: adding "Inquiries" table to a "Project Management" DB; adding "Event Registration" pages to a "Portfolio Website" app. This applies to ALL resource creation — tables, pages, and the applications/databases themselves. Remember their answer — only re-ask when a new, different mismatch arises. + +""" + +HANDLING_AMBIGUITY = """\ + +Ambiguous terms — pick by context, confirm only if truly unclear: +- "table" → App Builder: Table element | Database: database table +- "form" → App Builder: Form element | Database: Form view +- "workflow action" → App Builder: element action | Automations: action node + +""" + +BASEROW_KNOWLEDGE = """\ + +Workspace → Databases, Applications, Automations, Dashboards +Database → Tables → Fields (30+ types, link_row for relations) + Views (grid, form, kanban, calendar, gallery, timeline) + Rows +Application → Pages → Elements + Data Sources + Actions +Automation → Workflows → Trigger + Action/Router/Iterator nodes (use {{ node.ref }} for formulas) + +""" + +LIMITATIONS_AND_SOURCES = f"""\ + +Cannot create/modify/delete: user accounts, workspaces, dashboards, widgets, snapshots, webhooks, integrations, roles, permissions. +Docs: search_user_docs | API: {settings.PUBLIC_BACKEND_URL}/api/schema.json | Web: https://baserow.io | Community: https://community.baserow.io + """ -AUTOMATION_BUILDER_CONCEPTS = """ -### AUTOMATIONS (no-code automation builder) - -**Structure**: Automation → Workflows → Trigger + Actions + Routers (Nodes) - -**Key concepts**: -• **Trigger**: The single event that starts the workflow (e.g., row created/updated/deleted) -• **Actions**: Tasks performed (e.g., create/update rows, send emails, call webhooks) -• **Routers**: Conditional logic (if/else, switch) to control flow -• **Iterators**: Loop over lists of items -• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes -• **Execution**: Runs in the background; monitor via logs -• **History**: Track runs, successes, failures -• **Publishing**: Requires at least one configured action -""" - -AGENT_LIMITATIONS = """ -## LIMITATIONS - -### CANNOT CREATE: -• User accounts, workspaces -• Applications, pages -• Dashboards, widgets -• Snapshots, webhooks, integrations -• Roles, permissions - -### CANNOT UPDATE/MODIFY: -• User, workspace, or integration settings -• Roles, permissions -• Applications, pages -• Dashboards, widgets - -### CANNOT DELETE: -• Users, workspaces -• Roles, permissions -• Applications, pages -• Dashboards, widgets -""" - -ASSISTANT_SYSTEM_PROMPT_BASE = ( - f""" -You are Kuma, an AI expert for Baserow (open-source no-code platform). - -## YOUR KNOWLEDGE -1. **Core concepts** (below) -2. **Detailed docs** - use search_user_docs tool to search when needed -3. **API specs** - guide users to "{settings.PUBLIC_BACKEND_URL}/api/schema.json" -4. **Official website** - "https://baserow.io" -5. **Community support** - "https://community.baserow.io" -6. **Direct support** - for Advanced/Enterprise plan users - -## ANSWER FORMATTING GUIDELINES -• Use American English spelling and grammar -• Only use Markdown (bold, italics, lists, code blocks) -• Prefer lists in explanations. Numbered lists for steps; bulleted for others. -• Use code blocks for examples, commands, snippets -• Be concise and clear in your response - -## BASEROW CONCEPTS -""" - + CORE_CONCEPTS - + DATABASE_BUILDER_CONCEPTS - + APPLICATION_BUILDER_CONCEPTS - + AUTOMATION_BUILDER_CONCEPTS -) - AGENT_SYSTEM_PROMPT = ( - ASSISTANT_SYSTEM_PROMPT_BASE - + """ -## YOUR TOOLS - -**CRITICAL - Understanding your tools:** -- Learn what each tool does ONLY from its **name** and **description** -- **NEVER use `search_user_docs` to learn about your tools** - it contains end-user documentation, NOT information about your available tools or how to call them -- `search_user_docs` is ONLY for answering user questions about Baserow features and providing manual instructions - -## REQUEST HANDLING - -### ACTION REQUESTS - CHECK FIRST - -**CRITICAL: Before treating a request as a question, determine if it's an action you can perform.** - -Recognize action requests by: -- Imperative verbs: "Show...", "Filter...", "Create...", "Add...", "Delete...", "Update...", "Sort...", "Hide..." -- Desired states: "I want only...", "I need a field that...", "Make it show..." -- Example: "Show only rows where the primary field is empty" → This is an ACTION (create a filter), not a question about filtering - -**DO vs EXPLAIN:** -- If you have tools to do it → **DO IT** -- If you lack tools → **THEN explain** how to do it manually -- **NEVER explain how to do something you can do yourself** - -**Workflow:** -1. Check your tools - can you fulfill this? -2. **YES**: Execute (ask for clarification only if request is ambiguous) -3. **NO** (see LIMITATIONS): Explain you can't, then provide manual instructions from docs - -### QUESTIONS (only after ruling out action requests) - -**FACTUAL QUESTIONS** - asking what Baserow IS or HAS: -- Examples: "Does Baserow have X feature?", "How does Y work?", "What options exist for Z?" -- These have objectively correct/incorrect answers that must come from documentation -- **ALWAYS search documentation first** using `search_user_docs` -- Check the `reliability_note` in the response: - - **HIGH CONFIDENCE**: Present the answer confidently with sources - - **PARTIAL MATCH**: Provide the answer but note some details may be incomplete - - **LOW CONFIDENCE / NOTHING FOUND**: Tell the user you couldn't find this in the documentation. **DO NOT guess or assume features exist** - if docs don't mention it (e.g., a "barcode field"), it likely doesn't exist. Suggest checking the community forum or contacting support. -- **NEVER fabricate Baserow features or capabilities** - -**ADVISORY QUESTIONS** - asking how to USE or APPLY Baserow: -- Examples: "How should I structure X?", "What's a good approach for Y?", "Help me build Z", "Which field type works best for W?" -- These ask for your expertise in applying Baserow to solve problems - there's no single correct answer -- **Use your knowledge** of Baserow's real capabilities (field types, views, formulas, automations, linking, etc.) to provide helpful recommendations -- You may search docs for reference, but can also directly advise based on your understanding of Baserow -- Focus on practical solutions using actual Baserow functionality - -**Key principle**: Never fabricate what Baserow CAN do. Freely advise on HOW to use what Baserow actually offers. -""" - + AGENT_LIMITATIONS - + """ - -## TASK INSTRUCTIONS: -""" + AGENT_IDENTITY + + RULES + + HANDLING_AMBIGUITY + + BASEROW_KNOWLEDGE + + LIMITATIONS_AND_SOURCES ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py b/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py new file mode 100644 index 0000000000..007f55398d --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/retrying_model.py @@ -0,0 +1,448 @@ +""" +A pydantic-ai Model wrapper that retries on transient provider errors. + +Provider SDKs (Groq, Anthropic, OpenAI) sometimes raise exceptions that +are transient — e.g. ``groq.APIError: Failed to parse tool call arguments +as JSON``. pydantic-ai handles *some* of these (e.g. ``tool_use_failed`` +with a structured body), but others slip through. + +``RetryingModel`` wraps any pydantic-ai ``Model`` and adds retry logic +around ``request()`` with configurable back-off. + +Streaming recovery +------------------ +pydantic-ai's ``GroqStreamedResponse`` catches ``APIError`` with +``tool_use_failed`` bodies, but only when the ``failed_generation`` JSON +is valid. When Groq sends **truly malformed** JSON (not just +schema-invalid), pydantic-ai's ``Json[...]`` type fails to parse it and +re-raises the raw ``APIError``. + +Since this error occurs *during* stream consumption (after yield), +``@asynccontextmanager`` cannot yield a replacement. Instead we wrap +the stream in ``_ErrorRecoveringStream`` which intercepts ``APIError`` +in its ``_get_event_iterator`` and emits a ``ToolCallPart`` (or +``TextPart``) so pydantic-ai's validation loop can tell the model +what was wrong. + +For errors that occur *before* the stream is established (during +``request_stream`` setup), we fall back to the retrying ``request()`` +method and wrap the result in ``_PreFetchedResponse``. +""" + +from __future__ import annotations + +import asyncio +import json +import re +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from datetime import datetime, timezone +from typing import Any + +from loguru import logger +from pydantic_ai._run_context import RunContext +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ModelMessage, ModelResponse, ToolCallPart +from pydantic_ai.models import ( + KnownModelName, + Model, + ModelRequestParameters, + ModelResponseStreamEvent, + StreamedResponse, + infer_model, +) +from pydantic_ai.models.wrapper import WrapperModel +from pydantic_ai.settings import ModelSettings + +# Transient Groq errors that are safe to retry. +_RETRYABLE_MESSAGES = frozenset( + { + "Failed to parse tool call arguments as JSON", + "Tool call validation failed", + } +) + + +def _is_transient_provider_error(exc: Exception) -> bool: + """Return True for provider errors that are transient and safe to retry.""" + + msg = str(exc) + return any(needle in msg for needle in _RETRYABLE_MESSAGES) + + +def _extract_tool_use_failed(body: dict) -> dict | None: + """Extract ``tool_use_failed`` error dict from an error body. + + Handles both wrapped (``{"error": {...}}``) and unwrapped layouts + (the Groq SDK streaming path sets ``body=data["error"]``). + """ + + error = body.get("error", body) + if not isinstance(error, dict): + return None + if error.get("code") != "tool_use_failed": + return None + return error + + +_TOOL_NAME_RE = re.compile(r'"name"\s*:\s*"([^"]+)"') + + +def _extract_tool_name(failed_gen: str) -> str: + """Best-effort tool name extraction from truncated/malformed JSON.""" + + m = _TOOL_NAME_RE.search(failed_gen) + return m.group(1) if m else "unknown" + + +def _recover_failed_generation(failed_gen: str, model_name: str = "") -> ModelResponse: + """Turn a ``failed_generation`` string into a synthetic ``ModelResponse``. + + If the JSON is valid and contains ``name`` + ``arguments``, returns a + ``ToolCallPart`` so pydantic-ai's validation loop can tell the model + what was wrong. For truly malformed JSON, extracts the tool name + (best-effort) and returns a ``ToolCallPart`` with empty args so + pydantic-ai's validation rejects it and sends a retry prompt. + """ + + try: + parsed = json.loads(failed_gen) + if isinstance(parsed, dict) and "name" in parsed and "arguments" in parsed: + return ModelResponse( + parts=[ + ToolCallPart( + tool_name=parsed["name"], + args=json.dumps(parsed["arguments"]), + ) + ], + model_name=model_name, + ) + except (json.JSONDecodeError, TypeError): + pass + + # JSON is truly malformed (e.g. truncated). We must NOT fall back to a + # TextPart here because the stream may have already started emitting + # tool-call events — mixing TextPart into a tool-call stream causes + # pydantic-ai's AgentStream to fail with "unable to find output". + # + # Instead, try to extract the tool name from partial JSON and emit a + # ToolCallPart with empty args. pydantic-ai's validation will reject + # the args and send a retry prompt to the model. + tool_name = _extract_tool_name(failed_gen) + return ModelResponse( + parts=[ + ToolCallPart( + tool_name=tool_name, + args="{}", + ) + ], + model_name=model_name, + ) + + +def _try_recover_tool_use_failed(exc: Exception) -> ModelResponse | None: + """Try to recover a ``tool_use_failed`` error into a ``ModelResponse``. + + Works with both ``ModelHTTPError`` (non-streaming path) and raw + provider ``APIError`` (streaming path where pydantic-ai's handler + couldn't parse the malformed JSON). + """ + + if isinstance(exc, ModelHTTPError): + body = exc.body + model_name = exc.model_name + elif hasattr(exc, "body"): + # Raw provider APIError (e.g. groq.APIError). + body = exc.body # type: ignore[union-attr] + model_name = "" + else: + return None + + if not isinstance(body, dict): + return None + + error = _extract_tool_use_failed(body) + if error is None: + return None + + failed_gen = error.get("failed_generation") + if not failed_gen or not isinstance(failed_gen, str): + return None + + return _recover_failed_generation(failed_gen, model_name) + + +def _resolve_model(model_name: str) -> Model: + """Resolve a model name to a pydantic-ai Model instance. + + For Google models, constructs the model with a fresh + ``httpx.AsyncClient`` instead of relying on ``infer_model()`` which + uses a process-global cached client. That cached client binds to the + event loop at creation time and breaks when reused on a different loop + (common in Django async views). + See: https://github.com/pydantic/pydantic-ai/issues/3240 + """ + + if model_name.startswith(("google-gla:", "google:", "google-vertex:")): + import httpx + from pydantic_ai.models.google import GoogleModel + from pydantic_ai.providers.google import GoogleProvider + + vertexai = model_name.startswith("google-vertex:") + google_model_name = model_name.split(":", 1)[1] + return GoogleModel( + google_model_name, + provider=GoogleProvider(http_client=httpx.AsyncClient(), vertexai=vertexai), + ) + + return infer_model(model_name) + + +class RetryingModel(WrapperModel): + """Model wrapper that retries ``request()`` on transient provider errors. + + Model resolution is deferred until the first actual call so that + constructing a ``RetryingModel`` from a model name string does not + require provider API keys to be available at import/init time. + + Only ``request()`` has a retry loop. ``request_stream()`` falls back + to ``request()`` when the stream raises a retryable error, since + ``@asynccontextmanager`` only allows a single ``yield``. + """ + + def __init__( + self, + wrapped: Model | KnownModelName, + *, + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 10.0, + ): + # Bypass WrapperModel.__init__ to defer infer_model. + Model.__init__(self) + self._wrapped_or_name = wrapped + self._resolved: Model | None = None + self.max_attempts = max_attempts + self.base_delay = base_delay + self.max_delay = max_delay + + @property + def wrapped(self) -> Model: + if self._resolved is None: + self._resolved = ( + self._wrapped_or_name + if isinstance(self._wrapped_or_name, Model) + else _resolve_model(self._wrapped_or_name) + ) + return self._resolved + + @wrapped.setter + def wrapped(self, value: Model) -> None: + self._resolved = value + + def _delay_for(self, attempt: int) -> float: + """Exponential back-off delay capped at ``max_delay``.""" + return min(self.base_delay * (2 ** (attempt - 1)), self.max_delay) + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + for attempt in range(1, self.max_attempts + 1): + try: + return await self.wrapped.request( + messages, model_settings, model_request_parameters + ) + except Exception as exc: + # Try to recover tool_use_failed into a response so + # pydantic-ai's validation loop can tell the model what + # was wrong (instead of blindly retrying the same request). + recovered = _try_recover_tool_use_failed(exc) + if recovered is not None: + logger.info( + "[assistant] Recovered tool_use_failed error into ModelResponse" + ) + return recovered + + if ( + not _is_transient_provider_error(exc) + or attempt == self.max_attempts + ): + raise + delay = self._delay_for(attempt) + logger.warning( + "[assistant] Model request failed (attempt {}/{}), " + "retrying in {:.1f}s: {}", + attempt, + self.max_attempts, + delay, + repr(exc), + ) + await asyncio.sleep(delay) + raise RuntimeError("Exhausted retries") # pragma: no cover + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + run_context: RunContext[Any] | None = None, + ) -> AsyncIterator[StreamedResponse]: + yielded = False + try: + async with self.wrapped.request_stream( + messages, model_settings, model_request_parameters, run_context + ) as stream: + yielded = True + # Wrap the stream so that errors *during* chunk iteration + # (e.g. groq.APIError with malformed failed_generation) + # are caught and converted to recovery events rather than + # crashing the entire agent run. + yield _ErrorRecoveringStream(stream) + except Exception as exc: + if yielded: + # Error during stream consumption that + # _ErrorRecoveringStream couldn't handle. + raise + + # Setup error — try to recover tool_use_failed. + recovered = _try_recover_tool_use_failed(exc) + if recovered is not None: + logger.info( + "[assistant] Recovered tool_use_failed error " + "in stream into ModelResponse" + ) + yield _PreFetchedResponse(recovered, model_request_parameters) + return + + if not _is_transient_provider_error(exc): + raise + # Stream failed with a retryable error. Fall back to a + # non-streaming request which has its own retry loop. + logger.warning( + "[assistant] Stream failed with retryable error, " + "falling back to non-streaming request: {}", + repr(exc), + ) + response = await self.request( + messages, model_settings, model_request_parameters + ) + yield _PreFetchedResponse(response, model_request_parameters) + + +class _ErrorRecoveringStream(StreamedResponse): + """Transparent proxy around a ``StreamedResponse`` that catches provider + errors during chunk iteration and converts ``tool_use_failed`` errors + (even with malformed JSON) into recovery events. + + pydantic-ai's ``GroqStreamedResponse`` already handles ``tool_use_failed`` + when the ``failed_generation`` JSON is *valid*, but fails when it is + truly malformed because ``Json[_GroqToolUseFailedGeneration]`` raises + ``ValidationError``. This wrapper catches the re-raised ``APIError`` + and emits a ``ToolCallPart`` or ``TextPart`` so pydantic-ai's + validation loop can tell the model what was wrong. + """ + + # Dataclass fields on StreamedResponse that have class-level defaults + # (e.g. ``final_result_event = None``). These shadow ``__getattr__`` + # because Python finds the class attribute before calling __getattr__. + # We override them as properties so reads delegate to ``_inner``. + final_result_event = property(lambda self: self._inner.final_result_event) # type: ignore[assignment] + provider_response_id = property(lambda self: self._inner.provider_response_id) # type: ignore[assignment] + provider_details = property(lambda self: self._inner.provider_details) # type: ignore[assignment] + finish_reason = property(lambda self: self._inner.finish_reason) # type: ignore[assignment] + + def __init__(self, inner: StreamedResponse): + # Don't call super().__init__() — delegate everything to *inner*. + # Only store our own _inner and _event_iterator on the instance. + object.__setattr__(self, "_inner", inner) + object.__setattr__(self, "_event_iterator", None) + + def __getattr__(self, name: str) -> Any: + return getattr(self._inner, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_inner", "_event_iterator"): + object.__setattr__(self, name, value) + else: + setattr(self._inner, name, value) + + async def _get_event_iterator( + self, + ) -> AsyncIterator[ModelResponseStreamEvent]: + try: + async for event in self._inner._get_event_iterator(): + yield event + except Exception as exc: + recovered = _try_recover_tool_use_failed(exc) + if recovered is None: + raise + logger.info( + "[assistant] Recovered tool_use_failed error during stream consumption" + ) + for i, part in enumerate(recovered.parts): + yield self._parts_manager.handle_part( + vendor_part_id=f"recovered-{i}", part=part + ) + + # Abstract properties — delegate to inner stream. + + @property + def model_name(self) -> str: + return self._inner.model_name + + @property + def provider_name(self) -> str | None: + return self._inner.provider_name + + @property + def provider_url(self) -> str | None: + return self._inner.provider_url + + @property + def timestamp(self) -> datetime: + return self._inner.timestamp + + +class _PreFetchedResponse(StreamedResponse): + """A ``StreamedResponse`` backed by an already-complete ``ModelResponse``. + + Used when ``request_stream`` falls back to ``request()`` after a + retryable streaming error. Emits all response parts as immediate + ``PartStartEvent`` s so pydantic-ai can process them normally. + """ + + def __init__( + self, + response: ModelResponse, + model_request_parameters: ModelRequestParameters, + ): + super().__init__(model_request_parameters=model_request_parameters) + self._response = response + self._usage.input_tokens = response.usage.input_tokens + self._usage.output_tokens = response.usage.output_tokens + + async def _get_event_iterator( + self, + ) -> AsyncIterator[ModelResponseStreamEvent]: + for i, part in enumerate(self._response.parts): + yield self._parts_manager.handle_part(vendor_part_id=i, part=part) + + @property + def model_name(self) -> str: + return self._response.model_name or "" + + @property + def provider_name(self) -> str | None: + return self._response.provider_name + + @property + def provider_url(self) -> str | None: + return self._response.provider_url + + @property + def timestamp(self) -> datetime: + return self._response.timestamp or datetime.now(tz=timezone.utc) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/signatures.py b/enterprise/backend/src/baserow_enterprise/assistant/signatures.py deleted file mode 100644 index 60bd981266..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/signatures.py +++ /dev/null @@ -1,33 +0,0 @@ -import udspy - -from .prompts import AGENT_SYSTEM_PROMPT - - -class ChatSignature(udspy.Signature): - __doc__ = AGENT_SYSTEM_PROMPT - - question: str = udspy.InputField() - conversation_history: list[str] = udspy.InputField( - desc="Previous messages formatted as '[index] (role): content', ordered chronologically" - ) - ui_context: str | None = udspy.InputField( - default=None, - description=( - "The JSON serialized context the user is currently in. " - "It contains information about the user, the timezone, the workspace, etc." - "Whenever make sense, use it to ground your answer." - ), - ) - answer: str = udspy.OutputField() - - @classmethod - def format_conversation_history(cls, history: udspy.History) -> list[str]: - """ - Format the conversation history into a list of strings for the signature. - """ - - formatted_history = [] - for i, msg in enumerate(history.messages): - formatted_history.append(f"[{i}] ({msg['role']}): {msg['content']}") - - return formatted_history diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tasks.py b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py index ef72768cfc..babb9f6b30 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tasks.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py @@ -3,7 +3,7 @@ from baserow.config.celery import app from .handler import AssistantHandler -from .tools import KnowledgeBaseHandler +from .tools.search_user_docs.handler import KnowledgeBaseHandler @app.task(bind=True) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py index 9b5259092d..f1242c39a9 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py @@ -1,21 +1,50 @@ """ Posthog telemetry integration for the Baserow Assistant. -This module provides tracing callbacks that capture DSPy execution flows -and send structured events to Posthog for LLM analytics. +Hooks into pydantic-ai's OpenTelemetry instrumentation to capture LLM +generation and tool call events, mapping them to PostHog's AI analytics +event schema (``$ai_trace``, ``$ai_generation``, ``$ai_span``). + +Architecture: + + PosthogTracingCallback -- per-request context manager that emits the + top-level ``$ai_trace`` event and publishes + trace metadata via a ``ContextVar`` for the + span exporter. + + PosthogSpanProcessor -- OpenTelemetry ``SpanProcessor`` that maps + pydantic-ai spans to PostHog events: + ``chat ...`` -> ``$ai_generation`` + ``running tool`` -> ``$ai_span`` + ``agent run`` -> ``$ai_span`` + The ``running tools`` grouping span is + transparently skipped; child tool spans have + their parent remapped to the grandparent + (typically the ``agent run`` span). + + setup_instrumentation() -- one-time wiring of the span processor into a + ``TracerProvider`` + ``Agent.instrument_all()``. """ +from __future__ import annotations + +import json from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any from uuid import uuid4 -import udspy -from udspy.callback import BaseCallback +from opentelemetry.sdk.trace import ReadableSpan, SpanProcessor, TracerProvider +from opentelemetry.trace import SpanKind from baserow.core.posthog import get_posthog_client from baserow_enterprise.assistant.models import AssistantChat +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + def _utc_now() -> datetime: return datetime.now(tz=timezone.utc) @@ -25,23 +54,436 @@ def _uuid() -> str: return str(uuid4()) -class PosthogTracingCallback(BaseCallback): +def _posthog_capture(distinct_id: str, event: str, properties: dict, **kwargs): + """Send a single event to PostHog with standardised error handling.""" + + posthog_client = get_posthog_client() + try: + posthog_client.capture( + distinct_id=distinct_id, event=event, properties=properties, **kwargs + ) + except Exception: + pass + + +# --------------------------------------------------------------------------- +# Trace context (ContextVars shared between callback and span exporter) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _TraceContext: + """Immutable snapshot of per-request trace metadata.""" + + trace_id: str + user_id: str + workspace_id: str + chat_uuid: str + + +_trace_ctx: ContextVar[_TraceContext | None] = ContextVar("_trace_ctx", default=None) + +# Tool names collected during a trace for the $ai_trace summary. +_tool_calls: ContextVar[list[str]] = ContextVar("_tool_calls") + + +# --------------------------------------------------------------------------- +# Message format conversion (pydantic-ai -> PostHog) +# --------------------------------------------------------------------------- + + +def _parse_arguments(value): + """Ensure tool call arguments are a dict, parsing JSON strings if needed.""" + if isinstance(value, str): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError): + return value + return value + + +# pydantic-ai key names -> PostHog key names +_PART_TRANSFORMS = { + "text": lambda p: { + "type": "text", + "text": p.get("content", ""), + }, + "tool_call": lambda p: { + "type": "tool_call", + "tool_call_id": p.get("id", ""), + "name": p.get("name", ""), + "arguments": _parse_arguments(p.get("arguments", {})), + }, + "tool_return": lambda p: { + "type": "tool_result", + "tool_call_id": p.get("tool_call_id", ""), + "content": p.get("content", ""), + }, + "thinking": lambda p: { + "type": "thinking", + "thinking": p.get("content", ""), + }, +} + + +def _safe_json_attr(attrs: dict, key: str) -> list | dict | None: + """Extract a JSON-serialised span attribute, returning None if missing or + unparseable.""" + + val = attrs.get(key) + if val is None: + return None + if isinstance(val, str): + try: + return json.loads(val) + except (json.JSONDecodeError, TypeError): + return None + return val + + +def _pydantic_messages_to_posthog(messages: list[dict]) -> list[dict]: + """Convert pydantic-ai message dicts to PostHog's expected format. + + pydantic-ai: ``{"role": ..., "parts": [{"type": "text", "content": ...}]}`` + PostHog: ``{"role": ..., "content": [{"type": "text", "text": ...}]}`` """ - Captures uDSPy execution traces and sends events to Posthog. - This callback tracks: - - uDSPy module execution (ChainOfThought, ReAct, Predict) - - LLM API calls (OpenAI, Groq, etc.) - - Tool invocations - - Performance metrics and token usage + result = [] + for msg in messages: + content_parts = [] + for part in msg.get("parts", []): + ptype = part.get("type", "text") + transform = _PART_TRANSFORMS.get(ptype) + content_parts.append(transform(part) if transform else part) + result.append({"role": msg.get("role", "unknown"), "content": content_parts}) + return result + + +# --------------------------------------------------------------------------- +# Span helpers (shared by _emit_generation and _emit_tool_span) +# --------------------------------------------------------------------------- + + +def _span_latency(span: ReadableSpan) -> float | None: + """Compute span duration in seconds from OTel nanosecond timestamps.""" + + if span.start_time and span.end_time: + return (span.end_time - span.start_time) / 1e9 + return None + + +def _base_properties(ctx: _TraceContext) -> dict: + """Properties common to every PostHog event within a trace.""" + + return { + "$ai_trace_id": ctx.trace_id, + "$ai_session_id": ctx.chat_uuid, + "workspace_id": ctx.workspace_id, + } + + +def _extract_reasoning(output_messages: list[dict]) -> str | None: + """Join all ``thinking`` parts and tool-call ``thought`` fields from output + messages into a single string.""" + + parts: list[str] = [] + for msg in output_messages: + for part in msg.get("parts", []): + ptype = part.get("type") + if ptype == "thinking": + if content := part.get("content"): + parts.append(content) + elif ptype == "tool_call": + args = _parse_arguments(part.get("arguments", {})) + if isinstance(args, dict) and (thought := args.get("thought")): + parts.append(thought) + return "\n".join(parts) if parts else None + - Each instance is created per Assistant call with trace context, so - multiple concurrent traces can be captured independently. +# --------------------------------------------------------------------------- +# PosthogSpanExporter +# --------------------------------------------------------------------------- + +# Model setting keys emitted by pydantic-ai as ``gen_ai.request.*`` attrs. +_MODEL_PARAM_KEYS = ( + "temperature", + "max_tokens", + "top_p", + "seed", + "presence_penalty", + "frequency_penalty", +) + + +class PosthogSpanProcessor(SpanProcessor): + """Maps pydantic-ai OTel spans to PostHog LLM analytics events. + + ``chat {model}`` -> ``$ai_generation`` + ``running tool`` -> ``$ai_span`` (parent remapped past ``running tools``) + ``agent run`` -> ``$ai_span`` + ``running tools`` -> skipped (children re-parented to grandparent) """ def __init__(self): - super().__init__() + # "running tools" span_id -> its parent span_id. + # Populated on_start so child tool spans (which end first) can + # look up the grandparent during on_end. + self._tools_group_parents: dict[int, int | None] = {} + + # -- SpanProcessor interface ------------------------------------------- + + def on_start(self, span, parent_context=None): + if span.name == "running tools": + parent_id = span.parent.span_id if span.parent else None + self._tools_group_parents[span.context.span_id] = parent_id + + def on_end(self, span: ReadableSpan): + ctx = _trace_ctx.get() + if ctx is None: + return + + try: + self._process_span(span, ctx) + except Exception: + pass + + # Clean up mapping once the grouping span itself ends. + if span.name == "running tools": + self._tools_group_parents.pop(span.context.span_id, None) + + def shutdown(self): + pass + + def force_flush(self, timeout_millis: int = 0) -> bool: + return True + + # -- internal ---------------------------------------------------------- + + def _resolve_parent_id(self, span: ReadableSpan) -> str | None: + """Return the hex ``$ai_parent_id``, skipping ``running tools``.""" + + if not span.parent: + return None + parent_id = span.parent.span_id + # If the direct parent is a "running tools" span, jump to its parent. + grandparent = self._tools_group_parents.get(parent_id) + if grandparent is not None: + parent_id = grandparent + return f"{parent_id:016x}" + + def _span_id_props(self, span: ReadableSpan) -> dict: + props: dict = {"$ai_span_id": f"{span.context.span_id:016x}"} + parent_hex = self._resolve_parent_id(span) + if parent_hex: + props["$ai_parent_id"] = parent_hex + return props + + def _process_span(self, span: ReadableSpan, ctx: _TraceContext): + attrs = dict(span.attributes or {}) + + if span.kind == SpanKind.CLIENT and span.name.startswith("chat "): + self._emit_generation(span, attrs, ctx) + elif span.name == "running tool": + self._emit_tool_span(span, attrs, ctx) + elif span.name == "agent run": + self._emit_agent_span(span, attrs, ctx) + # "running tools" is intentionally not emitted. + + def _emit_generation(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map a ``chat {model}`` span to ``$ai_generation``.""" + + input_messages = _safe_json_attr(attrs, "gen_ai.input.messages") + output_messages = _safe_json_attr(attrs, "gen_ai.output.messages") + + properties = { + **_base_properties(ctx), + "$ai_model": ( + attrs.get("gen_ai.response.model") or attrs.get("gen_ai.request.model") + ), + "$ai_provider": ( + attrs.get("gen_ai.provider.name") or attrs.get("gen_ai.system") + ), + "$ai_input_tokens": attrs.get("gen_ai.usage.input_tokens"), + "$ai_output_tokens": attrs.get("gen_ai.usage.output_tokens"), + } + + # Model parameters + model_params = { + key: val + for key in _MODEL_PARAM_KEYS + if (val := attrs.get(f"gen_ai.request.{key}")) is not None + } + if model_params: + properties["$ai_model_parameters"] = model_params + + # System prompt + system_instructions = _safe_json_attr(attrs, "gen_ai.system_instructions") + if system_instructions and isinstance(system_instructions, list): + system_text = "\n".join( + p.get("content", "") for p in system_instructions if isinstance(p, dict) + ) + if system_text: + properties["$ai_system_prompt"] = system_text + + # Input / output messages + if input_messages: + properties["$ai_input"] = _pydantic_messages_to_posthog(input_messages) + if output_messages: + properties["$ai_output_choices"] = _pydantic_messages_to_posthog( + output_messages + ) + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + # Tool definitions and names + tool_definitions = _safe_json_attr(attrs, "gen_ai.tool.definitions") + if tool_definitions and isinstance(tool_definitions, list): + tool_names = [ + t.get("name", "?") for t in tool_definitions if isinstance(t, dict) + ] + if tool_names: + properties["$ai_tools"] = tool_names + properties["$ai_tool_definitions"] = tool_definitions + + # Reasoning / thinking + if output_messages and isinstance(output_messages, list): + reasoning = _extract_reasoning(output_messages) + if reasoning: + properties["$ai_reasoning"] = reasoning + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_generation", properties) + + def _emit_agent_span(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map an ``agent run`` span to ``$ai_span`` with the agent name.""" + + agent_name = attrs.get("agent_name", "unknown_agent") + + properties = { + **_base_properties(ctx), + "$ai_span_name": f"Agent: {agent_name}", + } + + # System prompt + system_instructions = _safe_json_attr(attrs, "gen_ai.system_instructions") + if system_instructions and isinstance(system_instructions, list): + system_text = "\n".join( + p.get("content", "") for p in system_instructions if isinstance(p, dict) + ) + if system_text: + properties["$ai_input_state"] = {"system_prompt": system_text} + + # User input (first user message) and final output + all_messages = _safe_json_attr(attrs, "pydantic_ai.all_messages") + if all_messages and isinstance(all_messages, list): + for msg in all_messages: + if msg.get("role") == "user": + parts = msg.get("parts", []) + user_texts = [ + p.get("content", "") + for p in parts + if isinstance(p, dict) and p.get("type") == "text" + ] + if user_texts: + input_state = properties.get("$ai_input_state", {}) + input_state["user_prompt"] = "\n".join(user_texts) + properties["$ai_input_state"] = input_state + break + + final_result = attrs.get("final_result") + if final_result is not None: + properties["$ai_output_state"] = _parse_arguments(final_result) + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_span", properties) + + def _emit_tool_span(self, span: ReadableSpan, attrs: dict, ctx: _TraceContext): + """Map a ``running tool`` span to ``$ai_span``.""" + + tool_name = attrs.get("gen_ai.tool.name", "unknown_tool") + + # Record for the trace summary. + try: + _tool_calls.get().append(tool_name) + except LookupError: + pass + + tool_args = _safe_json_attr(attrs, "tool_arguments") + + properties = { + **_base_properties(ctx), + "$ai_span_name": f"Tool: {tool_name}", + "$ai_input_state": tool_args or {}, + "$ai_output_state": _parse_arguments(attrs.get("tool_response")), + } + + # Chain-of-thought reasoning from the "thought" argument + if isinstance(tool_args, dict) and tool_args.get("thought"): + properties["$ai_reasoning"] = tool_args["thought"] + + latency = _span_latency(span) + if latency is not None: + properties["$ai_latency"] = latency + + properties.update(self._span_id_props(span)) + _posthog_capture(ctx.user_id, "$ai_span", properties) + + +# --------------------------------------------------------------------------- +# One-time instrumentation setup +# --------------------------------------------------------------------------- + +_instrumentation_ready = False + + +def setup_instrumentation(): + """Activate pydantic-ai's OTel instrumentation with PostHog export. + Safe to call multiple times (subsequent calls are no-ops). + Does nothing when PostHog is disabled. + """ + + global _instrumentation_ready + if _instrumentation_ready: + return + + from django.conf import settings as django_settings + + posthog_enabled = getattr(django_settings, "POSTHOG_ENABLED", False) + if not posthog_enabled: + return + + from pydantic_ai import Agent, InstrumentationSettings + + tracer_provider = TracerProvider() + tracer_provider.add_span_processor(PosthogSpanProcessor()) + + Agent.instrument_all( + InstrumentationSettings( + tracer_provider=tracer_provider, + include_content=True, + ) + ) + + _instrumentation_ready = True + + +# --------------------------------------------------------------------------- +# PosthogTracingCallback — per-request trace lifecycle +# --------------------------------------------------------------------------- + + +class PosthogTracingCallback: + """Per-request trace lifecycle. Creates the ``$ai_trace`` event and + publishes ``_TraceContext`` for the span exporter.""" + + def __init__(self): self.chat: AssistantChat | None = None self.human_msg: str | None = None self.trace_id: str | None = None @@ -49,50 +491,36 @@ def __init__(self): self.user_id: str | None = None self.workspace_id: str | None = None self.chat_uuid: str | None = None - self.spans: dict[str, dict] = {} - self.span_ids: list[str] = [] + self.trace_outputs = None @contextmanager def trace(self, chat: AssistantChat, human_message: str): - """ - Context manager for tracing an assistant execution. - Initializes trace context and captures the overall trace event. - It also patches the OpenAI client to auto-capture generation events. + """Context manager that scopes a single assistant execution. - :param chat: The AssistantChat instance - :param human_message: The initial user message + Publishes ``_trace_ctx`` so ``PosthogSpanExporter`` can attach trace + metadata to child ``$ai_generation`` / ``$ai_span`` events. """ - from posthog.ai.openai import AsyncOpenAI - self.chat = chat self.human_msg = human_message - self.trace_id = _uuid() self.span_id = _uuid() self.user_id = str(chat.user_id) self.workspace_id = str(chat.workspace_id) self.chat_uuid = str(chat.uuid) + self.trace_outputs = None start_time = _utc_now() - self.spans = {} - self.span_ids = [self.span_id] - self.trace_outputs = None - # patch the OpenAI client to automatically send the generation event - lm = udspy.settings._context_lm.get() - openai_client = lm.client - - # Check if client is already a PostHog-wrapped client by checking its - # module. We avoid isinstance() here because it can fail when the class - # is mocked in tests. - is_posthog_client = "posthog" in type(openai_client).__module__ - if not is_posthog_client: - lm.client = AsyncOpenAI( - api_key=openai_client.api_key, - base_url=openai_client.base_url, - posthog_client=get_posthog_client(), + token = _trace_ctx.set( + _TraceContext( + trace_id=self.trace_id, + user_id=self.user_id, + workspace_id=self.workspace_id, + chat_uuid=self.chat_uuid, ) + ) + tools_token = _tool_calls.set([]) exception = None try: @@ -101,7 +529,17 @@ def trace(self, chat: AssistantChat, human_message: str): exception = exc raise finally: - # Stop trace + tool_call_names = _tool_calls.get([]) + _trace_ctx.reset(token) + _tool_calls.reset(tools_token) + + output_state = self.trace_outputs if exception is None else str(exception) + if tool_call_names: + if output_state is None: + output_state = {} + if isinstance(output_state, dict): + output_state["tool_calls"] = tool_call_names + self._capture_event( "$ai_trace", timestamp=start_time, @@ -112,190 +550,27 @@ def trace(self, chat: AssistantChat, human_message: str): "$ai_latency": (_utc_now() - start_time).total_seconds(), "$ai_is_error": exception is not None, "$ai_input_state": {"user_message": human_message}, - "$ai_output_state": self.trace_outputs - if exception is None - else str(exception), + "$ai_output_state": output_state, }, ) - def _capture_event(self, event: str, **kwargs): - """ - Capture a Posthog event if Posthog is enabled. + try: + get_posthog_client().flush() + except Exception: + pass - :param event: Event name (e.g., "$ai_generation") - :param properties: Event properties dictionary - """ + def set_trace_output(self, output: str): + """Record the agent's final answer for the ``$ai_trace`` event.""" - default_props = { - "$ai_trace_id": self.trace_id, - "$ai_session_id": self.chat_uuid, - "workspace_id": self.workspace_id, - } - if "properties" in kwargs: - kwargs["properties"].update(default_props) - else: - kwargs["properties"] = default_props - - posthog_client = get_posthog_client() - posthog_client.capture( - distinct_id=str(self.user_id), - event=event, - **kwargs, - ) # noqa: W505 - - def on_module_start(self, call_id: str, instance: Any, inputs: dict): - """ - Track the start of a DSPy module execution. + self.trace_outputs = {"answer": output} - Captures ChainOfThought, ReAct, Predict, and other module types. - - :param call_id: Unique identifier for this call - :param instance: The DSPy module instance - :param inputs: Input dictionary passed to the module - """ - - module_type = instance.__class__.__name__ - parent_span_id = self.span_ids[-1] if self.span_ids else None - span_id = call_id - self.span_ids.append(span_id) - span = { - "start_time": _utc_now(), - "properties": { - "$ai_span_name": module_type, - "$ai_span_id": span_id, - "$ai_parent_span_id": parent_span_id, - }, - } - self.spans[span_id] = span - - def _update_span_with_signature_data(signature): - adapter = udspy.ChatAdapter() - input_fields = ", ".join(signature.get_input_fields().keys()) - output_fields = ", ".join(signature.get_output_fields()) - span["properties"]["$ai_input_state"] = { - "signature": f"{input_fields} -> {output_fields}", - "instructions": adapter.format_instructions(signature), - **inputs["kwargs"], - } - - if isinstance(instance, (udspy.Predict, udspy.ReAct)): - _update_span_with_signature_data(instance.signature) - elif isinstance(instance, udspy.ChainOfThought): - _update_span_with_signature_data(instance.original_signature) - - def on_module_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Remove the span from the stack together with all the started $ai_generation - spans appended in `on_lm_start` - - Args: - call_id: Unique identifier for this call - outputs: Module output (if successful) - exception: Exception raised (if failed) - """ - - while (span_id := self.span_ids.pop()) != call_id: - continue - - span = self.spans.pop(span_id) - start_time = span.pop("start_time") - span["properties"].update( - { - "$ai_latency": (_utc_now() - start_time).total_seconds(), - "$ai_is_error": exception is not None, - "$ai_output_state": outputs if exception is None else str(exception), - } - ) - - if isinstance(outputs, dict) and "answer" in outputs: - self.trace_outputs = { - k: v - for k, v in outputs.items() - if k not in ["module", "native_tool_calls"] - } - - self._capture_event("$ai_span", timestamp=start_time, **span) - - def on_lm_start(self, call_id: str, instance: Any, inputs: dict): - """ - Only enrich posthog properties that will be sent automatically - by the patched openai client. - Add the span_id to the stack so any tool call will be shown - as a child span. - - Args: - call_id: Unique identifier for this call - instance: The LM instance - inputs: API call parameters (model, messages, temperature, etc.) - """ + def _capture_event(self, event: str, **kwargs): + """Capture a PostHog event, merging in default trace properties.""" - parent_span_id = self.span_ids[-1] if self.span_ids else None - kwargs = inputs["kwargs"] - span_id = call_id - self.span_ids.append(span_id) - kwargs["posthog_distinct_id"] = self.user_id - kwargs["posthog_trace_id"] = self.trace_id - kwargs["posthog_properties"] = { + kwargs["properties"] = { + **kwargs.get("properties", {}), + "$ai_trace_id": self.trace_id, "$ai_session_id": self.chat_uuid, - "$ai_parent_span_id": parent_span_id, - "$ai_span_id": span_id, "workspace_id": self.workspace_id, - "$ai_provider": instance.provider, - } - - def on_lm_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Automatically tracked by the patched openai client. - - :param call_id: Unique identifier for this call - :param outputs: LLM response object - :param exception: Exception raised (if failed) - """ - - pass - - def on_tool_start(self, call_id: str, instance: Any, inputs: dict): - """ - Track the start of a tool invocation. - - Args: - call_id: Unique identifier for this call - instance: The tool instance - inputs: Tool input parameters - """ - - tool_name = getattr(instance, "name", instance.__class__.__name__) - - span_id = call_id - parent_span_id = self.span_ids[-1] if self.span_ids else None - self.spans[span_id] = { - "start_time": _utc_now(), - "properties": { - "$ai_span_name": f"Tool: {tool_name}", - "$ai_span_id": span_id, - "$ai_parent_span_id": parent_span_id, - "$ai_input_state": inputs, - }, } - - def on_tool_end(self, call_id: str, outputs: Any, exception: Exception | None): - """ - Track the completion of a tool invocation. - - Args: - call_id: Unique identifier for this call - outputs: Tool output - exception: Exception raised (if failed) - """ - - span_id = call_id - span = self.spans.pop(span_id) - start_time = span.pop("start_time") - span["properties"].update( - { - "$ai_latency": (_utc_now() - start_time).total_seconds(), - "$ai_is_error": exception is not None, - "$ai_output_state": outputs if exception is None else str(exception), - } - ) - self._capture_event("$ai_span", timestamp=start_time, **span) + _posthog_capture(str(self.user_id), event, **kwargs) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py index 4ac95ed235..8b13789179 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py @@ -1,5 +1 @@ -from .automation.tools import * # noqa: F401, F403 -from .core.tools import * # noqa: F401, F403 -from .database.tools import * # noqa: F401, F403 -from .navigation.tools import * # noqa: F401, F403 -from .search_user_docs.tools import * # noqa: F401, F403 + diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py index ace1c221c3..8b13789179 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/__init__.py @@ -1,6 +1 @@ -from .tools import ListWorkflowsToolType, WorkflowToolFactoryToolType -__all__ = [ - "ListWorkflowsToolType", - "WorkflowToolFactoryToolType", -] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py new file mode 100644 index 0000000000..8f56f79152 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/agents.py @@ -0,0 +1,156 @@ +""" +Sub-agents for the automation assistant tools. + +Contains: +- ``AssistantFormulaContext``: Automation-specific formula context. +- ``get_generate_formulas_tool()``: Gets the automation formula generator. +- ``update_workflow_formulas()``: Generates formulas for workflow nodes. +""" + +from typing import TYPE_CHECKING, Any + +from django.db import transaction +from django.utils.translation import gettext as _ + +from loguru import logger + +from baserow.contrib.automation.nodes.models import AutomationNode +from baserow_enterprise.assistant.tools.shared.agents import get_formula_generator +from baserow_enterprise.assistant.tools.shared.formula_utils import ( + BaseFormulaContext, + create_example_from_json_schema, + minimize_json_schema, +) + +from .prompts import GENERATE_FORMULA_PROMPT +from .types import ActionNodeCreate, NodeUpdate, WorkflowCreate + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + +class AssistantFormulaContext(BaseFormulaContext): + """ + Automation-specific formula context. + + Wraps node data in the ``{"previous_node": {...}}`` structure expected + by automation formula ``get()`` paths. + """ + + def add_node_context( + self, + node_id: int | str, + node_context: dict[str, Any], + context_metadata: dict[str, dict[str, str]] | None = None, + ): + """Add a node's output values to the formula context.""" + self.add_context(str(node_id), node_context, context_metadata) + + def get_formula_context(self) -> dict[str, Any]: + """Return context wrapped in ``previous_node`` for automation formulas.""" + return {"previous_node": self.context} + + def __getitem__(self, key) -> Any: + """Resolve paths like ``previous_node.1.0.field_name``.""" + return self._resolve_path(key, "previous_node") + + +def get_generate_formulas_tool(): + """Get the automation formula generator using the shared factory.""" + return get_formula_generator(GENERATE_FORMULA_PROMPT) + + +def update_workflow_formulas( + workflow: "WorkflowCreate", + node_mapping: dict[int | str, Any], + tool_helpers: "ToolHelpers", +) -> None: + """ + Generate and apply formulas for all nodes in a newly created workflow. + + Walks nodes in order, building up the available formula context as it goes. + For each node that has ``$formula:`` values, delegates to the formula + generation agent and writes the results back to the ORM service. + """ + + context = AssistantFormulaContext() + generate_formula = get_generate_formulas_tool() + + def _build_node_context(orm_node: AutomationNode, node_create): + """Extract schema/example from a node and add it to the formula context.""" + + schema = orm_node.service.get_type().generate_schema(orm_node.service.specific) + example = create_example_from_json_schema(schema) + metadata = minimize_json_schema(schema) + metadata["node_id"] = orm_node.id + metadata["node_ref"] = node_create.ref + if getattr(node_create, "previous_node_ref", None): + metadata["previous_node_ref"] = node_create.previous_node_ref + context.add_node_context(orm_node.id, example, metadata) + + def _generate_node_formulas(node: ActionNodeCreate, orm_node: AutomationNode): + """Generate formulas for a single node and write them to the service.""" + + formulas_to_create = node.get_formulas_to_create(orm_node) + if formulas_to_create is None: + return + result = generate_formula(formulas_to_create, context) + if result: + node.update_service_with_formulas(orm_node.service, result) + + # Seed context with the trigger + orm_trigger, trigger_create = node_mapping[workflow.trigger.ref] + _build_node_context(orm_trigger, trigger_create) + + # Process action nodes in order + for node in workflow.nodes: + orm_node, _node_create = node_mapping[node.ref] + node.apply_direct_values(orm_node.service) + + if node.get_formulas_to_create(orm_node) is not None: + tool_helpers.update_status( + _("Generating formulas for node '%(label)s'..." % {"label": node.label}) + ) + with transaction.atomic(): + try: + _generate_node_formulas(node, orm_node) + except Exception as exc: + logger.exception( + "Failed to generate formulas for node {}: {}", orm_node.id, exc + ) + + _build_node_context(orm_node, node) + + +def update_single_node_formulas( + node_update: "NodeUpdate", + orm_node: AutomationNode, + tool_helpers: "ToolHelpers", +) -> None: + """ + Generate and apply formulas for a single node being updated. + + Builds formula context from the node's workflow, then generates + formulas for the $formula: fields in the update. + """ + + context = AssistantFormulaContext() + generate_formula = get_generate_formulas_tool() + + # Build context from the workflow's existing nodes + workflow = orm_node.workflow + all_nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + for wf_node in all_nodes: + schema = wf_node.service.get_type().generate_schema(wf_node.service.specific) + example = create_example_from_json_schema(schema) + metadata = minimize_json_schema(schema) + metadata["node_id"] = wf_node.id + context.add_node_context(wf_node.id, example, metadata) + + formulas_to_create = node_update.get_formulas_to_update(orm_node) + if formulas_to_create is None: + return + + result = generate_formula(formulas_to_create, context) + if result: + node_update.update_service_with_formulas(orm_node.service, result) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py new file mode 100644 index 0000000000..c7bfe5fe10 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/helpers.py @@ -0,0 +1,288 @@ +""" +Shared helpers for the automation assistant tools. + +Contains permission-checked accessors and the workflow creation orchestrator +used by ``tools.py`` and ``agents.py``. +""" + +from typing import TYPE_CHECKING, Any + +from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ + +from baserow.contrib.automation.models import Automation +from baserow.contrib.automation.nodes.registries import automation_node_type_registry +from baserow.contrib.automation.nodes.service import AutomationNodeService +from baserow.contrib.automation.workflows.models import AutomationWorkflow +from baserow.contrib.automation.workflows.service import AutomationWorkflowService +from baserow.core.models import Workspace +from baserow.core.service import CoreService + +from .types import NodeUpdate, WorkflowCreate + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + from .types import ActionNodeCreate + + +def get_automation( + automation_id: int, user: AbstractUser, workspace: Workspace +) -> Automation: + """Fetch an automation scoped to the user's workspace.""" + + base_queryset = Automation.objects.filter(workspace=workspace) + return CoreService().get_application( + user, automation_id, base_queryset=base_queryset + ) + + +def get_workflow( + workflow_id: int, user: AbstractUser, workspace: Workspace +) -> AutomationWorkflow: + """Fetch a workflow with a workspace-level permission check.""" + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + if workflow.automation.workspace_id != workspace.id: + raise ValueError("Workflow not in workspace") + return workflow + + +def get_nodes_in_order(user: AbstractUser, workflow: AutomationWorkflow) -> list[dict]: + """ + Return the nodes of a workflow in graph traversal order. + + Walks the workflow graph starting from the trigger, following ``next`` + edges (all outputs) and ``children`` to produce a flat, ordered list. + """ + + nodes = AutomationNodeService().get_nodes(user, workflow) + node_map = {n.id: n for n in nodes} + graph = workflow.get_graph().graph + + trigger_id = graph.get("0") + if trigger_id is None: + return [] + + ordered_ids: list[int] = [] + visited: set[int] = set() + + def walk(node_id: int): + if node_id in visited or node_id not in node_map: + return + visited.add(node_id) + ordered_ids.append(node_id) + info = graph.get(str(node_id), {}) + # Follow children first (for container nodes like iterators) + for child_id in info.get("children", []): + walk(child_id) + # Then follow next edges in order + for output_uid, next_ids in info.get("next", {}).items(): + for nid in next_ids: + walk(nid) + + walk(trigger_id) + + result = [] + for nid in ordered_ids: + node = node_map[nid] + node_type = node.get_type() + entry = { + "id": node.id, + "label": node.get_label(), + "type": node_type.type, + } + result.append(entry) + + return result + + +def add_nodes_to_workflow( + user: AbstractUser, + workflow: AutomationWorkflow, + nodes: list["ActionNodeCreate"], + tool_helpers: "ToolHelpers", +) -> tuple[list[Any], dict[int | str, Any]]: + """ + Add action nodes to an existing workflow. + + The ``previous_node_ref`` on each node can reference: + - An existing node ID as a string (e.g. "49") + - A temp ref from an earlier node in the same ``nodes`` list + + Returns a list of created ORM nodes and the node mapping. + """ + + # Seed the mapping with existing nodes in the workflow + existing_nodes = AutomationNodeService().get_nodes(user, workflow) + node_mapping: dict[int | str, Any] = {} + for n in existing_nodes: + # Create a stub for the node_create part that has type and edges info + stub = _ExistingNodeStub(n) + node_mapping[str(n.id)] = (n, stub) + node_mapping[n.id] = (n, stub) + + created = [] + for node in nodes: + tool_helpers.raise_if_cancelled() + reference_node_id, output = node.to_orm_reference_node(node_mapping) + orm_node = _create_node( + user, + workflow, + node, + tool_helpers, + reference_node_id=reference_node_id, + output=output, + ) + node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) + created.append(orm_node) + + return created, node_mapping + + +class _EdgeStub: + """Bridges ORM edge ``uid`` to the ``_uid`` attribute expected by ``to_orm_reference_node``.""" + + def __init__(self, orm_edge): + self.label = orm_edge.label + self._uid = str(orm_edge.uid) + + +class _ExistingNodeStub: + """ + Lightweight stub exposing ``type`` and ``edges`` from an existing ORM node, + so ``ActionNodeCreate.to_orm_reference_node`` can resolve router edge labels. + """ + + def __init__(self, orm_node): + self.type = orm_node.get_type().type + self.edges = [] + if self.type == "router" and hasattr(orm_node.service, "specific"): + service = orm_node.service.specific + if hasattr(service, "edges"): + self.edges = [_EdgeStub(e) for e in service.edges.all()] + + +def create_workflow( + user: AbstractUser, + automation: Automation, + workflow: "WorkflowCreate", + tool_helpers: "ToolHelpers", +) -> tuple[AutomationWorkflow, dict[int | str, Any]]: + """ + Create a workflow with its trigger and action nodes. + + Returns the ORM workflow and a mapping of ``{ref_or_id: (orm_node, node_create)}`` + for every created node, usable by downstream formula generation. + """ + + tool_helpers.update_status( + _("Creating workflow '%(name)s'..." % {"name": workflow.name}) + ) + + orm_wf = AutomationWorkflowService().create_workflow( + user, automation.id, workflow.name + ) + + node_mapping: dict[int | str, Any] = {} + + # -- Trigger -- + orm_trigger = _create_node(user, orm_wf, workflow.trigger, tool_helpers) + node_mapping[workflow.trigger.ref] = node_mapping[orm_trigger.id] = ( + orm_trigger, + workflow.trigger, + ) + + # -- Action / router / iterator nodes -- + for node in workflow.nodes: + try: + reference_node_id, output = node.to_orm_reference_node(node_mapping) + except ValueError as exc: + from pydantic_ai import ModelRetry + + raise ModelRetry(str(exc)) from exc + orm_node = _create_node( + user, + orm_wf, + node, + tool_helpers, + reference_node_id=reference_node_id, + output=output, + ) + node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) + + return orm_wf, node_mapping + + +def _create_node(user, workflow, node_create, tool_helpers, **extra_kwargs): + """Create a single automation node (trigger or action).""" + + tool_helpers.update_status( + _("Creating node '%(label)s'..." % {"label": node_create.label}) + ) + node_type = automation_node_type_registry.get(node_create.type) + return AutomationNodeService().create_node( + user, + node_type, + workflow, + label=node_create.label, + service=node_create.to_orm_service_dict(), + **extra_kwargs, + ) + + +def update_node( + user: "AbstractUser", + workspace: "Workspace", + node_update: "NodeUpdate", + tool_helpers: "ToolHelpers", +): + """ + Update an automation node's label and/or service config. + + :param user: The acting user. + :param workspace: Workspace for permission check. + :param node_update: The update definition. + :param tool_helpers: Provides status updates. + :returns: The updated ORM node. + """ + + node = AutomationNodeService().get_node(user, node_update.node_id) + if node.workflow.automation.workspace_id != workspace.id: + raise ValueError("Node not in workspace") + + kwargs = {} + if node_update.label is not None: + kwargs["label"] = node_update.label + + node_type = node.service.get_type().type if node.service else None + service_dict = node_update.to_update_service_dict(node_type) if node_type else None + if service_dict is not None: + kwargs["service"] = service_dict + + if kwargs: + tool_helpers.update_status( + _("Updating node '%(label)s'..." % {"label": node.label}) + ) + AutomationNodeService().update_node(user, node.id, **kwargs) + + return AutomationNodeService().get_node(user, node_update.node_id) + + +def delete_node( + user: "AbstractUser", + workspace: "Workspace", + node_id: int, +): + """ + Delete an automation node. + + :param user: The acting user. + :param workspace: Workspace for permission check. + :param node_id: ID of the node to delete. + """ + + node = AutomationNodeService().get_node(user, node_id) + if node.workflow.automation.workspace_id != workspace.id: + raise ValueError("Node not in workspace") + AutomationNodeService().delete_node(user, node_id) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py index a0e54c5ab1..55904a5b7d 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/prompts.py @@ -1,31 +1,13 @@ -GENERATE_FORMULA_PROMPT = """ -You are a formula builder. Generate formulas using these functions: +from baserow_enterprise.assistant.tools.shared.formula_prompt import FORMULA_LANGUAGE -**Comparison operators** (for router conditions only): -equal, not_equal, greater_than, less_than, greater_than_equal, less_than_equal -- Arguments: numbers, 'strings', or get() functions -- Returns: boolean -- Example: greater_than(get('age'), 18) +GENERATE_FORMULA_PROMPT = ( + FORMULA_LANGUAGE + + """ +## Context: Automation Workflows -**concat(...args)** - Joins arguments into a string -- Arguments: 'string literals' or get() functions -- Example: concat('Hello ', get('name'), '!') - -**get(path)** - Retrieves values from context using path notation -- Objects: get('user.name') -- Arrays: get('items.0'), get('orders.2.total') -- Nested: get('users.0.address.city') -- All: get('users.*.email') returns a list of emails from all users - -**if(condition, true_value, false_value)** - Conditional expression -- Arguments: a boolean condition, value if true, value if false -- Example: if(greater_than(get('score'), 50), 'pass', 'fail') - -**today()** - Returns the current date -**now()** - Returns the current date and time - -**constants**: -- A string literal enclosed in single quotes (e.g., 'hello world', '123') +In automation formulas, data is accessed through the previous_node structure: +- Path format: get('previous_node..0.') +- Each node ID maps to an array of rows; use index 0 for the first (and usually only) row. **Example 1 - String Fields:** Input: @@ -84,3 +66,4 @@ 3. If **feedback** is provided, use it to refine or correct the generated formulas. 4. Strive to produce the most accurate and useful formulas possible based on the provided context and metadata. """ +) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py new file mode 100644 index 0000000000..d8a63f2f61 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class AutomationToolType(AssistantToolType): + type = "automation" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import automation_toolset + + return automation_toolset + + def get_routing_rules(self): + from .tools import ROUTING_RULES + + return ROUTING_RULES diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py index 6046ca5ffa..5338ddf07b 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/tools.py @@ -1,150 +1,370 @@ -from typing import TYPE_CHECKING, Any, Callable +from typing import Annotated, Any -from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ -import udspy +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset from baserow.contrib.automation.workflows.service import AutomationWorkflowService -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.types import WorkflowNavigationType -from . import utils -from .types import WorkflowCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -def get_list_workflows_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], dict[str, list[dict]]]: - """ - List all workflows in an automation. +from . import agents, helpers +from .types import ActionNodeCreate, NodeUpdate, WorkflowCreate + + +def list_workflows( + ctx: RunContext[AssistantDeps], + automation_id: Annotated[ + int, Field(description="The ID of the automation to list workflows for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List workflows in an automation. + + WHEN to use: Check existing workflows in an automation, or find workflow IDs before creating new ones. + WHAT it does: Lists all workflows in an automation with their id, name, and state. + RETURNS: Workflows array with id, name, state. + DO NOT USE when: You already have the workflow IDs you need. """ - def list_workflows(automation_id: int) -> dict[str, Any]: - """ - List all workflows in an automation application. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - :param automation_id: The ID of the automation application - :return: Dictionary with workflows list - """ + tool_helpers.update_status(_("Listing workflows...")) - nonlocal user, workspace, tool_helpers + automation = helpers.get_automation(automation_id, user, workspace) + workflows = AutomationWorkflowService().list_workflows(user, automation.id) - tool_helpers.update_status(_("Listing workflows...")) + return { + "workflows": [{"id": w.id, "name": w.name, "state": w.state} for w in workflows] + } - automation = utils.get_automation(automation_id, user, workspace) - workflows = AutomationWorkflowService().list_workflows(user, automation.id) - return { - "workflows": [ - {"id": w.id, "name": w.name, "state": w.state} for w in workflows - ] - } +def list_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow to list nodes for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List nodes in a workflow in execution order. - return list_workflows + WHEN to use: Inspect the nodes in a workflow, find node IDs before updating or deleting. + WHAT it does: Lists all nodes (trigger + actions) in graph traversal order with id, label, and type. + RETURNS: Nodes array with id, label, type. + DO NOT USE when: You already have the node IDs you need. + """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + tool_helpers.update_status(_("Listing nodes...")) + + workflow = helpers.get_workflow(workflow_id, user, workspace) + nodes = helpers.get_nodes_in_order(user, workflow) + + return {"nodes": nodes} + + +def add_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow to add nodes to.") + ], + nodes: Annotated[ + list[ActionNodeCreate], + Field( + description="Nodes to add. previous_node_ref can be an existing node ID (as string) or a temp ref from an earlier node in this list." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add action/router nodes to an existing workflow. + + WHEN to use: User wants to insert or append nodes in an existing workflow — e.g. add a router between trigger and action, or add a new action after an existing one. + WHAT it does: Creates new nodes attached to existing ones. Use previous_node_ref with the string ID of an existing node (e.g. "49") or a temp ref of a node being created in the same call. + RETURNS: Created nodes array with id, label, type. + DO NOT USE when: You want to create an entirely new workflow — use create_workflows instead. + HOW: Use list_nodes first to find the existing node IDs, then specify previous_node_ref to place new nodes. Use router_edge_label when attaching to a router branch. + """ -def get_workflow_tool_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, list[WorkflowCreate]], dict[str, list[dict]]]: - def create_workflows( - automation_id: int, workflows: list[WorkflowCreate] - ) -> dict[str, Any]: - """ - Create one or more workflows in an automation. Always use {{ node.ref }} to - reference previous nodes values inside the workflow. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - :param automation_id: The automation application ID - :param workflows: List of workflows to create - :return: Dictionary with created workflows - """ + if not nodes: + return {"created_nodes": []} - nonlocal user, workspace, tool_helpers + tool_helpers.update_status(_("Adding nodes to workflow...")) - created = [] + workflow = helpers.get_workflow(workflow_id, user, workspace) - automation = utils.get_automation(automation_id, user, workspace) - for wf in workflows: - with transaction.atomic(): - orm_workflow, node_mapping = utils.create_workflow( - user, automation, wf, tool_helpers - ) - created.append( - { - "id": orm_workflow.id, - "name": orm_workflow.name, - "state": orm_workflow.state, - } - ) - - # In separate transactions, try to update the formulas inside the workflow, - # so we don't block the main creation if something goes wrong here. - utils.update_workflow_formulas(wf, node_mapping, tool_helpers) - - # Navigate to the last created workflow - tool_helpers.navigate_to( - WorkflowNavigationType( - type="automation-workflow", - automation_id=automation.id, - workflow_id=orm_workflow.id, - workflow_name=orm_workflow.name, - ) + with transaction.atomic(): + created_nodes, node_mapping = helpers.add_nodes_to_workflow( + user, workflow, nodes, tool_helpers ) - return {"created_workflows": created} - - def load_workflow_automation_tools(): - """ - TOOL LOADER: Loads tools to manage workflows in an automation. - - After calling this loader, you will have access to: - - create_workflows: Create workflows with triggers, actions, and routers + # Generate formulas for nodes that need them + for orm_node, node_create in [(n, nodes[i]) for i, n in enumerate(created_nodes)]: + formulas = node_create.get_formulas_to_create(orm_node) + if formulas: + node_create.apply_direct_values(orm_node.service) + tool_helpers.update_status( + _( + "Generating formulas for node '%(label)s'..." + % {"label": orm_node.label} + ) + ) + with transaction.atomic(): + try: + agents.update_single_node_formulas( + node_create, orm_node, tool_helpers + ) + except Exception: + from loguru import logger + + logger.exception( + "Failed to generate formulas for node {}", orm_node.id + ) + + return { + "created_nodes": [ + {"id": n.id, "label": n.get_label(), "type": n.get_type().type} + for n in created_nodes + ] + } + + +def create_workflows( + ctx: RunContext[AssistantDeps], + automation_id: Annotated[ + int, Field(description="The ID of the automation to create workflows in.") + ], + workflows: Annotated[ + list[WorkflowCreate], + Field( + description="List of workflows to create, each with a trigger and action nodes." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create workflows with triggers and action nodes. + + WHEN to use: User wants automated workflows with triggers and action nodes. + WHAT it does: Creates workflows with a trigger and action/router/iterator nodes. Use {{ node.ref }} for referencing values from previous nodes. + RETURNS: Created workflows with id, name, state. + DO NOT USE when: Workflows with those names already exist — check with list_workflows first. + HOW: Each workflow needs exactly one trigger and one or more actions/routers. Use {{ node.ref }} syntax to reference previous node values in action formulas. Know the table_id and field_ids for row-based triggers and actions. + + ## Workflow Structure + + Each workflow has a trigger (the starting event) and action nodes (tasks to perform). + Nodes execute in sequence. Use {{ node.ref }} template syntax to reference + values from previous nodes. + + ## Dynamic Values with $formula: + + Any string field marked "Supports $formula:" can use dynamic values. + Prefix with '$formula:' + a natural-language description to auto-generate a formula + from context data. Otherwise the value is used as a literal. + - {"field_id": 123, "value": "$formula: the customer name from the trigger data"} + - {"field_id": 456, "value": "$formula: today's date"} + - {"field_id": 789, "value": "pending"} ← literal, no prefix + """ - Use this when you need to create workflows in an automation but don't have the tool. - """ # noqa: W505 + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @udspy.module_callback - def _load_workflow_automation_tools(context): - nonlocal user, workspace, tool_helpers + if not workflows: + return {"created_workflows": []} - observation = ["New tools are now available.\n"] + created = [] - create_tool = udspy.Tool(create_workflows) - new_tools = [create_tool] - observation.append( - "- Use `create_workflows` to create workflows in an automation." + automation = helpers.get_automation(automation_id, user, workspace) + for wf in workflows: + tool_helpers.raise_if_cancelled() + with transaction.atomic(): + orm_workflow, node_mapping = helpers.create_workflow( + user, automation, wf, tool_helpers + ) + created.append( + { + "id": orm_workflow.id, + "name": orm_workflow.name, + "state": orm_workflow.state, + } ) - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + # In separate transactions, try to update the formulas inside the workflow, + # so we don't block the main creation if something goes wrong here. + agents.update_workflow_formulas(wf, node_mapping, tool_helpers) + + # Navigate to the last created workflow + tool_helpers.navigate_to( + WorkflowNavigationType( + type="automation-workflow", + automation_id=automation.id, + workflow_id=orm_workflow.id, + workflow_name=orm_workflow.name, + ) + ) + + return {"created_workflows": created} + + +def update_nodes( + ctx: RunContext[AssistantDeps], + workflow_id: Annotated[ + int, Field(description="The ID of the workflow containing the nodes.") + ], + nodes: Annotated[ + list[NodeUpdate], + Field( + description="List of node updates, each with a node_id and properties to change." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Update automation node labels and service configuration. + + WHEN to use: User wants to rename a node, change email subject/body, update slack channel, etc. + WHAT it does: Updates node label and/or service config. Supports $formula: prefix for dynamic values. + RETURNS: Updated node IDs and any errors. + DO NOT USE when: You need to change a node's type — delete and recreate it instead. + HOW: Use list_workflows first to find the workflow and node IDs. + """ - return _load_workflow_automation_tools + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - return load_workflow_automation_tools + if not nodes: + return {"updated_nodes": []} + # Verify workflow belongs to workspace + helpers.get_workflow(workflow_id, user, workspace) -# ============================================================================ -# TOOL TYPE REGISTRY -# ============================================================================ + updated = [] + errors = [] + nodes_needing_formulas = [] + with transaction.atomic(): + for node_update in nodes: + tool_helpers.raise_if_cancelled() + try: + orm_node = helpers.update_node( + user, workspace, node_update, tool_helpers + ) + updated.append({"node_id": orm_node.id, "label": orm_node.label}) + + # Check if any fields need formula generation + formulas = node_update.get_formulas_to_update(orm_node) + if formulas: + nodes_needing_formulas.append((node_update, orm_node, formulas)) + except Exception as e: + errors.append(f"Error updating node {node_update.node_id}: {e}") + + # Apply direct values and generate formulas outside the main transaction + for node_update, orm_node, formulas in nodes_needing_formulas: + node_update.apply_direct_values(orm_node.service) + tool_helpers.update_status( + _("Generating formulas for node '%(label)s'..." % {"label": orm_node.label}) + ) + with transaction.atomic(): + try: + agents.update_single_node_formulas(node_update, orm_node, tool_helpers) + except Exception as exc: + from loguru import logger + + logger.exception( + "Failed to generate formulas for node {}: {}", orm_node.id, exc + ) -class ListWorkflowsToolType(AssistantToolType): - type = "list_workflows" + result: dict[str, Any] = {"updated_nodes": updated} + if errors: + result["errors"] = errors + return result + + +def delete_nodes( + ctx: RunContext[AssistantDeps], + node_ids: Annotated[ + list[int], + Field(description="List of node IDs to delete."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Delete automation nodes. + + WHEN to use: User wants to remove nodes from a workflow. + WHAT it does: Deletes the specified automation nodes. + RETURNS: Deleted node IDs and any errors. + DO NOT USE when: You want to modify a node — use update_nodes instead. + """ - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_list_workflows_tool(user, workspace, tool_helpers) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + if not node_ids: + return {"deleted_node_ids": []} -class WorkflowToolFactoryToolType(AssistantToolType): - type = "workflow_tool_factory" + deleted = [] + errors = [] - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_workflow_tool_factory(user, workspace, tool_helpers) + for node_id in node_ids: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Deleting node %(node_id)s...") % {"node_id": node_id} + ) + try: + helpers.delete_node(user, workspace, node_id) + deleted.append(node_id) + except Exception as e: + errors.append(f"Error deleting node {node_id}: {e}") + + result: dict[str, Any] = {"deleted_node_ids": deleted} + if errors: + result["errors"] = errors + return result + + +TOOL_FUNCTIONS = [ + list_workflows, + list_nodes, + create_workflows, + add_nodes, + update_nodes, + delete_nodes, +] +automation_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) + +ROUTING_RULES = """\ +- Check list_* before create_* to avoid duplicates. +- switch_mode: switch domain if task needs tools not in the current mode. +- create_workflows: use {{ node.ref }} for node refs, $formula: prefix for dynamic field values. +- add_nodes: insert/append nodes. Use list_nodes first to find existing node IDs.""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py index f2c9159123..5358c6af80 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/__init__.py @@ -1,26 +1,16 @@ from .node import ( - AiAgentNodeCreate, - CreateRowActionCreate, - DeleteRowActionCreate, - HasFormulasToCreateMixin, - NodeBase, - RouterNodeCreate, - SendEmailActionCreate, + ActionNodeCreate, + ActionNodeItem, + NodeUpdate, TriggerNodeCreate, - UpdateRowActionCreate, ) from .workflow import WorkflowCreate, WorkflowItem __all__ = [ "WorkflowCreate", "WorkflowItem", - "NodeBase", - "RouterNodeCreate", - "CreateRowActionCreate", - "UpdateRowActionCreate", - "DeleteRowActionCreate", - "SendEmailActionCreate", - "AiAgentNodeCreate", + "ActionNodeCreate", + "ActionNodeItem", + "NodeUpdate", "TriggerNodeCreate", - "HasFormulasToCreateMixin", ] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py index 1c4b916eb2..0c2f425324 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/node.py @@ -1,10 +1,17 @@ -from abc import ABC, abstractmethod -from typing import Annotated, Any, Literal, Optional +""" +Automation node type models and ORM conversion logic. + +Defines ``TriggerNodeCreate``, ``ActionNodeCreate``, and their read-back +counterparts (``TriggerNodeItem``, ``ActionNodeItem``), plus the dispatch +tables that convert between Pydantic models and Django ORM representations. +""" + +from typing import Any, Callable, Literal, Optional from uuid import uuid4 from django.conf import settings -from pydantic import Field, PrivateAttr +from pydantic import Field, PrivateAttr, model_serializer, model_validator from baserow.contrib.automation.nodes.models import AutomationNode from baserow.core.formula.types import ( @@ -13,81 +20,82 @@ ) from baserow.core.services.handler import ServiceHandler from baserow.core.services.models import Service +from baserow_enterprise.assistant.tools.shared.formula_utils import ( + FORMULA_PREFIX, + formula_desc, + literal_or_placeholder, + needs_formula, +) from baserow_enterprise.assistant.types import BaseModel +# Short marker appended to fields that support $formula: dynamic values. +# The full explanation lives in the create_workflows tool description. +SUPPORTS_FORMULA = f" Supports {FORMULA_PREFIX} prefix." -class NodeBase(BaseModel): - """Base node model.""" - - label: str = Field(..., description="The human readable name of the node") - type: str +# --------------------------------------------------------------------------- +# Field-mapping helpers (shared by apply_direct / update_formulas) +# --------------------------------------------------------------------------- -class RefCreate(BaseModel): - """Base node creation model.""" - - ref: str = Field( - ..., description="A reference ID for the node, only used during creation" - ) - - -class Item(BaseModel): - id: str +def _upsert_field_mappings( + service: Service, + values: dict[int, tuple[str, bool]], +): + """ + Bulk-upsert field mappings on a service. + + ``values`` maps ``field_id → (formula_value, enabled)``. + Existing mappings are updated in place; missing ones are created. + """ + + if not values: + return + + existing = {m.field_id: m for m in service.field_mappings.all()} + FieldMapping = service.field_mappings.model + to_create, to_update = [], [] + + for field_id, (formula, enabled) in values.items(): + if field_id in existing: + mapping = existing[field_id] + mapping.value = formula + mapping.enabled = enabled + to_update.append(mapping) + else: + to_create.append( + FieldMapping( + field_id=field_id, + value=formula, + enabled=enabled, + service_id=service.id, + ) + ) -class HasFormulasToCreateMixin(ABC): - @abstractmethod - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - """ - Creates and returns a mapping between field names and formulas to be created - for the given ORM node. Every value needs to contain instructions or description - on how to generate the formula for that field. - Prefix optional fields with "[optional]: " in the description to indicate they - are not mandatory. - """ + if to_create: + service.field_mappings.bulk_create(to_create) + if to_update: + FieldMapping.objects.bulk_update(to_update, ["value", "enabled"]) - pass - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - save = False - for field_name, formula in formulas.items(): - if hasattr(service, field_name): - setattr( - service, - field_name, - BaserowFormulaObject.create(formula=formula), - ) - save = True - if save: - ServiceHandler().update_service(service.get_type(), service) +# --------------------------------------------------------------------------- +# Sub-models +# --------------------------------------------------------------------------- class PeriodicTriggerSettings(BaseModel): - interval: Literal["MINUTE", "HOUR", "DAY", "WEEK", "MONTH"] = Field( - ..., description="The interval for the periodic trigger" - ) + """All times in UTC — remove timezone offsets.""" + + interval: Literal["MINUTE", "HOUR", "DAY", "WEEK", "MONTH"] minute: int = Field( default=0, - description=( - "If interval=MINUTE, the number of minutes between each trigger. " - f"Minimum is set to {settings.INTEGRATIONS_PERIODIC_MINUTE_MIN} minutes. " - "If interval=HOUR, the UTC minute for the periodic trigger. " - ), - ) - hour: int = Field( - default=0, - description=( - "The UTC hour for the periodic trigger. " - "ALWAYS remove timezone offset from the context." - ), - ) - day_of_week: int = Field( - default=0, - description="The day of the week for the periodic trigger (0=Monday, 6=Sunday)", - ) - day_of_month: int = Field( - default=1, description="The day of the month for the periodic trigger (1-31)" + ge=0, + le=59, + description=f"MINUTE: minutes between triggers (min {settings.INTEGRATIONS_PERIODIC_MINUTE_MIN}). HOUR: minute of the hour.", ) + hour: int = Field(default=0, ge=0, le=23, description="UTC hour (0-23).") + day_of_week: int = Field(default=0, ge=0, le=6, description="0=Monday, 6=Sunday.") + day_of_month: int = Field(default=1, ge=1, le=31, description="1-31.") class RowsTriggersSettings(BaseModel): @@ -96,9 +104,46 @@ class RowsTriggersSettings(BaseModel): table_id: int = Field(..., description="The ID of the table to monitor") -class TriggerNodeCreate(NodeBase, RefCreate): +class RouterEdgeCreate(BaseModel): + """Router branch. Order matters: first matching branch is taken.""" + + label: str = Field(description="Branch label.") + condition: str = Field( + description="Boolean condition using comparison operators and get() functions.", + ) + + _uid: str = PrivateAttr(default_factory=lambda: str(uuid4())) + + def to_orm_service_dict(self) -> dict[str, Any]: + return {"uid": self._uid, "label": self.label} + + +class RouterBranch(RouterEdgeCreate): + """Existing router branch with ID.""" + + id: str + + +class AutomationFieldValue(BaseModel): + """Field ID → value mapping for row actions.""" + + field_id: int = Field(..., description="Database field ID.") + value: str = Field(..., description=f"Field value.{SUPPORTS_FORMULA}") + + +# --------------------------------------------------------------------------- +# Trigger +# --------------------------------------------------------------------------- + + +_PERIODIC_KEYS = {"interval", "minute", "hour", "day_of_week", "day_of_month"} + + +class TriggerNodeCreate(BaseModel): """Create a trigger node in a workflow.""" + ref: str = Field(..., description="Temporary reference ID for creation.") + label: str = Field(..., description="Display name.") type: Literal[ "periodic", "http_trigger", @@ -107,16 +152,38 @@ class TriggerNodeCreate(NodeBase, RefCreate): "rows_deleted", ] - # periodic trigger specific periodic_interval: Optional[PeriodicTriggerSettings] = Field( default=None, - description="UTC configuration for periodic trigger. ALWAYS remove timezone offset from the context.", + description="(periodic) Schedule settings in UTC.", ) rows_triggers_settings: Optional[RowsTriggersSettings] = Field( default=None, - description="Configuration for rows trigger", + description="(rows_*) Table to monitor.", ) + @model_validator(mode="before") + @classmethod + def _fold_flat_periodic(cls, data): + """Accept flat periodic fields (interval, hour, ...) and nest them.""" + + if not isinstance(data, dict): + return data + if data.get("periodic_interval") is not None: + return data + flat = {k: data.pop(k) for k in list(data) if k in _PERIODIC_KEYS} + if flat: + data["periodic_interval"] = flat + return data + + @model_validator(mode="after") + def _validate_trigger_settings(self): + if self.type == "periodic" and self.periodic_interval is None: + raise ValueError("periodic trigger requires periodic_interval") + if self.type in ("rows_created", "rows_updated", "rows_deleted"): + if self.rows_triggers_settings is None: + raise ValueError(f"{self.type} trigger requires rows_triggers_settings") + return self + def to_orm_service_dict(self) -> dict[str, Any]: """Convert to ORM dict for node creation service.""" @@ -138,28 +205,128 @@ def to_orm_service_dict(self) -> dict[str, Any]: return {} -class TriggerNodeItem(TriggerNodeCreate, Item): +class TriggerNodeItem(TriggerNodeCreate): """Existing trigger node with ID.""" + id: str http_trigger_url: str | None = Field( default=None, description="The URL to trigger the HTTP request" ) -class EdgeCreate(BaseModel): - previous_node_ref: str = Field( - ..., - description="The reference ID of the previous node to link from. Every node can have only one previous node.", - ) +# --------------------------------------------------------------------------- +# Action node +# --------------------------------------------------------------------------- + +ActionNodeType = Literal[ + "router", + "smtp_email", + "slack_write_message", + "create_row", + "update_row", + "delete_row", + "ai_agent", +] + + +class ActionNodeCreate(BaseModel): + """Flat model for creating an action node: type + type-specific fields.""" + + ref: str = Field(..., description="Temporary reference ID for creation.") + label: str = Field(..., description="Display name.") + type: ActionNodeType + previous_node_ref: str = Field(..., description="Ref of the preceding node.") router_edge_label: str = Field( default="", - description="If the previous node is a router, the edge label to link from if different from default", + description="Branch label if previous node is a router.", + ) + + # -- router -- + edges: list[RouterEdgeCreate] | None = Field( + default=None, + description="(router) Branches. A default branch is auto-created.", + ) + + # -- smtp_email -- + to_emails: str | None = Field( + default=None, description=f"(smtp_email) Recipients.{SUPPORTS_FORMULA}" + ) + cc_emails: str | None = Field( + default=None, description=f"(smtp_email) CC.{SUPPORTS_FORMULA}" + ) + bcc_emails: str | None = Field( + default=None, description=f"(smtp_email) BCC.{SUPPORTS_FORMULA}" + ) + subject: str | None = Field( + default=None, description=f"(smtp_email) Subject.{SUPPORTS_FORMULA}" + ) + body: str | None = Field( + default=None, description=f"(smtp_email) Body.{SUPPORTS_FORMULA}" ) + body_type: Literal["plain", "html"] = "plain" - def to_orm_reference_node( - self, node_mapping: dict - ) -> tuple[Optional[int], Optional[str]]: - """Get the ORM node ID and output label from the previous node reference.""" + # -- slack_write_message -- + channel: str | None = None + text: str | None = Field( + default=None, description=f"(slack) Message.{SUPPORTS_FORMULA}" + ) + + # -- create_row / update_row / delete_row -- + table_id: int | None = None + row_id: str | None = Field( + default=None, description=f"(update/delete_row) Row ID.{SUPPORTS_FORMULA}" + ) + values: list[AutomationFieldValue] | None = None + + # -- ai_agent -- + output_type: Literal["text", "choice"] = Field( + default="text", + description="(ai_agent) Chain another action to use the output.", + ) + choices: list[str] | None = Field( + default=None, + description="(ai_agent) Choices if output_type='choice'.", + ) + prompt: str | None = Field( + default=None, description=f"(ai_agent) Prompt.{SUPPORTS_FORMULA}" + ) + + # Required fields per type + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "router": [("edges", "edges")], + "smtp_email": [ + ("to_emails", "to_emails"), + ("subject", "subject"), + ("body", "body"), + ], + "slack_write_message": [("channel", "channel"), ("text", "text")], + "create_row": [("table_id", "table_id"), ("values", "values")], + "update_row": [ + ("table_id", "table_id"), + ("row_id", "row_id"), + ("values", "values"), + ], + "delete_row": [("table_id", "table_id"), ("row_id", "row_id")], + "ai_agent": [("prompt", "prompt")], + } + + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if getattr(self, attr) is None] + if missing: + raise ValueError(f"{self.type} requires {', '.join(missing)}") + return self + + # -- ORM conversion -- + + def to_orm_service_dict(self) -> dict[str, Any]: + """Convert type-specific fields to an ORM service dict.""" + return _TO_ORM_SERVICE[self.type](self) + + def to_orm_reference_node(self, node_mapping: dict) -> tuple[Optional[int], str]: + """Resolve the previous node reference into an ORM node ID and output label.""" if self.previous_node_ref not in node_mapping: raise ValueError( @@ -169,13 +336,13 @@ def to_orm_reference_node( previous_orm_node, previous_node_create = node_mapping[self.previous_node_ref] output = "" - if self.router_edge_label and previous_node_create.type == "router": + if ( + self.router_edge_label + and getattr(previous_node_create, "type", None) == "router" + ): + edges = getattr(previous_node_create, "edges", None) or [] output = next( - ( - edge._uid - for edge in previous_node_create.edges - if edge.label == self.router_edge_label - ), + (edge._uid for edge in edges if edge.label == self.router_edge_label), None, ) if output is None: @@ -185,356 +352,541 @@ def to_orm_reference_node( return previous_orm_node.id, output + # -- Formula lifecycle -- -class RouterEdgeCreate(BaseModel): - """Router branch configuration.""" + def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str] | None: + """Return a ``{key: description}`` dict of formulas to generate, or None.""" - label: str = Field( - description="The label of the router branch. Order of branches matters: first matching branch is taken.", - ) - condition: str = Field( - description=( - "The condition formula to evaluate for this branch as boolean. " - "Use comparison operators and get(...) functions to build the formula with a boolean result. " - "Always mentions the field values using get(...) functions." - ), - ) + fn = _GET_FORMULAS.get(self.type) + return fn(self, orm_node) if fn else None - _uid: str = PrivateAttr(default_factory=lambda: str(uuid4())) + def apply_direct_values(self, service: Service): + """Apply literal (non-$formula) values directly to the service.""" - def to_orm_service_dict(self) -> dict[str, Any]: - return { - "uid": self._uid, - "label": self.label, - } + fn = _APPLY_DIRECT.get(self.type) + if fn is not None: + fn(self, service) + def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): + """Write generated formulas back to the ORM service.""" -class RouterBranch(RouterEdgeCreate, Item): - """Existing router branch with ID.""" + fn = _UPDATE_FORMULAS.get(self.type) + if fn is not None: + fn(self, service, formulas) + else: + _default_update_formulas(service, formulas) -class RouterNodeBase(NodeBase): - """Create a router node with branches.""" +# --------------------------------------------------------------------------- +# to_orm_service dispatch: (ActionNodeCreate) -> dict +# --------------------------------------------------------------------------- - type: Literal["router"] - edges: list[RouterEdgeCreate] = Field( - ..., - description="List of branches for the router node. A default branch is created automatically.", - ) +def _router_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return {"edges": [branch.to_orm_service_dict() for branch in n.edges]} -class RouterNodeCreate(RouterNodeBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin): - """Create a router node with branches and link configuration.""" - def to_orm_service_dict(self) -> dict[str, Any]: - return {"edges": [branch.to_orm_service_dict() for branch in self.edges]} +def _email_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return { + "to_email": literal_or_placeholder(n.to_emails), + "cc_email": literal_or_placeholder(n.cc_emails), + "bcc_email": literal_or_placeholder(n.bcc_emails), + "subject": literal_or_placeholder(n.subject), + "body": literal_or_placeholder(n.body), + "body_type": f"'{n.body_type}'", + } - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - return {edge.label: edge.condition for edge in self.edges} - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - orm_edges = service.specific.edges.all() - formulas = {k.lower(): v for k, v in formulas.items()} - EdgeModel = service.specific.edges.model - updates = [] - for orm_edge in orm_edges: - label = orm_edge.label.lower() - if label in formulas: - orm_edge.condition["mode"] = BASEROW_FORMULA_MODE_ADVANCED - orm_edge.condition["formula"] = formulas[label] - updates.append(orm_edge) - if updates: - EdgeModel.objects.bulk_update(updates, ["condition"]) - - -class RouterNodeItem(RouterNodeBase, Item): - """Existing router node with ID.""" - - -class SendEmailActionBase(NodeBase): - """Send email action configuration.""" - - type: Literal["smtp_email"] - to_emails: str - cc_emails: Optional[str] - bcc_emails: Optional[str] - subject: str - body: str - body_type: Literal["plain", "html"] = Field(default="plain") - - -class SendEmailActionCreate( - SendEmailActionBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create a send email action with edge configuration.""" +def _slack_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + channel = (n.channel or "").lstrip("#") + return { + "channel": channel, + "text": literal_or_placeholder(n.text), + } - def to_orm_service_dict(self) -> dict[str, Any]: - return { - "to_email": f"'{self.to_emails}'", - "cc_email": f"'{self.cc_emails or ''}'", - "bcc_email": f"'{self.bcc_emails or ''}'", - "subject": f"'{self.subject}'", - "body": f"'{self.body}'", - "body_type": f"'{self.body_type}'", - } - - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - values = {} - to_emails_base = ( - "A comma separated list of email addresses to send the email to." - ) - if self.to_emails: - values["to_emails"] = ( - to_emails_base + f" Value to resolve: {self.to_emails}" - ) - else: - values["to_emails"] = "[optional]: " + to_emails_base - cc_emails_base = "A comma separated list of email addresses to CC the email to." - if self.cc_emails: - values["cc_emails"] = ( - cc_emails_base + f" Value to resolve: {self.cc_emails}" - ) - else: - values["cc_emails"] = "[optional]: " + cc_emails_base - - bcc_emails_base = ( - "A comma separated list of email addresses to BCC the email to." - ) - if self.bcc_emails: - values["bcc_emails"] = ( - bcc_emails_base + f" Value to resolve: {self.bcc_emails}" - ) - else: - values["bcc_emails"] = "[optional]: " + bcc_emails_base +def _row_action_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return {"table_id": n.table_id} - values["subject"] = "The subject of the email." - if self.subject: - values["subject"] += f" Value to resolve: {self.subject}" - values["body"] = f"The {self.body_type} body content of the email." - if self.body: - values["body"] += f" Value to resolve: {self.body}" - return values +def _ai_agent_to_orm(n: ActionNodeCreate) -> dict[str, Any]: + return { + "ai_choices": (n.choices or []) if n.output_type == "choice" else [], + "ai_prompt": literal_or_placeholder(n.prompt), + "ai_output_type": n.output_type, + } -class SendEmailActionItem(SendEmailActionBase, Item): - """Existing send email action with ID.""" +_TO_ORM_SERVICE: dict[str, Callable] = { + "router": _router_to_orm, + "smtp_email": _email_to_orm, + "slack_write_message": _slack_to_orm, + "create_row": _row_action_to_orm, + "update_row": _row_action_to_orm, + "delete_row": _row_action_to_orm, + "ai_agent": _ai_agent_to_orm, +} -class SlackWriteMessageActionBase(NodeBase): - """Send Slack message action configuration.""" +# --------------------------------------------------------------------------- +# get_formulas_to_create dispatch: (ActionNodeCreate, AutomationNode) -> dict | None +# --------------------------------------------------------------------------- - type: Literal["slack_write_message"] - channel: str - text: str +def _router_formulas(n: ActionNodeCreate, orm_node: AutomationNode) -> dict[str, str]: + return {edge.label: edge.condition for edge in n.edges} -class SlackWriteMessageActionCreate( - SlackWriteMessageActionBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create a send Slack message action with edge configuration.""" - def to_orm_service_dict(self) -> dict[str, Any]: +def _email_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + fields = { + "to_emails": ( + "A comma separated list of email addresses to send the email to.", + n.to_emails, + ), + "cc_emails": ( + "A comma separated list of email addresses to CC the email to.", + n.cc_emails, + ), + "bcc_emails": ( + "A comma separated list of email addresses to BCC the email to.", + n.bcc_emails, + ), + "subject": ("The subject of the email.", n.subject), + "body": (f"The {n.body_type} body content of the email.", n.body), + } + values = { + key: f"{base_desc} Value to resolve: {formula_desc(val)}" + for key, (base_desc, val) in fields.items() + if needs_formula(val) + } + return values or None + + +def _slack_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.text): return { - "channel": self.channel, - "text": f"'{self.text}'", + "text": f"The message content. Value to resolve: {formula_desc(n.text)}" } + return None - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - values = {} - message_base = "The message content." - if self.text: - values["text"] = message_base + f" Value to resolve: '{self.text}'" - else: - values["text"] = "[optional]: " + message_base - return values +def _row_action_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + from baserow_enterprise.assistant.tools.shared.formula_utils import ( + minimize_json_schema, + ) + service = orm_node.service.specific + schema = service.get_type().generate_schema(service.specific) + values_by_id = {fv.field_id: fv.value for fv in (n.values or [])} + values = {} -class CreateRowActionBase(NodeBase): - """Create row action configuration.""" + if needs_formula(n.row_id): + values["row_id"] = ( + f"the row ID to update. Value to resolve: {formula_desc(n.row_id)}" + ) - type: Literal["create_row"] - table_id: int - values: dict[int, Any] = Field( - ..., description="A mapping of field IDs to values or formulas to update" - ) + for v in minimize_json_schema(schema).values(): + value = values_by_id.get(int(v["id"])) + if needs_formula(value): + desc = v["desc"] + f" Value to resolve: {formula_desc(value)}" + values[int(v["id"])] = {**v, "desc": desc} + return values or None -class RowActionService: - def to_orm_service_dict(self) -> dict[str, Any]: + +def _ai_agent_formulas( + n: ActionNodeCreate, orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.prompt): return { - "table_id": self.table_id, + "ai_prompt": f"The AI prompt. Value to resolve: {formula_desc(n.prompt)}" } + return None -class RowActionFormulaToCreate(HasFormulasToCreateMixin): - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - from baserow_enterprise.assistant.tools.automation.utils import ( - _minimize_json_schema, - ) +_GET_FORMULAS: dict[str, Callable] = { + "router": _router_formulas, + "smtp_email": _email_formulas, + "slack_write_message": _slack_formulas, + "create_row": _row_action_formulas, + "update_row": _row_action_formulas, + "delete_row": _row_action_formulas, + "ai_agent": _ai_agent_formulas, +} - service = orm_node.service.specific - schema = service.get_type().generate_schema(service.specific) - values = {"row_id": "the row ID to update"} - for v in _minimize_json_schema(schema).values(): - desc = v["desc"] - value = self.values.get(int(v["id"])) - if value: - desc += f" Value to resolve: {value}" - else: - desc = "[optional]: " + desc - values[int(v["id"])] = {**v, "desc": desc} - return values - def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): - row_id_formula = formulas.pop("row_id", None) - - field_mappings = {m.field_id: m for m in service.field_mappings.all()} - field_mapping_to_create = [] - field_mapping_to_update = [] - FieldMapping = service.field_mappings.model - for field_id, formula in formulas.items(): - if field_id in field_mappings: - field_mappings[field_id].value = formula - field_mappings[field_id].enabled = True - field_mapping_to_update.append(field_mappings[field_id]) - else: - field_mapping_to_create.append( - FieldMapping( - field_id=field_id, - value=formula, - enabled=True, - service_id=service.id, - ) - ) - if field_mapping_to_create: - service.field_mappings.bulk_create(field_mapping_to_create) - if field_mapping_to_update: - FieldMapping.objects.bulk_update( - field_mapping_to_update, ["value", "enabled"] - ) +# --------------------------------------------------------------------------- +# update_service_with_formulas dispatch +# --------------------------------------------------------------------------- + - if row_id_formula: - service.row_id = row_id_formula - ServiceHandler().update_service(service.get_type(), service) +def _default_update_formulas(service: Service, formulas: dict[str, str]): + """Set ``BaserowFormulaObject`` on named service fields.""" + save = False + for field_name, formula in formulas.items(): + if hasattr(service, field_name): + setattr(service, field_name, BaserowFormulaObject.create(formula=formula)) + save = True + if save: + ServiceHandler().update_service(service.get_type(), service) -class CreateRowActionCreate( - RowActionService, - CreateRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, + +def _router_update_formulas( + n: ActionNodeCreate, service: Service, formulas: dict[str, str] +): + """Write generated condition formulas to router edges.""" + + formulas_lower = {k.lower(): v for k, v in formulas.items()} + EdgeModel = service.specific.edges.model + updates = [] + for orm_edge in service.specific.edges.all(): + label = orm_edge.label.lower() + if label in formulas_lower: + orm_edge.condition["mode"] = BASEROW_FORMULA_MODE_ADVANCED + orm_edge.condition["formula"] = formulas_lower[label] + updates.append(orm_edge) + if updates: + EdgeModel.objects.bulk_update(updates, ["condition"]) + + +def _row_action_update_formulas( + n: ActionNodeCreate, service: Service, formulas: dict[str, str] ): - """Create a create row action with edge configuration.""" + """Write generated formulas to row action field mappings and row_id.""" + row_id_formula = formulas.pop("row_id", None) -class CreateRowActionItem(CreateRowActionBase, Item): - """Existing create row action with ID.""" + _upsert_field_mappings( + service, + {field_id: (formula, True) for field_id, formula in formulas.items()}, + ) + + if row_id_formula: + service.row_id = row_id_formula + ServiceHandler().update_service(service.get_type(), service) + + +_UPDATE_FORMULAS: dict[str, Callable] = { + "router": _router_update_formulas, + "create_row": _row_action_update_formulas, + "update_row": _row_action_update_formulas, + "delete_row": _row_action_update_formulas, +} + + +# --------------------------------------------------------------------------- +# apply_direct_values dispatch +# --------------------------------------------------------------------------- -class UpdateRowActionBase(NodeBase): - """Update row action configuration.""" +def _row_action_apply_direct(n: ActionNodeCreate, service: Service): + """Write literal (non-$formula) field values as quoted formulas.""" - type: Literal["update_row"] - table_id: int - row_id: str = Field(..., description="The row ID or a formula to identify the row") - values: dict[int, Any] = Field( - ..., description="A mapping of field IDs to values or formulas to update" + _upsert_field_mappings( + service, + { + fv.field_id: (f"'{fv.value}'", True) + for fv in (n.values or []) + if not needs_formula(fv.value) + }, ) + if n.row_id and not needs_formula(n.row_id): + service.row_id = f"'{n.row_id}'" + ServiceHandler().update_service(service.get_type(), service) -class UpdateRowActionCreate( - RowActionService, - UpdateRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, -): - """Create an update row action with edge configuration.""" +_APPLY_DIRECT: dict[str, Callable] = { + "create_row": _row_action_apply_direct, + "update_row": _row_action_apply_direct, + "delete_row": _row_action_apply_direct, +} -class UpdateRowActionItem(UpdateRowActionBase, Item): - """Existing update row action with ID.""" +# --------------------------------------------------------------------------- +# ActionNodeItem (read-back) +# --------------------------------------------------------------------------- -class DeleteRowActionBase(NodeBase): - """Delete row action configuration.""" - type: Literal["delete_row"] - table_id: int - row_id: str = Field(..., description="The row ID or a formula to identify the row") +# --------------------------------------------------------------------------- +# NodeUpdate (for update_nodes tool) +# --------------------------------------------------------------------------- -class DeleteRowActionCreate( - RowActionService, - DeleteRowActionBase, - RefCreate, - EdgeCreate, - RowActionFormulaToCreate, -): - """Create a delete row action with edge configuration.""" +class NodeUpdate(BaseModel): + """Flat model for updating an automation node.""" + node_id: int = Field(..., description="The ID of the node to update.") + label: str | None = Field(None, description="New display name.") -class DeleteRowActionItem(DeleteRowActionBase, Item): - """Existing delete row action with ID.""" + # -- smtp_email -- + to_emails: str | None = Field( + default=None, description=f"(smtp_email) Recipients.{SUPPORTS_FORMULA}" + ) + cc_emails: str | None = Field( + default=None, description=f"(smtp_email) CC.{SUPPORTS_FORMULA}" + ) + bcc_emails: str | None = Field( + default=None, description=f"(smtp_email) BCC.{SUPPORTS_FORMULA}" + ) + subject: str | None = Field( + default=None, description=f"(smtp_email) Subject.{SUPPORTS_FORMULA}" + ) + body: str | None = Field( + default=None, description=f"(smtp_email) Body.{SUPPORTS_FORMULA}" + ) + body_type: Literal["plain", "html"] | None = None + # -- slack_write_message -- + channel: str | None = None + text: str | None = Field( + default=None, description=f"(slack) Message.{SUPPORTS_FORMULA}" + ) -class AiAgentNodeBase(NodeBase): - """AI Agent action configuration.""" + # -- create_row / update_row / delete_row -- + table_id: int | None = None + row_id: str | None = Field( + default=None, description=f"(update/delete_row) Row ID.{SUPPORTS_FORMULA}" + ) + values: list[AutomationFieldValue] | None = None - type: Literal["ai_agent"] = Field( - ..., - description="Don't stop at this node. Chain some other action to use the AI output.", + # -- ai_agent -- + output_type: Literal["text", "choice"] | None = None + choices: list[str] | None = None + prompt: str | None = Field( + default=None, description=f"(ai_agent) Prompt.{SUPPORTS_FORMULA}" ) - output_type: Literal["text", "choice"] = Field(default="text") - choices: Optional[list[str]] = Field( - default=None, - description="List of choices if output_type is 'choice'", + + def to_update_service_dict(self, current_type: str) -> dict[str, Any] | None: + """Build a service kwargs dict from non-None fields. Returns None if no service fields set.""" + builder = _TO_UPDATE_SERVICE.get(current_type) + if builder is None: + return None + result = builder(self) + return result if result else None + + def get_formulas_to_update(self, orm_node: AutomationNode) -> dict[str, str] | None: + """Return a {key: description} dict of formulas to generate, or None.""" + fn = _GET_UPDATE_FORMULAS.get( + orm_node.service.get_type().type if orm_node.service else None + ) + return fn(self, orm_node) if fn else None + + def apply_direct_values(self, service: Service): + """Apply literal (non-$formula) values directly to the service.""" + fn = _APPLY_UPDATE_DIRECT.get(service.get_type().type if service else None) + if fn is not None: + fn(self, service) + + def update_service_with_formulas(self, service: Service, formulas: dict[str, str]): + """Write generated formulas back to the ORM service.""" + stype = service.get_type().type if service else None + fn = _UPDATE_FORMULAS.get(stype) + if fn is not None: + # Reuse the existing dispatch (expects ActionNodeCreate-like but works for our purposes) + fn(self, service, formulas) + else: + _default_update_formulas(service, formulas) + + +# -- to_update_service dispatch -- + + +def _email_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.to_emails is not None: + d["to_email"] = literal_or_placeholder(n.to_emails) + if n.cc_emails is not None: + d["cc_email"] = literal_or_placeholder(n.cc_emails) + if n.bcc_emails is not None: + d["bcc_email"] = literal_or_placeholder(n.bcc_emails) + if n.subject is not None: + d["subject"] = literal_or_placeholder(n.subject) + if n.body is not None: + d["body"] = literal_or_placeholder(n.body) + if n.body_type is not None: + d["body_type"] = f"'{n.body_type}'" + return d + + +def _slack_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.channel is not None: + d["channel"] = n.channel.lstrip("#") + if n.text is not None: + d["text"] = literal_or_placeholder(n.text) + return d + + +def _row_action_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.table_id is not None: + d["table_id"] = n.table_id + return d + + +def _ai_agent_update_service(n: "NodeUpdate") -> dict[str, Any]: + d = {} + if n.prompt is not None: + d["ai_prompt"] = literal_or_placeholder(n.prompt) + if n.output_type is not None: + d["ai_output_type"] = n.output_type + if n.choices is not None: + d["ai_choices"] = n.choices + return d + + +_TO_UPDATE_SERVICE: dict[str, Callable] = { + "smtp_email": _email_update_service, + "slack_write_message": _slack_update_service, + "create_row": _row_action_update_service, + "update_row": _row_action_update_service, + "delete_row": _row_action_update_service, + "ai_agent": _ai_agent_update_service, +} + + +# -- get_formulas_to_update dispatch -- + + +def _email_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + fields = { + "to_emails": ("Recipients.", n.to_emails), + "cc_emails": ("CC.", n.cc_emails), + "bcc_emails": ("BCC.", n.bcc_emails), + "subject": ("Subject.", n.subject), + "body": ("Body.", n.body), + } + values = { + key: f"{base_desc} Value to resolve: {formula_desc(val)}" + for key, (base_desc, val) in fields.items() + if needs_formula(val) + } + return values or None + + +def _slack_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.text): + return { + "text": f"The message content. Value to resolve: {formula_desc(n.text)}" + } + return None + + +def _row_action_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + from baserow_enterprise.assistant.tools.shared.formula_utils import ( + minimize_json_schema, ) - prompt: str + service = orm_node.service.specific + schema = service.get_type().generate_schema(service.specific) + values_by_id = {fv.field_id: fv.value for fv in (n.values or [])} + values = {} -class AiAgentNodeCreate( - AiAgentNodeBase, RefCreate, EdgeCreate, HasFormulasToCreateMixin -): - """Create an AI Agent action with edge configuration.""" + if needs_formula(n.row_id): + values["row_id"] = f"the row ID. Value to resolve: {formula_desc(n.row_id)}" - def to_orm_service_dict(self) -> dict[str, Any]: + for v in minimize_json_schema(schema).values(): + value = values_by_id.get(int(v["id"])) + if needs_formula(value): + desc = v["desc"] + f" Value to resolve: {formula_desc(value)}" + values[int(v["id"])] = {**v, "desc": desc} + + return values or None + + +def _ai_agent_update_formulas( + n: "NodeUpdate", orm_node: AutomationNode +) -> dict[str, str] | None: + if needs_formula(n.prompt): return { - "ai_choices": (self.choices or []) if self.output_type == "choice" else [], - "ai_prompt": f"'{self.prompt}'", - "ai_output_type": self.output_type, + "ai_prompt": f"The AI prompt. Value to resolve: {formula_desc(n.prompt)}" } + return None - def get_formulas_to_create(self, orm_node: AutomationNode) -> dict[str, str]: - return {"ai_prompt": self.prompt} +_GET_UPDATE_FORMULAS: dict[str, Callable] = { + "smtp_email": _email_update_formulas, + "slack_write_message": _slack_update_formulas, + "create_row": _row_action_update_formulas, + "update_row": _row_action_update_formulas, + "delete_row": _row_action_update_formulas, + "ai_agent": _ai_agent_update_formulas, +} -class AiAgentNodeItem(AiAgentNodeBase, Item): - """Existing AI Agent action with ID.""" +# -- apply_direct_values dispatch for update -- -AnyNodeCreate = Annotated[ - RouterNodeCreate - # actions - | SendEmailActionCreate - | SlackWriteMessageActionCreate - | CreateRowActionCreate - | UpdateRowActionCreate - | DeleteRowActionCreate - | AiAgentNodeCreate, - Field(discriminator="type"), -] -AnyNodeItem = ( - RouterNodeItem - # actions - | SendEmailActionItem - | CreateRowActionItem - | UpdateRowActionItem - | DeleteRowActionItem - | AiAgentNodeItem -) +def _row_action_update_apply_direct(n: "NodeUpdate", service: Service): + """Write literal (non-$formula) field values as quoted formulas.""" + _upsert_field_mappings( + service, + { + fv.field_id: (f"'{fv.value}'", True) + for fv in (n.values or []) + if not needs_formula(fv.value) + }, + ) + if n.row_id and not needs_formula(n.row_id): + service.row_id = f"'{n.row_id}'" + ServiceHandler().update_service(service.get_type(), service) + + +_APPLY_UPDATE_DIRECT: dict[str, Callable] = { + "create_row": _row_action_update_apply_direct, + "update_row": _row_action_update_apply_direct, + "delete_row": _row_action_update_apply_direct, +} + + +class ActionNodeItem(BaseModel): + """Existing action node with ID — flat structure, excludes None values.""" + + id: str + label: str + type: str + previous_node_ref: str | None = None + router_edge_label: str | None = None + + # (router) + edges: list[RouterBranch] | None = None + + # (smtp_email) + to_emails: str | None = None + cc_emails: str | None = None + bcc_emails: str | None = None + subject: str | None = None + body: str | None = None + body_type: str | None = None + + # (slack_write_message) + channel: str | None = None + text: str | None = None + + # (create_row, update_row, delete_row) + table_id: int | None = None + row_id: str | None = None + values: list[AutomationFieldValue] | None = None + + # (ai_agent) + output_type: str | None = None + choices: list[str] | None = None + prompt: str | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py index 5470d91648..5fe625816c 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/types/workflow.py @@ -1,66 +1,18 @@ -from typing import Annotated, Literal - from pydantic import Field from baserow_enterprise.assistant.types import BaseModel -from .node import AnyNodeCreate, TriggerNodeCreate - - -class WorkflowEdgeCreate(BaseModel): - """Workflow edge connecting two nodes.""" - - type: Literal["edge"] - from_node_label: str = Field( - ..., - description="The label of the node where the edge starts", - ) - to_node_label: str = Field( - ..., - description="The label of the node where the edge ends", - ) - - -class WorkflowRouterEdgeCreate(WorkflowEdgeCreate): - """Workflow edge connecting to a router node with a branch label.""" - - type: Literal["router_branch"] - router_branch_label: str = Field( - default="", - description="The branch label for the router node edge", - ) - - -AnyWorkflowEdgeCreate = Annotated[ - WorkflowEdgeCreate, - WorkflowRouterEdgeCreate, - Field( - discriminator="type", - default="edge", - description=( - "The type of workflow edge. Use 'edge' in normal linear (a follows b) connections. " - "Use 'router_branch' when connecting to a router node with a branch label. " - ), - ), -] +from .node import ActionNodeCreate, TriggerNodeCreate class WorkflowCreate(BaseModel): """Base workflow model.""" - name: str = Field(..., description="The name of the workflow") - trigger: TriggerNodeCreate = Field( - ..., - description="The trigger node configuration for the workflow", - ) - nodes: list[AnyNodeCreate] = Field( + name: str = Field(..., description="Workflow name.") + trigger: TriggerNodeCreate = Field(..., description="The trigger node.") + nodes: list[ActionNodeCreate] = Field( default_factory=list, - description=( - "The nodes executed or evaluated once the trigger fires. " - "Every node must have only one incoming edge. If the previous node is a router, " - "the branch label must be specified for non-default branches. " - "Only if explicitly requested, this list can be empty." - ), + description="Action nodes executed after the trigger. Each node has one previous_node_ref.", ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py deleted file mode 100644 index 2a2dcac63a..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py +++ /dev/null @@ -1,390 +0,0 @@ -from datetime import date, datetime -from typing import TYPE_CHECKING, Any, Tuple - -from django.contrib.auth.models import AbstractUser -from django.db import transaction -from django.utils.translation import gettext as _ - -import udspy -from loguru import logger -from pydantic import ConfigDict - -from baserow.contrib.automation.models import Automation -from baserow.contrib.automation.nodes.models import AutomationNode -from baserow.contrib.automation.nodes.registries import automation_node_type_registry -from baserow.contrib.automation.nodes.service import AutomationNodeService -from baserow.contrib.automation.workflows.models import AutomationWorkflow -from baserow.contrib.automation.workflows.service import AutomationWorkflowService -from baserow.core.formula import resolve_formula -from baserow.core.formula.registries import formula_runtime_function_registry -from baserow.core.formula.types import ( - BASEROW_FORMULA_MODE_ADVANCED, - BaserowFormulaObject, - FormulaContext, -) -from baserow.core.models import Workspace -from baserow.core.service import CoreService -from baserow.core.utils import to_path - -from .prompts import GENERATE_FORMULA_PROMPT -from .types import HasFormulasToCreateMixin, NodeBase, WorkflowCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -def _minimize_json_schema(schema) -> dict[str, dict[str, str]]: - """ - Generate a mapping between field ids and names/types from a JSON schema. - Useful when generating formulas to understand the provided context. - """ - - field_type_descriptions = { - "link_row": "the row ID as number or the primary field value as string", - "single_select": "the option ID as number or the value as string", - "multiple_select": "a comma separated list of option IDs or values as string", - "date": "a date string in ISO 8601 format", - "date_time": "a date-time string in ISO 8601 format", - "boolean": "true or false", - } - field_type_extra_info = { - "single_select": lambda meta: { - "select_options": meta.get("select_options", []) - }, - "multiple_select": lambda meta: { - "select_options": meta.get("select_options", []) - }, - "multiple_collaborators": lambda meta: { - "available_collaborators": meta.get("available_collaborators", []) - }, - } - - if schema.get("type") == "array": - return _minimize_json_schema(schema.get("items")) - elif schema.get("type") != "object": - raise ValueError("Schema must be of type object or array of objects") - - properties = schema.get("properties", {}) - mapping = {} - for key, prop in properties.items(): - metadata = prop.get("metadata") - if metadata: - field_type = metadata["type"] - mapping[key] = { - "id": metadata["id"], - "name": metadata["name"], - "type": field_type, - "desc": field_type_descriptions.get(field_type, ""), - } - if field_type in field_type_extra_info: - get_extra_info = field_type_extra_info[field_type] - mapping[key].update(get_extra_info(metadata)) - return mapping - - -def _create_example_from_json_schema(schema) -> Tuple[dict, dict]: - """ - Generate example data from a JSON schema. - Useful when generating formulas to provide example context data. - """ - - examples = { - "string": "text", - "number": 1, - "boolean": True, - "null": None, - "object": lambda prop: _create_example_from_json_schema(prop), - "array": lambda prop: [_create_example_from_json_schema(prop["items"])], - } - - if schema.get("type") == "array": - return [_create_example_from_json_schema(schema.get("items"))] - elif schema.get("type") != "object": - raise ValueError("Schema must be of type object or array of objects") - - properties = schema.get("properties", {}) - example = {} - for key, prop in properties.items(): - value = examples[prop.get("type")] - if callable(value): - example[key] = value(prop) - else: - example[key] = value - return example - - -class AssistantFormulaContext(FormulaContext): - def __init__(self): - self.context = {} - self.context_metadata = {} - super().__init__() - - def add_node_context( - self, - node_id: int | str, - node_context: dict[str, any], - context_metadata: dict[str, dict[str, str]] | None = None, - ): - """Update the formula context with new values.""" - - self.context.update({str(node_id): node_context}) - if context_metadata: - self.context_metadata.update({str(node_id): context_metadata}) - - def get_formula_context(self) -> dict[str, any]: - return {"previous_node": self.context} - - def get_context_metadata(self) -> dict[str, any]: - return self.context_metadata - - def __getitem__(self, key) -> any: - start, *key_parts = to_path(key) - if start != "previous_node": - raise KeyError( - f"Key '{key}' not found in context. Only 'previous_node' is supported at the root level." - ) - value = self.context - for kp in key_parts: - try: - value = value[int(kp) if isinstance(value, list) else kp] - except (KeyError, TypeError, ValueError): - available_keys = ( - list(value.keys()) - if isinstance(value, dict) - else ", ".join(map(str, range(len(value)))) - ) - raise KeyError( - f"Key '{kp}' of '{key}' not found in {value}, Available keys: {available_keys}" - ) - if not isinstance(value, (int, float, str, bool, date, datetime)): - raise ValueError( - f"Value for key '{key}' is not a valid type. " - f"Expected int, float, str, bool, date, or datetime. " - f"Got {type(value).__name__} instead. " - f"Make sure to only reference primitive types in the formula context." - ) - return value - - -def get_generate_formulas_tool(): - class RuntimeFormulaGenerator(udspy.Signature): - __doc__ = GENERATE_FORMULA_PROMPT - - fields_to_resolve: dict[str, dict[str, str]] = udspy.InputField( - desc=( - "The fields that need formulas to be generated. " - "If prefixed with [optional], the field is not mandatory." - ) - ) - context: dict[str, Any] = udspy.InputField( - desc="The available context to use in formula generation composed of previous nodes results." - ) - context_metadata: dict[str, Any] = udspy.InputField( - desc="Metadata about the context fields, with refs and names to assist in formula generation." - ) - feedback: str = udspy.InputField( - desc="Validation errors from previous attempt. Empty if first attempt." - ) - generated_formulas: dict[str, Any] = udspy.OutputField() - - model_config = ConfigDict(arbitrary_types_allowed=True) - - def check_formula(generated_formula: str, context: AssistantFormulaContext) -> str: - try: - resolve_formula( - BaserowFormulaObject.create( - formula=generated_formula, mode=BASEROW_FORMULA_MODE_ADVANCED - ), - formula_runtime_function_registry, - context, - ) - except Exception as exc: - raise ValueError(f"Generated formula is invalid: {str(exc)}") - return "ok, the formula is valid" - - def generate_node_formulas( - fields_to_resolve: dict, - context: AssistantFormulaContext, - max_retries: int = 3, - ) -> str: - """ - For every non-null input field in the node's schema, generate a formula - that fulfills the request, using the provided context object. - """ - - predict = udspy.Predict(RuntimeFormulaGenerator) - feedback = "" - for __ in range(max_retries): - result = predict( - fields_to_resolve=fields_to_resolve, - context=context.get_formula_context(), - context_metadata=context.get_context_metadata(), - feedback=feedback, - ) - # Ensure all the generated formulas are valid - valid_formulas = {} - generated_formulas = result.generated_formulas - for field_id, formula in generated_formulas.items(): - try: - check_formula(formula, context) - valid_formulas[field_id] = formula - except ValueError as exc: - feedback += f"Error for {field_id}, formula {formula} not valid: {str(exc)}\n" - - if len(valid_formulas) == len(generated_formulas): - return valid_formulas - - # Any valid formula is better than none - if valid_formulas: - return valid_formulas - else: - raise ValueError( - "Failed to generate any valid formulas after " - f"{max_retries} attempts. Feedback:\n{feedback}" - ) - - return generate_node_formulas - - -def get_automation( - automation_id: int, user: AbstractUser, workspace: Workspace -) -> Automation: - """Get automation with permission check.""" - - base_queryset = Automation.objects.filter(workspace=workspace) - automation = CoreService().get_application( - user, automation_id, base_queryset=base_queryset - ) - return automation - - -def get_workflow( - workflow_id: int, user: AbstractUser, workspace: Workspace -) -> AutomationWorkflow: - """Get workflow with permission check.""" - - workflow = AutomationWorkflowService().get_workflow(user, workflow_id) - if workflow.automation.workspace_id != workspace.id: - raise ValueError("Workflow not in workspace") - return workflow - - -def create_workflow( - user: AbstractUser, - automation: Automation, - workflow: "WorkflowCreate", - tool_helpers: "ToolHelpers", -) -> Tuple[AutomationWorkflow, dict[int | str, Any]]: - """ - Creates a new workflow in the given automation based on the provided definition. - """ - - tool_helpers.update_status( - _("Creating workflow '%(name)s'..." % {"name": workflow.name}) - ) - - orm_wf = AutomationWorkflowService().create_workflow( - user, automation.id, workflow.name - ) - - node_mapping = {} - - # First create the trigger node - orm_service_data = workflow.trigger.to_orm_service_dict() - node_type = automation_node_type_registry.get(workflow.trigger.type) - tool_helpers.update_status( - _("Creating trigger '%(label)s'..." % {"label": workflow.trigger.label}) - ) - orm_trigger = AutomationNodeService().create_node( - user, - node_type, - orm_wf, - label=workflow.trigger.label, - service=orm_service_data, - ) - - node_mapping[workflow.trigger.ref] = node_mapping[orm_trigger.id] = ( - orm_trigger, - workflow.trigger, - ) - - for node in workflow.nodes: - orm_service_data = node.to_orm_service_dict() - reference_node_id, output = node.to_orm_reference_node(node_mapping) - node_type = automation_node_type_registry.get(node.type) - tool_helpers.update_status( - _("Creating node '%(label)s'..." % {"label": node.label}) - ) - orm_node = AutomationNodeService().create_node( - user, - node_type, - orm_wf, - reference_node_id=reference_node_id, - output=output, - label=node.label, - service=orm_service_data, - ) - node_mapping[node.ref] = node_mapping[orm_node.id] = (orm_node, node) - - return orm_wf, node_mapping - - -def update_workflow_formulas( - workflow: "WorkflowCreate", - node_mapping: dict[int | str, Any], - tool_helpers: "ToolHelpers", -) -> None: - """ - Loop over all nodes and verify if they have formulas to update. If so, update the - formulas in the ORM node service providing the available context up to that node and - the user request for that node. - """ - - context = AssistantFormulaContext() - - def _get_service_schema(orm_node: AutomationNode): - return orm_node.service.get_type().generate_schema(orm_node.service.specific) - - def _update_context_with_node_data( - orm_node: AutomationNode, node_to_create: NodeBase - ): - schema = _get_service_schema(orm_node) - example = _create_example_from_json_schema(schema) - descr = _minimize_json_schema(schema) - descr["node_id"] = orm_node.id - descr["node_ref"] = node_to_create.ref - if getattr(node_to_create, "previous_node_ref", None): - descr["previous_node_ref"] = node_to_create.previous_node_ref - context.add_node_context(orm_node.id, example, descr) - - # Add the trigger context first - trigger_node = workflow.trigger - orm_trigger, __ = node_mapping[trigger_node.ref] - _update_context_with_node_data(orm_trigger, trigger_node) - - generate_formula_tool = get_generate_formulas_tool() - - def _generate_and_update_node_formulas( - node: HasFormulasToCreateMixin, orm_node: AutomationNode - ): - formulas_to_create = node.get_formulas_to_create(orm_node) - result = generate_formula_tool(formulas_to_create, context) - if result: - node.update_service_with_formulas(orm_node.service, result) - - # Node by node, generate formulas if needed and update the context with the node - # data, so following nodes can use it. - for node in workflow.nodes: - orm_node, __ = node_mapping[node.ref] - if isinstance(node, HasFormulasToCreateMixin): - tool_helpers.update_status( - _("Generating formulas for node '%(label)s'..." % {"label": node.label}) - ) - with transaction.atomic(): - try: - _generate_and_update_node_formulas(node, orm_node) - except Exception as exc: - logger.exception( - "Failed to generate formulas for node %s: %s", orm_node.id, exc - ) - _update_context_with_node_data(orm_node, node) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py new file mode 100644 index 0000000000..83b1ca9391 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tool_types.py @@ -0,0 +1,15 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class CoreToolType(AssistantToolType): + type = "core" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import core_toolset + + return core_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py index 8d94e02e70..f8dc64e0e2 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/tools.py @@ -1,130 +1,173 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import Annotated, Any, Literal -from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset + from baserow.core.actions import CreateApplicationActionType -from baserow.core.models import Workspace -from baserow.core.registries import application_type_registry from baserow.core.service import CoreService -from baserow_enterprise.assistant.tools.registries import AssistantToolType - -from .types import AnyBuilderItem, BuilderItem, BuilderItemCreate - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - +from baserow_enterprise.assistant.deps import AgentMode, AssistantDeps + +from .types import BuilderItem, BuilderItemCreate, builder_type_registry + + +def list_builders( + ctx: RunContext[AssistantDeps], + builder_types: Annotated[ + list[Literal["database", "application", "automation", "dashboard"]] | None, + Field( + description="Filter: only return builders of these types. null to return all types." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List databases, applications, automations, dashboards in the workspace. + + WHEN to use: You need to find databases, applications, automations, or dashboards in the workspace. Call this before creating builders to avoid duplicates. + WHAT it does: Lists all builders the user can access, optionally filtered by type. Max 20 results. + RETURNS: Dict of builders grouped by type, each with id, name, type. + DO NOT USE when: You already know the builder ID you need. + """ -def get_list_builders_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[], list[AnyBuilderItem]]: + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + tool_helpers.update_status( + _("Listing %(builder_types)ss...") + % { + "builder_types": builder_types[0] + if builder_types and len(builder_types) == 1 + else "builder" + } + ) + + applications_qs = CoreService().list_applications_in_workspace( + user, workspace, specific=False + ) + + builders = {} + for app in applications_qs: + try: + item = builder_type_registry.from_django_orm(app) + except KeyError: + continue + if not builder_types or item.type in builder_types: + builders.setdefault(item.type, []).append(item.model_dump()) + + if not builders: + return {} + + total = sum(len(v) for v in builders.values()) + max_items = 20 + if total > max_items: + truncated = {} + remaining = max_items + for btype, items in builders.items(): + truncated[btype] = items[:remaining] + remaining -= len(truncated[btype]) + if remaining <= 0: + break + return { + **truncated, + "_info": f"Showing {max_items} of {total} builders. " + "Use builder_types to filter.", + } + + return builders + + +def create_builders( + ctx: RunContext[AssistantDeps], + builders: Annotated[ + list[BuilderItemCreate], + Field(description="List of builders to create, each with a name and type."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create a new database, application, or automation. + + WHEN to use: User wants a new database, application, or automation created in the workspace. + WHAT it does: Creates one or more builders with the specified names and types. + RETURNS: List of created builders with id, name, type. + DO NOT USE when: A builder with that name may already exist — check with list_builders first. + HOW: Pick a unique, descriptive name. Check existing builders with list_builders to avoid duplicates. """ - Returns a function that lists all the builders the user has access to in the - current workspace. + + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + + created_builders = [] + with transaction.atomic(): + for builder in builders: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating %(builder_type)s %(builder_name)s...") + % {"builder_type": builder.type, "builder_name": builder.name} + ) + builder_orm_instance = CreateApplicationActionType.do( + user, workspace, builder.get_orm_type(), name=builder.name + ) + builder.post_creation_hook(user, builder_orm_instance) + created_builders.append( + BuilderItem( + id=builder_orm_instance.id, + name=builder_orm_instance.name, + type=builder.type, + ).model_dump() + ) + + return {"created_builders": created_builders} + + +def switch_mode( + ctx: RunContext[AssistantDeps], + mode: Annotated[ + Literal["database", "application", "automation", "explain"], + Field( + description=( + "Target mode: 'database' for table/field/view/row ops, " + "'application' for page/element/data-source ops, " + "'automation' for workflow/node ops, " + "'explain' for answering Baserow questions." + ) + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + Switch between domain modes (database, application, automation, explain). + + WHEN to use: Task needs tools from a different domain, or user asks a how-to question (→ "explain"). + WHAT it does: Changes the available toolset to the target domain's tools. + RETURNS: Confirmation of mode switch. + DO NOT USE when: Already in the requested mode. """ - def list_builders( - builder_types: list[ - Literal["database", "application", "automation", "dashboard"] - ] - | None = None, - ) -> list[AnyBuilderItem] | str: - """ - Lists all the builders the user can access (databases, applications, - automations, dashboards) in the current workspace. - - If `builder_types` is provided, only builders of that type are returned, - otherwise all builders are returned (default). - """ - - nonlocal user, workspace, tool_helpers - - tool_helpers.update_status( - _("Listing %(builder_types)ss...") - % { - "builder_types": builder_types[0] - if builder_types and len(builder_types) == 1 - else "builder" - } - ) + target = AgentMode(mode) + if ctx.deps.mode == target: + return f"Already in {target.value} mode." - applications_qs = CoreService().list_applications_in_workspace( - user, workspace, specific=False + ctx.deps.mode = target + if target == AgentMode.EXPLAIN: + return ( + "Switched to explain mode. " + "Call search_user_docs now to answer the user's question from the Baserow documentation." ) + return f"Switched to {target.value} mode." - builders = {} - for builder in applications_qs: - builder_type = application_type_registry.get_by_model( - builder.specific_class - ).type - if not builder_types or builder_type in builder_types: - builders.setdefault(builder_type, []).append( - BuilderItem( - id=builder.id, name=builder.name, type=builder_type - ).model_dump() - ) - - return builders if builders else "no builders found" - - return list_builders - - -class ListBuildersToolType(AssistantToolType): - type = "list_builders" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_builders_tool(user, workspace, tool_helpers) - - -def get_create_modules_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[str], dict[str, Any]]: - """ - Returns a function that creates a module in the current workspace. - """ - def create_builders(builders: list[BuilderItemCreate]) -> dict[str, Any]: - """ - Create a builder in the current workspace and return its ID and name. - - - name: desired name for the builder (better if unique in the workspace) - """ - - nonlocal user, workspace, tool_helpers - - created_builders = [] - with transaction.atomic(): - for builder in builders: - tool_helpers.update_status( - _("Creating %(builder_type)s %(builder_name)s...") - % {"builder_type": builder.type, "builder_name": builder.name} - ) - builder_orm_instance = CreateApplicationActionType.do( - user, workspace, builder.get_orm_type(), name=builder.name - ) - builder.post_creation_hook(user, builder_orm_instance) - created_builders.append( - BuilderItem( - id=builder_orm_instance.id, - name=builder_orm_instance.name, - type=builder.type, - ).model_dump() - ) - - return {"created_builders": created_builders} - - return create_builders - - -class CreateBuildersToolType(AssistantToolType): - type = "create_builders" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_create_modules_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [list_builders, create_builders, switch_mode] +core_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py index 87183d68dc..0612620349 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/core/types.py @@ -30,10 +30,14 @@ def get_orm_type(self) -> str: def from_django_orm(cls, orm_app: BaserowApplication) -> "BuilderItem": """Creates a BuilderItem instance from a Django ORM Application instance.""" + orm_type = application_type_registry.get_by_model(orm_app.specific_class).type + # The application_type_registry uses "builder" internally, but our + # Literal type expects "application". + type_mapping = {"builder": "application"} return cls( id=orm_app.id, name=orm_app.name, - type=application_type_registry.get_by_model(orm_app.specific_class).type, + type=type_mapping.get(orm_type, orm_type), ) def _post_creation_hook(self, user, builder_orm_instance): diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py new file mode 100644 index 0000000000..a92d7c2000 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/agents.py @@ -0,0 +1,284 @@ +from typing import Any, Callable + +from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent, Tool +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.usage import UsageLimits + +from baserow.contrib.database.api.formula.serializers import TypeFormulaResultSerializer +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import FormulaField +from baserow.core.models import Workspace +from baserow_premium.prompts import get_formula_docs + +from . import helpers +from .prompts import ( + FORMULA_AGENT_INSTRUCTIONS, + SAMPLE_ROW_AGENT_INSTRUCTIONS, + format_formula_fixer_prompt, + format_sample_rows_prompt, +) + +# --------------------------------------------------------------------------- +# Formula generation agent +# --------------------------------------------------------------------------- + + +class FormulaGenerationResult(PydanticBaseModel): + """Output model for the formula generation agent.""" + + table_id: int = Field( + description=( + "The ID of the table the formula is intended for. " + "Should be the same as current_table_id, unless the formula can " + "only be created in a different table." + ) + ) + field_name: str = Field( + description="The name of the formula field to be created. For a new field, it must be unique in the table." + ) + formula: str = Field( + description="The generated formula. Must be a valid Baserow formula." + ) + formula_type: str = Field( + description=( + "The type of the generated formula. Must be one of: text, long_text, " + "number, boolean, date, link_row, single_select, multiple_select, duration, array." + ) + ) + is_formula_valid: bool = Field( + description="Whether the generated formula is valid or not." + ) + error_message: str = Field( + default="", + description="If the formula is not valid, an error message explaining why.", + ) + + +formula_generation_agent: Agent[None, FormulaGenerationResult] = Agent( + output_type=FormulaGenerationResult, + instructions=FORMULA_AGENT_INSTRUCTIONS, + name="formula_generation_agent", +) + + +def get_formula_type_tool( + user: AbstractUser, workspace: Workspace +) -> Callable[[str], str]: + """ + Returns a function that validates a formula and returns its type. + """ + + def get_formula_type(table_id: int, field_name: str, formula: str) -> str: + """ + Returns the type of a formula. Raises an exception if the formula + is not valid. + **ALWAYS** call this to validate a formula is valid before returning it. + """ + + nonlocal user, workspace + + table = helpers.filter_tables(user, workspace).filter(id=table_id).first() + if not table: + raise ValueError(f"Table with ID {table_id} not found in workspace.") + + field = FormulaField(formula=formula, table=table, name=field_name, order=0) + field.recalculate_internal_fields(raise_if_invalid=True) + + result = TypeFormulaResultSerializer(field).data + if result["error"]: + field_names = list( + FieldHandler() + .get_base_fields_queryset() + .filter(table=table) + .values_list("name", flat=True) + ) + raise TypeError( + f"Invalid formula: {result['error']}. " + f"Available fields in table '{table.name}': {', '.join(field_names)}" + ) + + return result["formula_type"] + + return get_formula_type + + +def make_formula_fixer( + user: AbstractUser, workspace: Workspace, tool_helpers +) -> Callable: + """ + Returns a callback that tries to auto-generate a valid formula when the + LLM-provided one is invalid. Uses the ``formula_generation_agent``. + """ + + def fix_formula(table, field_name: str, original_formula: str) -> str | None: + database_tables = helpers.filter_tables(user, workspace).filter( + database_id=table.database_id + ) + schema = [ + t.model_dump() for t in helpers.get_tables_schema(database_tables, True) + ] + tool_helpers.update_status( + _("Fixing formula for %(name)s...") % {"name": field_name} + ) + + formula_type_tool = Tool(get_formula_type_tool(user, workspace)) + formula_toolset = FunctionToolset([formula_type_tool]) + prompt = format_formula_fixer_prompt( + field_name, original_formula, schema, get_formula_docs() + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) + + model = get_model_string() + result = formula_generation_agent.run_sync( + prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + toolsets=[formula_toolset], + usage_limits=UsageLimits(request_limit=20), + ) + if result.output.is_formula_valid: + return result.output.formula + return None + + return fix_formula + + +# --------------------------------------------------------------------------- +# Sample-row generation agent +# --------------------------------------------------------------------------- + + +def _find_reverse_link_row_fields(tables: list) -> dict[int, set[int]]: + """ + Identify auto-created reverse link_row fields across a set of tables. + + When a link_row field is created between two tables, Baserow auto-creates + a reverse field on the linked table. For sample-row generation we only + want the "owning" side (the explicitly created field) so the agent doesn't + face circular dependencies. + + For any bidirectional pair the field with the **higher** ID is the + auto-created reverse (it's created immediately after the explicit one). + + :returns: ``{table_id: {field_id, ...}}`` of reverse field IDs to exclude. + """ + + from baserow.contrib.database.fields.models import LinkRowField + + table_ids = {t.id for t in tables} + link_fields = LinkRowField.objects.filter( + table_id__in=table_ids, link_row_table_id__in=table_ids + ).select_related("link_row_related_field") + + reverse_ids: dict[int, set[int]] = {} + seen_pairs: set[tuple[int, int]] = set() + + for lf in link_fields: + related = lf.link_row_related_field + if related is None: + continue + pair = (min(lf.id, related.id), max(lf.id, related.id)) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + + # The field with the higher ID is the auto-created reverse. + reverse = lf if lf.id > related.id else related + reverse_ids.setdefault(reverse.table_id, set()).add(reverse.id) + + return reverse_ids + + +def generate_sample_rows( + user: AbstractUser, + workspace: Workspace, + tool_helpers, + created_tables: list, + data_brief: str | None = None, +) -> dict[int, list[Any]]: + """ + Use an agent with ``create_rows`` tools to generate and insert + realistic sample rows for newly created tables. + + Instead of building one giant structured-output schema for all tables, + this gives the agent a ``create_rows_in_table_`` tool per table. + The agent decides the insertion order itself — it naturally creates + rows in linked-to tables first, sees the returned row IDs, and uses + them in link_row fields of dependent tables. + """ + + from baserow_enterprise.assistant.model_profiles import ( + SAMPLE, + get_model_settings, + get_model_string, + ) + + from .tools import _build_row_tools + + tool_helpers.update_status(_("Generating example rows for these new tables...")) + + # Build a create_rows tool for every table in the database (not just + # the newly created ones) so link_row fields can reference rows in + # pre-existing tables too. + database = created_tables[0].database + all_db_tables = list(database.table_set.all()) + + # Identify reverse (auto-created) link_row fields to exclude from the + # create schema. When a link_row is created between two tables in the + # same batch, Baserow auto-creates a reverse field. Including both + # sides creates a circular dependency the sample-row agent cannot + # resolve. For any bidirectional pair, the field with the higher ID + # is the auto-created reverse — we exclude it. + reverse_field_ids = _find_reverse_link_row_fields(all_db_tables) + + create_tools = [] + for table in all_db_tables: + # Exclude reverse link_row fields for this table + exclude = reverse_field_ids.get(table.id) + field_ids = None + if exclude: + all_field_ids = [ + fo["field"].id for fo in table.get_model().get_field_objects() + ] + field_ids = [fid for fid in all_field_ids if fid not in exclude] + row_tools = _build_row_tools( + user, workspace, tool_helpers, table, field_ids=field_ids + ) + create_tools.append(row_tools["create"]) + + # Build a description of each table so the agent knows the schemas. + schemas = helpers.get_tables_schema(created_tables, full_schema=True) + table_info = "\n".join(f"- {schema.model_dump()}" for schema in schemas) + + model = get_model_string() + sample_row_agent = Agent( + output_type=str, + instructions=SAMPLE_ROW_AGENT_INSTRUCTIONS, + tools=create_tools, + name="sample_row_agent", + ) + sample_row_agent.run_sync( + format_sample_rows_prompt(table_info, data_brief=data_brief), + model=model, + model_settings=get_model_settings(model, SAMPLE), + usage_limits=UsageLimits(request_limit=len(all_db_tables) * 3 + 2), + ) + + # Collect the rows that were actually inserted. + rows_created: dict[int, list] = {} + for table in created_tables: + table_model = table.get_model() + rows = list(table_model.objects.all()) + if rows: + rows_created[table.id] = rows + + return rows_created diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py new file mode 100644 index 0000000000..1652f2a207 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/helpers.py @@ -0,0 +1,388 @@ +""" +Shared helpers for the database assistant tools. + +Contains query helpers, schema builders, and action orchestration used by +``tools.py`` and ``agents.py``. +""" + +from itertools import groupby +from typing import TYPE_CHECKING, Any, Callable + +from django.contrib.auth.models import AbstractUser +from django.db.models import Q, QuerySet +from django.utils.translation import gettext as _ + +from baserow.contrib.database.fields.actions import ( + CreateFieldActionType, + DeleteFieldActionType, + UpdateFieldActionType, +) +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import Field +from baserow.contrib.database.fields.registries import field_type_registry +from baserow.contrib.database.table.handler import TableHandler +from baserow.contrib.database.table.models import Table +from baserow.contrib.database.views.actions import CreateViewFilterActionType +from baserow.contrib.database.views.handler import ViewHandler +from baserow.contrib.database.views.models import View, ViewFilter +from baserow.core.db import specific_iterator +from baserow.core.models import Workspace +from baserow_enterprise.assistant.tools.database.types.table import TableItem + +from .types import ( + AnyViewFilterItemCreate, + FieldItem, + FieldItemCreate, + FieldItemUpdate, + InvalidFormulaFieldError, +) + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import ToolHelpers + + +class ToolInputError(Exception): + """Raised when tool input is invalid — returned to the model as an error message.""" + + +def filter_tables(user: AbstractUser, workspace: Workspace) -> QuerySet[Table]: + """Return all tables visible to the user in the given workspace.""" + + return TableHandler().list_workspace_tables(user, workspace) + + +def get_table(user: AbstractUser, workspace: Workspace, table_id: int) -> Table: + """Get a single table by ID, raising ToolInputError if not found.""" + + try: + return filter_tables(user, workspace).get(id=table_id) + except Table.DoesNotExist: + raise ToolInputError( + f"Table with ID {table_id} not found. " + "Use get_tables_schema to find valid table IDs." + ) + + +def get_tables_schema( + tables: list[Table], + full_schema: bool = False, +) -> list[TableItem]: + """ + Build serialised schema descriptions for the given tables. + + :param tables: Tables to describe. + :param full_schema: If True include all fields, otherwise only primary + fields and relationships. + :returns: List of table descriptions, in the same order as the input tables. + """ + + q = Q(table__in=tables) + if not full_schema: + q &= Q(linkrowfield__isnull=False) | Q(primary=True) + + base_field_queryset = FieldHandler().get_base_fields_queryset() + fields = specific_iterator( + base_field_queryset.filter(q).order_by("table_id", "order"), + per_content_type_queryset_hook=( + lambda field, queryset: field_type_registry.get_by_model( + field + ).enhance_field_queryset(queryset, field) + ), + ) + + table_items: list[TableItem] = [] + tables_by_id = {table.id: table for table in tables} + for table_id, fields_in_table in groupby(fields, lambda f: f.table_id): + table_items.append(_get_table_schema(tables_by_id, table_id, fields_in_table)) + + # Preserve the input order + input_order = {t.id: i for i, t in enumerate(tables)} + table_items.sort(key=lambda t: input_order[t.id]) + return table_items + + +def _get_table_schema( + tables_by_id: dict[int, Table], table_id: int, fields_in_table: list[Field] +) -> TableItem: + """ + Build a TableItem schema description for a single table given its fields. + + :param tables_by_id: Mapping of table ID → table instance for all tables. + :param table_id: ID of the table to describe. + :param fields_in_table: Iterable of field instances belonging to the table. + :returns: TableItem describing the table and its fields. + """ + + fields_in_table = list(fields_in_table) + primary_field = next((f for f in fields_in_table if f.primary), None) + if primary_field is None: + raise ValueError(f"Table {table_id} has no primary field") + primary_field_item = FieldItem.from_django_orm(primary_field) + + table = tables_by_id[table_id] + + return TableItem( + id=table_id, + name=table.name, + primary_field=primary_field_item, + fields=[ + FieldItem.from_django_orm(f) + for f in fields_in_table + if f.id != primary_field.id + ], + ) + + +def create_fields( + user: AbstractUser, + table: Table, + field_items: list[FieldItemCreate], + tool_helpers: "ToolHelpers", + formula_fixer: Callable[[Table, str, str], str | None] | None = None, +) -> tuple[list[FieldItem], list[str], list[dict]]: + """ + Create fields in a table, handling formula errors with optional auto-fix. + + Fields are sorted so that dependencies are satisfied: regular fields first, + then link_row, lookup, and formula last. + + :param user: The acting user. + :param table: Target table. + :param field_items: Field definitions to create. + :param tool_helpers: Provides status updates and cancellation. + :param formula_fixer: Optional callback ``(table, name, formula) -> fixed`` + invoked when a formula field fails validation. + :returns: Tuple of (created fields, field error messages, formula error dicts). + """ + + from .types import InvalidFormulaFieldError + from .types.fields import FIELD_ORDER + + created_fields: list[FieldItem] = [] + formula_errors: list[dict] = [] + field_errors: list[str] = [] + + # Creation order: regular → link_row → lookup → formula. + # link_row before lookup so auto-created links exist for lookups. + # formula last so they can reference fields created earlier. + field_items = sorted(field_items, key=lambda f: FIELD_ORDER.get(f.type, 0)) + + for field_item in field_items: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating field %(field_name)s...") % {"field_name": field_item.name} + ) + + try: + new_field = CreateFieldActionType.do( + user, + table, + field_item.type, + **field_item.to_django_orm_kwargs(table, user=user), + ) + created_fields.append(FieldItem.from_django_orm(new_field)) + except InvalidFormulaFieldError as exc: + _fix_formula_field( + user, table, formula_fixer, created_fields, formula_errors, exc + ) + except Exception as e: + field_errors.append( + f"Error creating field {field_item.name} in table_{table.id}: {e}.\n" + f"Please retry recreating this field later, if important." + ) + return created_fields, field_errors, formula_errors + + +def _fix_formula_field( + user: AbstractUser, + table: Table, + formula_fixer: Callable[[Table, str, str], str | None] | None, + created_fields: list[FieldItem], + formula_errors: list[dict], + exc: InvalidFormulaFieldError, +): + """ + Attempt to fix an invalid formula field using the provided formula_fixer callback. + If successful, creates the field with the fixed formula. Otherwise, records the error. + + :param user: The acting user. + :param table: The table the field belongs to. + :param formula_fixer: Callback to attempt formula fixing. + :param created_fields: List to append successfully created fields to. + :param formula_errors: List to append error details to if fixing fails. + :param exc: The exception containing details about the invalid formula. + """ + + fixed = False + if formula_fixer: + try: + new_formula = formula_fixer(exc.table, exc.field_name, exc.formula) + if new_formula: + new_field = CreateFieldActionType.do( + user, + table, + "formula", + name=exc.field_name, + formula=new_formula, + ) + created_fields.append(FieldItem.from_django_orm(new_field)) + fixed = True + except Exception: + pass + if not fixed: + formula_errors.append( + { + "field_name": exc.field_name, + "formula": exc.formula, + "error": exc.error, + } + ) + + +def get_view(user: AbstractUser, workspace: Workspace, view_id: int) -> View: + """ + Fetch a view scoped to the user's workspace. + + :param user: The acting user. + :param workspace: Workspace the view must belong to. + :param view_id: ID of the view to retrieve. + """ + + return ViewHandler().get_view_as_user( + user, + view_id, + base_queryset=View.objects.filter(table__database__workspace=workspace), + ) + + +def create_view_filter( + user: AbstractUser, + orm_view: View, + table_fields: dict[int, Any], + view_filter_item: AnyViewFilterItemCreate, +) -> ViewFilter: + """ + Create a single view filter after validating the field type matches. + + :param user: The acting user. + :param orm_view: The view to add the filter to. + :param table_fields: Mapping of field ID → field instance for the table. + :param view_filter_item: The filter definition to create. + :raises ValueError: If the field is not found or its type doesn't match. + """ + + field = table_fields.get(view_filter_item.field_id) + if field is None: + raise ValueError( + f"Field {view_filter_item.field_id} not found for filter. " + f"Available field IDs: {sorted(table_fields.keys())}" + ) + field_type = field_type_registry.get_by_model(field.specific_class) + if field_type.type != view_filter_item.type: + raise ValueError( + f"Field '{field.name}' (id={field.id}) is type '{field_type.type}', " + f"but filter declared type '{view_filter_item.type}'" + ) + + filter_type = view_filter_item.get_django_orm_type(field) + filter_value = view_filter_item.get_django_orm_value( + field, timezone=user.profile.timezone + ) + + return CreateViewFilterActionType.do( + user, + orm_view, + field, + filter_type, + filter_value, + filter_group_id=None, + ) + + +def update_field( + user: AbstractUser, + workspace: Workspace, + field_update: "FieldItemUpdate", + formula_fixer: Callable[[Table, str, str], str | None] | None = None, +) -> FieldItem: + """ + Update an existing field. + + :param user: The acting user. + :param workspace: Workspace the field must belong to. + :param field_update: The update definition. + :param formula_fixer: Optional callback for fixing invalid formulas. + :returns: Updated field as FieldItem. + """ + + base_field = FieldHandler().get_field(field_update.field_id) + field = base_field.specific + + # Verify workspace access + filter_tables(user, workspace).filter(id=base_field.table_id).get() + field_type = field_type_registry.get_by_model(field).type + kwargs = field_update.to_update_kwargs(field_type) + + if not kwargs: + return FieldItem.from_django_orm(field) + + # Validate formula before updating + if "formula" in kwargs and kwargs["formula"]: + from baserow.contrib.database.fields.models import FormulaField + from baserow.core.formula.parser.exceptions import BaserowFormulaException + + try: + tmp = FormulaField( + formula=kwargs["formula"], + table=field.table, + name=kwargs.get("name", field.name), + order=0, + ) + tmp.recalculate_internal_fields(raise_if_invalid=True) + except BaserowFormulaException as e: + if formula_fixer: + fixed = formula_fixer( + field.table, + kwargs.get("name", field.name), + kwargs["formula"], + ) + if fixed: + kwargs["formula"] = fixed + else: + raise InvalidFormulaFieldError( + kwargs.get("name", field.name), + kwargs["formula"], + field.table, + str(e), + ) + else: + raise InvalidFormulaFieldError( + kwargs.get("name", field.name), + kwargs["formula"], + field.table, + str(e), + ) + + UpdateFieldActionType.do(user, field, **kwargs) + # Re-fetch the specific field to get the updated state + updated_field = FieldHandler().get_field(field_update.field_id).specific + return FieldItem.from_django_orm(updated_field) + + +def delete_field( + user: AbstractUser, + workspace: Workspace, + field_id: int, +) -> None: + """ + Delete (soft-delete / trash) a field. + + :param user: The acting user. + :param workspace: Workspace the field must belong to. + :param field_id: ID of the field to delete. + """ + + base_field = FieldHandler().get_field(field_id) + # Verify workspace access + filter_tables(user, workspace).filter(id=base_field.table_id).get() + DeleteFieldActionType.do(user, base_field.specific) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py new file mode 100644 index 0000000000..cc1808ef37 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/prompts.py @@ -0,0 +1,64 @@ +""" +Prompt strings and templates for database sub-agents. +""" + +# --------------------------------------------------------------------------- +# Agent instructions +# --------------------------------------------------------------------------- + +FORMULA_AGENT_INSTRUCTIONS = ( + "Generates a Baserow formula based on the provided description and table schema. " + "Always validate the formula using the get_formula_type tool before returning it." +) + +SAMPLE_ROW_AGENT_INSTRUCTIONS = ( + "Create 5 realistic sample rows for each table using the " + "create_rows tools provided. " + "IMPORTANT: Fill EVERY field for every row. Do NOT leave any field " + "empty or null unless the data genuinely requires it. " + "Insertion order: start with tables that have NO link_row fields, " + "so you have real row IDs to reference. " + "Then create rows in dependent tables, using those IDs in link_row fields. " + "Reply with a short summary when done." +) + +# --------------------------------------------------------------------------- +# Prompt formatters +# --------------------------------------------------------------------------- + + +def format_formula_fixer_prompt( + field_name: str, + original_formula: str, + schema: list[dict], + formula_docs: str, +) -> str: + return ( + f"Fix this formula for field '{field_name}': {original_formula}\n\n" + f"Tables schema: {schema}\n\n" + f"Formula documentation: {formula_docs}" + ) + + +def format_formula_generation_prompt( + description: str, + schema: list[dict], + formula_docs: str, +) -> str: + return ( + f"Description: {description}\n\n" + f"Tables schema: {schema}\n\n" + f"Formula documentation: {formula_docs}" + ) + + +def format_sample_rows_prompt(table_info: str, data_brief: str | None = None) -> str: + prompt = ( + f"Create 5 sample rows for each of these tables:\n{table_info}" + "\n\nREMINDER: Fill ALL fields for every row — especially link_row " + "(relationship) fields. Use the row IDs returned by previous " + "create_rows calls as values for link_row fields in dependent tables." + ) + if data_brief: + prompt += f"\n\nUser instructions for the data: {data_brief}" + return prompt diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py new file mode 100644 index 0000000000..2ee9cbd5cc --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class DatabaseToolType(AssistantToolType): + type = "database" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import database_toolset + + return database_toolset + + def get_routing_rules(self): + from .tools import ROUTING_RULES + + return ROUTING_RULES diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py index 56d2ed7a37..08f221d1cf 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py @@ -1,23 +1,29 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple +from typing import TYPE_CHECKING, Annotated, Any, Literal from django.contrib.auth.models import AbstractUser from django.db import transaction from django.utils.translation import gettext as _ -import udspy from loguru import logger -from pydantic import create_model +from pydantic import Field, create_model +from pydantic_ai import RunContext, Tool +from pydantic_ai.toolsets import FunctionToolset +from pydantic_ai.usage import UsageLimits -from baserow.contrib.database.api.formula.serializers import TypeFormulaResultSerializer from baserow.contrib.database.fields.actions import ( CreateFieldActionType, DeleteFieldActionType, UpdateFieldActionType, ) -from baserow.contrib.database.fields.models import FormulaField from baserow.contrib.database.fields.registries import field_type_registry from baserow.contrib.database.models import Database +from baserow.contrib.database.rows.actions import ( + CreateRowsActionType, + DeleteRowsActionType, + UpdateRowsActionType, +) from baserow.contrib.database.table.actions import CreateTableActionType +from baserow.contrib.database.table.models import Table from baserow.contrib.database.views.actions import ( CreateViewActionType, UpdateViewFieldOptionsActionType, @@ -25,848 +31,1200 @@ from baserow.contrib.database.views.handler import ViewHandler from baserow.core.models import Workspace from baserow.core.service import CoreService -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from baserow_enterprise.assistant.deps import AssistantDeps +from baserow_enterprise.assistant.tools.toolset import inline_refs from baserow_enterprise.assistant.types import TableNavigationType, ViewNavigationType from baserow_premium.prompts import get_formula_docs -from . import utils +from . import helpers +from .agents import ( + formula_generation_agent, + generate_sample_rows, + get_formula_type_tool, + make_formula_fixer, +) +from .prompts import format_formula_generation_prompt from .types import ( - AnyFieldItem, - AnyFieldItemCreate, - AnyViewFilterItem, - AnyViewItemCreate, - BaseTableItem, + FieldItemCreate, + FieldItemUpdate, ListTablesFilterArg, TableItemCreate, ViewFiltersArgs, - view_item_registry, + ViewItem, + ViewItemCreate, + get_create_row_model, + get_link_row_hints, + get_update_row_model, ) if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers + from baserow_enterprise.assistant.deps import ToolHelpers +MAX_HINT_TABLES = 10 -def get_list_tables_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[str]]: - """ - Returns a function that lists all the tables in a given database the user has - access to in the current workspace. - """ - def list_tables(filters: ListTablesFilterArg) -> list[dict[str, Any]]: - """ - List tables that verifies the filters +def _no_tables_found_hint( + user: AbstractUser, workspace: Workspace, filters: "ListTablesFilterArg" +) -> str: + """Build an informative message when no tables match the filters. - - Always call this before creating new tables to avoid duplicates. - - Always call this to link existing tables when table IDs are not known. - """ + When the caller supplied a ``database_id`` that doesn't correspond to any + real database in the workspace, say so explicitly and list the first + available tables so the model can self-correct. + """ - nonlocal user, workspace, tool_helpers + parts: list[str] = [] - tables = ( - utils.filter_tables(user, workspace) - .filter(filters.to_orm_filter()) - .select_related("database") + # Check whether the requested database actually exists. + db_ref = filters.database_id_or_name + if db_ref is not None: + if isinstance(db_ref, int): + db_exists = Database.objects.filter(workspace=workspace, id=db_ref).exists() + else: + db_exists = Database.objects.filter( + workspace=workspace, name__icontains=db_ref + ).exists() + if not db_exists: + parts.append( + f"No database matching '{db_ref}' exists in this " + f"workspace. Note: workspace, application, and database IDs " + f"are different — make sure you are using a database ID." + ) + else: + parts.append( + f"Database '{db_ref}' exists but has no tables " + f"matching the provided filters." + ) + else: + parts.append("No tables found matching the provided filters.") + + # Fetch a sample of available tables across the workspace. + all_tables = ( + helpers.filter_tables(user, workspace) + .select_related("database") + .order_by("database_id", "id") + ) + total_tables = all_tables.count() + + if total_tables == 0: + parts.append("This workspace has no database tables at all.") + return " ".join(parts) + + sample = all_tables[:MAX_HINT_TABLES] + db_ids_seen: set[int] = set() + table_lines: list[str] = [] + for t in sample: + db_ids_seen.add(t.database_id) + table_lines.append( + f' - table_id={t.id}, table_name="{t.name}", ' + f'database_id={t.database_id}, database_name="{t.database.name}"' ) - databases = {} - database_names = [] - for table in tables: - if table.database_id not in databases: - databases[table.database_id] = { - "id": table.database_id, - "name": table.database.name, - "tables": [], - } - database_names.append(table.database.name) - databases[table.database_id]["tables"].append( - { - "id": table.id, - "name": table.name, - "database_id": table.database_id, - } - ) + total_dbs = Database.objects.filter(workspace=workspace).count() - tool_helpers.update_status( - _("Listing tables in %(database_names)s...") - % {"database_names": ", ".join(database_names)} + parts.append( + f"Available tables ({total_tables} table(s) across " + f"{total_dbs} database(s) in this workspace):" + ) + parts.append("\n".join(table_lines)) + + remaining_tables = total_tables - len(sample) + remaining_dbs = total_dbs - len(db_ids_seen) + if remaining_tables > 0: + parts.append( + f" ... and {remaining_tables} more table(s) in " + f"{remaining_dbs} more database(s)." ) - if len(databases) == 0: - return "No tables found" - elif len(databases) == 1: - # Return just the tables array when there's only one database - return list(databases.values())[0]["tables"] - else: - return list(databases.values()) - - return list_tables + return "\n".join(parts) -class ListTablesToolType(AssistantToolType): - type = "list_tables" - thinking_message = "Looking for tables..." +# --------------------------------------------------------------------------- +# Tool 1: list_tables +# --------------------------------------------------------------------------- - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_tables_tool(user, workspace, tool_helpers) +def list_tables( + ctx: RunContext[AssistantDeps], + filters: Annotated[ + ListTablesFilterArg, + Field(description="Filter criteria to narrow down which tables to list."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> list[dict[str, Any]] | dict[str, Any]: + """\ + List tables, optionally filtered by database or name. -def get_tables_schema_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[str]]: - """ - Returns a function that lists all the fields in a given table the user has - access to in the current workspace. + WHEN to use: Before creating tables (to avoid duplicates), when you need table IDs, or to discover what tables exist in the workspace. + WHAT it does: Lists tables matching the filter criteria (database_id, name, starred), grouped by database. + RETURNS: Tables with id, name, database_id. Includes a hint with available tables if no match found. + DO NOT USE when: You already have the table IDs you need. """ - def get_tables_schema( - table_ids: list[int], - full_schema: bool, - ) -> list[dict[str, Any]]: - """ - Returns the schema of the specified tables, including their fields if requested. - Use `full_schema=True` to get all the fields, otherwise only the table names, - IDs, primary keys, and relationships will be included. + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - When to use: - Understanding table structure before creating/modifying fields - - Checking existing field names to avoid duplicates - Understanding table - relationships when creating link_row fields + tables = ( + helpers.filter_tables(user, workspace) + .filter(filters.to_orm_filter()) + .select_related("database") + ) - Remember: - Always call this before creating fields to avoid duplicate names - - Use get_rows_tools() for any row operations - not this one - """ + databases = {} + database_names = [] + for table in tables: + if table.database_id not in databases: + databases[table.database_id] = { + "id": table.database_id, + "name": table.database.name, + "tables": [], + } + database_names.append(table.database.name) + databases[table.database_id]["tables"].append( + { + "id": table.id, + "name": table.name, + "database_id": table.database_id, + } + ) - nonlocal user, workspace, tool_helpers + tool_helpers.update_status( + _("Listing tables in %(database_names)s...") + % {"database_names": ", ".join(database_names)} + ) - if not table_ids: - return [] + if len(databases) == 0: + return {"tables": [], "_info": _no_tables_found_hint(user, workspace, filters)} + elif len(databases) == 1: + # Return just the tables array when there's only one database + return list(databases.values())[0]["tables"] + else: + return list(databases.values()) + + +# --------------------------------------------------------------------------- +# Tool 2: get_tables_schema +# --------------------------------------------------------------------------- + + +def get_tables_schema( + ctx: RunContext[AssistantDeps], + table_ids: Annotated[ + list[int], Field(description="List of table IDs to retrieve schemas for.") + ], + full_schema: Annotated[ + bool, + Field( + description="If True, include all fields. If False, only table names, IDs, primary keys, and relationships." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Get field definitions for tables (full_schema=True for all fields). + + WHEN to use: Before creating/modifying fields to understand table structure and avoid duplicates. Also for understanding relationships when creating link_row fields. + WHAT it does: Returns the schema of specified tables. full_schema=True returns all fields with types and configs. full_schema=False returns only names, IDs, primary keys, and relationships. + RETURNS: Table schemas with field names, types, IDs, primary keys, and relationships. + DO NOT USE when: You need row data — use list_rows instead. For row operations, use load_row_tools, those tools already provide the necessary schema info in their instructions. + """ - tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - tool_helpers.update_status( - _("Inspecting %(table_names)s schema...") - % {"table_names": ", ".join(t.name for t in tables)} - ) + if not table_ids: + return {"tables_schema": []} - return { - "tables_schema": [ - ts.model_dump() for ts in utils.get_tables_schema(tables, full_schema) - ] - } + tables = helpers.filter_tables(user, workspace).filter(id__in=table_ids) - return get_tables_schema + tool_helpers.update_status( + _("Inspecting %(table_names)s schema...") + % {"table_names": ", ".join(t.name for t in tables)} + ) + return { + "tables_schema": [ + ts.model_dump() for ts in helpers.get_tables_schema(tables, full_schema) + ] + } + + +# --------------------------------------------------------------------------- +# Tool 3: list_rows +# --------------------------------------------------------------------------- + + +def list_rows( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to list rows from.") + ], + offset: Annotated[ + int, + Field( + description="Number of rows to skip for pagination. Use 0 for the first page." + ), + ], + limit: Annotated[ + int, Field(description="Maximum number of rows to return (max 20).") + ], + field_ids: Annotated[ + list[int] | None, + Field(description="List of field IDs to include, or null for all fields."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Read rows from a table with pagination (max 20 per call). + + WHEN to use: User wants to see data in a table, or you need to check existing row values. + WHAT it does: Lists rows from a table with pagination (offset/limit) and optional field filtering. Max 20 rows per call. + RETURNS: Rows array with field values, plus total row count for pagination. + DO NOT USE when: You need to create, update, or delete rows — call load_row_tools first to get row manipulation tools. + """ -class GetTablesSchemaToolType(AssistantToolType): - type = "get_tables_schema" + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_tables_schema_tool(user, workspace, tool_helpers) + table = helpers.get_table(user, workspace, table_id) + tool_helpers.update_status( + _("Listing rows in %(table_name)s ") % {"table_name": table.name} + ) -def get_table_and_fields_tools_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[list[TableItemCreate]], list[dict[str, Any]]]: - def create_fields( - table_id: int, fields: list[AnyFieldItemCreate] - ) -> list[AnyFieldItem]: - """ - Creates fields in the specified table. + rows_qs = table.get_model().objects.all() + rows = rows_qs[offset : offset + limit] - - Choose the most appropriate field type for each field. - - Field names must be unique within a table: check existing names - when needed and skip duplicates. - - For link_row fields, ensure the linked table already exists in - the same database; create it first if needed. - """ + response_model = create_model( + f"ResponseTable{table.id}RowWithFieldFilter", + id=(int, ...), + __base__=get_create_row_model(table, field_ids=field_ids), + ) - nonlocal user, workspace, tool_helpers + return { + "rows": [ + response_model.from_django_orm(row, field_ids).model_dump() for row in rows + ], + "total": rows_qs.count(), + } + + +# --------------------------------------------------------------------------- +# Tool 4: list_views +# --------------------------------------------------------------------------- + + +def list_views( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to list views for.") + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + List views in a table. + + WHEN to use: Before creating views (to avoid duplicate names), or to find existing view IDs. + WHAT it does: Lists all views in a table with their id, name, and type. + RETURNS: Views array with id, name, type configuration. + DO NOT USE when: You already have the view IDs you need. + """ - if not fields: - return [] + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) + table = helpers.get_table(user, workspace, table_id) - with transaction.atomic(): - created_fields = utils.create_fields(user, table, fields, tool_helpers) - return {"created_fields": [field.model_dump() for field in created_fields]} - - def create_tables( - database_id: int, tables: list[TableItemCreate], add_sample_rows: bool = True - ) -> list[dict[str, Any]]: - """ - Creates tables with fields and rows in a database. **ALWAYS** add sample rows - unless explicitly asked otherwise. - - - table names should be unique in a database - - add meaningful fields with the appropriate types and relationships to other - existing tables. The reversed link_row fields will be created automatically. - - if add_sample_rows is True (default), add some example rows to each table - """ - - nonlocal user, workspace, tool_helpers - - if not tables: - return {"created_tables": []} - - database = CoreService().get_application( - user, - database_id, - specific=False, - base_queryset=Database.objects.filter(workspace=workspace), - ) + tool_helpers.update_status( + _("Listing views in %(table_name)s...") % {"table_name": table.name} + ) - created_tables = [] - with transaction.atomic(): - for i, table in enumerate(tables): - tool_helpers.update_status( - _("Creating table %(table_name)s...") % {"table_name": table.name} - ) + views = ViewHandler().list_views( + user, + table, + filters=False, + sortings=False, + decorations=False, + group_bys=False, + limit=100, + ) - created_table, __ = CreateTableActionType.do( - user, database, table.name, fill_example=False - ) - created_tables.append(created_table) + return {"views": [ViewItem.from_django_orm(view).model_dump() for view in views]} - primary_field_item = table.primary_field - primary_field = created_table.get_primary_field().specific - new_field_type = field_type_registry.get(primary_field_item.type) - UpdateFieldActionType.do( - user, - primary_field, - new_type_name=new_field_type.type, - name=primary_field_item.name, - ) - # Now that we have all the tables created, we can create the fields - notes = [] - for table, created_table in zip(tables, created_tables): - with transaction.atomic(): - try: - utils.create_fields(user, created_table, table.fields, tool_helpers) - except Exception as e: - notes.append( - f"Error creating fields for table_{created_table.id}: {e}.\n" - f"Please retry recreating fields for table_{created_table.id} manually." - ) +# --------------------------------------------------------------------------- +# Tool 5: create_tables +# --------------------------------------------------------------------------- - tool_helpers.navigate_to( - TableNavigationType( - type="database-table", - database_id=database.id, - table_id=created_table.id, - table_name=created_table.name, - ) - ) - if add_sample_rows: - instructions = [] +def _create_empty_tables( + user: AbstractUser, + database: Database, + tables: list[TableItemCreate], + tool_helpers: "ToolHelpers", +) -> list[Table]: + """Create bare tables and rename each one's auto-created primary field.""" + created: list[Table] = [] + with transaction.atomic(): + for table in tables: + tool_helpers.raise_if_cancelled() tool_helpers.update_status( - _("Preparing example rows for these new tables...") + _("Creating table %(table_name)s...") % {"table_name": table.name} ) - tools = [] - for table, created_table in zip(tables, created_tables): - create_rows_tool = utils.get_table_rows_tools( - user, workspace, tool_helpers, created_table - )["create"] - tools.append(create_rows_tool) - instructions.append( - f"- Create 5 example rows with realistic data for {created_table.name} (Id: {created_table.id}). " - "Fill every relationship with valid data when possible." - ) - - predictor = udspy.ReAct( - "instructions -> result", tools=tools, max_iters=len(tables * 2) + created_table, __ = CreateTableActionType.do( + user, database, table.name, fill_example=False ) - result = predictor(instructions=("\n".join(instructions))) - notes.append(result) - - return { - "created_tables": [ - BaseTableItem(id=t.id, name=t.name).model_dump() for t in created_tables - ], - "notes": notes, - } - - def load_table_and_fields_tools(): - """ - TOOL LOADER: Loads table and field creation tools for a database. - - After calling this loader, you will have access to: - - create_tables: Create new tables in a database with fields and sample rows - - create_fields: Add new fields to an existing table + created.append(created_table) + primary_field = created_table.get_primary_field().specific + UpdateFieldActionType.do(user, primary_field, name=table.primary_field_name) + return created - Use this when you need to create tables or add fields but don't have the tools. - """ - @udspy.module_callback - def _load_table_and_fields_tools(context): - nonlocal user, workspace, tool_helpers - - observation = ["New tools are now available.\n"] - - create_tool = udspy.Tool(create_tables) - new_tools = [create_tool] - observation.append("- Use `create_tables` to create tables in a database.") - - create_fields_tool = udspy.Tool(create_fields) - new_tools.append(create_fields_tool) - observation.append("- Use `create_fields` to create fields in a table.") +def _create_table_fields( + user: AbstractUser, + tables: list[TableItemCreate], + created_tables: list[Table], + tool_helpers: "ToolHelpers", + formula_fixer, +) -> list[str]: + """Create non-primary fields for each table; return collected notes/errors.""" + notes: list[str] = [] + for table, created_table in zip(tables, created_tables): + tool_helpers.raise_if_cancelled() + with transaction.atomic(): + # Drop any field whose name matches the primary field name — it's + # already set via UpdateFieldActionType.do() above. Including it in + # fields too is a common model mistake that would otherwise produce + # a "field already exists" error note. + non_primary_fields = [ + f + for f in table.fields + if f.name.lower() != table.primary_field_name.lower() + ] + _created, field_errors, formula_errors = helpers.create_fields( + user, + created_table, + non_primary_fields, + tool_helpers, + formula_fixer=formula_fixer, + ) + notes.extend(field_errors) + for err in formula_errors: + notes.append( + f"Invalid formula for field '{err['field_name']}' " + f"in table_{created_table.id}: {err['error']}. " + f"Use generate_formula to fix it." + ) + return notes + + +def create_tables( + ctx: RunContext[AssistantDeps], + database_id: Annotated[ + int, + Field( + ..., + description="The ID of the database to create tables in.", + ), + ], + tables: Annotated[ + list[TableItemCreate], + Field( + ..., + description="List of tables to create, each with a name, primary field, fields and relationships.", + ), + ], + add_sample_rows: Annotated[ + bool | str, + Field( + ..., + description="Controls sample row generation. True (default): generate realistic example rows. " + "A string: a brief describing what kind of data to create (e.g. 'Italian recipes with calorie counts'). " + "False: create empty tables, only use when the user explicitly asks for no sample data.", + ), + ], + thought: Annotated[ + str, + Field( + ..., + description="Brief reasoning for calling this tool.", + ), + ], +) -> dict[str, Any]: + """\ + Create tables with fields; generates sample rows by default. + + WHEN to use: User wants new tables created in a database. Always set add_sample_rows=true (or a descriptive string) unless explicitly asked for empty tables. + WHAT it does: Creates tables with fields, generates sample rows by default. Pass add_sample_rows=false ONLY when the user explicitly asks for empty tables. + Pass a string to guide the kind of sample data generated (e.g. "Italian recipes with calorie counts"). Table names must be unique. Reversed link_row fields are auto-created. + At the end, this tool automatically navigates the user to the last created table. + RETURNS: Created table schemas with all field IDs. Notes on any errors. + DO NOT USE when: Tables already exist — check with list_tables first. + HOW: Pass ALL related tables in a single call — link_row fields can reference other tables in the same call by name (they are created internally before fields are added). Choose appropriate field types for each column. + Use single_select/multiple_select with select_options for categorical data. The primary field is always text — pick a meaningful name for it. + """ - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - return _load_table_and_fields_tools + if not tables: + return {"created_tables": []} - return load_table_and_fields_tools + database = CoreService().get_application( + user, + database_id, + specific=False, + base_queryset=Database.objects.filter(workspace=workspace), + ) + created_tables = _create_empty_tables(user, database, tables, tool_helpers) -class TableAndFieldsToolFactoryToolType(AssistantToolType): - type = "table_and_fields_tool_factory" + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) + notes = _create_table_fields( + user, tables, created_tables, tool_helpers, formula_fixer + ) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_table_and_fields_tools_factory(user, workspace, tool_helpers) + last_table = created_tables[-1] + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=database.id, + table_id=last_table.id, + table_name=last_table.name, + ) + ) + created_rows = {} + if add_sample_rows: + try: + data_brief = add_sample_rows if isinstance(add_sample_rows, str) else None + created_rows = generate_sample_rows( + user, workspace, tool_helpers, created_tables, data_brief=data_brief + ) + except Exception as e: + logger.exception( + "[assistant] generate_sample_rows raised unexpectedly: {}", e + ) + notes.append(f"Error creating sample rows: {e}") + + # Return the full schema so callers don't need a separate + # get_tables_schema call to learn field IDs. + tables_schema = [ + ts.model_dump() + for ts in helpers.get_tables_schema(created_tables, full_schema=True) + ] + + response: dict[str, Any] = {"created_tables": tables_schema, "notes": notes} + if created_rows: + response["created_rows"] = { + f"Row IDs for newly created rows in table_{table_id}": [ + row.id for row in rows + ] + for table_id, rows in created_rows.items() + } -def get_list_rows_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, int, int, list[int] | None], list[dict[str, Any]]]: + return response + + +# --------------------------------------------------------------------------- +# Tool 6: create_fields +# --------------------------------------------------------------------------- + + +def create_fields( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to add fields to.") + ], + fields: Annotated[ + list[FieldItemCreate], + Field( + description="List of fields to create with their types and configurations." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add fields to an existing table. + + WHEN to use: Adding fields to an existing table, or retrying failed field creation after create_tables partial failure. + WHAT it does: Creates fields in the specified table. Field names must be unique. For link_row fields, the linked table must already exist. + RETURNS: Created fields with id, name, type. Formula errors with hints if any. + DO NOT USE when: Creating a brand new table — use create_tables instead, which handles fields as part of table creation. + HOW: Call get_tables_schema first to see existing fields and avoid duplicates. For link_row fields, ensure the target table already exists. """ - Returns a function that lists rows in a given table the user has access to in the - current workspace. - """ - - def list_rows( - table_id: int, - offset: int = 0, - limit: int = 20, - field_ids: list[int] | None = None, - ) -> list[dict[str, Any]]: - """ - Lists rows in the specified table. - - - Use offset and limit for pagination. - - Use field_ids to limit the response to specific fields. - """ - nonlocal user, workspace, tool_helpers + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) - - tool_helpers.update_status( - _("Listing rows in %(table_name)s ") % {"table_name": table.name} - ) + if not fields: + return {"created_fields": []} - rows_qs = table.get_model().objects.all() - rows = rows_qs[offset : offset + limit] + table = helpers.get_table(user, workspace, table_id) - response_model = create_model( - f"ResponseTable{table.id}RowWithFieldFilter", - id=(int, ...), - __base__=utils.get_create_row_model(table, field_ids=field_ids), + with transaction.atomic(): + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) + created_fields, field_errors, formula_errors = helpers.create_fields( + user, table, fields, tool_helpers, formula_fixer=formula_fixer ) + result = {"created_fields": [field.model_dump() for field in created_fields]} + if field_errors: + result["field_errors"] = field_errors + if formula_errors: + for err in formula_errors: + err["hint"] = ( + "Use generate_formula to create a valid formula for this field." + ) + result["formula_errors"] = formula_errors + return result + + +# --------------------------------------------------------------------------- +# Tool 7: update_fields +# --------------------------------------------------------------------------- + + +def update_fields( + ctx: RunContext[AssistantDeps], + fields: Annotated[ + list[FieldItemUpdate], + Field( + description="List of field updates, each with a field_id and the properties to change." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Update existing fields (rename, change properties). + + WHEN to use: User wants to rename a field, change decimal places, update select options, or modify other field properties. + WHAT it does: Updates field properties. Cannot change field type or link_row targets — create a new field instead. + RETURNS: Updated fields with id, name, type and current properties. + DO NOT USE when: You need to change a field's type — delete and recreate it instead. + HOW: Call get_tables_schema first to see current field IDs and types. + """ - return { - "rows": [ - response_model.from_django_orm(row, field_ids).model_dump() - for row in rows - ], - "total": rows_qs.count(), - } - - return list_rows - - -class ListRowsToolType(AssistantToolType): - type = "list_rows" + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_rows_tool(user, workspace, tool_helpers) + if not fields: + return {"updated_fields": [], "errors": []} + updated = [] + errors = [] + formula_fixer = make_formula_fixer(user, workspace, tool_helpers) -def get_rows_tools_factory( - user: AbstractUser, - workspace: Workspace, - tool_helpers: "ToolHelpers", -) -> Callable[[int, list[dict[str, Any]]], list[Any]]: - def load_rows_tools( - table_ids: list[int], - operations: list[Literal["create", "update", "delete"]], - ) -> Tuple[str, list[Callable[[Any], Any]]]: - """ - TOOL LOADER: Loads row manipulation tools for specified tables. - Make sure to have the correct table IDs. - - After calling this loader, you will have access to table-specific tools: - - create_rows_in_table_X: Create new rows in table X - - update_rows_in_table_X: Update existing rows in table X by their IDs - - delete_rows_in_table_X: Delete rows from table X by their IDs - - Use this when you need to create, update, or delete rows but don't have - the tools. - Call with the table IDs and desired operations (create/update/delete). - """ - - @udspy.module_callback - def _load_rows_tools(context): - nonlocal user, workspace, tool_helpers - - tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) - if not tables: - observation = [ - "No valid tables found for the given IDs. ", - "Make sure the table IDs are correct.", - ] - return "\n".join(observation) - - new_tools = [] - observation = ["New tools are now available.\n"] - for table in tables: - table_tools = utils.get_table_rows_tools( - user, workspace, tool_helpers, table + with transaction.atomic(): + for field_update in fields: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Updating field %(field_id)s...") + % {"field_id": field_update.field_id} + ) + try: + field_item = helpers.update_field( + user, workspace, field_update, formula_fixer=formula_fixer ) + updated.append(field_item.model_dump()) + except Exception as e: + errors.append(f"Error updating field {field_update.field_id}: {e}") + + result: dict[str, Any] = {"updated_fields": updated} + if errors: + result["errors"] = errors + return result + + +# --------------------------------------------------------------------------- +# Tool 8: delete_fields +# --------------------------------------------------------------------------- + + +def delete_fields( + ctx: RunContext[AssistantDeps], + field_ids: Annotated[ + list[int], + Field(description="List of field IDs to delete."), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Delete fields (moves them to trash). + + WHEN to use: User wants to remove fields from a table. + WHAT it does: Soft-deletes fields (moves to trash, can be restored). Primary fields cannot be deleted. + RETURNS: List of deleted field IDs. + DO NOT USE when: You want to change a field — use update_fields instead. + HOW: Call get_tables_schema first to confirm field IDs. + """ - observation.append(f"Table '{table.name}' (ID: {table.id}):") + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - if "create" in operations: - create_rows = table_tools["create"] - new_tools.append(create_rows) - observation.append(f"- Use {create_rows.name} to create new rows.") + if not field_ids: + return {"deleted_field_ids": [], "errors": []} - if "update" in operations: - update_rows = table_tools["update"] - new_tools.append(update_rows) - observation.append( - f"- Use {update_rows.name} to update existing rows by their IDs." - ) + deleted = [] + errors = [] - if "delete" in operations: - delete_rows = table_tools["delete"] - new_tools.append(delete_rows) - observation.append( - f"- Use {delete_rows.name} to delete rows by their IDs." - ) + with transaction.atomic(): + for field_id in field_ids: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Deleting field %(field_id)s...") % {"field_id": field_id} + ) + try: + helpers.delete_field(user, workspace, field_id) + deleted.append(field_id) + except Exception as e: + errors.append(f"Error deleting field {field_id}: {e}") + + result: dict[str, Any] = {"deleted_field_ids": deleted} + if errors: + result["errors"] = errors + return result + + +# --------------------------------------------------------------------------- +# Tool 9: create_views +# --------------------------------------------------------------------------- + + +def create_views( + ctx: RunContext[AssistantDeps], + table_id: Annotated[ + int, Field(description="The ID of the table to create views for.") + ], + views: Annotated[ + list[ViewItemCreate], + Field( + description="List of views to create (grid, form, gallery, kanban, calendar, timeline)." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Create views (grid, form, gallery, kanban, calendar, timeline). + + WHEN to use: User wants a new view (grid, form, gallery, kanban, calendar, timeline) on a table. + WHAT it does: Creates views in the table. View names must be unique. A default grid view is auto-created with every new table — no need to recreate it. + RETURNS: Created views with id, name, type configuration. + DO NOT USE when: The default grid view already meets the user's needs. Check existing views with list_views to avoid duplicates. + HOW: Each view type requires specific config. Form views: provide field_options listing every field to show (field_id, name, order, required). Kanban: set column_field_id to a single_select field. Calendar: set date_field_id to a date field. Timeline: set both start/end date fields. Gallery: optionally set cover_field_id to a file field. Call get_tables_schema first to get the field IDs you need. + """ - observation.append("") + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) + if not views: + return {"created_views": []} - return _load_rows_tools + table = helpers.get_table(user, workspace, table_id) - return load_rows_tools + created_views = [] + with transaction.atomic(): + for view in views: + tool_helpers.raise_if_cancelled() + tool_helpers.update_status( + _("Creating %(view_type)s view %(view_name)s") + % {"view_type": view.type, "view_name": view.name} + ) + orm_view = CreateViewActionType.do( + user, + table, + view.type, + **view.to_django_orm_kwargs(table), + ) -class RowsToolFactoryToolType(AssistantToolType): - type = "rows_tool_factory" + field_options = view.field_options_to_django_orm() + if field_options: + UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_rows_tools_factory(user, workspace, tool_helpers) + created_views.append({"id": orm_view.id, **view.model_dump()}) + tool_helpers.navigate_to( + ViewNavigationType( + type="database-view", + database_id=table.database_id, + table_id=table.id, + view_id=created_views[0]["id"], + view_name=created_views[0]["name"], + view_type=created_views[0]["type"], + ) + ) -def get_list_views_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int], list[dict[str, Any]]]: + return {"created_views": created_views} + + +# --------------------------------------------------------------------------- +# Tool 8: create_view_filters +# --------------------------------------------------------------------------- + + +def create_view_filters( + ctx: RunContext[AssistantDeps], + view_filters: Annotated[ + list[ViewFiltersArgs], + Field( + description="List of view filter configurations, each specifying a view ID and its filters." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, Any]: + """\ + Add filters to views. + + WHEN to use: User wants to filter a view to show only specific rows matching conditions. + WHAT it does: Creates filter conditions on one or more views. Supports multiple filters per view. + RETURNS: Created filters with id and configuration per view. + DO NOT USE when: The view doesn't exist yet — create it first with create_views. + HOW: Get the table schema first to know field IDs and types. Match filter type to field type. + + ## Value formats by type + + - text: string + - number: number + - date: ISO date string (mode=exact_date) or integer (mode=nr_days_ago etc.) or "" (mode=today etc.) + - single_select / multiple_select: list of option label strings (matched case-insensitively) + - link_row: row ID (integer) + - boolean: true / false """ - Returns a function that lists all the views in a given table the user has - access to in the current workspace. - """ - - def list_views(table_id: int) -> list[dict[str, Any]]: - """ - List views in the specified table. - - - Always call this for existing tables to avoid creating views with duplicate - names. - """ - nonlocal user, workspace, tool_helpers + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - table = utils.filter_tables(user, workspace).get(id=table_id) + if not view_filters: + return {"created_view_filters": []} + created_view_filters = [] + for vf in view_filters: + tool_helpers.raise_if_cancelled() + orm_view = helpers.get_view(user, workspace, vf.view_id) tool_helpers.update_status( - _("Listing views in %(table_name)s...") % {"table_name": table.name} + _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} ) - views = ViewHandler().list_views( - user, - table, - filters=False, - sortings=False, - decorations=False, - group_bys=False, - limit=100, - ) - - return { - "views": [ - view_item_registry.from_django_orm(view).model_dump() for view in views - ] - } - - return list_views - - -class ListViewsToolType(AssistantToolType): - type = "list_views" - - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_list_views_tool(user, workspace, tool_helpers) - + fields = {f.id: f for f in orm_view.table.field_set.all()} + created_filters = [] + with transaction.atomic(): + for filter in vf.filters: + try: + orm_filter = helpers.create_view_filter( + user, orm_view, fields, filter + ) + except ValueError as e: + logger.warning(f"Skipping filter creation: {e}") + continue + + created_filters.append({"id": orm_filter.id, **filter.model_dump()}) + created_view_filters.append({"view_id": vf.view_id, "filters": created_filters}) + + return {"created_view_filters": created_view_filters} + + +# --------------------------------------------------------------------------- +# Tool 9: generate_formula +# --------------------------------------------------------------------------- + + +def generate_formula( + ctx: RunContext[AssistantDeps], + database_id: Annotated[ + int, + Field( + description="The ID of the database containing the tables for the formula." + ), + ], + description: Annotated[ + str, + Field( + description="A natural language description of what the formula should compute." + ), + ], + save_to_field: Annotated[ + bool, + Field( + description="If true, save the formula to a field. If false, only return it. Should be true unless explicitly asked otherwise." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> dict[str, str]: + """\ + Generate a formula from a natural-language description and save it. + + WHEN to use: User needs a computed field (formulas, calculations, cross-table lookups). No need to inspect the schema first — this tool does it automatically. + WHAT it does: Generates a valid Baserow formula from a natural-language description. Finds the best table and fields automatically. Saves to a formula field by default (save_to_field=true). + RETURNS: Generated formula string, formula type, and field details (name, table, operation). + DO NOT USE when: The user wants a simple non-formula field — use create_fields instead. + HOW: Describe what the formula should compute in plain language. The tool auto-discovers the table schema — no need to inspect it first. + """ + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) -def get_views_tool_factory( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[int, list[str]], list[str]]: - def create_view_filters( - view_filters: list[ViewFiltersArgs], - ) -> list[AnyViewFilterItem]: - """ - Creates filters in the specified views. - """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers - nonlocal user, workspace, tool_helpers + database_tables = helpers.filter_tables(user, workspace).filter( + database_id=database_id + ) + database_tables_schema = [ + t.model_dump() for t in helpers.get_tables_schema(database_tables, True) + ] - if not view_filters: - return [] + tool_helpers.update_status(_("Generating formula...")) - created_view_filters = [] - for vf in view_filters: - orm_view = utils.get_view(user, vf.view_id) - tool_helpers.update_status( - _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} - ) + formula_docs = get_formula_docs() + formula_type_tool = Tool(get_formula_type_tool(user, workspace)) + formula_toolset = FunctionToolset([formula_type_tool]) - fields = {f.id: f for f in orm_view.table.field_set.all()} - created_filters = [] - with transaction.atomic(): - for filter in vf.filters: - try: - orm_filter = utils.create_view_filter( - user, orm_view, fields, filter - ) - except ValueError as e: - logger.warning(f"Skipping filter creation: {e}") - continue - - created_filters.append({"id": orm_filter.id, **filter.model_dump()}) - created_view_filters.append( - {"view_id": vf.view_id, "filters": created_filters} - ) - - return {"created_view_filters": created_view_filters} + prompt = format_formula_generation_prompt( + description, database_tables_schema, formula_docs + ) - def create_views( - table_id: int, views: list[AnyViewItemCreate] - ) -> list[dict[str, Any]]: - """ - Creates views in the specified table. A default grid view showing all the rows - is created automatically when a table is created, no need to recreate it. + model = get_model_string() + agent_result = formula_generation_agent.run_sync( + prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + toolsets=[formula_toolset], + usage_limits=UsageLimits(request_limit=20), + ) + result = agent_result.output - - Choose the most appropriate view type for each view. - - View names must be unique within a table: check existing names when needed and - avoid duplicates. - """ + if not result.is_formula_valid: + raise Exception(f"Error generating formula: {result.error_message}") - nonlocal user, workspace, tool_helpers + table = next((t for t in database_tables if t.id == result.table_id), None) + if table is None: + raise Exception( + "The generated formula is intended for a different table " + f"than the current one. Table with ID {result.table_id} not found." + ) - if not views: - return [] + data = { + "formula": result.formula, + "formula_type": result.formula_type, + } + field_name = result.field_name - table = utils.filter_tables(user, workspace).get(id=table_id) + if save_to_field: + field = table.field_set.filter(name=field_name).first() + if field: + field = field.specific - created_views = [] with transaction.atomic(): - for view in views: - tool_helpers.update_status( - _("Creating %(view_type)s view %(view_name)s") - % {"view_type": view.type, "view_name": view.name} - ) - - orm_view = CreateViewActionType.do( + # Trash any existing non-formula field so it can be replaced, allowing + # the user to easily restore the original field if needed. + if field and field_type_registry.get_by_model(field).type != "formula": + DeleteFieldActionType.do(user, field) + field = None + + if field is None: + CreateFieldActionType.do( user, table, - view.type, - **view.to_django_orm_kwargs(table), + type_name="formula", + name=field_name, + formula=result.formula, + ) + operation = "field created" + else: + # Only update the formula of an existing formula field. + UpdateFieldActionType.do( + user, + field, + formula=result.formula, + ) + operation = "field updated" + + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=table.database_id, + table_id=table.id, + table_name=table.name, ) - - field_options = view.field_options_to_django_orm() - if field_options: - UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) - - created_views.append({"id": orm_view.id, **view.model_dump()}) - - tool_helpers.navigate_to( - ViewNavigationType( - type="database-view", - database_id=table.database_id, - table_id=table.id, - view_id=created_views[0]["id"], - view_name=created_views[0]["name"], - view_type=created_views[0]["type"], ) - ) - - return {"created_views": created_views} - - def load_views_tools(): - """ - TOOL LOADER: Loads tools to manage views and filters - (grid, gallery, form, kanban, calendar and timeline). - - After calling this loader, you will be able to: - - create_views: Create grid, gallery, form, kanban, calendar and timeline views - - create_view_filters: Create filters for specific views to filter rows - - Use this when you need to create views or filters but don't have the tools yet. - """ - @udspy.module_callback - def _load_views_tools(context): - nonlocal user, workspace, tool_helpers - - observation = ["New tools are now available.\n"] - - create_tool = udspy.Tool(create_views) - new_tools = [create_tool] - observation.append("- Use `create_views` to create views.") - - create_filters_tool = udspy.Tool(create_view_filters) - new_tools.append(create_filters_tool) - observation.append( - "- Use `create_view_filters` to create filters in views." + data.update( + { + "table_id": table.id, + "table_name": table.name, + "field_name": result.field_name, + "operation": operation, + } ) - # Re-initialize the module with the new tools for the next iteration - context.module.init_module(tools=context.module._tools + new_tools) - return "\n".join(observation) - - return _load_views_tools + return data - return load_views_tools +# --------------------------------------------------------------------------- +# Dynamic row tools (create / update / delete) +# --------------------------------------------------------------------------- -class ViewsToolFactoryToolType(AssistantToolType): - type = "views_tool_factory" - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_views_tool_factory(user, workspace, tool_helpers) - - -def get_formula_type_tool( - user: AbstractUser, workspace: Workspace -) -> Callable[[str], str]: +def _build_row_tools( + user: AbstractUser, + workspace: Workspace, + tool_helpers: "ToolHelpers", + table: Table, + field_ids: list[int] | None = None, +) -> dict[str, Tool]: """ - Returns a function that returns the type of a formula. + Build pydantic-ai Tool objects for row CRUD on a single table. + + Returns a dict with keys ``"create"``, ``"update"``, ``"delete"``, each + containing a ready-to-use ``Tool`` whose schema is derived from the table's + fields. + + :param user: The acting user. + :param workspace: Current workspace. + :param tool_helpers: Provides status updates and cancellation. + :param table: The table to build row tools for. + :param field_ids: If given, only include these field IDs in the + create model (useful for excluding reverse link_row fields). """ - def get_formula_type(table_id: int, field_name: str, formula: str) -> str: - """ - Returns the type of a formula. Raises an exception if the formula is not valid. - **ALWAYS** call this to validate a formula is valid before returning it. - """ - - nonlocal user, workspace + row_model_for_create = get_create_row_model(table, field_ids=field_ids) + row_model_for_update = get_update_row_model(table) + link_row_hints = get_link_row_hints(row_model_for_create) - table = utils.filter_tables(user, workspace).get(id=table_id) - field = FormulaField(formula=formula, table=table, name=field_name, order=0) - field.recalculate_internal_fields(raise_if_invalid=True) + def _create_rows( + rows: list[row_model_for_create], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Create new rows in the specified table.""" - result = TypeFormulaResultSerializer(field).data - if result["error"]: - raise Exception(f"Invalid formula: {result['error']}") - - return result["formula_type"] - - return get_formula_type + if not rows: + return {"created_row_ids": []} + tool_helpers.update_status( + _("Creating rows in %(table_name)s ") % {"table_name": table.name} + ) -class FormulaGenerationSignature(udspy.Signature): - """ - Generates a Baserow formula based on the provided description and table schema. - """ + validated_rows = [row.to_django_orm() for row in rows] - description: str = udspy.InputField( - desc="A brief description of what the formula should do." - ) - tables_schema: dict = udspy.InputField( - desc="The schema of all the tables in the database." - ) - formula_documentation: str = udspy.InputField( - desc="Documentation about Baserow formulas and their syntax." - ) - table_id: int = udspy.OutputField( - desc=( - "The ID of the table the formula is intended for. " - "Should be the same as current_table_id, unless the formula can " - "only be created in a different table." - ) - ) - field_name: str = udspy.OutputField( - desc="The name of the formula field to be created. For a new field, it must be unique in the table." - ) - formula: str = udspy.OutputField( - desc="The generated formula. Must be a valid Baserow formula." - ) - formula_type: str = udspy.OutputField( - desc="The type of the generated formula. Must be one of: text, long_text, " - "number, boolean, date, link_row, single_select, multiple_select, duration, array." - ) - is_formula_valid: bool = udspy.OutputField( - desc="Whether the generated formula is valid or not." + with transaction.atomic(): + orm_rows = CreateRowsActionType.do(user, table, validated_rows) + + return {"created_row_ids": [r.id for r in orm_rows]} + + create_rows_tool = Tool( + _create_rows, + name=f"create_rows_in_table_{table.id}", + description=( + f"WHEN: Creating new rows in '{table.name}' (ID: {table.id}). " + f"WHAT: Inserts up to 20 rows with field values matching the table schema. " + f"RETURNS: Created row IDs. " + f"DO NOT USE: For other tables — each table has its own create tool. " + f"HOW: Fill EVERY field including ALL link_row (relationship) fields. Never skip a field unless data is genuinely unavailable." + f"{link_row_hints}" + ), + max_retries=2, ) - error_message: str = udspy.OutputField( - desc="If the formula is not valid, an error message explaining why." + create_rows_tool.function_schema.json_schema = inline_refs( + create_rows_tool.function_schema.json_schema ) + def _update_rows( + rows: list[row_model_for_update], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Update existing rows in the specified table.""" -def get_generate_database_formula_tool( - user: AbstractUser, - workspace: Workspace, - tool_helpers: "ToolHelpers", -) -> Callable[[str, int], dict[str, str]]: - """ - Returns a function that generates a formula for a given field in a table. - """ + if not rows: + return {"updated_row_ids": []} - def generate_database_formula( - database_id: int, - description: str, - save_to_field: bool = True, - ) -> dict[str, str]: - """ - Generate a database formula for a formula field. No need to inspect the schema - before, this tool will do it automatically and find the best table and fields to - use. - - - table_id: The database ID where the formula field is located. - - description: A brief description of what the formula should do. - - save_to_field: Whether to save the generated formula to a field with the given - name (default: True). If False, the formula will be generated but not saved - into a field. - """ - - nonlocal user, workspace, tool_helpers - - database_tables = utils.filter_tables(user, workspace).filter( - database_id=database_id + tool_helpers.update_status( + _("Updating rows in %(table_name)s ") % {"table_name": table.name} ) - database_tables_schema = [ - t.model_dump() for t in utils.get_tables_schema(database_tables, True) - ] - - tool_helpers.update_status(_("Generating formula...")) - - formula_docs = get_formula_docs() - formula_generator = udspy.ReAct( - FormulaGenerationSignature, - tools=[get_formula_type_tool(user, workspace)], - max_iters=20, - ) - result = formula_generator( - description=description, - tables_schema={"tables": database_tables_schema}, - formula_documentation=formula_docs, - ) + validated_rows = [row.to_django_orm() for row in rows] - if not result.is_formula_valid: - raise Exception(f"Error generating formula: {result.error_message}") + with transaction.atomic(): + orm_rows = UpdateRowsActionType.do(user, table, validated_rows).updated_rows + + return {"updated_row_ids": [r.id for r in orm_rows]} + + update_rows_tool = Tool( + _update_rows, + name=f"update_rows_in_table_{table.id}", + description=( + f"WHEN: Updating existing rows in '{table.name}' (ID: {table.id}) by row ID. " + f"WHAT: Updates specified fields on up to 20 rows. Only include fields you want to change — omit fields to keep them unchanged. " + f"RETURNS: Updated row IDs. " + f"DO NOT USE: For other tables — each table has its own update tool." + f"{link_row_hints}" + ), + max_retries=2, + ) + update_rows_tool.function_schema.json_schema = inline_refs( + update_rows_tool.function_schema.json_schema + ) - table = next((t for t in database_tables if t.id == result.table_id), None) - if table is None: - raise Exception( - "The generated formula is intended for a different table " - f"than the current one. Table with ID {result.table_id} not found." - ) + def _delete_rows( + row_ids: list[int], + thought: Annotated[str, "Brief reasoning for calling this tool."], + ) -> dict[str, Any]: + """Delete rows in the specified table.""" - data = { - "formula": result.formula, - "formula_type": result.formula_type, - } - field_name = result.field_name - - if save_to_field: - field = table.field_set.filter(name=field_name).first() - if field: - field = field.specific - - with transaction.atomic(): - # Trash any existing non-formula field so it can be replaced, allowing - # the user to easily restore the original field if needed. - if field and field_type_registry.get_by_model(field).type != "formula": - DeleteFieldActionType.do(user, field) - field = None - - if field is None: - CreateFieldActionType.do( - user, - table, - type_name="formula", - name=field_name, - formula=result.formula, - ) - operation = "field created" - else: - # Only update the formula of an existing formula field. - UpdateFieldActionType.do( - user, - field, - formula=result.formula, - ) - operation = "field updated" - - tool_helpers.navigate_to( - TableNavigationType( - type="database-table", - database_id=table.database_id, - table_id=table.id, - table_name=table.name, - ) - ) + if not row_ids: + return {"deleted_row_ids": []} - data.update( - { - "table_id": table.id, - "table_name": table.name, - "field_name": result.field_name, - "operation": operation, - } - ) + tool_helpers.update_status( + _("Deleting rows in %(table_name)s ") % {"table_name": table.name} + ) - return data + with transaction.atomic(): + DeleteRowsActionType.do(user, table, row_ids) + + return {"deleted_row_ids": row_ids} + + delete_rows_tool = Tool( + _delete_rows, + name=f"delete_rows_in_table_{table.id}", + description=( + f"WHEN: Deleting rows from '{table.name}' (ID: {table.id}) by row ID. " + f"WHAT: Permanently removes up to 20 specified rows. " + f"RETURNS: Deleted row IDs. " + f"DO NOT USE: For other tables — each table has its own delete tool." + ), + ) - return generate_database_formula + return { + "create": create_rows_tool, + "update": update_rows_tool, + "delete": delete_rows_tool, + } + + +# --------------------------------------------------------------------------- +# Tool 10: load_row_tools +# --------------------------------------------------------------------------- + + +def load_row_tools( + ctx: RunContext[AssistantDeps], + table_ids: Annotated[ + list[int], Field(description="List of table IDs to load row tools for.") + ], + operations: Annotated[ + list[Literal["create", "update", "delete"]], + Field( + description="Which row operations to enable: 'create', 'update', and/or 'delete'." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + TOOL LOADER — unlocks create/update/delete row tools for directly manipulating DATABASE rows. No need to know the schema beforehand, the loaded tools include it. + + WHEN to use: You need to directly create, update, or delete rows in a database table. Must be called before any row manipulation. + WHAT it does: Unlocks table-specific tools and their schema: create_rows_in_table_X, update_rows_in_table_X, delete_rows_in_table_X for each table ID provided. The loaded tools include the full field schema — no need to call get_tables_schema. + RETURNS: Names of newly available tools. + DO NOT USE when: Row tools for these tables are already loaded from a previous call in this session. + DO NOT USE for builder workflow actions — if you want a button/form in an Application Builder page to create/update/delete rows, use create_actions instead. load_row_tools is for direct database manipulation, NOT for configuring app behavior. + HOW: Just call this with the table ID(s) and operations you need. The loaded row tools already contain the complete field schema in their parameters — do NOT call get_tables_schema or search_user_docs before or after this tool. + + EXAMPLES: + - "Create 5 rows" → load_row_tools([table_id], ["create"]) → create_rows_in_table_X(rows=[...]) + - "Update row 7" → load_row_tools([table_id], ["update"]) → update_rows_in_table_X(rows=[{id: 7, ...}]) + - "Delete rows 1-3" → load_row_tools([table_id], ["delete"]) → delete_rows_in_table_X(row_ids=[1,2,3]) + - To find linked row values, use list_rows with field_ids filter on the linked table. + """ + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers -class GenerateDatabaseFormulaToolType(AssistantToolType): - type = "generate_database_formula" + tables = helpers.filter_tables(user, workspace).filter(id__in=table_ids) + if not tables: + return ( + "No valid tables found for the given IDs. " + "Make sure the table IDs are correct." + ) - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_generate_database_formula_tool(user, workspace, tool_helpers) + new_tools: list[Tool] = [] + for table in tables: + table_tools = _build_row_tools(user, workspace, tool_helpers, table) + + if "create" in operations: + new_tools.append(table_tools["create"]) + if "update" in operations: + new_tools.append(table_tools["update"]) + if "delete" in operations: + new_tools.append(table_tools["delete"]) + + # Store new tools in dynamic_tools for the dynamic toolset + # to pick up on the next agent step + ctx.deps.dynamic_tools.extend(new_tools) + + tool_names = [t.name for t in new_tools] + return f"Tools loaded: {', '.join(tool_names)}" + + +# --------------------------------------------------------------------------- +# Module-level toolset +# --------------------------------------------------------------------------- + + +TOOL_FUNCTIONS = [ + list_tables, + get_tables_schema, + list_rows, + list_views, + create_tables, + create_fields, + update_fields, + delete_fields, + create_views, + create_view_filters, + generate_formula, + load_row_tools, +] +database_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) + +ROUTING_RULES = """\ +- Check list_* before create_* to avoid duplicates. +- switch_mode: switch domain if task needs tools not in the current mode. +- Database row CRUD → call load_row_tools first (includes schema — skip get_tables_schema). +- create_tables: include ALL related tables in one call so link_row fields connect properly. Add sample rows unless told otherwise. +- create_rows: fill EVERY field including ALL link_row fields. +- After creating tables for an app/automation task, switch_mode back to continue building.""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py index 7406ef8204..ddc43397e1 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py @@ -1,5 +1,6 @@ from .base import * # noqa: F401, F403 from .fields import * # noqa: F401, F403 +from .rows import * # noqa: F401, F403 from .table import * # noqa: F401, F403 from .view_filters import * # noqa: F401, F403 from .views import * # noqa: F401, F403 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py index 83e0219545..63d34e3656 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py @@ -1,35 +1,49 @@ from datetime import date, datetime -from baserow_enterprise.assistant.types import BaseModel +from dateutil import parser as _dateutil_parser -# Somehow LLMs struggle with dates -class Date(BaseModel): - year: int - month: int - day: int +def _normalize(value: str) -> str: + """Replace common separator variants so fromisoformat can parse them.""" - def to_django_orm(self): - return date(self.year, self.month, self.day).isoformat() + return value.replace("/", "-").strip() - @classmethod - def from_django_orm(cls, orm_date: date) -> "Date": - d = orm_date - return cls(year=d.year, month=d.month, day=d.day) +def parse_date(value: str) -> date: + """ + Parse a date string into a date object. -class Datetime(Date): - hour: int - minute: int + Tries ISO 8601 first, then falls back to dateutil for fuzzy formats + like ``Jan 5, 2023`` or ``05/01/2023``. + """ - def to_django_orm(self): - return datetime( - self.year, self.month, self.day, self.hour, self.minute - ).isoformat() + try: + return date.fromisoformat(_normalize(value)) + except ValueError: + return _dateutil_parser.parse(value).date() - @classmethod - def from_django_orm(cls, orm_datetime: datetime) -> "Datetime": - dt = orm_datetime - return cls( - year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, minute=dt.minute - ) + +def parse_datetime(value: str) -> datetime: + """ + Parse a datetime string into a datetime object. + + Tries ISO 8601 first, then falls back to dateutil for fuzzy formats + like ``Jan 5, 2023 10:00 AM``. + """ + + try: + return datetime.fromisoformat(_normalize(value)) + except ValueError: + return _dateutil_parser.parse(value) + + +def format_date(value: date) -> str: + """Format a date as ISO 8601 (``YYYY-MM-DD``).""" + + return value.isoformat() + + +def format_datetime(value: datetime) -> str: + """Format a datetime as ISO 8601 (``YYYY-MM-DDTHH:MM``), without seconds.""" + + return value.strftime("%Y-%m-%dT%H:%M") diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py index 742e212c6c..8829a91639 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py @@ -1,253 +1,18 @@ -from typing import Annotated, Literal, Type +import json +from typing import Any, Literal from django.db.models import Q -from pydantic import Field - -from baserow.contrib.database.fields.models import ( - DateField, - FormulaField, - LinkRowField, - LookupField, - MultipleSelectField, - NumberField, - RatingField, - SingleSelectField, -) +from pydantic import Field, model_serializer, model_validator + from baserow.contrib.database.fields.models import Field as BaserowField from baserow.contrib.database.fields.registries import field_type_registry from baserow_enterprise.assistant.types import BaseModel -from baserow_enterprise.data_sync.hubspot_contacts_data_sync import LongTextField from baserow_premium.permission_manager import Table - -class FieldItemCreate(BaseModel): - """Base model for creating a new field (no ID).""" - - name: str = Field(...) - type: str = Field(...) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return {k: v for k, v in self.model_dump().items() if k not in {"id", "type"}} - - -class FieldItem(FieldItemCreate): - """Model for an existing field (with ID).""" - - id: int = Field(...) - - @classmethod - def from_django_orm(cls, orm_field: BaserowField) -> "FieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type=field_type_registry.get_by_model(orm_field).type, - ) - - -# Event if type could be inferred, certain models (i.e. openai-gpt-oss-120b) requires -# all the fields to be required and can cause issues with optional fields, so we -# explicitly set them as required, even if seems unnecessary. - - -class BaseTextFieldItem(FieldItemCreate): - type: Literal["text"] = Field(..., description="Single line text field.") - - -class TextFieldItemCreate(BaseTextFieldItem): - """Model for creating a text field.""" - - -class TextFieldItem(BaseTextFieldItem, FieldItem): - """Model for an existing text field.""" - - -class BaseLongTextFieldItem(FieldItemCreate): - type: Literal["long_text"] = Field( - ..., - description="Multi-line text field. Ideal for descriptions, notes and long-form content.", - ) - rich_text: bool = Field( - default=True, - description="Whether the long text field supports rich text.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "long_text_enable_rich_text": self.rich_text, - } - - -class LongTextFieldItemCreate(BaseLongTextFieldItem): - """Model for creating a long text field.""" - - -class LongTextFieldItem(BaseLongTextFieldItem, FieldItem): - """Model for an existing long text field.""" - - @classmethod - def from_django_orm(cls, orm_field: LongTextField) -> "LongTextFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="long_text", - rich_text=orm_field.long_text_enable_rich_text, - ) - - -class BaseNumberFieldItem(FieldItemCreate): - type: Literal["number"] = Field( - ..., description="Numeric field, with decimals and optional prefix/suffix." - ) - decimal_places: int = Field(default=2, description="The number of decimal places.") - suffix: str = Field( - default="", - description="An optional suffix to display after the number.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "number_decimal_places": self.decimal_places, - "number_suffix": self.suffix, - } - - -class NumberFieldItemCreate(BaseNumberFieldItem): - """Model for creating a number field.""" - - -class NumberFieldItem(BaseNumberFieldItem, FieldItem): - """Model for an existing number field.""" - - @classmethod - def from_django_orm(cls, orm_field: NumberField) -> "NumberFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="number", - decimal_places=orm_field.number_decimal_places, - suffix=orm_field.number_suffix, - ) - - -class BaseRatingFieldItem(FieldItemCreate): - type: Literal["rating"] = Field( - ..., description="Rating field. Ideal for reviews or scores." - ) - max_value: int = Field( - default=5, description="The maximum value of the rating field." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "max_value": self.max_value, - } - - -class RatingFieldItemCreate(BaseRatingFieldItem): - """Model for creating a rating field.""" - - -class RatingFieldItem(BaseRatingFieldItem, FieldItem): - """Model for an existing rating field.""" - - @classmethod - def from_django_orm(cls, orm_field: RatingField) -> "RatingFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="rating", - max_value=orm_field.max_value, - ) - - -class BaseBooleanFieldItem(FieldItemCreate): - type: Literal["boolean"] = Field(..., description="Boolean field.") - - -class BooleanFieldItemCreate(BaseBooleanFieldItem): - """Model for creating a boolean field.""" - - -class BooleanFieldItem(BaseBooleanFieldItem, FieldItem): - """Model for an existing boolean field.""" - - -class BaseDateFieldItem(FieldItemCreate): - type: Literal["date"] = Field(..., description="Date or datetime field.") - include_time: bool = Field( - default=False, description="Whether the date field includes time." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "date_include_time": self.include_time, - } - - -class DateFieldItemCreate(BaseDateFieldItem): - """Model for creating a date field.""" - - -class DateFieldItem(BaseDateFieldItem, FieldItem): - """Model for an existing date field.""" - - @classmethod - def from_django_orm(cls, orm_field: DateField) -> "DateFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="date", - include_time=orm_field.date_include_time, - ) - - -class BaseLinkRowFieldItem(FieldItemCreate): - type: Literal["link_row"] = Field( - ..., description="Link row field. It creates relationships between tables." - ) - linked_table: str | int = Field( - ..., description="The ID or the name of the table this field links to." - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - if isinstance(self.linked_table, str): - q = Q(name=self.linked_table, database=table.database) - else: - q = Q(id=self.linked_table, database=table.database) - - try: - link_row_table = Table.objects.get(q) - except Table.DoesNotExist: - raise ValueError( - f"The linked_table '{self.linked_table}' does not exist in the database." - "Ensure you provide a valid table name or ID." - ) - - return {"name": self.name, "link_row_table": link_row_table} - - -class LinkRowFieldItemCreate(BaseLinkRowFieldItem): - """Model for creating a link row field.""" - - -class LinkRowFieldItem(BaseLinkRowFieldItem, FieldItem): - """Model for an existing link row field.""" - - @classmethod - def from_django_orm(cls, orm_field: LinkRowField) -> "BaseLinkRowFieldItem": - return cls( - id=orm_field.id, - name=orm_field.name, - type="link_row", - linked_table=orm_field.link_row_table_id, - ) - +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- OptionColor = Literal[ "light-blue", @@ -301,8 +66,7 @@ class SelectOption(BaseModel): color: OptionColor -# Define a subset of colors to use when creating fields, so we don't confuse the model -# with too many options. +# Subset of colors for creation to avoid confusing the model OptionColorCreate = Literal[ "blue", "green", @@ -319,248 +83,591 @@ class SelectOption(BaseModel): class SelectOptionCreate(BaseModel): value: str - color: OptionColorCreate + color: OptionColorCreate | None = None + + +class InvalidFormulaFieldError(Exception): + """Raised when a formula field has an invalid formula.""" + + def __init__(self, field_name: str, formula: str, table: Table, error: str): + self.field_name = field_name + self.formula = formula + self.table = table + self.error = error + super().__init__(f"Invalid formula for field '{field_name}': {error}") + + +# --------------------------------------------------------------------------- +# Flat field types — single model, all type-specific fields optional +# --------------------------------------------------------------------------- + +FieldType = Literal[ + "text", + "long_text", + "number", + "rating", + "boolean", + "date", + "link_row", + "single_select", + "multiple_select", + "file", + "formula", + "lookup", +] +_TYPE_ALIASES: dict[str, str] = { + "string": "text", + "varchar": "text", + "rich_text": "long_text", + "richtext": "long_text", + "textarea": "long_text", + "integer": "number", + "int": "number", + "float": "number", + "decimal": "number", + "numeric": "number", + "checkbox": "boolean", + "bool": "boolean", + "datetime": "date", + "link": "link_row", + "relation": "link_row", + "relationship": "link_row", + "foreign_key": "link_row", + "fk": "link_row", + "select": "single_select", + "dropdown": "single_select", + "enum": "single_select", + "multi_select": "multiple_select", + "multiselect": "multiple_select", + "tags": "multiple_select", + "attachment": "file", + "upload": "file", + "image": "file", +} + +_SELECT_COLORS: list[str] = [ + "blue", + "green", + "cyan", + "orange", + "yellow", + "red", + "brown", + "purple", + "pink", + "gray", +] -class BaseSingleSelectFieldItem(FieldItemCreate): - type: Literal["single_select"] = Field( - ..., - description="Single select field. Allows users to choose one option from a list.", - ) +_KEY_ALIASES: dict[str, str] = { + "long_text_enable_rich_text": "rich_text", + "number_decimal_places": "decimal_places", + "number_suffix": "suffix", + "date_include_time": "include_time", + "link_row_table": "linked_table", + "link_row_table_id": "linked_table", + "through_field": "linked_table", + "through_field_id": "linked_table", + "target_field_id": "target_field", +} + +# Creation order: regular → link_row → lookup → formula +FIELD_ORDER: dict[str, int] = {"link_row": 1, "lookup": 2, "formula": 3} + +_FIELD_EXAMPLES: dict[str, dict] = { + "text": {"name": "Title", "type": "text"}, + "long_text": {"name": "Notes", "type": "long_text"}, + "number": {"name": "Price", "type": "number", "decimal_places": 2}, + "rating": {"name": "Stars", "type": "rating", "max_value": 5}, + "boolean": {"name": "Active", "type": "boolean"}, + "date": {"name": "Due Date", "type": "date"}, + "link_row": { + "name": "Project", + "type": "link_row", + "linked_table": "Projects", + }, + "single_select": { + "name": "Status", + "type": "single_select", + "options": [{"value": "Open", "color": "green"}], + }, + "multiple_select": { + "name": "Tags", + "type": "multiple_select", + "options": [{"value": "Important", "color": "red"}], + }, + "file": {"name": "Attachment", "type": "file"}, + "formula": { + "name": "Total", + "type": "formula", + "formula": "field('Price') * 2", + }, + "lookup": { + "name": "Client Name", + "type": "lookup", + "linked_table": "Clients", + "target_field": "Name", + }, +} + + +# --------------------------------------------------------------------------- +# to_django_orm builders: (FieldItemCreate, Table, user | None) -> dict +# --------------------------------------------------------------------------- + + +def _resolve_linked_table(linked_table_ref, table): + """Resolve a linked_table reference (name or ID) to a Table object.""" + + if isinstance(linked_table_ref, str): + q = Q(name=linked_table_ref, database=table.database) + else: + q = Q(id=linked_table_ref, database=table.database) + result = Table.objects.filter(q).order_by("id").first() + if not result: + raise ValueError( + f"Table '{linked_table_ref}' not found in the database. " + f"Ensure you provide a valid table name or ID." + ) + return result -class SingleSelectFieldItemCreate(BaseSingleSelectFieldItem): - options: list[SelectOptionCreate] = Field( - description="The list of options for the field. Use appropriate colors for each option.", - ) +def _simple_to_orm(f, table, user): + return {"name": f.name} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "select_options": [ - {"id": -i, "value": option.value, "color": option.color} - for (i, option) in enumerate(self.options, start=1) - ], - } +def _long_text_to_orm(f, table, user): + return {"name": f.name, "long_text_enable_rich_text": f.rich_text} -class SingleSelectFieldItem(BaseSingleSelectFieldItem, FieldItem): - options: list[SelectOption] = Field( - description="The list of options for the field.", - ) - @classmethod - def from_django_orm( - cls, orm_field: SingleSelectField - ) -> "BaseSingleSelectFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="single_select", - options=[ - SelectOption( - id=opt.id, - value=opt.value, - color=opt.color, - ) - for opt in field.select_options.all() - ], - ) +def _number_to_orm(f, table, user): + return { + "name": f.name, + "number_decimal_places": f.decimal_places, + "number_suffix": f.suffix, + "number_negative": True, + } -class BaseMultipleSelectFieldItem(FieldItemCreate): - type: Literal["multiple_select"] = Field( - ..., - description="Multiple select field. Allows users to choose multiple options from a list.", - ) +def _rating_to_orm(f, table, user): + return {"name": f.name, "max_value": f.max_value} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "select_options": [ - {"id": -i, "value": option.value, "color": option.color} - for (i, option) in enumerate(self.options, start=1) - ], - } +def _date_to_orm(f, table, user): + return {"name": f.name, "date_include_time": f.include_time} -class MultipleSelectFieldItemCreate(BaseMultipleSelectFieldItem): - options: list[SelectOptionCreate] = Field( - description="The list of options for the field. Use appropriate colors for each option.", - ) +def _link_row_to_orm(f, table, user): + linked = _resolve_linked_table(f.linked_table, table) + return {"name": f.name, "link_row_table": linked} -class MultipleSelectFieldItem(BaseMultipleSelectFieldItem, FieldItem): - options: list[SelectOption] = Field( - description="The list of options for the field.", - ) - @classmethod - def from_django_orm( - cls, orm_field: MultipleSelectField - ) -> "BaseMultipleSelectFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="multiple_select", - options=[ - SelectOption( - id=opt.id, - value=opt.value, - color=opt.color, - ) - for opt in field.select_options.all() - ], - ) +def _select_to_orm(f, table, user): + return { + "name": f.name, + "select_options": [ + { + "id": -i, + "value": opt.value, + "color": opt.color or _SELECT_COLORS[(i - 1) % len(_SELECT_COLORS)], + } + for i, opt in enumerate(f.options, start=1) + ], + } -class BaseFileFieldItem(FieldItemCreate): - type: Literal["file"] = Field(..., description="File field.") +def _formula_to_orm(f, table, user): + if f.formula: + from baserow.contrib.database.fields.models import FormulaField + from baserow.core.formula.parser.exceptions import BaserowFormulaException + try: + tmp = FormulaField(formula=f.formula, table=table, name=f.name, order=0) + tmp.recalculate_internal_fields(raise_if_invalid=True) + except BaserowFormulaException as e: + raise InvalidFormulaFieldError(f.name, f.formula, table, str(e)) -class FileFieldItemCreate(BaseFileFieldItem): - pass + return {"name": f.name, "formula": f.formula} -class FileFieldItem(BaseFileFieldItem, FieldItem): - pass +def _lookup_to_orm(f, table, user): + from baserow.contrib.database.fields.models import LinkRowField + linked = _resolve_linked_table(f.linked_table, table) -class FormulaFieldItemCreate(FieldItemCreate): - type: Literal["formula"] = Field(..., description="Formula field.") - formula: str = Field( - ..., - description="The formula to use in the field. It needs to be generated via the appropriate tool or use '' as placeholder.", + # Find existing link_row field pointing to linked table + through = ( + LinkRowField.objects.filter(table=table, link_row_table=linked) + .order_by("id") + .first() ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "formula": self.formula, - } + # Auto-create link_row if missing and user is available + if not through and user: + from baserow.contrib.database.fields.actions import CreateFieldActionType + through = CreateFieldActionType.do( + user, + table, + "link_row", + name=linked.name, + link_row_table=linked, + ) -class FormulaFieldItem(FormulaFieldItemCreate, FieldItem): - formula_type: str = Field(..., description="The type of the formula.") - array_formula_type: str | None = Field( - ..., - description=("If the formula type is 'array', the type of the array items."), - ) + if not through: + raise ValueError( + f"No link_row field to '{f.linked_table}' exists on this table. " + f"Create a link_row field first." + ) - @classmethod - def from_django_orm(cls, orm_field: FormulaField) -> "FormulaFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="formula", - formula=field.formula, - formula_type=field.formula_type, - array_formula_type=field.array_formula_type, + data: dict[str, Any] = {"name": f.name, "through_field_id": through.id} + if isinstance(f.target_field, str): + data["target_field_name"] = f.target_field + else: + data["target_field_id"] = f.target_field + return data + + +_TO_DJANGO_ORM = { + "text": _simple_to_orm, + "boolean": _simple_to_orm, + "file": _simple_to_orm, + "long_text": _long_text_to_orm, + "number": _number_to_orm, + "rating": _rating_to_orm, + "date": _date_to_orm, + "link_row": _link_row_to_orm, + "single_select": _select_to_orm, + "multiple_select": _select_to_orm, + "formula": _formula_to_orm, + "lookup": _lookup_to_orm, +} + + +# --------------------------------------------------------------------------- +# from_django_orm builders: (orm_field) -> dict of extra kwargs +# --------------------------------------------------------------------------- + + +def _select_options_from_orm(orm_field): + from typing import get_args + + valid_colors = set(get_args(OptionColor)) + return [ + SelectOption( + id=opt.id, + value=opt.value, + color=opt.color if opt.color in valid_colors else "blue", ) + for opt in orm_field.specific.select_options.all() + ] + + +_FROM_DJANGO_ORM: dict[str, Any] = { + "long_text": lambda f: {"rich_text": f.specific.long_text_enable_rich_text}, + "number": lambda f: { + "decimal_places": f.number_decimal_places, + "suffix": f.number_suffix, + }, + "rating": lambda f: {"max_value": f.max_value}, + "date": lambda f: {"include_time": f.date_include_time}, + "link_row": lambda f: {"linked_table": f.link_row_table_id}, + "single_select": lambda f: {"options": _select_options_from_orm(f)}, + "multiple_select": lambda f: {"options": _select_options_from_orm(f)}, + "formula": lambda f: { + "formula": f.specific.formula, + "formula_type": f.specific.formula_type, + "array_formula_type": f.specific.array_formula_type, + }, + "lookup": lambda f: { + "through_field": f.specific.through_field_id, + "target_field": f.specific.target_field_id, + "through_field_name": f.specific.through_field_name, + "target_field_name": f.specific.target_field_name, + }, +} + + +# --------------------------------------------------------------------------- +# FieldItemCreate +# --------------------------------------------------------------------------- -class LookupFieldItemCreate(FieldItemCreate): - type: Literal["lookup"] = Field(..., description="Lookup field.") - through_field: int | str = Field( - ..., description="The ID of the link row field to lookup through." +class FieldItemCreate(BaseModel): + """Flat model for creating a field: name + type + type-specific options.""" + + name: str = Field(..., description="The name of the field.") + type: FieldType = Field(..., description="The field type.") + + # (long_text) + rich_text: bool = Field( + True, description="(long_text) Whether the field supports rich text." ) - target_field: int | str = Field( - ..., description="The ID of the field to lookup on the linked table." + # (number) + decimal_places: int = Field( + 0, description="(number) Decimal places (0, 1, 2, ...)." + ) + suffix: str = Field( + "", description="(number) Suffix displayed after the number, or ''." + ) + # (rating) + max_value: int = Field(5, description="(rating) Maximum rating value.") + # (date) + include_time: bool = Field( + False, description="(date) Whether the date includes time." + ) + # (link_row, lookup) + linked_table: str | int | None = Field( + None, + description="(link_row, lookup) ID or name of the linked table.", + ) + # (single_select, multiple_select) + options: list[SelectOptionCreate] | None = Field( + None, + description="(single_select, multiple_select) List of options with colors.", + ) + # (formula) + formula: str = Field( + "", description="(formula) The formula expression, or '' as placeholder." + ) + # (lookup) + target_field: int | str | None = Field( + None, description="(lookup) ID or name of the field to look up." ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - data = {"name": self.name} - if isinstance(self.through_field, str): - data["through_field_name"] = self.through_field - else: - data["through_field_id"] = self.through_field - - if isinstance(self.target_field, str): - data["target_field_name"] = self.target_field - else: - data["target_field_id"] = self.target_field + @model_validator(mode="before") + @classmethod + def _normalize(cls, data): + if not isinstance(data, dict): + return data + + # Normalize type aliases + raw_type = data.get("type") + if isinstance(raw_type, str): + data["type"] = _TYPE_ALIASES.get(raw_type, raw_type) + + # Normalize key aliases + for old_key, new_key in _KEY_ALIASES.items(): + if old_key in data and new_key not in data: + data[new_key] = data.pop(old_key) + + # Convert string options to SelectOptionCreate dicts + if "options" in data and isinstance(data["options"], list): + normalized = [] + for i, opt in enumerate(data["options"]): + if isinstance(opt, str): + normalized.append( + {"value": opt, "color": _SELECT_COLORS[i % len(_SELECT_COLORS)]} + ) + else: + normalized.append(opt) + data["options"] = normalized return data + # Required fields per type: {type: [(attr_name, display_name), ...]} + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "link_row": [("linked_table", "linked_table")], + "single_select": [("options", "options")], + "multiple_select": [("options", "options")], + "lookup": [("linked_table", "linked_table"), ("target_field", "target_field")], + } -class LookupFieldItem(LookupFieldItemCreate, FieldItem): - through_field_name: str = Field( - ..., description="The name of the link row field to lookup through." - ) - target_field_name: str = Field( - ..., description="The name of the field to lookup on the linked table." - ) + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if not getattr(self, attr)] + if missing: + raise ValueError( + f"{self.type} requires {', '.join(missing)}. " + f"Example: {json.dumps(_FIELD_EXAMPLES[self.type])}" + ) + return self - @classmethod - def from_django_orm(cls, orm_field: LookupField) -> "LookupFieldItem": - field = orm_field.specific - return cls( - id=field.id, - name=field.name, - type="lookup", - through_field=field.through_field_id, - target_field=field.target_field_id, - through_field_name=field.through_field_name, - target_field_name=field.target_field_name, - ) + def to_django_orm_kwargs(self, table: Table, user=None) -> dict[str, Any]: + builder = _TO_DJANGO_ORM.get(self.type, _simple_to_orm) + return builder(self, table, user) -AnyFieldItemCreate = Annotated[ - TextFieldItemCreate - | LongTextFieldItemCreate - | NumberFieldItemCreate - | RatingFieldItemCreate - | BooleanFieldItemCreate - | DateFieldItemCreate - | LinkRowFieldItemCreate - | SingleSelectFieldItemCreate - | MultipleSelectFieldItemCreate - | FileFieldItemCreate - | FormulaFieldItemCreate - | LookupFieldItemCreate, - Field(discriminator="type"), -] +# --------------------------------------------------------------------------- +# FieldItem (read-back) +# --------------------------------------------------------------------------- -AnyFieldItem = ( - TextFieldItem - | LongTextFieldItem - | NumberFieldItem - | RatingFieldItem - | BooleanFieldItem - | DateFieldItem - | LinkRowFieldItem - | SingleSelectFieldItem - | MultipleSelectFieldItem - | FileFieldItem - | FormulaFieldItem - | LookupFieldItem - | FieldItem -) - - -class FieldItemsRegistry: - _registry = { - "text": TextFieldItem, - "long_text": LongTextFieldItem, - "number": NumberFieldItem, - "date": DateFieldItem, - "boolean": BooleanFieldItem, - "rating": RatingFieldItem, - "link_row": LinkRowFieldItem, - "single_select": SingleSelectFieldItem, - "multiple_select": MultipleSelectFieldItem, - "file": FileFieldItem, - "formula": FormulaFieldItem, - "lookup": LookupFieldItem, - } - def from_django_orm(self, orm_field: Type[BaserowField]) -> FieldItem: +class FieldItem(BaseModel): + """Existing field with ID — flat structure matching FieldItemCreate.""" + + id: int = Field(...) + name: str = Field(..., description="The name of the field.") + type: str = Field(..., description="The field type.") + + # Type-specific (populated per type, others excluded via exclude_none) + rich_text: bool | None = None + decimal_places: int | None = None + suffix: str | None = None + max_value: int | None = None + include_time: bool | None = None + linked_table: int | None = None + options: list[SelectOption] | None = None + formula: str | None = None + formula_type: str | None = None + array_formula_type: str | None = None + through_field: int | None = None + target_field: int | None = None + through_field_name: str | None = None + target_field_name: str | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} + + @classmethod + def from_django_orm(cls, orm_field: BaserowField) -> "FieldItem": field_type = field_type_registry.get_by_model(orm_field).type - field_class: FieldItem = self._registry.get(field_type, FieldItem) - return field_class.from_django_orm(orm_field) + kwargs: dict[str, Any] = { + "id": orm_field.id, + "name": orm_field.name, + "type": field_type, + } + builder = _FROM_DJANGO_ORM.get(field_type) + if builder: + kwargs.update(builder(orm_field)) + return cls(**kwargs) + + +# --------------------------------------------------------------------------- +# FieldItemUpdate +# --------------------------------------------------------------------------- + + +def _update_simple(f, field_type): + kwargs = {} + if f.name is not None: + kwargs["name"] = f.name + return kwargs + + +def _update_long_text(f, field_type): + kwargs = _update_simple(f, field_type) + if f.rich_text is not None: + kwargs["long_text_enable_rich_text"] = f.rich_text + return kwargs + + +def _update_number(f, field_type): + kwargs = _update_simple(f, field_type) + if f.decimal_places is not None: + kwargs["number_decimal_places"] = f.decimal_places + if f.suffix is not None: + kwargs["number_suffix"] = f.suffix + return kwargs + + +def _update_rating(f, field_type): + kwargs = _update_simple(f, field_type) + if f.max_value is not None: + kwargs["max_value"] = f.max_value + return kwargs + + +def _update_date(f, field_type): + kwargs = _update_simple(f, field_type) + if f.include_time is not None: + kwargs["date_include_time"] = f.include_time + return kwargs + + +def _update_select(f, field_type): + kwargs = _update_simple(f, field_type) + if f.options is not None: + kwargs["select_options"] = [ + { + "id": -i, + "value": opt.value, + "color": opt.color or _SELECT_COLORS[(i - 1) % len(_SELECT_COLORS)], + } + for i, opt in enumerate(f.options, start=1) + ] + return kwargs + + +def _update_formula(f, field_type): + kwargs = _update_simple(f, field_type) + if f.formula is not None: + kwargs["formula"] = f.formula + return kwargs + + +_TO_UPDATE_ORM = { + "text": _update_simple, + "boolean": _update_simple, + "file": _update_simple, + "long_text": _update_long_text, + "number": _update_number, + "rating": _update_rating, + "date": _update_date, + "link_row": _update_simple, + "single_select": _update_select, + "multiple_select": _update_select, + "formula": _update_formula, + "lookup": _update_simple, +} + + +class FieldItemUpdate(BaseModel): + """Flat model for updating a field: field_id + optional type-specific fields.""" + + field_id: int = Field(..., description="The ID of the field to update.") + name: str | None = Field(None, description="New name for the field.") + + # (long_text) + rich_text: bool | None = Field( + None, description="(long_text) Whether the field supports rich text." + ) + # (number) + decimal_places: int | None = Field( + None, description="(number) Decimal places (0, 1, 2, ...)." + ) + suffix: str | None = Field( + None, description="(number) Suffix displayed after the number." + ) + # (rating) + max_value: int | None = Field(None, description="(rating) Maximum rating value.") + # (date) + include_time: bool | None = Field( + None, description="(date) Whether the date includes time." + ) + # (single_select, multiple_select) + options: list[SelectOptionCreate] | None = Field( + None, + description="(single_select, multiple_select) List of options with colors.", + ) + # (formula) + formula: str | None = Field(None, description="(formula) The formula expression.") + @model_validator(mode="before") + @classmethod + def _normalize_keys(cls, data): + if not isinstance(data, dict): + return data + for old_key, new_key in _KEY_ALIASES.items(): + if old_key in data and new_key not in data: + data[new_key] = data.pop(old_key) + # Convert string options to SelectOptionCreate dicts + if "options" in data and isinstance(data["options"], list): + normalized = [] + for i, opt in enumerate(data["options"]): + if isinstance(opt, str): + normalized.append( + {"value": opt, "color": _SELECT_COLORS[i % len(_SELECT_COLORS)]} + ) + else: + normalized.append(opt) + data["options"] = normalized + return data -field_item_registry = FieldItemsRegistry() + def to_update_kwargs(self, field_type: str) -> dict[str, Any]: + """Build kwargs for UpdateFieldActionType.do() based on the field's current type.""" + builder = _TO_UPDATE_ORM.get(field_type, _update_simple) + return builder(self, field_type) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py new file mode 100644 index 0000000000..f393275804 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/rows.py @@ -0,0 +1,383 @@ +""" +Dynamic Pydantic models for table row CRUD. + +Builds per-table create and update models whose fields mirror the table's +database columns, with converters to/from Django ORM representations. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Literal, Type + +from django.core.exceptions import ValidationError +from django.db.models import Q + +from pydantic import ConfigDict, Field, create_model + +from baserow.contrib.database.fields.field_types import LinkRowFieldType +from baserow.contrib.database.fields.models import SelectOption as OrmSelectOption +from baserow.contrib.database.table.models import ( + FieldObject, + GeneratedTableModel, + Table, +) +from baserow_enterprise.assistant.types import BaseModel + +from .base import format_date, format_datetime, parse_date, parse_datetime + + +@dataclass +class FieldDefinition: + """ + Pydantic field specification for a single table column. + + When ``type`` is None the field is unsupported and will be skipped + during model construction. + """ + + type: Type | None = None + field_def: Any | None = None + to_django_orm: Callable[[Any], Any] | None = None + from_django_orm: Callable[[Any], Any] | None = None + + +# --------------------------------------------------------------------------- +# Per-type builder functions +# --------------------------------------------------------------------------- + +# Shared converters for text-like fields +_none_to_empty = lambda v: v if v is not None else "" # noqa: E731 + + +def _text_field_def(orm_field, orm_field_type): + return FieldDefinition( + str | None, + Field(..., description="Single-line text", title=orm_field.name), + _none_to_empty, + _none_to_empty, + ) + + +def _long_text_field_def(orm_field, orm_field_type): + return FieldDefinition( + str | None, + Field(..., description="Multi-line text", title=orm_field.name), + _none_to_empty, + _none_to_empty, + ) + + +def _number_field_def(orm_field, orm_field_type): + return FieldDefinition( + float | None, + Field(..., description="Number or None", title=orm_field.name), + ) + + +def _boolean_field_def(orm_field, orm_field_type): + return FieldDefinition( + bool, Field(..., description="Boolean", title=orm_field.name) + ) + + +def _date_field_def(orm_field, orm_field_type): + if orm_field.date_include_time: + return FieldDefinition( + str | None, + Field( + ..., + description="ISO datetime (YYYY-MM-DDTHH:MM) or None", + title=orm_field.name, + ), + lambda v: parse_datetime(v).isoformat() if v else None, + lambda v: format_datetime(v) if v is not None else None, + ) + return FieldDefinition( + str | None, + Field(..., description="ISO date (YYYY-MM-DD) or None", title=orm_field.name), + lambda v: parse_date(v).isoformat() if v else None, + lambda v: format_date(v) if v is not None else None, + ) + + +def _single_select_field_def(orm_field, orm_field_type): + choices = [option.value for option in orm_field.select_options.all()] + if not choices: + return FieldDefinition() # Unsupported: no options defined + + return FieldDefinition( + Literal[*choices] | None, + Field( + ..., + description=f"One of: {', '.join(choices)} or None", + title=orm_field.name, + ), + lambda v: v if v in choices else None, + lambda v: v.value if isinstance(v, OrmSelectOption) else v, + ) + + +def _multiple_select_field_def(orm_field, orm_field_type): + choices = [option.value for option in orm_field.select_options.all()] + if not choices: + return FieldDefinition() # Unsupported: no options defined + + return FieldDefinition( + list[Literal[*choices]], + Field( + ..., + description=f"List of any of: {', '.join(choices)} or empty list", + title=orm_field.name, + ), + lambda v: [opt for opt in v if opt in choices], + lambda v: [opt.value for opt in v.all()] if v is not None else [], + ) + + +def _link_row_field_def(orm_field, orm_field_type): + linked_model = orm_field.link_row_table.get_model() + linked_primary_key = linked_model.get_primary_field() + if linked_primary_key is None: + return FieldDefinition() + + linked_pk = linked_primary_key.db_column + examples = list( + linked_model.objects.exclude( + Q(**{f"{linked_pk}__isnull": True}) | Q(**{f"{linked_pk}__exact": ""}) + ).values_list("id", linked_pk)[:10] + ) + + def to_django_orm(value): + if isinstance(value, (str, int)): + value = [value] + if value is not None: + try: + return LinkRowFieldType().prepare_value_for_db(orm_field, value) + except ValidationError: + pass + return [] + + def from_django_orm(value): + values = [str(v) for v in value.all()] + if orm_field.link_row_multiple_relationships: + return values + return values[0] if values else None + + if orm_field.link_row_multiple_relationships: + desc = "List of values (as strings) or IDs (as integers) from the linked table or empty list." + field_type = list[str | int] | None + else: + desc = "Single value (as string) or ID (as integer) from the linked table." + field_type = str | int | None + if examples: + desc += ( + " Examples: " + + ", ".join(f"{{id:{v[0]}, value: `{v[1]}`}}" for v in examples) + + ", .." + ) + return FieldDefinition( + field_type, + Field(..., description=desc, title=orm_field.name), + to_django_orm, + from_django_orm, + ) + + +_FIELD_DEF_BUILDERS: dict[str, Callable] = { + "text": _text_field_def, + "long_text": _long_text_field_def, + "number": _number_field_def, + "boolean": _boolean_field_def, + "date": _date_field_def, + "single_select": _single_select_field_def, + "multiple_select": _multiple_select_field_def, + "link_row": _link_row_field_def, +} + + +def get_field_definition(field_object: FieldObject) -> FieldDefinition: + """ + Return a :class:`FieldDefinition` for a table field, or an empty + (unsupported) definition if the field type has no registered builder. + """ + + orm_field_type = field_object["type"] + builder = _FIELD_DEF_BUILDERS.get(orm_field_type.type) + if builder is None: + return FieldDefinition() + return builder(field_object["field"], orm_field_type) + + +# --------------------------------------------------------------------------- +# Helpers shared by create / update models +# --------------------------------------------------------------------------- + +# field_conversions maps field names to (db_column, to_orm, from_orm) tuples. +FieldConversions = dict[str, tuple[str, Callable | None, Callable | None]] + + +def _scan_table_fields( + table: Table, field_ids: list[int] | None = None +) -> tuple[dict[str, tuple], FieldConversions]: + """ + Scan a table's fields and return Pydantic field specs plus ORM converters. + + :param table: The table to scan. + :param field_ids: If given, only include fields with these IDs. + :returns: ``(field_definitions, field_conversions)`` dicts keyed by field name. + """ + + field_definitions: dict[str, tuple] = {} + field_conversions: FieldConversions = {} + + for field_object in table.get_model().get_field_objects(): + fd = get_field_definition(field_object) + if fd.type is None: + continue + if field_ids is not None and field_object["field"].id not in field_ids: + continue + + field = field_object["field"] + field_definitions[field.name] = (fd.type, fd.field_def) + field_conversions[field.name] = ( + field.db_column, + fd.to_django_orm, + fd.from_django_orm, + ) + + return field_definitions, field_conversions + + +def _convert_fields( + items: dict[str, Any], field_conversions: FieldConversions +) -> dict[str, Any]: + """Convert a {field_name: value} mapping to {db_column: orm_value}.""" + + orm_data: dict[str, Any] = {} + for key, value in items.items(): + if key == "id": + orm_data["id"] = value + continue + if key not in field_conversions: + continue + orm_key, converter, _ = field_conversions[key] + orm_data[orm_key] = converter(value) if converter else value + return orm_data + + +# --------------------------------------------------------------------------- +# Row models +# --------------------------------------------------------------------------- + + +def get_create_row_model( + table: Table, field_ids: list[int] | None = None +) -> type[BaseModel]: + """ + Build a Pydantic model for creating rows in the given table. + + The returned model has a field for each supported column, with + ``to_django_orm()`` and ``from_django_orm()`` for ORM conversion. + + :param table: The table whose columns define the model fields. + :param field_ids: If given, only include these field IDs. + """ + + field_definitions, field_conversions = _scan_table_fields(table, field_ids) + + class CreateRowModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_django_orm(self) -> dict[str, Any]: + return _convert_fields(self.__dict__, field_conversions) + + @classmethod + def from_django_orm( + cls, orm_row: GeneratedTableModel, field_ids: list[int] | None = None + ) -> "CreateRowModel": + init_data = {} + if "id" in cls.model_fields: + init_data["id"] = orm_row.id + for field_object in orm_row.get_field_objects(): + field = field_object["field"] + if field.name not in field_conversions: + continue + if field_ids is not None and field.id not in field_ids: + continue + db_column, _, from_django_orm = field_conversions[field.name] + value = getattr(orm_row, db_column) + init_data[field.name] = ( + from_django_orm(value) if from_django_orm else value + ) + return cls(**init_data) + + return create_model( + f"Table{table.id}Row", + __module__=__name__, + __base__=CreateRowModel, + **field_definitions, + ) + + +def get_update_row_model(table: Table) -> type[BaseModel]: + """ + Build a Pydantic model for updating rows in the given table. + + All fields are optional with ``default=None``; only fields explicitly + provided during construction are included in ``to_django_orm()`` output, + so omitting a field means "don't change". + + :param table: The table whose columns define the model fields. + """ + + create_model_class = get_create_row_model(table) + _, field_conversions = _scan_table_fields(table) + + # All fields become Optional with default=None + update_fields = { + name: ( + info.annotation | None, + Field(default=None, description=info.description, title=info.title), + ) + for name, info in create_model_class.model_fields.items() + } + update_fields["id"] = (int, Field(..., description="The ID of the row to update")) + + class UpdateRowModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + def to_django_orm(self) -> dict[str, Any]: + # Only convert explicitly provided fields (pydantic tracks this) + explicitly_set = { + k: getattr(self, k) for k in self.model_fields_set if k != "id" + } + orm_data = _convert_fields(explicitly_set, field_conversions) + orm_data["id"] = self.id + return orm_data + + return create_model( + f"UpdateTable{table.id}Row", + __module__=__name__, + __base__=UpdateRowModel, + **update_fields, + ) + + +def get_link_row_hints(row_model: type[BaseModel]) -> str: + """ + Collect link_row example hints from a row model's field descriptions. + + Returns a formatted string for inclusion in tool descriptions, or an + empty string if no link_row fields with examples are found. + + :param row_model: A row model built by :func:`get_create_row_model`. + """ + + hints: list[str] = [] + for name, info in row_model.model_fields.items(): + desc = info.description or "" + if "linked table" in desc and "Examples:" in desc: + hints.append(f"{name} ({info.title}): {desc}") + + if not hints: + return "" + return " LINK_ROW fields: " + "; ".join(hints) + "." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py index 3703e877bc..8203c162dd 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py @@ -1,10 +1,12 @@ +import json + from django.db.models import Q -from pydantic import Field +from pydantic import Field, ValidationError, model_validator from baserow_enterprise.assistant.types import BaseModel -from .fields import AnyFieldItem, AnyFieldItemCreate +from .fields import _FIELD_EXAMPLES, _TYPE_ALIASES, FieldItem, FieldItemCreate class BaseTableItemCreate(BaseModel): @@ -22,50 +24,91 @@ class BaseTableItem(BaseTableItemCreate): class TableItemCreate(BaseTableItemCreate): """Model for creating a table with fields.""" - primary_field: AnyFieldItemCreate = Field( + primary_field_name: str = Field( ..., - description="The primary field of the table. Preferbly a text field with a sensible name for a primary field of the table.", - ) - fields: list[AnyFieldItemCreate] = Field( - ..., description="The fields of the table." + description="The name of the primary field (text field).", ) + fields: list[FieldItemCreate] = Field(..., description="The fields of the table.") + + @model_validator(mode="wrap") + @classmethod + def _validate_with_field_examples(cls, data, handler): + try: + return handler(data) + except ValidationError as exc: + if not isinstance(data, dict): + raise + + table_name = data.get("name", "unknown") + fields_data = data.get("fields", []) + if not isinstance(fields_data, list): + raise + + # Collect field indices that have errors + error_field_indices: set[int] = set() + for error in exc.errors(): + loc = error.get("loc", ()) + if len(loc) >= 2 and loc[0] == "fields" and isinstance(loc[1], int): + error_field_indices.add(loc[1]) + + if not error_field_indices: + raise # No field-level errors, re-raise as-is + + error_fields = [] + error_types: set[str] = set() + for idx in sorted(error_field_indices): + if idx < len(fields_data) and isinstance(fields_data[idx], dict): + fd = fields_data[idx] + fname = fd.get("name", f"fields[{idx}]") + ftype = str(fd.get("type", "unknown")) + ftype = _TYPE_ALIASES.get(ftype, ftype) + error_fields.append(f"'{fname}' ({ftype})") + if ftype in _FIELD_EXAMPLES: + error_types.add(ftype) + + if not error_fields: + raise + + parts = [ + f"Table '{table_name}': invalid fields: {', '.join(error_fields)}." + ] + for ft in sorted(error_types): + parts.append(f" {ft}: {json.dumps(_FIELD_EXAMPLES[ft])}") + + raise ValueError("\n".join(parts)) from None class TableItem(BaseTableItem): """Model for an existing table with fields.""" - primary_field: AnyFieldItem = Field( - ..., description="The primary field of the table." - ) - fields: list[AnyFieldItem] = Field(..., description="The fields of the table.") + primary_field: FieldItem = Field(..., description="The primary field of the table.") + fields: list[FieldItem] = Field(..., description="The fields of the table.") class ListTablesFilterArg(BaseModel): - database_ids: list[int] | None = Field( - default=None, - description="A list of database_ids to filter. None to exclude this filter", - ) - database_names: list[str] | None = Field( - default=None, - description="A list of database_names to filter. None to exclude this filter", - ) - table_ids: list[int] | None = Field( + database_id_or_name: int | str | None = Field( default=None, - description="A list of table ids to filter. None to exclude this filter", + description="The ID or name of the database to filter. null to exclude this filter.", ) - table_names: list[str] | None = Field( + table_ids_or_names: list[int | str] | None = Field( default=None, - description="A list of table names to filter. None to exclude this filter", + description="A list of table ids or names to filter in an OR fashion. null to exclude this filter.", ) def to_orm_filter(self) -> Q: q_filter = Q() - if self.database_ids: - q_filter &= Q(database_id__in=self.database_ids) - if self.database_names: - q_filter &= Q(database__name__in=self.database_names) - if self.table_ids: - q_filter &= Q(id__in=self.table_ids) - if self.table_names: - q_filter &= Q(name__in=self.table_names) + if isinstance(self.database_id_or_name, int): + q_filter &= Q(database_id=self.database_id_or_name) + elif isinstance(self.database_id_or_name, str): + q_filter &= Q(database__name__icontains=self.database_id_or_name) + if self.table_ids_or_names: + combined = Q() + ids = [item for item in self.table_ids_or_names if isinstance(item, int)] + names = [item for item in self.table_ids_or_names if isinstance(item, str)] + if ids: + combined |= Q(id__in=ids) + if names: + for name in names: + combined |= Q(name__icontains=name) + q_filter &= combined return q_filter diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py index 343fc14773..f80a9875c3 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py @@ -1,558 +1,238 @@ from typing import Literal -from pydantic import Field +from pydantic import Field, model_validator -from baserow_enterprise.assistant.types import Annotated, BaseModel +from baserow_enterprise.assistant.types import BaseModel -from .base import Date +from .base import parse_date +# --------------------------------------------------------------------------- +# Flat filter model +# --------------------------------------------------------------------------- -class ViewFilterItemCreate(BaseModel): - """Model for creating a new view filter (no ID).""" - - field_id: int = Field(...) - type: str = Field(...) - operator: str = Field(...) - value: str = Field(...) - - def get_django_orm_type(self, field, **kwargs) -> str: - return self.operator - - def get_django_orm_value(self, field, **kwargs) -> str: - return self.value - - -class ViewFilterItem(ViewFilterItemCreate): - """Model for an existing view filter (with ID).""" - - id: int = Field(..., description="The unique identifier of the view filter.") - - -class TextViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["text"] = Field(..., description="A text filter.") - value: str = Field(..., description="The text value to filter on.") - - -class TextEqualViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) - - -class TextEqualViewFilterItem(TextEqualViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotEqualViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - -class TextNotEqualViewFilterItem(TextNotEqualViewFilterItemCreate, ViewFilterItem): - pass - - -class TextContainsViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["contains"] = Field( - ..., description="Checks if the field contains the value." - ) - - -class TextContainsViewFilterItem(TextContainsViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotContainsViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["contains_not"] = Field( - ..., description="Checks if the field does not contain the value." - ) - - -class TextNotContainsViewFilterItem( - TextNotContainsViewFilterItemCreate, ViewFilterItem -): - pass - - -class TextEmptyViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["empty"] = Field(..., description="Checks if the field is empty.") - - -class TextEmptyViewFilterItem(TextEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -class TextNotEmptyViewFilterItemCreate(TextViewFilterItemCreate): - operator: Literal["not_empty"] = Field( - ..., description="Checks if the field is not empty." - ) - - -class TextNotEmptyViewFilterItem(TextNotEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -AnyTextViewFilterItemCreate = Annotated[ - TextEqualViewFilterItemCreate - | TextNotEqualViewFilterItemCreate - | TextContainsViewFilterItemCreate - | TextNotContainsViewFilterItemCreate - | TextEmptyViewFilterItemCreate - | TextNotEmptyViewFilterItemCreate, - Field(discriminator="operator"), +FilterType = Literal[ + "text", "number", "date", "single_select", "multiple_select", "link_row", "boolean" ] -AnyTextViewFilterItem = Annotated[ - TextEqualViewFilterItem - | TextNotEqualViewFilterItem - | TextContainsViewFilterItem - | TextNotContainsViewFilterItem - | TextEmptyViewFilterItem - | TextNotEmptyViewFilterItem, - Field(discriminator="operator"), +_OPERATORS: dict[str, tuple[str, ...]] = { + "text": ("equal", "not_equal", "contains", "contains_not", "empty", "not_empty"), + "number": ("equal", "not_equal", "higher_than", "lower_than", "empty", "not_empty"), + "date": ("equal", "not_equal", "after", "before"), + "single_select": ("is_any_of", "is_none_of"), + "multiple_select": ("is_any_of", "is_none_of"), + "link_row": ("has", "has_not"), + "boolean": ("equal",), +} + +# Operator aliases: normalize LLM-natural names to Baserow names before validation. +_OPERATOR_ALIASES: dict[str, str] = { + "equals": "equal", + "is": "equal", + "not_equals": "not_equal", + "is_not": "not_equal", + "greater_than": "higher_than", + "greater_than_or_equal": "higher_than", # or_equal flag handles the rest + "less_than": "lower_than", + "less_than_or_equal": "lower_than", # or_equal flag handles the rest + "gte": "higher_than", + "lte": "lower_than", + "gt": "higher_than", + "lt": "lower_than", + "neq": "not_equal", + "ne": "not_equal", + "eq": "equal", +} + +DateFilterMode = Literal[ + "today", + "yesterday", + "tomorrow", + "this_week", + "last_week", + "next_week", + "this_month", + "last_month", + "next_month", + "this_year", + "last_year", + "next_year", + "nr_days_ago", + "nr_days_from_now", + "nr_weeks_ago", + "nr_weeks_from_now", + "nr_months_ago", + "nr_months_from_now", + "nr_years_ago", + "nr_years_from_now", + "exact_date", ] -class NumberViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["number"] = Field(..., description="A number filter.") - value: float = Field(..., description="The number value to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return str(self.value) - - -class NumberViewFilterItem(NumberViewFilterItemCreate, ViewFilterItem): - pass +# --------------------------------------------------------------------------- +# ORM type dispatch: (filter, field, **kwargs) -> str +# --------------------------------------------------------------------------- + +_NUMBER_OR_EQUAL = { + "higher_than": "higher_than_or_equal", + "lower_than": "lower_than_or_equal", +} + +_DATE_ORM_TYPE = { + "equal": "date_is", + "not_equal": "date_is_not", + "after": "date_is_after", + "before": "date_is_before", +} + +_DATE_OR_EQUAL = { + "after": "date_is_on_or_after", + "before": "date_is_on_or_before", +} + +_SINGLE_SELECT_ORM_TYPE = { + "is_any_of": "single_select_is_any_of", + "is_none_of": "single_select_is_none_of", +} + +_MULTIPLE_SELECT_ORM_TYPE = { + "is_any_of": "multiple_select_has", + "is_none_of": "multiple_select_has_not", +} + +_LINK_ROW_ORM_TYPE = { + "has": "link_row_has", + "has_not": "link_row_has_not", +} + +_GET_ORM_TYPE = { + "text": lambda f, field, **kw: f.operator, + "number": lambda f, field, **kw: ( + _NUMBER_OR_EQUAL.get(f.operator, f.operator) if f.or_equal else f.operator + ), + "date": lambda f, field, **kw: ( + _DATE_OR_EQUAL[f.operator] + if f.or_equal and f.operator in _DATE_OR_EQUAL + else _DATE_ORM_TYPE[f.operator] + ), + "single_select": lambda f, field, **kw: _SINGLE_SELECT_ORM_TYPE[f.operator], + "multiple_select": lambda f, field, **kw: _MULTIPLE_SELECT_ORM_TYPE[f.operator], + "link_row": lambda f, field, **kw: _LINK_ROW_ORM_TYPE[f.operator], + "boolean": lambda f, field, **kw: "equal", +} + + +# --------------------------------------------------------------------------- +# ORM value dispatch: (filter, field, **kwargs) -> str +# --------------------------------------------------------------------------- + + +def _select_orm_value(f, field, **kwargs): + values = set(v.lower() for v in f.value) + valid_option_ids = [ + option.id + for option in field.select_options.all() + if option.value.lower() in values + ] + return ",".join(str(v) for v in valid_option_ids) + + +def _date_orm_value(f, field, **kwargs): + timezone = kwargs.get("timezone", "UTC") + if isinstance(f.value, str): + value = parse_date(f.value).isoformat() + elif isinstance(f.value, int): + value = str(f.value) + else: + value = "" + return f"{timezone}?{value}?{f.mode}" + + +_GET_ORM_VALUE = { + "text": lambda f, field, **kw: f.value + if isinstance(f.value, str) + else str(f.value or ""), + "number": lambda f, field, **kw: str(f.value), + "date": _date_orm_value, + "single_select": _select_orm_value, + "multiple_select": _select_orm_value, + "link_row": lambda f, field, **kw: str(f.value), + "boolean": lambda f, field, **kw: "1" if f.value else "0", +} + + +# --------------------------------------------------------------------------- +# ViewFilterItemCreate +# --------------------------------------------------------------------------- -class NumberEqualsViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) - - -class NumberEqualsViewFilterItem(NumberEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class NumberNotEqualsViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - -class NumberNotEqualsViewFilterItem( - NumberNotEqualsViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberHigherThanViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["higher_than"] = Field( - ..., description="Checks if the field is higher than the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is higher than or equal to the value.", - ) - - -class NumberHigherThanViewFilterItem( - NumberHigherThanViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberLowerThanViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["lower_than"] = Field( - ..., description="Checks if the field is lower than the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is lower than or equal to the value.", - ) - - -class NumberLowerThanViewFilterItem( - NumberLowerThanViewFilterItemCreate, ViewFilterItem -): - pass - - -class NumberEmptyViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["empty"] = Field(..., description="Checks if the field is empty.") - - -class NumberEmptyViewFilterItem(NumberEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -class NumberNotEmptyViewFilterItemCreate(NumberViewFilterItemCreate): - operator: Literal["not_empty"] = Field( - ..., description="Checks if the field is not empty." - ) - - -class NumberNotEmptyViewFilterItem(NumberNotEmptyViewFilterItemCreate, ViewFilterItem): - pass - - -AnyNumberViewFilterItemCreate = Annotated[ - NumberEqualsViewFilterItemCreate - | NumberNotEqualsViewFilterItemCreate - | NumberHigherThanViewFilterItemCreate - | NumberLowerThanViewFilterItemCreate - | NumberEmptyViewFilterItemCreate - | NumberNotEmptyViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyNumberViewFilterItem = Annotated[ - NumberEqualsViewFilterItem - | NumberNotEqualsViewFilterItem - | NumberHigherThanViewFilterItem - | NumberLowerThanViewFilterItem - | NumberEmptyViewFilterItem - | NumberNotEmptyViewFilterItem, - Field(discriminator="operator"), -] - +class ViewFilterItemCreate(BaseModel): + """Flat model for creating a view filter: field_id + type + operator + value.""" -class DateViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["date"] = Field(..., description="A date filter.") - value: Date | int | None = Field( + field_id: int = Field(..., description="Field ID to filter on.") + type: FilterType = Field(..., description="Must match field type.") + operator: str = Field( ..., - description="\n".join( - [ - "The date value to filter on.", - "Use an integer for days/weeks/months/years ago/from now.", - "Use a date object for an exact date.", - "None otherwise.", - ] + description=( + "Filter operator. " + "text: equal/not_equal/contains/contains_not/empty/not_empty. " + "number: equal/not_equal/greater_than/less_than/empty/not_empty " + "(use or_equal=true for ≥/≤). " + "date: equal/not_equal/after/before (use or_equal=true for on_or_after/on_or_before). " + "single_select/multiple_select: is_any_of/is_none_of. " + "link_row: has/has_not. " + "boolean: equal." ), ) - mode: Literal[ - "today", - "yesterday", - "tomorrow", - "this_week", - "last_week", - "next_week", - "this_month", - "last_month", - "next_month", - "this_year", - "last_year", - "next_year", - "nr_days_ago", - "nr_days_from_now", - "nr_weeks_ago", - "nr_weeks_from_now", - "nr_months_ago", - "nr_months_from_now", - "nr_years_ago", - "nr_years_from_now", - "exact_date", - ] = Field( - "exact_date", - description="The mode to use for the date filter. ALWAYS use the right mode if available. Use 'exact_date' if you have an exact date.", - ) - - def get_django_orm_value(self, field, **kwargs) -> str: - timezone = kwargs.get("timezone", "UTC") - - if isinstance(self.value, Date): - value = self.value.to_django_orm() - elif isinstance(self.value, int): - value = str(self.value) - else: - value = "" - - return f"{timezone}?{value}?{self.mode}" - - -class DateEqualsViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["equal"] = Field( - ..., description="Checks if the field is equal to the value." - ) + value: str | float | int | bool | list[str] | None = Field( + None, + description="Filter value (type-dependent).", + ) + mode: DateFilterMode | None = Field(None, description="(date) Date filter mode.") + or_equal: bool = Field(False, description="(number, date) Include equal values.") + + @model_validator(mode="before") + @classmethod + def _normalize_operator(cls, data): + if isinstance(data, dict) and "operator" in data: + op = data["operator"] + normalized = _OPERATOR_ALIASES.get(op) + if normalized: + data = dict(data) + data["operator"] = normalized + # Auto-set or_equal for _or_equal variants + if "or_equal" in op: + data.setdefault("or_equal", True) + return data + + @model_validator(mode="after") + def _validate_per_type(self): + valid = _OPERATORS.get(self.type) + if valid and self.operator not in valid: + raise ValueError( + f"Invalid operator '{self.operator}' for type '{self.type}'. " + f"Valid operators: {', '.join(valid)}" + ) + if self.type == "date" and self.mode is None: + raise ValueError("date filter requires 'mode'.") + return self def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is" - - -class DateEqualsViewFilterItem(DateEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class DateNotEqualsViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["not_equal"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_not" - - -class DateNotEqualsViewFilterItem(DateNotEqualsViewFilterItemCreate, ViewFilterItem): - pass - - -class DateAfterViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["after"] = Field( - ..., description="Checks if the field is after the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is after or equal to the value.", - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_on_or_after" if self.or_equal else "date_is_after" - - -class DateAfterViewFilterItem(DateAfterViewFilterItemCreate, ViewFilterItem): - pass - - -class DateBeforeViewFilterItemCreate(DateViewFilterItemCreate): - operator: Literal["before"] = Field( - ..., description="Checks if the field is before the value." - ) - or_equal: bool = Field( - False, - description="If true, checks if the field is before or equal to the value.", - ) - - def get_django_orm_type(self, field, **kwargs) -> str: - return "date_is_on_or_before" if self.or_equal else "date_is_before" - - -class DateBeforeViewFilterItem(DateBeforeViewFilterItemCreate, ViewFilterItem): - pass - - -AnyDateViewFilterItemCreate = Annotated[ - DateEqualsViewFilterItemCreate - | DateNotEqualsViewFilterItemCreate - | DateAfterViewFilterItemCreate - | DateBeforeViewFilterItemCreate, - Field(discriminator="operator"), -] -AnyDateViewFilterItem = Annotated[ - DateEqualsViewFilterItem - | DateNotEqualsViewFilterItem - | DateAfterViewFilterItem - | DateBeforeViewFilterItem, - Field(discriminator="operator"), -] - - -class SingleSelectViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["single_select"] = Field(..., description="A single select filter.") - value: list[str] = Field( - ..., description="The select option value(s) to filter on." - ) - - def get_django_orm_value(self, field, **kwargs) -> str: - values = set(v.lower() for v in self.value) - valid_option_ids = [ - option.id - for option in field.select_options.all() - if option.value.lower() in values - ] - return ",".join([str(v) for v in valid_option_ids]) - - -class SingleSelectIsAnyViewFilterItemCreate(SingleSelectViewFilterItemCreate): - operator: Literal["is_any_of"] = Field( - ..., description="Checks if the field is equal to any of the values " - ) - - def get_django_orm_type(self, field, **kwargs): - return "single_select_is_any_of" - - -class SingleSelectIsAnyViewFilterItem( - SingleSelectIsAnyViewFilterItemCreate, ViewFilterItem -): - pass - - -class SingleSelectIsNoneOfNotViewFilterItemCreate(SingleSelectViewFilterItemCreate): - operator: Literal["is_none_of"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs): - return "single_select_is_none_of" - - -class SingleSelectIsNoneOfNotViewFilterItem( - SingleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem -): - pass - - -AnySingleSelectViewFilterItemCreate = Annotated[ - SingleSelectIsAnyViewFilterItemCreate | SingleSelectIsNoneOfNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnySingleSelectViewFilterItem = Annotated[ - SingleSelectIsAnyViewFilterItem | SingleSelectIsNoneOfNotViewFilterItem, - Field(discriminator="operator"), -] - - -class MultipleSelectViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["multiple_select"] = Field( - ..., description="A multiple select filter." - ) - value: list[str] = Field( - ..., description="The select option value(s) to filter on." - ) + return _GET_ORM_TYPE[self.type](self, field, **kwargs) def get_django_orm_value(self, field, **kwargs) -> str: - values = set(v.lower() for v in self.value) - valid_option_ids = [ - option.id - for option in field.select_options.all() - if option.value.lower() in values - ] - return ",".join([str(v) for v in valid_option_ids]) - - -class MultipleSelectIsAnyViewFilterItemCreate(MultipleSelectViewFilterItemCreate): - operator: Literal["is_any_of"] = Field( - ..., description="Checks if the field is equal to any of the values " - ) - - def get_django_orm_type(self, field, **kwargs): - return "multiple_select_has" - - -class MultipleSelectIsAnyViewFilterItem( - MultipleSelectIsAnyViewFilterItemCreate, ViewFilterItem -): - pass - - -class MultipleSelectIsNoneOfNotViewFilterItemCreate(MultipleSelectViewFilterItemCreate): - operator: Literal["is_none_of"] = Field( - ..., description="Checks if the field is not equal to the value." - ) - - def get_django_orm_type(self, field, **kwargs): - return "multiple_select_has_not" - - -class MultipleSelectIsNoneOfNotViewFilterItem( - MultipleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem -): - pass - - -AnyMultipleSelectViewFilterItemCreate = Annotated[ - MultipleSelectIsAnyViewFilterItemCreate - | MultipleSelectIsNoneOfNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyMultipleSelectViewFilterItem = Annotated[ - MultipleSelectIsAnyViewFilterItem | MultipleSelectIsNoneOfNotViewFilterItem, - Field(discriminator="operator"), -] - - -class LinkRowViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["link_row"] = Field(..., description="A link row filter.") - value: int = Field(..., description="The linked record ID to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return str(self.value) - - -class LinkRowHasViewFilterItemCreate(LinkRowViewFilterItemCreate): - operator: Literal["has"] = Field( - ..., description="Checks if the field has the linked record." - ) - - def get_django_orm_type(self, field, **kwargs): - return "link_row_has" - + return _GET_ORM_VALUE[self.type](self, field, **kwargs) -class LinkRowHasViewFilterItem(LinkRowHasViewFilterItemCreate, ViewFilterItem): - pass +class ViewFilterItem(ViewFilterItemCreate): + """Existing view filter with ID.""" -class LinkRowHasNotViewFilterItemCreate(LinkRowViewFilterItemCreate): - operator: Literal["has_not"] = Field( - ..., description="Checks if the field does not have the linked record." - ) - - def get_django_orm_type(self, field, **kwargs): - return "link_row_has_not" - - -class LinkRowHasNotViewFilterItem(LinkRowHasNotViewFilterItemCreate, ViewFilterItem): - pass - - -AnyLinkRowViewFilterItemCreate = Annotated[ - LinkRowHasViewFilterItemCreate | LinkRowHasNotViewFilterItemCreate, - Field(discriminator="operator"), -] - -AnyLinkRowViewFilterItem = Annotated[ - LinkRowHasViewFilterItem | LinkRowHasNotViewFilterItem, - Field(discriminator="operator"), -] - - -class BooleanViewFilterItemCreate(ViewFilterItemCreate): - type: Literal["boolean"] = Field(..., description="A boolean filter.") - value: bool = Field(..., description="The boolean value to filter on.") - - def get_django_orm_value(self, field, **kwargs) -> str: - return "1" if self.value else "0" - - -class BooleanIsViewFilterItemCreate(BooleanViewFilterItemCreate): - operator: Literal["is"] = Field(..., description="Checks if the field is true.") - value: bool = Field(..., description="The boolean value to filter on.") - - def get_django_orm_type(self, field, **kwargs) -> str: - return "boolean" - - -class BooleanIsTrueViewFilterItem(BooleanIsViewFilterItemCreate, ViewFilterItem): - pass - + id: int = Field(..., description="The unique identifier of the view filter.") -AnyViewFilterItemCreate = Annotated[ - AnyTextViewFilterItemCreate - | AnyNumberViewFilterItemCreate - | AnyDateViewFilterItemCreate - | AnySingleSelectViewFilterItemCreate - | AnyLinkRowViewFilterItemCreate - | BooleanViewFilterItemCreate - | MultipleSelectViewFilterItemCreate, - Field(discriminator="type"), -] -AnyViewFilterItem = Annotated[ - AnyTextViewFilterItem - | AnyNumberViewFilterItem - | AnyDateViewFilterItem - | AnySingleSelectViewFilterItem - | AnyLinkRowViewFilterItem - | BooleanIsTrueViewFilterItem - | MultipleSelectIsAnyViewFilterItem, - Field(discriminator="type"), -] +AnyViewFilterItemCreate = ViewFilterItemCreate +AnyViewFilterItem = ViewFilterItem class ViewFiltersArgs(BaseModel): view_id: int - filters: list[AnyViewFilterItemCreate] + filters: list[ViewFilterItemCreate] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py index 021e4916a1..72d9efd74b 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py @@ -1,300 +1,269 @@ -from typing import Annotated, Literal, Type +import json +from typing import Any, Literal -from pydantic import Field +from pydantic import Field, model_serializer, model_validator from baserow.contrib.database.fields.models import ( DateField, FileField, SingleSelectField, ) -from baserow.contrib.database.views.models import FormView, GalleryView, GridView from baserow.contrib.database.views.models import View as BaserowView from baserow.contrib.database.views.registries import view_type_registry from baserow_enterprise.assistant.types import BaseModel from baserow_premium.permission_manager import Table -from baserow_premium.views.models import CalendarView, KanbanView, TimelineView +# --------------------------------------------------------------------------- +# Shared types +# --------------------------------------------------------------------------- -class ViewItemCreate(BaseModel): - name: str = Field( - ..., - description="A sensible name for the view (i.e. 'Pending payments', 'Completed tasks', etc.).", - ) - public: bool = Field( - default=False, - description="Whether the view is publicly accessible. False unless specified.", - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - "name": self.name, - "public": self.public, - } - - def field_options_to_django_orm(self) -> dict[str, any]: - return {} - - -class ViewItem(BaseModel): - id: int = Field(...) - @classmethod - def from_django_orm(cls, orm_view: Type[BaserowView]) -> "ViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - public=orm_view.public, - ) +class FormFieldOption(BaseModel): + field_id: int = Field(..., description="Field ID.") + name: str = Field(..., description="Display name in form.") + description: str = Field(..., description="Field description, or ''.") + required: bool = Field(..., description="Required?") + order: int = Field(..., description="Sort order.") class GridFieldOption(BaseModel): field_id: int = Field(...) width: int = Field( - default=200, - description="The width of the field in the grid view. Default is 200.", + ..., + description="The width of the field in the grid view (e.g. 200).", ) hidden: bool = Field( - default=False, - description="Whether the field is hidden in the grid view. Default is False.", - ) - - -class GridViewItemCreate(ViewItemCreate): - type: Literal["grid"] = Field(..., description="A grid view.") - row_height: Literal["small", "medium", "large"] = Field( - default="small", - description=( - "The height of the rows in the view. Can be 'small', 'medium' or 'large'. Default is 'small'." - ), - ) - - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - **super().to_django_orm_kwargs(table), - "row_height": self.row_height, - } - - -class GridViewItem(GridViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: GridView) -> "GridViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="grid", - row_height="small", - public=orm_view.public, - ) - - -class KanbanViewItemCreate(ViewItemCreate): - type: Literal["kanban"] = Field(..., description="A kanban view.") - column_field_id: int | None = Field( ..., - description="The ID of the field to use for the kanban columns. Must be a single select field. None if no single select field is available.", + description="Whether the field is hidden in the grid view.", ) - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - column_field = model.get_field_object_by_id(self.column_field_id)["field"] - if not isinstance(column_field, SingleSelectField): - raise ValueError("The column_field_id must be a Single Select field.") - return { - **super().to_django_orm_kwargs(table), - "single_select_field": column_field, - } +# --------------------------------------------------------------------------- +# Flat view types +# --------------------------------------------------------------------------- + +ViewType = Literal["grid", "kanban", "calendar", "gallery", "timeline", "form"] + +_VIEW_EXAMPLES: dict[str, dict] = { + "grid": {"name": "All Items", "type": "grid", "row_height": "small"}, + "kanban": {"name": "Board", "type": "kanban", "column_field_id": 123}, + "calendar": {"name": "Schedule", "type": "calendar", "date_field_id": 456}, + "gallery": {"name": "Photos", "type": "gallery", "cover_field_id": 789}, + "timeline": { + "name": "Project Timeline", + "type": "timeline", + "start_date_field_id": 111, + "end_date_field_id": 222, + }, + "form": { + "name": "Contact Form", + "type": "form", + "title": "Contact Us", + "description": "", + "submit_button_label": "Submit", + "receive_notification_on_submit": False, + "submit_action": "MESSAGE", + "submit_action_message": "Thank you!", + "submit_action_redirect_url": "", + "field_options": [ + { + "field_id": 1, + "name": "Name", + "description": "", + "required": True, + "order": 1, + } + ], + }, +} + + +# --------------------------------------------------------------------------- +# to_django_orm builders: (ViewItemCreate, Table) -> dict +# --------------------------------------------------------------------------- + + +def _grid_to_orm(v, table): + return {"row_height": v.row_height} + + +def _kanban_to_orm(v, table): + model = table.get_model() + column_field = model.get_field_object_by_id(v.column_field_id)["field"] + if not isinstance(column_field, SingleSelectField): + raise ValueError("The column_field_id must be a Single Select field.") + return {"single_select_field": column_field} + + +def _calendar_to_orm(v, table): + model = table.get_model() + date_field = model.get_field_object_by_id(v.date_field_id)["field"] + if not isinstance(date_field, DateField): + raise ValueError("The date_field_id must be a Date field.") + return {"date_field": date_field} + + +def _gallery_to_orm(v, table): + model = table.get_model() + cover_field = model.get_field_object_by_id(v.cover_field_id)["field"] + if not isinstance(cover_field, FileField): + raise ValueError("The cover_field_id must be a File field.") + return {"card_cover_image_field_id": v.cover_field_id} + + +def _timeline_to_orm(v, table): + model = table.get_model() + start_field = model.get_field_object_by_id(v.start_date_field_id)["field"] + end_field = model.get_field_object_by_id(v.end_date_field_id)["field"] + if ( + not isinstance(start_field, DateField) + or not isinstance(end_field, DateField) + or start_field.id == end_field.id + or start_field.date_include_time != end_field.date_include_time + ): + raise ValueError( + "Invalid timeline configuration: both start and end fields must be Date fields " + "and they must have the same include_time setting (either both include time or " + "both are date-only). " + ) + return {"start_date_field": start_field, "end_date_field": end_field} -class KanbanViewItem(KanbanViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: KanbanView) -> "KanbanViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="kanban", - column_field_id=orm_view.single_select_field_id, - public=orm_view.public, - ) +def _form_to_orm(v, table): + return {"title": v.title, "description": v.description} -class CalendarViewItemCreate(ViewItemCreate): - type: Literal["calendar"] = Field(..., description="A calendar view.") - date_field_id: int | None = Field( - ..., - description="The ID of the field to use for the calendar dates. Must be a date field. None if no date field is available.", - ) +_TO_DJANGO_ORM = { + "grid": _grid_to_orm, + "kanban": _kanban_to_orm, + "calendar": _calendar_to_orm, + "gallery": _gallery_to_orm, + "timeline": _timeline_to_orm, + "form": _form_to_orm, +} - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - date_field = model.get_field_object_by_id(self.date_field_id)["field"] - if not isinstance(date_field, DateField): - raise ValueError("The date_field_id must be a Date field.") - return { - **super().to_django_orm_kwargs(table), - "date_field": date_field, - } +# --------------------------------------------------------------------------- +# from_django_orm builders: (orm_view) -> dict of extra kwargs +# --------------------------------------------------------------------------- -class CalendarViewItem(CalendarViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: CalendarView) -> "CalendarViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="calendar", - date_field_id=orm_view.date_field_id, - public=orm_view.public, +def _form_field_options_from_orm(orm_view): + return [ + FormFieldOption( + field_id=fo.field_id, + name=fo.name, + description=fo.description, + required=fo.required, + order=fo.order, ) + for fo in orm_view.active_field_options.all() + ] -class BaseGalleryViewItem(ViewItemCreate): - type: Literal["gallery"] = Field(..., description="A gallery view.") - cover_field_id: int | None = Field( - default=None, - description=( - "The ID of the field to use for the gallery cover image. Must be a file field. None if no file field is available." - ), - ) +_FROM_DJANGO_ORM: dict[str, Any] = { + "grid": lambda v: {"row_height": "small"}, + "kanban": lambda v: {"column_field_id": v.single_select_field_id}, + "calendar": lambda v: {"date_field_id": v.date_field_id}, + "gallery": lambda v: {"cover_field_id": v.card_cover_image_field_id}, + "timeline": lambda v: { + "start_date_field_id": v.start_date_field_id, + "end_date_field_id": v.end_date_field_id, + }, + "form": lambda v: { + "title": v.title, + "description": v.description, + "field_options": _form_field_options_from_orm(v), + }, +} -class GalleryViewItemCreate(BaseGalleryViewItem): - def to_django_orm_kwargs(self, table): - model = table.get_model() - cover_field = model.get_field_object_by_id(self.cover_field_id)["field"] - if not isinstance(cover_field, FileField): - raise ValueError("The cover_field_id must be a File field.") - - return { - **super().to_django_orm_kwargs(table), - "card_cover_image_field_id": self.cover_field_id, - } +# --------------------------------------------------------------------------- +# ViewItemCreate +# --------------------------------------------------------------------------- -class GalleryViewItem(BaseGalleryViewItem, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: GalleryView) -> "GalleryViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="gallery", - cover_field_id=orm_view.card_cover_image_field_id, - public=orm_view.public, - ) +class ViewItemCreate(BaseModel): + """Flat model for creating a view: name + type + type-specific options.""" + name: str = Field(..., description="Descriptive view name.") + public: bool = Field(..., description="Publicly accessible? Default false.") + type: ViewType = Field(..., description="View type.") -class BaseTimelineViewItem(ViewItemCreate): - type: Literal["timeline"] = Field(..., description="A timeline view.") - start_date_field_id: int | None = Field( - ..., - description="The ID of the field to use for the timeline dates. Must be a date field. None if no date field is available.", + # -- grid -- + row_height: Literal["small", "medium", "large"] = Field( + "small", description="(grid) Row height." ) - end_date_field_id: int | None = Field( - ..., - description=( - "The ID of the field to use for the timeline end dates. Must be a date field. None if no date field is available." - ), + # -- kanban -- + column_field_id: int | None = Field( + None, description="(kanban) Single-select field ID for columns." ) - - -class TimelineViewItemCreate(BaseTimelineViewItem): - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - model = table.get_model() - start_field = model.get_field_object_by_id(self.start_date_field_id)["field"] - end_field = model.get_field_object_by_id(self.end_date_field_id)["field"] - if ( - not isinstance(start_field, DateField) - or not isinstance(end_field, DateField) - or start_field.id == end_field.id - or start_field.date_include_time != end_field.date_include_time - ): - raise ValueError( - "Invalid timeline configuration: both start and end fields must be Date fields " - "and they must have the same include_time setting (either both include time or " - "both are date-only). " - ) - - return { - **super().to_django_orm_kwargs(table), - "start_date_field": start_field, - "end_date_field": end_field, - } - - -class TimelineViewItem(BaseTimelineViewItem, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: TimelineView) -> "TimelineViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="timeline", - start_date_field_id=orm_view.start_date_field_id, - end_date_field_id=orm_view.end_date_field_id, - public=orm_view.public, - ) - - -class FormFieldOption(BaseModel): - field_id: int = Field(..., description="The ID of the field.") - name: str = Field(..., description="The name to show for the field in the form.") - description: str = Field( - default="", description="The description to show for the field in the form." + # -- calendar -- + date_field_id: int | None = Field(None, description="(calendar) Date field ID.") + # -- gallery -- + cover_field_id: int | None = Field( + None, description="(gallery) File field ID for covers." ) - required: bool = Field( - default=True, - description="Whether the field is required in the form. Default is True.", + # -- timeline -- + start_date_field_id: int | None = Field( + None, description="(timeline) Start date field ID." ) - order: int = Field(..., description="The order of the field in the form.") - - -class BaseFormViewItem(ViewItemCreate): - type: Literal["form"] = Field(..., description="A form view.") - title: str = Field(..., description="The title of the form.") - description: str = Field(..., description="The description of the form.") - submit_button_label: str = Field( - default="Submit", description="The label of the submit button." + end_date_field_id: int | None = Field( + None, description="(timeline) End date field ID." ) + # -- form -- + title: str = Field("", description="(form) Title, or ''.") + description: str = Field("", description="(form) Description, or ''.") + submit_button_label: str = Field("Submit", description="(form) Button label.") receive_notification_on_submit: bool = Field( - default=False, - description=( - "Whether to receive an email notification when the form is submitted." - ), + False, description="(form) Email on submit." ) submit_action: Literal["MESSAGE", "REDIRECT"] = Field( - default="MESSAGE", - description="The action to perform when the form is submitted.", - ) - submit_action_message: str = Field( - default="", - description=( - "The message to display when the form is submitted and the action is 'MESSAGE'." - ), + "MESSAGE", description="(form) 'MESSAGE' or 'REDIRECT'." ) + submit_action_message: str = Field("", description="(form) Message after submit.") submit_action_redirect_url: str = Field( - default="", - description=( - "The URL to redirect to when the form is submitted and the action is 'REDIRECT'." - ), + "", description="(form) Redirect URL after submit." ) - - field_options: list[FormFieldOption] = Field( - ..., - description=( - "The list of fields to show in the form, along with their options. The fields must be part of the table." - ), + field_options: list[FormFieldOption] | None = Field( + None, + description="(form) Fields to show (OPT-IN: include all you want visible).", ) + # Required fields per type: {type: [(attr_name, display_name), ...]} + _REQUIRED_FIELDS: dict[str, list[tuple[str, str]]] = { + "kanban": [("column_field_id", "column_field_id")], + "calendar": [("date_field_id", "date_field_id")], + "gallery": [("cover_field_id", "cover_field_id")], + "timeline": [ + ("start_date_field_id", "start_date_field_id"), + ("end_date_field_id", "end_date_field_id"), + ], + "form": [("field_options", "field_options")], + } -class FormViewItemCreate(BaseFormViewItem): - def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: - return { - **super().to_django_orm_kwargs(table), - "title": self.title, - "description": self.description, - } - - def field_options_to_django_orm(self): + @model_validator(mode="after") + def _validate_required_for_type(self): + required = self._REQUIRED_FIELDS.get(self.type) + if required: + missing = [name for attr, name in required if not getattr(self, attr)] + if missing: + raise ValueError( + f"{self.type} requires {', '.join(missing)}. " + f"Example: {json.dumps(_VIEW_EXAMPLES[self.type])}" + ) + return self + + def to_django_orm_kwargs(self, table: Table) -> dict[str, Any]: + base = {"name": self.name, "public": self.public} + builder = _TO_DJANGO_ORM.get(self.type) + if builder: + base.update(builder(self, table)) + return base + + def field_options_to_django_orm(self) -> dict[str, Any]: + if self.type != "form" or not self.field_options: + return {} return { fo.field_id: { "enabled": True, @@ -307,64 +276,44 @@ def field_options_to_django_orm(self): } -class FormViewItem(FormViewItemCreate, ViewItem): - @classmethod - def from_django_orm(cls, orm_view: FormView) -> "FormViewItem": - return cls( - id=orm_view.id, - name=orm_view.name, - type="form", - public=orm_view.public, - title=orm_view.title, - description=orm_view.description, - field_options=[ - FormFieldOption( - field_id=fo.field_id, - name=fo.name, - description=fo.description, - required=fo.required, - order=fo.order, - ) - for fo in orm_view.active_field_options.all() - ], - ) +# --------------------------------------------------------------------------- +# ViewItem (read-back) +# --------------------------------------------------------------------------- -AnyViewItemCreate = Annotated[ - GridViewItemCreate - | KanbanViewItemCreate - | CalendarViewItemCreate - | GalleryViewItemCreate - | TimelineViewItemCreate - | FormViewItemCreate, - Field(discriminator="type"), -] - -AnyViewItem = Annotated[ - GridViewItem - | KanbanViewItem - | CalendarViewItem - | GalleryViewItem - | TimelineViewItem - | FormViewItem, - Field(discriminator="type"), -] - - -class ViewItemsRegistry: - _registry = { - "grid": GridViewItem, - "kanban": KanbanViewItem, - "calendar": CalendarViewItem, - "gallery": GalleryViewItem, - "timeline": TimelineViewItem, - "form": FormViewItem, - } - - def from_django_orm(self, orm_view: Type[BaserowView]) -> ViewItem: - view_type = view_type_registry.get_by_model(orm_view).type - view_class: ViewItem = self._registry.get(view_type, ViewItem) - return view_class.from_django_orm(orm_view) +class ViewItem(BaseModel): + """Existing view with ID — flat structure matching ViewItemCreate.""" + id: int = Field(...) + name: str = Field(...) + public: bool = Field(...) + type: str = Field(...) + + # Type-specific (populated per type, others excluded via serializer) + row_height: str | None = None + column_field_id: int | None = None + date_field_id: int | None = None + cover_field_id: int | None = None + start_date_field_id: int | None = None + end_date_field_id: int | None = None + title: str | None = None + description: str | None = None + field_options: list[FormFieldOption] | None = None + + @model_serializer(mode="wrap") + def _exclude_none(self, handler): + return {k: v for k, v in handler(self).items() if v is not None} -view_item_registry = ViewItemsRegistry() + @classmethod + def from_django_orm(cls, orm_view: BaserowView) -> "ViewItem": + view_type = view_type_registry.get_by_model(orm_view).type + kwargs: dict[str, Any] = { + "id": orm_view.id, + "name": orm_view.name, + "public": orm_view.public, + "type": view_type, + } + builder = _FROM_DJANGO_ORM.get(view_type) + if builder: + kwargs.update(builder(orm_view)) + return cls(**kwargs) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py deleted file mode 100644 index 2cc186c093..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py +++ /dev/null @@ -1,559 +0,0 @@ -from dataclasses import dataclass -from itertools import groupby -from typing import TYPE_CHECKING, Any, Callable, Literal, Type, Union - -from django.contrib.auth.models import AbstractUser -from django.core.exceptions import ValidationError -from django.db import transaction -from django.db.models import Q, QuerySet -from django.utils.translation import gettext as _ - -import udspy -from pydantic import ConfigDict, Field, create_model -from udspy.utils import minimize_schema, resolve_json_schema_reference - -from baserow.contrib.database.fields.actions import CreateFieldActionType -from baserow.contrib.database.fields.field_types import LinkRowFieldType -from baserow.contrib.database.fields.handler import FieldHandler -from baserow.contrib.database.fields.models import SelectOption as OrmSelectOption -from baserow.contrib.database.fields.registries import field_type_registry -from baserow.contrib.database.rows.actions import ( - CreateRowsActionType, - DeleteRowsActionType, - UpdateRowsActionType, -) -from baserow.contrib.database.table.handler import TableHandler -from baserow.contrib.database.table.models import ( - FieldObject, - GeneratedTableModel, - Table, -) -from baserow.contrib.database.views.actions import CreateViewFilterActionType -from baserow.contrib.database.views.handler import ViewHandler -from baserow.contrib.database.views.models import View, ViewFilter -from baserow.core.db import specific_iterator -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.database.types.table import ( - BaseTableItem, - TableItem, -) - -from .types import ( - AnyFieldItem, - AnyFieldItemCreate, - AnyViewFilterItemCreate, - BaseModel, - Date, - Datetime, - field_item_registry, -) - -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - -NoChange = Literal["__NO_CHANGE__"] - - -def filter_tables(user: AbstractUser, workspace: Workspace) -> QuerySet[Table]: - return TableHandler().list_workspace_tables(user, workspace) - - -def list_tables( - user: AbstractUser, workspace: Workspace, database_id: int -) -> list[BaseTableItem]: - tables_qs = filter_tables(user, workspace).filter(database_id=database_id) - - return [BaseTableItem(id=table.id, name=table.name) for table in tables_qs] - - -def get_tables_schema( - tables: list[Table], - full_schema: bool = False, -) -> list[TableItem]: - """Returns the schema of the specified tables.""" - - q = Q(table__in=tables) - if not full_schema: # Only the primary fields and relationships - q &= Q(linkrowfield__isnull=False) | Q(primary=True) - - base_field_queryset = FieldHandler().get_base_fields_queryset() - fields = specific_iterator( - base_field_queryset.filter(q).order_by("table_id", "order"), - per_content_type_queryset_hook=( - lambda field, queryset: field_type_registry.get_by_model( - field - ).enhance_field_queryset(queryset, field) - ), - ) - - table_items = [] - for table_id, fields_in_table in groupby(fields, lambda f: f.table_id): - fields_in_table = list(fields_in_table) - table = next(t for t in tables if t.id == table_id) - primary_field = next(f for f in fields if f.primary) - primary_field_item = field_item_registry.from_django_orm(primary_field) - - table_items.append( - TableItem( - id=table.id, - name=table.name, - primary_field=primary_field_item, - fields=[ - field_item_registry.from_django_orm(f) - for f in fields_in_table - if f.id != primary_field.id - ], - ) - ) - - # Make sure the order is the same as the input - tables = list(tables) - table_items.sort( - key=lambda t: tables.index(next(tb for tb in tables if tb.id == t.id)) - ) - - return table_items - - -def create_fields( - user: AbstractUser, - table: Table, - field_items: list[AnyFieldItemCreate], - tool_helpers: "ToolHelpers", -) -> list[AnyFieldItem]: - created_fields = [] - for field_item in field_items: - tool_helpers.update_status( - _("Creating field %(field_name)s...") % {"field_name": field_item.name} - ) - - new_field = CreateFieldActionType.do( - user, - table, - field_item.type, - **field_item.to_django_orm_kwargs(table), - ) - created_fields.append(field_item_registry.from_django_orm(new_field)) - return created_fields - - -@dataclass -class FieldDefinition: - type: Type | None = None - field_def: Any | None = None - to_django_orm: Callable[[Any], Any] | None = None - from_django_orm: Callable[[Any], Any] | None = None - - -def _get_pydantic_field_definition( - field_object: FieldObject, -) -> FieldDefinition: - """ - Returns the Pydantic field type and definition for the given field object. - """ - - orm_field = field_object["field"] - orm_field_type = field_object["type"] - - match orm_field_type.type: - case "text": - return FieldDefinition( - str | None, - Field(..., description="Single-line text", title=orm_field.name), - lambda v: v if v is not None else "", - lambda v: v if v is not None else "", - ) - - case "long_text": - return FieldDefinition( - str | None, - Field(..., description="Multi-line text", title=orm_field.name), - lambda v: v if v is not None else "", - lambda v: v if v is not None else "", - ) - case "number": - return FieldDefinition( - float | None, - Field(..., description="Number or None", title=orm_field.name), - ) - case "boolean": - return FieldDefinition( - bool, Field(..., description="Boolean", title=orm_field.name) - ) - case "date": - if orm_field.date_include_time: - return FieldDefinition( - Datetime | None, - Field(..., description="Datetime or None", title=orm_field.name), - lambda v: v.to_django_orm() if v else None, - lambda v: Datetime.from_django_orm(v) if v is not None else None, - ) - else: - return FieldDefinition( - Date | None, - Field(..., description="Date or None", title=orm_field.name), - lambda v: v.to_django_orm() if v else None, - lambda v: Date.from_django_orm(v) if v is not None else None, - ) - case "single_select": - choices = [option.value for option in orm_field.select_options.all()] - - return FieldDefinition( - Literal[*choices] | None, - Field( - ..., - description=f"One of: {', '.join(choices)} or None", - title=orm_field.name, - ), - lambda v: v if v in choices else None, - lambda v: v.value if isinstance(v, OrmSelectOption) else v, - ) - case "multiple_select": - choices = [option.value for option in orm_field.select_options.all()] - - return FieldDefinition( - list[Literal[*choices]], - Field( - ..., - description=f"List of any of: {', '.join(choices)} or empty list", - title=orm_field.name, - ), - lambda v: [opt for opt in v if opt in choices], - lambda v: [opt.value for opt in v.all()] if v is not None else None, - ) - case "link_row": - linked_model = orm_field.link_row_table.get_model() - linked_primary_key = linked_model.get_primary_field() - - # If there's no primary key, we can't safely work with this field - if linked_primary_key is None: - return FieldDefinition() # Unsupported field type - - # Avoid null or empty values - linked_pk = linked_primary_key.db_column - linked_values = list( - linked_model.objects.exclude( - Q(**{f"{linked_pk}__isnull": True}) - | Q(**{f"{linked_pk}__exact": ""}) - ).values_list(linked_pk, flat=True)[:10] - ) - examples = f"Examples: {', '.join([str(v) for v in linked_values])}" - - def to_django_orm(value): - if isinstance(value, str) or isinstance(value, int): - value = [value] - if value is not None: - try: - return LinkRowFieldType().prepare_value_for_db(orm_field, value) - except ValidationError: - pass - return [] - - def from_django_orm(value): - values = [str(v) for v in value.all()] - if orm_field.link_row_multiple_relationships: - return values - else: - return values[0] if values else None - - # TODO: verify this can work with every possible primary field type - if orm_field.link_row_multiple_relationships: - desc = "List of values (as strings) or IDs (as integers) from the linked table or empty list." - field_type = list[str | int] | None - else: - desc = "Single value (as string) or ID (as integer) from the linked table or empty list." - field_type = str | int | None - if examples: - desc += " " + examples - return FieldDefinition( - field_type, - Field(None, description=desc, title=orm_field.name), - to_django_orm, - from_django_orm, - ) - - case _: - return FieldDefinition() # Unsupported field type - - -def get_create_row_model(table: Table, field_ids: list[int] | None = None) -> BaseModel: - """ - Dynamically creates a Pydantic model for the given table based on its fields, to be - used for row creation and validation. - """ - - model_name = f"Table{table.id}Row" - - field_definitions = {} - field_conversions = {} - - table_model = table.get_model() - for field_object in table_model.get_field_objects(): - field_definition = _get_pydantic_field_definition(field_object) - if field_definition.type is None: - continue # Skip unsupported field types - if field_ids is not None and field_object["field"].id not in field_ids: - continue # Skip fields not in the specified list - - field = field_object["field"] - field_definitions[field.name] = ( - field_definition.type, - field_definition.field_def, - ) - field_conversions[field.name] = ( - field.db_column, - field_definition.to_django_orm, - field_definition.from_django_orm, - ) - - class TableRowModel(BaseModel): - model_config = ConfigDict( - extra="forbid", - ) - - def to_django_orm(self) -> dict[str, Any]: - orm_data = {} - for key, value in self.__dict__.items(): - if key == "id": - orm_data["id"] = value - continue - - if key not in field_conversions or value == "__NO_CHANGE__": - continue - - orm_key, to_django_orm, _ = field_conversions[key] - if to_django_orm: - orm_data[orm_key] = to_django_orm(value) - else: - orm_data[orm_key] = value - return orm_data - - @classmethod - def from_django_orm( - cls, orm_row: GeneratedTableModel, field_ids: list[int] | None = None - ) -> "TableRowModel": - init_data = {"id": orm_row.id} - for field_object in orm_row.get_field_objects(): - field = field_object["field"] - if field.name not in field_conversions: - continue - if field_ids is not None and field.id not in field_ids: - continue - db_column, _, from_django_orm = field_conversions[field.name] - value = getattr(orm_row, db_column) - if from_django_orm: - init_data[field.name] = from_django_orm(value) - else: - init_data[field.name] = value - return cls(**init_data) - - return create_model( - model_name, - __module__=__name__, - __base__=TableRowModel, - **field_definitions, - ) - - -def get_update_row_model(table) -> BaseModel: - """Creates an update model where all fields can be NoChange.""" - - create_model_class = get_create_row_model(table) - - # Build update fields - all fields become Union[OriginalType, NoChange] - update_fields = {} - - for field_name, field_info in create_model_class.model_fields.items(): - original_type = field_info.annotation - - update_fields[field_name] = ( - Union[NoChange, original_type], - Field( - ..., - description=f"Use '__NO_CHANGE__' to keep current value. To update, use a {field_info.description}", - ), - ) - - update_fields["id"] = (int, Field(..., description="The ID of the row to update")) - - # Create the update model - UpdateRowModel = create_model( - f"UpdateTable{table.id}Row", - __base__=create_model_class, - **update_fields, - ) - - return UpdateRowModel - - -def get_view(user, view_id: int): - return ViewHandler().get_view_as_user(user, view_id) - - -def get_table_rows_tools( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers", table: Table -): - row_model_for_create = get_create_row_model(table) - row_model_for_update = get_update_row_model(table) - row_model_for_response = create_model( - f"ResponseTable{table.id}Row", - id=(int, ...), - __base__=row_model_for_create, - ) - - def _create_rows( - rows: list[dict[str, Any]], - ) -> list[dict[str, Any]]: - """ - Create new rows in the specified table. - """ - - nonlocal \ - user, \ - workspace, \ - tool_helpers, \ - row_model_for_create, \ - row_model_for_response - - if not rows: - return [] - - tool_helpers.update_status( - _("Creating rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - orm_rows = CreateRowsActionType.do( - user, - table, - [row_model_for_create(**row).to_django_orm() for row in rows], - ) - - return {"created_row_ids": [r.id for r in orm_rows]} - - create_row_model_schema = minimize_schema( - resolve_json_schema_reference(row_model_for_create.model_json_schema()) - ) - create_rows_tool = udspy.Tool( - func=_create_rows, - name=f"create_rows_in_table_{table.id}", - description=f"Creates new rows in the table {table.name} (ID: {table.id}). Max 20 rows at a time.", - args={ - "rows": { - "items": create_row_model_schema, - "type": "array", - "maxItems": 20, - } - }, - ) - - def _update_rows( - rows: list[dict[str, Any]], - ) -> list[dict[str, Any]]: - """ - Update existing rows in the specified table. - """ - - nonlocal user, workspace, tool_helpers, row_model_for_update - - if not rows: - return [] - - tool_helpers.update_status( - _("Updating rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - orm_rows = UpdateRowsActionType.do( - user, - table, - [row_model_for_update(**row).to_django_orm() for row in rows], - ).updated_rows - - return {"updated_row_ids": [r.id for r in orm_rows]} - - update_row_model_schema = minimize_schema( - resolve_json_schema_reference(row_model_for_update.model_json_schema()) - ) - update_rows_tool = udspy.Tool( - func=_update_rows, - name=f"update_rows_in_table_{table.id}", - description=f"Updates existing rows in the table {table.name} (ID: {table.id}), identified by their row IDs. Max 20 at a time.", - args={ - "rows": { - "items": update_row_model_schema, - "type": "array", - "maxItems": 20, - } - }, - ) - - def _delete_rows(row_ids: list[int]) -> str: - """ - Delete rows in the specified table. - """ - - nonlocal user, workspace, tool_helpers - - if not row_ids: - return - - tool_helpers.update_status( - _("Deleting rows in %(table_name)s ") % {"table_name": table.name} - ) - - with transaction.atomic(): - DeleteRowsActionType.do(user, table, row_ids) - - return {"deleted_row_ids": row_ids} - - delete_rows_tool = udspy.Tool( - func=_delete_rows, - name=f"delete_rows_in_table_{table.id}", - description=f"Deletes rows in the table {table.name} (ID: {table.id}). Max 20 at a time.", - args={ - "row_ids": { - "items": {"type": "integer"}, - "type": "array", - "maxItems": 20, - } - }, - ) - - return { - "create": create_rows_tool, - "update": update_rows_tool, - "delete": delete_rows_tool, - } - - -def create_view_filter( - user: AbstractUser, - orm_view: View, - table_fields: list[Field], - view_filter_item: AnyViewFilterItemCreate, -) -> ViewFilter: - """ - Creates a view filter from the given view filter item. - """ - - field = table_fields.get(view_filter_item.field_id) - if field is None: - raise ValueError("Field not found for filter") - field_type = field_type_registry.get_by_model(field.specific_class) - if field_type.type != view_filter_item.type: - raise ValueError("Field type mismatch for filter") - - filter_type = view_filter_item.get_django_orm_type(field) - filter_value = view_filter_item.get_django_orm_value( - field, timezone=user.profile.timezone - ) - - return CreateViewFilterActionType.do( - user, - orm_view, - field, - filter_type, - filter_value, - filter_group_id=None, - ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py new file mode 100644 index 0000000000..23864ac773 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tool_types.py @@ -0,0 +1,15 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class NavigationToolType(AssistantToolType): + type = "navigation" + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import navigation_toolset + + return navigation_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py index a9ad456ac7..53b867d789 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py @@ -1,52 +1,53 @@ -from typing import TYPE_CHECKING, Callable +from typing import Annotated -from django.contrib.auth.models import AbstractUser +from django.core.exceptions import ObjectDoesNotExist from django.utils.translation import gettext as _ -from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.registries import AssistantToolType +from pydantic import Field +from pydantic_ai import RunContext +from pydantic_ai.toolsets import FunctionToolset -from .types import AnyNavigationRequestType +from baserow_enterprise.assistant.deps import AssistantDeps -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers +from .types import AnyNavigationRequestType -def get_navigation_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[AnyNavigationRequestType], str]: +def navigate( + ctx: RunContext[AssistantDeps], + request: Annotated[ + AnyNavigationRequestType, + Field( + description="The navigation target: either a specific table or the workspace home." + ), + ], + thought: Annotated[ + str, Field(description="Brief reasoning for calling this tool.") + ], +) -> str: + """\ + Navigate the UI to a table, view, automation, page, or workspace home. + + WHEN to use: User asks to open, go to, or see something in the workspace. Also after creating new resources (views, fields, rows) in an existing database or table. + WHAT it does: Navigates the UI to a table, view, automation workflow, builder page, or workspace home. + RETURNS: Confirmation of navigation. + DO NOT USE when: You need data — use list/get tools instead. Navigation only changes the UI focus. """ - Returns a function that provides navigation instructions to the user based on - their current workspace context. - """ - - def navigate(request: AnyNavigationRequestType) -> str: - """ - Navigate within the workspace. - Use when: - - the user asks to open, go, to be brought to something - - the user asks to see something from their workspace - - if something new has been created in a previously existing database or table, - like a view, a field or some rows - """ - - nonlocal user, workspace + user = ctx.deps.user + workspace = ctx.deps.workspace + tool_helpers = ctx.deps.tool_helpers + try: location = request.to_location(user, workspace, request) + except ObjectDoesNotExist: + return "Error: could not navigate — the target was not found. Check that the ID is correct." - tool_helpers.update_status( - _("Navigating to %(location)s...") - % {"location": location.to_localized_string()} - ) - return tool_helpers.navigate_to(location) - - return navigate - + tool_helpers.update_status( + _("Navigating to %(location)s...") + % {"location": location.to_localized_string()} + ) + return tool_helpers.navigate_to(location) -class NavigationToolType(AssistantToolType): - type = "navigation" - @classmethod - def get_tool(cls, user, workspace, tool_helpers): - return get_navigation_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [navigate] +navigation_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py index f0421eb7bf..49219c4da0 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py @@ -5,11 +5,10 @@ from pydantic import Field from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.database.utils import filter_tables +from baserow_enterprise.assistant.tools.database.helpers import filter_tables from baserow_enterprise.assistant.types import ( BaseModel, TableNavigationType, - WorkspaceNavigationType, ) @@ -51,22 +50,7 @@ def to_location( ) -class WorkspaceNavigationRequestType(NavigationRequestType): - type: Literal["workspace"] = Field( - ..., description="The home page of the workspace" - ) - - @classmethod - def to_location( - cls, - user: AbstractUser, - workspace: Workspace, - request: "WorkspaceNavigationRequestType", - ) -> WorkspaceNavigationType: - return WorkspaceNavigationType(type="workspace", id=workspace.id) - - AnyNavigationRequestType = Annotated[ - TableNavigationRequestType | WorkspaceNavigationRequestType, + TableNavigationRequestType, Field(discriminator="type"), ] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py index ee2fd6c157..d6753d751a 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py @@ -1,16 +1,16 @@ +from baserow_enterprise.assistant.deps import EventBus from baserow_enterprise.assistant.types import AiNavigationMessage, AnyNavigationType -def unsafe_navigate_to(location: AnyNavigationType) -> str: +def unsafe_navigate_to(location: AnyNavigationType, event_bus: EventBus) -> str: """ Navigate to a specific table or view without any safety checks. Make sure all the IDs provided are valid and can be accessed by the user before calling this function. - :param navigation_type: The type of navigation to perform. + :param location: The type of navigation to perform. + :param event_bus: The event bus to emit the navigation event on. """ - from udspy.streaming import emit_event - - emit_event(AiNavigationMessage(location=location)) + event_bus.emit(AiNavigationMessage(location=location)) return "Navigated successfully." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py index df3a7dea16..e0d413dbed 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py @@ -1,111 +1,190 @@ -from typing import TYPE_CHECKING, Any, Callable +""" +Baserow registry for assistant tool types. -from django.contrib.auth.models import AbstractUser +Each tool module (navigation, database, etc.) registers an +``AssistantToolType`` instance. The registry assembles the combined +toolset at runtime, filtering by ``can_use(user, workspace)`` so +individual tool groups can be gated on permissions or feature flags. +""" -from baserow.core.exceptions import ( - InstanceTypeAlreadyRegistered, - InstanceTypeDoesNotExist, +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable + +from pydantic_ai.toolsets import AbstractToolset, CombinedToolset + +from baserow.core.registry import Instance, Registry +from baserow_enterprise.assistant.deps import AgentMode + +from .toolset import ( + InlineRefsToolset, + ModeAwareToolset, + generate_tool_manifest_compact, ) -from baserow.core.models import Workspace -from baserow.core.registries import Instance, Registry if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers + from django.contrib.auth.models import AbstractUser + from baserow.core.models import Workspace + from baserow_enterprise.assistant.deps import AssistantDeps -class AssistantToolType(Instance): - name: str = "" - - @classmethod - def can_use(cls, user: AbstractUser, workspace: Workspace, *args, **kwargs) -> bool: - """ - Returns whether or not the given user can use this tool in the given workspace. - :param user: The user to check if they can use this tool. - :param workspace: The workspace where to check if the tool can be used. - :return: True if the user can use this tool, False otherwise. - """ +class AssistantToolType(Instance): + """ + Base class for assistant tool groups. - return True + Each subclass represents a logical group of tools (e.g. "database", + "navigation"). Override ``can_use`` to gate availability on user + permissions or feature flags. + """ - @classmethod - def on_tool_start( - cls, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - """ - Called when the tool is started. It can be used to stream status messages. + type: str = "" - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the udspy tool being called. - :param inputs: The inputs provided to the tool. + def can_use(self, user: "AbstractUser", workspace: "Workspace") -> bool: """ + Permission gate. Override in subclasses for conditional availability. - pass - - @classmethod - def on_tool_end( - cls, - call_id: str, - instance: Any, - inputs: dict[str, Any], - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ): - """ - Called when the tool has finished, either successfully or with an exception. - - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the udspy tool being called. - :param inputs: The inputs provided to the tool. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. + :param user: The requesting user. + :param workspace: The current workspace. + :return: ``True`` if this tool group should be included. """ - pass + return True - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - """ - Returns the actual tool function to be called to pass to the udspy react agent. + def get_tool_functions(self) -> list[Callable]: + """Return the raw tool functions for manifest generation.""" - :param user: The user that will be using the tool. - :param workspace: The workspace the user is currently in. - :param tool_helpers: A dataclass containing helper functions that can be used by - the tool function. - """ + raise NotImplementedError - raise NotImplementedError("Subclasses must implement this method.") + def get_toolset(self) -> AbstractToolset: + """Return the pydantic-ai ``FunctionToolset`` for this group.""" + raise NotImplementedError -class AssistantToolDoesNotExist(InstanceTypeDoesNotExist): - pass + def get_routing_rules(self) -> str: + """Return routing rules text for this tool group's manifest. + Override in subclasses that define mode-specific routing rules. + Returns empty string by default (no rules). + """ -class AssistantToolAlreadyRegistered(InstanceTypeAlreadyRegistered): - pass + return "" class AssistantToolRegistry(Registry[AssistantToolType]): name = "assistant_tool" - does_not_exist_exception_class = AssistantToolDoesNotExist - already_registered_exception_class = AssistantToolAlreadyRegistered + def build_toolset( + self, + user: "AbstractUser", + workspace: "Workspace", + model: str, + deps: "AssistantDeps", + ) -> tuple[AbstractToolset, str, str, str, str]: + """ + Assemble the combined assistant toolset, filtering by ``can_use()``. + + :param user: The requesting user. + :param workspace: The current workspace. + :param model: The pydantic-ai model string. + :param deps: The assistant deps (used for mode-aware filtering). + :return: ``(toolset, database_manifest, application_manifest, + automation_manifest, explain_manifest)``. + """ + + toolsets: list[AbstractToolset] = [] + module_groups: list[tuple[str, list[Callable]]] = [] + + for tool_type in self.get_all(): + if not tool_type.can_use(user, workspace): + continue + toolsets.append(tool_type.get_toolset()) + module_groups.append((tool_type.type, tool_type.get_tool_functions())) + + combined = CombinedToolset(toolsets) + mode_aware = ModeAwareToolset(combined, deps) + + from .toolset import _get_mode_tool_map + + # Build a routing-rules lookup from registered tool types so each + # module owns its own rules (no hardcoded imports here). + routing_rules_by_type: dict[str, str] = { + tt.type: tt.get_routing_rules() + for tt in self.get_all() + if tt.get_routing_rules() + } + + mode_map = _get_mode_tool_map() + shared = mode_map[AgentMode.DATABASE] & mode_map[AgentMode.APPLICATION] + + _mode_config: list[tuple[str, AgentMode, str]] = [ + ("database", AgentMode.DATABASE, routing_rules_by_type.get("database", "")), + ( + "application", + AgentMode.APPLICATION, + routing_rules_by_type.get("builder", ""), + ), + ( + "automation", + AgentMode.AUTOMATION, + routing_rules_by_type.get("automation", ""), + ), + ] - def list_all_usable_tools( - self, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> list[AssistantToolType]: - return [ - tool_type.get_tool(user, workspace, tool_helpers) - for tool_type in self.get_all() - if tool_type.can_use(user, workspace) + manifests = {} + for mode_key, mode, rules in _mode_config: + allowed = mode_map[mode] + groups = [ + (label, [f for f in funcs if f.__name__ in allowed]) + for label, funcs in module_groups + ] + manifest = generate_tool_manifest_compact(groups, routing_rules=rules) + + # Append a compact cross-mode summary so the agent knows what + # capabilities exist in other modes (and can switch_mode to use them). + other_lines = [] + for other_key, other_mode, _ in _mode_config: + if other_key == mode_key: + continue + specific = mode_map[other_mode] - shared + other_lines.append(f"- {other_key}: {', '.join(sorted(specific))}") + if other_lines: + manifest += "\n\n## Other modes (switch_mode to access)\n" + "\n".join( + other_lines + ) + + manifests[mode_key] = manifest + + explain_allowed = mode_map[AgentMode.EXPLAIN] + explain_groups = [ + (label, [f for f in funcs if f.__name__ in explain_allowed]) + for label, funcs in module_groups ] + manifests["explain"] = generate_tool_manifest_compact(explain_groups) + + return ( + InlineRefsToolset(mode_aware, model=model), + manifests["database"], + manifests["application"], + manifests["automation"], + manifests["explain"], + ) assistant_tool_registry = AssistantToolRegistry() + + +def get_shared_read_funcs() -> list[Callable]: + """ + Return read-only tool functions shared across sub-agents. + + Uses deferred imports to avoid circular dependencies. + """ + + from baserow_enterprise.assistant.tools.database.tools import ( + get_tables_schema, + list_rows, + list_tables, + ) + + return [list_tables, get_tables_schema, list_rows] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py new file mode 100644 index 0000000000..95860d2695 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tool_types.py @@ -0,0 +1,20 @@ +from baserow_enterprise.assistant.tools.registries import AssistantToolType + + +class SearchDocsToolType(AssistantToolType): + type = "search_user_docs" + + def can_use(self, user, workspace) -> bool: + from .handler import KnowledgeBaseHandler + + return KnowledgeBaseHandler().can_search() + + def get_tool_functions(self): + from .tools import TOOL_FUNCTIONS + + return TOOL_FUNCTIONS + + def get_toolset(self): + from .tools import search_docs_toolset + + return search_docs_toolset diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py index 4d337e7685..bb1eba26f9 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_user_docs/tools.py @@ -1,64 +1,70 @@ -from typing import TYPE_CHECKING, Annotated, Any, Callable +import re +from typing import Annotated, Any -from django.contrib.auth.models import AbstractUser from django.utils.translation import gettext as _ -import udspy from asgiref.sync import sync_to_async +from loguru import logger +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent, RunContext +from pydantic_ai.toolsets import FunctionToolset -from baserow.core.models import Workspace +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.models import KnowledgeBaseChunk -from baserow_enterprise.assistant.tools.registries import AssistantToolType from .handler import KnowledgeBaseHandler -if TYPE_CHECKING: - from baserow_enterprise.assistant.assistant import ToolHelpers - - -class SearchDocsSignature(udspy.Signature): - """ - Given a user question and documentation chunks as context, provide an accurate - and concise answer along with a reliability score. - - CRITICAL: The context may contain documents retrieved by keyword similarity that - are NOT actually relevant to the user's question. You MUST carefully evaluate - each document's ACTUAL TOPIC before using it: - - 1. First, identify the SPECIFIC FEATURE or concept the user is asking about - 2. For each document, check if it DIRECTLY explains that specific feature - 3. IGNORE documents that merely mention similar keywords but cover different topics - (e.g., if asked about "webhooks in Baserow", ignore docs about external - webhook services or third-party integrations - only use docs about - Baserow's native webhook feature) - 4. Only use documents that would genuinely help answer THIS specific question - - If no documents in the context actually address the user's question (even if - they contain similar words), respond with "Nothing found in the documentation." - - Include instructions and URLs from the documentation when relevant. - Never fabricate answers or URLs. - """ - - question: str = udspy.InputField() - context: dict[str, str] = udspy.InputField( - desc=( - "A mapping of source URLs to documents. WARNING: These documents were " - "retrieved by keyword similarity and may include irrelevant results. " - "Carefully filter to only use documents that DIRECTLY address the question." - ) - ) - - answer: str = udspy.OutputField() - sources: list[str] = udspy.OutputField( - desc=( +# Regex that matches assistant tool names in a search query. Used to +# short-circuit search_user_docs when the model is trying to look up how +# its own tools work instead of answering a user question. +_TOOL_QUERY_RE = re.compile( + r"(?:list|create|get|update|delete|generate|load|add)_" + r"(?:tables?|fields?|views?|rows?|pages?|elements?|actions?|data_sources?|" + r"theme|workflows?|view_filters?|formula|row_tools|" + r"action_field_mapping|rows_in_table)" + r"|search_user_docs" + r"|\bnavigate\s+(?:tool|function|param)", + re.IGNORECASE, +) + + +SEARCH_DOCS_INSTRUCTIONS = """\ +Given a user question and documentation chunks as context, provide an accurate +and concise answer along with a reliability score. + +CRITICAL: The context may contain documents retrieved by keyword similarity that +are NOT actually relevant to the user's question. You MUST carefully evaluate +each document's ACTUAL TOPIC before using it: + +1. First, identify the SPECIFIC FEATURE or concept the user is asking about +2. For each document, check if it DIRECTLY explains that specific feature +3. IGNORE documents that merely mention similar keywords but cover different topics + (e.g., if asked about "webhooks in Baserow", ignore docs about external + webhook services or third-party integrations - only use docs about + Baserow's native webhook feature) +4. Only use documents that would genuinely help answer THIS specific question + +If no documents in the context actually address the user's question (even if +they contain similar words), respond with "Nothing found in the documentation." + +Include instructions and URLs from the documentation when relevant. +Never fabricate answers or URLs. +""" + + +class SearchDocsResult(PydanticBaseModel): + answer: str = Field(description="The answer to the user's question.") + sources: list[str] = Field( + default_factory=list, + description=( "URLs of documents that were ACTUALLY USED to form the answer. " "Only include sources that directly addressed the question topic. " "Leave empty if no documents were relevant. Maximum 3 URLs, ordered by relevance." - ) + ), ) - reliability: float = udspy.OutputField( - desc=( + reliability: float = Field( + description=( "How well the RELEVANT documents (not all documents) support the answer. " "1.0 = found documents that directly and completely answer the question. " "0.5 = found partially relevant information. " @@ -66,155 +72,195 @@ class SearchDocsSignature(udspy.Signature): ) ) - @classmethod - def format_context(cls, chunks: list[KnowledgeBaseChunk]) -> dict[str, str]: - """ - Formats the context as a list of strings for the signature. - Each string is formatted as "Source URL: content". - :param chunks: The list of knowledge base chunks. - :return: A dictionary mapping source URLs to their combined content. - """ +search_docs_agent: Agent[None, SearchDocsResult] = Agent( + output_type=SearchDocsResult, + instructions=SEARCH_DOCS_INSTRUCTIONS, + name="search_docs_agent", +) - context = {} - for chunk in chunks: - url = chunk.source_document.source_url - content = chunk.content - if url not in context: - context[url] = content - else: - context[url] += "\n" + content - - return context +def format_context(chunks: list[KnowledgeBaseChunk]) -> dict[str, str]: + """ + Formats the context as a mapping of source URLs to their combined content. -def get_search_user_docs_tool( - user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[str], dict[str, Any]]: + :param chunks: The list of knowledge base chunks. + :return: A dictionary mapping source URLs to their combined content. """ - Returns a tool function that searches Baserow's knowledge base and uses an LLM - to filter and synthesize relevant documentation into a focused answer. - The search retrieves documents by keyword similarity, then the LLM evaluates - each document's actual relevance to the question before generating an answer. + context = {} + for chunk in chunks: + url = chunk.source_document.source_url + content = chunk.content + if url not in context: + context[url] = content + else: + context[url] += "\n" + content + + return context + + +async def search_user_docs( + ctx: RunContext[AssistantDeps], + question: Annotated[ + str, + ( + "A precise search query in English using Baserow terminology. " + "Focus on the SPECIFIC Baserow feature being asked about. " + "Include the feature name and action, e.g., 'How to create webhooks in Baserow' " + "or 'Baserow table linking feature'. Avoid generic terms that could match " + "unrelated documentation about third-party services or integrations." + ), + ], + thought: Annotated[str, "Brief reasoning for calling this tool."], +) -> dict[str, Any]: + """\ + Search Baserow end-user docs for feature guides. NOT for tool introspection. It doesn't provide any information about your own tools. + + WHEN to use: User explicitly asks how to do something in Baserow's UI, or wants to learn about a specific Baserow feature (e.g., linking tables, webhooks, forms). + WHAT it does: Searches official Baserow end-user documentation and returns an answer with reliability score and source URLs. + RETURNS: Answer, reliability score (0.0-1.0), reliability_note (HIGH/PARTIAL/LOW), source URLs. Always check reliability_note before using the answer. + DO NOT USE when: Looking up how YOUR OWN tools work — you already know your tools from their names, descriptions, and schemas. Also not for API/programming documentation. + + IMPORTANT: Frame the question to target Baserow's NATIVE features specifically. + For example, ask about "Baserow webhooks" not just "webhooks" to avoid getting + results about external webhook services that integrate WITH Baserow. """ - async def search_user_docs( - question: Annotated[ - str, - ( - "A precise search query in English using Baserow terminology. " - "Focus on the SPECIFIC Baserow feature being asked about. " - "Include the feature name and action, e.g., 'How to create webhooks in Baserow' " - "or 'Baserow table linking feature'. Avoid generic terms that could match " - "unrelated documentation about third-party services or integrations." + tool_helpers = ctx.deps.tool_helpers + + # Guard: reject queries about the model's own tools. + if _TOOL_QUERY_RE.search(question): + logger.info("search_user_docs: rejected tool-introspection query: {}", question) + return { + "answer": ( + "STOP. This tool searches END-USER documentation only — " + "it has no information about your tools. " + "You already know how to use your tools from their names, " + "descriptions, and parameter schemas. " + "If a tool call failed, read the error message carefully " + "and adjust the parameters." ), - ], - ) -> dict[str, Any]: - """ - Search Baserow's official documentation for user guides and feature - explanations. - - PURPOSE: Provides end-user documentation about Baserow's built-in - features and how to use them through the UI. - - USE WHEN: The user asks how to do something in Baserow, wants to learn - about a Baserow feature, or needs step-by-step instructions. - - DO NOT USE FOR: Agent tool usage, API implementation details, or - programming help. - - IMPORTANT: Frame the question to target Baserow's NATIVE features - specifically. For example, ask about "Baserow webhooks" not just - "webhooks" to avoid getting results about external webhook services that - integrate WITH Baserow. - """ - - nonlocal tool_helpers - - tool_helpers.update_status(_("Exploring the knowledge base...")) - - @sync_to_async - def _search(question: str) -> list[KnowledgeBaseChunk]: - chunks = KnowledgeBaseHandler().search(question, 15) - return list(chunks) - - searcher = udspy.ChainOfThought(SearchDocsSignature) - relevant_chunks = await _search(question) - prediction = await searcher.aexecute( - question=question, - context=SearchDocsSignature.format_context(relevant_chunks), - stream=True, - ) + "reliability": 0.0, + "reliability_note": "REJECTED: Tool-introspection query.", + "sources": [], + } - sources = [] - available_urls = {chunk.source_document.source_url for chunk in relevant_chunks} - for url in prediction.sources: - # somehow LLMs sometimes return sources as objects - if isinstance(url, dict) and "url" in url: - url = url["url"] - - if not isinstance(url, str): - continue - - if url in available_urls and url not in sources: - sources.append(url) - if len(sources) >= 3: - break - - # Only fallback to available URLs if reliability is high AND we have a - # real answer. Don't populate sources if the model indicated no relevant - # docs were found. - nothing_found = "nothing found" in prediction.answer.lower() - if not sources and prediction.reliability > 0.8 and not nothing_found: - sources = list(available_urls)[:3] - - # Override reliability to 0 if the model explicitly said nothing was - # found. The model sometimes returns high reliability for "nothing - # found" answers, which is semantically incorrect - we want reliability - # to reflect whether we actually found useful information. - reliability = 0.0 if nothing_found else prediction.reliability - - if reliability >= 0.7: - reliability_note = ( - "HIGH CONFIDENCE: Answer is well-supported by the documentation." - ) - elif reliability >= 0.4: - reliability_note = ( - "PARTIAL MATCH: Some relevant information was found, but the " - "documentation may not fully cover this topic. Supplement with " - "general knowledge but warn the user that details may be incomplete." - ) - else: - reliability_note = ( + tool_helpers.update_status(_("Exploring the knowledge base...")) + + try: + return await _search_user_docs_impl(ctx, question) + except Exception: + logger.exception("search_user_docs failed for question: {}", question) + return { + "answer": "An error occurred while searching the documentation.", + "reliability": 0.0, + "reliability_note": ( + "LOW CONFIDENCE: The documentation search encountered an error. " + "Inform the user that documentation search is temporarily " + "unavailable and suggest they check baserow.io/docs directly." + ), + "sources": [], + } + + +async def _search_user_docs_impl( + ctx: RunContext[AssistantDeps], + question: str, +) -> dict[str, Any]: + """Inner implementation of search_user_docs, separated for error handling.""" + + @sync_to_async + def _search(question: str) -> list[KnowledgeBaseChunk]: + chunks = KnowledgeBaseHandler().search(question, 15) + return list(chunks) + + relevant_chunks = await _search(question) + + if not relevant_chunks: + return { + "answer": "Nothing found in the documentation.", + "reliability": 0.0, + "reliability_note": ( "LOW CONFIDENCE: The documentation does not contain information about " "this topic. DO NOT provide an answer based on general knowledge or " "assumptions - the feature may not exist in Baserow. Tell the user: " "'I couldn't find information about this in the official Baserow " "documentation.' and suggest they check the community forum or " "contact support." - ) - - return { - "answer": prediction.answer, - "reliability": reliability, - "reliability_note": reliability_note, - "sources": sources, + ), + "sources": [], } - return search_user_docs + context = format_context(relevant_chunks) + + prompt = ( + f"Question: {question}\n\n" + f"Documentation context (source URL -> content):\n{context}" + ) + from baserow_enterprise.assistant.model_profiles import get_model_string + + agent_result = await search_docs_agent.run(prompt, model=get_model_string()) + prediction = agent_result.output + + sources = [] + available_urls = {chunk.source_document.source_url for chunk in relevant_chunks} + for url in prediction.sources: + # somehow LLMs sometimes return sources as objects + if isinstance(url, dict) and "url" in url: + url = url["url"] + + if not isinstance(url, str): + continue + + if url in available_urls and url not in sources: + sources.append(url) + if len(sources) >= 3: + break + + # Only fallback to available URLs if reliability is high AND we have a + # real answer. Don't populate sources if the model indicated no relevant + # docs were found. + nothing_found = "nothing found" in prediction.answer.lower() + if not sources and prediction.reliability > 0.8 and not nothing_found: + sources = list(available_urls)[:3] + + # Override reliability to 0 if the model explicitly said nothing was + # found. The model sometimes returns high reliability for "nothing + # found" answers, which is semantically incorrect - we want reliability + # to reflect whether we actually found useful information. + reliability = 0.0 if nothing_found else prediction.reliability + + if reliability >= 0.7: + reliability_note = ( + "HIGH CONFIDENCE: Answer is well-supported by the documentation." + ) + elif reliability >= 0.4: + reliability_note = ( + "PARTIAL MATCH: Some relevant information was found, but the " + "documentation may not fully cover this topic. Supplement with " + "general knowledge but warn the user that details may be incomplete." + ) + else: + reliability_note = ( + "LOW CONFIDENCE: The documentation does not contain information about " + "this topic. DO NOT provide an answer based on general knowledge or " + "assumptions - the feature may not exist in Baserow. Tell the user: " + "'I couldn't find information about this in the official Baserow " + "documentation.' and suggest they check the community forum or " + "contact support." + ) + if sources: + ctx.deps.extend_sources(sources) -class SearchDocsToolType(AssistantToolType): - type = "search_user_docs" + return { + "answer": prediction.answer, + "reliability": reliability, + "reliability_note": reliability_note, + "sources": sources, + } - def can_use( - self, user: AbstractUser, workspace: Workspace, *args, **kwargs - ) -> bool: - return KnowledgeBaseHandler().can_search() - @classmethod - def get_tool( - cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" - ) -> Callable[[Any], Any]: - return get_search_user_docs_tool(user, workspace, tool_helpers) +TOOL_FUNCTIONS = [search_user_docs] +search_docs_toolset = FunctionToolset(TOOL_FUNCTIONS, max_retries=3) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py new file mode 100644 index 0000000000..b586400c55 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/__init__.py @@ -0,0 +1,25 @@ +from .agents import get_formula_generator +from .formula_utils import ( + FORMULA_PREFIX, + RAW_FORMULA_RE, + BaseFormulaContext, + create_example_from_json_schema, + formula_desc, + literal_or_placeholder, + minimize_json_schema, + needs_formula, + wrap_static_string, +) + +__all__ = [ + "FORMULA_PREFIX", + "RAW_FORMULA_RE", + "needs_formula", + "formula_desc", + "literal_or_placeholder", + "wrap_static_string", + "minimize_json_schema", + "create_example_from_json_schema", + "BaseFormulaContext", + "get_formula_generator", +] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py new file mode 100644 index 0000000000..058e5f8ecd --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/agents.py @@ -0,0 +1,141 @@ +""" +Shared formula generation agent factory. + +Contains: +- ``FormulaGeneratorOutput``: Output model for the formula generator agent. +- ``get_formula_generator()``: Factory to create a formula generator with a custom prompt. +""" + +from typing import Callable + +from pydantic import BaseModel as PydanticBaseModel +from pydantic import Field +from pydantic_ai import Agent + +from baserow.core.formula import resolve_formula +from baserow.core.formula.registries import formula_runtime_function_registry +from baserow.core.formula.types import ( + BASEROW_FORMULA_MODE_ADVANCED, + BaserowFormulaObject, +) + +from .formula_utils import BaseFormulaContext + + +class FormulaGeneratorOutput(PydanticBaseModel): + """Output model for the formula generator agent.""" + + generated_formulas: dict[str, str] = Field( + description=( + "A mapping of field identifiers to their generated formulas. " + "Each key is a field id/name from `fields_to_resolve` and the value " + "is the generated formula string." + ) + ) + + +def get_formula_generator( + prompt: str, +) -> Callable[[dict, BaseFormulaContext, int], dict[str, str]]: + """ + Factory to create a formula generator with a custom prompt. + + :param prompt: The system prompt for the LLM describing available functions. + :return: A function that generates formulas from field descriptions. + """ + + formula_agent = Agent( + output_type=FormulaGeneratorOutput, + instructions=prompt, + name="formula_agent", + ) + + def check_formula(generated_formula: str, context: BaseFormulaContext) -> str: + """Validate a generated formula against the context.""" + try: + resolve_formula( + BaserowFormulaObject.create( + formula=generated_formula, mode=BASEROW_FORMULA_MODE_ADVANCED + ), + formula_runtime_function_registry, + context, + ) + except Exception as exc: + raise ValueError(f"Generated formula is invalid: {str(exc)}") + return "ok, the formula is valid" + + def generate_formulas( + fields_to_resolve: dict, + context: BaseFormulaContext, + max_retries: int = 3, + ) -> dict[str, str]: + """ + Generate formulas for the given field descriptions. + + :param fields_to_resolve: Dict mapping field names to descriptions. + :param context: Formula context with available data. + :param max_retries: Number of retry attempts on validation failure. + :return: Dict mapping field names to generated formulas. + :raises ValueError: If no valid formulas could be generated. + """ + feedback = "" + valid_formulas = {} + remaining = dict(fields_to_resolve) + + for __ in range(max_retries): + if not remaining: + break + + user_prompt = ( + f"Fields to resolve: {remaining}\n" + f"(If prefixed with [optional], the field is not mandatory.)\n\n" + f"Context: {context.get_formula_context()}\n\n" + f"Context metadata: {context.get_context_metadata()}\n" + f"(Metadata about the context fields, with refs and names " + f"to assist in formula generation.)\n\n" + f"Feedback: {feedback or 'None (first attempt)'}" + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + get_model_string, + ) + + model = get_model_string() + try: + result = formula_agent.run_sync( + user_prompt, + model=model, + model_settings=get_model_settings(model, UTILITY), + ) + except Exception as exc: + feedback += f"Formula agent error: {str(exc)}\n" + continue + + generated_formulas = result.output.generated_formulas + for field_id, formula in generated_formulas.items(): + if field_id not in remaining: + continue + try: + check_formula(formula, context) + valid_formulas[field_id] = formula + remaining.pop(field_id, None) + except ValueError as exc: + feedback += ( + f"Error for {field_id}, formula {formula} not valid: " + f"{str(exc)}\n" + ) + + if not remaining: + return valid_formulas + + # Return any valid formulas we have, or raise if none + if valid_formulas: + return valid_formulas + else: + raise ValueError( + f"Failed to generate any valid formulas after " + f"{max_retries} attempts. Feedback:\n{feedback}" + ) + + return generate_formulas diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py new file mode 100644 index 0000000000..30d87fcfdd --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_prompt.py @@ -0,0 +1,84 @@ +""" +Shared formula language reference for formula generation prompts. + +This module contains the common formula language documentation shared between +the automation and builder formula generators. Context-specific sections +(automation paths, builder data providers) are appended by each consumer. +""" + +FORMULA_LANGUAGE = """\ +You are a formula builder. Generate formulas using the Baserow formula language. + +## Value Access + +**get(path)** - Retrieves values from context using dot-separated path notation +- Objects: get('user.name') +- Arrays by index: get('items.0'), get('orders.2.total') +- Nested: get('users.0.address.city') +- Wildcard: get('users.*.email') returns a list of values from all items + +## Field Type Suffixes + +When accessing database fields via get(), certain field types require a suffix +to extract the display value. Use the correct suffix based on the field type +reported in context_metadata: + +| Field type | Suffix | Example path | +|---|---|---| +| text, number, boolean, date, url, email, phone_number, rating, long_text, uuid | *(none)* | `field_10` | +| single_select | `.value` | `field_10.value` | +| multiple_select | `.*.value` | `field_10.*.value` | +| link_row | `.*.value` | `field_10.*.value` | +| last_modified_by | `.name` | `field_10.name` | +| created_by | `.name` | `field_10.name` | +| multiple_collaborators | `.*.name` | `field_10.*.name` | +| file | `.*.url` or `.*.visible_name` | `field_10.*.url` | + +Always check the field type in context_metadata and apply the matching suffix. + +## Operators + +**Comparison** (return boolean): +- equal(a, b), not_equal(a, b) +- greater_than(a, b), less_than(a, b) +- greater_than_or_equal(a, b), less_than_or_equal(a, b) +- Infix: a==b, a!=b, ab, a>=b + +**Arithmetic:** +- add(a, b) or a+b, minus(a, b) or a-b +- multiply(a, b) or a*b, divide(a, b) or a/b + +**Logic:** +- and(a, b), or(a, b) + +## Functions + +**Core:** +- concat(...args) - Join arguments into a string: concat('Hello ', get('name'), '!') +- if(condition, true_value, false_value) - Conditional expression + +**String:** +- upper(text), lower(text), capitalize(text) +- strip(text), replace(text, old, new), length(text), contains(text, search) +- split(text, separator), join(array, separator) + +**Number:** +- round(num, decimals), is_even(num), is_odd(num) + +**Date:** +- today() - Current date +- now() - Current date and time +- day(date), month(date), year(date), hour(datetime), minute(datetime), second(datetime) +- datetime_format(datetime, format) + +**Array:** +- sum(array), avg(array), at(array, index) + +**Utility:** +- is_empty(value), get_property(object, key) + +## Constants + +- String literals in single quotes: 'hello world', '123' +- Numbers: 42, 3.14 +""" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py new file mode 100644 index 0000000000..3b92631e00 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/shared/formula_utils.py @@ -0,0 +1,272 @@ +import re +from abc import ABC, abstractmethod +from datetime import date, datetime +from typing import Any + +from baserow.core.formula.types import FormulaContext +from baserow.core.utils import to_path + +# ============================================================================= +# Formula Detection Constants and Helpers +# ============================================================================= + +FORMULA_PREFIX = "$formula:" + +# Detects raw formula syntax the LLM might write instead of using $formula:. +# Matches: get('...'), concat(...), {{ ... }}, comparison operators, if(...), +# today(), now(). +RAW_FORMULA_RE = re.compile( + r"\bget\s*\(|\bconcat\s*\(|\{\{.*\}\}" + r"|\b(?:equal|not_equal|greater_than|less_than" + r"|greater_than_(?:or_)?equal|less_than_(?:or_)?equal)\s*\(" + r"|\bif\s*\(|\btoday\s*\(|\bnow\s*\(" +) + + +def needs_formula(value: str | None) -> bool: + """ + Check if a value requires formula processing. + + Returns True for explicit ``$formula:`` prefixed values *and* for raw + formula expressions the LLM may write inline (e.g. ``get('field')`` + or ``{{ get('field') }}``). + + :param value: The string value to check, or None. + :return: True if the value needs formula generation. + """ + + if not value: + return False + stripped = value.strip() + return stripped.lower().startswith(FORMULA_PREFIX) or bool( + RAW_FORMULA_RE.search(stripped) + ) + + +def formula_desc(value: str) -> str: + """ + Extract the formula description from a value. + + For ``$formula:`` prefixed values, strips the prefix. + For raw formula expressions, returns the value as-is so the + formula generator can convert it to a proper formula. + + :param value: A string containing a formula description or raw formula. + :return: The description text or raw formula expression. + """ + + stripped = value.strip() + if stripped.lower().startswith(FORMULA_PREFIX): + return stripped[len(FORMULA_PREFIX) :].strip() + # Raw formula expression — pass through for the generator to fix up + return stripped + + +def literal_or_placeholder(value: str | None) -> str: + """ + Return a quoted literal formula, or empty placeholder for formula values. + + Used when creating ORM objects: formula fields get a ``''`` placeholder + that will be replaced later by the formula generator, while literal + values are wrapped in single quotes. + + :param value: The string value, or None. + :return: A single-quoted formula literal or ``''`` placeholder. + """ + + if not value or needs_formula(value): + return "''" + return wrap_static_string(value) + + +def wrap_static_string(value: str) -> str: + """ + Wrap a static string as a Baserow formula literal. + + If the value is already a quoted formula literal (e.g. ``'Submit'``), + it is returned unchanged to avoid double-wrapping which would produce + escaped quotes visible in the UI (e.g. ``'\\'Submit\\''``). + + :param value: Plain text string or already-quoted formula literal. + :return: Formula-compatible string literal with proper escaping. + """ + + if len(value) >= 2 and value[0] == "'" and value[-1] == "'": + return value + escaped = value.replace("'", "\\'") + return f"'{escaped}'" + + +# ============================================================================= +# JSON Schema Utilities +# ============================================================================= + + +def minimize_json_schema(schema: dict) -> dict[str, dict[str, str]]: + """ + Generate a mapping between field ids and names/types from a JSON schema. + Useful when generating formulas to understand the provided context. + + :param schema: JSON schema dict with properties and metadata. + :return: Mapping of field_key -> {id, name, type, desc, ...}. + """ + field_type_descriptions = { + "link_row": "the row ID as number or the primary field value as string", + "single_select": "the option ID as number or the value as string", + "multiple_select": "a comma separated list of option IDs or values as string", + "date": "a date string in ISO 8601 format", + "date_time": "a date-time string in ISO 8601 format", + "boolean": "true or false", + } + field_type_extra_info = { + "single_select": lambda meta: { + "select_options": meta.get("select_options", []) + }, + "multiple_select": lambda meta: { + "select_options": meta.get("select_options", []) + }, + "multiple_collaborators": lambda meta: { + "available_collaborators": meta.get("available_collaborators", []) + }, + } + + if schema.get("type") == "array": + return minimize_json_schema(schema.get("items")) + elif schema.get("type") != "object": + raise ValueError("Schema must be of type object or array of objects") + + properties = schema.get("properties", {}) + mapping = {} + for key, prop in properties.items(): + metadata = prop.get("metadata") + if metadata: + field_type = metadata["type"] + mapping[key] = { + "id": metadata["id"], + "name": metadata["name"], + "type": field_type, + "desc": field_type_descriptions.get(field_type, ""), + } + if field_type in field_type_extra_info: + get_extra_info = field_type_extra_info[field_type] + mapping[key].update(get_extra_info(metadata)) + return mapping + + +def create_example_from_json_schema(schema: dict) -> Any: + """ + Generate example data from a JSON schema. + Useful when generating formulas to provide example context data. + + :param schema: JSON schema dict. + :return: Example data matching the schema structure. + """ + examples = { + "string": "1", + "number": 1, + "boolean": True, + "null": None, + "object": lambda prop: create_example_from_json_schema(prop), + "array": lambda prop: [create_example_from_json_schema(prop["items"])], + } + + if schema.get("type") == "array": + return [create_example_from_json_schema(schema.get("items"))] + elif schema.get("type") != "object": + raise ValueError("Schema must be of type object or array of objects") + + properties = schema.get("properties", {}) + example = {} + for key, prop in properties.items(): + value = examples[prop.get("type")] + if callable(value): + example[key] = value(prop) + else: + example[key] = value + return example + + +# ============================================================================= +# Base Formula Context +# ============================================================================= + + +class BaseFormulaContext(FormulaContext, ABC): + """ + Base context for formula generation, shared between automation and builder. + + Subclasses must implement get_formula_context() and __getitem__ for + path resolution. + """ + + def __init__(self): + self.context: dict[str, Any] = {} + self.context_metadata: dict[str, Any] = {} + super().__init__() + + def add_context( + self, + key: str, + example_data: Any, + metadata: dict[str, Any] | None = None, + ): + """ + Add data to the formula context. + + :param key: Context key (e.g., "data_source.5" or "1" for node ID). + :param example_data: Example data for this context entry. + :param metadata: Optional metadata describing the structure. + """ + self.context[key] = example_data + if metadata: + self.context_metadata[key] = metadata + + @abstractmethod + def get_formula_context(self) -> dict[str, Any]: + """Return the context dict for formula generation.""" + pass + + def get_context_metadata(self) -> dict[str, Any]: + """Return metadata about the context.""" + return self.context_metadata + + def _resolve_path(self, key: str, root_key: str) -> Any: + """ + Resolve a dotted path through the context. + + :param key: Full path like "data_source.5.field_name". + :param root_key: Expected root key to validate against. + :return: The resolved value. + :raises KeyError: If path cannot be resolved. + :raises ValueError: If resolved value is not a primitive type. + """ + start, *key_parts = to_path(key) + if start != root_key: + raise KeyError( + f"Key '{key}' not found in context. " + f"Only '{root_key}' is supported at the root level." + ) + + value = self.context + for kp in key_parts: + try: + value = value[int(kp) if isinstance(value, list) else kp] + except (KeyError, TypeError, ValueError): + available_keys = ( + list(value.keys()) + if isinstance(value, dict) + else ", ".join(map(str, range(len(value)))) + ) + raise KeyError( + f"Key '{kp}' of '{key}' not found in {value}, " + f"Available keys: {available_keys}" + ) + + if not isinstance(value, (int, float, str, bool, date, datetime)): + raise ValueError( + f"Value for key '{key}' is not a valid type. " + f"Expected int, float, str, bool, date, or datetime. " + f"Got {type(value).__name__} instead. " + f"Make sure to only reference primitive types in the formula context." + ) + return value diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py new file mode 100644 index 0000000000..f4e4497a4c --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/toolset.py @@ -0,0 +1,438 @@ +""" +Pydantic-ai toolset utilities for the assistant. + +Contains schema helpers (``inline_refs``), lenient argument validation, +the ``InlineRefsToolset`` wrapper, ``ModeAwareToolset``, and the compact +tool manifest builder. These are pure toolset concerns with no dependency +on the Baserow registry system. +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, Callable + +from loguru import logger +from pydantic import ValidationError +from pydantic_ai import Agent +from pydantic_ai.exceptions import ModelRetry +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.abstract import AgentDepsT, ToolsetTool +from typing_extensions import Self + +from baserow_enterprise.assistant.deps import AgentMode + +if TYPE_CHECKING: + from baserow_enterprise.assistant.deps import AssistantDeps + +# --------------------------------------------------------------------------- +# Schema utilities +# --------------------------------------------------------------------------- + +# Keys that are JSON Schema / Pydantic metadata the LLM doesn't need. +_STRIP_KEYS = frozenset({"$defs", "discriminator", "title"}) + + +def inline_refs(schema: dict) -> dict: + """ + Recursively resolve all ``$ref`` pointers in a JSON schema, producing a + self-contained schema with no ``$defs`` section. + + Also strips ``discriminator`` and ``title`` metadata that LLMs don't need + and that can contain dangling ``$defs`` references. + + Many LLM providers (especially open-weight models behind Groq) struggle + with ``$ref`` / ``$defs`` indirection. Inlining makes the schema + directly readable by the model. + """ + + defs = schema.get("$defs", {}) + _seen: set[str] = set() # guard against circular refs + + def _resolve(node, *, _inside_properties=False): + if isinstance(node, dict): + if "$ref" in node: + ref_name = node["$ref"].rsplit("/", 1)[-1] + if ref_name in _seen: + return {"type": "object"} # break circular ref + _seen.add(ref_name) + resolved = _resolve(defs[ref_name]) if ref_name in defs else node + _seen.discard(ref_name) + return resolved + result = {} + for k, v in node.items(): + # Strip JSON Schema metadata keys, but never strip property + # names inside a "properties" dict (e.g. a field literally + # named "title" or "description"). + if k in _STRIP_KEYS and not _inside_properties: + continue + result[k] = _resolve(v, _inside_properties=(k == "properties")) + return result + if isinstance(node, list): + return [_resolve(item) for item in node] + return node + + return _resolve(schema) + + +# --------------------------------------------------------------------------- +# Lenient validator & fixer +# --------------------------------------------------------------------------- + +_FIXER_PROMPT = """\ +You are a JSON repair tool. You receive a JSON object that failed schema \ +validation, the validation errors, and the target JSON schema. Return ONLY \ +the fixed JSON object — no explanation, no markdown fences. Preserve the \ +original values as much as possible; only change what is needed to satisfy \ +the schema.""" + + +class _LenientValidator: + """ + Drop-in replacement for pydantic-core ``SchemaValidator`` that parses + JSON without enforcing the tool's parameter schema. + + Real validation happens later in ``InlineRefsToolset.call_tool()``, + where we can attempt an async structured-output fix before failing. + """ + + def validate_json(self, input, *, allow_partial="off", **kwargs): + if isinstance(input, (str, bytes, bytearray)): + return json.loads(input) if input else {} + return input + + def validate_python(self, input, *, allow_partial="off", **kwargs): + return input if input is not None else {} + + +_LENIENT_VALIDATOR = _LenientValidator() + + +# --------------------------------------------------------------------------- +# InlineRefsToolset +# --------------------------------------------------------------------------- + + +class InlineRefsToolset(AbstractToolset[AgentDepsT]): + """ + Wraps another toolset with two responsibilities: + + 1. **Inline $ref/$defs** in tool parameter schemas so open-weight models + can parse them directly. + 2. **Fix broken tool args** via a lightweight structured-output call + instead of going through the full agent retry loop (which is slow + and rarely succeeds). + """ + + def __init__(self, inner: AbstractToolset[AgentDepsT], model: str): + self._inner = inner + self._model = model + self._original_validators: dict[str, Any] = {} + self._schemas: dict[str, dict] = {} + + @property + def id(self) -> str: + return self._inner.id + + # --- Delegation methods (match WrapperToolset pattern) --- + + async def __aenter__(self) -> Self: + await self._inner.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self._inner.__aexit__(*args) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self._inner.apply(visitor) + + def visit_and_replace( + self, + visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]], + ) -> AbstractToolset[AgentDepsT]: + new = InlineRefsToolset( + self._inner.visit_and_replace(visitor), model=self._model + ) + return new + + # --- Tool interception --- + + async def get_tools(self, ctx) -> dict[str, ToolsetTool[AgentDepsT]]: + tools = await self._inner.get_tools(ctx) + for name, tool in tools.items(): + # Inline $ref/$defs in the JSON schema + tool.tool_def.parameters_json_schema = inline_refs( + tool.tool_def.parameters_json_schema + ) + # Save the original validator and schema once, then replace with + # lenient passthrough so validation failures reach call_tool() + # where we can attempt an async fix. Guard against multiple calls + # so we don't overwrite the real validator with _LENIENT_VALIDATOR. + if name not in self._original_validators: + self._original_validators[name] = tool.args_validator + self._schemas[name] = tool.tool_def.parameters_json_schema + tool.args_validator = _LENIENT_VALIDATOR + return tools + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: Any, + tool: ToolsetTool[AgentDepsT], + ) -> Any: + original_validator = self._original_validators.get(name) + if original_validator: + try: + tool_args = original_validator.validate_python(tool_args) + except ValidationError as e: + tool_args = await self._fix_tool_args(name, tool_args, e) + return await self._inner.call_tool(name, tool_args, ctx, tool) + + async def _fix_tool_args( + self, + tool_name: str, + wrong_args: dict[str, Any], + error: ValidationError, + ) -> dict[str, Any]: + """ + Attempt to fix invalid tool arguments via a lightweight structured- + output call. If the fix also fails validation, raises ``ModelRetry`` + so pydantic-ai can handle it normally. + """ + + schema = self._schemas.get(tool_name, {}) + error_details = error.errors(include_url=False, include_context=False) + + logger.warning( + "[assistant] Tool '{}' args failed validation, attempting fix. Errors: {}", + tool_name, + error_details, + ) + + prompt = ( + f"Tool: {tool_name}\n\n" + f"Schema:\n{json.dumps(schema, indent=2)}\n\n" + f"Invalid input:\n{json.dumps(wrong_args, indent=2)}\n\n" + f"Validation errors:\n{json.dumps(error_details, indent=2)}" + ) + + try: + fix_agent = Agent( + output_type=str, + instructions=_FIXER_PROMPT, + name="fix_agent", + ) + from baserow_enterprise.assistant.model_profiles import ( + UTILITY, + get_model_settings, + ) + + fixer_settings = get_model_settings(self._model, UTILITY) + result = await fix_agent.run( + prompt, + model=self._model, + model_settings={ + **fixer_settings, + "response_format": {"type": "json_object"}, + }, + ) + fixed_args = json.loads(result.output) + except Exception as exc: + logger.warning( + "[assistant] Fixer call failed for tool '{}': {}", + tool_name, + exc, + ) + raise ModelRetry( + f"Tool arguments invalid and fix attempt failed: {error_details}" + ) from exc + + # Re-validate with original schema + original_validator = self._original_validators[tool_name] + try: + validated = original_validator.validate_python(fixed_args) + except ValidationError as e2: + logger.warning( + "[assistant] Fixed args for tool '{}' still invalid: {}", + tool_name, + e2.errors(include_url=False, include_context=False), + ) + raise ModelRetry( + f"Tool arguments still invalid after fix attempt: " + f"{e2.errors(include_url=False, include_context=False)}" + ) from e2 + + return validated + + +# --------------------------------------------------------------------------- +# Mode-aware toolset +# --------------------------------------------------------------------------- + + +def _build_mode_tool_map() -> dict[AgentMode, frozenset[str]]: + """Build mode → tool-names mapping from actual function references. + + Derives names via ``f.__name__`` instead of hand-maintained string + lists to eliminate typo risk. + """ + + from .automation.tools import TOOL_FUNCTIONS as AUTO_FN + from .core.tools import create_builders, list_builders, switch_mode + from .database.tools import TOOL_FUNCTIONS as DB_FN + from .navigation.tools import navigate + from .search_user_docs.tools import search_user_docs + + try: + from .builder.tools import TOOL_FUNCTIONS as BUILDER_FN + except ImportError: + BUILDER_FN = [] + + n = frozenset # alias for readability + + def names(*funcs): + return n(f.__name__ for f in funcs) + + shared = names( + navigate, + switch_mode, + list_builders, + # Read-only database tools available in every mode + *[f for f in DB_FN if f.__name__.startswith(("list_", "get_"))], + ) + + return { + AgentMode.DATABASE: shared | names(*DB_FN, create_builders), + AgentMode.APPLICATION: shared | names(*BUILDER_FN, create_builders), + AgentMode.AUTOMATION: shared | names(*AUTO_FN, create_builders), + AgentMode.EXPLAIN: shared + | names( + *[f for f in BUILDER_FN if f.__name__.startswith("list_")], + *[f for f in AUTO_FN if f.__name__.startswith("list_")], + search_user_docs, + ), + } + + +_MODE_TOOL_MAP: dict[AgentMode, frozenset[str]] | None = None + + +def _get_mode_tool_map() -> dict[AgentMode, frozenset[str]]: + global _MODE_TOOL_MAP + if _MODE_TOOL_MAP is None: + _MODE_TOOL_MAP = _build_mode_tool_map() + return _MODE_TOOL_MAP + + +class ModeAwareToolset(AbstractToolset[AgentDepsT]): + """ + Filters the inner toolset based on the current :class:`AgentMode`. + + Each domain mode (DATABASE, APPLICATION, AUTOMATION) exposes only its + relevant tools plus shared read-only tools. EXPLAIN mode exposes + read-only tools plus ``search_user_docs``. + """ + + def __init__(self, inner: AbstractToolset[AgentDepsT], deps: "AssistantDeps"): + self._inner = inner + self._deps = deps + + @property + def id(self) -> str: + return self._inner.id + + async def __aenter__(self) -> Self: + await self._inner.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self._inner.__aexit__(*args) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self._inner.apply(visitor) + + def visit_and_replace( + self, + visitor: Callable[[AbstractToolset[AgentDepsT]], AbstractToolset[AgentDepsT]], + ) -> AbstractToolset[AgentDepsT]: + return ModeAwareToolset(self._inner.visit_and_replace(visitor), self._deps) + + async def get_tools(self, ctx) -> dict[str, ToolsetTool[AgentDepsT]]: + all_tools = await self._inner.get_tools(ctx) + allowed = _get_mode_tool_map()[self._deps.mode] + return {k: v for k, v in all_tools.items() if k in allowed} + + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: Any, + tool: ToolsetTool[AgentDepsT], + ) -> Any: + from baserow.core.exceptions import UserNotInWorkspace + from baserow_enterprise.assistant.tools.database.helpers import ToolInputError + + try: + return await self._inner.call_tool(name, tool_args, ctx, tool) + except ToolInputError as exc: + return {"error": str(exc)} + except UserNotInWorkspace: + return { + "error": ( + "One or more IDs reference a resource outside the current " + "workspace. Use the appropriate list_* tool to find " + "the correct IDs and retry." + ) + } + + +# --------------------------------------------------------------------------- +# Compact tool manifest +# --------------------------------------------------------------------------- + + +def tool_manifest_line_compact(name: str, description: str) -> str: + """Format a single tool entry — first line of description only.""" + + desc = description.strip() + first_line = desc.split("\n")[0].strip() if desc else name + return f"- {name}: {first_line}" + + +_MODULE_LABELS: dict[str, str] = { + "core": "Core (workspace & modules)", + "navigation": "Navigation", + "database": "Database (tables, fields, views, rows)", + "builder": "Application Builder (pages, elements, data sources, actions)", + "automation": "Automations (workflows, triggers, actions)", + "search_user_docs": "Documentation", +} + + +def generate_tool_manifest_compact( + module_groups: list[tuple[str, list[Callable]]], + routing_rules: str = "", +) -> str: + """ + Build a compact ```` manifest: routing rules + tools + grouped by module with section headers. + + :param module_groups: ``(module_type, funcs)`` pairs, one per module. + :param routing_rules: Cross-tool routing rules to prepend. + :return: A newline-separated manifest string. + """ + + lines: list[str] = [] + if routing_rules: + lines.append(routing_rules.strip()) + lines.append("") + for module_type, funcs in module_groups: + if not funcs: + continue + label = _MODULE_LABELS.get(module_type, module_type) + lines.append(f"## {label}") + for func in funcs: + lines.append(tool_manifest_line_compact(func.__name__, func.__doc__ or "")) + lines.append("") + return "\n".join(lines).rstrip() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/types.py b/enterprise/backend/src/baserow_enterprise/assistant/types.py index 080dbee730..21be1de4bb 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/types.py @@ -4,7 +4,6 @@ from django.utils.translation import gettext as _ -import udspy from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field @@ -12,6 +11,7 @@ class BaseModel(PydanticBaseModel): model_config = ConfigDict( extra="forbid", + coerce_numbers_to_str=True, ) @@ -157,7 +157,7 @@ class AiMessage(AiMessageChunk): ) -class AiThinkingMessage(BaseModel, udspy.StreamEvent): +class AiThinkingMessage(BaseModel): type: Literal["ai/thinking"] = AssistantMessageType.AI_THINKING.value content: str = Field( default="", @@ -236,16 +236,27 @@ def to_localized_string(self): return _("workflow %(workflow_name)s") % {"workflow_name": self.workflow_name} +class BuilderPageNavigationType(BaseModel): + type: Literal["builder-page"] + application_id: int + page_id: int + page_name: str + + def to_localized_string(self): + return _("page %(page_name)s") % {"page_name": self.page_name} + + AnyNavigationType = Annotated[ TableNavigationType | WorkspaceNavigationType | ViewNavigationType - | WorkflowNavigationType, + | WorkflowNavigationType + | BuilderPageNavigationType, Field(discriminator="type"), ] -class AiNavigationMessage(BaseModel, udspy.StreamEvent): +class AiNavigationMessage(BaseModel): type: Literal["ai/navigation"] = "ai/navigation" location: AnyNavigationType diff --git a/enterprise/backend/src/baserow_enterprise/config/settings/settings.py b/enterprise/backend/src/baserow_enterprise/config/settings/settings.py index d3e8852e11..595288e713 100644 --- a/enterprise/backend/src/baserow_enterprise/config/settings/settings.py +++ b/enterprise/backend/src/baserow_enterprise/config/settings/settings.py @@ -79,6 +79,35 @@ def setup(settings): settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = os.getenv( "BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL", "" ) - settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE = float( - os.getenv("BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", "") or 0.3 + _temp_raw = os.getenv("BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE", "") + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE = ( + float(_temp_raw) if _temp_raw else None ) + + # Backward compatibility: bridge old UDSPY_LM_* env vars so existing + # deployments continue to work without config changes. + _udspy_model = os.getenv("UDSPY_LM_MODEL", "") + if _udspy_model and not settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL: + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = _udspy_model + + _udspy_api_key = os.getenv("UDSPY_LM_API_KEY", "") + if _udspy_api_key: + # pydantic-ai reads provider-specific env vars. Set them all as + # fallbacks so the old catch-all key works regardless of provider. + for _key in ( + "OPENAI_API_KEY", + "GROQ_API_KEY", + "ANTHROPIC_API_KEY", + "GEMINI_API_KEY", + ): + os.environ.setdefault(_key, _udspy_api_key) + + _udspy_base_url = os.getenv("UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL", "") + if _udspy_base_url: + # pydantic-ai's OpenAI provider reads OPENAI_BASE_URL. + os.environ.setdefault("OPENAI_BASE_URL", _udspy_base_url) + + # Bridge old AWS_REGION_NAME to boto3's standard AWS_DEFAULT_REGION. + _aws_region = os.getenv("AWS_REGION_NAME", "") + if _aws_region: + os.environ.setdefault("AWS_DEFAULT_REGION", _aws_region) diff --git a/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py b/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py new file mode 100644 index 0000000000..1ebf92109c --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/migrations/0058_assistantchat_message_history.py @@ -0,0 +1,24 @@ +# Generated by Django 5.0.13 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("baserow_enterprise", "0057_role_hidden"), + ] + + operations = [ + migrations.AddField( + model_name="assistantchat", + name="message_history", + field=models.BinaryField( + blank=True, + help_text=( + "Serialized pydantic-ai message history (JSON bytes) for " + "multi-turn conversation context." + ), + null=True, + ), + ), + ] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/__init__.py @@ -0,0 +1 @@ + diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py new file mode 100644 index 0000000000..ea0299c75a --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/conftest.py @@ -0,0 +1,153 @@ +import asyncio +import logging +import os +import sys + +from django.conf import settings + +import pytest +from loguru import logger + +from baserow.config.settings.test import TEST_ENV_VARS + +# Suppress DEBUG-level loguru output during evals. Baserow's cache layer logs +# every cache hit/miss at DEBUG, which floods the output when using -s. Agent +# message history is printed via print() and is captured by pytest: it appears +# in the failure report automatically without needing -s. +logger.remove() +logger.add(sys.stderr, level="WARNING") + +# Expose API keys from TEST_ENV_FILE to os.environ so that LLM provider +# SDKs (which read os.getenv() at import/construction time) can find them. +# test.py already parses TEST_ENV_FILE via dotenv_values but deliberately +# does NOT inject non-allowlisted keys into os.environ. We bridge that +# gap here for the small set of keys the eval suite needs. +_API_KEY_NAMES = ("GROQ_API_KEY", "OPENAI_API_KEY", "ANTHROPIC_API_KEY") +for _k in _API_KEY_NAMES: + if (_v := TEST_ENV_VARS.get(_k)) and not os.environ.get(_k): + os.environ[_k] = _v + + +_EVALS_DIR = os.path.dirname(__file__) + + +def _evals_explicitly_requested(config): + """Return True when the user intentionally targeted eval tests.""" + + # ``-m eval`` on the command line + marker_expr = config.getoption("-m", default="") + if "eval" in marker_expr: + return True + + # User pointed pytest at an eval file/directory (e.g. VSCode test runner) + for arg in config.args: + if os.path.abspath(arg).startswith(_EVALS_DIR): + return True + + return False + + +def pytest_collection_modifyitems(config, items): + """Skip eval tests unless explicitly requested (``-m eval`` or by path). + + Also wires up ``EVAL_RETRIES``: when set to a positive integer, every eval + test is automatically marked with ``pytest.mark.retry(N)`` so that failing + tests are re-run up to N times. A test that passes on retry is a flake + (LLM non-determinism); one that fails all N retries is a consistent bug. + """ + + if not _evals_explicitly_requested(config): + skip_eval = pytest.mark.skip(reason="eval tests only run with -m eval") + for item in items: + if item.get_closest_marker("eval"): + item.add_marker(skip_eval) + return + + eval_retries = int(os.environ.get("EVAL_RETRIES", "0")) + if eval_retries > 0: + for item in items: + if item.get_closest_marker("eval"): + item.add_marker(pytest.mark.retry(eval_retries)) + + +def pytest_generate_tests(metafunc): + """Auto-parametrize tests that use the ``eval_model`` fixture.""" + + if "eval_model" in metafunc.fixturenames: + from .eval_utils import get_eval_model + + model_str = get_eval_model() + models = [m.strip() for m in model_str.split(",") if m.strip()] + metafunc.parametrize("eval_model", models, scope="session") + + +@pytest.fixture(scope="session") +def synced_knowledge_base(django_db_blocker): + """ + Sync the knowledge base once per pytest session if not already populated. + + With ``--reuse-db`` the DB persists across sessions, so the (slow) + embedding + sync step only runs the very first time. Subsequent + sessions detect that the KB is already populated and return immediately. + """ + + with django_db_blocker.unblock(): + if not getattr(settings, "BASEROW_EMBEDDINGS_API_URL", ""): + return # No embeddings server → nothing to sync + + from baserow_enterprise.assistant.tools.search_user_docs.handler import ( + KnowledgeBaseHandler, + ) + + handler = KnowledgeBaseHandler() + + if handler.can_search(): + return # Already populated (e.g. --reuse-db from a previous run) + + if not handler.can_have_knowledge_base(): + return # pgvector not available + + print("\n[eval] Syncing knowledge base (first run — this may take a while)...") + handler.sync_knowledge_base() + print("[eval] Knowledge base sync complete.") + + +@pytest.fixture(autouse=True) +def suppress_asyncio_stopiteration_error(): + """ + Suppress the 'StopIteration interacts badly with generators' asyncio error. + + This is a known Python issue when generators raise StopIteration in contexts + where asyncio futures are involved. The error is harmless but noisy. + """ + original_handler = None + + def custom_exception_handler(loop, context): + exception = context.get("exception") + if isinstance(exception, TypeError) and "StopIteration" in str(exception): + return # Suppress this specific error + if original_handler: + original_handler(loop, context) + else: + loop.default_exception_handler(context) + + try: + loop = asyncio.get_event_loop() + original_handler = loop.get_exception_handler() + loop.set_exception_handler(custom_exception_handler) + except RuntimeError: + pass # No event loop + + # Also suppress the log message + asyncio_logger = logging.getLogger("asyncio") + original_level = asyncio_logger.level + asyncio_logger.setLevel(logging.CRITICAL) + + yield + + asyncio_logger.setLevel(original_level) + try: + loop = asyncio.get_event_loop() + loop.set_exception_handler(original_handler) + except RuntimeError: + pass diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py new file mode 100644 index 0000000000..e4d9963fca --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/eval_utils.py @@ -0,0 +1,372 @@ +""" +Shared utilities for assistant evals (single-agent architecture). + +These utilities are used by multiple eval test files and provide: +- LLM configuration +- UIContext building +- Callback tracking for assertions +- Assistant creation helpers +- Message history formatting for inspection +""" + +import json +import os + +from pydantic_ai.usage import UsageLimits + +from baserow_enterprise.assistant.agents import main_agent +from baserow_enterprise.assistant.deps import AssistantDeps, ToolHelpers +from baserow_enterprise.assistant.tools.registries import assistant_tool_registry +from baserow_enterprise.assistant.types import ( + ApplicationUIContext, + TableUIContext, + UIContext, + UserUIContext, + WorkspaceUIContext, +) + +# Default model for evals - can be overridden via EVAL_LLM_MODEL env var +DEFAULT_EVAL_MODEL = "groq:openai/gpt-oss-120b" + + +def build_database_ui_context(user, workspace, database=None, table=None) -> str: + """ + Build a UIContext for a database, formatted as JSON string. + + This tells the agent which workspace/database/table the user is viewing. + """ + ctx = UIContext( + workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), + database=ApplicationUIContext(id=str(database.id), name=database.name) + if database + else None, + table=TableUIContext(id=table.id, name=table.name) if table else None, + user=UserUIContext(id=user.id, name=user.first_name, email=user.email), + ) + return ctx.format() + + +def format_message_history(result) -> list[dict]: + """ + Format the full message history from an agent run for inspection. + + Returns a list of dicts with structured info about each message: + - role: system/user/assistant/tool + - type: the pydantic-ai message class name + - content: text content (if any) + - tool_calls: list of tool call info (if any) + - tool_name: name of tool that returned this result (for tool results) + - timestamp: message timestamp (if available) + """ + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + ) + + messages = getattr(result, "all_messages", lambda: [])() or [] + formatted = [] + + for msg in messages: + if isinstance(msg, ModelRequest): + for part in msg.parts: + part_type = type(part).__name__ + entry = {"role": "user", "type": part_type} + + if hasattr(part, "content"): + entry["content"] = part.content + if hasattr(part, "tool_name"): + entry["tool_name"] = part.tool_name + if hasattr(part, "tool_call_id"): + entry["tool_call_id"] = part.tool_call_id + if hasattr(part, "timestamp"): + entry["timestamp"] = str(part.timestamp) + + formatted.append(entry) + + elif isinstance(msg, ModelResponse): + for part in msg.parts: + part_type = type(part).__name__ + entry = {"role": "assistant", "type": part_type} + + if hasattr(part, "content"): + entry["content"] = part.content + if hasattr(part, "tool_name"): + entry["tool_name"] = part.tool_name + if hasattr(part, "tool_call_id"): + entry["tool_call_id"] = part.tool_call_id + if hasattr(part, "args"): + # Tool call arguments + args = part.args + if isinstance(args, str): + try: + args = json.loads(args) + except (json.JSONDecodeError, TypeError): + pass + entry["args"] = args + + formatted.append(entry) + + return formatted + + +def print_message_history(result, max_content_len=1000): + """ + Print a human-readable summary of the full message history. + + Shows all LLM requests, responses, tool calls, and tool results + in chronological order. + """ + history = format_message_history(result) + + print("\n" + "=" * 80) + print("MESSAGE HISTORY") + print("=" * 80) + + for i, entry in enumerate(history): + role = entry["role"].upper() + msg_type = entry.get("type", "unknown") + print(f"\n--- [{i + 1}] {role} ({msg_type}) ---") + + if "content" in entry: + content = str(entry["content"]) + if len(content) > max_content_len: + content = content[:max_content_len] + "..." + print(f" Content: {content}") + + if "tool_name" in entry: + print(f" Tool: {entry['tool_name']}") + + if "args" in entry: + args_str = json.dumps(entry["args"], indent=2, default=str) + if len(args_str) > max_content_len: + args_str = args_str[:max_content_len] + "..." + print(f" Args: {args_str}") + + if "tool_call_id" in entry: + print(f" Call ID: {entry['tool_call_id']}") + + print("\n" + "=" * 80) + print(f"Total entries: {len(history)}") + print("=" * 80 + "\n") + + +def print_trajectory(result, max_obs_len=500): + """Debug helper to print the agent's trajectory.""" + print("\n=== TRAJECTORY ===") + # pydantic-ai stores messages differently + for i, msg in enumerate(getattr(result, "all_messages", lambda: [])() or []): + print(f"\n--- Message {i + 1} ---") + print(f" {type(msg).__name__}: {str(msg)[:max_obs_len]}") + print("\n=== END TRAJECTORY ===\n") + + +def get_eval_model() -> str: + """ + Get the model string for evals. + + Configure via EVAL_LLM_MODEL environment variable. + API keys should be set via standard env vars (OPENAI_API_KEY, GROQ_API_KEY). + """ + return os.environ.get("EVAL_LLM_MODEL", DEFAULT_EVAL_MODEL) + + +class EvalToolTracker: + """ + Placeholder for future tool-call instrumentation. + + Currently eval assertions rely on inspecting the pydantic-ai message + history (``RetryPromptPart`` entries) rather than wrapping individual + tools, so this class is intentionally minimal. + """ + + def __init__(self, verbose: bool = True): + self.verbose = verbose + + +def create_eval_assistant(user, workspace, max_iters=15, model=None): + """ + Create an assistant configured like production for evals. + + Returns (agent, deps, tracker, model, usage_limits, toolset) so tests + can run the agent. Uses the single-agent architecture with the full + monolithic toolset from build_assistant_toolset(). + + :param model: Override the LLM model string. Falls back to + ``get_eval_model()`` (i.e. the ``EVAL_LLM_MODEL`` env var). + """ + from django.conf import settings + + tool_helpers = ToolHelpers(lambda x: None, lambda x: None) + tracker = EvalToolTracker() + model = model or get_eval_model() + + # Ensure sub-agents (e.g. formula_agent) also use the eval model. + # get_model_string() does .replace("/", ":", 1) on the setting value, + # so store in "/" format (e.g. "groq/openai/gpt-oss-120b"). + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = model.replace(":", "/", 1) + + deps = AssistantDeps( + user=user, + workspace=workspace, + tool_helpers=tool_helpers, + ) + + # Build the single-agent toolset (navigation + core + database + automation) + toolset, db_manifest, app_manifest, auto_manifest, explain_manifest = ( + assistant_tool_registry.build_toolset(user, workspace, model, deps) + ) + deps.database_manifest = db_manifest + deps.application_manifest = app_manifest + deps.automation_manifest = auto_manifest + deps.explain_manifest = explain_manifest + usage_limits = UsageLimits(request_limit=max_iters) + + return main_agent, deps, tracker, model, usage_limits, toolset + + +def get_tool_call_sequence(result) -> list[str]: + """ + Return the ordered list of tool names called during an agent run. + + Extracts assistant-side tool call entries from the message history, + preserving chronological order. + """ + + history = format_message_history(result) + return [ + e["tool_name"] + for e in history + if e["role"] == "assistant" and "tool_name" in e and "args" in e + ] + + +def assert_tool_call_order(result, expected_order: list[str]): + """ + Assert that tools were called in the expected relative order. + + For each consecutive pair (A, B) in *expected_order*, verifies that the + **last** call to A comes before the **first** call to B. This guarantees + that all A work is fully completed before any B work begins. + + Example:: + + assert_tool_call_order(result, [ + "create_pages", + "create_layout_elements", + "create_display_elements", + ]) + """ + + sequence = get_tool_call_sequence(result) + + def _all_indices(tool_name: str) -> list[int]: + indices = [i for i, name in enumerate(sequence) if name == tool_name] + if not indices: + raise AssertionError( + f"Expected tool '{tool_name}' was never called. " + f"Actual sequence: {sequence}" + ) + return indices + + for i in range(len(expected_order) - 1): + name_a = expected_order[i] + name_b = expected_order[i + 1] + last_a = _all_indices(name_a)[-1] + first_b = _all_indices(name_b)[0] + assert last_a < first_b, ( + f"Expected all '{name_a}' calls to finish before any '{name_b}' call, " + f"but last '{name_a}' at pos {last_a} >= first '{name_b}' at pos {first_b}. " + f"Actual sequence: {sequence}" + ) + + +class EvalChecklist: + """ + Soft-assertion context manager for eval tests. + + Collects labelled checks without raising immediately. On exit it prints a + score table (visible with ``-s``) and raises a single AssertionError that + lists every failed check. This lets you see "4/6 (66%)" instead of the + binary "FAIL at first assertion" behaviour of plain ``assert``. + + Usage:: + + with EvalChecklist("creates Bookstore database") as checks: + checks.check("Books table exists", any("book" in n for n in names)) + checks.check("Authors table exists", any("author" in n for n in names), + hint=f"got: {names}") + """ + + def __init__(self, name: str): + self.name = name + self._checks: list[tuple[str, bool, str]] = [] + + def check(self, label: str, condition: bool, hint: str = "") -> bool: + """Record a soft check. Returns the condition value for further use.""" + self._checks.append((label, bool(condition), hint)) + return bool(condition) + + @property + def score(self) -> tuple[int, int]: + passed = sum(1 for _, ok, _ in self._checks if ok) + return passed, len(self._checks) + + def assert_all(self): + passed, total = self.score + pct = 100 * passed // total if total else 0 + lines = [ + f" {'✓' if ok else '✗'} {label}" + + (f" ({hint})" if not ok and hint else "") + for label, ok, hint in self._checks + ] + summary = ( + f"\nEVAL SCORE [{self.name}]: {passed}/{total} ({pct}%)\n" + + "\n".join(lines) + ) + print(summary) + failed = [label for label, ok, _ in self._checks if not ok] + assert not failed, summary + + def __enter__(self): + return self + + def __exit__(self, exc_type, *_): + if exc_type is None: + self.assert_all() + return False + + +def count_tool_errors(result) -> tuple[int, str]: + """ + Count tool validation errors in the agent result. + + Inspects the pydantic-ai message history for ``RetryPromptPart`` entries, + which indicate the LLM sent invalid arguments that failed pydantic + validation. "Unknown tool name" retries are excluded — the LLM explored a + non-existent tool and recovered on its own, which is acceptable. + + Returns ``(error_count, hint)`` suitable for use with + :meth:`EvalChecklist.check`. + """ + from pydantic_ai.messages import ModelRequest, RetryPromptPart + + if result is None: + return 0, "" + + messages = getattr(result, "all_messages", lambda: [])() or [] + retry_errors = [] + for msg in messages: + if isinstance(msg, ModelRequest): + for part in msg.parts: + if isinstance(part, RetryPromptPart): + content = str(part.content) + if "Unknown tool name" in content: + continue + retry_errors.append( + { + "tool_name": getattr(part, "tool_name", None), + "content": content, + } + ) + hint = "\n".join(f" - {e['tool_name']}: {e['content']}" for e in retry_errors) + return len(retry_errors), hint diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py new file mode 100644 index 0000000000..6982c1a284 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_automation_workflows.py @@ -0,0 +1,845 @@ +import pytest + +from baserow.contrib.automation.workflows.models import AutomationWorkflow + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + format_message_history, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_LISTS_WORKFLOWS = "List the workflows in automation ID {automation_id}" + +PROMPT_CREATES_WORKFLOW = ( + "Create a workflow in automation {automation_name} that " + "triggers when a row is created in table '{table_name}', " + "and updates the Status field to 'Processing'." +) + +PROMPT_CREATES_WEEKLY_SLACK_REMINDER = ( + "In automation '{automation_name}', create a workflow that sends a " + "Slack message to #general every Tuesday at 9am UTC asking " + "'Is there anything to demo this week?'" +) + +PROMPT_CREATES_ROUTER_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in table '{table_name}'. " + "Add a router: if Priority is 'High', send a Slack message to " + "#urgent saying 'High priority ticket created'. " + "If Priority is 'Low', do nothing (just the router branch is fine)." +) + +PROMPT_CREATES_ROW_WITH_FIELD_VALUES = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in '{source_table_name}'. " + "Then create a row in '{log_table_name}' with Entry set to " + "the new contact's Name and Source set to 'automation'." +) + +PROMPT_CREATES_UPDATE_ROW_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is updated in '{table_name}'. " + "Then update the same row: set Status to 'Reviewed' and " + "Notes to 'Automatically reviewed by automation'." +) + +PROMPT_CREATES_EMAIL_NOTIFICATION_WORKFLOW = ( + "In automation '{automation_name}', create a workflow that " + "triggers when a row is created in '{table_name}'. " + "Send an email to admin@example.com with subject 'New Order' " + "and body 'A new order has been placed'." +) + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +def _get_create_workflows_args(result) -> list[dict]: + """Return the parsed ``args`` dicts of every ``create_workflows`` tool call + the agent made (assistant-side entries have ``args``).""" + + history = format_message_history(result) + return [ + e["args"] + for e in history + if e["role"] == "assistant" + and e.get("tool_name") == "create_workflows" + and "args" in e + ] + + +def _get_workflow_nodes(automation): + """Return (workflow, trigger, action_nodes) for the first workflow.""" + + workflow = AutomationWorkflow.objects.filter(automation=automation).first() + assert workflow is not None, "No workflow was created" + trigger = workflow.get_trigger() + action_nodes = list( + workflow.automation_workflow_nodes.exclude(id=trigger.id).order_by("id") + ) + return workflow, trigger, action_nodes + + +# --------------------------------------------------------------------------- +# Existing evals +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_lists_workflows(data_fixture, eval_model): + """Agent should call list_workflows when asked about automation workflows.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + automation = data_fixture.create_automation_application( + workspace=workspace, name="My Automation" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_LISTS_WORKFLOWS.format(automation_id=automation.id), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "list_workflows" and e["role"] == "user" + ] + + with EvalChecklist("lists workflows") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called list_workflows", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_workflow(data_fixture, eval_model): + """Agent should create a workflow when asked to automate a process.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Orders") + data_fixture.create_text_field(table=table, name="Order ID", primary=True) + data_fixture.create_text_field(table=table, name="Status") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Order Processing" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_workflows" and e["role"] == "user" + ] + workflows = AutomationWorkflow.objects.filter(automation=automation) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + update_nodes_args = [n for n in nodes_args if n.get("type") == "update_row"] + ur_values = update_nodes_args[0].get("values", []) if update_nodes_args else [] + ur_has_processing = any( + "processing" in str(v.get("value", "")).lower() for v in ur_values + ) + + db_ok = workflows.exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_update_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_update_actions = [] + + with EvalChecklist("creates workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_workflows", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "workflow created in DB", + db_ok, + ) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Orders", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "update_row node in args", + len(update_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "update_row sets field to 'Processing'", + ur_has_processing, + hint=f"values: {ur_values}", + ) + checks.check( + "DB trigger is rows_created", + db_trigger_type == "local_baserow_rows_created", + hint=f"got {db_trigger_type}", + ) + checks.check( + "update_row action in DB", + len(db_update_actions) >= 1, + ) + + +# --------------------------------------------------------------------------- +# Periodic trigger + Slack message +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_weekly_slack_reminder(data_fixture, eval_model): + """Agent should create a periodic-WEEK trigger firing on Tuesday with a + Slack message node asking about demos.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + automation = data_fixture.create_automation_application( + workspace=workspace, name="Team Reminders" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_WEEKLY_SLACK_REMINDER.format( + automation_name=automation.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + interval_args = trigger_args.get("periodic_interval", {}) + nodes_args = wf_args.get("nodes", []) + slack_nodes_args = [n for n in nodes_args if n.get("type") == "slack_write_message"] + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.get_type().type + db_slack_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "slack_write_message" + ] + else: + db_trigger_type = None + db_slack_actions = [] + + slack_node = slack_nodes_args[0] if slack_nodes_args else {} + slack_channel = slack_node.get("channel", "") + slack_text = slack_node.get("text", "") + + with EvalChecklist("creates weekly Slack reminder") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_workflows", + len(call_args_list) >= 1, + ) + checks.check( + "trigger type is periodic", + trigger_args.get("type") == "periodic", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "interval is WEEK", + interval_args.get("interval") == "WEEK", + hint=f"got {interval_args.get('interval')}", + ) + checks.check( + "day_of_week is 1 (Tuesday)", + interval_args.get("day_of_week") == 1, + hint=f"got {interval_args.get('day_of_week')}", + ) + checks.check( + "slack_write_message node in args", + len(slack_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "workflow created in DB with periodic trigger", + db_trigger_type == "periodic", + hint=f"got {db_trigger_type}", + ) + checks.check( + "Slack action exists in DB", + len(db_slack_actions) >= 1, + ) + checks.check( + "Slack channel is #general", + "general" in slack_channel.lower(), + hint=f"got channel: '{slack_channel}'", + ) + checks.check( + "Slack message mentions demo", + "demo" in slack_text.lower(), + hint=f"got text: '{slack_text}'", + ) + + +# --------------------------------------------------------------------------- +# Router node +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_router_workflow(data_fixture, eval_model): + """Agent should create a workflow with a router node that branches + based on a condition.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tickets") + data_fixture.create_text_field(table=table, name="Title", primary=True) + priority_field = data_fixture.create_single_select_field( + table=table, name="Priority" + ) + data_fixture.create_select_option(field=priority_field, value="High", order=0) + data_fixture.create_select_option(field=priority_field, value="Low", order=1) + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Ticket Router" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_ROUTER_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + nodes_args = wf_args.get("nodes", []) + router_nodes_args = [n for n in nodes_args if n.get("type") == "router"] + router_edges_args = ( + router_nodes_args[0].get("edges", []) if router_nodes_args else [] + ) + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_router_actions = [ + n for n in action_nodes if n.service.get_type().type == "router" + ] + db_edges_count = ( + db_router_actions[0].service.specific.edges.count() + if db_router_actions + else 0 + ) + else: + db_router_actions = [] + db_edges_count = 0 + + trigger_args = wf_args.get("trigger", {}) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + slack_nodes_in_nodes = [ + n for n in nodes_args if n.get("type") == "slack_write_message" + ] + slack_channel = ( + slack_nodes_in_nodes[0].get("channel", "") if slack_nodes_in_nodes else "" + ) + + with EvalChecklist("creates router workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Tickets", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "router node in args", + len(router_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "router has >=2 edges in args", + len(router_edges_args) >= 2, + hint=f"got {len(router_edges_args)}", + ) + checks.check( + "router node in DB", + len(db_router_actions) >= 1, + ) + checks.check( + "router has >=2 edges in DB", + db_edges_count >= 2, + hint=f"got {db_edges_count}", + ) + checks.check( + "Slack node exists for High branch", + len(slack_nodes_in_nodes) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "Slack channel is #urgent", + "urgent" in slack_channel.lower(), + hint=f"got channel: '{slack_channel}'", + ) + + +# --------------------------------------------------------------------------- +# Create-row / update-row with field value formulas +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_row_with_field_values(data_fixture, eval_model): + """Agent should create a workflow with a create_row node that maps + specific field values (including formula-style references).""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + source_table = data_fixture.create_database_table( + database=database, name="Contacts" + ) + data_fixture.create_text_field(table=source_table, name="Name", primary=True) + data_fixture.create_email_field(table=source_table, name="Email") + + log_table = data_fixture.create_database_table(database=database, name="Log") + data_fixture.create_text_field(table=log_table, name="Entry", primary=True) + data_fixture.create_text_field(table=log_table, name="Source") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Contact Logger" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, source_table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_ROW_WITH_FIELD_VALUES.format( + automation_name=automation.name, + source_table_name=source_table.name, + log_table_name=log_table.name, + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + create_row_nodes_args = [n for n in nodes_args if n.get("type") == "create_row"] + cr_values = ( + create_row_nodes_args[0].get("values", []) if create_row_nodes_args else [] + ) + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_create_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_create_actions = [] + + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + cr_node = create_row_nodes_args[0] if create_row_nodes_args else {} + cr_table_id = cr_node.get("table_id") + cr_has_literal_automation = any( + "automation" in str(v.get("value", "")).lower() for v in cr_values + ) + + with EvalChecklist("creates row with field values") as checks: + checks.check("<=1 tool errors", err_count <= 1, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Contacts (source_table)", + trigger_table_id == source_table.id, + hint=f"got table_id={trigger_table_id}, expected={source_table.id}", + ) + checks.check( + "create_row node in args", + len(create_row_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "create_row targets Log table", + cr_table_id == log_table.id, + hint=f"got table_id={cr_table_id}, expected={log_table.id}", + ) + checks.check( + "create_row has >=1 field value", + len(cr_values) >= 1, + hint=f"got {len(cr_values)}", + ) + checks.check( + "create_row has 'automation' literal value (Source field)", + cr_has_literal_automation, + hint=f"values: {cr_values}", + ) + checks.check( + "DB trigger is rows_created", + db_trigger_type == "local_baserow_rows_created", + hint=f"got {db_trigger_type}", + ) + checks.check( + "create_row action in DB", + len(db_create_actions) >= 1, + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_update_row_workflow(data_fixture, eval_model): + """Agent should create a workflow with an update_row node that references + field values from the trigger.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Task", primary=True) + data_fixture.create_text_field(table=table, name="Status") + data_fixture.create_long_text_field(table=table, name="Notes") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Task Processor" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_UPDATE_ROW_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + nodes_args = wf_args.get("nodes", []) + update_nodes_args = [n for n in nodes_args if n.get("type") == "update_row"] + ur = update_nodes_args[0] if update_nodes_args else {} + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_trigger_type = trigger_node.service.get_type().type + db_update_actions = [ + n + for n in action_nodes + if n.service.get_type().type == "local_baserow_upsert_row" + ] + else: + db_trigger_type = None + db_update_actions = [] + + ur_values = ur.get("values", []) + ur_has_reviewed = any( + "reviewed" in str(v.get("value", "")).lower() for v in ur_values + ) + ur_has_notes = any( + "automation" in str(v.get("value", "")).lower() + or "review" in str(v.get("value", "")).lower() + for v in ur_values + ) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + + with EvalChecklist("creates update-row workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_updated", + trigger_args.get("type") == "rows_updated", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Tasks", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "update_row node in args", + len(update_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "update_row has >=1 field value", + len(ur_values) >= 1, + ) + checks.check( + "update_row has row_id", + bool(ur.get("row_id")), + ) + checks.check( + "update_row sets Status to 'Reviewed'", + ur_has_reviewed, + hint=f"values: {ur_values}", + ) + checks.check( + "update_row sets Notes (automation/reviewed text)", + ur_has_notes, + hint=f"values: {ur_values}", + ) + checks.check( + "DB trigger is rows_updated", + db_trigger_type == "local_baserow_rows_updated", + hint=f"got {db_trigger_type}", + ) + checks.check( + "update_row action in DB", + len(db_update_actions) >= 1, + ) + + +# --------------------------------------------------------------------------- +# Send email node +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_email_notification_workflow(data_fixture, eval_model): + """Agent should create a workflow with an smtp_email node.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Orders") + data_fixture.create_text_field(table=table, name="Order ID", primary=True) + data_fixture.create_text_field(table=table, name="Customer Email") + + automation = data_fixture.create_automation_application( + workspace=workspace, name="Order Notifications" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_EMAIL_NOTIFICATION_WORKFLOW.format( + automation_name=automation.name, table_name=table.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + call_args_list = _get_create_workflows_args(result) + args = call_args_list[0] if call_args_list else {} + wf_args = args.get("workflows", [{}])[0] if args.get("workflows") else {} + trigger_args = wf_args.get("trigger", {}) + trigger_table_id = trigger_args.get("rows_triggers_settings", {}).get("table_id") + nodes_args = wf_args.get("nodes", []) + email_nodes_args = [n for n in nodes_args if n.get("type") == "smtp_email"] + email_node = email_nodes_args[0] if email_nodes_args else {} + email_to = email_node.get("to_emails", "") + email_subject = email_node.get("subject", "") + email_body = email_node.get("body", "") + + db_ok = AutomationWorkflow.objects.filter(automation=automation).exists() + if db_ok: + workflow, trigger_node, action_nodes = _get_workflow_nodes(automation) + db_email_actions = [ + n for n in action_nodes if n.service.get_type().type == "smtp_email" + ] + else: + db_email_actions = [] + + with EvalChecklist("creates email notification workflow") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("called create_workflows", len(call_args_list) >= 1) + checks.check( + "trigger is rows_created", + trigger_args.get("type") == "rows_created", + hint=f"got {trigger_args.get('type')}", + ) + checks.check( + "trigger table is Orders", + trigger_table_id == table.id, + hint=f"got table_id={trigger_table_id}, expected={table.id}", + ) + checks.check( + "smtp_email node in args", + len(email_nodes_args) >= 1, + hint=f"node types: {[n.get('type') for n in nodes_args]}", + ) + checks.check( + "email to admin@example.com", + "admin@example.com" in email_to, + hint=f"got to: '{email_to}'", + ) + checks.check( + "email subject mentions 'Order'", + "order" in email_subject.lower(), + hint=f"got subject: '{email_subject}'", + ) + checks.check( + "email body mentions order being placed", + "order" in email_body.lower() or "placed" in email_body.lower(), + hint=f"got body: '{email_body}'", + ) + checks.check("workflow created in DB", db_ok) + checks.check( + "smtp_email action in DB", + len(db_email_actions) >= 1, + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py new file mode 100644 index 0000000000..db0a28ccd0 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_core_builders.py @@ -0,0 +1,201 @@ +import pytest + +from baserow.contrib.automation.models import Automation +from baserow.contrib.database.models import Database + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + format_message_history, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_LISTS_DATABASES = "What databases do I have in this workspace?" + +PROMPT_CREATES_DATABASE = "Create a new database called 'Customer Portal'" + +PROMPT_CREATES_AUTOMATION = "Create an empty automation called 'Overdue Task Reminder'." + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_lists_databases(data_fixture, eval_model): + """Agent should call list_builders when asked what databases exist.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Inventory" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_LISTS_DATABASES, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "list_builders" and e["role"] == "user" + ] + + with EvalChecklist("lists databases") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called list_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "answer mentions 'Inventory'", + "Inventory" in result.output, + hint=f"answer: {result.output[:200]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_database(data_fixture, eval_model): + """Agent should create a new database when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_DATABASE, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_builders" and e["role"] == "user" + ] + created = Database.objects.filter(workspace=workspace, name__icontains="customer") + + with EvalChecklist("creates database") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "called create_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "database 'Customer Portal' exists", + created.exists(), + hint=f"databases: {list(Database.objects.filter(workspace=workspace).values_list('name', flat=True))}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_automation(data_fixture, eval_model): + """Agent should create a new automation when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace) + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_AUTOMATION, + ui_context=ui_context, + ) + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + history = format_message_history(result) + tool_calls = [ + e + for e in history + if e.get("tool_name") == "create_builders" and e["role"] == "user" + ] + created = list(Automation.objects.all()) + automation = created[0] if created else None + + with EvalChecklist("creates automation") as checks: + checks.check("<=1 tool errors", err_count <= 1, hint=err_hint) + checks.check( + "called create_builders", + len(tool_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + "exactly 1 automation created", + len(created) == 1, + hint=f"found {len(created)}: {[a.name for a in created]}", + ) + checks.check( + "automation named 'Overdue Task Reminder'", + automation is not None and "overdue" in automation.name.lower(), + hint=f"got: '{automation.name if automation else None}'", + ) + checks.check( + "automation in correct workspace", + automation is not None and automation.workspace_id == workspace.id, + hint=f"workspace_id={automation.workspace_id if automation else None} vs {workspace.id}", + ) + checks.check( + "automation has no workflows", + automation is not None and automation.workflows.count() == 0, + hint=f"workflows: {list(automation.workflows.all()) if automation else []}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py new file mode 100644 index 0000000000..bad28ec3fa --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_rows.py @@ -0,0 +1,214 @@ +import pytest + +from baserow.contrib.database.rows.handler import RowHandler + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_CREATES_ROWS_WITH_ALL_FIELD_TYPES = ( + "Create 5 rows with diverse sample data in table {table_name}. " + "Fill in ALL fields with realistic values." +) + + +def _create_rich_table(data_fixture): + """ + Create a table with all managed field types plus a linked table + with sample data. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + # Linked table (target for link_row fields) + linked_table = data_fixture.create_database_table( + database=database, name="Categories" + ) + linked_primary = data_fixture.create_text_field( + table=linked_table, name="Name", primary=True + ) + + # Populate linked table with sample rows + RowHandler().force_create_rows( + user, + linked_table, + [ + {linked_primary.db_column: "Work"}, + {linked_primary.db_column: "Personal"}, + {linked_primary.db_column: "Urgent"}, + ], + ) + + # Main table with all managed field types + table = data_fixture.create_database_table(database=database, name="Tasks") + title = data_fixture.create_text_field(table=table, name="Title", primary=True) + description = data_fixture.create_long_text_field(table=table, name="Description") + estimated_hours = data_fixture.create_number_field( + table=table, name="Estimated Hours", number_decimal_places=1 + ) + completed = data_fixture.create_boolean_field(table=table, name="Completed") + due_date = data_fixture.create_date_field(table=table, name="Due Date") + created_at = data_fixture.create_date_field( + table=table, name="Created At", date_include_time=True + ) + + status_field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=status_field, value="To Do", order=0) + data_fixture.create_select_option(field=status_field, value="In Progress", order=1) + data_fixture.create_select_option(field=status_field, value="Done", order=2) + + tags_field = data_fixture.create_multiple_select_field(table=table, name="Tags") + data_fixture.create_select_option(field=tags_field, value="Bug", order=0) + data_fixture.create_select_option(field=tags_field, value="Feature", order=1) + data_fixture.create_select_option(field=tags_field, value="Docs", order=2) + + category_field = data_fixture.create_link_row_field( + table=table, + link_row_table=linked_table, + name="Category", + link_row_multiple_relationships=False, + ) + related_categories_field = data_fixture.create_link_row_field( + table=table, + link_row_table=linked_table, + name="Related Categories", + link_row_multiple_relationships=True, + ) + + return { + "user": user, + "workspace": workspace, + "database": database, + "table": table, + "linked_table": linked_table, + "fields": { + "title": title, + "description": description, + "estimated_hours": estimated_hours, + "completed": completed, + "due_date": due_date, + "created_at": created_at, + "status": status_field, + "tags": tags_field, + "category": category_field, + "related_categories": related_categories_field, + }, + } + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_rows_with_all_field_types(data_fixture, eval_model, db): + """ + Agent should create rows with sensible data for every field type. + + This tests the full flow: + 1. Agent calls get_tables_schema to learn the table structure + 2. Agent calls load_row_tools to unlock create_rows_in_table_X + 3. Agent calls create_rows_in_table_X with all fields populated + """ + + res = _create_rich_table(data_fixture) + user = res["user"] + workspace = res["workspace"] + database = res["database"] + table = res["table"] + fields = res["fields"] + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table=table) + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=PROMPT_CREATES_ROWS_WITH_ALL_FIELD_TYPES.format( + table_name=table.name + ), + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + table_model = table.get_model() + row_count = table_model.objects.count() + sample_rows = list(table_model.objects.all()) + + def _get_field_value(row, field_name): + return getattr(row, fields[field_name].db_column, None) + + def _any_row(check_fn): + return any(check_fn(r) for r in sample_rows) + + with EvalChecklist("creates rows with all field types") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check("5 rows created", row_count == 5, hint=f"got {row_count}") + checks.check( + "title populated", + _any_row(lambda r: bool(_get_field_value(r, "title"))), + ) + checks.check( + "description populated", + _any_row(lambda r: bool(_get_field_value(r, "description"))), + ) + checks.check( + "estimated_hours populated", + _any_row(lambda r: _get_field_value(r, "estimated_hours") is not None), + ) + checks.check( + "estimated_hours > 0 in at least one row", + _any_row(lambda r: (_get_field_value(r, "estimated_hours") or 0) > 0), + ) + checks.check( + "completed has at least one True", + _any_row(lambda r: _get_field_value(r, "completed") is True), + ) + checks.check( + "due_date populated", + _any_row(lambda r: _get_field_value(r, "due_date") is not None), + ) + checks.check( + "created_at populated", + _any_row(lambda r: _get_field_value(r, "created_at") is not None), + ) + checks.check( + "status is a known option", + _any_row( + lambda r: bool(_get_field_value(r, "status")) + and _get_field_value(r, "status").value + in ["To Do", "In Progress", "Done"] + ), + ) + checks.check( + "tags has at least one known option", + _any_row( + lambda r: bool( + set(_get_field_value(r, "tags").values_list("value", flat=True)) + & {"Bug", "Feature", "Docs"} + ) + ), + ) + checks.check( + "category linked", + _any_row(lambda r: len(_get_field_value(r, "category").all()) > 0), + ) + checks.check( + "related_categories linked", + _any_row( + lambda r: len(_get_field_value(r, "related_categories").all()) > 0 + ), + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py new file mode 100644 index 0000000000..1761d79a46 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_database_tables.py @@ -0,0 +1,1164 @@ +import pytest + +from baserow.contrib.database.fields.models import ( + BooleanField, + DateField, + LinkRowField, + LongTextField, + NumberField, + SingleSelectField, + TextField, +) +from baserow.contrib.database.models import Table +from baserow.contrib.database.views.models import View, ViewFilter +from baserow.core.db import specific_iterator + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + count_tool_errors, + create_eval_assistant, + print_message_history, +) + +# --------------------------------------------------------------------------- +# Eval prompts — one per test, easy to scan for coverage +# --------------------------------------------------------------------------- + +PROMPT_CREATES_SIMPLE_TABLE = ( + "Create a Recipes table in database {database_name} with these fields: " + "Name, Description, Prep Time in Minutes, Servings, and Vegetarian. " + "Don't add sample rows." +) + +PROMPT_CREATES_TABLE_WITH_SELECT_FIELDS = ( + "Create a Tasks table in database {database_name} with: " + "Title, Status with options: To Do, In Progress, Done, " + "Priority with options: Low, Medium, High, " + "and Due Date. Don't add sample rows." +) + +PROMPT_CREATES_RELATED_TABLES = ( + "Create a simple project management system in database {database_name} with: " + "1. A Projects table with Name and Description. " + "2. A Tasks table with Title, Status with options: To Do, In Progress, Done, " + "and a link to the Projects table. " + "Don't add sample rows." +) + +PROMPT_CREATES_DATABASE_FROM_DESCRIPTION = ( + "Set up a Bookstore database to manage a bookstore. " + "I need tables for Books and Authors. " + "Books should have title, description, price, publication date, and a link to Authors. " + "Authors should have name and bio. " + "Don't add sample rows." +) + +PROMPT_CREATE_RELATED_TABLES_WITH_SAMPLE_ROWS = ( + "Set up the Bookstore database {database_name} with: " + "1. An Authors table with Name and Bio. " + "2. A Books table with Title, Genre " + "(single select: Fiction, Non-Fiction, Science, History), " + "Price, and a link to the Authors table." +) + +# -- View creation prompts -------------------------------------------------- + +PROMPT_CREATE_GRID_VIEW = ( + "Create a grid view called 'All Tasks' for table {table_name}." +) + +PROMPT_CREATE_KANBAN_VIEW = ( + "Create a kanban view called 'Task Board' for table {table_name}. " + "Use the Status field (id: {status_field_name}) as the column field." +) + +PROMPT_CREATE_CALENDAR_VIEW = ( + "Create a calendar view called 'Schedule' for table {table_name}. " + "Use the Due Date field (id: {date_field_name}) as the date field." +) + +PROMPT_CREATE_GALLERY_VIEW = ( + "Create a gallery view called 'Image Gallery' for table {table_name}. " + "Use the Cover Image field (id: {file_field_name}) as the cover image." +) + +PROMPT_CREATE_TIMELINE_VIEW = ( + "Create a timeline view called 'Project Timeline' for table {table_name}. " + "Use Start Date (id: {start_field_name}) and End Date (id: {end_field_name})." +) + +PROMPT_CREATE_FORM_VIEW = ( + "Create a form view called 'Submit Task' for table {table_name}. " + "Include the Name field in the form." +) + +# -- View filter prompts ---------------------------------------------------- + +PROMPT_FILTER_TEXT_CONTAINS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Description field (id: {text_field_name}) " + "to only show rows where it contains 'important'." +) + +PROMPT_FILTER_NUMBER_GREATER_THAN = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Amount field (id: {number_field_name}) " + "to only show rows where it is greater than 100." +) + +PROMPT_FILTER_DATE_AFTER = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Due Date field (id: {date_field_name}) " + "to only show rows where the date is after today." +) + +PROMPT_FILTER_SINGLE_SELECT_ANY_OF = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Status field (id: {select_field_name}) " + "to only show rows where Status is any of 'Active' or 'Pending'." +) + +PROMPT_FILTER_MULTIPLE_SELECT_HAS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Tags field (id: {multi_field_name}) " + "to only show rows where Tags has 'Important'." +) + +PROMPT_FILTER_BOOLEAN_IS = ( + "Create a grid view called 'Filtered' for table {table_name}, " + "then add a filter on the Active field (id: {bool_field_name}) " + "to only show rows where Active is true." +) + +# -- Field update/delete prompts -------------------------------------------- + +PROMPT_UPDATE_FIELD_RENAME = ( + "Rename the Description field to Summary in the {table_name} table." +) + +PROMPT_UPDATE_FIELD_SELECT_OPTIONS = ( + "Add an 'In Progress' option to the Status field in the {table_name} table." +) + +PROMPT_DELETE_FIELD = "Delete the Notes field from the {table_name} table." + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + """Helper to run the agent with standard configuration.""" + deps.tool_helpers.request_context["ui_context"] = ui_context + + result = agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + return result + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_simple_table(data_fixture, eval_model): + """Agent should create a table with basic field types when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Recipe Database" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_SIMPLE_TABLE.format(database_name=database.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + recipe_tables = [t for t in tables if "recipe" in t.name.lower()] + table = recipe_tables[0] if recipe_tables else None + fields = list(specific_iterator(table.field_set.all())) if table else [] + field_names = {f.name.lower(): f for f in fields} + text_fields = [f for f in fields if isinstance(f, (TextField, LongTextField))] + number_fields = [f for f in fields if isinstance(f, NumberField)] + boolean_fields = [f for f in fields if isinstance(f, BooleanField)] + + prep_number = next( + ( + f + for f in number_fields + if any(kw in f.name.lower() for kw in ("prep", "time", "minute")) + ), + None, + ) + veg_bool = next((f for f in boolean_fields if "vegetarian" in f.name.lower()), None) + + with EvalChecklist("creates Recipes table") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Recipes table created", + len(recipe_tables) == 1, + hint=f"got {len(recipe_tables)}: {[t.name for t in tables]}", + ) + checks.check( + "Name field exists", + any("name" in n for n in field_names), + hint=f"fields: {list(field_names.keys())}", + ) + checks.check( + "Description field exists", + any("description" in n for n in field_names), + hint=f"fields: {list(field_names.keys())}", + ) + checks.check( + ">=2 text/long_text fields", + len(text_fields) >= 2, + hint=f"got {len(text_fields)}", + ) + checks.check( + ">=2 number fields", + len(number_fields) >= 2, + hint=f"got {len(number_fields)}", + ) + checks.check( + ">=1 boolean field", + len(boolean_fields) >= 1, + hint=f"got {len(boolean_fields)}", + ) + checks.check( + "Prep Time/Minutes field exists (number)", + prep_number is not None, + hint=f"number fields: {[f.name for f in number_fields]}", + ) + checks.check( + "Vegetarian field exists (boolean)", + veg_bool is not None, + hint=f"boolean fields: {[f.name for f in boolean_fields]}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_table_with_select_fields(data_fixture, eval_model): + """Agent should create a table with single select and appropriate options.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Task Management" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_TABLE_WITH_SELECT_FIELDS.format( + database_name=database.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + task_tables = [t for t in tables if "task" in t.name.lower()] + table = task_tables[0] if task_tables else None + fields = list(specific_iterator(table.field_set.all())) if table else [] + select_fields = [f for f in fields if isinstance(f, SingleSelectField)] + status_field = next((f for f in select_fields if "status" in f.name.lower()), None) + status_options = ( + list(status_field.select_options.values_list("value", flat=True)) + if status_field + else [] + ) + date_fields = [f for f in fields if isinstance(f, DateField)] + field_names_lower = {f.name.lower(): f for f in fields} + priority_field = next( + (f for f in select_fields if "priority" in f.name.lower()), None + ) + priority_options = ( + list(priority_field.select_options.values_list("value", flat=True)) + if priority_field + else [] + ) + status_option_values = {o.lower() for o in status_options} + priority_option_values = {o.lower() for o in priority_options} + + with EvalChecklist("creates Tasks table with selects") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Tasks table created", + len(task_tables) == 1, + hint=f"got {len(task_tables)}: {[t.name for t in tables]}", + ) + checks.check( + ">=2 single select fields (Status, Priority)", + len(select_fields) >= 2, + hint=f"got {len(select_fields)}: {[f.name for f in select_fields]}", + ) + checks.check( + "Status field exists", + status_field is not None, + hint=f"select fields: {[f.name for f in select_fields]}", + ) + checks.check( + "Status has >=3 options", + len(status_options) >= 3, + hint=f"got: {status_options}", + ) + checks.check( + ">=1 date field", + len(date_fields) >= 1, + hint=f"got {len(date_fields)}", + ) + checks.check( + "Title text field exists", + any("title" in n for n in field_names_lower), + hint=f"fields: {list(field_names_lower.keys())}", + ) + checks.check( + "Priority field exists", + priority_field is not None, + hint=f"select fields: {[f.name for f in select_fields]}", + ) + checks.check( + "Status has To Do / In Progress / Done", + {"to do", "in progress", "done"} <= status_option_values, + hint=f"got: {status_options}", + ) + checks.check( + "Priority has Low / Medium / High", + {"low", "medium", "high"} <= priority_option_values, + hint=f"got: {priority_options}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_related_tables(data_fixture, eval_model): + """Agent should create multiple tables with link_row relationships.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Project Management" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=20, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_RELATED_TABLES.format(database_name=database.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + table_names = {t.name.lower(): t for t in tables} + project_tables = [name for name in table_names if "project" in name] + task_tables = [name for name in table_names if "task" in name] + + task_table = table_names[task_tables[0]] if task_tables else None + task_fields = ( + list(specific_iterator(task_table.field_set.all())) if task_table else [] + ) + link_fields = [f for f in task_fields if isinstance(f, LinkRowField)] + + project_table = table_names[project_tables[0]] if project_tables else None + link_to_projects = ( + [f for f in link_fields if f.link_row_table_id == project_table.id] + if project_table + else [] + ) + project_fields = ( + list(specific_iterator(project_table.field_set.all())) if project_table else [] + ) + project_text_fields = [ + f for f in project_fields if isinstance(f, (TextField, LongTextField)) + ] + task_select_fields = [f for f in task_fields if isinstance(f, SingleSelectField)] + status_field_in_tasks = next( + (f for f in task_select_fields if "status" in f.name.lower()), None + ) + status_opts_in_tasks = ( + list(status_field_in_tasks.select_options.values_list("value", flat=True)) + if status_field_in_tasks + else [] + ) + status_opt_values = {o.lower() for o in status_opts_in_tasks} + + with EvalChecklist("creates related tables") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Projects table exists", + len(project_tables) >= 1, + hint=f"got tables: {list(table_names.keys())}", + ) + checks.check( + "Tasks table exists", + len(task_tables) >= 1, + hint=f"got tables: {list(table_names.keys())}", + ) + checks.check( + ">=1 link_row field in Tasks", + len(link_fields) >= 1, + hint=f"fields: {[(f.name, type(f).__name__) for f in task_fields]}", + ) + checks.check( + "link_row points to Projects table", + len(link_to_projects) >= 1, + hint=f"links to: {[(f.name, f.link_row_table_id) for f in link_fields]}", + ) + checks.check( + "Projects has >=2 text fields (Name, Description)", + len(project_text_fields) >= 2, + hint=f"project text fields: {[f.name for f in project_text_fields]}", + ) + checks.check( + "Tasks has Status single_select field", + status_field_in_tasks is not None, + hint=f"task select fields: {[f.name for f in task_select_fields]}", + ) + checks.check( + "Tasks Status has To Do / In Progress / Done", + {"to do", "in progress", "done"} <= status_opt_values, + hint=f"got: {status_opts_in_tasks}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_creates_database_from_description(data_fixture, eval_model): + """ + Agent should create a full database structure from a high-level description. + + This tests the agent's ability to interpret a vague request and create + appropriate tables, fields, and relationships. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=25, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATES_DATABASE_FROM_DESCRIPTION, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + from baserow.contrib.database.models import Database + + databases = Database.objects.filter(workspace=workspace) + tables = list(Table.objects.filter(database__in=databases)) + table_names_lower = [t.name.lower() for t in tables] + + books_table = next((t for t in tables if "book" in t.name.lower()), None) + books_fields = ( + list(specific_iterator(books_table.field_set.all())) if books_table else [] + ) + books_field_types = {type(f) for f in books_fields} + + authors_table_obj = next((t for t in tables if "author" in t.name.lower()), None) + authors_fields = ( + list(specific_iterator(authors_table_obj.field_set.all())) + if authors_table_obj + else [] + ) + authors_field_types = {type(f) for f in authors_fields} + books_link_fields = [f for f in books_fields if isinstance(f, LinkRowField)] + link_to_authors = ( + [f for f in books_link_fields if f.link_row_table_id == authors_table_obj.id] + if authors_table_obj + else [] + ) + + with EvalChecklist("creates Bookstore database") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "database created", + databases.exists(), + hint="no database found in workspace", + ) + checks.check( + "Books table exists", + any("book" in n for n in table_names_lower), + hint=f"got: {[t.name for t in tables]}", + ) + checks.check( + "Authors table exists", + any("author" in n for n in table_names_lower), + hint=f"got: {[t.name for t in tables]}", + ) + checks.check( + "Books has text/long_text field", + TextField in books_field_types or LongTextField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has number field (price)", + NumberField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has date field", + DateField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books has link_row field to Authors", + LinkRowField in books_field_types, + hint=f"field types: {[t.__name__ for t in books_field_types]}", + ) + checks.check( + "Books link_row points to Authors table", + len(link_to_authors) >= 1, + hint=f"link targets: {[f.link_row_table_id for f in books_link_fields]}", + ) + checks.check( + "Authors has text field (name/bio)", + TextField in authors_field_types or LongTextField in authors_field_types, + hint=f"authors field types: {[t.__name__ for t in authors_field_types]}", + ) + checks.check( + "Books has >=2 text/long_text fields (title + description)", + sum(1 for f in books_fields if isinstance(f, (TextField, LongTextField))) + >= 2, + hint=f"books text fields: {[f.name for f in books_fields if isinstance(f, (TextField, LongTextField))]}", + ) + + +# --------------------------------------------------------------------------- +# Parametrized view creation eval +# --------------------------------------------------------------------------- + + +def _setup_grid(data_fixture, table): + """Grid view needs no special fields.""" + return {} + + +def _setup_kanban(data_fixture, table): + """Kanban needs a single_select field.""" + field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=field, value="To Do", order=1) + data_fixture.create_select_option(field=field, value="In Progress", order=2) + data_fixture.create_select_option(field=field, value="Done", order=3) + return {"status_field": field} + + +def _setup_calendar(data_fixture, table): + """Calendar needs a date field.""" + field = data_fixture.create_date_field(table=table, name="Due Date") + return {"date_field": field} + + +def _setup_gallery(data_fixture, table): + """Gallery needs a file field.""" + field = data_fixture.create_file_field(table=table, name="Cover Image") + return {"file_field": field} + + +def _setup_timeline(data_fixture, table): + """Timeline needs two date fields with matching include_time.""" + start = data_fixture.create_date_field( + table=table, name="Start Date", date_include_time=False + ) + end = data_fixture.create_date_field( + table=table, name="End Date", date_include_time=False + ) + return {"start_field": start, "end_field": end} + + +def _setup_form(data_fixture, table): + """Form view uses existing fields; no extra setup beyond what's already there.""" + return {} + + +_VIEW_TEST_CASES = [ + pytest.param("grid", _setup_grid, PROMPT_CREATE_GRID_VIEW, id="grid"), + pytest.param("kanban", _setup_kanban, PROMPT_CREATE_KANBAN_VIEW, id="kanban"), + pytest.param( + "calendar", _setup_calendar, PROMPT_CREATE_CALENDAR_VIEW, id="calendar" + ), + pytest.param("gallery", _setup_gallery, PROMPT_CREATE_GALLERY_VIEW, id="gallery"), + pytest.param( + "timeline", _setup_timeline, PROMPT_CREATE_TIMELINE_VIEW, id="timeline" + ), + pytest.param("form", _setup_form, PROMPT_CREATE_FORM_VIEW, id="form"), +] + + +_EXPECTED_VIEW_NAMES = { + "grid": "all tasks", + "kanban": "task board", + "calendar": "schedule", + "gallery": "image gallery", + "timeline": "project timeline", + "form": "submit task", +} + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize("view_type,setup_fn,prompt_template", _VIEW_TEST_CASES) +def test_agent_creates_view( + data_fixture, eval_model, view_type, setup_fn, prompt_template +): + """Agent should create a view of the given type without tool errors.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + # Set up type-specific fields + extra = setup_fn(data_fixture, table) + + # Build prompt with field IDs injected + fmt_kwargs = {"table_name": table.name} + for key, field in extra.items(): + fmt_kwargs[f"{key}_name"] = field.name + prompt = prompt_template.format(**fmt_kwargs) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=prompt, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + views = View.objects.filter(table=table) + typed_views = [ + v for v in views if v.get_type().type == view_type and v.name != "Grid" + ] + + view_name_ok = any( + _EXPECTED_VIEW_NAMES[view_type] in v.name.lower() for v in typed_views + ) + + with EvalChecklist(f"creates {view_type} view") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + f"{view_type} view created", + len(typed_views) >= 1, + hint=f"got views: {[(v.name, v.get_type().type) for v in views]}", + ) + checks.check( + "view name matches expected", + view_name_ok, + hint=f"expected '{_EXPECTED_VIEW_NAMES[view_type]}', got: {[v.name for v in typed_views]}", + ) + + +# --------------------------------------------------------------------------- +# Parametrized view filter creation eval +# --------------------------------------------------------------------------- + + +def _setup_text_filter(data_fixture, table): + field = data_fixture.create_text_field(table=table, name="Description") + return {"text_field": field} + + +def _setup_number_filter(data_fixture, table): + field = data_fixture.create_number_field(table=table, name="Amount") + return {"number_field": field} + + +def _setup_date_filter(data_fixture, table): + field = data_fixture.create_date_field(table=table, name="Due Date") + return {"date_field": field} + + +def _setup_single_select_filter(data_fixture, table): + field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=field, value="Active", order=1) + data_fixture.create_select_option(field=field, value="Pending", order=2) + data_fixture.create_select_option(field=field, value="Closed", order=3) + return {"select_field": field} + + +def _setup_multiple_select_filter(data_fixture, table): + field = data_fixture.create_multiple_select_field(table=table, name="Tags") + data_fixture.create_select_option(field=field, value="Important", order=1) + data_fixture.create_select_option(field=field, value="Urgent", order=2) + data_fixture.create_select_option(field=field, value="Low", order=3) + return {"multi_field": field} + + +def _setup_boolean_filter(data_fixture, table): + field = data_fixture.create_boolean_field(table=table, name="Active") + return {"bool_field": field} + + +_FILTER_TEST_CASES = [ + pytest.param( + "text", + _setup_text_filter, + PROMPT_FILTER_TEXT_CONTAINS, + "contains", + "important", + id="text_contains", + ), + pytest.param( + "number", + _setup_number_filter, + PROMPT_FILTER_NUMBER_GREATER_THAN, + "higher_than", + "100", + id="number_greater_than", + ), + pytest.param( + "date", + _setup_date_filter, + PROMPT_FILTER_DATE_AFTER, + "date_is_after", + None, # value contains UTC?date_mode format — fragile to check + id="date_after", + ), + pytest.param( + "single_select", + _setup_single_select_filter, + PROMPT_FILTER_SINGLE_SELECT_ANY_OF, + "single_select_is_any_of", + None, # value is comma-separated option IDs — fragile to check + id="single_select_is_any_of", + ), + pytest.param( + "multiple_select", + _setup_multiple_select_filter, + PROMPT_FILTER_MULTIPLE_SELECT_HAS, + "multiple_select_has", + None, # value is option ID — fragile to check + id="multiple_select_has", + ), + pytest.param( + "boolean", + _setup_boolean_filter, + PROMPT_FILTER_BOOLEAN_IS, + "equal", + "1", + id="boolean_equal", + ), +] + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +@pytest.mark.parametrize( + "filter_type,setup_fn,prompt_template,expected_orm_type,expected_value_fragment", + _FILTER_TEST_CASES, +) +def test_agent_creates_view_filter( + data_fixture, + eval_model, + filter_type, + setup_fn, + prompt_template, + expected_orm_type, + expected_value_fragment, +): + """Agent should create a view with the correct filter type without tool errors.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + # Set up type-specific fields + extra = setup_fn(data_fixture, table) + + # Build prompt with field IDs injected + fmt_kwargs = {"table_name": table.name} + for key, field in extra.items(): + fmt_kwargs[f"{key}_name"] = field.name + prompt = prompt_template.format(**fmt_kwargs) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=prompt, + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + filters = ViewFilter.objects.filter(view__table=table, type=expected_orm_type) + all_filter_types = list( + ViewFilter.objects.filter(view__table=table).values_list("type", flat=True) + ) + filter_obj = filters.first() + setup_field = list(extra.values())[0] if extra else None + + with EvalChecklist(f"creates {filter_type} view filter") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + f"ViewFilter type='{expected_orm_type}' exists", + filters.exists(), + hint=f"got filter types: {all_filter_types}", + ) + checks.check( + "filter is on the correct field", + filter_obj is not None + and setup_field is not None + and filter_obj.field_id == setup_field.id, + hint=f"filter field_id={filter_obj.field_id if filter_obj else None}, expected={setup_field.id if setup_field else None}", + ) + if expected_value_fragment is not None: + checks.check( + "filter value is correct", + filter_obj is not None + and expected_value_fragment in (filter_obj.value or ""), + hint=f"filter value='{filter_obj.value if filter_obj else None}', expected fragment='{expected_value_fragment}'", + ) + + +# --------------------------------------------------------------------------- +# Field update/delete evals +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_renames_field(data_fixture, eval_model): + """Agent should rename a field when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + data_fixture.create_long_text_field(table=table, name="Description") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_UPDATE_FIELD_RENAME.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + field_names = list(table.field_set.all().values_list("name", flat=True)) + + with EvalChecklist("renames field") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Summary field exists", + any("summary" in n.lower() for n in field_names), + hint=f"fields: {field_names}", + ) + checks.check( + "Description field gone", + not any(n.lower() == "description" for n in field_names), + hint=f"fields: {field_names}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_updates_select_options(data_fixture, eval_model): + """Agent should add a new option to a single_select field.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + status_field = data_fixture.create_single_select_field(table=table, name="Status") + data_fixture.create_select_option(field=status_field, value="To Do", order=1) + data_fixture.create_select_option(field=status_field, value="Done", order=2) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_UPDATE_FIELD_SELECT_OPTIONS.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + status_field.refresh_from_db() + options = list(status_field.select_options.values_list("value", flat=True)) + + with EvalChecklist("updates select options") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "In Progress option added", + any("in progress" in o.lower() for o in options), + hint=f"options: {options}", + ) + checks.check( + "existing options preserved", + {"to do", "done"} <= {o.lower() for o in options}, + hint=f"options: {options}", + ) + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_agent_deletes_field(data_fixture, eval_model): + """Agent should delete a field when asked.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + data_fixture.create_text_field(table=table, name="Name", primary=True) + data_fixture.create_long_text_field(table=table, name="Notes") + data_fixture.create_text_field(table=table, name="Priority") + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=15, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database, table) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_DELETE_FIELD.format(table_name=table.name), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + field_names = list(table.field_set.all().values_list("name", flat=True)) + + with EvalChecklist("deletes field") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Notes field gone", + not any(n.lower() == "notes" for n in field_names), + hint=f"fields: {field_names}", + ) + checks.check( + "other fields preserved", + any("name" in n.lower() for n in field_names) + and any("priority" in n.lower() for n in field_names), + hint=f"fields: {field_names}", + ) + + +# --------------------------------------------------------------------------- +# Sample rows eval +# --------------------------------------------------------------------------- + + +@pytest.mark.eval +@pytest.mark.django_db(transaction=True) +def test_create_related_tables_with_sample_rows(data_fixture, eval_model): + """ + Agent creates two related tables (Authors → Books) and sample rows + are generated for both, including link_row references. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Bookstore" + ) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=25, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=PROMPT_CREATE_RELATED_TABLES_WITH_SAMPLE_ROWS.format( + database_name=database.name + ), + ui_context=ui_context, + ) + + print_message_history(result) + err_count, err_hint = count_tool_errors(result) + + tables = Table.objects.filter(database=database) + table_names = {t.name.lower(): t for t in tables} + author_tables = [name for name in table_names if "author" in name] + book_tables = [name for name in table_names if "book" in name] + + authors_count = ( + table_names[author_tables[0]].get_model().objects.count() + if author_tables + else 0 + ) + books_count = ( + table_names[book_tables[0]].get_model().objects.count() if book_tables else 0 + ) + books_table_obj = table_names[book_tables[0]] if book_tables else None + books_fields_list = ( + list(specific_iterator(books_table_obj.field_set.all())) + if books_table_obj + else [] + ) + genre_field = next( + ( + f + for f in books_fields_list + if isinstance(f, SingleSelectField) and "genre" in f.name.lower() + ), + None, + ) + genre_options = ( + list(genre_field.select_options.values_list("value", flat=True)) + if genre_field + else [] + ) + genre_option_values = {o.lower() for o in genre_options} + price_field = next( + ( + f + for f in books_fields_list + if isinstance(f, NumberField) and "price" in f.name.lower() + ), + None, + ) + books_link_fields_list = [ + f for f in books_fields_list if isinstance(f, LinkRowField) + ] + + with EvalChecklist("creates Bookstore with sample rows") as checks: + checks.check("no tool errors", err_count == 0, hint=err_hint) + checks.check( + "Authors table exists", + len(author_tables) >= 1, + hint=f"got: {list(table_names.keys())}", + ) + checks.check( + "Books table exists", + len(book_tables) >= 1, + hint=f"got: {list(table_names.keys())}", + ) + checks.check( + "Authors has >=1 sample row", + authors_count >= 1, + hint=f"got {authors_count}", + ) + checks.check( + "Books has >=2 sample rows", + books_count >= 2, + hint=f"got {books_count}", + ) + checks.check( + "Books has Genre single_select field", + genre_field is not None, + hint=f"books select fields: {[f.name for f in books_fields_list if isinstance(f, SingleSelectField)]}", + ) + checks.check( + "Genre has Fiction / Non-Fiction / Science / History options", + {"fiction", "non-fiction", "science", "history"} <= genre_option_values, + hint=f"got: {genre_options}", + ) + checks.check( + "Books has Price (number) field", + price_field is not None, + hint=f"books number fields: {[f.name for f in books_fields_list if isinstance(f, NumberField)]}", + ) + checks.check( + "Books has link_row to Authors", + len(books_link_fields_list) >= 1, + hint=f"books fields: {[f.name for f in books_fields_list]}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py new file mode 100644 index 0000000000..2aa153e221 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/evals/test_eval_search_user_docs.py @@ -0,0 +1,276 @@ +from django.conf import settings + +import pytest + +from .eval_utils import ( + EvalChecklist, + build_database_ui_context, + create_eval_assistant, + format_message_history, + print_message_history, +) + + +@pytest.fixture(autouse=True) +def _require_knowledge_base(synced_knowledge_base): + """Skip search docs tests when the knowledge base is not available. + + Depends on the session-scoped ``synced_knowledge_base`` fixture + (conftest.py) which syncs the KB once per session if needed. + """ + + if not getattr(settings, "BASEROW_EMBEDDINGS_API_URL", ""): + pytest.skip( + "BASEROW_EMBEDDINGS_API_URL not set. " + "See docs/testing/ai-assistant-evals.md for setup instructions." + ) + + from baserow_enterprise.assistant.tools.search_user_docs.handler import ( + KnowledgeBaseHandler, + ) + + if not KnowledgeBaseHandler().can_search(): + pytest.skip( + "Knowledge base not available. " + "Requires: pgvector extension and synced KB data. " + "See docs/testing/ai-assistant-evals.md for setup instructions." + ) + + +# --------------------------------------------------------------------------- +# Test cases: (id, question, expected_source_patterns, expected_answer_keywords) +# +# expected_source_patterns: at least ONE returned source URL must contain +# one of these substrings. +# expected_answer_keywords: the agent's final answer must contain at least +# ONE of these substrings (case-insensitive). +# --------------------------------------------------------------------------- + +SEARCH_DOCS_CASES = [ + pytest.param( + ( + "I'm trying to do a VLOOKUP to pull the 'Client Email' from my " + "'Clients' tab into my 'Projects' tab based on the client name. " + "I can't find the formula for this. Does it exist in Baserow?" + ), + ["link-to-table", "lookup-field"], + ["link row", "lookup", "link_row", "relationship"], + id="vlookup-to-link-row", + ), + pytest.param( + ( + "I need to run a raw SQL query to join three tables for a report. " + "I'm on the standard cloud hosted plan. Where do I find my database " + "host, port, and credentials to connect my BI tool?" + ), + ["technical", "set-up-baserow"], + ["api", "self-host", "rest api", "not available", "cannot"], + id="raw-sql-cloud-plan", + ), + pytest.param( + ( + "I'm trying to calculate the days between two dates. I typed " + "=DAYS(field('End'), field('Start')) like I do in Google Sheets " + "but it says 'Invalid Syntax'. What am I doing wrong?" + ), + ["formula", "understanding-formulas"], + ["date_diff", "date diff", "datediff"], + id="date-diff-formula", + ), + pytest.param( + "Where is the save button? I don't want to lose my work.", + ["baserow-basics"], + ["auto", "automatically", "saved"], + id="auto-save", + ), + pytest.param( + "How can I put a form on my website that sends data to my table?", + ["creating-forms", "guide-to-creating-forms"], + ["form", "embed", "share"], + id="form-embed", + ), + pytest.param( + "I deleted a bunch of rows by mistake. Is there a recycling bin?", + ["data-recovery", "deletion"], + ["trash", "recover", "undo", "restore"], + id="data-recovery", + ), + pytest.param( + ( + "I want to share a specific view with my client so they can see " + "the progress, but I don't want them to edit anything or see the " + "other tables. Is that possible?" + ), + ["public-sharing", "permissions"], + ["share", "public", "read-only", "read only", "view"], + id="share-view-read-only", + ), + pytest.param( + "I need to lock a column so my team can see it but not mess it up.", + ["field-level-permissions", "permissions"], + ["permission", "field", "read", "lock"], + id="field-permissions", + ), + pytest.param( + ( + "How can I create a calendar that shows my tasks, but only the ones assigned to me." + ), + ["calendar-view", "calendar", "filters"], + ["calendar", "filter", "view"], + id="calendar-with-filter", + ), + pytest.param( + ( + "I'm trying to combine the first name and last name columns " + "into one, but I want to make sure it's uppercase. Can you tell me how to " + "write that formula?" + ), + ["formula", "understanding-formulas"], + ["concat", "upper", "formula"], + id="concat-upper-formula", + ), + pytest.param( + ( + "I'm running Baserow on my own server with Docker. A new version " + "came out yesterday, how do I install it without losing my data?" + ), + ["set-up-baserow", "configuration"], + ["docker", "pull", "upgrade", "update", "volume"], + id="docker-upgrade", + ), + pytest.param( + ( + "I want to write a script so that whenever I tick a checkbox, " + "it sends an email to the client. Do I need to build a custom " + "plugin for this?" + ), + ["webhook", "workflow-automation", "automation"], + ["automation", "webhook", "trigger", "workflow"], + id="checkbox-email-automation", + ), + pytest.param( + ( + "I want to embed my inventory sheet on my website so clients " + "can search it. Do they need a Baserow account to see it? " + "How do I generate the code?" + ), + ["public-sharing"], + ["embed", "public", "share", "account"], + id="embed-public-view", + ), + pytest.param( + "Can Baserow integrate with Google AI Studio?", + ["configure-generative-ai", "database-api"], + ["ai", "generative", "integration", "api"], + id="google-ai-studio", + ), + pytest.param( + ( + "I'm trying to fetch data from my table using curl but I keep " + "getting a 401 error. I generated a token in my settings, but it " + "says I don't have permissions. Do I need to use my login email " + "and password instead?" + ), + ["rest-api", "database-api"], + ["token", "api", "permission", "authentication"], + id="api-401-error", + ), + pytest.param( + ( + "Is there a way to only get rows where the 'Status' field is " + "set to 'Done' via the API? I don't want to download the whole " + "JSON and filter it in my script." + ), + ["rest-api", "database-api"], + ["filter", "api", "parameter", "field"], + id="api-filter-rows", + ), +] + + +def _run_agent( + agent, deps, tracker, model, usage_limits, toolset, question, ui_context +): + deps.tool_helpers.request_context["ui_context"] = ui_context + return agent.run_sync( + user_prompt=question, + deps=deps, + model=model, + usage_limits=usage_limits, + toolsets=[toolset], + ) + + +@pytest.mark.eval +@pytest.mark.django_db +@pytest.mark.parametrize( + "question,expected_source_patterns,expected_keywords", SEARCH_DOCS_CASES +) +def test_search_user_docs( + data_fixture, + eval_model, + question, + expected_source_patterns, + expected_keywords, +): + """ + Agent should call search_user_docs for user-docs questions and return + an answer with relevant sources and content. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + agent, deps, tracker, model, usage_limits, toolset = create_eval_assistant( + user, workspace, max_iters=10, model=eval_model + ) + ui_context = build_database_ui_context(user, workspace, database) + + result = _run_agent( + agent, + deps, + tracker, + model, + usage_limits, + toolset, + question=question, + ui_context=ui_context, + ) + + print_message_history(result) + + history = format_message_history(result) + search_calls = [ + e + for e in history + if e.get("tool_name") == "search_user_docs" and e["role"] == "assistant" + ] + sources = deps.sources + answer = result.output.lower() + keyword_match = any(kw.lower() in answer for kw in expected_keywords) + + # Source URL matching is non-fatal — URLs change and the retrieval may + # return valid alternative sources. Print a warning but don't score it. + if expected_source_patterns and sources: + source_match = any( + any(pattern in url for pattern in expected_source_patterns) + for url in sources + ) + if not source_match: + print( + f"\n WARNING: No source matched {expected_source_patterns}.\n" + f" Returned sources: {sources}" + ) + + with EvalChecklist("search user docs") as checks: + checks.check( + "called search_user_docs", + len(search_calls) >= 1, + hint=f"tools called: {[e.get('tool_name') for e in history if e.get('tool_name')]}", + ) + checks.check( + f"answer mentions one of {expected_keywords}", + keyword_match, + hint=f"answer (first 300 chars): {result.output[:300]}", + ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 1d5bbec55f..62e5e77d54 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -1,15 +1,21 @@ from unittest.mock import MagicMock, patch -from django.core.cache import cache +from django.test.utils import override_settings import pytest from asgiref.sync import async_to_sync -from udspy import OutputStreamChunk, Prediction +from pydantic_ai.messages import PartStartEvent +from pydantic_ai.messages import TextPart as PaiTextPart -from baserow_enterprise.assistant.assistant import Assistant, AssistantCallbacks -from baserow_enterprise.assistant.exceptions import AssistantMessageCancelled +from baserow_enterprise.assistant.assistant import ( + Assistant, + compact_message_history, + get_model_string, +) +from baserow_enterprise.assistant.deps import AssistantDeps from baserow_enterprise.assistant.models import AssistantChat, AssistantChatMessage from baserow_enterprise.assistant.types import ( + AiMessage, AiMessageChunk, AiStartedMessage, AiThinkingMessage, @@ -23,133 +29,113 @@ WorkspaceUIContext, ) +TEST_MODEL = "groq:test-model" + @pytest.fixture(autouse=True) -def mock_posthog_openai(): - with patch("posthog.ai.openai.AsyncOpenAI") as mock: - # Configure the mock if needed +def mock_posthog(): + with patch("baserow_enterprise.assistant.telemetry.get_posthog_client") as mock: mock.return_value = MagicMock() - mock.return_value.model = "test-model" yield mock -@pytest.mark.django_db -class TestAssistantCallbacks: - """Test the AssistantCallbacks class for handling tool execution""" - - def test_extend_sources_deduplicates(self): - """Test that sources are deduplicated when extended""" - - callbacks = AssistantCallbacks() - - # Add initial sources - callbacks.extend_sources( - ["https://example.com/doc1", "https://example.com/doc2"] - ) - assert callbacks.sources == [ - "https://example.com/doc1", - "https://example.com/doc2", - ] +@pytest.fixture(autouse=True) +def _set_test_model(settings): + settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL = "groq/test-model" - # Add sources with duplicates - callbacks.extend_sources( - ["https://example.com/doc2", "https://example.com/doc3"] - ) - # Should only add the new source, not the duplicate - assert callbacks.sources == [ - "https://example.com/doc1", - "https://example.com/doc2", - "https://example.com/doc3", - ] +# --------------------------------------------------------------------------- +# Mock helpers for pydantic-ai's run_stream_events async generator +# --------------------------------------------------------------------------- - def test_extend_sources_preserves_order(self): - """Test that source order is preserved (first occurrence wins)""" - callbacks = AssistantCallbacks() +async def _mock_run_stream_events(answer: str, messages_json: bytes = b"[]"): + """ + Async generator that mimics ``main_agent.run_stream_events()`` + yielding PartStartEvent, then AgentRunResultEvent. + """ + from pydantic_ai.run import AgentRunResultEvent - callbacks.extend_sources(["https://example.com/a"]) - callbacks.extend_sources(["https://example.com/b"]) - callbacks.extend_sources(["https://example.com/a"]) # Duplicate + # Emit a text part start with the full answer + yield PartStartEvent(index=0, part=PaiTextPart(content=answer)) - # 'a' should remain first - assert callbacks.sources == ["https://example.com/a", "https://example.com/b"] + # Emit the final result event + mock_result = MagicMock() + mock_result.output = answer + mock_result.all_messages_json.return_value = messages_json + yield AgentRunResultEvent(result=mock_result) - def test_on_tool_end_extracts_sources_from_outputs(self): - """Test that sources are extracted from tool outputs""" - callbacks = AssistantCallbacks() +def make_mock_run_stream_events_side_effect(answer: str, messages_json: bytes = b"[]"): + """Return a side_effect callable that returns the mock async generator.""" - # Mock tool instance and inputs - tool_instance = MagicMock() - tool_instance.name = "search_user_docs" - inputs = {"query": "test"} + def side_effect(*args, **kwargs): + return _mock_run_stream_events(answer, messages_json) - # Register tool call - callbacks.tool_calls["call_123"] = (tool_instance, inputs) + return side_effect - # Mock registry - with patch( - "baserow_enterprise.assistant.assistant.assistant_tool_registry" - ) as mock_registry: - mock_tool = MagicMock() - mock_registry.get.return_value = mock_tool - # Tool returns outputs with sources - outputs = { - "result": "Some documentation", - "sources": ["https://baserow.io/docs/api"], - } +# --------------------------------------------------------------------------- +# Unit tests +# --------------------------------------------------------------------------- - callbacks.on_tool_end("call_123", outputs) - # Sources should be extracted - assert callbacks.sources == ["https://baserow.io/docs/api"] +@pytest.mark.django_db +class TestAssistantDeps: + """Test the AssistantDeps class for source tracking.""" - def test_on_tool_end_handles_missing_sources(self): - """Test that tool outputs without sources don't cause errors""" + def test_extend_sources_deduplicates(self): + deps = AssistantDeps( + user=MagicMock(), + workspace=MagicMock(), + tool_helpers=MagicMock(), + ) - callbacks = AssistantCallbacks() + deps.extend_sources(["https://example.com/doc1", "https://example.com/doc2"]) + assert deps.sources == [ + "https://example.com/doc1", + "https://example.com/doc2", + ] - tool_instance = MagicMock() - tool_instance.name = "some_tool" - callbacks.tool_calls["call_123"] = (tool_instance, {}) + deps.extend_sources(["https://example.com/doc2", "https://example.com/doc3"]) - with patch( - "baserow_enterprise.assistant.assistant.assistant_tool_registry" - ) as mock_registry: - mock_tool = MagicMock() - mock_registry.get.return_value = mock_tool + assert deps.sources == [ + "https://example.com/doc1", + "https://example.com/doc2", + "https://example.com/doc3", + ] - # Tool returns outputs without sources - outputs = {"result": "Some result"} + def test_extend_sources_preserves_order(self): + deps = AssistantDeps( + user=MagicMock(), + workspace=MagicMock(), + tool_helpers=MagicMock(), + ) - callbacks.on_tool_end("call_123", outputs) + deps.extend_sources(["https://example.com/a"]) + deps.extend_sources(["https://example.com/b"]) + deps.extend_sources(["https://example.com/a"]) - # Should not raise, sources should remain empty - assert callbacks.sources == [] + assert deps.sources == ["https://example.com/a", "https://example.com/b"] @pytest.mark.django_db class TestAssistantChatHistory: - """Test chat history loading and formatting""" + """Test chat history loading and formatting.""" def test_list_chat_messages_returns_in_chronological_order( self, enterprise_data_fixture ): - """Test that list_chat_messages returns messages oldest to newest""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create messages in order - msg1 = AssistantChatMessage.objects.create( + AssistantChatMessage.objects.create( chat=chat, role=AssistantChatMessage.Role.HUMAN, content="First question" ) - msg2 = AssistantChatMessage.objects.create( + AssistantChatMessage.objects.create( chat=chat, role=AssistantChatMessage.Role.AI, content="First answer" ) msg3 = AssistantChatMessage.objects.create( @@ -159,77 +145,38 @@ def test_list_chat_messages_returns_in_chronological_order( assistant = Assistant(chat) messages = assistant.list_chat_messages() - # Should be in chronological order (oldest first) assert len(messages) == 3 assert messages[0].content == "First question" assert messages[1].content == "First answer" assert messages[2].content == "Second question" - # It's possible to skip messages using last_message_id messages = assistant.list_chat_messages(last_message_id=msg3.id, limit=1) assert len(messages) == 1 assert messages[0].content == "First answer" - def test_aload_chat_history_formats_as_question_answer_pairs( - self, enterprise_data_fixture - ): - """Test that chat history is loaded as user/assistant message pairs for UDSPy""" - + def test_load_message_history_returns_none_for_empty(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create conversation history - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="What is Baserow?" - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Baserow is a no-code database platform.", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content="How do I create a table?", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="You can create a table by clicking the + button.", - ) - assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)() - - # History should contain user/assistant message pairs - assert assistant.history is not None - assert len(assistant.history.messages) == 4 - - # First pair - assert assistant.history.messages[0] == { - "role": "user", - "content": "What is Baserow?", - } - assert assistant.history.messages[1] == { - "role": "assistant", - "content": "Baserow is a no-code database platform.", - } - - # Second pair - assert assistant.history.messages[2] == { - "role": "user", - "content": "How do I create a table?", - } - assert assistant.history.messages[3] == { - "role": "assistant", - "content": "You can create a table by clicking the + button.", - } - - def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): - """Test that history loading respects the limit parameter""" + history = async_to_sync(assistant._load_message_history)() + assert history is None + + def test_load_message_history_deserializes_and_compacts( + self, enterprise_data_fixture + ): + from pydantic_ai.messages import ( + ModelMessagesTypeAdapter, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) @@ -237,178 +184,147 @@ def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): user=user, workspace=workspace, title="Test Chat" ) - # Create 10 message pairs (20 messages) - for i in range(10): - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content=f"Question {i}", - ) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.AI, content=f"Answer {i}" - ) + messages = [ + ModelRequest(parts=[UserPromptPart(content="create a database")]), + ModelResponse( + parts=[ + ToolCallPart( + tool_name="create_tables", + args={"thought": "creating", "tables": ["recipes"]}, + tool_call_id="tc1", + ) + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name="create_tables", + content="Created", + tool_call_id="tc1", + ) + ] + ), + ModelResponse(parts=[TextPart(content="Done!")]), + ] + chat.message_history = ModelMessagesTypeAdapter.dump_json(messages) + chat.save(update_fields=["message_history"]) assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)( - limit=6 - ) # Last 6 messages - - # Should only load the most recent 6 messages (3 pairs) - assert len(assistant.history.messages) == 6 + history = async_to_sync(assistant._load_message_history)() - def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixture): - """ - Test that incomplete message pairs (e.g., orphaned human messages) are skipped - """ + assert history is not None + assert len(history) == 2 + assert isinstance(history[0], ModelRequest) + assert isinstance(history[1], ModelResponse) + def test_load_message_history_handles_corrupt_data(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Create complete pair - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question 1" - ) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.AI, content="Answer 1" - ) - - # Create orphaned human message (no AI response yet) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question 2" - ) + chat.message_history = b"not valid json" + chat.save(update_fields=["message_history"]) assistant = Assistant(chat) - assistant.history = async_to_sync(assistant.afetch_chat_history)() - - # Should only include the complete pair (2 messages: user + assistant) - assert len(assistant.history.messages) == 2 - assert assistant.history.messages[0] == { - "role": "user", - "content": "Question 1", - } - assert assistant.history.messages[1] == { - "role": "assistant", - "content": "Answer 1", - } - - @patch("udspy.ReAct.astream") - def test_history_is_passed_to_astream_as_context( - self, mock_react_astream, enterprise_data_fixture - ): - """ - Test that chat history is loaded correctly and passed to the agent as context - """ + history = async_to_sync(assistant._load_message_history)() + assert history is None + + +class TestCompactMessageHistory: + """Test the message history compaction logic.""" + + def test_compacts_tool_calls_in_older_turns(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + ) + + messages = [ + ModelRequest(parts=[UserPromptPart(content="create a database")]), + ModelResponse( + parts=[ + ToolCallPart( + tool_name="create_tables", + args={"thought": "creating"}, + tool_call_id="tc1", + ) + ] + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name="create_tables", + content="Created", + tool_call_id="tc1", + ) + ] + ), + ModelResponse(parts=[TextPart(content="Done!")]), + ModelRequest(parts=[UserPromptPart(content="add a field")]), + ModelResponse(parts=[TextPart(content="Added!")]), + ] - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test Chat" - ) + compacted = compact_message_history(messages) + assert len(compacted) == 4 - # Create conversation history (2 complete pairs) - AssistantChatMessage.objects.create( - chat=chat, role=AssistantChatMessage.Role.HUMAN, content="What is Baserow?" - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Baserow is a no-code database", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.HUMAN, - content="How do I create a table?", - ) - AssistantChatMessage.objects.create( - chat=chat, - role=AssistantChatMessage.Role.AI, - content="Click the Create Table button", + def test_trims_to_max_messages(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, ) - assistant = Assistant(chat) + messages = [] + for i in range(20): + messages.append( + ModelRequest(parts=[UserPromptPart(content=f"Question {i}")]) + ) + messages.append(ModelResponse(parts=[TextPart(content=f"Answer {i}")])) - # Mock the agent stream to verify conversation history is passed - def mock_agent_stream_factory(*args, **kwargs): - # Verify conversation history is passed to the agent - assert kwargs["conversation_history"] == [ - "[0] (user): What is Baserow?", - "[1] (assistant): Baserow is a no-code database", - "[2] (user): How do I create a table?", - "[3] (assistant): Click the Create Table button", - ] - - async def _stream(): - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Answer", - trajectory=[], - reasoning="", - ) - - return _stream() - - mock_react_astream.side_effect = mock_agent_stream_factory - - message = HumanMessage(content="How to add a view?") - - # Consume the stream to trigger assertions - async def consume_stream(): - async for _ in assistant.astream_messages(message): - pass + compacted = compact_message_history(messages, max_messages=6) + assert len(compacted) == 6 - async_to_sync(consume_stream)() + def test_preserves_simple_conversations(self): + from pydantic_ai.messages import ( + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, + ) + + messages = [ + ModelRequest(parts=[UserPromptPart(content="hello")]), + ModelResponse(parts=[TextPart(content="hi")]), + ] + + compacted = compact_message_history(messages) + assert len(compacted) == 2 @pytest.mark.django_db class TestAssistantMessagePersistence: - """Test that messages are persisted correctly during streaming""" + """Test that messages are persisted correctly during streaming.""" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_persists_human_message( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that human messages are persisted to database before streaming""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - # Yield a simple response - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction(answer="Hello", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) assistant = Assistant(chat) ui_context = UIContext( @@ -416,7 +332,6 @@ async def mock_agent_stream(*args, **kwargs): user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Consume the stream async def consume_stream(): human_message = HumanMessage(content="Test message", ui_context=ui_context) async for _ in assistant.astream_messages(human_message): @@ -424,7 +339,6 @@ async def consume_stream(): async_to_sync(consume_stream)() - # Human message should be persisted human_messages = AssistantChatMessage.objects.filter( chat=chat, role=AssistantChatMessage.Role.HUMAN ).count() @@ -435,129 +349,64 @@ async def consume_stream(): ).first() assert saved_message.content == "Test message" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_persists_ai_message_with_sources( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_persists_ai_message( + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that AI messages are persisted with sources in artifacts""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - assistant = Assistant(chat) - - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming with a Prediction at the end - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Based on docs", - content="Based on docs", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Based on docs", - trajectory=[], - reasoning="", - ) + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Based on docs" + ) - mock_react_astream.return_value = mock_agent_stream() + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Manually add sources to callback manager (simulating tool execution) async def consume_stream(): - messages = [] human_message = HumanMessage(content="Question", ui_context=ui_context) - async for msg in assistant.astream_messages(human_message): - messages.append(msg) - return messages + async for _ in assistant.astream_messages(human_message): + pass async_to_sync(consume_stream)() - # AI message should be persisted ai_messages = AssistantChatMessage.objects.filter( chat=chat, role=AssistantChatMessage.Role.AI ).count() assert ai_messages == 1 - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - @patch("udspy.Predict") + @patch("baserow_enterprise.assistant.agents.title_agent.run") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_persists_chat_title( self, - mock_predict_class, - mock_react_astream, - mock_cot_astream, + mock_run_stream_events, + mock_title_run, enterprise_data_fixture, ): - """Test that chat titles are persisted to the database""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, - workspace=workspace, - title="", # New chat - ) + chat = AssistantChat.objects.create(user=user, workspace=workspace, title="") - # Mock title generator - async def mock_title_aforward(*args, **kwargs): - return Prediction(chat_title="Greeting") + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) - mock_title_generator = MagicMock() - mock_title_generator.aforward = mock_title_aforward - mock_predict_class.return_value = mock_title_generator + mock_title_result = MagicMock() + mock_title_result.output = "Greeting" + mock_title_run.return_value = mock_title_result assistant = Assistant(chat) - - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, answer="Hello", trajectory=[], reasoning="" - ) - - mock_react_astream.return_value = mock_agent_stream() ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), ) - # Consume the stream async def consume_stream(): human_message = HumanMessage(content="Hello", ui_context=ui_context) async for _ in assistant.astream_messages(human_message): @@ -565,187 +414,112 @@ async def consume_stream(): async_to_sync(consume_stream)() - # Refresh from DB chat.refresh_from_db() - - # Title should be persisted assert chat.title == "Greeting" @pytest.mark.django_db class TestAssistantStreaming: - """Test streaming behavior of the Assistant""" + """Test streaming behavior of the Assistant.""" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_yields_answer_chunks( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that answer chunks are yielded during streaming""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello world" + ) assistant = Assistant(chat) - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta=" world", - content="Hello world", - is_complete=False, - ) - yield Prediction(answer="Hello world", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() - async def consume_stream(): - chunks = [] + messages = [] human_message = HumanMessage(content="Test") async for msg in assistant.astream_messages(human_message): - if isinstance(msg, AiMessageChunk): - chunks.append(msg) - return chunks + messages.append(msg) + return messages - chunks = async_to_sync(consume_stream)() + messages = async_to_sync(consume_stream)() - # Should receive chunks with accumulated content - assert len(chunks) == 2 - assert chunks[0].content == "Hello" - assert chunks[1].content == "Hello world" + # Filter for final AiMessage + ai_messages = [m for m in messages if isinstance(m, AiMessage)] + assert len(ai_messages) == 1 + assert ai_messages[0].content == "Hello world" + assert ai_messages[0].id is not None - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - @patch("udspy.Predict") - def test_astream_messages_yields_title_chunks( - self, - mock_predict_class, - mock_react_astream, - mock_cot_astream, - enterprise_data_fixture, - ): - """Test that title chunks are yielded for new chats""" + # Should also have AiMessageChunk(s) + chunks = [ + m + for m in messages + if isinstance(m, AiMessageChunk) and not isinstance(m, AiMessage) + ] + assert len(chunks) >= 1 + @patch("baserow_enterprise.assistant.agents.title_agent.run") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_yields_title_for_new_chat( + self, mock_run_stream_events, mock_title_run, enterprise_data_fixture + ): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, - workspace=workspace, - title="", # New chat - ) + chat = AssistantChat.objects.create(user=user, workspace=workspace, title="") - # Mock title generator - async def mock_title_aforward(*args, **kwargs): - return Prediction(chat_title="Title") + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Answer" + ) - mock_title_generator = MagicMock() - mock_title_generator.aforward = mock_title_aforward - mock_predict_class.return_value = mock_title_generator + mock_title_result = MagicMock() + mock_title_result.output = "Title" + mock_title_run.return_value = mock_title_result assistant = Assistant(chat) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction( - module=assistant._assistant, - answer="Answer", - trajectory=[], - reasoning="", - ) - - mock_react_astream.return_value = mock_agent_stream() - async def consume_stream(): - title_messages = [] + msgs = [] human_message = HumanMessage(content="Test") async for msg in assistant.astream_messages(human_message): - if isinstance(msg, ChatTitleMessage): - title_messages.append(msg) - return title_messages + msgs.append(msg) + return msgs - title_messages = async_to_sync(consume_stream)() + messages = async_to_sync(consume_stream)() - # Should receive title chunks + title_messages = [m for m in messages if isinstance(m, ChatTitleMessage)] assert len(title_messages) == 1 assert title_messages[0].content == "Title" - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") def test_astream_messages_yields_thinking_messages( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_run_stream_events, enterprise_data_fixture ): - """Test that thinking messages from tools are yielded""" - user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test Chat" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) + assistant = Assistant(chat) - mock_cot_astream.return_value = mock_router_stream() + async def mock_stream_with_thinking(*args, **kwargs): + from pydantic_ai.run import AgentRunResultEvent - assistant = Assistant(chat) + # Emit thinking message via the event bus during streaming + assistant._event_bus.emit(AiThinkingMessage(content="still thinking...")) - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - yield AiThinkingMessage(content="still thinking...") - yield OutputStreamChunk( - module=assistant._assistant.extract_module, - field_name="answer", - delta="Answer", - content="Answer", - is_complete=False, - ) - yield Prediction(answer="Answer", trajectory=[], reasoning="") + # Yield text part then result + yield PartStartEvent(index=0, part=PaiTextPart(content="Answer")) + + mock_result = MagicMock() + mock_result.output = "Answer" + mock_result.all_messages_json.return_value = b"[]" + yield AgentRunResultEvent(result=mock_result) - mock_react_astream.return_value = mock_agent_stream() + mock_run_stream_events.side_effect = mock_stream_with_thinking ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -753,38 +527,60 @@ async def mock_agent_stream(*args, **kwargs): ) async def consume_stream(): - thinking_messages = [] + thinking = [] human_message = HumanMessage(content="Test", ui_context=ui_context) async for msg in assistant.astream_messages(human_message): if isinstance(msg, AiThinkingMessage): - thinking_messages.append(msg) - return thinking_messages + thinking.append(msg) + return thinking thinking_messages = async_to_sync(consume_stream)() - # Should receive the thinking message emitted by the agent stream assert len(thinking_messages) == 1 assert thinking_messages[0].content == "still thinking..." + @patch("baserow_enterprise.assistant.agents.main_agent.run_stream_events") + def test_astream_messages_yields_ai_started_message( + self, mock_run_stream_events, enterprise_data_fixture + ): + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test" + ) + + mock_run_stream_events.side_effect = make_mock_run_stream_events_side_effect( + "Hello" + ) + + assistant = Assistant(chat) + human_message = HumanMessage(content="Hello") + + async def collect_messages(): + messages = [] + async for msg in assistant.astream_messages(human_message): + messages.append(msg) + return messages + + messages = async_to_sync(collect_messages)() + + assert len(messages) > 0 + assert isinstance(messages[0], AiStartedMessage) + assert messages[0].message_id is not None + @pytest.mark.django_db class TestUIContext: - """Test UI context handling and validation""" + """Test UI context handling and validation.""" def test_ui_context_from_validate_request_adds_user_info( self, enterprise_data_fixture ): - """ - Test that UIContext.from_validate_request adds user information - from request - """ - user = enterprise_data_fixture.create_user( email="test@example.com", first_name="Test User" ) workspace = enterprise_data_fixture.create_workspace(user=user) - # Mock request object class MockRequest: pass @@ -792,7 +588,6 @@ class MockRequest: request.user = user ui_context_data = {"workspace": {"id": workspace.id, "name": workspace.name}} - ui_context = UIContext.from_validate_request(request, ui_context_data) assert ui_context.workspace.id == workspace.id @@ -802,8 +597,6 @@ class MockRequest: assert ui_context.user.name == "Test User" def test_ui_context_with_database_builder_fields(self): - """Test that UIContext correctly stores database builder fields""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test Workspace"), database=ApplicationUIContext(id="db-123", name="My Database"), @@ -814,123 +607,23 @@ def test_ui_context_with_database_builder_fields(self): assert ui_context.workspace.id == 1 assert ui_context.database.id == "db-123" - assert ui_context.database.name == "My Database" assert ui_context.table.id == 456 - assert ui_context.table.name == "Customers" assert ui_context.view.id == 789 - assert ui_context.view.name == "All Customers" - assert ui_context.view.type == "grid" - - def test_ui_context_with_application_builder_fields(self): - """Test that UIContext correctly stores application builder fields""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - application=ApplicationUIContext(id="app-123", name="My App"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - ) - - assert ui_context.application.id == "app-123" - assert ui_context.application.name == "My App" - assert ui_context.database is None - assert ui_context.table is None def test_ui_context_serialization_excludes_none_values(self): - """Test that UIContext serialization excludes None values""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test Workspace"), user=UserUIContext(id=1, name="Test", email="test@test.com"), - # All other fields are None ) - # Serialize with exclude_none=True serialized = ui_context.model_dump(exclude_none=True) - assert "workspace" in serialized assert "user" in serialized assert "database" not in serialized assert "table" not in serialized - assert "view" not in serialized - assert "application" not in serialized - - def test_ui_context_json_serialization_excludes_none(self): - """Test that UIContext JSON serialization excludes None values""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - table=TableUIContext(id=456, name="Customers"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - # database and view are None - ) - - # Serialize to JSON with exclude_none=True - json_str = ui_context.model_dump_json(exclude_none=True) - - # Parse back to verify - import json - - parsed = json.loads(json_str) - - assert "workspace" in parsed - assert "table" in parsed - assert "user" in parsed - assert "database" not in parsed - assert "view" not in parsed - - def test_human_message_with_ui_context(self): - """Test that HumanMessage correctly stores ui_context""" - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - database=ApplicationUIContext(id="db-123", name="My Database"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - ) - - human_message = HumanMessage( - content="How do I create a field?", ui_context=ui_context - ) - - assert human_message.content == "How do I create a field?" - assert human_message.ui_context.workspace.id == 1 - assert human_message.ui_context.database.id == "db-123" - assert human_message.ui_context.database.name == "My Database" - - def test_human_message_ui_context_json_serialization(self): - """ - Test that HumanMessage ui_context serializes to JSON with None - values excluded - """ - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=1, name="Test Workspace"), - database=ApplicationUIContext(id="db-123", name="My Database"), - table=TableUIContext(id=456, name="Customers"), - user=UserUIContext(id=1, name="Test", email="test@test.com"), - # view is None - ) - - human_message = HumanMessage( - content="How do I filter this view?", ui_context=ui_context - ) - - # Serialize ui_context as it would be in the prompt - ui_context_json = human_message.ui_context.model_dump_json(exclude_none=True) - - # Parse to verify - import json - - parsed = json.loads(ui_context_json) - - # Should have database and table but not view - assert "database" in parsed - assert parsed["database"]["name"] == "My Database" - assert "table" in parsed - assert parsed["table"]["name"] == "Customers" - assert "view" not in parsed # None values excluded def test_ui_context_has_default_timestamp(self): - """Test that UIContext has a default timestamp""" + from datetime import datetime ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test"), @@ -938,14 +631,9 @@ def test_ui_context_has_default_timestamp(self): ) assert ui_context.timestamp is not None - # Should be a datetime object - from datetime import datetime - assert isinstance(ui_context.timestamp, datetime) def test_ui_context_has_default_timezone(self): - """Test that UIContext has a default timezone of UTC""" - ui_context = UIContext( workspace=WorkspaceUIContext(id=1, name="Test"), user=UserUIContext(id=1, name="Test", email="test@test.com"), @@ -954,8 +642,6 @@ def test_ui_context_has_default_timezone(self): assert ui_context.timezone == "UTC" def test_user_ui_context_from_user(self, enterprise_data_fixture): - """Test UserUIContext.from_user factory method""" - user = enterprise_data_fixture.create_user( email="john@example.com", first_name="John Doe" ) @@ -966,173 +652,55 @@ def test_user_ui_context_from_user(self, enterprise_data_fixture): assert user_context.name == "John Doe" assert user_context.email == "john@example.com" - -@pytest.mark.django_db -class TestAssistantCancellation: - """Test cancellation functionality in Assistant""" - - def test_get_cancellation_cache_key(self, enterprise_data_fixture): - """Test that cancellation cache key is correctly formatted""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" - ) - - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() - - assert cache_key == f"assistant:chat:{chat.uuid}:cancelled" - - def test_check_cancellation_raises_when_flag_set(self, enterprise_data_fixture): - """Test that check_cancellation raises exception when flag is set""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + def test_human_message_with_ui_context(self): + ui_context = UIContext( + workspace=WorkspaceUIContext(id=1, name="Test Workspace"), + database=ApplicationUIContext(id="db-123", name="My Database"), + user=UserUIContext(id=1, name="Test", email="test@test.com"), ) - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() - - # Set cancellation flag - cache.set(cache_key, True) - - # Should raise exception - with pytest.raises(AssistantMessageCancelled) as exc_info: - assistant._check_cancellation(cache_key, "msg123") - - assert exc_info.value.message_id == "msg123" - - # Flag should be cleaned up - assert cache.get(cache_key) is None - - def test_check_cancellation_does_nothing_when_no_flag( - self, enterprise_data_fixture - ): - """Test that check_cancellation does nothing when flag not set""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + human_message = HumanMessage( + content="How do I create a field?", ui_context=ui_context ) - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() + assert human_message.content == "How do I create a field?" + assert human_message.ui_context.workspace.id == 1 + assert human_message.ui_context.database.id == "db-123" - # Should not raise - assistant._check_cancellation(cache_key, "msg123") - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_yields_ai_started_message( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture - ): - """Test that astream_messages yields AiStartedMessage at the beginning""" +@pytest.mark.django_db +class TestAssistantCancellation: + """Test cancellation functionality in Assistant.""" + def test_get_cancellation_cache_key(self, enterprise_data_fixture): user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) chat = AssistantChat.objects.create( user=user, workspace=workspace, title="Test" ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() - - # Mock the agent streaming - async def mock_agent_stream(*args, **kwargs): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta="Hello", - content="Hello", - is_complete=False, - ) - yield Prediction(answer="Hello there!", trajectory=[], reasoning="") - - mock_react_astream.return_value = mock_agent_stream() - - assistant = Assistant(chat) - human_message = HumanMessage(content="Hello") - - # Collect messages - async def collect_messages(): - messages = [] - async for msg in assistant.astream_messages(human_message): - messages.append(msg) - return messages - - messages = async_to_sync(collect_messages)() - - # First message should be AiStartedMessage - assert len(messages) > 0 - assert isinstance(messages[0], AiStartedMessage) - assert messages[0].message_id is not None - - @patch("udspy.ChainOfThought.astream") - @patch("udspy.ReAct.astream") - def test_astream_messages_checks_cancellation_periodically( - self, mock_react_astream, mock_cot_astream, enterprise_data_fixture - ): - """Test that astream_messages checks for cancellation every 10 chunks""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Test" + from baserow_enterprise.assistant.assistant import ( + get_assistant_cancellation_key, ) - # Mock the router stream - async def mock_router_stream(*args, **kwargs): - yield Prediction( - routing_decision="delegate_to_agent", - extracted_context="", - search_query="", - ) - - mock_cot_astream.return_value = mock_router_stream() + cache_key = get_assistant_cancellation_key(str(chat.uuid)) + assert cache_key == f"assistant:chat:{chat.uuid}:cancelled" - # Mock the stream to return many chunks - enough to trigger check at 10 - async def mock_agent_stream(*args, **kwargs): - # Yield 15 chunks - cancellation check happens at chunk 10 - for i in range(15): - yield OutputStreamChunk( - module=None, - field_name="answer", - delta=f"word{i}", - content=f"word{i}", - is_complete=False, - ) - yield Prediction(answer="Complete response", trajectory=[], reasoning="") - mock_react_astream.return_value = mock_agent_stream() +class TestGetModelString: + """Test the model string conversion logic.""" - assistant = Assistant(chat) - cache_key = assistant._get_cancellation_cache_key() + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="groq/llama-3.3-70b") + def test_replaces_slash_with_colon(self): + assert get_model_string() == "groq:llama-3.3-70b" - # Set cancellation flag immediately - it should be detected at chunk 10 - cache.set(cache_key, True) - - ui_context = UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ) - human_message = HumanMessage(content="Hello", ui_context=ui_context) + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="openai/gpt-4") + def test_openai_model(self): + assert get_model_string() == "openai:gpt-4" - # Should raise AssistantMessageCancelled when check happens at chunk 10 - async def stream_messages(): - async for msg in assistant.astream_messages(human_message): - pass + @override_settings(BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL="gpt-4o") + def test_bare_model_defaults_to_openai(self): + assert get_model_string() == "openai:gpt-4o" - with pytest.raises(AssistantMessageCancelled): - async_to_sync(stream_messages)() + def test_explicit_model_overrides_setting(self): + assert get_model_string("groq/custom-model") == "groq:custom-model" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py new file mode 100644 index 0000000000..c57f35cf1d --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_node_tools.py @@ -0,0 +1,304 @@ +import pytest + +from baserow.contrib.automation.nodes.models import AutomationNode +from baserow.contrib.automation.nodes.service import AutomationNodeService +from baserow_enterprise.assistant.tools.automation.tools import ( + add_nodes, + create_workflows, + delete_nodes, + list_nodes, + update_nodes, +) +from baserow_enterprise.assistant.tools.automation.types import ( + ActionNodeCreate, + NodeUpdate, + TriggerNodeCreate, + WorkflowCreate, +) + +from .utils import make_test_ctx + + +@pytest.fixture(autouse=True) +def mock_formula_generator(monkeypatch): + """Mock update_workflow_formulas and update_single_node_formulas to avoid LM calls.""" + + monkeypatch.setattr( + "baserow_enterprise.assistant.tools.automation.agents.update_workflow_formulas", + lambda workflow, node_mapping, tool_helpers: None, + ) + monkeypatch.setattr( + "baserow_enterprise.assistant.tools.automation.agents.update_single_node_formulas", + lambda node_update, orm_node, tool_helpers: None, + ) + + +def _create_test_workflow(data_fixture, user, workspace): + """Create a workflow with a trigger and an email action node.""" + automation = data_fixture.create_automation_application( + user=user, workspace=workspace + ) + + ctx = make_test_ctx(user, workspace) + result = create_workflows( + ctx, + automation_id=automation.id, + workflows=[ + WorkflowCreate( + name="Test Workflow", + trigger=TriggerNodeCreate( + ref="trigger1", + label="Periodic Trigger", + type="periodic", + periodic_interval={"interval": "DAY"}, + ), + nodes=[ + ActionNodeCreate( + ref="email1", + label="Send Email", + previous_node_ref="trigger1", + type="smtp_email", + to_emails="test@example.com", + subject="Hello", + body="World", + ), + ], + ) + ], + thought="test", + ) + + workflow_id = result["created_workflows"][0]["id"] + return automation, workflow_id + + +@pytest.mark.django_db(transaction=True) +def test_list_nodes(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + ctx = make_test_ctx(user, workspace) + result = list_nodes(ctx, workflow_id=workflow_id, thought="inspect") + + nodes = result["nodes"] + assert len(nodes) == 2 + + # First node is the trigger + assert nodes[0]["label"] == "Periodic Trigger" + assert nodes[0]["type"] == "periodic" + + # Second node is the email action + assert nodes[1]["label"] == "Send Email" + assert nodes[1]["type"] == "smtp_email" + + # All nodes have IDs + assert all("id" in n for n in nodes) + + +@pytest.mark.django_db(transaction=True) +def test_add_node_after_existing(data_fixture): + """Add a router node between the trigger and existing email node.""" + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + # Get existing nodes + ctx = make_test_ctx(user, workspace) + existing = list_nodes(ctx, workflow_id=workflow_id, thought="check") + trigger_id = existing["nodes"][0]["id"] + email_id = existing["nodes"][1]["id"] + + # Delete the existing email node first (we'll re-add it after the router) + delete_nodes( + ctx, node_ids=[email_id], thought="remove email to re-add after router" + ) + + # Add a router after the trigger, then a new email after the router + result = add_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + ActionNodeCreate( + ref="router1", + label="My Router", + type="router", + previous_node_ref=str(trigger_id), + edges=[ + {"label": "always", "condition": "true"}, + ], + ), + ActionNodeCreate( + ref="slack1", + label="Send Slack After Router", + type="smtp_email", + previous_node_ref="router1", + router_edge_label="always", + to_emails="test@example.com", + subject="Hello", + body="Routed message", + ), + ], + thought="insert router between trigger and email", + ) + + assert len(result["created_nodes"]) == 2 + assert result["created_nodes"][0]["type"] == "router" + assert result["created_nodes"][0]["label"] == "My Router" + assert result["created_nodes"][1]["label"] == "Send Slack After Router" + + # Verify final workflow order + final = list_nodes(ctx, workflow_id=workflow_id, thought="verify") + assert len(final["nodes"]) == 3 + assert final["nodes"][0]["type"] == "periodic" + assert final["nodes"][1]["type"] == "router" + assert final["nodes"][2]["type"] == "smtp_email" + + +@pytest.mark.django_db(transaction=True) +def test_add_node_append_to_workflow(data_fixture): + """Append a new action node at the end of an existing workflow.""" + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + ctx = make_test_ctx(user, workspace) + existing = list_nodes(ctx, workflow_id=workflow_id, thought="check") + email_id = existing["nodes"][1]["id"] + + # Append a new email node after the existing email node + result = add_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + ActionNodeCreate( + ref="email1", + label="Follow-up Email", + type="smtp_email", + previous_node_ref=str(email_id), + to_emails="followup@example.com", + subject="Follow-up", + body="This is a follow-up.", + ), + ], + thought="append email after email", + ) + + assert len(result["created_nodes"]) == 1 + assert result["created_nodes"][0]["label"] == "Follow-up Email" + + # Verify workflow now has 3 nodes + final = list_nodes(ctx, workflow_id=workflow_id, thought="verify") + assert len(final["nodes"]) == 3 + assert final["nodes"][2]["type"] == "smtp_email" + assert final["nodes"][2]["label"] == "Follow-up Email" + + +@pytest.mark.django_db(transaction=True) +def test_update_node_label(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + # Get the action node + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] # The email action node + + ctx = make_test_ctx(user, workspace) + result = update_nodes( + ctx, + workflow_id=workflow_id, + nodes=[NodeUpdate(node_id=action_node.id, label="Updated Email")], + thought="rename node", + ) + + assert result["updated_nodes"][0]["label"] == "Updated Email" + + # Verify in DB + refreshed = AutomationNodeService().get_node(user, action_node.id) + assert refreshed.label == "Updated Email" + + +@pytest.mark.django_db(transaction=True) +def test_update_node_service_config(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + ctx = make_test_ctx(user, workspace) + result = update_nodes( + ctx, + workflow_id=workflow_id, + nodes=[ + NodeUpdate( + node_id=action_node.id, + subject="New Subject", + ) + ], + thought="update email subject", + ) + + assert len(result["updated_nodes"]) == 1 + assert "errors" not in result + + +@pytest.mark.django_db(transaction=True) +def test_delete_node(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + ctx = make_test_ctx(user, workspace) + result = delete_nodes( + ctx, + node_ids=[action_node.id], + thought="delete node", + ) + + assert result["deleted_node_ids"] == [action_node.id] + + # Node should be gone + assert not AutomationNode.objects.filter(id=action_node.id).exists() + + +@pytest.mark.django_db(transaction=True) +def test_delete_node_wrong_workspace(data_fixture): + user = data_fixture.create_user() + workspace1 = data_fixture.create_workspace(user=user) + workspace2 = data_fixture.create_workspace(user=user) + automation, workflow_id = _create_test_workflow(data_fixture, user, workspace1) + + from baserow.contrib.automation.workflows.service import AutomationWorkflowService + + workflow = AutomationWorkflowService().get_workflow(user, workflow_id) + nodes = list(workflow.automation_workflow_nodes.all().order_by("id")) + action_node = nodes[-1] + + # Try to delete from wrong workspace + ctx = make_test_ctx(user, workspace2) + result = delete_nodes( + ctx, + node_ids=[action_node.id], + thought="delete from wrong workspace", + ) + + assert result["deleted_node_ids"] == [] + assert len(result["errors"]) == 1 + + # Node should still exist + assert AutomationNode.objects.filter(id=action_node.id).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py index 1e8ee8f8b0..05a7f75057 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_automation_workflow_tools.py @@ -1,28 +1,25 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.automation.workflows.handler import AutomationWorkflowHandler from baserow.core.formula import resolve_formula from baserow.core.formula.registries import formula_runtime_function_registry from baserow.core.formula.types import BASEROW_FORMULA_MODE_ADVANCED +from baserow_enterprise.assistant.tools.automation.agents import AssistantFormulaContext from baserow_enterprise.assistant.tools.automation.tools import ( - get_list_workflows_tool, - get_workflow_tool_factory, + create_workflows, + list_workflows, ) from baserow_enterprise.assistant.tools.automation.types import ( - CreateRowActionCreate, - DeleteRowActionCreate, - RouterNodeCreate, + ActionNodeCreate, TriggerNodeCreate, - UpdateRowActionCreate, WorkflowCreate, ) -from baserow_enterprise.assistant.tools.automation.types.node import RouterEdgeCreate -from baserow_enterprise.assistant.tools.automation.utils import AssistantFormulaContext +from baserow_enterprise.assistant.tools.automation.types.node import ( + AutomationFieldValue, + RouterEdgeCreate, +) -from .utils import fake_tool_helpers +from .utils import make_test_ctx @pytest.fixture(autouse=True) @@ -38,7 +35,7 @@ def mock_update_workflow_formulas(workflow, node_mapping, tool_helpers): pass monkeypatch.setattr( - "baserow_enterprise.assistant.tools.automation.utils.update_workflow_formulas", + "baserow_enterprise.assistant.tools.automation.agents.update_workflow_formulas", mock_update_workflow_formulas, ) @@ -54,8 +51,8 @@ def test_list_workflows(data_fixture): automation=automation, name="Test Workflow" ) - tool = get_list_workflows_tool(user, workspace, fake_tool_helpers) - result = tool(automation_id=automation.id) + ctx = make_test_ctx(user, workspace) + result = list_workflows(ctx, automation_id=automation.id, thought="test") assert result == { "workflows": [{"id": workflow.id, "name": "Test Workflow", "state": "draft"}] @@ -76,8 +73,8 @@ def test_list_workflows_multiple(data_fixture): automation=automation, name="Workflow 2" ) - tool = get_list_workflows_tool(user, workspace, fake_tool_helpers) - result = tool(automation_id=automation.id) + ctx = make_test_ctx(user, workspace) + result = list_workflows(ctx, automation_id=automation.id, thought="test") assert result == { "workflows": [ @@ -97,25 +94,10 @@ def test_create_workflows(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -124,19 +106,21 @@ def test_create_workflows(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -144,8 +128,6 @@ def test_create_workflows(data_fixture): assert result["created_workflows"][0]["state"] == "draft" # Verify workflow was created with a trigger - from baserow.contrib.automation.workflows.handler import AutomationWorkflowHandler - workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) trigger = workflow.get_trigger() @@ -163,25 +145,10 @@ def test_create_multiple_workflows(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -190,15 +157,16 @@ def test_create_multiple_workflows(data_fixture): ref="trigger1", label="Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Action", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ), @@ -208,19 +176,21 @@ def test_create_multiple_workflows(data_fixture): ref="trigger2", label="Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action2", label="Action", previous_node_ref="trigger2", type="create_row", table_id=table.id, - values={}, + values=[], ) ], ), ], + thought="test", ) assert len(result["created_workflows"]) == 2 @@ -234,36 +204,45 @@ def test_create_multiple_workflows(data_fixture): [ ( TriggerNodeCreate( - type="rows_created", ref="trigger", label="Rows Created Trigger" + type="rows_created", + ref="trigger", + label="Rows Created Trigger", + rows_triggers_settings={"table_id": 999}, ), - CreateRowActionCreate( + ActionNodeCreate( type="create_row", ref="action", previous_node_ref="trigger", label="Create Row Action", table_id=999, - values={}, + values=[], ), ), ( TriggerNodeCreate( - type="rows_updated", ref="trigger", label="Rows Updated Trigger" + type="rows_updated", + ref="trigger", + label="Rows Updated Trigger", + rows_triggers_settings={"table_id": 999}, ), - UpdateRowActionCreate( + ActionNodeCreate( type="update_row", ref="action", previous_node_ref="trigger", label="Update Row Action", table_id=999, row_id="1", - values={}, + values=[], ), ), ( TriggerNodeCreate( - type="rows_deleted", ref="trigger", label="Rows Deleted Trigger" + type="rows_deleted", + ref="trigger", + label="Rows Deleted Trigger", + rows_triggers_settings={"table_id": 999}, ), - DeleteRowActionCreate( + ActionNodeCreate( type="delete_row", ref="action", previous_node_ref="trigger", @@ -285,25 +264,10 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac table.pk = 999 # To match the action's table_id table.save() - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -312,6 +276,7 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac nodes=[action], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -328,7 +293,7 @@ def test_create_workflow_with_row_triggers_and_actions(data_fixture, trigger, ac @pytest.mark.django_db(transaction=True) def test_create_row_action_with_field_ids(data_fixture): - """Test CreateRowActionCreate uses field IDs in values dict, not field names.""" + """Test ActionNodeCreate uses field IDs in values dict, not field names.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -340,25 +305,10 @@ def test_create_row_action_with_field_ids(data_fixture): text_field = data_fixture.create_text_field(table=table, name="Name") number_field = data_fixture.create_number_field(table=table, name="Age") - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None + ctx = make_test_ctx(user, workspace) - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -367,22 +317,26 @@ def test_create_row_action_with_field_ids(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row with field IDs", previous_node_ref="trigger1", type="create_row", table_id=table.id, - values={ - text_field.id: "John Doe", - number_field.id: 25, - }, + values=[ + AutomationFieldValue( + field_id=text_field.id, value="John Doe" + ), + AutomationFieldValue(field_id=number_field.id, value="25"), + ], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 @@ -400,7 +354,7 @@ def test_create_row_action_with_field_ids(data_fixture): @pytest.mark.django_db(transaction=True) def test_update_row_action_with_row_id_and_field_ids(data_fixture): - """Test UpdateRowActionCreate uses row_id parameter and field IDs in values.""" + """Test ActionNodeCreate uses row_id parameter and field IDs in values.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -411,25 +365,10 @@ def test_update_row_action_with_row_id_and_field_ids(data_fixture): table = data_fixture.create_database_table(user=user, database=database) text_field = data_fixture.create_text_field(table=table, name="Status") - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -438,42 +377,44 @@ def test_update_row_action_with_row_id_and_field_ids(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - UpdateRowActionCreate( + ActionNodeCreate( ref="action1", label="Update row", previous_node_ref="trigger1", type="update_row", table_id=table.id, row_id="123", - values={text_field.id: "completed"}, + values=[ + AutomationFieldValue( + field_id=text_field.id, value="completed" + ) + ], ) ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) - # Get the action node and verify it was created with the correct table - # Note: row_id formula generation occurs in a separate transaction and may fail - # if DSPy is not configured, so we only verify basic service configuration action_nodes = workflow.automation_workflow_nodes.exclude( id=workflow.get_trigger().id ) assert action_nodes.count() == 1 action_node = action_nodes.first() assert action_node.service.specific.table_id == table.id - # Verify the service type is correct for upsert_row (update operation) assert action_node.service.get_type().type == "local_baserow_upsert_row" @pytest.mark.django_db(transaction=True) def test_delete_row_action_with_row_id(data_fixture): - """Test DeleteRowActionCreate uses row_id parameter.""" + """Test ActionNodeCreate uses row_id parameter.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -483,25 +424,10 @@ def test_delete_row_action_with_row_id(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) + ctx = make_test_ctx(user, workspace) - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None - - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -510,9 +436,10 @@ def test_delete_row_action_with_row_id(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - DeleteRowActionCreate( + ActionNodeCreate( ref="action1", label="Delete row", previous_node_ref="trigger1", @@ -523,28 +450,25 @@ def test_delete_row_action_with_row_id(data_fixture): ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 workflow_id = result["created_workflows"][0]["id"] workflow = AutomationWorkflowHandler().get_workflow(workflow_id) - # Get the action node and verify it was created with the correct table - # Note: row_id formula generation occurs in a separate transaction and may fail - # if DSPy is not configured, so we only verify basic service configuration action_nodes = workflow.automation_workflow_nodes.exclude( id=workflow.get_trigger().id ) assert action_nodes.count() == 1 action_node = action_nodes.first() assert action_node.service.specific.table_id == table.id - # Verify the service type is correct for delete_row assert action_node.service.get_type().type == "local_baserow_delete_row" @pytest.mark.django_db(transaction=True) def test_router_node_with_required_conditions(data_fixture): - """Test RouterNodeCreate requires condition field for each edge.""" + """Test ActionNodeCreate requires condition field for each edge.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -554,25 +478,10 @@ def test_router_node_with_required_conditions(data_fixture): database = data_fixture.create_database_application(user=user, workspace=workspace) table = data_fixture.create_database_table(user=user, database=database) - factory = get_workflow_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_workflows_tool = next( - (tool for tool in added_tools if tool.name == "create_workflows"), None - ) - assert create_workflows_tool is not None + ctx = make_test_ctx(user, workspace) - result = create_workflows_tool.func( + result = create_workflows( + ctx, automation_id=automation.id, workflows=[ WorkflowCreate( @@ -581,9 +490,10 @@ def test_router_node_with_required_conditions(data_fixture): ref="trigger1", label="Periodic Trigger", type="periodic", + periodic_interval={"interval": "DAY"}, ), nodes=[ - RouterNodeCreate( + ActionNodeCreate( ref="router1", label="Router", previous_node_ref="trigger1", @@ -599,17 +509,18 @@ def test_router_node_with_required_conditions(data_fixture): ), ], ), - CreateRowActionCreate( + ActionNodeCreate( ref="action1", label="Create row", previous_node_ref="router1", type="create_row", table_id=table.id, - values={}, + values=[], ), ], ) ], + thought="test", ) assert len(result["created_workflows"]) == 1 diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py new file mode 100644 index 0000000000..3054689274 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_core_tools.py @@ -0,0 +1,116 @@ +import pytest + +from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.core.tools import ( + create_builders, + list_builders, +) +from baserow_enterprise.assistant.tools.core.types import BuilderItemCreate + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_list_builders_all(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + db = data_fixture.create_database_application(workspace=workspace, name="My DB") + automation = data_fixture.create_automation_application( + workspace=workspace, name="My Automation" + ) + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert "database" in result + assert any(b["name"] == "My DB" for b in result["database"]) + assert "automation" in result + assert any(b["name"] == "My Automation" for b in result["automation"]) + + +@pytest.mark.django_db +def test_list_builders_filter_by_type(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + data_fixture.create_database_application(workspace=workspace, name="DB 1") + data_fixture.create_automation_application(workspace=workspace, name="Auto 1") + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=["database"], thought="databases only") + + assert "database" in result + assert "automation" not in result + + +@pytest.mark.django_db +def test_list_builders_empty(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert result == {} + + +@pytest.mark.django_db +def test_list_builders_truncation(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + for i in range(25): + data_fixture.create_database_application(workspace=workspace, name=f"DB {i}") + + ctx = make_test_ctx(user, workspace) + result = list_builders(ctx, builder_types=None, thought="list all") + + assert "_info" in result + assert len(result["database"]) == 20 + + +@pytest.mark.django_db +def test_create_builders_database(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [BuilderItemCreate(name="New Database", type="database")] + result = create_builders(ctx, builders=builders, thought="create db") + + assert len(result["created_builders"]) == 1 + created = result["created_builders"][0] + assert created["name"] == "New Database" + assert created["type"] == "database" + assert created["id"] == AnyInt() + + +@pytest.mark.django_db +def test_create_builders_multiple(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [ + BuilderItemCreate(name="DB One", type="database"), + BuilderItemCreate(name="DB Two", type="database"), + ] + result = create_builders(ctx, builders=builders, thought="create two dbs") + + assert len(result["created_builders"]) == 2 + names = [b["name"] for b in result["created_builders"]] + assert "DB One" in names + assert "DB Two" in names + + +@pytest.mark.django_db +def test_create_database_ignores_theme(data_fixture): + """Creating a database should not fail even though databases have no theme.""" + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + builders = [BuilderItemCreate(name="My DB", type="database")] + result = create_builders(ctx, builders=builders, thought="create db") + + assert len(result["created_builders"]) == 1 + assert result["created_builders"][0]["type"] == "database" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py new file mode 100644 index 0000000000..7a96240ecc --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_field_tools.py @@ -0,0 +1,151 @@ +import pytest + +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import Field +from baserow_enterprise.assistant.tools.database.tools import ( + delete_fields, + update_fields, +) +from baserow_enterprise.assistant.tools.database.types import ( + FieldItemUpdate, + SelectOptionCreate, +) + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_update_field_name(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Old Name") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id, name="New Name")], + thought="rename field", + ) + + assert result["updated_fields"][0]["name"] == "New Name" + assert result["updated_fields"][0]["id"] == field.id + + # Verify in DB + refreshed = FieldHandler().get_field(field.id) + assert refreshed.name == "New Name" + + +@pytest.mark.django_db +def test_update_number_field_decimal_places(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field( + table=table, name="Price", number_decimal_places=0 + ) + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id, decimal_places=2)], + thought="change decimal places", + ) + + assert result["updated_fields"][0]["decimal_places"] == 2 + + +@pytest.mark.django_db +def test_update_select_field_options(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_single_select_field(table=table, name="Status") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[ + FieldItemUpdate( + field_id=field.id, + options=[ + SelectOptionCreate(value="Open", color="green"), + SelectOptionCreate(value="Closed", color="red"), + ], + ) + ], + thought="add options", + ) + + updated = result["updated_fields"][0] + assert len(updated["options"]) == 2 + option_values = {o["value"] for o in updated["options"]} + assert option_values == {"Open", "Closed"} + + +@pytest.mark.django_db +def test_update_field_no_changes(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Unchanged") + + ctx = make_test_ctx(user, workspace) + result = update_fields( + ctx, + fields=[FieldItemUpdate(field_id=field.id)], + thought="no changes", + ) + + assert result["updated_fields"][0]["name"] == "Unchanged" + + +@pytest.mark.django_db +def test_delete_field(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="To Delete") + + ctx = make_test_ctx(user, workspace) + result = delete_fields( + ctx, + field_ids=[field.id], + thought="delete field", + ) + + assert result["deleted_field_ids"] == [field.id] + + # Field should be trashed + assert not Field.objects.filter(id=field.id).exists() + assert Field.objects_and_trash.filter(id=field.id, trashed=True).exists() + + +@pytest.mark.django_db +def test_delete_primary_field_fails(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + primary_field = data_fixture.create_text_field( + table=table, name="Primary", primary=True + ) + + ctx = make_test_ctx(user, workspace) + result = delete_fields( + ctx, + field_ids=[primary_field.id], + thought="try delete primary", + ) + + assert result["deleted_field_ids"] == [] + assert len(result["errors"]) == 1 + + # Primary field should still exist + refreshed = FieldHandler().get_field(primary_field.id) + assert refreshed.primary is True diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py index 004477d9e6..306c06289a 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py @@ -1,15 +1,12 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.rows.handler import RowHandler from baserow_enterprise.assistant.tools.database.tools import ( - get_list_rows_tool, - get_rows_tools_factory, + list_rows, + load_row_tools, ) -from .utils import fake_tool_helpers +from .utils import make_test_ctx def _create_simple_database_with_linked_tables_and_rows(data_fixture): @@ -134,20 +131,20 @@ def test_list_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) - assert callable(list_table_rows) + ctx = make_test_ctx(user, workspace) - result = list_table_rows(table_id=table.id, offset=0, limit=50) + result = list_rows( + ctx, table_id=table.id, offset=0, limit=50, field_ids=None, thought="test" + ) rows = result["rows"] assert len(rows) == 3 assert rows[0] == { "primary": "Row A1", "Long text field": "Long text A1", "Number field": 10.123, - "Date field": {"year": 2023, "month": 1, "day": 1}, - "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Date field": "2023-01-01", + "Datetime field": "2023-01-01T10:00", "Single link to B": "Row B1", "Multiple select": ["Option A", "Option B"], "Text field": "Text A1", @@ -159,8 +156,8 @@ def test_list_rows(data_fixture): "primary": "Row A2", "Long text field": "Long text A2", "Number field": 20.456, - "Date field": {"year": 2023, "month": 2, "day": 1}, - "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Date field": "2023-02-01", + "Datetime field": "2023-02-01T11:00", "Single link to B": "Row B2", "Multiple select": ["Option B", "Option C"], "Text field": "Text A2", @@ -183,8 +180,13 @@ def test_list_rows(data_fixture): } # List a single field - result = list_table_rows( - table_id=table.id, offset=0, limit=50, field_ids=[table.get_primary_field().id] + result = list_rows( + ctx, + table_id=table.id, + offset=0, + limit=50, + field_ids=[table.get_primary_field().id], + thought="test", ) rows = result["rows"] assert len(rows) == 3 @@ -209,24 +211,19 @@ def test_create_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) + ctx = make_test_ctx(user, workspace) - tools_upgrade = meta_tool([table.id], ["create"]) - assert is_module_callback(tools_upgrade) + observation = load_row_tools(ctx, [table.id], ["create"], thought="test") + assert isinstance(observation, str) + assert f"create_rows_in_table_{table.id}" in observation - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + # Tools should be stored in ctx.deps.dynamic_tools + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"create_rows_in_table_{table.id}" in added_tools_names + create_tool = dynamic_tools[0] + assert create_tool.name == f"create_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 @@ -236,14 +233,8 @@ def test_create_rows(data_fixture): "Text field": "Text A3", "Long text field": "Long text A3", "Number field": 30.789, - "Date field": {"year": 2023, "month": 3, "day": 1}, - "Datetime field": { - "year": 2023, - "month": 3, - "day": 1, - "hour": 12, - "minute": 0, - }, + "Date field": "2023-03-01", + "Datetime field": "2023-03-01T12:00", "Single select": "Option 1", "Multiple select": ["Option A", "Option C"], "Single link to B": "Row B3", @@ -261,8 +252,12 @@ def test_create_rows(data_fixture): "Single link to B": None, "link": [], } - create_table_rows = added_tools[0] - result = create_table_rows(rows=[row_1, row_2]) + # Validate dicts through the tool's schema (as pydantic-ai would), + # then call the underlying function. + validated_args = create_tool.function_schema.validator.validate_python( + {"rows": [row_1, row_2], "thought": "test"} + ) + result = create_tool.function(**validated_args) created_row_ids = result["created_row_ids"] assert len(created_row_ids) == 2 assert created_row_ids == [4, 5] @@ -275,28 +270,22 @@ def test_update_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) - tools_upgrade = meta_tool([table.id], ["update"]) - assert is_module_callback(tools_upgrade) + ctx = make_test_ctx(user, workspace) + + observation = load_row_tools(ctx, [table.id], ["update"], thought="test") + assert isinstance(observation, str) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"update_rows_in_table_{table.id}" in added_tools_names + update_tool = dynamic_tools[0] + assert update_tool.name == f"update_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 - # Update row 1 with new values + # Update row 1 — only pass fields to change, omit the rest row_1_updates = { "id": 1, "primary": "Updated Row A1", @@ -305,41 +294,33 @@ def test_update_rows(data_fixture): "Single select": "Option 2", "link": ["Row B3"], "Single link to B": "Row B2", - "Datetime field": "__NO_CHANGE__", - "Date field": "__NO_CHANGE__", - "Multiple select": "__NO_CHANGE__", - "Long text field": "__NO_CHANGE__", } - # Update row 2 with new values + # Update row 2 — only pass fields to change row_2_updates = { "id": 2, - "Single link to B": "__NO_CHANGE__", "Long text field": "Updated Long text A2", - "Date field": {"year": 2024, "month": 12, "day": 31}, + "Date field": "2024-12-31", "Multiple select": ["Option A"], - "primary": "__NO_CHANGE__", - "Text field": "__NO_CHANGE__", - "Number field": "__NO_CHANGE__", - "Datetime field": "__NO_CHANGE__", - "Single select": "__NO_CHANGE__", - "link": "__NO_CHANGE__", } - update_table_rows = added_tools[0] - result = update_table_rows(rows=[row_1_updates, row_2_updates]) + validated_args = update_tool.function_schema.validator.validate_python( + {"rows": [row_1_updates, row_2_updates], "thought": "test"} + ) + result = update_tool.function(**validated_args) updated_row_ids = result["updated_row_ids"] assert len(updated_row_ids) == 2 assert updated_row_ids == [1, 2] # Verify the rows were updated correctly - list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) - row_1, row_2 = list_table_rows(table_id=table.id, offset=0, limit=2)["rows"] + row_1, row_2 = list_rows( + ctx, table_id=table.id, offset=0, limit=2, field_ids=None, thought="test" + )["rows"] assert row_1 == { "primary": "Updated Row A1", "Long text field": "Long text A1", "Number field": 99.999, - "Date field": {"year": 2023, "month": 1, "day": 1}, - "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Date field": "2023-01-01", + "Datetime field": "2023-01-01T10:00", "Single link to B": "Row B2", "Multiple select": ["Option A", "Option B"], "Text field": "Updated Text A1", @@ -351,8 +332,8 @@ def test_update_rows(data_fixture): "primary": "Row A2", "Long text field": "Updated Long text A2", "Number field": 20.456, - "Date field": {"year": 2024, "month": 12, "day": 31}, - "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Date field": "2024-12-31", + "Datetime field": "2023-02-01T11:00", "Single link to B": "Row B2", "Multiple select": ["Option A"], "Text field": "Text A2", @@ -369,29 +350,23 @@ def test_delete_rows(data_fixture): user = res["user"] workspace = res["workspace"] table = res["table_a"] - tool_helpers = fake_tool_helpers - - meta_tool = get_rows_tools_factory(user, workspace, tool_helpers) - assert callable(meta_tool) - - tools_upgrade = meta_tool([table.id], ["delete"]) - assert is_module_callback(tools_upgrade) - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - added_tools = mock_module.init_module.call_args[1]["tools"] - added_tools_names = [tool.name for tool in added_tools] - assert len(added_tools) == 1 - assert f"delete_rows_in_table_{table.id}" in added_tools_names - delete_table_rows = added_tools[0] + + ctx = make_test_ctx(user, workspace) + + observation = load_row_tools(ctx, [table.id], ["delete"], thought="test") + assert isinstance(observation, str) + + dynamic_tools = ctx.deps.dynamic_tools + assert len(dynamic_tools) == 1 + + delete_tool = dynamic_tools[0] + assert delete_tool.name == f"delete_rows_in_table_{table.id}" table_model = table.get_model() assert table_model.objects.count() == 3 # Delete rows with ids 1 and 3 - result = delete_table_rows(row_ids=[1, 3]) + result = delete_tool.function(row_ids=[1, 3], thought="test") assert result["deleted_row_ids"] == [1, 3] # Verify rows were deleted diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py index 0d1b052dd0..97c9560f35 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py @@ -1,35 +1,45 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.fields.models import FormulaField from baserow.contrib.database.formula.registries import formula_function_registry from baserow.contrib.database.table.models import Table from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.database.agents import FormulaGenerationResult from baserow_enterprise.assistant.tools.database.tools import ( - get_generate_database_formula_tool, - get_list_tables_tool, - get_table_and_fields_tools_factory, + create_fields, + create_tables, + generate_formula, + list_tables, ) from baserow_enterprise.assistant.tools.database.types import ( - BooleanFieldItemCreate, - DateFieldItemCreate, - FileFieldItemCreate, - LinkRowFieldItemCreate, + FieldItem, + FieldItemCreate, + InvalidFormulaFieldError, ListTablesFilterArg, - LongTextFieldItemCreate, - MultipleSelectFieldItemCreate, - NumberFieldItemCreate, - RatingFieldItemCreate, SelectOptionCreate, - SingleSelectFieldItemCreate, TableItemCreate, - TextFieldItemCreate, - field_item_registry, ) -from .utils import fake_tool_helpers +from .utils import make_test_ctx + + +def _make_mock_formula_result(**kwargs): + """Create a mock agent result with a FormulaGenerationResult output.""" + defaults = { + "table_id": 1, + "field_name": "test_formula", + "formula": "'ok'", + "formula_type": "text", + "is_formula_valid": True, + "error_message": "", + } + defaults.update(kwargs) + result = FormulaGenerationResult(**defaults) + mock_agent_result = MagicMock() + mock_agent_result.output = result + return mock_agent_result @pytest.mark.django_db @@ -46,138 +56,86 @@ def test_list_tables_tool(data_fixture): table_2 = data_fixture.create_database_table(database=database_1, name="Table 2") table_3 = data_fixture.create_database_table(database=database_2, name="Table 3") - tool = get_list_tables_tool(user, workspace, fake_tool_helpers) + ctx = make_test_ctx(user, workspace) - # Test 1: Filter by database_ids (single database) - returns flat list - response = tool( + # Test 1: Filter by database_id (single database) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id], - database_names=None, - table_ids=None, - table_names=None, - ) + database_id_or_name=database_1.id, + table_ids_or_names=None, + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 2: Filter by database_names (single database) - returns flat list - response = tool( + # Test 2: Filter by database_name (single database) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=["Database 2"], - table_ids=None, - table_names=None, - ) + database_id_or_name="Database 2", + table_ids_or_names=None, + ), ) assert response == [ {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, ] - # Test 3: Filter by multiple database_ids - returns database wrapper structure - response = tool( + # Test 4: Filter by database + table_ids - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id, database_2.id], - database_names=None, - table_ids=None, - table_names=None, - ) - ) - assert response == [ - { - "id": database_1.id, - "name": "Database 1", - "tables": [ - {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, - {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, - ], - }, - { - "id": database_2.id, - "name": "Database 2", - "tables": [ - {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, - ], - }, - ] - - # Test 4: Filter by table_ids (single database) - returns flat list - response = tool( - filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=[table_1.id, table_2.id], - table_names=None, - ) + database_id_or_name=database_1.id, + table_ids_or_names=[table_1.id, table_2.id], + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 5: Filter by table_names (single database) - returns flat list - response = tool( + # Test 5: Filter by database + table_names - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=None, - table_names=["Table 1"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Table 1"], + ), ) assert response == [ {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, ] - # Test 6: Filter by table_ids across multiple databases - returns database wrapper - response = tool( - filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=[table_1.id, table_3.id], - table_names=None, - ) - ) - assert response == [ - { - "id": database_1.id, - "name": "Database 1", - "tables": [ - {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, - ], - }, - { - "id": database_2.id, - "name": "Database 2", - "tables": [ - {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, - ], - }, - ] - - # Test 7: Combined filters (database_ids + table_names) - returns flat list - response = tool( + # Test 6: Combined filters (database_id + table_names) - returns flat list + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=[database_1.id], - database_names=None, - table_ids=None, - table_names=["Table 2"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Table 2"], + ), ) assert response == [ {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, ] - # Test 8: No matching tables - returns "No tables found" - response = tool( + # Test 7: No matching tables - returns hint with available tables + response = list_tables( + ctx, + thought="test", filters=ListTablesFilterArg( - database_ids=None, - database_names=None, - table_ids=None, - table_names=["Nonexistent Table"], - ) + database_id_or_name=database_1.id, + table_ids_or_names=["Nonexistent Table"], + ), ) - assert response == "No tables found" + info = response["_info"] + assert "no tables matching" in info or "No tables found" in info @pytest.mark.django_db @@ -188,44 +146,29 @@ def test_create_simple_table_tool(data_fixture): workspace=workspace, name="Database 1" ) - factory = get_table_and_fields_tools_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called + ctx = make_test_ctx(user, workspace) - added_tools = mock_module.init_module.call_args[1]["tools"] - assert len(added_tools) == 2 # create_tables and create_fields - - # Find the create_tables tool - create_tables_tool = next( - (tool for tool in added_tools if tool.name == "create_tables"), None - ) - assert create_tables_tool is not None - - # Call the underlying function directly (not through udspy.Tool wrapper) - response = create_tables_tool.func( + # Call the tool function directly + response = create_tables( + ctx, + thought="test", database_id=database.id, tables=[ TableItemCreate( name="New Table", - primary_field=TextFieldItemCreate(type="text", name="Name"), + primary_field_name="Name", fields=[], ) ], add_sample_rows=False, ) - assert response == { - "created_tables": [{"id": AnyInt(), "name": "New Table"}], - "notes": [], - } + assert len(response["created_tables"]) == 1 + assert response["created_tables"][0]["name"] == "New Table" + assert response["created_tables"][0]["id"] == AnyInt() + assert response["notes"] == [] + # Full schema is included so callers have field IDs + assert "primary_field" in response["created_tables"][0] # Ensure the table was actually created assert Table.objects.filter( @@ -242,66 +185,27 @@ def test_create_complex_table_tool(data_fixture): ) table = data_fixture.create_database_table(database=database, name="Table 1") - factory = get_table_and_fields_tools_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - assert len(added_tools) == 2 # create_tables and create_fields - - # Find the create_tables tool - create_tables_tool = next( - (tool for tool in added_tools if tool.name == "create_tables"), None - ) - assert create_tables_tool is not None + ctx = make_test_ctx(user, workspace) - primary_field = TextFieldItemCreate(type="text", name="Name") + primary_field_name = "Name" fields = [ - LongTextFieldItemCreate( - type="long_text", - name="Description", - rich_text=True, - ), - NumberFieldItemCreate( - type="number", - name="Amount", - decimal_places=2, - suffix="$", - ), - DateFieldItemCreate( - type="date", - name="Due Date", - include_time=False, - ), - DateFieldItemCreate( - type="date", - name="Event Time", - include_time=True, - ), - BooleanFieldItemCreate( - type="boolean", - name="Done?", - ), - SingleSelectFieldItemCreate( - type="single_select", + FieldItemCreate(name="Description", type="long_text", rich_text=True), + FieldItemCreate(name="Amount", type="number", decimal_places=2, suffix="$"), + FieldItemCreate(name="Due Date", type="date", include_time=False), + FieldItemCreate(name="Event Time", type="date", include_time=True), + FieldItemCreate(name="Done?", type="boolean"), + FieldItemCreate( name="Status", + type="single_select", options=[ SelectOptionCreate(value="New", color="blue"), SelectOptionCreate(value="In Progress", color="yellow"), SelectOptionCreate(value="Done", color="green"), ], ), - MultipleSelectFieldItemCreate( - type="multiple_select", + FieldItemCreate( name="Tags", + type="multiple_select", options=[ SelectOptionCreate(value="Red", color="red"), SelectOptionCreate(value="Yellow", color="yellow"), @@ -309,38 +213,36 @@ def test_create_complex_table_tool(data_fixture): SelectOptionCreate(value="Blue", color="blue"), ], ), - LinkRowFieldItemCreate( - type="link_row", + FieldItemCreate( name="Related Items", + type="link_row", linked_table=table.id, ), - RatingFieldItemCreate( - type="rating", - name="Rating", - max_value=5, - ), - FileFieldItemCreate( - type="file", - name="Attachments", - ), + FieldItemCreate(name="Rating", type="rating", max_value=5), + FieldItemCreate(name="Attachments", type="file"), ] - # Call the underlying function directly (not through udspy.Tool wrapper) - response = create_tables_tool.func( + # Call the tool function directly + response = create_tables( + ctx, + thought="test", database_id=database.id, tables=[ TableItemCreate( name="New Table", - primary_field=primary_field, + primary_field_name=primary_field_name, fields=fields, ) ], add_sample_rows=False, ) - assert response == { - "created_tables": [{"id": AnyInt(), "name": "New Table"}], - "notes": [], - } + assert len(response["created_tables"]) == 1 + assert response["created_tables"][0]["name"] == "New Table" + assert response["created_tables"][0]["id"] == AnyInt() + assert response["notes"] == [] + # Full schema is included with all field details + assert "primary_field" in response["created_tables"][0] + assert "fields" in response["created_tables"][0] # Ensure the table was actually created with all fields created_table = Table.objects.filter( @@ -351,28 +253,41 @@ def test_create_complex_table_tool(data_fixture): table_model = created_table.get_model() fields_map = {field.name: field for field in fields} - fields_map[primary_field.name] = primary_field for field_object in table_model.get_field_objects(): orm_field = field_object["field"] - assert orm_field.name in fields_map - field_item = fields_map.pop(orm_field.name).model_dump() - orm_field_to_item = field_item_registry.from_django_orm(orm_field).model_dump() + read_item = FieldItem.from_django_orm(orm_field).model_dump() + if orm_field.primary: - assert field_item["name"] == primary_field.name + assert orm_field.name == primary_field_name + continue + + assert orm_field.name in fields_map + create_item = fields_map.pop(orm_field.name) + create_dump = create_item.model_dump() + + # Both create and read are flat: type is top-level + assert create_dump["type"] == read_item["type"] - for key, value in orm_field_to_item.items(): - if key == "id": + # Compare type-specific fields present in both + skip_keys = {"name", "type"} + for key, value in create_dump.items(): + if key in skip_keys: continue + read_value = read_item.get(key) + if read_value is None: + continue # read model excludes None; defaults aren't relevant if key == "options": - # Saved options have an ID, so we need to remove them before comparison - for option in value: + # Saved options have an ID, so remove them before comparison + for option in read_value: option.pop("id") - - assert field_item[key] == value + assert read_value == value, ( + f"Field '{orm_field.name}' key '{key}': " + f"expected {value}, got {read_value}" + ) @pytest.mark.django_db -def test_generate_database_formula_no_save(data_fixture): +def test_generate_formula_no_save(data_fixture): """Test formula generation without saving to a field.""" user = data_fixture.create_user() @@ -381,20 +296,22 @@ def test_generate_database_formula_no_save(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=False, @@ -409,7 +326,7 @@ def test_generate_database_formula_no_save(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_create_new_field(data_fixture): +def test_generate_formula_create_new_field(data_fixture): """Test formula generation creates a new field when none exists.""" user = data_fixture.create_user() @@ -418,20 +335,22 @@ def test_generate_database_formula_create_new_field(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=True, @@ -453,7 +372,7 @@ def test_generate_database_formula_create_new_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_update_existing_formula_field(data_fixture): +def test_generate_formula_update_existing_formula_field(data_fixture): """Test formula generation updates an existing formula field.""" user = data_fixture.create_user() @@ -468,20 +387,22 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): ) existing_field_id = existing_field.id - # Mock the udspy.ReAct to return a new formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'new'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'new'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return updated text", save_to_field=True, @@ -503,7 +424,7 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_replace_non_formula_field(data_fixture): +def test_generate_formula_replace_non_formula_field(data_fixture): """Test formula generation replaces a non-formula field.""" user = data_fixture.create_user() @@ -518,20 +439,22 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): ) existing_field_id = existing_text_field.id - # Mock the udspy.ReAct to return a valid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - result = tool( + ctx = make_test_ctx(user, workspace) + result = generate_formula( + ctx, + thought="test", database_id=database.id, description="Return a simple text", save_to_field=True, @@ -559,7 +482,7 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_invalid_formula(data_fixture): +def test_generate_formula_invalid_formula(data_fixture): """Test error handling when formula generation fails.""" user = data_fixture.create_user() @@ -568,23 +491,27 @@ def test_generate_database_formula_invalid_formula(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to return an invalid formula - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = False - mock_prediction.formula = "" - mock_prediction.formula_type = "" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "Formula syntax error: invalid expression" + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="", + formula_type="", + is_formula_valid=False, + error_message="Formula syntax error: invalid expression", + ) - with patch("udspy.ReAct") as mock_react: - mock_react.return_value.return_value = mock_prediction + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) + ctx = make_test_ctx(user, workspace) # Verify exception is raised with pytest.raises(Exception) as exc_info: - tool( + generate_formula( + ctx, + thought="test", database_id=database.id, description="Invalid formula test", save_to_field=True, @@ -598,7 +525,7 @@ def test_generate_database_formula_invalid_formula(data_fixture): @pytest.mark.django_db -def test_generate_database_formula_documentation_completeness(data_fixture): +def test_generate_formula_documentation_completeness(data_fixture): """Test that formula documentation contains all required functions.""" user = data_fixture.create_user() @@ -607,39 +534,39 @@ def test_generate_database_formula_documentation_completeness(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the udspy.ReAct to capture the formula_documentation argument - mock_prediction = MagicMock() - mock_prediction.is_formula_valid = True - mock_prediction.formula = "'ok'" - mock_prediction.formula_type = "text" - mock_prediction.field_name = "test_formula" - mock_prediction.table_id = table.id - mock_prediction.error_message = "" - - captured_formula_docs = None - - class MockReAct: - def __init__(self, signature, tools=None, max_iters=10): - nonlocal captured_formula_docs - # Don't capture anything here - wait for the call - self.mock_instance = MagicMock(return_value=mock_prediction) - - def __call__(self, **kwargs): - nonlocal captured_formula_docs - captured_formula_docs = kwargs.get("formula_documentation") - return mock_prediction - - with patch("udspy.ReAct", MockReAct): - tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) - tool( + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="test_formula", + formula="'ok'", + formula_type="text", + ) + + captured_prompt = None + + def mock_run_sync(prompt, **kwargs): + nonlocal captured_prompt + captured_prompt = prompt + return mock_result + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync", + side_effect=mock_run_sync, + ): + ctx = make_test_ctx(user, workspace) + generate_formula( + ctx, + thought="test", database_id=database.id, description="Test documentation", save_to_field=False, ) - # Verify formula_documentation was provided - assert captured_formula_docs is not None - assert len(captured_formula_docs) > 0 + # Verify formula documentation was included in the prompt + assert captured_prompt is not None + assert len(captured_prompt) > 0 + + # The formula_documentation is now embedded in the prompt string + captured_formula_docs = captured_prompt # Known exceptions (internal functions not documented) formula_exceptions = [ @@ -689,3 +616,204 @@ def __call__(self, **kwargs): assert func in captured_formula_docs, ( f"Expected function '{func}' not found in documentation" ) + + +@pytest.mark.django_db +def test_formula_field_validation_raises_on_invalid_formula(data_fixture): + """Invalid formula in to_django_orm_kwargs raises InvalidFormulaFieldError.""" + + table = data_fixture.create_database_table(name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + item = FieldItemCreate( + name="Bad Formula", + type="formula", + formula="this is not a valid formula!!!", + ) + with pytest.raises(InvalidFormulaFieldError) as exc_info: + item.to_django_orm_kwargs(table) + + assert exc_info.value.field_name == "Bad Formula" + assert exc_info.value.formula == "this is not a valid formula!!!" + assert exc_info.value.table == table + + +@pytest.mark.django_db +def test_formula_field_validation_passes_for_valid_formula(data_fixture): + """Valid formula in to_django_orm_kwargs returns kwargs without error.""" + + table = data_fixture.create_database_table(name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + item = FieldItemCreate( + name="Good Formula", + type="formula", + formula="field('Name')", + ) + result = item.to_django_orm_kwargs(table) + assert result == {"name": "Good Formula", "formula": "field('Name')"} + + +@pytest.mark.django_db +def test_formula_field_validation_passes_for_empty_formula(data_fixture): + """Empty formula string skips validation.""" + + table = data_fixture.create_database_table(name="Test") + + item = FieldItemCreate( + name="Empty Formula", + type="formula", + formula="", + ) + result = item.to_django_orm_kwargs(table) + assert result == {"name": "Empty Formula", "formula": ""} + + +@pytest.mark.django_db +def test_create_fields_tool_with_invalid_formula_auto_fixes(data_fixture): + """ + When a formula field has an invalid formula, create_fields + auto-fixes it via the formula generation pipeline. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + mock_result = _make_mock_formula_result( + table_id=table.id, + field_name="Fixed Formula", + formula="field('Name')", + formula_type="text", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result + + ctx = make_test_ctx(user, workspace) + result = create_fields( + ctx, + thought="test", + table_id=table.id, + fields=[ + FieldItemCreate(name="Description", type="text"), + FieldItemCreate( + name="Bad Formula", + type="formula", + formula="invalid_stuff!!!", + ), + ], + ) + + # The text field should be created successfully + assert len(result["created_fields"]) == 2 + # No formula errors since auto-fix succeeded + assert "formula_errors" not in result + + # Verify the formula field was created with the original name and fixed formula + formula_field = table.field_set.filter(name="Bad Formula").first() + assert formula_field is not None + + +@pytest.mark.django_db +def test_create_fields_tool_reports_error_when_auto_fix_fails(data_fixture): + """ + When auto-fix also fails, create_fields reports the error + without failing the entire batch. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Test") + data_fixture.create_text_field(table=table, name="Name", primary=True) + + mock_result = _make_mock_formula_result( + is_formula_valid=False, + error_message="Could not fix formula", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync" + ) as mock_agent: + mock_agent.return_value = mock_result + + ctx = make_test_ctx(user, workspace) + result = create_fields( + ctx, + thought="test", + table_id=table.id, + fields=[ + FieldItemCreate(name="Description", type="text"), + FieldItemCreate( + name="Bad Formula", + type="formula", + formula="invalid_stuff!!!", + ), + ], + ) + + # The text field should still be created successfully + assert len(result["created_fields"]) == 1 + assert result["created_fields"][0]["name"] == "Description" + + # Formula errors should be reported + assert len(result["formula_errors"]) == 1 + assert result["formula_errors"][0]["field_name"] == "Bad Formula" + assert "hint" in result["formula_errors"][0] + + +@pytest.mark.django_db +def test_create_tables_with_invalid_formula_auto_fixes(data_fixture): + """ + When create_tables encounters an invalid formula, it auto-fixes + via the formula generation pipeline. + """ + + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + + def mock_run_sync(prompt, **kwargs): + # The table doesn't exist yet when the mock is created, so we + # dynamically set table_id on call. + tables = Table.objects.filter(database=database).order_by("-id") + return _make_mock_formula_result( + table_id=tables.first().id, + field_name="My Formula", + formula="'fixed'", + formula_type="text", + ) + + with patch( + "baserow_enterprise.assistant.tools.database.tools.formula_generation_agent.run_sync", + side_effect=mock_run_sync, + ): + ctx = make_test_ctx(user, workspace) + result = create_tables( + ctx, + thought="test", + database_id=database.id, + tables=[ + TableItemCreate( + name="Test Table", + primary_field_name="Name", + fields=[ + FieldItemCreate( + name="My Formula", + type="formula", + formula="bad formula!!!", + ), + ], + ) + ], + add_sample_rows=False, + ) + + assert len(result["created_tables"]) == 1 + # No formula error notes since auto-fix succeeded + assert result["notes"] == [] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip deleted file mode 100644 index 2c74da289e..0000000000 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py.skip +++ /dev/null @@ -1,52 +0,0 @@ -import pytest - -from baserow.contrib.database.models import Database -from baserow.test_utils.helpers import AnyInt -from baserow_enterprise.assistant.tools.database.tools import ( - get_create_database_tool, - get_list_databases_tool, -) - -from .utils import fake_tool_helpers - - -@pytest.mark.django_db -def test_list_databases_tool(data_fixture): - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - database = data_fixture.create_database_application( - workspace=workspace, name="Database 1" - ) - - tool = get_list_databases_tool(user, workspace, fake_tool_helpers) - response = tool() - - assert response == {"databases": [{"id": database.id, "name": "Database 1"}]} - - database_2 = data_fixture.create_database_application( - workspace=workspace, name="Database 2" - ) - response = tool() - assert response == { - "databases": [ - {"id": database.id, "name": "Database 1"}, - {"id": database_2.id, "name": "Database 2"}, - ] - } - - -@pytest.mark.django_db -def test_create_database_tool(data_fixture): - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - - tool = get_create_database_tool(user, workspace, fake_tool_helpers) - response = tool(name="New Database") - - assert response == {"created_database": {"id": AnyInt(), "name": "New Database"}} - - # Ensure the database was actually created - - assert Database.objects.filter( - id=response["created_database"]["id"], name="New Database" - ).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py index 88726637f9..273df74d86 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_view_filters_tools.py @@ -1,41 +1,20 @@ import pytest from baserow.contrib.database.views.models import ViewFilter -from baserow_enterprise.assistant.tools.database.types import ( - BooleanIsViewFilterItemCreate, - DateAfterViewFilterItemCreate, - DateBeforeViewFilterItemCreate, - DateEqualsViewFilterItemCreate, - DateNotEqualsViewFilterItemCreate, - LinkRowHasNotViewFilterItemCreate, - LinkRowHasViewFilterItemCreate, - MultipleSelectIsAnyViewFilterItemCreate, - MultipleSelectIsNoneOfNotViewFilterItemCreate, - NumberEmptyViewFilterItemCreate, - NumberEqualsViewFilterItemCreate, - NumberHigherThanViewFilterItemCreate, - NumberLowerThanViewFilterItemCreate, - NumberNotEmptyViewFilterItemCreate, - NumberNotEqualsViewFilterItemCreate, - SingleSelectIsAnyViewFilterItemCreate, - SingleSelectIsNoneOfNotViewFilterItemCreate, - TextContainsViewFilterItemCreate, - TextEmptyViewFilterItemCreate, - TextEqualViewFilterItemCreate, - TextNotContainsViewFilterItemCreate, - TextNotEmptyViewFilterItemCreate, - TextNotEqualViewFilterItemCreate, -) -from baserow_enterprise.assistant.tools.database.types.base import Date +from baserow_enterprise.assistant.tools.database.helpers import create_view_filter from baserow_enterprise.assistant.tools.database.types.view_filters import ( ViewFilterItemCreate, ) -from baserow_enterprise.assistant.tools.database.utils import create_view_filter + + +def _make_filter(field_id, **kwargs): + """Shortcut to build a ViewFilterItemCreate.""" + return ViewFilterItemCreate(field_id=field_id, **kwargs) @pytest.mark.django_db def test_all_text_filters_conversion(data_fixture): - """Test all text filter types can be converted to Baserow filters.""" + """Test all text filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -46,52 +25,29 @@ def test_all_text_filters_conversion(data_fixture): table_fields = {field.id: field} text_filters = [ + ({"type": "text", "operator": "equal", "value": "test"}, "equal", "test"), ( - TextEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="equal", value="test" - ), - "equal", - "test", - ), - ( - TextNotEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="not_equal", value="test" - ), + {"type": "text", "operator": "not_equal", "value": "test"}, "not_equal", "test", ), ( - TextContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains", value="keyword" - ), + {"type": "text", "operator": "contains", "value": "keyword"}, "contains", "keyword", ), ( - TextNotContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains_not", value="spam" - ), + {"type": "text", "operator": "contains_not", "value": "spam"}, "contains_not", "spam", ), - ( - TextEmptyViewFilterItemCreate( - field_id=field.id, type="text", operator="empty", value="" - ), - "empty", - "", - ), - ( - TextNotEmptyViewFilterItemCreate( - field_id=field.id, type="text", operator="not_empty", value="" - ), - "not_empty", - "", - ), + ({"type": "text", "operator": "empty", "value": ""}, "empty", ""), + ({"type": "text", "operator": "not_empty", "value": ""}, "not_empty", ""), ] - for filter_create, expected_type, expected_value in text_filters: - created_filter = create_view_filter(user, view, table_fields, filter_create) + for kwargs, expected_type, expected_value in text_filters: + filter_item = _make_filter(field.id, **kwargs) + created_filter = create_view_filter(user, view, table_fields, filter_item) assert created_filter is not None assert created_filter.view.id == view.id @@ -99,7 +55,6 @@ def test_all_text_filters_conversion(data_fixture): assert created_filter.type == expected_type assert created_filter.value == expected_value - # Verify in database assert ViewFilter.objects.filter( view=view, field=field, type=expected_type ).exists() @@ -107,7 +62,7 @@ def test_all_text_filters_conversion(data_fixture): @pytest.mark.django_db def test_all_number_filters_conversion(data_fixture): - """Test all number filter types can be converted to Baserow filters.""" + """Test all number filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -119,59 +74,80 @@ def test_all_number_filters_conversion(data_fixture): number_filters = [ ( - NumberEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="equal", value=42.0 - ), + {"type": "number", "operator": "equal", "value": 42.0, "or_equal": False}, "equal", "42.0", ), ( - NumberNotEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="not_equal", value=0.0 - ), + { + "type": "number", + "operator": "not_equal", + "value": 0.0, + "or_equal": False, + }, "not_equal", "0.0", ), ( - NumberHigherThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="higher_than", - value=100.0, - or_equal=False, - ), + { + "type": "number", + "operator": "higher_than", + "value": 100.0, + "or_equal": False, + }, "higher_than", "100.0", ), ( - NumberLowerThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="lower_than", - value=50.0, - or_equal=False, - ), + { + "type": "number", + "operator": "higher_than", + "value": 100.0, + "or_equal": True, + }, + "higher_than_or_equal", + "100.0", + ), + ( + { + "type": "number", + "operator": "lower_than", + "value": 50.0, + "or_equal": False, + }, "lower_than", "50.0", ), ( - NumberEmptyViewFilterItemCreate( - field_id=field.id, type="number", operator="empty", value=0.0 - ), + { + "type": "number", + "operator": "lower_than", + "value": 50.0, + "or_equal": True, + }, + "lower_than_or_equal", + "50.0", + ), + ( + {"type": "number", "operator": "empty", "value": 0.0, "or_equal": False}, "empty", "0.0", ), ( - NumberNotEmptyViewFilterItemCreate( - field_id=field.id, type="number", operator="not_empty", value=0.0 - ), + { + "type": "number", + "operator": "not_empty", + "value": 0.0, + "or_equal": False, + }, "not_empty", "0.0", ), ] - for filter_create, expected_type, expected_value in number_filters: - created_filter = create_view_filter(user, view, table_fields, filter_create) + for kwargs, expected_type, expected_value in number_filters: + filter_item = _make_filter(field.id, **kwargs) + created_filter = create_view_filter(user, view, table_fields, filter_item) assert created_filter is not None assert created_filter.type == expected_type @@ -183,7 +159,7 @@ def test_all_number_filters_conversion(data_fixture): @pytest.mark.django_db def test_all_date_filters_conversion(data_fixture): - """Test all date filter types can be converted to Baserow filters.""" + """Test all date filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -194,80 +170,86 @@ def test_all_date_filters_conversion(data_fixture): table_fields = {field.id: field} # Test with exact date - date_filter = DateEqualsViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=15), + value="2024-01-15", mode="exact_date", + or_equal=False, ) - created_filter = create_view_filter(user, view, table_fields, date_filter) - assert created_filter.type == "date_is" - assert "2024-01-15" in created_filter.value - assert created_filter.value.endswith("?exact_date") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "date_is" + assert "2024-01-15" in created.value + assert created.value.endswith("?exact_date") # Test with relative date (today) - date_filter2 = DateNotEqualsViewFilterItemCreate( - field_id=field.id, type="date", operator="not_equal", value=None, mode="today" + filter_item2 = _make_filter( + field.id, + type="date", + operator="not_equal", + value=None, + mode="today", + or_equal=False, ) - created_filter2 = create_view_filter(user, view, table_fields, date_filter2) - assert created_filter2.type == "date_is_not" - assert created_filter2.value.endswith("??today") + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "date_is_not" + assert created2.value.endswith("??today") # Test date_is_after - date_filter3 = DateAfterViewFilterItemCreate( - field_id=field.id, + filter_item3 = _make_filter( + field.id, type="date", operator="after", value=7, mode="nr_days_ago", or_equal=False, ) - created_filter3 = create_view_filter(user, view, table_fields, date_filter3) - assert created_filter3.type == "date_is_after" - assert "?7?" in created_filter3.value - assert created_filter3.value.endswith("nr_days_ago") + created3 = create_view_filter(user, view, table_fields, filter_item3) + assert created3.type == "date_is_after" + assert "?7?" in created3.value + assert created3.value.endswith("nr_days_ago") # Test date_is_on_or_after - date_filter4 = DateAfterViewFilterItemCreate( - field_id=field.id, + filter_item4 = _make_filter( + field.id, type="date", operator="after", value=30, mode="nr_days_from_now", or_equal=True, ) - created_filter4 = create_view_filter(user, view, table_fields, date_filter4) - assert created_filter4.type == "date_is_on_or_after" + created4 = create_view_filter(user, view, table_fields, filter_item4) + assert created4.type == "date_is_on_or_after" # Test date_is_before - date_filter5 = DateBeforeViewFilterItemCreate( - field_id=field.id, + filter_item5 = _make_filter( + field.id, type="date", operator="before", value=None, mode="tomorrow", or_equal=False, ) - created_filter5 = create_view_filter(user, view, table_fields, date_filter5) - assert created_filter5.type == "date_is_before" + created5 = create_view_filter(user, view, table_fields, filter_item5) + assert created5.type == "date_is_before" # Test date_is_on_or_before - date_filter6 = DateBeforeViewFilterItemCreate( - field_id=field.id, + filter_item6 = _make_filter( + field.id, type="date", operator="before", value=14, mode="nr_weeks_from_now", or_equal=True, ) - created_filter6 = create_view_filter(user, view, table_fields, date_filter6) - assert created_filter6.type == "date_is_on_or_before" + created6 = create_view_filter(user, view, table_fields, filter_item6) + assert created6.type == "date_is_on_or_before" @pytest.mark.django_db def test_all_single_select_filters_conversion(data_fixture): - """Test all single select filter types can be converted to Baserow filters.""" + """Test all single select filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -281,45 +263,38 @@ def test_all_single_select_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is_any_of - filter_create = SingleSelectIsAnyViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="single_select", operator="is_any_of", value=["Active", "Pending"], ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "single_select_is_any_of" - # Value should contain option IDs - option_ids = created_filter.value.split(",") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "single_select_is_any_of" + option_ids = created.value.split(",") assert str(option1.id) in option_ids assert str(option2.id) in option_ids assert len(option_ids) == 2 # Test case insensitive matching - filter_create2 = SingleSelectIsAnyViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_any_of", - value=["active"], # lowercase + filter_item2 = _make_filter( + field.id, type="single_select", operator="is_any_of", value=["active"] ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert str(option1.id) in created_filter2.value + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert str(option1.id) in created2.value # Test is_none_of - filter_create3 = SingleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_none_of", - value=["Inactive"], + filter_item3 = _make_filter( + field.id, type="single_select", operator="is_none_of", value=["Inactive"] ) - created_filter3 = create_view_filter(user, view, table_fields, filter_create3) - assert created_filter3.type == "single_select_is_none_of" - assert str(option3.id) in created_filter3.value + created3 = create_view_filter(user, view, table_fields, filter_item3) + assert created3.type == "single_select_is_none_of" + assert str(option3.id) in created3.value @pytest.mark.django_db def test_all_multiple_select_filters_conversion(data_fixture): - """Test all multiple select filter types can be converted to Baserow filters.""" + """Test all multiple select filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -333,28 +308,25 @@ def test_all_multiple_select_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is_any_of (has) - filter_create = MultipleSelectIsAnyViewFilterItemCreate( - field_id=field.id, + filter_item = _make_filter( + field.id, type="multiple_select", operator="is_any_of", value=["Important", "Urgent"], ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "multiple_select_has" - option_ids = created_filter.value.split(",") + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "multiple_select_has" + option_ids = created.value.split(",") assert str(option1.id) in option_ids assert str(option2.id) in option_ids # Test is_none_of (has_not) - filter_create2 = MultipleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="multiple_select", - operator="is_none_of", - value=["Archived"], + filter_item2 = _make_filter( + field.id, type="multiple_select", operator="is_none_of", value=["Archived"] ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "multiple_select_has_not" - assert str(option3.id) in created_filter2.value + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "multiple_select_has_not" + assert str(option3.id) in created2.value @pytest.mark.django_db @@ -362,7 +334,7 @@ def test_all_multiple_select_filters_conversion(data_fixture): reason="Link row filters have a bug in Baserow (UnboundLocalError in view_filters.py:1301)" ) def test_all_link_row_filters_conversion(data_fixture): - """Test all link row filter types can be converted to Baserow filters.""" + """Test all link row filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -374,25 +346,23 @@ def test_all_link_row_filters_conversion(data_fixture): table_fields = {field.id: field} # Test link_row_has - filter_create = LinkRowHasViewFilterItemCreate( - field_id=field.id, type="link_row", operator="has", value=123 - ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "link_row_has" - assert created_filter.value == "123" + filter_item = _make_filter(field.id, type="link_row", operator="has", value=123) + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "link_row_has" + assert created.value == "123" # Test link_row_has_not - filter_create2 = LinkRowHasNotViewFilterItemCreate( - field_id=field.id, type="link_row", operator="has_not", value=456 + filter_item2 = _make_filter( + field.id, type="link_row", operator="has_not", value=456 ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "link_row_has_not" - assert created_filter2.value == "456" + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "link_row_has_not" + assert created2.value == "456" @pytest.mark.django_db def test_all_boolean_filters_conversion(data_fixture): - """Test all boolean filter types can be converted to Baserow filters.""" + """Test all boolean filter operators can be converted to Baserow filters.""" user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) @@ -403,109 +373,30 @@ def test_all_boolean_filters_conversion(data_fixture): table_fields = {field.id: field} # Test is true - filter_create = BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=True - ) - created_filter = create_view_filter(user, view, table_fields, filter_create) - assert created_filter.type == "boolean" - assert created_filter.value == "1" + filter_item = _make_filter(field.id, type="boolean", operator="equal", value=True) + created = create_view_filter(user, view, table_fields, filter_item) + assert created.type == "equal" + assert created.value == "1" # Test is false - filter_create2 = BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=False - ) - created_filter2 = create_view_filter(user, view, table_fields, filter_create2) - assert created_filter2.type == "boolean" - assert created_filter2.value == "0" - - -def get_all_concrete_filter_classes(): - """ - Recursively find all concrete ViewFilterItemCreate subclasses. Concrete classes are - those that have specific operators and are meant to be instantiated. - """ - - def get_all_subclasses(cls): - all_subclasses = [] - for subclass in cls.__subclasses__(): - all_subclasses.append(subclass) - all_subclasses.extend(get_all_subclasses(subclass)) - return all_subclasses - - all_subclasses = get_all_subclasses(ViewFilterItemCreate) - - # Filter to only concrete classes (those with specific operators defined as Literal) - # These are the classes that end with "Create" and have a specific operator - concrete_classes = [] - for cls in all_subclasses: - # Check if this class defines a specific operator (has Literal type annotation) - if hasattr(cls, "__annotations__") and "operator" in cls.__annotations__: - annotation = cls.__annotations__["operator"] - # Check if it's a Literal type (concrete operator) - if hasattr(annotation, "__origin__") or "Literal" in str(annotation): - concrete_classes.append(cls) - - return concrete_classes - - -def test_filter_class_discovery(): - """ - Test that the filter class discovery mechanism works correctly. This ensures our - introspection logic properly identifies concrete filter classes. - """ - - all_concrete_classes = get_all_concrete_filter_classes() - - # Verify we found a reasonable number of filter classes - # As of now, there should be at least 20+ concrete filter classes - assert len(all_concrete_classes) >= 20, ( - f"Expected at least 20 concrete filter classes, found {len(all_concrete_classes)}. " - f"Classes found: {[cls.__name__ for cls in all_concrete_classes]}" - ) - - # Verify that known concrete classes are discovered - class_names = {cls.__name__ for cls in all_concrete_classes} - expected_classes = { - "TextEqualViewFilterItemCreate", - "NumberEqualsViewFilterItemCreate", - "DateEqualsViewFilterItemCreate", - "BooleanIsViewFilterItemCreate", - "LinkRowHasViewFilterItemCreate", - "SingleSelectIsAnyViewFilterItemCreate", - "MultipleSelectIsAnyViewFilterItemCreate", - } - - missing = expected_classes - class_names - assert not missing, f"Expected classes not found: {missing}" - - # Verify that base/intermediate classes are NOT included - excluded_classes = { - "ViewFilterItemCreate", - "TextViewFilterItemCreate", - "NumberViewFilterItemCreate", - "DateViewFilterItemCreate", - } - - found_excluded = excluded_classes & class_names - assert not found_excluded, ( - f"Base/intermediate classes should not be included: {found_excluded}" - ) + filter_item2 = _make_filter(field.id, type="boolean", operator="equal", value=False) + created2 = create_view_filter(user, view, table_fields, filter_item2) + assert created2.type == "equal" + assert created2.value == "0" @pytest.mark.django_db def test_comprehensive_all_filter_types_conversion(data_fixture): """ - Comprehensive test ensuring ALL filter types can be successfully converted to - Baserow filters with a table containing all supported field types. + Comprehensive test ensuring all filter config types can be successfully + converted to Baserow filters with a table containing all supported field types. """ - # Setup user = data_fixture.create_user() workspace = data_fixture.create_workspace(user=user) database = data_fixture.create_database_application(workspace=workspace) table = data_fixture.create_database_table(database=database, name="All Fields") - # Create all field types text_field = data_fixture.create_text_field(table=table, name="Text", primary=True) number_field = data_fixture.create_number_field(table=table, name="Number") date_field = data_fixture.create_date_field(table=table, name="Date") @@ -513,16 +404,9 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): single_select = data_fixture.create_single_select_field(table=table, name="Status") multi_select = data_fixture.create_multiple_select_field(table=table, name="Tags") - linked_table = data_fixture.create_database_table(database=database, name="Linked") - data_fixture.create_text_field(table=linked_table, name="Linked Text", primary=True) - link_field = data_fixture.create_link_row_field( - table=table, link_row_table=linked_table - ) - data_fixture.create_select_option(field=single_select, value="Active", order=1) data_fixture.create_select_option(field=multi_select, value="Important", order=1) - # Create view and table_fields dict view = data_fixture.create_grid_view(table=table) table_fields = { text_field.id: text_field, @@ -531,82 +415,78 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): boolean_field.id: boolean_field, single_select.id: single_select, multi_select.id: multi_select, - link_field.id: link_field, } - # List of all filter types to test all_filters = [ # Text filters - TextEqualViewFilterItemCreate( - field_id=text_field.id, type="text", operator="equal", value="test" - ), - TextNotEqualViewFilterItemCreate( - field_id=text_field.id, type="text", operator="not_equal", value="test" - ), - TextContainsViewFilterItemCreate( - field_id=text_field.id, type="text", operator="contains", value="test" - ), - TextNotContainsViewFilterItemCreate( - field_id=text_field.id, type="text", operator="contains_not", value="test" - ), - TextEmptyViewFilterItemCreate( - field_id=text_field.id, type="text", operator="empty", value="" - ), - TextNotEmptyViewFilterItemCreate( - field_id=text_field.id, type="text", operator="not_empty", value="" - ), + _make_filter(text_field.id, type="text", operator="equal", value="test"), + _make_filter(text_field.id, type="text", operator="not_equal", value="test"), + _make_filter(text_field.id, type="text", operator="contains", value="test"), + _make_filter(text_field.id, type="text", operator="contains_not", value="test"), + _make_filter(text_field.id, type="text", operator="empty", value=""), + _make_filter(text_field.id, type="text", operator="not_empty", value=""), # Number filters - NumberEqualsViewFilterItemCreate( - field_id=number_field.id, type="number", operator="equal", value=42.0 + _make_filter( + number_field.id, type="number", operator="equal", value=42.0, or_equal=False ), - NumberNotEqualsViewFilterItemCreate( - field_id=number_field.id, type="number", operator="not_equal", value=0.0 + _make_filter( + number_field.id, + type="number", + operator="not_equal", + value=0.0, + or_equal=False, ), - NumberHigherThanViewFilterItemCreate( - field_id=number_field.id, + _make_filter( + number_field.id, type="number", operator="higher_than", value=10.0, or_equal=False, ), - NumberLowerThanViewFilterItemCreate( - field_id=number_field.id, + _make_filter( + number_field.id, type="number", operator="lower_than", value=100.0, or_equal=True, ), - NumberEmptyViewFilterItemCreate( - field_id=number_field.id, type="number", operator="empty", value=0.0 + _make_filter( + number_field.id, type="number", operator="empty", value=0.0, or_equal=False ), - NumberNotEmptyViewFilterItemCreate( - field_id=number_field.id, type="number", operator="not_empty", value=0.0 + _make_filter( + number_field.id, + type="number", + operator="not_empty", + value=0.0, + or_equal=False, ), # Date filters - DateEqualsViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=1), + value="2024-01-01", mode="exact_date", + or_equal=False, ), - DateNotEqualsViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="not_equal", value=None, mode="today", + or_equal=False, ), - DateAfterViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="after", value=7, mode="nr_days_ago", or_equal=False, ), - DateBeforeViewFilterItemCreate( - field_id=date_field.id, + _make_filter( + date_field.id, type="date", operator="before", value=None, @@ -614,44 +494,34 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): or_equal=True, ), # Select filters - SingleSelectIsAnyViewFilterItemCreate( - field_id=single_select.id, + _make_filter( + single_select.id, type="single_select", operator="is_any_of", value=["Active"], ), - SingleSelectIsNoneOfNotViewFilterItemCreate( - field_id=single_select.id, + _make_filter( + single_select.id, type="single_select", operator="is_none_of", value=["Active"], ), - MultipleSelectIsAnyViewFilterItemCreate( - field_id=multi_select.id, + _make_filter( + multi_select.id, type="multiple_select", operator="is_any_of", value=["Important"], ), - MultipleSelectIsNoneOfNotViewFilterItemCreate( - field_id=multi_select.id, + _make_filter( + multi_select.id, type="multiple_select", operator="is_none_of", value=["Important"], ), - # Link row filters - LinkRowHasViewFilterItemCreate( - field_id=link_field.id, type="link_row", operator="has", value=1 - ), - LinkRowHasNotViewFilterItemCreate( - field_id=link_field.id, type="link_row", operator="has_not", value=2 - ), # Boolean filter - BooleanIsViewFilterItemCreate( - field_id=boolean_field.id, type="boolean", operator="is", value=True - ), + _make_filter(boolean_field.id, type="boolean", operator="equal", value=True), ] - # Test that all filters can be created successfully created_filters = [] for filter_item in all_filters: try: @@ -662,11 +532,9 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): except Exception as e: pytest.fail(f"Failed to create filter {filter_item}: {e}") - # Verify all filters were created in the database assert len(created_filters) == len(all_filters) assert ViewFilter.objects.filter(view=view).count() == len(all_filters) - # Verify each filter type is represented filter_types = set(f.type for f in created_filters) expected_types = { "equal", @@ -676,7 +544,7 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): "empty", "not_empty", "higher_than", - "lower_than", + "lower_than_or_equal", "date_is", "date_is_not", "date_is_after", @@ -685,29 +553,6 @@ def test_comprehensive_all_filter_types_conversion(data_fixture): "single_select_is_none_of", "multiple_select_has", "multiple_select_has_not", - "link_row_has", - "link_row_has_not", - "boolean", + "equal", # for boolean field } assert filter_types == expected_types - - # CRITICAL CHECK: Ensure all concrete filter classes are tested - all_concrete_classes = get_all_concrete_filter_classes() - tested_classes = {type(filter_item) for filter_item in all_filters} - - missing_classes = set(all_concrete_classes) - tested_classes - if missing_classes: - missing_names = [cls.__name__ for cls in missing_classes] - pytest.fail( - f"The following filter classes are not tested: {', '.join(missing_names)}. " - f"Please add test instances for these classes to the all_filters list." - ) - - # Ensure we're not testing non-existent classes - extra_classes = tested_classes - set(all_concrete_classes) - if extra_classes: - extra_names = [cls.__name__ for cls in extra_classes] - pytest.fail( - f"The following classes in the test don't exist as concrete filter classes: " - f"{', '.join(extra_names)}. Please remove them from the test." - ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py index bf3e940d08..4ac15bd2e7 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py @@ -1,91 +1,21 @@ -from unittest.mock import Mock - import pytest -from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.views.models import View, ViewFilter from baserow_enterprise.assistant.tools.database.tools import ( - get_list_views_tool, - get_views_tool_factory, + create_view_filters, + create_views, + list_views, ) from baserow_enterprise.assistant.tools.database.types import ( - BooleanIsViewFilterItemCreate, - CalendarViewItemCreate, - DateAfterViewFilterItemCreate, - DateBeforeViewFilterItemCreate, - DateEqualsViewFilterItemCreate, - DateNotEqualsViewFilterItemCreate, FormFieldOption, - FormViewItemCreate, - GalleryViewItemCreate, - GridViewItemCreate, - KanbanViewItemCreate, - MultipleSelectIsAnyViewFilterItemCreate, - MultipleSelectIsNoneOfNotViewFilterItemCreate, - NumberEqualsViewFilterItemCreate, - NumberHigherThanViewFilterItemCreate, - NumberLowerThanViewFilterItemCreate, - NumberNotEqualsViewFilterItemCreate, - SingleSelectIsAnyViewFilterItemCreate, - SingleSelectIsNoneOfNotViewFilterItemCreate, - TextContainsViewFilterItemCreate, - TextEqualViewFilterItemCreate, - TextNotContainsViewFilterItemCreate, - TextNotEqualViewFilterItemCreate, - TimelineViewItemCreate, + ViewItemCreate, ) -from baserow_enterprise.assistant.tools.database.types.base import Date from baserow_enterprise.assistant.tools.database.types.view_filters import ( + ViewFilterItemCreate, ViewFiltersArgs, ) -from .utils import fake_tool_helpers - - -def get_create_views_tool(user, workspace): - """Helper to get the create_views tool from the factory""" - - factory = get_views_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_views_tool = next( - (tool for tool in added_tools if tool.name == "create_views"), None - ) - assert create_views_tool is not None - return create_views_tool - - -def get_create_view_filters_tool(user, workspace): - """Helper to get the create_view_filters tool from the factory""" - - factory = get_views_tool_factory(user, workspace, fake_tool_helpers) - assert callable(factory) - - tools_upgrade = factory() - assert is_module_callback(tools_upgrade) - - mock_module = Mock() - mock_module._tools = [] - mock_module.init_module = Mock() - tools_upgrade(ModuleContext(module=mock_module)) - assert mock_module.init_module.called - - added_tools = mock_module.init_module.call_args[1]["tools"] - create_filters_tool = next( - (tool for tool in added_tools if tool.name == "create_view_filters"), None - ) - assert create_filters_tool is not None - return create_filters_tool +from .utils import make_test_ctx @pytest.mark.django_db @@ -96,23 +26,23 @@ def test_list_views_tool(data_fixture): table = data_fixture.create_database_table(database=database) view = data_fixture.create_grid_view(table=table, name="View 1", order=1) - tool = get_list_views_tool(user, workspace, fake_tool_helpers) - response = tool(table_id=table.id) + ctx = make_test_ctx(user, workspace) + response = list_views(ctx, thought="test", table_id=table.id) assert response == { "views": [ { "id": view.id, "name": "View 1", + "public": False, "type": "grid", "row_height": "small", - "public": False, } ] } view_2 = data_fixture.create_grid_view(table=table, name="View 2", order=2) - response = tool(table_id=table.id) + response = list_views(ctx, thought="test", table_id=table.id) assert len(response["views"]) == 2 assert response["views"][0]["name"] == "View 1" assert response["views"][1]["name"] == "View 2" @@ -125,12 +55,17 @@ def test_create_grid_view(data_fixture): database = data_fixture.create_database_application(workspace=workspace) table = data_fixture.create_database_table(database=database) - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - GridViewItemCreate( - type="grid", name="Grid View", public=False, row_height="medium" + ViewItemCreate( + name="Grid View", + public=False, + type="grid", + row_height="medium", ) ], ) @@ -148,14 +83,16 @@ def test_create_kanban_view(data_fixture): table = data_fixture.create_database_table(database=database) single_select = data_fixture.create_single_select_field(table=table, name="Status") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - KanbanViewItemCreate( - type="kanban", + ViewItemCreate( name="Kanban View", public=False, + type="kanban", column_field_id=single_select.id, ) ], @@ -174,14 +111,16 @@ def test_create_calendar_view(data_fixture): table = data_fixture.create_database_table(database=database) date_field = data_fixture.create_date_field(table=table, name="Date") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - CalendarViewItemCreate( - type="calendar", + ViewItemCreate( name="Calendar View", public=False, + type="calendar", date_field_id=date_field.id, ) ], @@ -200,14 +139,16 @@ def test_create_gallery_view(data_fixture): table = data_fixture.create_database_table(database=database) file_field = data_fixture.create_file_field(table=table, name="Files") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - GalleryViewItemCreate( - type="gallery", + ViewItemCreate( name="Gallery View", public=False, + type="gallery", cover_field_id=file_field.id, ) ], @@ -227,14 +168,16 @@ def test_create_timeline_view(data_fixture): start_date = data_fixture.create_date_field(table=table, name="Start Date") end_date = data_fixture.create_date_field(table=table, name="End Date") - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - TimelineViewItemCreate( - type="timeline", + ViewItemCreate( name="Timeline View", public=False, + type="timeline", start_date_field_id=start_date.id, end_date_field_id=end_date.id, ) @@ -254,14 +197,16 @@ def test_create_form_view(data_fixture): table = data_fixture.create_database_table(database=database) field = data_fixture.create_text_field(table=table, name="Name", primary=True) - tool = get_create_views_tool(user, workspace) - response = tool.func( + ctx = make_test_ctx(user, workspace) + response = create_views( + ctx, + thought="test", table_id=table.id, views=[ - FormViewItemCreate( - type="form", + ViewItemCreate( name="Form View", public=True, + type="form", title="Contact Form", description="Fill out this form", submit_button_label="Submit", @@ -297,18 +242,23 @@ def test_create_text_equal_filter(data_fixture): field = data_fixture.create_text_field(table=table, name="Name") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="equal", value="test" + ViewFilterItemCreate( + field_id=field.id, + type="text", + operator="equal", + value="test", ) ], ) - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -326,13 +276,15 @@ def test_create_text_not_equal_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextNotEqualViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="not_equal", @@ -340,7 +292,7 @@ def test_create_text_not_equal_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -356,13 +308,15 @@ def test_create_text_contains_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextContainsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="contains", @@ -370,7 +324,7 @@ def test_create_text_contains_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -386,13 +340,15 @@ def test_create_text_not_contains_filter(data_fixture): field = data_fixture.create_text_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - TextNotContainsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="text", operator="contains_not", @@ -400,7 +356,7 @@ def test_create_text_not_contains_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -419,18 +375,24 @@ def test_create_number_equal_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="equal", value=42.0 + ViewFilterItemCreate( + field_id=field.id, + type="number", + operator="equal", + value=42.0, + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -446,21 +408,24 @@ def test_create_number_not_equal_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberNotEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="not_equal", value=42.0, + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -476,13 +441,15 @@ def test_create_number_higher_than_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberHigherThanViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="higher_than", @@ -509,13 +476,15 @@ def test_create_number_lower_than_filter(data_fixture): field = data_fixture.create_number_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - NumberLowerThanViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="number", operator="lower_than", @@ -524,7 +493,7 @@ def test_create_number_lower_than_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -541,22 +510,25 @@ def test_create_date_equal_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="equal", - value=Date(year=2024, month=1, day=15), + value="2024-01-15", mode="exact_date", + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -572,22 +544,25 @@ def test_create_date_not_equal_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateNotEqualsViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="not_equal", value=None, mode="today", + or_equal=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -605,13 +580,15 @@ def test_create_date_after_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateAfterViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="after", @@ -621,7 +598,7 @@ def test_create_date_after_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -639,13 +616,15 @@ def test_create_date_before_filter(data_fixture): field = data_fixture.create_date_field(table=table) view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - DateBeforeViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="date", operator="before", @@ -655,7 +634,7 @@ def test_create_date_before_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -676,13 +655,15 @@ def test_create_single_select_is_any_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Option 2") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - SingleSelectIsAnyViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="single_select", operator="is_any_of", @@ -690,7 +671,7 @@ def test_create_single_select_is_any_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -709,13 +690,15 @@ def test_create_single_select_is_none_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Bad Option") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - SingleSelectIsNoneOfNotViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="single_select", operator="is_none_of", @@ -723,7 +706,7 @@ def test_create_single_select_is_none_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -742,22 +725,27 @@ def test_create_boolean_is_true_filter(data_fixture): field = data_fixture.create_boolean_field(table=table, name="Active") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=True + ViewFilterItemCreate( + field_id=field.id, + type="boolean", + operator="equal", + value=True, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 - assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() @pytest.mark.django_db @@ -769,22 +757,27 @@ def test_create_boolean_is_false_filter(data_fixture): field = data_fixture.create_boolean_field(table=table, name="Active") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=False + ViewFilterItemCreate( + field_id=field.id, + type="boolean", + operator="equal", + value=False, ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 - assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() # Multiple select filter tests @@ -799,13 +792,15 @@ def test_create_multiple_select_is_any_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Tag 2") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - MultipleSelectIsAnyViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="multiple_select", operator="is_any_of", @@ -813,7 +808,7 @@ def test_create_multiple_select_is_any_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 @@ -832,13 +827,15 @@ def test_create_multiple_select_is_none_of_filter(data_fixture): data_fixture.create_select_option(field=field, value="Bad Tag") view = data_fixture.create_grid_view(table=table) - tool = get_create_view_filters_tool(user, workspace) - response = tool.func( - [ + ctx = make_test_ctx(user, workspace) + response = create_view_filters( + ctx, + thought="test", + view_filters=[ ViewFiltersArgs( view_id=view.id, filters=[ - MultipleSelectIsNoneOfNotViewFilterItemCreate( + ViewFilterItemCreate( field_id=field.id, type="multiple_select", operator="is_none_of", @@ -846,7 +843,7 @@ def test_create_multiple_select_is_none_of_filter(data_fixture): ) ], ), - ] + ], ) assert len(response["created_view_filters"]) == 1 assert ViewFilter.objects.filter( diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py new file mode 100644 index 0000000000..a1f757bd3c --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_navigation_tools.py @@ -0,0 +1,47 @@ +from unittest.mock import MagicMock + +import pytest + +from baserow_enterprise.assistant.tools.navigation.tools import navigate +from baserow_enterprise.assistant.tools.navigation.types import ( + TableNavigationRequestType, +) + +from .utils import make_test_ctx + + +@pytest.mark.django_db +def test_navigate_to_table(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database, name="Tasks") + + navigate_mock = MagicMock(return_value="Navigated successfully.") + ctx = make_test_ctx(user, workspace) + ctx.deps.tool_helpers.navigate_to = navigate_mock + + request = TableNavigationRequestType(type="database-table", table_id=table.id) + result = navigate(ctx, request, thought="go to tasks table") + + assert result == "Navigated successfully." + navigate_mock.assert_called_once() + location = navigate_mock.call_args[0][0] + assert location.type == "database-table" + assert location.table_id == table.id + assert location.database_id == database.id + assert location.table_name == "Tasks" + + +@pytest.mark.django_db +def test_navigate_to_nonexistent_table(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + ctx = make_test_ctx(user, workspace) + + request = TableNavigationRequestType(type="database-table", table_id=999999) + result = navigate(ctx, request, thought="go to missing table") + + assert "Error" in result + assert "not found" in result diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py new file mode 100644 index 0000000000..dcc77492f9 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_search_docs_tools.py @@ -0,0 +1,104 @@ +import os +from unittest.mock import patch + +import pytest + +from baserow_enterprise.assistant.tools.search_user_docs.tools import ( + _TOOL_QUERY_RE, + search_user_docs, +) + +from .utils import make_test_ctx + +# search_user_docs is async, so we need this to allow sync ORM calls from +# data_fixture inside async tests. +os.environ.setdefault("DJANGO_ALLOW_ASYNC_UNSAFE", "true") + + +class TestToolQueryGuard: + """Tests for the tool-introspection regex guard.""" + + @pytest.mark.parametrize( + "query", + [ + "list_tables", + "create_fields", + "get_tables_schema", + "update_rows", + "delete_rows", + "generate_formula", + "create_view_filters", + "search_user_docs", + "navigate tool parameters", + ], + ) + def test_rejects_tool_introspection_queries(self, query): + assert _TOOL_QUERY_RE.search(query) is not None + + @pytest.mark.parametrize( + "query", + [ + "How to create a webhook in Baserow", + "How to link tables in Baserow", + "Baserow form view", + "How do I import data into Baserow", + ], + ) + def test_allows_legitimate_queries(self, query): + assert _TOOL_QUERY_RE.search(query) is None + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_rejects_tool_introspection(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + result = await search_user_docs( + ctx, question="list_tables", thought="looking up tool" + ) + + assert result["reliability"] == 0.0 + assert "REJECTED" in result["reliability_note"] + assert result["sources"] == [] + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_handles_empty_results(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + with patch( + "baserow_enterprise.assistant.tools.search_user_docs.tools.KnowledgeBaseHandler" + ) as mock_handler_cls: + mock_handler_cls.return_value.search.return_value = [] + + result = await search_user_docs( + ctx, question="How to use webhooks in Baserow", thought="user asks" + ) + + assert result["reliability"] == 0.0 + assert "Nothing found" in result["answer"] + + +@pytest.mark.django_db +@pytest.mark.asyncio +async def test_search_user_docs_handles_error(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + ctx = make_test_ctx(user, workspace) + + with patch( + "baserow_enterprise.assistant.tools.search_user_docs.tools.KnowledgeBaseHandler" + ) as mock_handler_cls: + mock_handler_cls.return_value.search.side_effect = RuntimeError("db error") + + result = await search_user_docs( + ctx, question="How to use webhooks", thought="user asks" + ) + + assert result["reliability"] == 0.0 + assert "error" in result["answer"].lower() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py new file mode 100644 index 0000000000..c83caf73ef --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_retrying_model.py @@ -0,0 +1,472 @@ +"""Unit tests for RetryingModel.""" + +import pytest + +from baserow_enterprise.assistant.retrying_model import ( + RetryingModel, + _is_transient_provider_error, +) + + +class TestIsTransientProviderError: + def test_groq_parse_error(self): + exc = Exception("Failed to parse tool call arguments as JSON") + assert _is_transient_provider_error(exc) is True + + def test_tool_validation_failed(self): + exc = Exception("Tool call validation failed: something") + assert _is_transient_provider_error(exc) is True + + def test_auth_error_not_retryable(self): + exc = Exception("Invalid API key") + assert _is_transient_provider_error(exc) is False + + def test_generic_error_not_retryable(self): + exc = ValueError("something went wrong") + assert _is_transient_provider_error(exc) is False + + +def _make_retrying(inner_mock, **kwargs): + """Create a RetryingModel with a pre-resolved mock as the wrapped model.""" + + model = RetryingModel.__new__(RetryingModel) + model._wrapped_or_name = inner_mock + model._resolved = inner_mock + model.max_attempts = kwargs.get("max_attempts", 3) + model.base_delay = kwargs.get("base_delay", 0.01) + model.max_delay = kwargs.get("max_delay", 10.0) + return model + + +@pytest.mark.asyncio +async def test_request_retries_on_transient_error(): + """RetryingModel.request should retry transient errors.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + response = ModelResponse(parts=[TextPart(content="hello")]) + inner.request = AsyncMock( + side_effect=[ + Exception("Failed to parse tool call arguments as JSON"), + response, + ] + ) + + model = _make_retrying(inner) + result = await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert result == response + assert inner.request.call_count == 2 + + +@pytest.mark.asyncio +async def test_request_raises_non_transient_error(): + """RetryingModel.request should not retry non-transient errors.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock(side_effect=ValueError("bad input")) + + model = _make_retrying(inner) + with pytest.raises(ValueError, match="bad input"): + await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert inner.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_request_exhausts_retries(): + """RetryingModel should raise after exhausting max_attempts.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock( + side_effect=Exception("Failed to parse tool call arguments as JSON") + ) + + model = _make_retrying(inner, max_attempts=2) + with pytest.raises(Exception, match="Failed to parse"): + await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + assert inner.request.call_count == 2 + + +def test_deferred_model_resolution(): + """RetryingModel should defer infer_model until first access.""" + + model = RetryingModel("groq:some-model") + # Should not raise at construction time + assert model._resolved is None + + +# --------------------------------------------------------------------------- +# tool_use_failed recovery +# --------------------------------------------------------------------------- + + +def _make_tool_use_failed_error( + failed_generation: str, + model_name: str = "test-model", +): + from pydantic_ai.exceptions import ModelHTTPError + + return ModelHTTPError( + status_code=400, + model_name=model_name, + body={ + "error": { + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": failed_generation, + } + }, + ) + + +class TestTryRecoverToolUseFailed: + def test_recovers_valid_tool_call(self): + from pydantic_ai.messages import ToolCallPart + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error( + '{"name": "list_tables", "arguments": {"thought": "test"}}' + ) + result = _try_recover_tool_use_failed(exc) + + assert result is not None + assert len(result.parts) == 1 + part = result.parts[0] + assert isinstance(part, ToolCallPart) + assert part.tool_name == "list_tables" + assert "thought" in part.args + + def test_recovers_malformed_json_as_tool_call(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error("{not valid json") + result = _try_recover_tool_use_failed(exc) + + assert result is not None + assert len(result.parts) == 1 + from pydantic_ai.messages import ToolCallPart + + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "unknown" + assert result.parts[0].args == "{}" + + def test_recovers_malformed_json_extracts_tool_name(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = _make_tool_use_failed_error( + '{"name": "create_elements", "arguments": {"page_id": 1, "elements": [truncated' + ) + result = _try_recover_tool_use_failed(exc) + + assert result is not None + from pydantic_ai.messages import ToolCallPart + + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_elements" + assert result.parts[0].args == "{}" + + def test_returns_none_for_non_tool_use_failed(self): + from pydantic_ai.exceptions import ModelHTTPError + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + exc = ModelHTTPError( + status_code=400, + model_name="test", + body={"error": {"message": "other error", "code": "other"}}, + ) + assert _try_recover_tool_use_failed(exc) is None + + def test_returns_none_for_non_model_http_error(self): + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + assert _try_recover_tool_use_failed(ValueError("nope")) is None + + def test_recovers_raw_api_error_with_body(self): + """Handles raw provider APIError (e.g. groq.APIError) with body attr.""" + + from pydantic_ai.messages import ToolCallPart + + from baserow_enterprise.assistant.retrying_model import ( + _try_recover_tool_use_failed, + ) + + class FakeAPIError(Exception): + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + exc = FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_rows", "arguments": {"table_id": 1}}', + }, + ) + result = _try_recover_tool_use_failed(exc) + assert result is not None + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_rows" + + +@pytest.mark.asyncio +async def test_request_recovers_tool_use_failed(): + """request() should recover tool_use_failed into a ModelResponse.""" + + from unittest.mock import AsyncMock, MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + inner.request = AsyncMock( + side_effect=_make_tool_use_failed_error( + '{"name": "create_rows", "arguments": {"thought": "hi"}}' + ) + ) + + model = _make_retrying(inner) + result = await model.request( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) + + from pydantic_ai.messages import ToolCallPart + + # Should return a recovered response, not raise + assert len(result.parts) == 1 + assert isinstance(result.parts[0], ToolCallPart) + assert result.parts[0].tool_name == "create_rows" + # Should NOT have retried — recovery is immediate + assert inner.request.call_count == 1 + + +@pytest.mark.asyncio +async def test_request_stream_recovers_tool_use_failed(): + """request_stream() should recover tool_use_failed into a PreFetchedResponse.""" + + from unittest.mock import MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + + async def _failing_stream(*args, **kwargs): + raise _make_tool_use_failed_error( + '{"name": "list_rows", "arguments": {"table_id": 1}}' + ) + + # Make request_stream an async context manager that raises + from contextlib import asynccontextmanager + + @asynccontextmanager + async def failing_cm(*args, **kwargs): + raise _make_tool_use_failed_error( + '{"name": "list_rows", "arguments": {"table_id": 1}}' + ) + yield # pragma: no cover + + inner.request_stream = failing_cm + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + # Collect events from the pre-fetched response + events = [e async for e in stream] + + from pydantic_ai.models import PartStartEvent + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert start_events[0].part.tool_name == "list_rows" + + +@pytest.mark.asyncio +async def test_request_stream_recovers_mid_stream_api_error(): + """_ErrorRecoveringStream catches APIError during chunk iteration + and emits recovery events instead of crashing.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai._parts_manager import ModelResponsePartsManager + from pydantic_ai.models import ModelRequestParameters, PartStartEvent + + # Simulate a real StreamedResponse whose _get_event_iterator raises APIError + class FakeAPIError(Exception): + """Simulates groq.APIError with a body attribute.""" + + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + class FakeStreamedResponse: + """Minimal fake that raises during iteration.""" + + model_name = "test-model" + provider_name = "test" + provider_url = "http://test" + timestamp = None + _parts_manager = ModelResponsePartsManager() + model_request_parameters = ModelRequestParameters( + function_tools=[], output_tools=[] + ) + final_result_event = None + provider_response_id = None + provider_details = None + finish_reason = None + + async def _get_event_iterator(self): + raise FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_elements", "arguments": {"bad": true}}', + }, + ) + yield # pragma: no cover — make it a generator + + inner = MagicMock() + + @asynccontextmanager + async def fake_request_stream(*args, **kwargs): + yield FakeStreamedResponse() + + inner.request_stream = fake_request_stream + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + events = [e async for e in stream] + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert start_events[0].part.tool_name == "create_elements" + + +@pytest.mark.asyncio +async def test_request_stream_recovers_mid_stream_malformed_json(): + """_ErrorRecoveringStream recovers even when failed_generation JSON is + unparseable — returns a ToolCallPart with empty args so pydantic-ai's + validation loop can retry.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai._parts_manager import ModelResponsePartsManager + from pydantic_ai.messages import ToolCallPart + from pydantic_ai.models import ModelRequestParameters, PartStartEvent + + class FakeAPIError(Exception): + def __init__(self, message, body=None): + super().__init__(message) + self.body = body + + class FakeStreamedResponse: + model_name = "test-model" + provider_name = "test" + provider_url = "http://test" + timestamp = None + _parts_manager = ModelResponsePartsManager() + model_request_parameters = ModelRequestParameters( + function_tools=[], output_tools=[] + ) + final_result_event = None + provider_response_id = None + provider_details = None + finish_reason = None + + async def _get_event_iterator(self): + raise FakeAPIError( + "Failed to parse tool call arguments as JSON", + body={ + "message": "Failed to parse tool call arguments as JSON", + "type": "invalid_request_error", + "code": "tool_use_failed", + "failed_generation": '{"name": "create_elements", "arguments": {truncated', + }, + ) + yield # pragma: no cover + + inner = MagicMock() + + @asynccontextmanager + async def fake_request_stream(*args, **kwargs): + yield FakeStreamedResponse() + + inner.request_stream = fake_request_stream + + model = _make_retrying(inner) + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + events = [e async for e in stream] + + start_events = [e for e in events if isinstance(e, PartStartEvent)] + assert len(start_events) == 1 + assert isinstance(start_events[0].part, ToolCallPart) + assert start_events[0].part.tool_name == "create_elements" + assert start_events[0].part.args == "{}" + + +@pytest.mark.asyncio +async def test_request_stream_reraises_after_yield(): + """Errors during stream __aexit__ (non-recoverable) must re-raise.""" + + from contextlib import asynccontextmanager + from unittest.mock import MagicMock + + from pydantic_ai.models import ModelRequestParameters + + inner = MagicMock() + + @asynccontextmanager + async def stream_that_fails_during_consumption(*args, **kwargs): + # Yield a mock stream, then raise on __aexit__ + yield MagicMock() + raise Exception("some unrelated error") + + inner.request_stream = stream_that_fails_during_consumption + + model = _make_retrying(inner) + with pytest.raises(Exception, match="some unrelated error"): + async with model.request_stream( + [], None, ModelRequestParameters(function_tools=[], output_tools=[]) + ) as stream: + pass # stream consumed, then __aexit__ raises diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py index 4e0b8b095a..d1a0835c91 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py @@ -1,10 +1,17 @@ +import json from unittest.mock import MagicMock, patch import pytest -import udspy from baserow_enterprise.assistant.models import AssistantChat -from baserow_enterprise.assistant.telemetry import PosthogTracingCallback +from baserow_enterprise.assistant.telemetry import ( + PosthogSpanProcessor, + PosthogTracingCallback, + _pydantic_messages_to_posthog, + _tool_calls, + _trace_ctx, + _TraceContext, +) @pytest.fixture @@ -16,18 +23,6 @@ def assistant_chat_fixture(enterprise_data_fixture): ) -@pytest.fixture(autouse=True) -def mock_posthog_openai(): - with ( - udspy.settings.context(lm=udspy.LM(model="fake-model")), - patch("posthog.ai.openai.AsyncOpenAI") as mock, - ): - # Configure the mock if needed - mock.return_value = MagicMock() - mock.return_value.model = "test-model" - yield mock - - @pytest.mark.django_db class TestPosthogTracingCallback: @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") @@ -89,209 +84,563 @@ def test_trace_context_manager_exception( assert call_args.kwargs["properties"]["$ai_is_error"] is True @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_module_start_end(self, mock_get_client, assistant_chat_fixture): - """Test module execution tracing.""" + def test_trace_with_output(self, mock_get_client, assistant_chat_fixture): + """Test that trace output is captured when set.""" mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog callback = PosthogTracingCallback() - # Initialize context manually - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-123" - callback.span_ids = ["root-span"] - callback.spans = {} - callback.enabled = True - - # Mock a CoT module - mock_module = MagicMock(spec=udspy.ChainOfThought) - mock_module.__class__ = udspy.ChainOfThought - mock_signature = MagicMock() - mock_signature.get_input_fields.return_value = {"q": 1} - mock_signature.get_output_fields.return_value = { - "a": 1 - } # Should be dict, not list - mock_signature.get_instructions.return_value = "Test instructions" - mock_module.original_signature = mock_signature - - # Start module - callback.on_module_start( - call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} - ) - assert len(callback.span_ids) == 2 - assert len(callback.spans) == 1 + with callback.trace(assistant_chat_fixture, "Hello"): + callback.set_trace_output("The answer is 42") + + call_args = mock_posthog.capture.call_args + props = call_args.kwargs["properties"] + assert props["$ai_output_state"] == {"answer": "The answer is 42"} + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_trace_sets_and_clears_context_var( + self, mock_get_client, assistant_chat_fixture + ): + """Test that the ContextVar is set inside the trace and cleared after.""" + + mock_get_client.return_value = MagicMock() + + callback = PosthogTracingCallback() + + # Before trace, context should be None + assert _trace_ctx.get() is None + + with callback.trace(assistant_chat_fixture, "Hello"): + ctx = _trace_ctx.get() + assert ctx is not None + assert ctx.trace_id == callback.trace_id + assert ctx.user_id == str(assistant_chat_fixture.user_id) + assert ctx.workspace_id == str(assistant_chat_fixture.workspace_id) + assert ctx.chat_uuid == str(assistant_chat_fixture.uuid) + + # After trace, context should be cleared + assert _trace_ctx.get() is None + + +class TestPydanticMessagesToPosthog: + """Test the message format conversion utility.""" + + def test_convert_text_message(self): + """Test converting a simple text message.""" + + messages = [{"role": "user", "parts": [{"type": "text", "content": "Hello"}]}] + result = _pydantic_messages_to_posthog(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == [{"type": "text", "text": "Hello"}] + + def test_convert_tool_call(self): + """Test converting a tool call message.""" + + messages = [ + { + "role": "assistant", + "parts": [ + { + "type": "tool_call", + "id": "call_123", + "name": "list_tables", + "arguments": {"database_id": 1}, + } + ], + } + ] + result = _pydantic_messages_to_posthog(messages) + + assert result[0]["role"] == "assistant" + tc = result[0]["content"][0] + assert tc["type"] == "tool_call" + assert tc["tool_call_id"] == "call_123" + assert tc["name"] == "list_tables" + assert tc["arguments"] == {"database_id": 1} + + def test_convert_tool_return(self): + """Test converting a tool return message.""" + + messages = [ + { + "role": "tool", + "parts": [ + { + "type": "tool_return", + "tool_call_id": "call_123", + "content": "Tables: Users, Orders", + } + ], + } + ] + result = _pydantic_messages_to_posthog(messages) + + assert result[0]["content"][0]["type"] == "tool_result" + assert result[0]["content"][0]["tool_call_id"] == "call_123" + + +class TestPosthogSpanProcessor: + """Test the OpenTelemetry span processor for PostHog.""" + + def _make_mock_span( + self, + name, + kind, + attrs=None, + start_time=None, + end_time=None, + parent_span_id=None, + span_id=0x1234, + ): + """Create a mock ReadableSpan.""" + + span = MagicMock() + span.name = name + span.kind = kind + span.attributes = attrs or {} + span.start_time = start_time or 1000000000 # 1s in ns + span.end_time = end_time or 2000000000 # 2s in ns + span.events = [] + + # Context + span.context = MagicMock() + span.context.span_id = span_id + + # Parent + if parent_span_id is not None: + span.parent = MagicMock() + span.parent.span_id = parent_span_id + else: + span.parent = None + + return span - # End module - callback.on_module_end( - call_id="call-1", outputs={"a": "result"}, exception=None + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_generation_span(self, mock_get_client): + """Test that a 'chat' span is mapped to $ai_generation.""" + + from opentelemetry.trace import SpanKind + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="chat groq:llama-3.3-70b", + kind=SpanKind.CLIENT, + attrs={ + "gen_ai.request.model": "llama-3.3-70b", + "gen_ai.response.model": "llama-3.3-70b", + "gen_ai.provider.name": "groq", + "gen_ai.usage.input_tokens": 100, + "gen_ai.usage.output_tokens": 50, + "gen_ai.input.messages": json.dumps( + [{"role": "user", "parts": [{"type": "text", "content": "Hi"}]}] + ), + "gen_ai.output.messages": json.dumps( + [ + { + "role": "assistant", + "parts": [{"type": "text", "content": "Hello!"}], + } + ] + ), + }, + parent_span_id=0xABCD, ) - assert len(callback.span_ids) == 1 - assert len(callback.spans) == 0 + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - # Verify span event was called mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - - # Check the event structure - assert call_args.kwargs["distinct_id"] == str(assistant_chat_fixture.user_id) - assert call_args.kwargs["event"] == "$ai_span" - assert "timestamp" in call_args.kwargs + call = mock_posthog.capture.call_args + assert call.kwargs["distinct_id"] == "user-456" + assert call.kwargs["event"] == "$ai_generation" - # Check properties - props = call_args.kwargs["properties"] + props = call.kwargs["properties"] assert props["$ai_trace_id"] == "trace-123" - assert props["$ai_session_id"] == str(assistant_chat_fixture.uuid) - assert props["workspace_id"] == str(assistant_chat_fixture.workspace_id) - assert props["$ai_span_name"] == "ChainOfThought" - assert props["$ai_span_id"] == "call-1" - assert props["$ai_parent_span_id"] == "root-span" - assert "$ai_input_state" in props - assert props["$ai_output_state"] == {"a": "result"} - assert props["$ai_latency"] >= 0 - assert props["$ai_is_error"] is False + assert props["$ai_session_id"] == "chat-abc" + assert props["workspace_id"] == "ws-789" + assert props["$ai_model"] == "llama-3.3-70b" + assert props["$ai_provider"] == "groq" + assert props["$ai_input_tokens"] == 100 + assert props["$ai_output_tokens"] == 50 + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) + assert props["$ai_parent_id"] == f"{0xABCD:016x}" + + # Check message format conversion + assert len(props["$ai_input"]) == 1 + assert props["$ai_input"][0]["role"] == "user" + assert props["$ai_input"][0]["content"][0]["text"] == "Hi" - def test_on_lm_start(self, assistant_chat_fixture): - """Test LM start tracing.""" + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_tool_span(self, mock_get_client): + """Test that a 'running tool' span is mapped to $ai_span.""" - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = "user-1" - callback.workspace_id = "ws-1" - callback.chat_uuid = "chat-1" - callback.trace_id = "trace-1" - callback.span_ids = ["root"] + from opentelemetry.trace import SpanKind - mock_lm = MagicMock() - mock_lm.provider = "openai" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog - inputs = {"kwargs": {}} - callback.on_lm_start("call-1", mock_lm, inputs) + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "list_tables", + "tool_arguments": '{"database_id": 1}', + "tool_response": "Found 3 tables: Users, Orders, Products", + }, + parent_span_id=0x5678, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - assert len(callback.span_ids) == 2 - assert inputs["kwargs"]["posthog_distinct_id"] == "user-1" - assert inputs["kwargs"]["posthog_trace_id"] == "trace-1" - assert inputs["kwargs"]["posthog_properties"]["$ai_provider"] == "openai" + mock_posthog.capture.assert_called_once() + call = mock_posthog.capture.call_args + assert call.kwargs["event"] == "$ai_span" + + props = call.kwargs["properties"] + assert props["$ai_span_name"] == "Tool: list_tables" + assert props["$ai_input_state"] == {"database_id": 1} + assert "Found 3 tables" in props["$ai_output_state"] + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_tool_start_end(self, mock_get_client, assistant_chat_fixture): - """Test tool execution tracing.""" + def test_agent_run_span_is_exported(self, mock_get_client): + """Test that 'agent run' spans are exported as $ai_span with agent name, + system prompt, user input, and final output.""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-1" - callback.span_ids = ["root"] - callback.spans = {} - callback.enabled = True - - mock_tool = MagicMock() - mock_tool.name = "test_tool" - - # Start tool - callback.on_tool_start( - call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + processor = PosthogSpanProcessor() + + system_instructions = json.dumps( + [{"type": "text", "content": "You are a helpful assistant."}] + ) + all_messages = json.dumps( + [ + { + "role": "user", + "parts": [{"type": "text", "content": "Create a table"}], + }, + { + "role": "model-response", + "parts": [{"type": "text", "content": "Done!"}], + }, + ] + ) + + span = self._make_mock_span( + name="agent run", + kind=SpanKind.INTERNAL, + attrs={ + "agent_name": "main_agent", + "gen_ai.system_instructions": system_instructions, + "pydantic_ai.all_messages": all_messages, + "final_result": '{"table_id": 1}', + }, + parent_span_id=0x9999, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) + + mock_posthog.capture.assert_called_once() + call = mock_posthog.capture.call_args + assert call.kwargs["event"] == "$ai_span" + + props = call.kwargs["properties"] + assert props["$ai_span_name"] == "Agent: main_agent" + assert props["$ai_trace_id"] == "trace-123" + assert props["$ai_latency"] == pytest.approx(1.0, abs=0.01) + assert props["$ai_parent_id"] == f"{0x9999:016x}" + assert ( + props["$ai_input_state"]["system_prompt"] == "You are a helpful assistant." ) + assert props["$ai_input_state"]["user_prompt"] == "Create a table" + assert props["$ai_output_state"] == {"table_id": 1} - assert len(callback.spans) == 1 + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_agent_run_span_subagent_label(self, mock_get_client): + """Test that sub-agent spans get their own distinct label and handle + string final_result.""" + + from opentelemetry.trace import SpanKind - # End tool - callback.on_tool_end(call_id="call-1", outputs="result", exception=None) + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="agent run", + kind=SpanKind.INTERNAL, + attrs={ + "agent_name": "sample_row_agent", + "final_result": "Rows created successfully", + }, + ) + + ctx = _TraceContext( + trace_id="trace-123", + user_id="user-456", + workspace_id="ws-789", + chat_uuid="chat-abc", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(span) + finally: + _trace_ctx.reset(token) - # Verify event - mock_posthog.capture.assert_called() props = mock_posthog.capture.call_args.kwargs["properties"] - assert props["$ai_span_name"] == "Tool: test_tool" - assert props["$ai_input_state"] == {"arg": "val"} - assert props["$ai_output_state"] == "result" + assert props["$ai_span_name"] == "Agent: sample_row_agent" + assert props["$ai_output_state"] == "Rows created successfully" @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_module_end_with_exception( - self, mock_get_client, assistant_chat_fixture - ): - """Test that exception string is captured in $ai_output_state.""" + def test_running_tools_skipped_and_parent_remapped(self, mock_get_client): + """Test that 'running tools' is not emitted and child tool spans + have their parent remapped to the grandparent (agent span).""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-123" - callback.span_ids = ["root-span"] - callback.spans = {} - callback.enabled = True - - # Mock a module - mock_module = MagicMock(spec=udspy.ChainOfThought) - mock_module.__class__ = udspy.ChainOfThought - mock_signature = MagicMock() - mock_signature.get_input_fields.return_value = {"q": 1} - mock_signature.get_output_fields.return_value = {"a": 1} - mock_signature.get_instructions.return_value = "Test instructions" - mock_module.original_signature = mock_signature - - # Start module - callback.on_module_start( - call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} + processor = PosthogSpanProcessor() + + agent_span_id = 0xAAAA + tools_group_span_id = 0xBBBB + tool_span_id = 0xCCCC + + # 1) "running tools" starts — processor records the parent mapping. + tools_group_span = self._make_mock_span( + name="running tools", + kind=SpanKind.INTERNAL, + span_id=tools_group_span_id, + parent_span_id=agent_span_id, + ) + processor.on_start(tools_group_span) + + # 2) "running tool" ends — its direct parent is the tools group, + # but the processor should remap to the agent span. + tool_span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "create_tables", + "tool_arguments": "{}", + "tool_response": "ok", + }, + span_id=tool_span_id, + parent_span_id=tools_group_span_id, ) - # End module with exception - test_exception = ValueError("Test error message") - callback.on_module_end(call_id="call-1", outputs=None, exception=test_exception) + ctx = _TraceContext(trace_id="t", user_id="u", workspace_id="w", chat_uuid="c") + token = _trace_ctx.set(ctx) + try: + processor.on_end(tool_span) - # Verify exception string is captured - mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - props = call_args.kwargs["properties"] + # Tool span's parent should be the agent, not the tools group. + props = mock_posthog.capture.call_args.kwargs["properties"] + assert props["$ai_parent_id"] == f"{agent_span_id:016x}" - assert props["$ai_is_error"] is True - assert props["$ai_output_state"] == "Test error message" + mock_posthog.capture.reset_mock() + + # 3) "running tools" ends — should NOT emit anything. + processor.on_end(tools_group_span) + mock_posthog.capture.assert_not_called() + finally: + _trace_ctx.reset(token) @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") - def test_on_tool_end_with_exception(self, mock_get_client, assistant_chat_fixture): - """Test that exception string is captured in $ai_output_state for tools.""" + def test_without_trace_context_is_noop(self, mock_get_client): + """Test that spans without a trace context are silently ignored.""" + + from opentelemetry.trace import SpanKind mock_posthog = MagicMock() mock_get_client.return_value = mock_posthog - callback = PosthogTracingCallback() - callback.chat = assistant_chat_fixture - callback.user_id = str(assistant_chat_fixture.user_id) - callback.workspace_id = str(assistant_chat_fixture.workspace_id) - callback.chat_uuid = str(assistant_chat_fixture.uuid) - callback.trace_id = "trace-1" - callback.span_ids = ["root"] - callback.spans = {} - callback.enabled = True - - mock_tool = MagicMock() - mock_tool.name = "test_tool" - - # Start tool - callback.on_tool_start( - call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + processor = PosthogSpanProcessor() + + span = self._make_mock_span( + name="chat groq:llama-3.3-70b", + kind=SpanKind.CLIENT, ) - # End tool with exception - test_exception = RuntimeError("Tool execution failed") - callback.on_tool_end(call_id="call-1", outputs=None, exception=test_exception) + # No trace context set + processor.on_end(span) - # Verify exception string is captured - mock_posthog.capture.assert_called_once() - call_args = mock_posthog.capture.call_args - props = call_args.kwargs["properties"] + mock_posthog.capture.assert_not_called() + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_multiple_spans(self, mock_get_client): + """Test that multiple spans are all processed.""" + + from opentelemetry.trace import SpanKind + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + processor = PosthogSpanProcessor() + + generation_span = self._make_mock_span( + name="chat openai:gpt-4o", + kind=SpanKind.CLIENT, + attrs={ + "gen_ai.request.model": "gpt-4o", + "gen_ai.provider.name": "openai", + "gen_ai.usage.input_tokens": 200, + "gen_ai.usage.output_tokens": 80, + }, + span_id=0x1111, + ) + tool_span = self._make_mock_span( + name="running tool", + kind=SpanKind.INTERNAL, + attrs={ + "gen_ai.tool.name": "create_table", + "tool_arguments": "{}", + "tool_response": "Created table", + }, + span_id=0x2222, + ) + + ctx = _TraceContext( + trace_id="trace-multi", + user_id="user-1", + workspace_id="ws-1", + chat_uuid="chat-1", + ) + token = _trace_ctx.set(ctx) + try: + processor.on_end(generation_span) + processor.on_end(tool_span) + finally: + _trace_ctx.reset(token) + + assert mock_posthog.capture.call_count == 2 + events = [c.kwargs["event"] for c in mock_posthog.capture.call_args_list] + assert "$ai_generation" in events + assert "$ai_span" in events + + +class TestSetupInstrumentation: + """Test the one-time instrumentation setup.""" + + @patch("baserow_enterprise.assistant.telemetry._instrumentation_ready", False) + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_setup_skipped_when_posthog_disabled(self, mock_get_client): + """Test that setup is skipped when POSTHOG_ENABLED is False.""" + + from baserow_enterprise.assistant.telemetry import setup_instrumentation + + # POSTHOG_ENABLED is False in test settings + setup_instrumentation() + + # Should not have called get_posthog_client (nothing was set up) + mock_get_client.assert_not_called() + + +class TestEndToEndOtelPipeline: + """Integration: verify that a real pydantic-ai Agent run produces + PostHog events via the OTel span exporter.""" + + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_agent_run_produces_posthog_events(self, mock_get_client): + """A real Agent.run_sync() inside a trace() should emit both + $ai_trace and $ai_generation events via PostHog.""" + + from opentelemetry.sdk.trace import TracerProvider as _TP + from pydantic_ai import Agent, InstrumentationSettings + + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + + # Wire up the same pipeline that setup_instrumentation() creates. + tp = _TP() + tp.add_span_processor(PosthogSpanProcessor()) + Agent.instrument_all( + InstrumentationSettings(tracer_provider=tp, include_content=True) + ) - assert props["$ai_is_error"] is True - assert props["$ai_output_state"] == "Tool execution failed" + try: + # Set trace context (simulates PosthogTracingCallback.trace()). + ctx = _TraceContext( + trace_id="e2e-trace", + user_id="e2e-user", + workspace_id="e2e-ws", + chat_uuid="e2e-chat", + ) + tok = _trace_ctx.set(ctx) + tools_tok = _tool_calls.set([]) + + try: + agent = Agent( + output_type=str, + instructions="Reply with 'pong'.", + name="e2e_test_agent", + ) + agent.run_sync("ping", model="test") + finally: + _trace_ctx.reset(tok) + _tool_calls.reset(tools_tok) + + # Verify PostHog received at least one $ai_generation event. + events = [c.kwargs["event"] for c in mock_posthog.capture.call_args_list] + assert "$ai_generation" in events, ( + f"Expected $ai_generation in captured events, got: {events}" + ) + + # Verify the trace metadata was attached. + gen_call = next( + c + for c in mock_posthog.capture.call_args_list + if c.kwargs["event"] == "$ai_generation" + ) + props = gen_call.kwargs["properties"] + assert props["$ai_trace_id"] == "e2e-trace" + assert props["$ai_session_id"] == "e2e-chat" + assert props["workspace_id"] == "e2e-ws" + finally: + # Clean up global instrumentation so other tests aren't affected. + Agent.instrument_all(None) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py index eebb6175ab..e4cc97f51d 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py @@ -1,3 +1,27 @@ -from baserow_enterprise.assistant.assistant import ToolHelpers +from unittest.mock import MagicMock -fake_tool_helpers = ToolHelpers(lambda x: None, lambda x: None) +from baserow_enterprise.assistant.deps import AssistantDeps, ToolHelpers + + +def create_fake_tool_helpers() -> ToolHelpers: + """Create a fresh ToolHelpers instance for testing.""" + return ToolHelpers(lambda x: None, lambda x: None) + + +def make_test_ctx(user, workspace, tool_helpers=None): + """ + Build a mock ``RunContext[AssistantDeps]`` for unit-testing tool functions. + + Returns a ``MagicMock`` whose ``.deps`` attribute is a real + ``AssistantDeps`` instance. + """ + + if tool_helpers is None: + tool_helpers = create_fake_tool_helpers() + ctx = MagicMock() + ctx.deps = AssistantDeps( + user=user, + workspace=workspace, + tool_helpers=tool_helpers, + ) + return ctx diff --git a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py index d309625c87..6596794695 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py @@ -24,9 +24,9 @@ def set_openai_api_key_env_var(): """ Set a dummy OpenAI API key for tests to prevent client instantiation errors. - udspy.LM() creates an OpenAI client that raises an error if OPENAI_API_KEY is not - set during client instantiation. This fixture ensures tests don't fail due to - missing API key configuration, which is not needed anyway. + Some pydantic-ai model backends create an OpenAI client that raises an error + if OPENAI_API_KEY is not set during client instantiation. This fixture ensures + tests don't fail due to missing API key configuration. """ if not os.getenv("OPENAI_API_KEY"): diff --git a/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py b/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py index 418f84e2b4..97d51450da 100755 --- a/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/enterprise/test_enterprise_license.py @@ -258,7 +258,7 @@ def test_user_data_no_enterprise_features_instance_wide_not_active( } -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) @responses.activate def test_check_licenses_with_enterprise_license_sends_usage_data( @@ -303,7 +303,7 @@ def test_check_licenses_with_enterprise_license_sends_usage_data( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_enterprise_license_counts_viewers_as_free( enterprise_data_fixture, data_fixture @@ -350,7 +350,7 @@ def test_enterprise_license_counts_viewers_as_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_who_is_editor_in_one_workspace_and_viewer_in_another_is_not_free( enterprise_data_fixture, data_fixture @@ -393,7 +393,7 @@ def test_user_who_is_editor_in_one_workspace_and_viewer_in_another_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_marked_for_deletion_is_not_counted_as_a_paid_user( enterprise_data_fixture, data_fixture @@ -439,7 +439,7 @@ def test_user_marked_for_deletion_is_not_counted_as_a_paid_user( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_deactivated_user_is_not_counted_as_a_paid_user( enterprise_data_fixture, data_fixture @@ -584,7 +584,7 @@ def test_enterprise_license_being_unregistered_sends_signal_to_all( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_table_role_is_not_free( enterprise_data_fixture, data_fixture, synced_roles @@ -622,7 +622,7 @@ def test_user_with_paid_table_role_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_table_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -660,7 +660,7 @@ def test_user_with_free_table_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_database_role_is_not_free( enterprise_data_fixture, data_fixture, synced_roles @@ -698,7 +698,7 @@ def test_user_with_paid_database_role_is_not_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_database_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -736,7 +736,7 @@ def test_user_with_free_database_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_table_role_is_not_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -777,7 +777,7 @@ def test_user_with_paid_table_role_is_not_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_table_role_is_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -818,7 +818,7 @@ def test_user_with_free_table_role_is_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_paid_database_role_is_not_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -859,7 +859,7 @@ def test_user_with_paid_database_role_is_not_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_free_database_role_is_free_from_team( enterprise_data_fixture, data_fixture, synced_roles @@ -900,7 +900,7 @@ def test_user_with_free_database_role_is_free_from_team( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_deleted_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -943,7 +943,7 @@ def test_user_in_deleted_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_inactive_user_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1002,7 +1002,7 @@ def test_inactive_user_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_inactive_user_in_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1064,7 +1064,7 @@ def test_inactive_user_in_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_to_be_deleted_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1123,7 +1123,7 @@ def test_user_to_be_deleted_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_to_be_deleted_in_team_with_paid_role_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1185,7 +1185,7 @@ def test_user_to_be_deleted_in_team_with_paid_role_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_complex_free_vs_paid_scenario( enterprise_data_fixture, data_fixture, synced_roles @@ -1281,7 +1281,7 @@ def test_complex_free_vs_paid_scenario( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_trashed_database_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1341,7 +1341,7 @@ def test_user_with_role_paid_on_trashed_database_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_database_in_trashed_workspace_is_free( enterprise_data_fixture, data_fixture, synced_roles, django_assert_num_queries @@ -1401,7 +1401,7 @@ def test_user_with_role_paid_on_database_in_trashed_workspace_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_role_paid_on_trashed_table_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1461,7 +1461,7 @@ def test_user_with_role_paid_on_trashed_table_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_team_with_role_paid_on_trashed_database_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1524,7 +1524,7 @@ def test_user_in_team_with_role_paid_on_trashed_database_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_in_team_with_role_paid_on_trashed_table_is_free( enterprise_data_fixture, data_fixture, synced_roles @@ -1587,7 +1587,7 @@ def test_user_in_team_with_role_paid_on_trashed_table_is_free( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_summary_calculation_for_enterprise_doesnt_do_n_plus_one_queries( enterprise_data_fixture, data_fixture, synced_roles, django_assert_num_queries @@ -1678,7 +1678,7 @@ def test_user_summary_calculation_for_enterprise_doesnt_do_n_plus_one_queries( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_can_query_for_summary_per_workspace( enterprise_data_fixture, data_fixture, synced_roles @@ -1814,7 +1814,7 @@ def test_can_query_for_summary_per_workspace( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_user_with_team_and_user_role_picks_highest_of_either( enterprise_data_fixture, data_fixture, synced_roles @@ -1860,7 +1860,7 @@ def test_user_with_team_and_user_role_picks_highest_of_either( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_order_of_roles_is_as_expected( enterprise_data_fixture, data_fixture, synced_roles @@ -1904,7 +1904,7 @@ def test_order_of_roles_is_as_expected( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_weird_workspace_user_permission_doesnt_break_usage_check( enterprise_data_fixture, data_fixture, synced_roles @@ -1936,7 +1936,7 @@ def test_weird_workspace_user_permission_doesnt_break_usage_check( ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_weird_ras_for_wrong_workspace_not_counted_when_querying_for_single_workspace_usage( enterprise_data_fixture, data_fixture, synced_roles @@ -1996,7 +1996,7 @@ def test_weird_ras_for_wrong_workspace_not_counted_when_querying_for_single_work ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_missing_roles_doesnt_cause_crash_and_members_admins_are_treated_as_non_free( enterprise_data_fixture, data_fixture, synced_roles @@ -2032,7 +2032,7 @@ def test_missing_roles_doesnt_cause_crash_and_members_admins_are_treated_as_non_ ) -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) @override_settings(DEBUG=True) def test_orphaned_paid_role_assignments_dont_get_counted( enterprise_data_fixture, data_fixture, synced_roles diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss index 5a73880ef1..712d4780cd 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss @@ -463,6 +463,12 @@ border: 0; outline: 0; + &--collapsed { + max-height: 250px; + overflow: hidden; + mask-image: linear-gradient(to bottom, black 200px, transparent); + } + // Markdown styles p { margin: 0 0 8px; @@ -629,6 +635,32 @@ vertical-align: middle; } +.assistant__reasoning-toggle { + display: flex; + justify-content: center; + width: 100%; + padding: 2px 0 0; + margin: 0; + border: 0; + background: none; + cursor: pointer; + color: #16829c; + opacity: 0.6; + + &:hover { + opacity: 1; + } +} + +.assistant__reasoning-chevron { + font-size: 12px; + transition: transform 0.2s ease; + + &--expanded { + transform: rotate(180deg); + } +} + .assistant__chat-history-spacer { width: 300px; } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue index 2734c3fc45..8bf587603b 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue @@ -28,10 +28,27 @@
+ Date: Fri, 20 Mar 2026 11:57:59 +0100 Subject: [PATCH 4/5] Fix error when changing filter type on link row fields (#5015) --- ...ror_when_changing_filter_type_on_link_row_fields.json | 9 +++++++++ .../database/components/view/ViewFilterTypeLinkRow.vue | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json diff --git a/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json b/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json new file mode 100644 index 0000000000..6d9e88583f --- /dev/null +++ b/changelog/entries/unreleased/bug/5014_fix_error_when_changing_filter_type_on_link_row_fields.json @@ -0,0 +1,9 @@ +{ + "type": "bug", + "message": "Fix error when changing filter type on link row fields", + "issue_origin": "github", + "issue_number": 5014, + "domain": "database", + "bullet_points": [], + "created_at": "2026-03-20" +} \ No newline at end of file diff --git a/web-frontend/modules/database/components/view/ViewFilterTypeLinkRow.vue b/web-frontend/modules/database/components/view/ViewFilterTypeLinkRow.vue index d939776284..2bb1582919 100644 --- a/web-frontend/modules/database/components/view/ViewFilterTypeLinkRow.vue +++ b/web-frontend/modules/database/components/view/ViewFilterTypeLinkRow.vue @@ -96,7 +96,7 @@ export default { .get('field', primary.type) .toHumanReadableString(primary, row[`field_${primary.id}`]) this.rowInfo = null - } else if (!this.isDropdown) { + } else if (!this.isDropdown && this.valid) { // Get the name from server this.loading = true try { From 0f420931d46876a83f12a8efc2b85e2e51730e2c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 20 Mar 2026 14:01:44 +0100 Subject: [PATCH 5/5] chore(deps): bump flatted from 3.4.1 to 3.4.2 in /web-frontend (#5019) Bumps [flatted](https://github.com/WebReflection/flatted) from 3.4.1 to 3.4.2. - [Commits](https://github.com/WebReflection/flatted/compare/v3.4.1...v3.4.2) --- updated-dependencies: - dependency-name: flatted dependency-version: 3.4.2 dependency-type: indirect ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- web-frontend/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/web-frontend/yarn.lock b/web-frontend/yarn.lock index 8de654bc49..f88bb16753 100644 --- a/web-frontend/yarn.lock +++ b/web-frontend/yarn.lock @@ -6662,9 +6662,9 @@ flat-cache@^6.1.19: hookified "^1.13.0" flatted@^3.2.9, flatted@^3.3.3: - version "3.4.1" - resolved "https://registry.yarnpkg.com/flatted/-/flatted-3.4.1.tgz#84ccd9579e76e9cc0d246c11d8be0beb019143e6" - integrity sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ== + version "3.4.2" + resolved "https://registry.yarnpkg.com/flatted/-/flatted-3.4.2.tgz#f5c23c107f0f37de8dbdf24f13722b3b98d52726" + integrity sha512-PjDse7RzhcPkIJwy5t7KPWQSZ9cAbzQXcafsetQoD7sOJRQlGikNbx7yZp2OotDnJyrDcbyRq3Ttb18iYOqkxA== flush-promises@^1.0.2: version "1.0.2"