diff --git a/backend/.flake8 b/backend/.flake8 index 4aefc9587b..eb50f5ee63 100644 --- a/backend/.flake8 +++ b/backend/.flake8 @@ -2,9 +2,9 @@ extend-ignore = E203, W503, F541, E501 max-doc-length = 88 per-file-ignores = - tests/*: F841, BAI001 - ../premium/backend/tests/*: F841, BAI001 - ../enterprise/backend/tests/*: F841, BAI001 + tests/*: F841 + ../premium/backend/tests/*: F841 + ../enterprise/backend/tests/*: F841 src/baserow/contrib/database/migrations/*: BDC001 src/baserow/core/migrations/*: BDC001 src/baserow/core/psycopg.py: BPG001 @@ -19,5 +19,4 @@ exclude = extension = BDC001 = flake8_baserow:DocstringPlugin BPG001 = flake8_baserow:BaserowPsycopgChecker - BAI001 = flake8_baserow:BaserowAIImportsChecker paths = ./flake8_plugins diff --git a/backend/flake8_plugins/flake8_baserow/__init__.py b/backend/flake8_plugins/flake8_baserow/__init__.py index 1766a8d398..1caa7e9514 100644 --- a/backend/flake8_plugins/flake8_baserow/__init__.py +++ b/backend/flake8_plugins/flake8_baserow/__init__.py @@ -1,5 +1,4 @@ from .docstring import Plugin as DocstringPlugin from .psycopg import BaserowPsycopgChecker -from .ai_imports import BaserowAIImportsChecker -__all__ = ["DocstringPlugin", "BaserowPsycopgChecker", "BaserowAIImportsChecker"] +__all__ = ["DocstringPlugin", "BaserowPsycopgChecker"] diff --git a/backend/flake8_plugins/flake8_baserow/ai_imports.py b/backend/flake8_plugins/flake8_baserow/ai_imports.py deleted file mode 100644 index 8d47707eff..0000000000 --- a/backend/flake8_plugins/flake8_baserow/ai_imports.py +++ /dev/null @@ -1,86 +0,0 @@ -import ast -from typing import Iterator, Tuple, Any - - -class BaserowAIImportsChecker: - """ - Flake8 plugin to ensure dspy and litellm are only imported locally within - functions/methods, not at module level. - """ - - name = "flake8-baserow-ai-imports" - version = "0.1.0" - - def __init__(self, tree: ast.AST, filename: str): - self.tree = tree - self.filename = filename - - def run(self) -> Iterator[Tuple[int, int, str, Any]]: - """Check for global imports of dspy and litellm.""" - for node in ast.walk(self.tree): - # Check if this is a module-level import (not inside a function/method) - if self._is_global_import(node): - if isinstance(node, ast.Import): - for alias in node.names: - if self._is_ai_module(alias.name): - yield ( - node.lineno, - node.col_offset, - f"BAI001 {alias.name} must be imported locally within functions/methods, not globally", - type(self), - ) - elif isinstance(node, ast.ImportFrom): - if node.module and self._is_ai_module(node.module): - yield ( - node.lineno, - node.col_offset, - f"BAI001 {node.module} must be imported locally within functions/methods, not globally", - type(self), - ) - - def _is_ai_module(self, module_name: str) -> bool: - """Check if the module is dspy or litellm (including submodules).""" - if not module_name: - return False - return ( - module_name == "dspy" - or module_name.startswith("dspy.") - or module_name == "litellm" - or module_name.startswith("litellm.") - ) - - def _is_global_import(self, node: ast.AST) -> bool: - """ - Check if an import node is at global scope. - Returns True if the import is not nested inside a function or method. - """ - if not isinstance(node, (ast.Import, ast.ImportFrom)): - return False - - # Walk up the AST to find if this import is inside a function/method - # We need to check the parent nodes, but ast.walk doesn't provide parent info - # So we'll traverse the tree differently - return self._check_node_is_global(self.tree, node) - - def _check_node_is_global( - self, root: ast.AST, target: ast.AST, in_function: bool = False - ) -> bool: - """ - Recursively check if target node is at global scope. - Returns True if the target is found at global scope (not in a function). - """ - if root is target: - return not in_function - - # Check if we're entering a function/method - new_in_function = in_function or isinstance( - root, (ast.FunctionDef, ast.AsyncFunctionDef) - ) - - # Recursively check all child nodes - for child in ast.iter_child_nodes(root): - result = self._check_node_is_global(child, target, new_in_function) - if result is not None: - return result - - return None diff --git a/backend/flake8_plugins/tests/test_flake8_baserow_ai_imports.py b/backend/flake8_plugins/tests/test_flake8_baserow_ai_imports.py deleted file mode 100644 index e36106ce7b..0000000000 --- a/backend/flake8_plugins/tests/test_flake8_baserow_ai_imports.py +++ /dev/null @@ -1,212 +0,0 @@ -import ast -from flake8_baserow.ai_imports import BaserowAIImportsChecker - - -def run_checker(code: str): - """Helper to run the checker on code and return errors.""" - tree = ast.parse(code) - checker = BaserowAIImportsChecker(tree, "test.py") - return list(checker.run()) - - -def test_global_dspy_import(): - """Test that global dspy imports are flagged.""" - code = """ -import dspy -""" - errors = run_checker(code) - assert len(errors) == 1 - assert "BAI001" in errors[0][2] - assert "dspy" in errors[0][2] - - -def test_global_litellm_import(): - """Test that global litellm imports are flagged.""" - code = """ -import litellm -""" - errors = run_checker(code) - assert len(errors) == 1 - assert "BAI001" in errors[0][2] - assert "litellm" in errors[0][2] - - -def test_global_dspy_from_import(): - """Test that global 'from dspy import' statements are flagged.""" - code = """ -from dspy import ChainOfThought -from dspy.predict import Predict -""" - errors = run_checker(code) - assert len(errors) == 2 - assert all("BAI001" in error[2] for error in errors) - - -def test_global_litellm_from_import(): - """Test that global 'from litellm import' statements are flagged.""" - code = """ -from litellm import completion -from litellm.utils import get_llm_provider -""" - errors = run_checker(code) - assert len(errors) == 2 - assert all("BAI001" in error[2] for error in errors) - - -def test_local_import_in_function(): - """Test that local imports within functions are allowed.""" - code = """ -def my_function(): - import dspy - import litellm - from dspy import ChainOfThought - from litellm import completion - return dspy, litellm -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_local_import_in_method(): - """Test that local imports within class methods are allowed.""" - code = """ -class MyClass: - def my_method(self): - import dspy - from litellm import completion - return dspy.ChainOfThought() -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_local_import_in_async_function(): - """Test that local imports within async functions are allowed.""" - code = """ -async def my_async_function(): - import dspy - from litellm import acompletion - return await acompletion() -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_mixed_global_and_local_imports(): - """Test that global imports are flagged while local imports are not.""" - code = """ -import dspy # This should be flagged - -def my_function(): - import litellm # This should be OK - return litellm.completion() - -from dspy import ChainOfThought # This should be flagged -""" - errors = run_checker(code) - assert len(errors) == 2 - assert all("BAI001" in error[2] for error in errors) - - -def test_nested_function_imports(): - """Test that imports in nested functions are allowed.""" - code = """ -def outer_function(): - def inner_function(): - import dspy - from litellm import completion - return dspy, completion - return inner_function() -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_other_imports_not_affected(): - """Test that other imports are not flagged.""" - code = """ -import os -import sys -from typing import List -from baserow.core.models import User -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_multiple_global_imports(): - """Test multiple global AI imports.""" - code = """ -import dspy -import litellm -from dspy import ChainOfThought -from litellm import completion -import os # This should not be flagged -""" - errors = run_checker(code) - assert len(errors) == 4 - assert all("BAI001" in error[2] for error in errors) - - -def test_import_with_alias(): - """Test that imports with aliases are also caught.""" - code = """ -import dspy as d -import litellm as llm - -def my_function(): - import dspy as local_d - return local_d -""" - errors = run_checker(code) - assert len(errors) == 2 - assert all("BAI001" in error[2] for error in errors) - - -def test_submodule_imports(): - """Test that submodule imports are caught at global scope.""" - code = """ -from dspy.teleprompt import BootstrapFewShot -from litellm.utils import token_counter - -def my_function(): - from dspy.predict import Predict - from litellm.integrations import log_event - return Predict, log_event -""" - errors = run_checker(code) - assert len(errors) == 2 - assert all("BAI001" in error[2] for error in errors) - # Verify the errors are for the global imports - assert errors[0][0] == 2 # Line number of first import - assert errors[1][0] == 3 # Line number of second import - - -def test_class_method_and_staticmethod(): - """Test that imports in classmethods and staticmethods are allowed.""" - code = """ -class MyClass: - @classmethod - def my_classmethod(cls): - import dspy - return dspy - - @staticmethod - def my_staticmethod(): - from litellm import completion - return completion -""" - errors = run_checker(code) - assert len(errors) == 0 - - -def test_lambda_not_considered_function(): - """Test that imports in lambdas (which aren't supported anyway) at module level are flagged.""" - code = """ -# Note: This is contrived since you can't actually have imports in lambdas, -# but this tests that lambda doesn't count as a function scope -import dspy -""" - errors = run_checker(code) - assert len(errors) == 1 - assert "BAI001" in errors[0][2] diff --git a/backend/requirements/base.in b/backend/requirements/base.in index 987610e50e..356bcc0f99 100644 --- a/backend/requirements/base.in +++ b/backend/requirements/base.in @@ -25,7 +25,7 @@ django-celery-beat==2.6.0 celery-redbeat==2.2.0 flower==2.0.1 service-identity==24.1.0 -regex==2024.4.28 +regex==2025.10.23 antlr4-python3-runtime==4.9.3 tqdm==4.66.4 boto3==1.40.40 @@ -86,7 +86,6 @@ tornado==6.5.0 # Pinned to address vulnerability. certifi==2025.4.26 # Pinned to address vulnerability. httpcore==1.0.9 # Pinned to address vulnerability. genson==1.3.0 -dspy-ai==3.0.3 -litellm==1.77.7 # Pinned to avoid bug in 1.75.3 requiring litellm[proxy] pyotp==2.9.0 qrcode==8.2 +udspy==0.1.6 diff --git a/backend/requirements/base.txt b/backend/requirements/base.txt index 86d3af0f3f..53b51bb15a 100644 --- a/backend/requirements/base.txt +++ b/backend/requirements/base.txt @@ -6,14 +6,6 @@ # advocate==1.0.0 # via -r base.in -aiohappyeyeballs==2.6.1 - # via aiohttp -aiohttp==3.12.15 - # via litellm -aiosignal==1.4.0 - # via aiohttp -alembic==1.14.1 - # via optuna amqp==5.3.1 # via kombu annotated-types==0.7.0 @@ -25,8 +17,6 @@ antlr4-python3-runtime==4.9.3 anyio==4.8.0 # via # anthropic - # asyncer - # dspy # httpx # mcp # openai @@ -42,11 +32,8 @@ asgiref==3.8.1 # django # django-cors-headers # opentelemetry-instrumentation-asgi -asyncer==0.0.8 - # via dspy attrs==24.3.0 # via - # aiohttp # jsonschema # referencing # service-identity @@ -62,9 +49,7 @@ azure-core==1.32.0 azure-storage-blob==12.24.0 # via django-storages backoff==2.2.1 - # via - # dspy - # posthog + # via posthog billiard==4.2.1 # via celery boto3==1.40.40 @@ -76,9 +61,7 @@ botocore==1.40.40 brotli==1.1.0 # via -r base.in cachetools==5.5.0 - # via - # dspy - # google-auth + # via google-auth celery[redis]==5.5.3 # via # -r base.in @@ -114,7 +97,6 @@ click==8.1.8 # click-didyoumean # click-plugins # click-repl - # litellm # uvicorn click-didyoumean==0.3.1 # via celery @@ -122,10 +104,6 @@ click-plugins==1.1.1 # via celery click-repl==0.3.0 # via celery -cloudpickle==3.1.1 - # via dspy -colorlog==6.9.0 - # via optuna constantly==23.10.4 # via twisted cron-descriptor==1.4.5 @@ -141,8 +119,6 @@ daphne==4.1.2 # via channels defusedxml==0.7.1 # via pysaml2 -diskcache==5.6.3 - # via dspy distro==1.9.0 # via # anthropic @@ -195,10 +171,6 @@ djangorestframework-simplejwt==5.3.1 # via -r base.in drf-spectacular==0.28.0 # via -r base.in -dspy==3.0.3 - # via dspy-ai -dspy-ai==3.0.3 - # via -r base.in elementpath==4.7.0 # via xmlschema et-xmlfile==2.0.0 @@ -207,22 +179,14 @@ eval-type-backport==0.2.2 # via mistralai faker==25.0.1 # via -r base.in -fastuuid==0.13.5 - # via litellm filelock==3.16.1 # via huggingface-hub flower==2.0.1 # via -r base.in -frozenlist==1.7.0 - # via - # aiohttp - # aiosignal fsspec==2024.12.0 # via huggingface-hub genson==1.3.0 # via -r base.in -gepa[dspy]==0.0.7 - # via dspy google-api-core==2.24.0 # via # google-cloud-core @@ -262,7 +226,6 @@ httpx==0.27.2 # via # anthropic # langsmith - # litellm # mcp # mistralai # ollama @@ -286,11 +249,8 @@ idna==3.10 # hyperlink # requests # twisted - # yarl importlib-metadata==8.4.0 - # via - # litellm - # opentelemetry-api + # via opentelemetry-api incremental==24.7.2 # via twisted inflection==0.5.1 @@ -299,22 +259,17 @@ isodate==0.7.2 # via azure-storage-blob itsdangerous==2.2.0 # via -r base.in -jinja2==3.1.6 - # via litellm jira2markdown==0.3.7 # via -r base.in -jiter==0.8.2 +jiter==0.11.1 # via # anthropic # openai + # udspy jmespath==1.0.1 # via # boto3 # botocore -joblib==1.5.2 - # via dspy -json-repair==0.51.0 - # via dspy jsonpatch==1.33 # via langchain-core jsonpath-python==1.0.6 @@ -322,9 +277,7 @@ jsonpath-python==1.0.6 jsonpointer==3.0.0 # via jsonpatch jsonschema==4.25.0 - # via - # drf-spectacular - # litellm + # via drf-spectacular jsonschema-specifications==2025.4.1 # via jsonschema kombu[redis]==5.5.4 @@ -344,22 +297,10 @@ langsmith==0.4.10 # via # langchain # langchain-core -litellm==1.77.7 - # via - # -r base.in - # dspy loguru==0.7.2 # via -r base.in -magicattr==0.1.6 - # via dspy -mako==1.3.10 - # via alembic markdown-it-py==3.0.0 # via rich -markupsafe==3.0.3 - # via - # jinja2 - # mako mcp==1.9.4 # via -r base.in mdurl==0.1.2 @@ -370,10 +311,6 @@ monotonic==1.6 # via posthog msgpack==1.1.0 # via channels-redis -multidict==6.6.4 - # via - # aiohttp - # yarl mypy-extensions==1.0.0 # via typing-inspect ndg-httpsclient==0.5.1 @@ -381,10 +318,7 @@ ndg-httpsclient==0.5.1 netifaces==0.11.0 # via advocate numpy==2.3.3 - # via - # dspy - # optuna - # pgvector + # via pgvector oauthlib==3.2.2 # via requests-oauthlib ollama==0.1.9 @@ -392,9 +326,8 @@ ollama==0.1.9 openai==2.2.0 # via # -r base.in - # dspy # langchain-openai - # litellm + # udspy openpyxl==3.1.5 # via -r base.in opentelemetry-api==1.37.0 @@ -503,12 +436,8 @@ opentelemetry-util-http==0.58b0 # opentelemetry-instrumentation-django # opentelemetry-instrumentation-requests # opentelemetry-instrumentation-wsgi -optuna==4.5.0 - # via dspy orjson==3.10.13 - # via - # dspy - # langsmith + # via langsmith packaging==23.2 # via # gunicorn @@ -517,7 +446,6 @@ packaging==23.2 # langchain-core # langsmith # opentelemetry-instrumentation - # optuna pgvector==0.4.1 # via -r base.in pillow==10.3.0 @@ -528,10 +456,6 @@ prometheus-client==0.21.1 # via flower prompt-toolkit==3.0.48 # via click-repl -propcache==0.3.2 - # via - # aiohttp - # yarl prosemirror @ https://github.com/fellowapp/prosemirror-py/archive/refs/tags/v0.3.5.zip # via -r base.in proto-plus==1.25.0 @@ -562,15 +486,14 @@ pycparser==2.22 pydantic==2.9.2 # via # anthropic - # dspy # langchain # langchain-core # langsmith - # litellm # mcp # mistralai # openai # pydantic-settings + # udspy pydantic-core==2.23.4 # via pydantic pydantic-settings==2.8.1 @@ -606,7 +529,6 @@ python-dateutil==2.8.2 # python-crontab python-dotenv==1.0.1 # via - # litellm # pydantic-settings # uvicorn python-multipart==0.0.20 @@ -622,10 +544,10 @@ pyyaml==6.0.2 # huggingface-hub # langchain # langchain-core - # optuna # uvicorn -redis==5.2.1 qrcode==8.2 + # via -r base.in +redis==5.2.1 # via # -r base.in # celery-redbeat @@ -637,17 +559,16 @@ referencing==0.36.2 # via # jsonschema # jsonschema-specifications -regex==2024.4.28 +regex==2025.10.23 # via # -r base.in - # dspy # tiktoken + # udspy requests==2.32.5 # via # -r base.in # advocate # azure-core - # dspy # google-api-core # google-cloud-storage # huggingface-hub @@ -664,9 +585,7 @@ requests-oauthlib==2.0.0 requests-toolbelt==1.0.0 # via langsmith rich==13.7.1 - # via - # -r base.in - # dspy + # via -r base.in rpds-py==0.26.0 # via # jsonschema @@ -694,10 +613,7 @@ sniffio==1.3.1 # httpx # openai sqlalchemy==2.0.36 - # via - # alembic - # langchain - # optuna + # via langchain sqlparse==0.5.3 # via django sse-starlette==2.2.1 @@ -709,16 +625,12 @@ starlette==0.46.1 tenacity==8.5.0 # via # celery-redbeat - # dspy # langchain-core + # udspy tiktoken==0.9.0 - # via - # langchain-openai - # litellm + # via langchain-openai tokenizers==0.21.0 - # via - # anthropic - # litellm + # via anthropic tornado==6.5 # via # -r base.in @@ -726,10 +638,8 @@ tornado==6.5 tqdm==4.66.4 # via # -r base.in - # dspy # huggingface-hub # openai - # optuna twisted[tls]==24.11.0 # via # -r base.in @@ -739,8 +649,6 @@ txaio==23.1.1 typing-extensions==4.11.0 # via # -r base.in - # aiosignal - # alembic # anthropic # anyio # azure-core @@ -767,6 +675,8 @@ tzdata==2025.2 # -r base.in # django-celery-beat # kombu +udspy==0.1.6 + # via -r base.in unicodecsv==0.14.1 # via -r base.in uritemplate==4.1.1 @@ -807,10 +717,6 @@ wrapt==1.17.0 # opentelemetry-instrumentation-redis xmlschema==2.5.1 # via pysaml2 -xxhash==3.5.0 - # via dspy -yarl==1.20.1 - # via aiohttp zipp==3.19.1 # via # -r base.in diff --git a/docs/development/embeddings-server.md b/docs/development/embeddings-server.md new file mode 100644 index 0000000000..39706280cb --- /dev/null +++ b/docs/development/embeddings-server.md @@ -0,0 +1,33 @@ +# Setup embeddings server in dev environment + +If you would like to use the AI-assistant in combination with the search documentation +tool, then you must add the embeddings server. This + +## Docker compose + +Add the following to your `docker-compose.dev.yml` and then (re)start your dev server. +The `BASEROW_EMBEDDINGS_API_URL=http://embeddings:80` variable is already configured by +default for the backend container, so it should then work out of the box. + +```yaml + embeddings: + build: + context: ./embeddings + dockerfile: Dockerfile + ports: + - "${HOST_PUBLISH_IP:-127.0.0.1}:7999:80" + networks: + local: + restart: unless-stopped + healthcheck: + test: + [ + "CMD", + "python", + "-c", + "import requests; requests.get('http://localhost/health').raise_for_status()", + ] + interval: 1m30s + timeout: 10s + retries: 3 +``` diff --git a/docs/index.md b/docs/index.md index 5fcfd85345..f89748526b 100644 --- a/docs/index.md +++ b/docs/index.md @@ -54,6 +54,8 @@ Baserow by following one the guides below: the supported and recommended runtime dependencies. * [Monitoring Baserow](installation/monitoring.md): Learn how to monitor your Baserow server using open telemetry. +* [Setup AI-assistant](installation/ai-assistant.md): A quick guide on how to setup the + AI-assistant. ## Baserow Tutorials diff --git a/docs/installation/ai-assistant.md b/docs/installation/ai-assistant.md new file mode 100644 index 0000000000..9a24ce89a3 --- /dev/null +++ b/docs/installation/ai-assistant.md @@ -0,0 +1,88 @@ +# Baserow AI-Assistant: Quick DevOps Setup + +This guide shows how to enable the AI-assistant in Baserow, configure the required +environment variables, and (optionally) turn on knowledge-base lookups via an embeddings +server. + +## 1) Core concepts + +- The assistant runs via **UDSPy** — see https://github.com/baserow/udspy/ +- UDSPy speaks to **any OpenAI-compatible API**. +- You **must** set `BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL` with the provider and model + of your choosing. +- The assistant has been mostly tested with the `gpt-oss-120b` family. Other models can + work as well. + +## 2) Minimal enablement + +Set the model you want, restart Baserow, and let migrations run. + +```dotenv +# Required +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-5-mini +OPENAI_API_KEY=your_api_key +``` + +## 3) Provider presets + +Choose **one** provider block and set its variables. + +### OpenAI / OpenAI-compatible + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=openai/gpt-5-mini +OPENAI_API_KEY=your_api_key +# Optional alternative endpoints (OpenAI EU or Azure OpenAI, etc.) +UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://eu.api.openai.com/v1 +# or +UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=https://.openai.azure.com +# or any OpenAI compatible endpoint +``` + +### AWS Bedrock + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=bedrock/openai.gpt-oss-120b-1:0 +AWS_BEARER_TOKEN_BEDROCK=your_bedrock_token +AWS_REGION_NAME=eu-central-1 +``` + +### Groq + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=groq/openai/gpt-oss-120b +GROQ_API_KEY=your_api_key +``` + +### Ollama + +```dotenv +BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL=ollama/gpt-oss:120b +OLLAMA_API_KEY=your_api_key +# Optionally and alternative endpoint +UDSPY_LM_OPENAI_COMPATIBLE_BASE_URL=http://localhost:11434/v1 +``` + +Under the hood, UDSPy auto-detects provider from the model prefix and builds an +OpenAI-compatible client accordingly. + +## 4) Knowledge-base lookup + +If your deployment method doesn’t auto-provision embeddings, run the Baserow embeddings +service and point Baserow at it. + +### Run the embeddings container + +```bash +docker run -d --name baserow-embeddings -p 80:80 baserow/embeddings:latest +``` + +### Point Baserow to it + +```dotenv +BASEROW_EMBEDDINGS_API_URL=http://your-embedder-service +# e.g., http://localhost if you mapped -p 80:80 locally +# Then restart Baserow and allow migrations to run. +``` + +After restart and migrations, knowledge-base lookup will be available. diff --git a/docs/installation/configuration.md b/docs/installation/configuration.md index 0b4b5772fa..f5735f2be6 100644 --- a/docs/installation/configuration.md +++ b/docs/installation/configuration.md @@ -180,8 +180,17 @@ The installation methods referred to in the variable descriptions are: | BASEROW\_ENTERPRISE\_MAX\_PERIODIC\_DATA\_SYNC\_CONSECUTIVE\_ERRORS | The maximum number of consecutive periodic data sync error before it's disabled. | 4 | | BASEROW\_DEADLOCK\_INITIAL\_BACKOFF | The initial backoff time for database deadlock retries. | 2 | | BASEROW\_DEADLOCK\_MAX\_RETRIES | The maximum number of database deadlock retries. | 1 | -| BASEROW\_EMBEDDINGS\_API\_URL | If not empty, the AI-assistant will use this as embedding server for the documentation lookup. | "" (empty string) | -| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Overview of all models: https://docs.litellm.ai/docs/providers. Provide like `groq/openai/gpt-oss-120b` or `bedrock/anthropic.claude-3-sonnet-20240229-v1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions are in the LiteLLM docs. | "" (empty string) | + +### AI-assistant Configuration +| Name | Description | Defaults | +|---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------| +| BASEROW\_EMBEDDINGS\_API\_URL | If not empty, the AI-assistant will use this as embedding server for the knowledge base lookup. Must point to a container running this image: https://hub.docker.com/r/baserow/embeddings | "" (empty string) | +| BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL | If not empty, then this model will be used for the AI-assistant. Provide like `groq/openai/gpt-oss-120b` or `bedrock/openai.gpt-oss-120b-1:0`. Note that additional API keys must be provided as environment variable depending on the provider. Instructions can be found at https://baserow.github.io/udspy/ | "" (empty string) | +| AWS\_BEARER\_TOKEN\_BEDROCK | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then this environment variable must be set. Instructions on how to obtain: https://docs.aws.amazon.com/bedrock/latest/userguide/api-keys-use.html | "" (empty string) | +| AWS\_REGION\_NAME | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=groq/openai/gpt-oss-120b, then the AWS region for the AI-assistant can be provided here. | us-east-1 | +| GROQ\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=bedrock/bedrock/openai.gpt-oss-120b-1:0, then the Groq API key can be provided here. | "" (empty string) | +| OLLAMA\_API\_KEY | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=ollama/gpt-oss:120b, then the Ollama API key can be provided here. | "" (empty string) | +| UDSPY\_LM\_OPENAI\_COMPATIBLE\_BASE\_URL | If the BASEROW\_ENTERPRISE\_ASSISTANT\_LLM\_MODEL=openai/gpt-5-nano, then the base URL can be changed here. This can be used to point to a different OpenAI compatible API like Azure, for example. | "" (empty string) | ### Data sync configuration diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py b/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py index f53870db51..227c7aec7b 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py @@ -13,7 +13,7 @@ ( "The specified language model is not supported or the provided API key is missing/invalid. " "Ensure you have set the correct provider API key and selected a compatible model in " - "`BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL`. See https://docs.litellm.ai/docs/providers for " + "`BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL`. See https://baserow.github.io/udspy/ for " "supported models, required environment variables, and example configuration." ), ) diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py index 437c2df373..4f78764734 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py @@ -161,6 +161,13 @@ class AiThinkingSerializer(serializers.Serializer): ) +class AiReasoningSerializer(serializers.Serializer): + type = serializers.CharField(default=AssistantMessageType.AI_REASONING) + content = serializers.CharField( + help_text="The reasoning content of the AI message." + ) + + class AiNavigationSerializer(serializers.Serializer): type = serializers.CharField(default=AssistantMessageType.AI_NAVIGATION) location = serializers.DictField(help_text=("The location to navigate to.")) @@ -192,7 +199,8 @@ class HumanMessageSerializer(serializers.Serializer): AssistantMessageType.CHAT_TITLE: ChatTitleMessageSerializer, AssistantMessageType.HUMAN: HumanMessageSerializer, AssistantMessageType.AI_MESSAGE: AiMessageSerializer, - AssistantMessageType.AI_THINKING: AiThinkingSerializer, + AssistantMessageType.AI_THINKING: AiThinkingSerializer, # Update the satus bar in the UI + AssistantMessageType.AI_REASONING: AiReasoningSerializer, # Show reasoning steps before the final answer AssistantMessageType.AI_NAVIGATION: AiNavigationSerializer, AssistantMessageType.AI_ERROR: AiErrorMessageSerializer, } diff --git a/enterprise/backend/src/baserow_enterprise/assistant/adapter.py b/enterprise/backend/src/baserow_enterprise/assistant/adapter.py deleted file mode 100644 index c6c2ed321b..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/adapter.py +++ /dev/null @@ -1,21 +0,0 @@ -from .prompts import ASSISTANT_SYSTEM_PROMPT - - -def get_chat_adapter(): - import dspy # local import to save memory when not used - - class ChatAdapter(dspy.ChatAdapter): - def format_field_description(self, signature: type[dspy.Signature]) -> str: - """ - This is the first part of the prompt the LLM sees, so we prepend our custom - system prompt to it to give it the personality and context of Baserow. - """ - - field_description = super().format_field_description(signature) - return ( - ASSISTANT_SYSTEM_PROMPT - + "## TASK INSTRUCTIONS:\n\n" - + field_description - ) - - return ChatAdapter() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index 50930f180c..c90f8a46f9 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -5,23 +5,26 @@ from django.conf import settings from django.utils import translation +import udspy +from udspy.callback import BaseCallback + from baserow.api.sessions import get_client_undo_redo_action_group_id from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError from baserow_enterprise.assistant.tools.navigation.types import AnyNavigationRequestType from baserow_enterprise.assistant.tools.navigation.utils import unsafe_navigate_to from baserow_enterprise.assistant.tools.registries import assistant_tool_registry -from .adapter import get_chat_adapter from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction +from .prompts import ASSISTANT_SYSTEM_PROMPT from .types import ( AiMessage, AiMessageChunk, AiNavigationMessage, + AiReasoningChunk, AiThinkingMessage, AssistantMessageUnion, ChatTitleMessage, HumanMessage, - UIContext, ) @@ -36,95 +39,93 @@ class AssistantMessagePair(TypedDict): answer: str -def get_assistant_callbacks(): - from dspy.utils.callback import BaseCallback - - class AssistantCallbacks(BaseCallback): - def __init__(self): - self.tool_calls = {} - self.sources = [] +class AssistantCallbacks(BaseCallback): + def __init__(self, tool_helpers: ToolHelpers | None = None): + self.tool_helpers = tool_helpers + self.tool_calls = {} + self.sources = [] - def extend_sources(self, sources: list[str]) -> None: - """ - Extends the current list of sources with new ones, avoiding duplicates. + def extend_sources(self, sources: list[str]) -> None: + """ + Extends the current list of sources with new ones, avoiding duplicates. - :param sources: The list of new source URLs to add. - :return: None - """ + :param sources: The list of new source URLs to add. + :return: None + """ - self.sources.extend([s for s in sources if s not in self.sources]) + self.sources.extend([s for s in sources if s not in self.sources]) - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ) -> None: - """ - Called when a tool starts. It records the tool call and invokes the - corresponding tool's on_tool_start method if it exists. + def on_tool_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ) -> None: + """ + Called when a tool starts. It records the tool call and invokes the + corresponding tool's on_tool_start method if it exists. - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the tool being called. - :param inputs: The inputs provided to the tool. - """ + :param call_id: The unique identifier of the tool call. + :param instance: The instance of the tool being called. + :param inputs: The inputs provided to the tool. + """ - try: - assistant_tool_registry.get(instance.name).on_tool_start( - call_id, instance, inputs - ) - self.tool_calls[call_id] = (instance, inputs) - except assistant_tool_registry.does_not_exist_exception_class: - pass - - def on_tool_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ) -> None: - """ - Called when a tool ends. It invokes the corresponding tool's on_tool_end - method if it exists and updates the sources if the tool produced any. - - :param call_id: The unique identifier of the tool call. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. - """ + try: + assistant_tool_registry.get(instance.name).on_tool_start( + call_id, instance, inputs + ) + self.tool_calls[call_id] = (instance, inputs) + except assistant_tool_registry.does_not_exist_exception_class: + pass - if call_id not in self.tool_calls: - return + def on_tool_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ) -> None: + """ + Called when a tool ends. It invokes the corresponding tool's on_tool_end + method if it exists and updates the sources if the tool produced any. + + :param call_id: The unique identifier of the tool call. + :param outputs: The outputs returned by the tool, or None if there was an + exception. + :param exception: The exception raised by the tool, or None if it was + successful. + """ - instance, inputs = self.tool_calls.pop(call_id) - assistant_tool_registry.get(instance.name).on_tool_end( - call_id, instance, inputs, outputs, exception - ) + if call_id not in self.tool_calls: + return - # If the tool produced sources, add them to the overall list of sources. - if isinstance(outputs, dict) and "sources" in outputs: - self.extend_sources(outputs["sources"]) + instance, inputs = self.tool_calls.pop(call_id) + assistant_tool_registry.get(instance.name).on_tool_end( + call_id, instance, inputs, outputs, exception + ) - return AssistantCallbacks() + if exception is not None and self.tool_helpers is not None: + self.tool_helpers.update_status( + f"Calling the {instance.name} tool encountered an error." + ) + # If the tool produced sources, add them to the overall list of sources. + if isinstance(outputs, dict) and "sources" in outputs: + self.extend_sources(outputs["sources"]) -def get_chat_signature(): - import dspy # local import to save memory when not used - class ChatSignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - ui_context: UIContext | None = dspy.InputField( - default=None, - desc=( - "The frontend UI content the user is currently in. " - "Whenever make sense, use it to ground your answer." - ), - ) - answer: str = dspy.OutputField() +class ChatSignature(udspy.Signature): + __doc__ = f"{ASSISTANT_SYSTEM_PROMPT}\n TASK INSTRUCTIONS: \n" - return ChatSignature + question: str = udspy.InputField() + ui_context: dict[str, Any] | None = udspy.InputField( + default=None, + desc=( + "The context the user is currently in. " + "It contains information about the user, the workspace, open table, view, etc." + "Whenever make sense, use it to ground your answer." + ), + ) + answer: str = udspy.OutputField() class Assistant: @@ -137,25 +138,16 @@ def __init__(self, chat: AssistantChat): self._init_assistant() def _init_lm_client(self): - import dspy # local import to save memory when not used - lm_model = settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL - - self._lm_client = dspy.LM( - model=lm_model, - cache=not settings.DEBUG, - max_retries=5, - max_tokens=32000, - ) + self._lm_client = udspy.LM(model=lm_model) def _init_assistant(self): - from .react import ReAct # local import to save memory when not used - - tool_helpers = self.get_tool_helpers() + self.tool_helpers = self.get_tool_helpers() tools = assistant_tool_registry.list_all_usable_tools( - self._user, self._workspace, tool_helpers + self._user, self._workspace, self.tool_helpers ) - self._assistant = ReAct(get_chat_signature(), tools=tools) + self.callbacks = AssistantCallbacks(self.tool_helpers) + self._assistant = udspy.ReAct(ChatSignature, tools=tools, max_iters=20) self.history = None async def acreate_chat_message( @@ -231,9 +223,9 @@ def list_chat_messages( ) return list(reversed(messages)) - async def aload_chat_history(self, limit=20): + async def aload_chat_history(self, limit=30): """ - Loads the chat history into a dspy.History object. It only loads complete + Loads the chat history into a udspy.History object. It only loads complete message pairs (human + AI). The history will be in chronological order and must respect the module signature (question, answer). @@ -241,14 +233,13 @@ async def aload_chat_history(self, limit=20): :return: None """ - import dspy # local import to save memory when not used - last_saved_messages: list[AssistantChatMessage] = [ msg async for msg in self._chat.messages.order_by("-created_on")[:limit] ] - messages = [] + self.history = udspy.History() while len(last_saved_messages) >= 2: + # Pop the oldest message pair to respect chronological order. first_message = last_saved_messages.pop() next_message = last_saved_messages[-1] if ( @@ -257,43 +248,20 @@ async def aload_chat_history(self, limit=20): ): continue - human_question = first_message + self.history.add_user_message(first_message.content) ai_answer = last_saved_messages.pop() - messages.append( - AssistantMessagePair( - question=human_question.content, - answer=ai_answer.content, - ) - ) - - self.history = dspy.History(messages=messages) + self.history.add_assistant_message(ai_answer.content) @lru_cache(maxsize=1) def check_llm_ready_or_raise(self): - import dspy # local import to save memory when not used - from litellm import get_supported_openai_params - - lm = self._lm_client - params = get_supported_openai_params(lm.model) - if params is None or "tools" not in params: - raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or could not be found. " - "Please make sure the model name is correct, it can use tools, " - "and that your API key has access to it." - ) - try: - with dspy.context(lm=lm): - lm("Say ok if you can read this.") + self._lm_client("Say ok if you can read this.") except Exception as e: raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or accessible: {e}" + f"The model '{self._lm_client.model}' is not supported or accessible: {e}" ) def get_tool_helpers(self) -> ToolHelpers: - from dspy.dsp.utils.settings import settings as dspy_settings - from dspy.streaming.messages import sync_send_to_stream - def update_status_localized(status: str): """ Sends a localized message to the frontend to update the assistant status. @@ -302,16 +270,62 @@ def update_status_localized(status: str): """ with translation.override(self._user.profile.language): - stream = dspy_settings.send_stream - - if stream is not None: - sync_send_to_stream(stream, AiThinkingMessage(content=status)) + udspy.emit_event(AiThinkingMessage(content=status)) return ToolHelpers( update_status=update_status_localized, navigate_to=unsafe_navigate_to, ) + async def _generate_chat_title( + self, user_message: HumanMessage, ai_msg: AiMessage + ) -> str: + """ + Generates a title for the chat based on the user message and AI response. + """ + + title_generator = udspy.Predict( + udspy.Signature.from_string( + "user_message, ai_response -> chat_title", + "Create a short title for the following chat conversation.", + ) + ) + rsp = await title_generator.aforward( + user_message=user_message.content, + ai_response=ai_msg.content[:300], + ) + return rsp.chat_title + + async def _acreate_ai_message_response( + self, + human_msg: HumanMessage, + final_prediction: udspy.Prediction, + sources: list[str], + ) -> AiMessage: + ai_msg = await self.acreate_chat_message( + AssistantChatMessage.Role.AI, + final_prediction.answer, + artifacts={"sources": sources}, + action_group_id=get_client_undo_redo_action_group_id(self._user), + ) + await AssistantChatPrediction.objects.acreate( + human_message=human_msg, + ai_response=ai_msg, + prediction={ + "model": self._lm_client.model, + "trajectory": final_prediction.trajectory, + "reasoning": final_prediction.reasoning, + }, + ) + + # Yield final complete message + return AiMessage( + id=ai_msg.id, + content=final_prediction.answer, + sources=sources, + can_submit_feedback=True, + ) + async def astream_messages( self, human_message: HumanMessage ) -> AsyncGenerator[AssistantMessageUnion, None]: @@ -322,86 +336,65 @@ async def astream_messages( :return: An async generator that yields the response messages. """ - import dspy # local import to save memory when not used - from dspy.primitives.prediction import Prediction - from dspy.streaming import StreamListener, StreamResponse - - callback_manager = get_assistant_callbacks() - - with dspy.context( + with udspy.settings.context( lm=self._lm_client, - cache=not settings.DEBUG, - callbacks=[*dspy.settings.config.callbacks, callback_manager], - adapter=get_chat_adapter(), + callbacks=[*udspy.settings.callbacks, self.callbacks], ): if self.history is None: await self.aload_chat_history() - # Follow the stream of all output fields - stream_listeners = [ - StreamListener(signature_field_name="answer"), - ] + user_question = human_message.content + if self.history.messages: # Enhance question context based on chat history + predictor = udspy.Predict("question, context -> enhanced_question") + user_question = ( + await predictor.aforward( + question=user_question, context=self.history.messages + ) + ).enhanced_question - stream_predict = dspy.streamify( - self._assistant, - stream_listeners=stream_listeners, - ) - output_stream = stream_predict( - history=self.history, - question=human_message.content, - ui_context=human_message.ui_context.model_dump_json( - exclude_none=True, indent=2 - ), + output_stream = self._assistant.astream( + question=user_question, + ui_context=human_message.ui_context.model_dump_json(exclude_none=True), ) human_msg = await self.acreate_chat_message( AssistantChatMessage.Role.HUMAN, human_message.content ) - answer = "" - async for stream_chunk in output_stream: - if isinstance(stream_chunk, StreamResponse): - # Accumulate chunks per field to deliver full, real‐time updates. - if stream_chunk.signature_field_name == "answer": - answer += stream_chunk.chunk + stream_reasoning = False + async for event in output_stream: + if isinstance(event, (AiThinkingMessage, AiNavigationMessage)): + # Start streaming reasoning from now on, since we are calling tools + # and updating the UI status + stream_reasoning = True + yield event + continue + + if isinstance(event, udspy.OutputStreamChunk): + # Stream the final answer chunks + if event.field_name == "answer": yield AiMessageChunk( - content=answer, sources=callback_manager.sources + content=event.content, + sources=self.callbacks.sources, ) - elif isinstance(stream_chunk, (AiThinkingMessage, AiNavigationMessage)): - # forward thinking/navigation messages as-is to the frontend - yield stream_chunk - elif isinstance(stream_chunk, Prediction): - # At the end of the prediction, save the AI message and the - # prediction details for future analysis and feedback. - ai_msg = await self.acreate_chat_message( - AssistantChatMessage.Role.AI, - answer, - artifacts={"sources": callback_manager.sources}, - action_group_id=get_client_undo_redo_action_group_id( - self._user - ), - ) - await AssistantChatPrediction.objects.acreate( - human_message=human_msg, - ai_response=ai_msg, - prediction={ - "model": self._lm_client.model, - "trajectory": stream_chunk.trajectory, - "reasoning": stream_chunk.reasoning, - }, - ) - # In case the streaming didn't work, make sure we yield at least one - # final message with the complete answer. - yield AiMessage( - id=ai_msg.id, - content=stream_chunk.answer, - sources=callback_manager.sources, - can_submit_feedback=True, - ) + continue + + if isinstance(event, udspy.Prediction): + if "next_thought" in event and stream_reasoning: + yield AiReasoningChunk(content=event.next_thought) - if not self._chat.title: - title_generator = dspy.Predict("question -> chat_title") - rsp = await title_generator.acall(question=human_message.content) - self._chat.title = rsp.chat_title - yield ChatTitleMessage(content=self._chat.title) - await self._chat.asave(update_fields=["title", "updated_on"]) + elif event.module is self._assistant: + ai_msg = await self._acreate_ai_message_response( + human_msg, event, self.callbacks.sources + ) + yield ai_msg + + if not self._chat.title: + chat_title = await self._generate_chat_title( + human_message, ai_msg + ) + yield ChatTitleMessage(content=chat_title) + self._chat.title = chat_title + await self._chat.asave( + update_fields=["title", "updated_on"] + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py index 5158931a4a..0c4f0eaa54 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py @@ -16,11 +16,14 @@ DATABASE_BUILDER_CONCEPTS = """ ### DATABASE BUILDER (no-code database) -**Structure**: Database → Tables → Fields + Rows + Views + Webhooks +**Structure**: Database → Tables → Fields + Views + Webhooks + Rows. Rows → comments. **Key concepts**: • **Fields**: Define schema (30+ types including link_row for relationships); one primary field per table • **Views**: Present data with filters/sorts/grouping/colors; can be shared, personal, or public +• **Rows**: Data records following the table schema; support for rich content (files, long text, formulas, numbers, dates, etc.). Changes are tracked in history. +• **Comments**: Threaded discussions on rows; mentions. +• **Formulas**: Computed fields using functions/operators; support for cross-table lookups • **Permissions**: RBAC at workspace/database/table/field levels; database tokens for API • **Data sync**: Table replication; **Webhooks**: Row/field/view event triggers """ @@ -33,6 +36,7 @@ **Key concepts**: • **Pages**: Routes with UI elements (buttons, tables, forms, etc.) • **Data Sources**: Connect to database tables/views; elements bind to them for dynamic content +• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes • **Workflows**: Event-driven actions (create/update rows, navigate, notifications) • **Publishing**: Requires domain configuration """ @@ -46,6 +50,8 @@ • **Trigger**: The single event that starts the workflow (e.g., row created/updated/deleted) • **Actions**: Tasks performed (e.g., create/update rows, send emails, call webhooks) • **Routers**: Conditional logic (if/else, switch) to control flow +• **Iterators**: Loop over lists of items +• **Formulas**: Reference data from previous nodes and compute values using functions/operators in nodes attributes • **Execution**: Runs in the background; monitor via logs • **History**: Track runs, successes, failures • **Publishing**: Requires at least one configured action @@ -66,13 +72,14 @@ • Be clear, concise, and actionable • For troubleshooting: ask for error messages or describe expected vs actual results • **NEVER** fabricate answers or URLs. Acknowledge when you can't be sure. -• When you have the tools to help, **ALWAYS** use them instead of answering with instructions. +• Use the tools whenever possible. Fallback to search_docs and provide instruction only when it's not possible to fulfill the request. Ground answers in the documentation. • When finished, briefly suggest one or more logical next steps only if they use tools you have access to and directly builds on what was just done. ## FORMATTING (CRITICAL) -• **No HTML**: Only Markdown (bold, italics, lists, code, tables) -• Prefer lists when possible. Numbered lists for steps; bulleted for others -• NEVER use tables. Use lists instead. +• Only use Markdown (bold, italics, lists, code blocks) +• Prefer lists in explanations. Numbered lists for steps; bulleted for others +• Use code blocks for examples, commands, snippets +• EXCEPTION: When showing database schema or query results, tables are acceptable ## BASEROW CONCEPTS """ diff --git a/enterprise/backend/src/baserow_enterprise/assistant/react.py b/enterprise/backend/src/baserow_enterprise/assistant/react.py deleted file mode 100644 index 08ecef15de..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/react.py +++ /dev/null @@ -1,271 +0,0 @@ -from typing import Any, Callable, Literal - -# This file is only imported when the assistant/react.py module is used, -# so the flake8 plugin will not complain about global imports of dspy and litell -import dspy # noqa: BAI001 -from dspy.adapters.types.tool import Tool # noqa: BAI001 -from dspy.predict.react import _fmt_exc # noqa: BAI001 -from dspy.primitives.module import Module # noqa: BAI001 -from dspy.signatures.signature import ensure_signature # noqa: BAI001 -from litellm import ContextWindowExceededError # noqa: BAI001 -from loguru import logger - -from .types import ToolsUpgradeResponse - -# Variant of dspy.predict.react.ReAct that accepts a "meta-tool": -# a callable that can produce tools at runtime (e.g. per-table schemas). -# This lets a single ReAct instance handle many different table signatures -# without creating a new agent for each request. - - -class ReAct(Module): - def __init__(self, signature, tools: list[Callable], max_iters: int = 100): - """ - ReAct stands for "Reasoning and Acting," a popular paradigm for building - tool-using agents. In this approach, the language model is iteratively provided - with a list of tools and has to reason about the current situation. The model - decides whether to call a tool to gather more information or to finish the task - based on its reasoning process. The DSPy version of ReAct is generalized to work - over any signature, thanks to signature polymorphism. - - Args: - signature: The signature of the module, which defines the input and output - of the react module. tools (list[Callable]): A list of functions, callable - objects, or `dspy.Tool` instances. max_iters (Optional[int]): The maximum - number of iterations to run. Defaults to 10. - - Example: - - ```python def get_weather(city: str) -> str: - return f"The weather in {city} is sunny." - - react = dspy.ReAct(signature="question->answer", tools=[get_weather]) pred = - react(question="What is the weather in Tokyo?") - """ - - super().__init__() - self.signature = signature = ensure_signature(signature) - self.max_iters = max_iters - - tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] - tools = {tool.name: tool for tool in tools} - outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) - - tools["finish"] = Tool( - func=lambda: "Completed.", - name="finish", - desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.", - args={}, - ) - - self.tools = tools - self.react = self._build_react_module() - self.extract = self._build_fallback_module() - - def _build_instructions(self) -> list[str]: - signature = self.signature - inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) - outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) - instr = [f"{signature.instructions}\n"] if signature.instructions else [] - - instr.extend( - [ - f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", - f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", - "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", - "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", - "When writing next_thought, you may reason about the current situation and plan for future steps.", - "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", - "Always DO the task with tools, never EXPLAIN how to do it. Return instructions only when you lack the necessary tools to complete the request.\n", - "Never assume a tool cannot be used based on your prior knowledge. If a tool exists that can help you, you MUST use it.\n", - "If you create new resources outside of your current visible context, like tables, views, fields or rows, you can navigate to them using the navigation tool.\n", - ] - ) - - for idx, tool in enumerate(self.tools.values()): - instr.append(f"({idx + 1}) {tool}") - instr.append( - "When providing `next_tool_args`, the value inside the field must be in JSON format" - ) - return instr - - def _build_react_module(self) -> type[Module]: - instructions = self._build_instructions() - react_signature = ( - dspy.Signature({**self.signature.input_fields}, "\n".join(instructions)) - .append("trajectory", dspy.InputField(), type_=str) - .append("next_thought", dspy.OutputField(), type_=str) - .append( - "next_tool_name", - dspy.OutputField(), - type_=Literal[tuple(self.tools.keys())], - ) - .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) - ) - - return dspy.Predict(react_signature) - - def _build_fallback_module(self) -> type[Module]: - signature = self.signature - fallback_signature = dspy.Signature( - {**signature.input_fields, **signature.output_fields}, - signature.instructions, - ).append("trajectory", dspy.InputField(), type_=str) - return dspy.ChainOfThought(fallback_signature) - - def _format_trajectory(self, trajectory: dict[str, Any]): - adapter = dspy.settings.adapter or dspy.ChatAdapter() - trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") - return adapter.format_user_message_content(trajectory_signature, trajectory) - - def forward(self, **input_args): - trajectory = {} - max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): - try: - pred = self._call_with_potential_trajectory_truncation( - self.react, trajectory, **input_args - ) - except ValueError as err: - logger.warning( - f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}" - ) - break - - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - result = self.tools[pred.next_tool_name](**pred.next_tool_args) - - # This is how meta tools return multiple tools, the first argument is - # the actual observation, the rest are new tools to add. Once we have - # add them, we need to rebuild the react module to include them. - # NOTE: tools will remain available for the rest of the trajectory, - # but won't be available in the next call to the agent. - if isinstance(result, ToolsUpgradeResponse): - new_tools = result.new_tools - observation = result.observation - for new_tool in new_tools: - if not isinstance(new_tool, Tool): - new_tool = Tool(new_tool) - self.tools[new_tool.name] = new_tool - self.react = self._build_react_module() - else: - observation = result - - trajectory[f"observation_{idx}"] = observation - except Exception as err: - trajectory[ - f"observation_{idx}" - ] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - - if pred.next_tool_name == "finish": - break - - extract = self._call_with_potential_trajectory_truncation( - self.extract, trajectory, **input_args - ) - return dspy.Prediction(trajectory=trajectory, **extract) - - async def aforward(self, **input_args): - trajectory = {} - max_iters = input_args.pop("max_iters", self.max_iters) - for idx in range(max_iters): - try: - pred = await self._async_call_with_potential_trajectory_truncation( - self.react, trajectory, **input_args - ) - except ValueError as err: - logger.warning( - f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}" - ) - break - - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = pred.next_tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - try: - observation = await self.tools[pred.next_tool_name]( - **pred.next_tool_args - ) - - # This is how meta tools return multiple tools, the first argument is - # the actual observation, the rest are new tools to add. Once we have - # add them, we need to rebuild the react module to include them. - # NOTE: tools will remain available for the rest of the trajectory, - # but won't be available in the next call to the agent. - if isinstance(observation, (list, tuple)): - for new_tool in observation[1:]: - if not isinstance(new_tool, Tool): - new_tool = Tool(new_tool) - self.tools[new_tool.name] = new_tool - self.react = self._build_react_module() - - observation = observation[0] - - trajectory[f"observation_{idx}"] = observation - except Exception as err: - trajectory[ - f"observation_{idx}" - ] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" - - if pred.next_tool_name == "finish": - break - - extract = await self._async_call_with_potential_trajectory_truncation( - self.extract, trajectory, **input_args - ) - return dspy.Prediction(trajectory=trajectory, **extract) - - def _call_with_potential_trajectory_truncation( - self, module, trajectory, **input_args - ): - for _ in range(3): - try: - return module( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) - except ContextWindowExceededError: - logger.warning( - "Trajectory exceeded the context window, truncating the oldest tool call information." - ) - trajectory = self.truncate_trajectory(trajectory) - - async def _async_call_with_potential_trajectory_truncation( - self, module, trajectory, **input_args - ): - for _ in range(3): - try: - return await module.acall( - **input_args, - trajectory=self._format_trajectory(trajectory), - ) - except ContextWindowExceededError: - logger.warning( - "Trajectory exceeded the context window, truncating the oldest tool call information." - ) - trajectory = self.truncate_trajectory(trajectory) - - def truncate_trajectory(self, trajectory): - """Truncates the trajectory so that it fits in the context window. - - Users can override this method to implement their own truncation logic. - """ - - keys = list(trajectory.keys()) - if len(keys) < 4: - # Every tool call has 4 keys: thought, tool_name, tool_args, and - # observation. - raise ValueError( - "The trajectory is too long so your prompt exceeded the context window, but the " - "trajectory cannot be truncated because it only has one tool call." - ) - - for key in keys[:4]: - trajectory.pop(key) - - return trajectory diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py index db8cca691e..2a2dcac63a 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/automation/utils.py @@ -5,6 +5,7 @@ from django.db import transaction from django.utils.translation import gettext as _ +import udspy from loguru import logger from pydantic import ConfigDict @@ -166,27 +167,25 @@ def __getitem__(self, key) -> any: def get_generate_formulas_tool(): - import dspy - - class RuntimeFormulaGenerator(dspy.Signature): + class RuntimeFormulaGenerator(udspy.Signature): __doc__ = GENERATE_FORMULA_PROMPT - fields_to_resolve: dict[str, dict[str, str]] = dspy.InputField( + fields_to_resolve: dict[str, dict[str, str]] = udspy.InputField( desc=( "The fields that need formulas to be generated. " "If prefixed with [optional], the field is not mandatory." ) ) - context: dict[str, Any] = dspy.InputField( + context: dict[str, Any] = udspy.InputField( desc="The available context to use in formula generation composed of previous nodes results." ) - context_metadata: dict[str, Any] = dspy.InputField( + context_metadata: dict[str, Any] = udspy.InputField( desc="Metadata about the context fields, with refs and names to assist in formula generation." ) - feedback: str = dspy.InputField( + feedback: str = udspy.InputField( desc="Validation errors from previous attempt. Empty if first attempt." ) - generated_formulas: dict[str, Any] = dspy.OutputField() + generated_formulas: dict[str, Any] = udspy.OutputField() model_config = ConfigDict(arbitrary_types_allowed=True) @@ -213,7 +212,7 @@ def generate_node_formulas( that fulfills the request, using the provided context object. """ - predict = dspy.Predict(RuntimeFormulaGenerator) + predict = udspy.Predict(RuntimeFormulaGenerator) feedback = "" for __ in range(max_retries): result = predict( diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py index dc73d12066..3d3cfc53ec 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py @@ -4,6 +4,7 @@ from django.db import transaction from django.utils.translation import gettext as _ +import udspy from baserow_premium.prompts import get_formula_docs from loguru import logger from pydantic import create_model @@ -20,30 +21,24 @@ from baserow.contrib.database.table.actions import CreateTableActionType from baserow.contrib.database.views.actions import ( CreateViewActionType, - CreateViewFilterActionType, UpdateViewFieldOptionsActionType, ) from baserow.contrib.database.views.handler import ViewHandler from baserow.core.models import Workspace from baserow.core.service import CoreService from baserow_enterprise.assistant.tools.registries import AssistantToolType -from baserow_enterprise.assistant.types import ( - TableNavigationType, - ToolsUpgradeResponse, - ViewNavigationType, - get_tool_signature, -) +from baserow_enterprise.assistant.types import TableNavigationType, ViewNavigationType from . import utils from .types import ( AnyFieldItem, AnyFieldItemCreate, AnyViewFilterItem, - AnyViewFilterItemCreate, AnyViewItemCreate, BaseTableItem, ListTablesFilterArg, TableItemCreate, + ViewFiltersArgs, view_item_registry, ) @@ -197,8 +192,6 @@ def create_tables( - if add_sample_rows is True (default), add some example rows to each table """ - import dspy # local import to save memory when not used - nonlocal user, workspace, tool_helpers if not tables: @@ -255,37 +248,26 @@ def create_tables( ) if add_sample_rows: - tools = {} instructions = [] tool_helpers.update_status( _("Preparing example rows for these new tables...") ) + tools = [] for table, created_table in zip(tables, created_tables): create_rows_tool = utils.get_table_rows_tools( user, workspace, tool_helpers, created_table )["create"] - tools[create_rows_tool.name] = create_rows_tool + tools.append(create_rows_tool) instructions.append( - f"- Create 5 example rows for table_{created_table.id}. Fill every relationship with valid data when possible." + f"- Create 5 example rows with realistic data for {created_table.name} (Id: {created_table.id}). " + "Fill every relationship with valid data when possible." ) - predictor = dspy.Predict(get_tool_signature()) - result = predictor( - question=("\n".join(instructions)), - tools=list(tools.values()), + predictor = udspy.ReAct( + "instructions -> result", tools=tools, max_iters=len(tables * 2) ) - for call in result.outputs.tool_calls: - with transaction.atomic(): - try: - result = tools[call.name](**call.args) - notes.append( - f"Rows created for table_{created_table.id}: {result}" - ) - except Exception as e: - notes.append( - f"Error creating example rows for table_{created_table.id}: {e}\n." - f"Please retry recreating rows for table_{created_table.id} manually." - ) + result = predictor(instructions=("\n".join(instructions))) + notes.append(result) return { "created_tables": [ @@ -433,43 +415,47 @@ def get_rows_tools( create/update/delete rows. """ - nonlocal user, workspace, tool_helpers + @udspy.module_callback + def load_rows_tools(context): + nonlocal user, workspace, tool_helpers - observation = ["New tools are now available.\n"] + observation = ["New tools are now available.\n"] - new_tools = [] - tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) - for table in tables: - table_tools = utils.get_table_rows_tools( - user, workspace, tool_helpers, table - ) + new_tools = [] + tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) + for table in tables: + table_tools = utils.get_table_rows_tools( + user, workspace, tool_helpers, table + ) - observation.append(f"Table '{table.name}' (ID: {table.id}):") + observation.append(f"Table '{table.name}' (ID: {table.id}):") - if "create" in operations: - create_rows = table_tools["create"] - new_tools.append(create_rows) - observation.append(f"- Use {create_rows.name} to create new rows.") + if "create" in operations: + create_rows = table_tools["create"] + new_tools.append(create_rows) + observation.append(f"- Use {create_rows.name} to create new rows.") - if "update" in operations: - update_rows = table_tools["update"] - new_tools.append(update_rows) - observation.append( - f"- Use {update_rows.name} to update existing rows by their IDs." - ) + if "update" in operations: + update_rows = table_tools["update"] + new_tools.append(update_rows) + observation.append( + f"- Use {update_rows.name} to update existing rows by their IDs." + ) - if "delete" in operations: - delete_rows = table_tools["delete"] - new_tools.append(delete_rows) - observation.append( - f"- Use {delete_rows.name} to delete rows by their IDs." - ) + if "delete" in operations: + delete_rows = table_tools["delete"] + new_tools.append(delete_rows) + observation.append( + f"- Use {delete_rows.name} to delete rows by their IDs." + ) - observation.append("") + observation.append("") - return ToolsUpgradeResponse( - observation="\n".join(observation), new_tools=new_tools - ) + # Re-initialize the module with the new tools for the next iteration + context.module.init_module(tools=context.module._tools + new_tools) + return "\n".join(observation) + + return load_rows_tools return get_rows_tools @@ -582,6 +568,7 @@ def create_views( field_options = view.field_options_to_django_orm() if field_options: UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) + created_views.append({"id": orm_view.id, **view.model_dump()}) tool_helpers.navigate_to( @@ -619,7 +606,7 @@ def get_create_view_filters_tool( """ def create_view_filters( - view_id: int, filters: list[AnyViewFilterItemCreate] + view_filters: list[ViewFiltersArgs], ) -> list[AnyViewFilterItem]: """ Creates filters in the specified view. @@ -631,45 +618,34 @@ def create_view_filters( nonlocal user, workspace, tool_helpers - if not filters: + if not view_filters: return [] - orm_view = utils.get_view(user, view_id) - tool_helpers.update_status( - _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} - ) - - fields = {f.id: f for f in orm_view.table.field_set.all()} - - created_filters = [] - with transaction.atomic(): - for filter in filters: - field = fields.get(filter.field_id) - if field is None: - logger.info("Skipping filter creation due to missing field") - continue - field_type = field_type_registry.get_by_model(field.specific_class) - if field_type.type != filter.type: - logger.info("Skipping filter creation due to type mismatch") - continue - - filter_type = filter.get_django_orm_type(field) - filter_value = filter.get_django_orm_value( - field, timezone=user.profile.timezone - ) + created_view_filters = [] + for vf in view_filters: + orm_view = utils.get_view(user, vf.view_id) + tool_helpers.update_status( + _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} + ) - orm_filter = CreateViewFilterActionType.do( - user, - orm_view, - field, - filter_type, - filter_value, - filter_group_id=None, - ) + fields = {f.id: f for f in orm_view.table.field_set.all()} + created_filters = [] + with transaction.atomic(): + for filter in vf.filters: + try: + orm_filter = utils.create_view_filter( + user, orm_view, fields, filter + ) + except ValueError as e: + logger.warning(f"Skipping filter creation: {e}") + continue - created_filters.append({"id": orm_filter.id, **filter.model_dump()}) + created_filters.append({"id": orm_filter.id, **filter.model_dump()}) + created_view_filters.append( + {"view_id": vf.view_id, "filters": created_filters} + ) - return {"created_view_filters": created_filters} + return {"created_view_filters": created_view_filters} return create_view_filters @@ -712,6 +688,45 @@ def get_formula_type(table_id: int, field_name: str, formula: str) -> str: return get_formula_type +class FormulaGenerationSignature(udspy.Signature): + """ + Generates a Baserow formula based on the provided description and table schema. + """ + + description: str = udspy.InputField( + desc="A brief description of what the formula should do." + ) + tables_schema: dict = udspy.InputField( + desc="The schema of all the tables in the database." + ) + formula_documentation: str = udspy.InputField( + desc="Documentation about Baserow formulas and their syntax." + ) + table_id: int = udspy.OutputField( + desc=( + "The ID of the table the formula is intended for. " + "Should be the same as current_table_id, unless the formula can " + "only be created in a different table." + ) + ) + field_name: str = udspy.OutputField( + desc="The name of the formula field to be created. For a new field, it must be unique in the table." + ) + formula: str = udspy.OutputField( + desc="The generated formula. Must be a valid Baserow formula." + ) + formula_type: str = udspy.OutputField( + desc="The type of the generated formula. Must be one of: text, long_text, " + "number, boolean, date, link_row, single_select, multiple_select, duration, array." + ) + is_formula_valid: bool = udspy.OutputField( + desc="Whether the generated formula is valid or not." + ) + error_message: str = udspy.OutputField( + desc="If the formula is not valid, an error message explaining why." + ) + + def get_generate_database_formula_tool( user: AbstractUser, workspace: Workspace, @@ -721,53 +736,15 @@ def get_generate_database_formula_tool( Returns a function that generates a formula for a given field in a table. """ - import dspy # local import to save memory when not used - - class FormulaGenerationSignature(dspy.Signature): - """ - Generates a Baserow formula based on the provided description and table schema. - """ - - description: str = dspy.InputField( - desc="A brief description of what the formula should do." - ) - tables_schema: dict = dspy.InputField( - desc="The schema of all the tables in the database." - ) - formula_documentation: str = dspy.InputField( - desc="Documentation about Baserow formulas and their syntax." - ) - table_id: int = dspy.OutputField( - desc=( - "The ID of the table the formula is intended for. " - "Should be the same as current_table_id, unless the formula can " - "only be created in a different table." - ) - ) - field_name: str = dspy.OutputField( - desc="The name of the formula field to be created. For a new field, it must be unique in the table." - ) - formula: str = dspy.OutputField( - desc="The generated formula. Must be a valid Baserow formula." - ) - formula_type: str = dspy.OutputField( - desc="The type of the generated formula. Must be one of: text, long_text, " - "number, boolean, date, link_row, single_select, multiple_select, duration, array." - ) - is_formula_valid: bool = dspy.OutputField( - desc="Whether the generated formula is valid or not." - ) - error_message: str = dspy.OutputField( - desc="If the formula is not valid, an error message explaining why." - ) - def generate_database_formula( database_id: int, description: str, save_to_field: bool = True, ) -> dict[str, str]: """ - Generate a database formula for a formula field. + Generate a database formula for a formula field. No need to inspect the schema + before, this tool will do it automatically and find the best table and fields to + use. - table_id: The database ID where the formula field is located. - description: A brief description of what the formula should do. @@ -781,16 +758,18 @@ def generate_database_formula( database_tables = utils.filter_tables(user, workspace).filter( database_id=database_id ) - database_tables_schema = utils.get_tables_schema(database_tables, True) + database_tables_schema = [ + t.model_dump() for t in utils.get_tables_schema(database_tables, True) + ] tool_helpers.update_status(_("Generating formula...")) formula_docs = get_formula_docs() - formula_generator = dspy.ReAct( + formula_generator = udspy.ReAct( FormulaGenerationSignature, tools=[get_formula_type_tool(user, workspace)], - max_iters=10, + max_iters=20, ) result = formula_generator( description=description, @@ -844,6 +823,15 @@ def generate_database_formula( ) operation = "field updated" + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=table.database_id, + table_id=table.id, + table_name=table.name, + ) + ) + data.update( { "table_id": table.id, diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py index b840542877..83e0219545 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py @@ -1,15 +1,13 @@ from datetime import date, datetime -from pydantic import Field - from baserow_enterprise.assistant.types import BaseModel # Somehow LLMs struggle with dates class Date(BaseModel): - year: int = Field(..., description="year (i.e. 2025).") - month: int = Field(..., description="month (1-12).") - day: int = Field(..., description="day (1-31).") + year: int + month: int + day: int def to_django_orm(self): return date(self.year, self.month, self.day).isoformat() @@ -21,8 +19,8 @@ def from_django_orm(cls, orm_date: date) -> "Date": class Datetime(Date): - hour: int = Field(..., description="hour (0-23).") - minute: int = Field(..., description="minute (0-59).") + hour: int + minute: int def to_django_orm(self): return datetime( diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py index e1654d0d97..9be4a3be13 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py @@ -68,8 +68,8 @@ class BaseLongTextFieldItem(FieldItemCreate): description="Multi-line text field. Ideal for descriptions, notes and long-form content.", ) rich_text: bool = Field( - ..., - description="Whether the long text field supports rich text. Default is True.", + default=True, + description="Whether the long text field supports rich text.", ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: @@ -101,12 +101,10 @@ class BaseNumberFieldItem(FieldItemCreate): type: Literal["number"] = Field( ..., description="Numeric field, with decimals and optional prefix/suffix." ) - decimal_places: int = Field( - ..., description="The number of decimal places. Default is 2." - ) + decimal_places: int = Field(default=2, description="The number of decimal places.") suffix: str = Field( - ..., - description="An optional suffix to display after the number. Default is empty.", + default="", + description="An optional suffix to display after the number.", ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: @@ -140,7 +138,7 @@ class BaseRatingFieldItem(FieldItemCreate): ..., description="Rating field. Ideal for reviews or scores." ) max_value: int = Field( - ..., description="The maximum value of the rating field. Default is 5." + default=5, description="The maximum value of the rating field." ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: @@ -182,7 +180,7 @@ class BooleanFieldItem(BaseBooleanFieldItem, FieldItem): class BaseDateFieldItem(FieldItemCreate): type: Literal["date"] = Field(..., description="Date or datetime field.") include_time: bool = Field( - ..., description="Whether the date field includes time. Default is False." + default=False, description="Whether the date field includes time." ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: @@ -216,13 +214,6 @@ class BaseLinkRowFieldItem(FieldItemCreate): linked_table: str | int = Field( ..., description="The ID or the name of the table this field links to." ) - has_link_back: bool = Field( - ..., - description="Whether the linked table should also have a link row field back to this table. Default is True.", - ) - multiple: bool = Field( - ..., description="Whether multiple links are allowed. Default is True." - ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: if isinstance(self.linked_table, str): @@ -238,12 +229,7 @@ def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: "Ensure you provide a valid table name or ID." ) - return { - "name": self.name, - "link_row_table": link_row_table, - "link_row_multiple_relationships": self.multiple, - "has_related_field": self.has_link_back and table != link_row_table, - } + return {"name": self.name, "link_row_table": link_row_table} class LinkRowFieldItemCreate(BaseLinkRowFieldItem): @@ -260,8 +246,6 @@ def from_django_orm(cls, orm_field: LinkRowField) -> "BaseLinkRowFieldItem": name=orm_field.name, type="link_row", linked_table=orm_field.link_row_table_id, - multiple=orm_field.link_row_multiple_relationships, - has_link_back=orm_field.link_row_related_field_id is not None, ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py index b19f742a87..76f5b8711f 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py @@ -555,3 +555,8 @@ class BooleanIsTrueViewFilterItem(BooleanIsViewFilterItemCreate, ViewFilterItem) | MultipleSelectIsAnyViewFilterItem, Field(discriminator="type"), ] + + +class ViewFiltersArgs(BaseModel): + view_id: int + filters: list[AnyViewFilterItemCreate] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py index 458bfcd70a..3395bd782a 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py @@ -18,10 +18,11 @@ class ViewItemCreate(BaseModel): name: str = Field( ..., - description="A sensible name for the view (i.e. 'All tasks', 'Completed tasks', etc.).", + description="A sensible name for the view (i.e. 'Pending payments', 'Completed tasks', etc.).", ) public: bool = Field( - ..., description="Whether the view is publicly accessible. Default is False." + ..., + description="Whether the view is publicly accessible. False unless specified.", ) def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: @@ -61,7 +62,7 @@ class GridFieldOption(BaseModel): class GridViewItemCreate(ViewItemCreate): type: Literal["grid"] = Field(..., description="A grid view.") row_height: Literal["small", "medium", "large"] = Field( - ..., + default="small", description=( "The height of the rows in the view. Can be 'small', 'medium' or 'large'. Default is 'small'." ), diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py index 94cd199cfe..c7d13c8a06 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py @@ -2,12 +2,15 @@ from itertools import groupby from typing import TYPE_CHECKING, Any, Callable, Literal, Type, Union +from django.contrib.auth.models import AbstractUser from django.core.exceptions import ValidationError from django.db import transaction from django.db.models import Q, QuerySet from django.utils.translation import gettext as _ +import udspy from pydantic import ConfigDict, Field, create_model +from udspy.utils import minimize_schema, resolve_json_schema_reference from baserow.contrib.database.fields.actions import CreateFieldActionType from baserow.contrib.database.fields.field_types import LinkRowFieldType @@ -25,7 +28,9 @@ GeneratedTableModel, Table, ) +from baserow.contrib.database.views.actions import CreateViewFilterActionType from baserow.contrib.database.views.handler import ViewHandler +from baserow.contrib.database.views.models import View, ViewFilter from baserow.core.db import specific_iterator from baserow.core.models import Workspace from baserow_enterprise.assistant.tools.database.types.table import ( @@ -36,6 +41,7 @@ from .types import ( AnyFieldItem, AnyFieldItemCreate, + AnyViewFilterItemCreate, BaseModel, Date, Datetime, @@ -48,11 +54,13 @@ NoChange = Literal["__NO_CHANGE__"] -def filter_tables(user, workspace: Workspace) -> QuerySet[Table]: +def filter_tables(user: AbstractUser, workspace: Workspace) -> QuerySet[Table]: return TableHandler().list_workspace_tables(user, workspace) -def list_tables(user, workspace: Workspace, database_id: int) -> list[BaseTableItem]: +def list_tables( + user: AbstractUser, workspace: Workspace, database_id: int +) -> list[BaseTableItem]: tables_qs = filter_tables(user, workspace).filter(database_id=database_id) return [BaseTableItem(id=table.id, name=table.name) for table in tables_qs] @@ -108,7 +116,7 @@ def get_tables_schema( def create_fields( - user, + user: AbstractUser, table: Table, field_items: list[AnyFieldItemCreate], tool_helpers: "ToolHelpers", @@ -177,14 +185,14 @@ def _get_pydantic_field_definition( return FieldDefinition( Datetime | None, Field(..., description="Datetime or None", title=orm_field.name), - lambda v: v.to_django_orm() if v is not None else None, + lambda v: v.to_django_orm() if v else None, lambda v: Datetime.from_django_orm(v) if v is not None else None, ) else: return FieldDefinition( Date | None, Field(..., description="Date or None", title=orm_field.name), - lambda v: v.to_django_orm() if v is not None else None, + lambda v: v.to_django_orm() if v else None, lambda v: Date.from_django_orm(v) if v is not None else None, ) case "single_select": @@ -383,11 +391,8 @@ def get_view(user, view_id: int): def get_table_rows_tools( - user, workspace: Workspace, tool_helpers: "ToolHelpers", table: Table + user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers", table: Table ): - import dspy # local import to save memory when not used - from dspy.adapters.types.tool import _resolve_json_schema_reference - row_model_for_create = get_create_row_model(table) row_model_for_update = get_update_row_model(table) row_model_for_response = create_model( @@ -421,13 +426,13 @@ def _create_rows( return {"created_row_ids": [r.id for r in orm_rows]} - create_row_model_schema = _resolve_json_schema_reference( - row_model_for_create.model_json_schema() + create_row_model_schema = minimize_schema( + resolve_json_schema_reference(row_model_for_create.model_json_schema()) ) - create_rows_tool = dspy.Tool( + create_rows_tool = udspy.Tool( func=_create_rows, name=f"create_rows_in_table_{table.id}", - desc=f"Creates new rows in the table {table.name} (ID: {table.id}). Max 20 rows at a time.", + description=f"Creates new rows in the table {table.name} (ID: {table.id}). Max 20 rows at a time.", args={ "rows": { "items": create_row_model_schema, @@ -462,13 +467,13 @@ def _update_rows( return {"updated_row_ids": [r.id for r in orm_rows]} - update_row_model_schema = _resolve_json_schema_reference( - row_model_for_update.model_json_schema() + update_row_model_schema = minimize_schema( + resolve_json_schema_reference(row_model_for_update.model_json_schema()) ) - update_rows_tool = dspy.Tool( + update_rows_tool = udspy.Tool( func=_update_rows, name=f"update_rows_in_table_{table.id}_by_row_ids", - desc=f"Updates existing rows in the table {table.name} (ID: {table.id}), identified by their row IDs. Max 20 at a time.", + description=f"Updates existing rows in the table {table.name} (ID: {table.id}), identified by their row IDs. Max 20 at a time.", args={ "rows": { "items": update_row_model_schema, @@ -497,10 +502,10 @@ def _delete_rows(row_ids: list[int]) -> str: return {"deleted_row_ids": row_ids} - delete_rows_tool = dspy.Tool( + delete_rows_tool = udspy.Tool( func=_delete_rows, name=f"delete_rows_in_table_{table.id}_by_row_ids", - desc=f"Deletes rows in the table {table.name} (ID: {table.id}). Max 20 at a time.", + description=f"Deletes rows in the table {table.name} (ID: {table.id}). Max 20 at a time.", args={ "row_ids": { "items": {"type": "integer"}, @@ -515,3 +520,35 @@ def _delete_rows(row_ids: list[int]) -> str: "update": update_rows_tool, "delete": delete_rows_tool, } + + +def create_view_filter( + user: AbstractUser, + orm_view: View, + table_fields: list[Field], + view_filter_item: AnyViewFilterItemCreate, +) -> ViewFilter: + """ + Creates a view filter from the given view filter item. + """ + + field = table_fields.get(view_filter_item.field_id) + if field is None: + raise ValueError("Field not found for filter") + field_type = field_type_registry.get_by_model(field.specific_class) + if field_type.type != view_filter_item.type: + raise ValueError("Field type mismatch for filter") + + filter_type = view_filter_item.get_django_orm_type(field) + filter_value = view_filter_item.get_django_orm_value( + field, timezone=user.profile.timezone + ) + + return CreateViewFilterActionType.do( + user, + orm_view, + field, + filter_type, + filter_value, + filter_group_id=None, + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py index 52c4de4d73..ee2fd6c157 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py @@ -10,10 +10,7 @@ def unsafe_navigate_to(location: AnyNavigationType) -> str: :param navigation_type: The type of navigation to perform. """ - from dspy.dsp.utils.settings import settings - from dspy.streaming.messages import sync_send_to_stream + from udspy.streaming import emit_event - stream = settings.send_stream - if stream is not None: - sync_send_to_stream(stream, AiNavigationMessage(location=location)) + emit_event(AiNavigationMessage(location=location)) return "Navigated successfully." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py index 18c9f79396..df3a7dea16 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py @@ -39,7 +39,7 @@ def on_tool_start( Called when the tool is started. It can be used to stream status messages. :param call_id: The unique identifier of the tool call. - :param instance: The instance of the dspy tool being called. + :param instance: The instance of the udspy tool being called. :param inputs: The inputs provided to the tool. """ @@ -58,7 +58,7 @@ def on_tool_end( Called when the tool has finished, either successfully or with an exception. :param call_id: The unique identifier of the tool call. - :param instance: The instance of the dspy tool being called. + :param instance: The instance of the udspy tool being called. :param inputs: The inputs provided to the tool. :param outputs: The outputs returned by the tool, or None if there was an exception. @@ -73,7 +73,7 @@ def get_tool( cls, user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" ) -> Callable[[Any], Any]: """ - Returns the actual tool function to be called to pass to the dspy react agent. + Returns the actual tool function to be called to pass to the udspy react agent. :param user: The user that will be using the tool. :param workspace: The workspace the user is currently in. diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py index 874f2d7fc6..9ac201b260 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py @@ -66,12 +66,8 @@ def __init__(self, embedder=None): @property def embedder(self): - import dspy # local import to save memory when not used - if self._embedder is None: - self._embedder = dspy.Embedder( - BaserowEmbedder(settings.BASEROW_EMBEDDINGS_API_URL) - ) + self._embedder = BaserowEmbedder(settings.BASEROW_EMBEDDINGS_API_URL) return self._embedder def embed_texts(self, texts: list[str]) -> list[list[float]]: @@ -86,7 +82,7 @@ def embed_texts(self, texts: list[str]) -> list[list[float]]: return [] embedder = self.embedder - # Support both dspy.Embedder (callable) and LangChain-style embedders + # Support both embedders as callables and LangChain-style embedders if callable(embedder): return embedder(texts) else: diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py index d0c540fec8..f87b57a81e 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py @@ -1,8 +1,10 @@ -from typing import TYPE_CHECKING, Any, Callable, TypedDict +from typing import TYPE_CHECKING, Annotated, Any, Callable from django.contrib.auth.models import AbstractUser from django.utils.translation import gettext as _ +import udspy + from baserow.core.models import Workspace from baserow_enterprise.assistant.tools.registries import AssistantToolType @@ -14,65 +16,74 @@ MAX_SOURCES = 3 -def get_search_predictor(): - import dspy # local import to save memory when not used +class SearchDocsSignature(udspy.Signature): + """ + Search the Baserow documentation for relevant information to answer user questions. + Never fabricate answers or URLs. Always copy instructions exactly as they appear in + the documentation, without rephrasing. + """ - class SearchDocsSignature(dspy.Signature): - question: str = dspy.InputField() - context: list[str] = dspy.InputField() - response: str = dspy.OutputField() - sources: list[str] = dspy.OutputField( - desc=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." + question: str = udspy.InputField() + context: list[str] = udspy.InputField() + response: str = udspy.OutputField() + sources: list[str] = udspy.OutputField( + desc=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." + ) + reliability: float = udspy.OutputField( + desc=( + "The reliability score of the response, from 0 to 1. " + "1 means the answer is fully supported by the provided context. " + "0 means the answer is not supported by the provided context." ) + ) - return dspy.ChainOfThought(SearchDocsSignature) +class SearchDocsRAG(udspy.Module): + def __init__(self): + self.rag = udspy.ChainOfThought(SearchDocsSignature) -class SearchDocsToolOutput(TypedDict): - response: str - sources: list[str] + def forward(self, question: str, *args, **kwargs): + context = KnowledgeBaseHandler().search(question, num_results=10) + return self.rag(context=context, question=question) def get_search_docs_tool( user: AbstractUser, workspace: Workspace, tool_helpers: "ToolHelpers" -) -> Callable[[str], SearchDocsToolOutput]: +) -> Callable[[str], dict[str, Any]]: """ Returns a function that searches the Baserow documentation for a given query. """ - def search_docs(query: str) -> SearchDocsToolOutput: + def search_docs( + question: Annotated[ + str, "The English version of the user question, using Baserow vocabulary." + ] + ) -> dict[str, Any]: """ - Search Baserow documentation. + Search Baserow documentation for relevant information. Make sure the question + is in English and uses Baserow-specific terminology to get the best results. """ - import dspy # local import to save memory when not used - nonlocal tool_helpers tool_helpers.update_status(_("Exploring the knowledge base...")) - class SearchDocsRAG(dspy.Module): - def __init__(self): - self.respond = get_search_predictor() - - def forward(self, question): - context = KnowledgeBaseHandler().search(question, num_results=10) - return self.respond(context=context, question=question) - - tool = SearchDocsRAG() - result = tool(query) - - sources = [] - for source in result.sources: - if source not in sources: - sources.append(source) - if len(sources) >= MAX_SOURCES: - break - - return SearchDocsToolOutput( - response=result.response, - sources=sources, - ) + search_tool = SearchDocsRAG() + answer = search_tool(question=question) + # Somehow sources can be objects with an "url" attribute instead of strings, + # let's fix that + fixed_sources = [] + for src in answer.sources[:MAX_SOURCES]: + if isinstance(src, str): + fixed_sources.append(src) + elif isinstance(src, dict) and "url" in src: + fixed_sources.append(src["url"]) + + return { + "response": answer.response, + "sources": fixed_sources, + "reliability": answer.reliability, + } return search_docs diff --git a/enterprise/backend/src/baserow_enterprise/assistant/types.py b/enterprise/backend/src/baserow_enterprise/assistant/types.py index 226730b961..bc99cd30c7 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/types.py @@ -1,9 +1,10 @@ from datetime import datetime, timezone from enum import StrEnum -from typing import Annotated, Any, Callable, Literal, Optional +from typing import Annotated, Literal, Optional from django.utils.translation import gettext as _ +import udspy from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field @@ -93,7 +94,8 @@ def from_validate_request(cls, request, ui_context_data) -> "UIContext": class AssistantMessageType(StrEnum): HUMAN = "human" AI_MESSAGE = "ai/message" - AI_THINKING = "ai/thinking" + AI_THINKING = "ai/thinking" # Update the status bar in the UI + AI_REASONING = "ai/reasoning" # Show reasoning steps before the final answer AI_NAVIGATION = "ai/navigation" AI_ERROR = "ai/error" TOOL_CALL = "tool_call" @@ -123,6 +125,11 @@ class AiMessageChunk(BaseModel): ) +class AiReasoningChunk(BaseModel): + type: Literal["ai/reasoning"] = "ai/reasoning" + content: str = Field(description="The reasoning content of the AI message chunk") + + class AiMessage(AiMessageChunk): id: int | None = Field( default=None, @@ -145,7 +152,7 @@ class AiMessage(AiMessageChunk): ) -class AiThinkingMessage(BaseModel): +class AiThinkingMessage(BaseModel, udspy.StreamEvent): type: Literal["ai/thinking"] = AssistantMessageType.AI_THINKING.value content: str = Field( default="", @@ -176,7 +183,12 @@ class AiErrorMessage(BaseModel): AIMessageUnion = ( - AiMessage | AiErrorMessage | AiThinkingMessage | ChatTitleMessage | AiMessageChunk + AiMessage + | AiErrorMessage + | AiThinkingMessage + | ChatTitleMessage + | AiMessageChunk + | AiReasoningChunk ) AssistantMessageUnion = HumanMessage | AIMessageUnion @@ -229,24 +241,6 @@ def to_localized_string(self): ] -class AiNavigationMessage(BaseModel): +class AiNavigationMessage(BaseModel, udspy.StreamEvent): type: Literal["ai/navigation"] = "ai/navigation" location: AnyNavigationType - - -class ToolsUpgradeResponse(BaseModel): - observation: str - new_tools: list[Callable[[Any], Any]] - - -def get_tool_signature(): - import dspy # local import to save memory when not used - - class ToolSignature(dspy.Signature): - """Signature for manual tool handling.""" - - question: str = dspy.InputField() - tools: list[dspy.Tool] = dspy.InputField() - outputs: dspy.ToolCalls = dspy.OutputField() - - return ToolSignature diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 4747b82e68..462d5746ca 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -12,10 +12,9 @@ import pytest from asgiref.sync import async_to_sync -from dspy.primitives.prediction import Prediction -from dspy.streaming import StreamResponse +from udspy import OutputStreamChunk, Prediction -from baserow_enterprise.assistant.assistant import Assistant, get_assistant_callbacks +from baserow_enterprise.assistant.assistant import Assistant, AssistantCallbacks from baserow_enterprise.assistant.models import AssistantChat, AssistantChatMessage from baserow_enterprise.assistant.types import ( AiMessageChunk, @@ -38,7 +37,7 @@ class TestAssistantCallbacks: def test_extend_sources_deduplicates(self): """Test that sources are deduplicated when extended""" - callbacks = get_assistant_callbacks() + callbacks = AssistantCallbacks() # Add initial sources callbacks.extend_sources( @@ -64,7 +63,7 @@ def test_extend_sources_deduplicates(self): def test_extend_sources_preserves_order(self): """Test that source order is preserved (first occurrence wins)""" - callbacks = get_assistant_callbacks() + callbacks = AssistantCallbacks() callbacks.extend_sources(["https://example.com/a"]) callbacks.extend_sources(["https://example.com/b"]) @@ -76,7 +75,7 @@ def test_extend_sources_preserves_order(self): def test_on_tool_end_extracts_sources_from_outputs(self): """Test that sources are extracted from tool outputs""" - callbacks = get_assistant_callbacks() + callbacks = AssistantCallbacks() # Mock tool instance and inputs tool_instance = MagicMock() @@ -107,7 +106,7 @@ def test_on_tool_end_extracts_sources_from_outputs(self): def test_on_tool_end_handles_missing_sources(self): """Test that tool outputs without sources don't cause errors""" - callbacks = get_assistant_callbacks() + callbacks = AssistantCallbacks() tool_instance = MagicMock() tool_instance.name = "some_tool" @@ -171,7 +170,7 @@ def test_list_chat_messages_returns_in_chronological_order( def test_aload_chat_history_formats_as_question_answer_pairs( self, enterprise_data_fixture ): - """Test that chat history is loaded as question/answer pairs for DSPy""" + """Test that chat history is loaded as user/assistant message pairs for UDSPy""" user = enterprise_data_fixture.create_user() workspace = enterprise_data_fixture.create_workspace(user=user) @@ -202,23 +201,27 @@ def test_aload_chat_history_formats_as_question_answer_pairs( assistant = Assistant(chat) async_to_sync(assistant.aload_chat_history)() - # History should contain question/answer pairs + # History should contain user/assistant message pairs assert assistant.history is not None - assert len(assistant.history.messages) == 2 + assert len(assistant.history.messages) == 4 # First pair - assert assistant.history.messages[0]["question"] == "What is Baserow?" + assert assistant.history.messages[0]["content"] == "What is Baserow?" + assert assistant.history.messages[0]["role"] == "user" assert ( - assistant.history.messages[0]["answer"] + assistant.history.messages[1]["content"] == "Baserow is a no-code database platform." ) + assert assistant.history.messages[1]["role"] == "assistant" # Second pair - assert assistant.history.messages[1]["question"] == "How do I create a table?" + assert assistant.history.messages[2]["content"] == "How do I create a table?" + assert assistant.history.messages[2]["role"] == "user" assert ( - assistant.history.messages[1]["answer"] + assistant.history.messages[3]["content"] == "You can create a table by clicking the + button." ) + assert assistant.history.messages[3]["role"] == "assistant" def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): """Test that history loading respects the limit parameter""" @@ -241,12 +244,10 @@ def test_aload_chat_history_respects_limit(self, enterprise_data_fixture): ) assistant = Assistant(chat) - async_to_sync(assistant.aload_chat_history)( - limit=6 - ) # Last 6 messages = 3 pairs + async_to_sync(assistant.aload_chat_history)(limit=6) # Last 6 messages - # Should only load the most recent 3 pairs - assert len(assistant.history.messages) <= 3 + # Should only load the most recent 6 messages (3 pairs) + assert len(assistant.history.messages) == 6 def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixture): """ @@ -275,19 +276,22 @@ def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixtu assistant = Assistant(chat) async_to_sync(assistant.aload_chat_history)() - # Should only include the complete pair - assert len(assistant.history.messages) == 1 - assert assistant.history.messages[0]["question"] == "Question 1" + # Should only include the complete pair (2 messages: user + assistant) + assert len(assistant.history.messages) == 2 + assert assistant.history.messages[0]["content"] == "Question 1" + assert assistant.history.messages[0]["role"] == "user" + assert assistant.history.messages[1]["content"] == "Answer 1" + assert assistant.history.messages[1]["role"] == "assistant" @pytest.mark.django_db class TestAssistantMessagePersistence: """Test that messages are persisted correctly during streaming""" - @patch("dspy.streamify") - @patch("dspy.LM") + @patch("udspy.ReAct.astream") + @patch("udspy.LM") def test_astream_messages_persists_human_message( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, mock_lm, mock_astream, enterprise_data_fixture ): """Test that human messages are persisted to database before streaming""" @@ -300,15 +304,16 @@ def test_astream_messages_persists_human_message( # Mock the streaming async def mock_stream(*args, **kwargs): # Yield a simple response - yield StreamResponse( - signature_field_name="answer", - chunk="Hello", - predict_name="ReAct", - is_last_chunk=False, + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Hello", + content="Hello", + is_complete=False, ) yield Prediction(answer="Hello", trajectory=[], reasoning="") - mock_streamify.return_value = MagicMock(return_value=mock_stream()) + mock_astream.return_value = mock_stream() # Configure mock LM to return a serializable model name mock_lm.return_value.model = "test-model" @@ -338,10 +343,10 @@ async def consume_stream(): ).first() assert saved_message.content == "Test message" - @patch("dspy.streamify") - @patch("dspy.LM") + @patch("udspy.ReAct.astream") + @patch("udspy.LM") def test_astream_messages_persists_ai_message_with_sources( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, mock_lm, mock_astream, enterprise_data_fixture ): """Test that AI messages are persisted with sources in artifacts""" @@ -351,22 +356,28 @@ def test_astream_messages_persists_ai_message_with_sources( user=user, workspace=workspace, title="Test Chat" ) - # Mock the streaming with a Prediction at the end - async def mock_stream(*args, **kwargs): - yield StreamResponse( - signature_field_name="answer", - chunk="Based on docs", - predict_name="ReAct", - is_last_chunk=False, - ) - yield Prediction(answer="Based on docs", trajectory=[], reasoning="") - - mock_streamify.return_value = MagicMock(return_value=mock_stream()) - # Configure mock LM to return a serializable model name mock_lm.return_value.model = "test-model" assistant = Assistant(chat) + + # Mock the streaming with a Prediction at the end + async def mock_stream(*args, **kwargs): + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Based on docs", + content="Based on docs", + is_complete=False, + ) + yield Prediction( + module=assistant._assistant, + answer="Based on docs", + trajectory=[], + reasoning="", + ) + + mock_astream.return_value = mock_stream() ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), @@ -388,12 +399,12 @@ async def consume_stream(): ).count() assert ai_messages == 1 - @patch("dspy.streamify") - @patch("dspy.Predict") + @patch("udspy.ReAct.astream") + @patch("udspy.Predict") def test_astream_messages_persists_chat_title( self, mock_predict_class, - mock_streamify, + mock_astream, enterprise_data_fixture, ): """Test that chat titles are persisted to the database""" @@ -404,27 +415,30 @@ def test_astream_messages_persists_chat_title( user=user, workspace=workspace, title="" # New chat ) - # Mock streaming - async def mock_stream(*args, **kwargs): - yield StreamResponse( - signature_field_name="answer", - chunk="Hello", - predict_name="ReAct", - is_last_chunk=False, - ) - yield Prediction(answer="Hello", trajectory=[], reasoning="") - - mock_streamify.return_value = MagicMock(return_value=mock_stream()) - # Mock title generator - async def mock_title_acall(*args, **kwargs): + async def mock_title_aforward(*args, **kwargs): return Prediction(chat_title="Greeting") mock_title_generator = MagicMock() - mock_title_generator.acall = mock_title_acall + mock_title_generator.aforward = mock_title_aforward mock_predict_class.return_value = mock_title_generator assistant = Assistant(chat) + + # Mock streaming + async def mock_stream(*args, **kwargs): + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Hello", + content="Hello", + is_complete=False, + ) + yield Prediction( + module=assistant._assistant, answer="Hello", trajectory=[], reasoning="" + ) + + mock_astream.return_value = mock_stream() ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), @@ -449,10 +463,10 @@ async def consume_stream(): class TestAssistantStreaming: """Test streaming behavior of the Assistant""" - @patch("dspy.streamify") - @patch("dspy.LM") + @patch("udspy.ReAct.astream") + @patch("udspy.LM") def test_astream_messages_yields_answer_chunks( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, mock_lm, mock_astream, enterprise_data_fixture ): """Test that answer chunks are yielded during streaming""" @@ -464,21 +478,23 @@ def test_astream_messages_yields_answer_chunks( # Mock streaming async def mock_stream(*args, **kwargs): - yield StreamResponse( - signature_field_name="answer", - chunk="Hello", - predict_name="ReAct", - is_last_chunk=False, + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Hello", + content="Hello", + is_complete=False, ) - yield StreamResponse( - signature_field_name="answer", - chunk=" world", - predict_name="ReAct", - is_last_chunk=False, + yield OutputStreamChunk( + module=None, + field_name="answer", + delta=" world", + content="Hello world", + is_complete=False, ) yield Prediction(answer="Hello world", trajectory=[], reasoning="") - mock_streamify.return_value = MagicMock(return_value=mock_stream()) + mock_astream.return_value = mock_stream() # Configure mock LM to return a serializable model name mock_lm.return_value.model = "test-model" @@ -500,17 +516,16 @@ async def consume_stream(): chunks = async_to_sync(consume_stream)() # Should receive chunks with accumulated content - assert len(chunks) == 3 + assert len(chunks) == 2 assert chunks[0].content == "Hello" assert chunks[1].content == "Hello world" - assert chunks[2].content == "Hello world" # Final chunk repeats full answer - @patch("dspy.streamify") - @patch("dspy.Predict") + @patch("udspy.ReAct.astream") + @patch("udspy.Predict") def test_astream_messages_yields_title_chunks( self, mock_predict_class, - mock_streamify, + mock_astream, enterprise_data_fixture, ): """Test that title chunks are yielded for new chats""" @@ -521,27 +536,33 @@ def test_astream_messages_yields_title_chunks( user=user, workspace=workspace, title="" # New chat ) - # Mock streaming - async def mock_stream(*args, **kwargs): - yield StreamResponse( - signature_field_name="answer", - chunk="Answer", - predict_name="ReAct", - is_last_chunk=False, - ) - yield Prediction(answer="Answer", trajectory=[], reasoning="") - - mock_streamify.return_value = MagicMock(return_value=mock_stream()) - # Mock title generator - async def mock_title_acall(*args, **kwargs): + async def mock_title_aforward(*args, **kwargs): return Prediction(chat_title="Title") mock_title_generator = MagicMock() - mock_title_generator.acall = mock_title_acall + mock_title_generator.aforward = mock_title_aforward mock_predict_class.return_value = mock_title_generator assistant = Assistant(chat) + + # Mock streaming + async def mock_stream(*args, **kwargs): + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Answer", + content="Answer", + is_complete=False, + ) + yield Prediction( + module=assistant._assistant, + answer="Answer", + trajectory=[], + reasoning="", + ) + + mock_astream.return_value = mock_stream() ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), @@ -561,10 +582,10 @@ async def consume_stream(): assert len(title_messages) == 1 assert title_messages[0].content == "Title" - @patch("dspy.streamify") - @patch("dspy.LM") + @patch("udspy.ReAct.astream") + @patch("udspy.LM") def test_astream_messages_yields_thinking_messages( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, mock_lm, mock_astream, enterprise_data_fixture ): """Test that thinking messages from tools are yielded""" @@ -577,15 +598,16 @@ def test_astream_messages_yields_thinking_messages( # Mock streaming async def mock_stream(*args, **kwargs): yield AiThinkingMessage(content="thinking") - yield StreamResponse( - signature_field_name="answer", - chunk="Answer", - predict_name="ReAct", - is_last_chunk=False, + yield OutputStreamChunk( + module=None, + field_name="answer", + delta="Answer", + content="Answer", + is_complete=False, ) yield Prediction(answer="Answer", trajectory=[], reasoning="") - mock_streamify.return_value = MagicMock(return_value=mock_stream()) + mock_astream.return_value = mock_stream() # Configure mock LM to return a serializable model name mock_lm.return_value.model = "test-model" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py index 7f4c3c21ba..4b83e36489 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py @@ -1,11 +1,13 @@ +from unittest.mock import Mock + import pytest +from udspy.module.callbacks import ModuleContext, is_module_callback from baserow.contrib.database.rows.handler import RowHandler from baserow_enterprise.assistant.tools.database.tools import ( get_list_rows_tool, get_rows_meta_tool, ) -from baserow_enterprise.assistant.types import ToolsUpgradeResponse from .utils import fake_tool_helpers @@ -213,19 +215,18 @@ def test_create_rows(data_fixture): assert callable(meta_tool) tools_upgrade = meta_tool([table.id], ["create"]) - assert isinstance(tools_upgrade, ToolsUpgradeResponse) - assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation - assert f"create_rows_in_table_{table.id}" in tools_upgrade.observation - assert ( - f"update_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation - ) - assert ( - f"delete_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation - ) - assert len(tools_upgrade.new_tools) == 1 + assert is_module_callback(tools_upgrade) - create_table_rows = tools_upgrade.new_tools[0] - assert create_table_rows.name == f"create_rows_in_table_{table.id}" + mock_module = Mock() + mock_module._tools = [] + mock_module.init_module = Mock() + tools_upgrade(ModuleContext(module=mock_module)) + assert mock_module.init_module.called + + added_tools = mock_module.init_module.call_args[1]["tools"] + added_tools_names = [tool.name for tool in added_tools] + assert len(added_tools) == 1 + assert f"create_rows_in_table_{table.id}" in added_tools_names table_model = table.get_model() assert table_model.objects.count() == 3 @@ -260,6 +261,7 @@ def test_create_rows(data_fixture): "Single link to B": None, "link": [], } + create_table_rows = added_tools[0] result = create_table_rows(rows=[row_1, row_2]) created_row_ids = result["created_row_ids"] assert len(created_row_ids) == 2 @@ -277,19 +279,19 @@ def test_update_rows(data_fixture): meta_tool = get_rows_meta_tool(user, workspace, tool_helpers) assert callable(meta_tool) - tools_upgrade = meta_tool([table.id], ["update"]) - assert isinstance(tools_upgrade, ToolsUpgradeResponse) - assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation - assert f"create_rows_in_table_{table.id}" not in tools_upgrade.observation - assert f"update_rows_in_table_{table.id}_by_row_ids" in tools_upgrade.observation - assert ( - f"delete_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation - ) - assert len(tools_upgrade.new_tools) == 1 + assert is_module_callback(tools_upgrade) + + mock_module = Mock() + mock_module._tools = [] + mock_module.init_module = Mock() + tools_upgrade(ModuleContext(module=mock_module)) + assert mock_module.init_module.called - update_table_rows = tools_upgrade.new_tools[0] - assert update_table_rows.name == f"update_rows_in_table_{table.id}_by_row_ids" + added_tools = mock_module.init_module.call_args[1]["tools"] + added_tools_names = [tool.name for tool in added_tools] + assert len(added_tools) == 1 + assert f"update_rows_in_table_{table.id}_by_row_ids" in added_tools_names table_model = table.get_model() assert table_model.objects.count() == 3 @@ -323,6 +325,7 @@ def test_update_rows(data_fixture): "link": "__NO_CHANGE__", } + update_table_rows = added_tools[0] result = update_table_rows(rows=[row_1_updates, row_2_updates]) updated_row_ids = result["updated_row_ids"] assert len(updated_row_ids) == 2 @@ -372,17 +375,17 @@ def test_delete_rows(data_fixture): assert callable(meta_tool) tools_upgrade = meta_tool([table.id], ["delete"]) - assert isinstance(tools_upgrade, ToolsUpgradeResponse) - assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation - assert f"create_rows_in_table_{table.id}" not in tools_upgrade.observation - assert ( - f"update_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation - ) - assert f"delete_rows_in_table_{table.id}_by_row_ids" in tools_upgrade.observation - assert len(tools_upgrade.new_tools) == 1 - - delete_table_rows = tools_upgrade.new_tools[0] - assert delete_table_rows.name == f"delete_rows_in_table_{table.id}_by_row_ids" + assert is_module_callback(tools_upgrade) + mock_module = Mock() + mock_module._tools = [] + mock_module.init_module = Mock() + tools_upgrade(ModuleContext(module=mock_module)) + assert mock_module.init_module.called + added_tools = mock_module.init_module.call_args[1]["tools"] + added_tools_names = [tool.name for tool in added_tools] + assert len(added_tools) == 1 + assert f"delete_rows_in_table_{table.id}_by_row_ids" in added_tools_names + delete_table_rows = added_tools[0] table_model = table.get_model() assert table_model.objects.count() == 3 diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py index e7869280fe..d065bdefac 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py @@ -271,8 +271,6 @@ def test_create_complex_table_tool(data_fixture): type="link_row", name="Related Items", linked_table=table.id, - has_link_back=False, - multiple=True, ), RatingFieldItemCreate( type="rating", @@ -340,7 +338,7 @@ def test_generate_database_formula_no_save(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the dspy.ReAct to return a valid formula + # Mock the udspy.ReAct to return a valid formula mock_prediction = MagicMock() mock_prediction.is_formula_valid = True mock_prediction.formula = "'ok'" @@ -349,7 +347,7 @@ def test_generate_database_formula_no_save(data_fixture): mock_prediction.table_id = table.id mock_prediction.error_message = "" - with patch("dspy.ReAct") as mock_react: + with patch("udspy.ReAct") as mock_react: mock_react.return_value.return_value = mock_prediction tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) @@ -377,7 +375,7 @@ def test_generate_database_formula_create_new_field(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the dspy.ReAct to return a valid formula + # Mock the udspy.ReAct to return a valid formula mock_prediction = MagicMock() mock_prediction.is_formula_valid = True mock_prediction.formula = "'ok'" @@ -386,7 +384,7 @@ def test_generate_database_formula_create_new_field(data_fixture): mock_prediction.table_id = table.id mock_prediction.error_message = "" - with patch("dspy.ReAct") as mock_react: + with patch("udspy.ReAct") as mock_react: mock_react.return_value.return_value = mock_prediction tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) @@ -427,7 +425,7 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): ) existing_field_id = existing_field.id - # Mock the dspy.ReAct to return a new formula + # Mock the udspy.ReAct to return a new formula mock_prediction = MagicMock() mock_prediction.is_formula_valid = True mock_prediction.formula = "'new'" @@ -436,7 +434,7 @@ def test_generate_database_formula_update_existing_formula_field(data_fixture): mock_prediction.table_id = table.id mock_prediction.error_message = "" - with patch("dspy.ReAct") as mock_react: + with patch("udspy.ReAct") as mock_react: mock_react.return_value.return_value = mock_prediction tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) @@ -477,7 +475,7 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): ) existing_field_id = existing_text_field.id - # Mock the dspy.ReAct to return a valid formula + # Mock the udspy.ReAct to return a valid formula mock_prediction = MagicMock() mock_prediction.is_formula_valid = True mock_prediction.formula = "'ok'" @@ -486,7 +484,7 @@ def test_generate_database_formula_replace_non_formula_field(data_fixture): mock_prediction.table_id = table.id mock_prediction.error_message = "" - with patch("dspy.ReAct") as mock_react: + with patch("udspy.ReAct") as mock_react: mock_react.return_value.return_value = mock_prediction tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) @@ -527,7 +525,7 @@ def test_generate_database_formula_invalid_formula(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the dspy.ReAct to return an invalid formula + # Mock the udspy.ReAct to return an invalid formula mock_prediction = MagicMock() mock_prediction.is_formula_valid = False mock_prediction.formula = "" @@ -536,7 +534,7 @@ def test_generate_database_formula_invalid_formula(data_fixture): mock_prediction.table_id = table.id mock_prediction.error_message = "Formula syntax error: invalid expression" - with patch("dspy.ReAct") as mock_react: + with patch("udspy.ReAct") as mock_react: mock_react.return_value.return_value = mock_prediction tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) @@ -566,7 +564,7 @@ def test_generate_database_formula_documentation_completeness(data_fixture): table = data_fixture.create_database_table(database=database, name="Test Table") data_fixture.create_text_field(table=table, name="text_field", primary=True) - # Mock the dspy.ReAct to capture the formula_documentation argument + # Mock the udspy.ReAct to capture the formula_documentation argument mock_prediction = MagicMock() mock_prediction.is_formula_valid = True mock_prediction.formula = "'ok'" @@ -588,7 +586,7 @@ def __call__(self, **kwargs): captured_formula_docs = kwargs.get("formula_documentation") return mock_prediction - with patch("dspy.ReAct", MockReAct): + with patch("udspy.ReAct", MockReAct): tool = get_generate_database_formula_tool(user, workspace, fake_tool_helpers) tool( database_id=database.id, diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py index 16128971cb..25ec2b3a16 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py @@ -33,6 +33,9 @@ TimelineViewItemCreate, ) from baserow_enterprise.assistant.tools.database.types.base import Date +from baserow_enterprise.assistant.tools.database.types.view_filters import ( + ViewFiltersArgs, +) from .utils import fake_tool_helpers @@ -248,16 +251,21 @@ def test_create_text_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - TextEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="equal", value="test" + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + TextEqualViewFilterItemCreate( + field_id=field.id, type="text", operator="equal", value="test" + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 - assert response["created_view_filters"][0]["operator"] == "equal" + assert len(response["created_view_filters"][0]["filters"]) == 1 + assert response["created_view_filters"][0]["filters"][0]["operator"] == "equal" assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() @@ -272,12 +280,19 @@ def test_create_text_not_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - TextNotEqualViewFilterItemCreate( - field_id=field.id, type="text", operator="not_equal", value="test" + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + TextNotEqualViewFilterItemCreate( + field_id=field.id, + type="text", + operator="not_equal", + value="test", + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -295,12 +310,19 @@ def test_create_text_contains_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - TextContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains", value="test" + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + TextContainsViewFilterItemCreate( + field_id=field.id, + type="text", + operator="contains", + value="test", + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -318,12 +340,19 @@ def test_create_text_not_contains_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - TextNotContainsViewFilterItemCreate( - field_id=field.id, type="text", operator="contains_not", value="test" + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + TextNotContainsViewFilterItemCreate( + field_id=field.id, + type="text", + operator="contains_not", + value="test", + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -344,12 +373,16 @@ def test_create_number_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - NumberEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="equal", value=42.0 + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + NumberEqualsViewFilterItemCreate( + field_id=field.id, type="number", operator="equal", value=42.0 + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -367,12 +400,19 @@ def test_create_number_not_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - NumberNotEqualsViewFilterItemCreate( - field_id=field.id, type="number", operator="not_equal", value=42.0 + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + NumberNotEqualsViewFilterItemCreate( + field_id=field.id, + type="number", + operator="not_equal", + value=42.0, + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -390,16 +430,20 @@ def test_create_number_higher_than_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - NumberHigherThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="higher_than", - value=10.0, - or_equal=False, + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + NumberHigherThanViewFilterItemCreate( + field_id=field.id, + type="number", + operator="higher_than", + value=10.0, + or_equal=False, + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -419,16 +463,20 @@ def test_create_number_lower_than_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - NumberLowerThanViewFilterItemCreate( - field_id=field.id, - type="number", - operator="lower_than", - value=100.0, - or_equal=False, + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + NumberLowerThanViewFilterItemCreate( + field_id=field.id, + type="number", + operator="lower_than", + value=100.0, + or_equal=False, + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -447,16 +495,20 @@ def test_create_date_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - DateEqualsViewFilterItemCreate( - field_id=field.id, - type="date", - operator="equal", - value=Date(year=2024, month=1, day=15), - mode="exact_date", + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + DateEqualsViewFilterItemCreate( + field_id=field.id, + type="date", + operator="equal", + value=Date(year=2024, month=1, day=15), + mode="exact_date", + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -474,16 +526,20 @@ def test_create_date_not_equal_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - DateNotEqualsViewFilterItemCreate( - field_id=field.id, - type="date", - operator="not_equal", - value=None, - mode="today", + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + DateNotEqualsViewFilterItemCreate( + field_id=field.id, + type="date", + operator="not_equal", + value=None, + mode="today", + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -503,17 +559,21 @@ def test_create_date_after_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - DateAfterViewFilterItemCreate( - field_id=field.id, - type="date", - operator="after", - value=7, - mode="nr_days_ago", - or_equal=False, + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + DateAfterViewFilterItemCreate( + field_id=field.id, + type="date", + operator="after", + value=7, + mode="nr_days_ago", + or_equal=False, + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -533,17 +593,21 @@ def test_create_date_before_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - DateBeforeViewFilterItemCreate( - field_id=field.id, - type="date", - operator="before", - value=None, - mode="tomorrow", - or_equal=True, + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + DateBeforeViewFilterItemCreate( + field_id=field.id, + type="date", + operator="before", + value=None, + mode="tomorrow", + or_equal=True, + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -566,15 +630,19 @@ def test_create_single_select_is_any_of_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - SingleSelectIsAnyViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_any_of", - value=["Option 1", "Option 2"], + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + SingleSelectIsAnyViewFilterItemCreate( + field_id=field.id, + type="single_select", + operator="is_any_of", + value=["Option 1", "Option 2"], + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -595,15 +663,19 @@ def test_create_single_select_is_none_of_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - SingleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="single_select", - operator="is_none_of", - value=["Bad Option"], + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + SingleSelectIsNoneOfNotViewFilterItemCreate( + field_id=field.id, + type="single_select", + operator="is_none_of", + value=["Bad Option"], + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -624,12 +696,16 @@ def test_create_boolean_is_true_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=True + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + BooleanIsViewFilterItemCreate( + field_id=field.id, type="boolean", operator="is", value=True + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -647,12 +723,16 @@ def test_create_boolean_is_false_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - BooleanIsViewFilterItemCreate( - field_id=field.id, type="boolean", operator="is", value=False + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + BooleanIsViewFilterItemCreate( + field_id=field.id, type="boolean", operator="is", value=False + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -673,15 +753,19 @@ def test_create_multiple_select_is_any_of_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - MultipleSelectIsAnyViewFilterItemCreate( - field_id=field.id, - type="multiple_select", - operator="is_any_of", - value=["Tag 1", "Tag 2"], + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + MultipleSelectIsAnyViewFilterItemCreate( + field_id=field.id, + type="multiple_select", + operator="is_any_of", + value=["Tag 1", "Tag 2"], + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 @@ -702,15 +786,19 @@ def test_create_multiple_select_is_none_of_filter(data_fixture): tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) response = tool( - view_id=view.id, - filters=[ - MultipleSelectIsNoneOfNotViewFilterItemCreate( - field_id=field.id, - type="multiple_select", - operator="is_none_of", - value=["Bad Tag"], + [ + ViewFiltersArgs( + view_id=view.id, + filters=[ + MultipleSelectIsNoneOfNotViewFilterItemCreate( + field_id=field.id, + type="multiple_select", + operator="is_none_of", + value=["Bad Tag"], + ) + ], ) - ], + ] ) assert len(response["created_view_filters"]) == 1 diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py deleted file mode 100644 index 3fcfef4ccd..0000000000 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Test that dspy is lazy-loaded only when the Assistant is actually used. - -This prevents unnecessary memory usage when the AI Assistant feature is not being used. -""" -import sys - -import pytest - - -@pytest.mark.django_db -class TestDspyLazyLoading: - """Verify that dspy is only loaded when Assistant is instantiated.""" - - def test_dspy_not_loaded_on_django_startup(self): - """ - Test that dspy is NOT loaded when Django starts up. - - This is critical for memory efficiency - dspy should only be loaded - when the AI Assistant feature is actually used. - """ - - # Remove dspy and litellm from sys.modules if already loaded - # (this can happen if other tests ran first) - modules_to_remove = [ - key - for key in sys.modules - if key.startswith("dspy") or key.startswith("litellm") - ] - for module in modules_to_remove: - del sys.modules[module] - - # Import the handler module (which is what gets imported at Django startup) - from baserow_enterprise.assistant import handler # noqa: F401 - - # Verify dspy and litellm are NOT loaded yet - assert "dspy" not in sys.modules, ( - "dspy should not be loaded on import. " - "Check for top-level dspy imports in assistant module files." - ) - - assert "litellm" not in sys.modules, ( - "litellm should not be loaded on import. " - "Check for top-level litellm imports in assistant module files." - ) - - def test_dspy_loaded_when_assistant_created( - self, data_fixture, enterprise_data_fixture - ): - """ - Test that dspy IS loaded when an Assistant object is created. - - This verifies that lazy loading works correctly and dspy is available - when needed. - """ - - # Remove dspy and litellm from sys.modules to start fresh - modules_to_remove = [ - key - for key in sys.modules - if key.startswith("dspy") or key.startswith("litellm") - ] - for module in modules_to_remove: - del sys.modules[module] - - # Create necessary fixtures - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - enterprise_data_fixture.enable_enterprise() - - # Import and use handler (should not load dspy yet) - from baserow_enterprise.assistant.handler import AssistantHandler - from baserow_enterprise.assistant.models import AssistantChat - - # Verify dspy and litellm are still not loaded - assert ( - "dspy" not in sys.modules - ), "dspy should not be loaded after importing handler" - - assert ( - "litellm" not in sys.modules - ), "litellm should not be loaded after importing handler" - - # Create a chat - chat = AssistantChat.objects.create( - user=user, - workspace=workspace, - ) - - # Create Assistant - this SHOULD trigger dspy loading - handler = AssistantHandler() - assistant = handler.get_assistant(chat) - - # Now dspy and litellm should be loaded - assert "dspy" in sys.modules, ( - "dspy should be loaded after creating Assistant instance. " - "Check that Assistant.__init__ imports dspy." - ) - - assert "litellm" in sys.modules, ( - "litellm should be loaded after creating Assistant instance. " - "Check that Assistant.__init__ imports dspy." - ) - - assert assistant is not None - - def test_assistant_handler_does_not_load_dspy(self, data_fixture): - """ - Test that using AssistantHandler methods (other than get_assistant) - does not load dspy. - """ - - # Remove dspy and litellm from sys.modules - modules_to_remove = [ - key - for key in sys.modules - if key.startswith("dspy") or key.startswith("litellm") - ] - for module in modules_to_remove: - del sys.modules[module] - - # Create fixtures - user = data_fixture.create_user() - workspace = data_fixture.create_workspace(user=user) - - from baserow_enterprise.assistant.handler import AssistantHandler - - handler = AssistantHandler() - - # These operations should not load dspy - chats = handler.list_chats(user, workspace.id) - assert chats is not None - - # Verify dspy and litellm are still not loaded - assert "dspy" not in sys.modules, ( - "dspy should not be loaded by AssistantHandler methods " - "(except get_assistant)" - ) - - assert "litellm" not in sys.modules, ( - "litellm should not be loaded by AssistantHandler methods " - "(except get_assistant)" - ) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py index eef256cfa0..45b9b1f894 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/conftest.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/conftest.py @@ -1,6 +1,9 @@ +import os + from django.apps import apps from django.test.utils import override_settings +import pytest from baserow_premium_tests.conftest import * # noqa: F403, F401 from baserow.core.apps import sync_operations_after_migrate @@ -16,6 +19,22 @@ ) +@pytest.fixture(autouse=True) # noqa: F405 +def set_openai_api_key_env_var(): + """ + Set a dummy OpenAI API key for tests to prevent client instantiation errors. + + udspy.LM() creates an OpenAI client that raises an error if OPENAI_API_KEY is not + set during client instantiation. This fixture ensures tests don't fail due to + missing API key configuration, which is not needed anyway. + """ + + if not os.getenv("OPENAI_API_KEY"): + os.environ[ + "OPENAI_API_KEY" + ] = "Please, assistant don't crash. You don't need me." + + @pytest.fixture # noqa: F405 def enterprise_data_fixture(fake, data_fixture): from .enterprise_fixtures import EnterpriseFixtures as EnterpriseFixturesBase diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss index 45500abf8c..e545d89303 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss @@ -308,7 +308,7 @@ height: 8px; border-radius: 50%; background-color: $palette-neutral-800; - animation: typing-bounce 1.4s infinite ease-in-out; + animation: assistant-typing-bounce 1.4s infinite ease-in-out; &:nth-child(1) { animation-delay: -0.32s; @@ -320,7 +320,7 @@ } } -@keyframes typing-bounce { +@keyframes assistant-typing-bounce { 0%, 80%, 100% { @@ -334,6 +334,7 @@ } } + .assistant__message { display: flex; @@ -361,6 +362,11 @@ } } +.assistant__message-text-container { + display: flex; + gap: 12px; +} + .assistant__message-bubble { background-color: $white; border: 1px solid $palette-neutral-200; @@ -371,7 +377,10 @@ overflow-wrap: break-word; max-width: 100%; min-width: 0; - display: block; + + & span.loading { + margin-top: 4px; + } .assistant__message--human & { background-color: $palette-neutral-100; @@ -381,6 +390,12 @@ .assistant__message--error & { background-color: $palette-red-200; } + + .assistant__message--reasoning & { + color: #16829c; + background-color: $palette-cyan-50; + border-color: $palette-cyan-50; + } } .assistant__message-text { @@ -555,6 +570,14 @@ } } +.assistant__reasoning-indicator { + @include loading-spinner(#16829c, 14px); + + display: inline-block; + margin-right: 8px; + vertical-align: middle; +} + .assistant__chat-history-spacer { width: 300px; } diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue index 62ca4beec7..12effaf6bc 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue @@ -8,6 +8,7 @@ 'assistant__message--human': message.role === 'human', 'assistant__message--ai': message.role === 'ai', 'assistant__message--error': message.error, + 'assistant__message--reasoning': message.reasoning, }" >
@@ -21,12 +22,15 @@