From 037e09a91df13cf2a89e5974916154a61763b65f Mon Sep 17 00:00:00 2001 From: Bram Date: Mon, 20 Oct 2025 13:48:04 +0200 Subject: [PATCH 1/3] Fix e2e tests (#4095) --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 454fd2ad51..c50a49b7fa 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -399,6 +399,7 @@ jobs: with: files: reports/report.xml check_name: Backend Tests (Group ${{ matrix.group }}) + comment_mode: off test-frontend: name: Web-Frontend Tests (Shard ${{ matrix.shard }}) @@ -451,6 +452,7 @@ jobs: with: files: reports/junit.xml check_name: Web-Frontend Tests (Shard ${{ matrix.shard }}) + comment_mode: off test-zapier: name: Zapier Integration Tests @@ -518,8 +520,7 @@ jobs: name: E2E Tests (Shard ${{ matrix.shard }}) timeout-minutes: 60 runs-on: ubuntu-latest - # Only run E2E tests on PRs, not on develop/master branches - if: github.event_name == 'pull_request' + if: needs.detect-changes.outputs.backend == 'true' || needs.detect-changes.outputs.frontend == 'true' || needs.detect-changes.outputs.dockerfiles == 'true' || github.ref_name == 'develop' || github.ref_name == 'master' needs: - build-backend - build-frontend @@ -678,7 +679,6 @@ jobs: collect-e2e-reports: name: Collect E2E Test Reports runs-on: ubuntu-latest - if: github.event_name == 'pull_request' && always() needs: [test-e2e] permissions: contents: read From 320d89d56a50a955290787ce8df1a28710382bbb Mon Sep 17 00:00:00 2001 From: Bram Date: Mon, 20 Oct 2025 13:58:50 +0200 Subject: [PATCH 2/3] Add migration to readme (#4096) --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 653a641f79..bae94f8644 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,14 @@ ## Baserow is an open-source no-code platform and the best alternative to Airtable +### 🚨 Repository Migration Notice + +Baserow has moved from GitLab to GitHub. All issues have been successfully migrated, +but merged and closed merge requests (PRs) were not imported. You can still browse the +old repository and its history at: https://gitlab.com/baserow/baserow. + +Please use this GitHub repository for all new issues, discussions, and contributions +going forward at: https://github.com/baserow/baserow. + It lets you easily build databases, applications, automations, and dashboards β€” all in one secure, self-hostable environment. Empower your team to create powerful tools without writing a single line of code. From 7a92c4dbfeceac48efb0bd2ecff3a2da971b0cdf Mon Sep 17 00:00:00 2001 From: Bram Date: Mon, 20 Oct 2025 15:38:57 +0200 Subject: [PATCH 3/3] Add Database Management Tools to AI Assistant (#4081) * Add database tools * Address feedback --------- Authored-by: Davide Silvestri --- .../database/locale/en/LC_MESSAGES/django.po | 18 +- .../baserow/contrib/database/table/actions.py | 9 +- .../baserow/contrib/database/views/handler.py | 12 +- .../core/locale/en/LC_MESSAGES/django.po | 4 +- .../baserow/locale/en/LC_MESSAGES/django.po | 53 +- .../api/assistant/serializers.py | 20 +- .../baserow_enterprise/api/assistant/views.py | 7 + .../backend/src/baserow_enterprise/apps.py | 34 +- .../baserow_enterprise/assistant/adapter.py | 14 + .../baserow_enterprise/assistant/assistant.py | 104 +-- .../baserow_enterprise/assistant/prompts.py | 50 +- .../src/baserow_enterprise/assistant/react.py | 272 +++++++ .../assistant/tools/__init__.py | 3 + .../assistant/tools/database/__init__.py | 0 .../assistant/tools/database/tools.py | 768 ++++++++++++++++++ .../tools/database/types/__init__.py | 6 + .../assistant/tools/database/types/base.py | 37 + .../tools/database/types/database.py | 15 + .../assistant/tools/database/types/fields.py | 494 +++++++++++ .../assistant/tools/database/types/table.py | 68 ++ .../tools/database/types/view_filters.py | 557 +++++++++++++ .../assistant/tools/database/types/views.py | 370 +++++++++ .../assistant/tools/database/utils.py | 514 ++++++++++++ .../assistant/tools/navigation/__init__.py | 0 .../assistant/tools/navigation/tools.py | 47 ++ .../assistant/tools/navigation/types.py | 72 ++ .../assistant/tools/navigation/utils.py | 19 + .../assistant/tools/registries.py | 69 +- .../assistant/tools/search_docs/tools.py | 72 +- .../src/baserow_enterprise/assistant/types.py | 83 +- .../locale/en/LC_MESSAGES/django.po | 93 ++- .../api/assistant/test_assistant_views.py | 13 +- .../assistant/test_assistant.py | 104 +-- .../test_assistant_database_rows_tools.py | 396 +++++++++ .../test_assistant_database_table_tools.py | 325 ++++++++ .../test_assistant_database_tools.py | 52 ++ .../test_assistant_database_views_tools.py | 719 ++++++++++++++++ .../assistant/utils.py | 3 + .../assets/scss/components/assistant.scss | 6 + .../assistant/AssistantInputMessage.vue | 19 +- .../components/assistant/AssistantPanel.vue | 46 ++ .../baserow_enterprise/locales/en.json | 8 +- .../baserow_enterprise/store/assistant.js | 36 +- .../locale/en/LC_MESSAGES/django.po | 14 +- 44 files changed, 5294 insertions(+), 331 deletions(-) create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/adapter.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/react.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/__init__.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/database.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/__init__.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py create mode 100644 enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py create mode 100644 enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py diff --git a/backend/src/baserow/contrib/database/locale/en/LC_MESSAGES/django.po b/backend/src/baserow/contrib/database/locale/en/LC_MESSAGES/django.po index 1eb91462f2..c1da067247 100644 --- a/backend/src/baserow/contrib/database/locale/en/LC_MESSAGES/django.po +++ b/backend/src/baserow/contrib/database/locale/en/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-09-29 14:05+0000\n" +"POT-Creation-Date: 2025-10-13 19:58+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -335,39 +335,39 @@ msgstr "" msgid "Table \"%(table_name)s\" (%(table_id)s) created" msgstr "" -#: src/baserow/contrib/database/table/actions.py:104 +#: src/baserow/contrib/database/table/actions.py:107 msgid "Delete table" msgstr "" -#: src/baserow/contrib/database/table/actions.py:105 +#: src/baserow/contrib/database/table/actions.py:108 #, python-format msgid "Table \"%(table_name)s\" (%(table_id)s) deleted" msgstr "" -#: src/baserow/contrib/database/table/actions.py:160 +#: src/baserow/contrib/database/table/actions.py:163 msgid "Order tables" msgstr "" -#: src/baserow/contrib/database/table/actions.py:161 +#: src/baserow/contrib/database/table/actions.py:164 msgid "Tables order changed" msgstr "" -#: src/baserow/contrib/database/table/actions.py:224 +#: src/baserow/contrib/database/table/actions.py:227 msgid "Update table" msgstr "" -#: src/baserow/contrib/database/table/actions.py:226 +#: src/baserow/contrib/database/table/actions.py:229 #, python-format msgid "" "Table (%(table_id)s) name changed from \"%(original_table_name)s\" to " "\"%(table_name)s\"" msgstr "" -#: src/baserow/contrib/database/table/actions.py:296 +#: src/baserow/contrib/database/table/actions.py:299 msgid "Duplicate table" msgstr "" -#: src/baserow/contrib/database/table/actions.py:298 +#: src/baserow/contrib/database/table/actions.py:301 #, python-format msgid "" "Table \"%(table_name)s\" (%(table_id)s) duplicated from " diff --git a/backend/src/baserow/contrib/database/table/actions.py b/backend/src/baserow/contrib/database/table/actions.py index c599c9ff47..31f248b00e 100755 --- a/backend/src/baserow/contrib/database/table/actions.py +++ b/backend/src/baserow/contrib/database/table/actions.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional, Tuple from django.contrib.auth.models import AbstractUser from django.utils.translation import gettext_lazy as _ @@ -47,8 +47,9 @@ def do( name: str, data: Optional[List[List[Any]]] = None, first_row_header: bool = True, + fill_example: bool = True, progress: Optional[Progress] = None, - ) -> Table: + ) -> Tuple[Table, Dict[str, Dict[str, Any]]]: """ Create a table in the specified database. Undoing this action trashes the table and redoing restores it. @@ -61,6 +62,8 @@ def do( :param first_row_header: Indicates if the first row are the fields. The names of these rows are going to be used as fields. If `fields` is provided, this options is ignored. + :param fill_example: Whether or not to fill the table with example data if + no data is provided. :param progress: An optional progress instance if you want to track the progress of the task. :return: The created table and the error report. @@ -72,7 +75,7 @@ def do( name, data=data, first_row_header=first_row_header, - fill_example=True, + fill_example=fill_example, progress=progress, ) diff --git a/backend/src/baserow/contrib/database/views/handler.py b/backend/src/baserow/contrib/database/views/handler.py index 97e9c94284..92d7272635 100644 --- a/backend/src/baserow/contrib/database/views/handler.py +++ b/backend/src/baserow/contrib/database/views/handler.py @@ -578,12 +578,12 @@ def list_views( self, user: AbstractUser, table: Table, - _type: str, - filters: bool, - sortings: bool, - decorations: bool, - group_bys: bool, - limit: int, + _type: str | None = None, + filters: bool = True, + sortings: bool = True, + decorations: bool = True, + group_bys: bool = True, + limit: int | None = None, ) -> Iterable[View]: """ Lists available views for a user/table combination. diff --git a/backend/src/baserow/core/locale/en/LC_MESSAGES/django.po b/backend/src/baserow/core/locale/en/LC_MESSAGES/django.po index 44084b8767..8c5f145c01 100644 --- a/backend/src/baserow/core/locale/en/LC_MESSAGES/django.po +++ b/backend/src/baserow/core/locale/en/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-09-29 14:05+0000\n" +"POT-Creation-Date: 2025-10-13 19:58+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -242,7 +242,7 @@ msgstr "" msgid "Decimal number" msgstr "" -#: src/baserow/core/handler.py:2185 src/baserow/core/user/handler.py:267 +#: src/baserow/core/handler.py:2187 src/baserow/core/user/handler.py:267 #, python-format msgid "%(name)s's workspace" msgstr "" diff --git a/backend/src/baserow/locale/en/LC_MESSAGES/django.po b/backend/src/baserow/locale/en/LC_MESSAGES/django.po index eebf80ebaf..72129cf0b9 100755 --- a/backend/src/baserow/locale/en/LC_MESSAGES/django.po +++ b/backend/src/baserow/locale/en/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-09-30 08:04+0000\n" +"POT-Creation-Date: 2025-10-13 19:58+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -35,15 +35,15 @@ msgstr "" #: src/baserow/contrib/automation/action_scopes.py:14 #, python-format msgid "" -"of type (%(node_type)s) in automation " -"\"%(automation_name)s\" (%(automation_id)s)." +"of type (%(node_type)s) in automation \"%(automation_name)s\" " +"(%(automation_id)s)." msgstr "" #: src/baserow/contrib/automation/actions.py:8 #, python-format msgid "" -"in workflow (%(workflow_id)s) in automation " -"\"%(automation_name)s\" (%(automation_id)s)." +"in workflow (%(workflow_id)s) in automation \"%(automation_name)s\" " +"(%(automation_id)s)." msgstr "" #: src/baserow/contrib/automation/automation_init_application.py:29 @@ -54,62 +54,71 @@ msgstr "" msgid "Local Baserow" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:28 +#: src/baserow/contrib/automation/nodes/actions.py:32 msgid "Create automation node" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:29 +#: src/baserow/contrib/automation/nodes/actions.py:33 #, python-format msgid "Node (%(node_id)s) created" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:100 +#: src/baserow/contrib/automation/nodes/actions.py:104 msgid "Update automation node" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:101 +#: src/baserow/contrib/automation/nodes/actions.py:105 #, python-format msgid "Node (%(node_id)s) updated" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:169 +#: src/baserow/contrib/automation/nodes/actions.py:173 msgid "Delete automation node" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:170 +#: src/baserow/contrib/automation/nodes/actions.py:174 #, python-format msgid "Node (%(node_id)s) deleted" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:227 +#: src/baserow/contrib/automation/nodes/actions.py:231 msgid "Order nodes" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:228 +#: src/baserow/contrib/automation/nodes/actions.py:232 msgid "Node order changed" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:296 +#: src/baserow/contrib/automation/nodes/actions.py:300 msgid "Duplicate automation node" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:297 +#: src/baserow/contrib/automation/nodes/actions.py:301 #, python-format msgid "Node (%(node_id)s) duplicated" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:380 +#: src/baserow/contrib/automation/nodes/actions.py:384 msgid "Replace automation node" msgstr "" -#: src/baserow/contrib/automation/nodes/actions.py:382 +#: src/baserow/contrib/automation/nodes/actions.py:386 #, python-format msgid "" "Node (%(node_id)s) changed from a type of %(original_node_type)s to " "%(node_type)s" msgstr "" -#: src/baserow/contrib/automation/nodes/node_types.py:195 +#: src/baserow/contrib/automation/nodes/actions.py:491 +msgid "Moved automation node" +msgstr "" + +#: src/baserow/contrib/automation/nodes/actions.py:492 +#, python-format +msgid "Node (%(node_id)s) moved" +msgstr "" + +#: src/baserow/contrib/automation/nodes/node_types.py:176 msgid "Branch" msgstr "" @@ -204,18 +213,18 @@ msgstr "" msgid "Widget \"%(widget_title)s\" (%(widget_id)s) deleted" msgstr "" -#: src/baserow/contrib/integrations/core/service_types.py:1083 +#: src/baserow/contrib/integrations/core/service_types.py:1103 msgid "Branch taken" msgstr "" -#: src/baserow/contrib/integrations/core/service_types.py:1088 +#: src/baserow/contrib/integrations/core/service_types.py:1108 msgid "Label" msgstr "" -#: src/baserow/contrib/integrations/core/service_types.py:1090 +#: src/baserow/contrib/integrations/core/service_types.py:1110 msgid "The label of the branch that matched the condition." msgstr "" -#: src/baserow/contrib/integrations/core/service_types.py:1374 +#: src/baserow/contrib/integrations/core/service_types.py:1402 msgid "Triggered at" msgstr "" diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py index fe801d335a..ae4abdb25c 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py @@ -142,22 +142,17 @@ class AiMessageSerializer(serializers.Serializer): class AiThinkingSerializer(serializers.Serializer): type = serializers.CharField(default=AssistantMessageType.AI_THINKING) - code = serializers.CharField( - help_text=( - "Thinking code. If empty, signals end of thinking. This is used to provide recurring " - "messages that have a translation in the frontend (i.e. 'thinking', 'answering', etc.)" - ) - ) content = serializers.CharField( - default="", - allow_blank=True, - help_text=( - "A short description of what the AI is thinking about. It can be used to " - "provide a dynamic message that don't have a translation in the frontend." - ), + default="The AI is thinking...", + help_text=("The message to show while the AI is thinking"), ) +class AiNavigationSerializer(serializers.Serializer): + type = serializers.CharField(default=AssistantMessageType.AI_NAVIGATION) + location = serializers.DictField(help_text=("The location to navigate to.")) + + class AiErrorMessageSerializer(serializers.Serializer): type = serializers.CharField(default=AssistantMessageType.AI_ERROR) code = serializers.CharField( @@ -185,6 +180,7 @@ class HumanMessageSerializer(serializers.Serializer): AssistantMessageType.HUMAN: HumanMessageSerializer, AssistantMessageType.AI_MESSAGE: AiMessageSerializer, AssistantMessageType.AI_THINKING: AiThinkingSerializer, + AssistantMessageType.AI_NAVIGATION: AiNavigationSerializer, AssistantMessageType.AI_ERROR: AiErrorMessageSerializer, } diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py index 648c203f05..c0b2baa1df 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py @@ -161,6 +161,13 @@ def post(self, request: Request, chat_uuid: str, data) -> StreamingHttpResponse: handler = AssistantHandler() chat, _ = handler.get_or_create_chat(request.user, workspace, chat_uuid) + + # Clearing the user websocket_id will make sure real-time updates are sent + chat.user.web_socket_id = None + # FIXME: As long as we don't allow users to change it, temporarily set the + # timezone to the one provided in the UI context + chat.user.profile.timezone = ui_context.timezone + assistant = handler.get_assistant(chat) human_message = HumanMessage(content=data["content"], ui_context=ui_context) diff --git a/enterprise/backend/src/baserow_enterprise/apps.py b/enterprise/backend/src/baserow_enterprise/apps.py index 98c1475fcc..8404f5cbfd 100755 --- a/enterprise/backend/src/baserow_enterprise/apps.py +++ b/enterprise/backend/src/baserow_enterprise/apps.py @@ -302,14 +302,42 @@ def ready(self): notification_type_registry.register(TwoWaySyncUpdateFailedNotificationType()) notification_type_registry.register(TwoWaySyncDeactivatedNotificationType()) + from baserow_enterprise.assistant.tools import ( + CreateDatabaseToolType, + CreateFieldsToolType, + CreateTablesToolType, + CreateViewFiltersToolType, + CreateViewsToolType, + GetRowsToolsToolType, + GetTablesSchemaToolType, + ListDatabasesToolType, + ListRowsToolType, + ListTablesToolType, + ListViewsToolType, + NavigationToolType, + SearchDocsToolType, + ) from baserow_enterprise.assistant.tools.registries import ( assistant_tool_registry, ) - from baserow_enterprise.assistant.tools.search_docs.tools import ( - SearchDocsToolType, - ) assistant_tool_registry.register(SearchDocsToolType()) + assistant_tool_registry.register(NavigationToolType()) + + assistant_tool_registry.register(ListDatabasesToolType()) + assistant_tool_registry.register(CreateDatabaseToolType()) + assistant_tool_registry.register(ListTablesToolType()) + assistant_tool_registry.register(CreateTablesToolType()) + assistant_tool_registry.register(GetTablesSchemaToolType()) + assistant_tool_registry.register(CreateFieldsToolType()) + + assistant_tool_registry.register(ListRowsToolType()) + assistant_tool_registry.register(GetRowsToolsToolType()) + + assistant_tool_registry.register(ListViewsToolType()) + assistant_tool_registry.register(CreateViewsToolType()) + + assistant_tool_registry.register(CreateViewFiltersToolType()) # The signals must always be imported last because they use the registries # which need to be filled first. diff --git a/enterprise/backend/src/baserow_enterprise/assistant/adapter.py b/enterprise/backend/src/baserow_enterprise/assistant/adapter.py new file mode 100644 index 0000000000..1702fb1c3c --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/adapter.py @@ -0,0 +1,14 @@ +import dspy + +from .prompts import ASSISTANT_SYSTEM_PROMPT + + +class ChatAdapter(dspy.ChatAdapter): + def format_field_description(self, signature: type[dspy.Signature]) -> str: + """ + This is the first part of the prompt the LLM sees, so we prepend our custom + system prompt to it to give it the personality and context of Baserow. + """ + + field_description = super().format_field_description(signature) + return ASSISTANT_SYSTEM_PROMPT + "## TASK INSTRUCTIONS:\n\n" + field_description diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index 7f2259fe20..7c2bee09b2 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -9,11 +9,13 @@ from baserow_enterprise.assistant.tools.registries import assistant_tool_registry +from .adapter import ChatAdapter from .models import AssistantChat, AssistantChatMessage -from .prompts import ASSISTANT_SYSTEM_PROMPT +from .react import ReAct from .types import ( AiMessage, AiMessageChunk, + AiNavigationMessage, AiThinkingMessage, AssistantMessageUnion, ChatTitleMessage, @@ -26,9 +28,9 @@ class ChatSignature(dspy.Signature): question: str = dspy.InputField() history: dspy.History = dspy.InputField() - ui_context: UIContext = dspy.InputField( + ui_context: UIContext | None = dspy.InputField( default=None, - description=( + desc=( "The frontend UI content the user is currently in. " "Whenever make sense, use it to ground your answer." ), @@ -104,6 +106,7 @@ def on_tool_end( call_id, instance, inputs, outputs, exception ) + # If the tool produced sources, add them to the overall list of sources. if isinstance(outputs, dict) and "sources" in outputs: self.extend_sources(outputs["sources"]) @@ -118,56 +121,15 @@ def __init__(self, chat: AssistantChat): self._lm_client = dspy.LM( model=lm_model, cache=not settings.DEBUG, + max_retries=5, ) - Signature = self._get_chat_signature() - self._assistant = dspy.ReAct( - Signature, - tools=assistant_tool_registry.list_all_usable_tools( - self._user, self._workspace - ), + tools = assistant_tool_registry.list_all_usable_tools( + self._user, self._workspace ) + self._assistant = ReAct(ChatSignature, tools=tools) self.history = None - def _get_chat_signature(self) -> dspy.Signature: - """ - Returns the appropriate signature for the chat based on whether it has a title. - - :return: the dspy.Signature for the chat, with the chat_title field included if - the chat does not yet have a title, otherwise with only the question, - history, and answer fields. - """ - - chat_signature_instructions = "## INSTRUCTIONS\n\nGiven the fields `question`, `history`, and `ui_context`, produce the fields `answer`" - if self._chat.title: # only inject our base system prompt - return ChatSignature.with_instructions( - "\n".join( - [ - ASSISTANT_SYSTEM_PROMPT, - f"{chat_signature_instructions}.", - ] - ) - ) - else: # the chat also needs a title - return dspy.Signature( - { - **ChatSignature.fields, - "chat_title": dspy.OutputField( - max_length=20, - description=( - "Capture the core intent of the conversation in ≀ 8 words for the chat title. " - "Use clear, action-oriented language (gerund verbs up front where possible)." - ), - ), - }, - instructions="\n".join( - [ - ASSISTANT_SYSTEM_PROMPT, - f"{chat_signature_instructions}, `chat_title`.", - ] - ), - ) - async def acreate_chat_message( self, role: AssistantChatMessage.Role, @@ -271,7 +233,13 @@ async def astream_messages( ensure_llm_model_accessible(self._lm_client) callback_manager = AssistantCallbacks() - with dspy.context(lm=self._lm_client, callbacks=[callback_manager]): + + with dspy.context( + lm=self._lm_client, + cache=not settings.DEBUG, + callbacks=[*dspy.settings.config.callbacks, callback_manager], + adapter=ChatAdapter(), + ): if self.history is None: await self.aload_chat_history() @@ -279,10 +247,6 @@ async def astream_messages( stream_listeners = [ StreamListener(signature_field_name="answer"), ] - if not self._chat.title: - stream_listeners.append( - StreamListener(signature_field_name="chat_title") - ) stream_predict = dspy.streamify( self._assistant, @@ -300,30 +264,32 @@ async def astream_messages( AssistantChatMessage.Role.HUMAN, human_message.content ) - chat_title, answer = "", "" + answer = "" async for stream_chunk in output_stream: if isinstance(stream_chunk, StreamResponse): # Accumulate chunks per field to deliver full, real‐time updates. - match stream_chunk.signature_field_name: - case "answer": - answer += stream_chunk.chunk - yield AiMessageChunk( - content=answer, sources=callback_manager.sources - ) - case "chat_title": - chat_title += stream_chunk.chunk - yield ChatTitleMessage(content=chat_title) + if stream_chunk.signature_field_name == "answer": + answer += stream_chunk.chunk + yield AiMessageChunk( + content=answer, sources=callback_manager.sources + ) elif isinstance(stream_chunk, Prediction): - # Final output ready β€” save streamed answer and title + yield AiMessageChunk( + content=stream_chunk.answer, sources=callback_manager.sources + ) await self.acreate_chat_message( AssistantChatMessage.Role.AI, answer, artifacts={"sources": callback_manager.sources}, ) - if chat_title and not self._chat.title: - self._chat.title = chat_title - await self._chat.asave(update_fields=["title", "updated_on"]) - elif isinstance(stream_chunk, AiThinkingMessage): - # If any tool stream a thinking message, forward it to the user + elif isinstance(stream_chunk, (AiThinkingMessage, AiNavigationMessage)): + # forward thinking/navigation messages as-is to the frontend yield stream_chunk + + if not self._chat.title: + title_generator = dspy.Predict("question -> chat_title") + rsp = await title_generator.acall(question=human_message.content) + self._chat.title = rsp.chat_title + yield ChatTitleMessage(content=self._chat.title) + await self._chat.asave(update_fields=["title", "updated_on"]) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py index c4cff5ae17..f4f8762e75 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/prompts.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/prompts.py @@ -1,27 +1,30 @@ CORE_CONCEPTS = """ -## BASEROW STRUCTURE +### BASEROW STRUCTURE -β€’ **Workspace** β†’ Databases, Applications, Automations, Dashboards +**Structure**: Workspace β†’ Databases, Applications, Automations, Dashboards, Snapshots + +**Key concepts**: β€’ **Roles**: Free (admin, member) | Advanced/Enterprise (admin, builder, editor, viewer, no access) β€’ **Features**: Real-time collaboration, SSO (SAML2/OIDC/OAuth2), MCP integration, API access, Audit logs β€’ **Plans**: Free, Premium, Advanced, Enterprise (https://baserow.io/pricing) -β€’ **Open Source**: Core is open source (https://gitlab.com/baserow/baserow) +β€’ **Open Source**: Core is open source (https://github.com/baserow/baserow) +β€’ **Snapshots**: Application-level backups """ DATABASE_BUILDER_CONCEPTS = """ -## DATABASE BUILDER (no-code database) +### DATABASE BUILDER (no-code database) **Structure**: Database β†’ Tables β†’ Fields + Rows + Views + Webhooks **Key concepts**: β€’ **Fields**: Define schema (30+ types including link_row for relationships); one primary field per table β€’ **Views**: Present data with filters/sorts/grouping/colors; can be shared, personal, or public -β€’ **Snapshots**: Database backups; **Data sync**: Table replication; **Webhooks**: Row/field/view event triggers β€’ **Permissions**: RBAC at workspace/database/table/field levels; database tokens for API +β€’ **Data sync**: Table replication; **Webhooks**: Row/field/view event triggers """ APPLICATION_BUILDER_CONCEPTS = """ -## APPLICATION BUILDER (visual app builder) +### APPLICATION BUILDER (visual app builder) **Structure**: Application β†’ Pages β†’ Elements + Data Sources + Workflows @@ -32,37 +35,46 @@ β€’ **Publishing**: Requires domain configuration """ +AUTOMATION_BUILDER_CONCEPTS = """ +### AUTOMATIONS (no-code automation builder) + +**Structure**: Automation β†’ Workflows β†’ Triggers + Actions + Routers (Nodes) + +**Key concepts**: +β€’ **Triggers**: Events that start automations (e.g., row created/updated, view accessed) +β€’ **Actions**: Tasks performed (e.g., create/update rows, send emails, call webhooks) +β€’ **Routers**: Conditional logic (if/else, switch) to control flow +β€’ **Execution**: Runs in the background; monitor via logs +β€’ **History**: Track runs, successes, failures +β€’ **Publishing**: Requires domain configuration +""" + ASSISTANT_SYSTEM_PROMPT = ( """ You are Baserow Assistant, an AI expert for Baserow (open-source no-code platform). ## YOUR KNOWLEDGE - -You know: -1. **Core concepts** (below) - answer directly +1. **Core concepts** (below) 2. **Detailed docs** - use search_docs tool to search when needed 3. **API specs** - guide users to https://api.baserow.io/api/schema.json ## HOW TO HELP - β€’ Use American English spelling and grammar β€’ Be clear, concise, and actionable β€’ For troubleshooting: ask for error messages or describe expected vs actual results -β€’ If uncertain: acknowledge it, then suggest how to find the answer (search docs, check API, etc.) -β€’ Think step-by-step; guide to simple solutions +β€’ **NEVER** fabricate answers or URLs. Acknowledge when you can't be sure. +β€’ When you have the tools to help, **ALWAYS** use them instead of answering with instructions. +* At the end, **always** ask follow-up questions to understand user needs and continue the conversation. ## FORMATTING (CRITICAL) β€’ **No HTML**: Only Markdown (bold, italics, lists, code, tables) -β€’ **Lists**: Prefer lists when possible. Numbered lists for steps; bulleted for others -β€’ **Tables**: NEVER use tables. Use lists instead. +β€’ Prefer lists when possible. Numbered lists for steps; bulleted for others +β€’ NEVER use tables. Use lists instead. +## BASEROW CONCEPTS """ + CORE_CONCEPTS - + "\n" + DATABASE_BUILDER_CONCEPTS - + "\n" + APPLICATION_BUILDER_CONCEPTS - + "\n" - """ -""" + + AUTOMATION_BUILDER_CONCEPTS ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/react.py b/enterprise/backend/src/baserow_enterprise/assistant/react.py new file mode 100644 index 0000000000..42474ccd11 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/react.py @@ -0,0 +1,272 @@ +from typing import TYPE_CHECKING, Any, Callable, Literal + +import dspy +from dspy.adapters.types.tool import Tool +from dspy.predict.react import _fmt_exc +from dspy.primitives.module import Module +from dspy.signatures.signature import ensure_signature +from litellm import ContextWindowExceededError +from loguru import logger + +from .types import ToolsUpgradeResponse + +if TYPE_CHECKING: + from dspy.signatures.signature import Signature + + +# Variant of dspy.predict.react.ReAct that accepts a "meta-tool": +# a callable that can produce tools at runtime (e.g. per-table schemas). +# This lets a single ReAct instance handle many different table signatures +# without creating a new agent for each request. + + +class ReAct(Module): + def __init__( + self, signature: type["Signature"], tools: list[Callable], max_iters: int = 100 + ): + """ + ReAct stands for "Reasoning and Acting," a popular paradigm for building + tool-using agents. In this approach, the language model is iteratively provided + with a list of tools and has to reason about the current situation. The model + decides whether to call a tool to gather more information or to finish the task + based on its reasoning process. The DSPy version of ReAct is generalized to work + over any signature, thanks to signature polymorphism. + + Args: + signature: The signature of the module, which defines the input and output + of the react module. tools (list[Callable]): A list of functions, callable + objects, or `dspy.Tool` instances. max_iters (Optional[int]): The maximum + number of iterations to run. Defaults to 10. + + Example: + + ```python def get_weather(city: str) -> str: + return f"The weather in {city} is sunny." + + react = dspy.ReAct(signature="question->answer", tools=[get_weather]) pred = + react(question="What is the weather in Tokyo?") + """ + + super().__init__() + self.signature = signature = ensure_signature(signature) + self.max_iters = max_iters + + tools = [t if isinstance(t, Tool) else Tool(t) for t in tools] + tools = {tool.name: tool for tool in tools} + outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) + + tools["finish"] = Tool( + func=lambda: "Completed.", + name="finish", + desc=f"Marks the task as complete. That is, signals that all information for producing the outputs, i.e. {outputs}, are now available to be extracted.", + args={}, + ) + + self.tools = tools + self.react = self._build_react_module() + self.extract = self._build_fallback_module() + + def _build_instructions(self) -> list[str]: + signature = self.signature + inputs = ", ".join([f"`{k}`" for k in signature.input_fields.keys()]) + outputs = ", ".join([f"`{k}`" for k in signature.output_fields.keys()]) + instr = [f"{signature.instructions}\n"] if signature.instructions else [] + + instr.extend( + [ + f"You are an Agent. In each episode, you will be given the fields {inputs} as input. And you can see your past trajectory so far.", + f"Your goal is to use one or more of the supplied tools to collect any necessary information for producing {outputs}.\n", + "To do this, you will interleave next_thought, next_tool_name, and next_tool_args in each turn, and also when finishing the task.", + "After each tool call, you receive a resulting observation, which gets appended to your trajectory.\n", + "When writing next_thought, you may reason about the current situation and plan for future steps.", + "When selecting the next_tool_name and its next_tool_args, the tool must be one of:\n", + ] + ) + + for idx, tool in enumerate(self.tools.values()): + instr.append(f"({idx + 1}) {tool}") + instr.append( + "When providing `next_tool_args`, the value inside the field must be in JSON format" + ) + return instr + + def _build_react_module(self) -> type[Module]: + instructions = self._build_instructions() + react_signature = ( + dspy.Signature({**self.signature.input_fields}, "\n".join(instructions)) + .append("trajectory", dspy.InputField(), type_=str) + .append("next_thought", dspy.OutputField(), type_=str) + .append( + "next_tool_name", + dspy.OutputField(), + type_=Literal[tuple(self.tools.keys())], + ) + .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) + ) + + return dspy.Predict(react_signature) + + def _build_fallback_module(self) -> type[Module]: + signature = self.signature + fallback_signature = dspy.Signature( + {**signature.input_fields, **signature.output_fields}, + signature.instructions, + ).append("trajectory", dspy.InputField(), type_=str) + return dspy.ChainOfThought(fallback_signature) + + def _format_trajectory(self, trajectory: dict[str, Any]): + adapter = dspy.settings.adapter or dspy.ChatAdapter() + trajectory_signature = dspy.Signature(f"{', '.join(trajectory.keys())} -> x") + return adapter.format_user_message_content(trajectory_signature, trajectory) + + def forward(self, **input_args): + trajectory = {} + max_iters = input_args.pop("max_iters", self.max_iters) + for idx in range(max_iters): + try: + pred = self._call_with_potential_trajectory_truncation( + self.react, trajectory, **input_args + ) + except ValueError as err: + logger.warning( + f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}" + ) + break + + trajectory[f"thought_{idx}"] = pred.next_thought + trajectory[f"tool_name_{idx}"] = pred.next_tool_name + trajectory[f"tool_args_{idx}"] = pred.next_tool_args + + try: + result = self.tools[pred.next_tool_name](**pred.next_tool_args) + + # This is how meta tools return multiple tools, the first argument is + # the actual observation, the rest are new tools to add. Once we have + # add them, we need to rebuild the react module to include them. + # NOTE: tools will remain available for the rest of the trajectory, + # but won't be available in the next call to the agent. + if isinstance(result, ToolsUpgradeResponse): + new_tools = result.new_tools + observation = result.observation + for new_tool in new_tools: + if not isinstance(new_tool, Tool): + new_tool = Tool(new_tool) + self.tools[new_tool.name] = new_tool + self.react = self._build_react_module() + else: + observation = result + + trajectory[f"observation_{idx}"] = observation + except Exception as err: + trajectory[ + f"observation_{idx}" + ] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + + if pred.next_tool_name == "finish": + break + + extract = self._call_with_potential_trajectory_truncation( + self.extract, trajectory, **input_args + ) + return dspy.Prediction(trajectory=trajectory, **extract) + + async def aforward(self, **input_args): + trajectory = {} + max_iters = input_args.pop("max_iters", self.max_iters) + for idx in range(max_iters): + try: + pred = await self._async_call_with_potential_trajectory_truncation( + self.react, trajectory, **input_args + ) + except ValueError as err: + logger.warning( + f"Ending the trajectory: Agent failed to select a valid tool: {_fmt_exc(err)}" + ) + break + + trajectory[f"thought_{idx}"] = pred.next_thought + trajectory[f"tool_name_{idx}"] = pred.next_tool_name + trajectory[f"tool_args_{idx}"] = pred.next_tool_args + + try: + observation = await self.tools[pred.next_tool_name]( + **pred.next_tool_args + ) + + # This is how meta tools return multiple tools, the first argument is + # the actual observation, the rest are new tools to add. Once we have + # add them, we need to rebuild the react module to include them. + # NOTE: tools will remain available for the rest of the trajectory, + # but won't be available in the next call to the agent. + if isinstance(observation, (list, tuple)): + for new_tool in observation[1:]: + if not isinstance(new_tool, Tool): + new_tool = Tool(new_tool) + self.tools[new_tool.name] = new_tool + self.react = self._build_react_module() + + observation = observation[0] + + trajectory[f"observation_{idx}"] = observation + except Exception as err: + trajectory[ + f"observation_{idx}" + ] = f"Execution error in {pred.next_tool_name}: {_fmt_exc(err)}" + + if pred.next_tool_name == "finish": + break + + extract = await self._async_call_with_potential_trajectory_truncation( + self.extract, trajectory, **input_args + ) + return dspy.Prediction(trajectory=trajectory, **extract) + + def _call_with_potential_trajectory_truncation( + self, module, trajectory, **input_args + ): + for _ in range(3): + try: + return module( + **input_args, + trajectory=self._format_trajectory(trajectory), + ) + except ContextWindowExceededError: + logger.warning( + "Trajectory exceeded the context window, truncating the oldest tool call information." + ) + trajectory = self.truncate_trajectory(trajectory) + + async def _async_call_with_potential_trajectory_truncation( + self, module, trajectory, **input_args + ): + for _ in range(3): + try: + return await module.acall( + **input_args, + trajectory=self._format_trajectory(trajectory), + ) + except ContextWindowExceededError: + logger.warning( + "Trajectory exceeded the context window, truncating the oldest tool call information." + ) + trajectory = self.truncate_trajectory(trajectory) + + def truncate_trajectory(self, trajectory): + """Truncates the trajectory so that it fits in the context window. + + Users can override this method to implement their own truncation logic. + """ + + keys = list(trajectory.keys()) + if len(keys) < 4: + # Every tool call has 4 keys: thought, tool_name, tool_args, and + # observation. + raise ValueError( + "The trajectory is too long so your prompt exceeded the context window, but the " + "trajectory cannot be truncated because it only has one tool call." + ) + + for key in keys[:4]: + trajectory.pop(key) + + return trajectory diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py index e69de29bb2..b2bb87ffa3 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/__init__.py @@ -0,0 +1,3 @@ +from .database.tools import * # noqa: F401, F403 +from .navigation.tools import * # noqa: F401, F403 +from .search_docs.tools import * # noqa: F401, F403 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py new file mode 100644 index 0000000000..bd12a3daa2 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py @@ -0,0 +1,768 @@ +from typing import Any, Callable, Literal, Tuple + +from django.contrib.auth.models import AbstractUser +from django.contrib.contenttypes.models import ContentType +from django.db import transaction +from django.utils.translation import gettext as _ + +import dspy +from loguru import logger +from pydantic import create_model + +from baserow.contrib.database.fields.actions import UpdateFieldActionType +from baserow.contrib.database.fields.registries import field_type_registry +from baserow.contrib.database.models import Database +from baserow.contrib.database.table.actions import CreateTableActionType +from baserow.contrib.database.views.actions import ( + CreateViewActionType, + CreateViewFilterActionType, + UpdateViewFieldOptionsActionType, +) +from baserow.contrib.database.views.handler import ViewHandler +from baserow.core.actions import CreateApplicationActionType +from baserow.core.models import Workspace +from baserow.core.service import CoreService +from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers +from baserow_enterprise.assistant.types import ( + TableNavigationType, + ToolSignature, + ToolsUpgradeResponse, + ViewNavigationType, +) + +from . import utils +from .types import ( + AnyFieldItem, + AnyFieldItemCreate, + AnyViewFilterItem, + AnyViewFilterItemCreate, + AnyViewItemCreate, + BaseTableItem, + DatabaseItem, + ListTablesFilterArg, + TableItemCreate, + view_item_registry, +) + + +def get_list_databases_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[], list[DatabaseItem]]: + """ + Returns a function that lists all the databases the user has access to in the + current workspace. + """ + + def list_databases() -> list[DatabaseItem]: + """ + Lists all the databases the user can access. + """ + + nonlocal user, workspace, tool_helpers + + tool_helpers.update_status(_("Listing databases...")) + + applications_qs = CoreService().list_applications_in_workspace( + user, workspace, specific=False + ) + + database_content_type = ContentType.objects.get_for_model(Database) + + return { + "databases": [ + DatabaseItem(id=database.id, name=database.name).model_dump() + for database in applications_qs.filter( + content_type=database_content_type + ) + ] + } + + return list_databases + + +class ListDatabasesToolType(AssistantToolType): + type = "list_databases" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_list_databases_tool(user, workspace, tool_helpers) + + +def get_list_tables_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int], list[str]]: + """ + Returns a function that lists all the tables in a given database the user has + access to in the current workspace. + """ + + def list_tables(filters: ListTablesFilterArg) -> list[dict[str, Any]]: + """ + List tables that verifies the filters + + - Always call this before creating new tables to avoid duplicates. + - Always call this to link existing tables when table IDs are not known. + """ + + nonlocal user, workspace, tool_helpers + + tables = ( + utils.filter_tables(user, workspace) + .filter(filters.to_orm_filter()) + .select_related("database") + ) + + databases = {} + database_names = [] + for table in tables: + if table.database_id not in databases: + databases[table.database_id] = { + "id": table.database_id, + "name": table.database.name, + "tables": [], + } + database_names.append(table.database.name) + databases[table.database_id]["tables"].append( + { + "id": table.id, + "name": table.name, + "database_id": table.database_id, + } + ) + + tool_helpers.update_status( + _("Listing tables in %(database_names)s...") + % {"database_names": ", ".join(database_names)} + ) + + if len(databases) == 0: + return "No tables found" + elif len(databases) == 1: + # Return just the tables array when there's only one database + return list(databases.values())[0]["tables"] + else: + return list(databases.values()) + + return list_tables + + +class ListTablesToolType(AssistantToolType): + type = "list_tables" + thinking_message = "Looking for tables..." + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_list_tables_tool(user, workspace, tool_helpers) + + +def get_tables_schema_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int], list[str]]: + """ + Returns a function that lists all the fields in a given table the user has + access to in the current workspace. + """ + + def get_tables_schema( + table_ids: list[int], + full_schema: bool, + ) -> list[dict[str, Any]]: + """ + Returns the schema of the specified tables, including their fields if requested. + Use `full_schema=True` to get all the fields, otherwise only the table names, + IDs, primary keys, and relationships will be included. + + When to use: - Understanding table structure before creating/modifying fields - + Checking existing field names to avoid duplicates - Understanding table + relationships when creating link_row fields + + Remember: - Always call this before creating fields to avoid duplicate names - + Use get_rows_tools() for any row operations - not this one + """ + + nonlocal user, workspace, tool_helpers + + if not table_ids: + return [] + + tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) + + tool_helpers.update_status( + _("Inspecting %(table_names)s schema...") + % {"table_names": ", ".join(t.name for t in tables)} + ) + + return { + "tables_schema": [ + ts.model_dump() for ts in utils.get_tables_schema(tables, full_schema) + ] + } + + return get_tables_schema + + +class GetTablesSchemaToolType(AssistantToolType): + type = "get_tables_schema" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_tables_schema_tool(user, workspace, tool_helpers) + + +def get_create_database_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[str], dict[str, Any]]: + """ + Returns a function that creates a database in the current workspace. + """ + + def create_database(name: str) -> dict[str, Any]: + """ + Create a database in the current workspace and return its ID and name. + **ALWAYS** create tables afterwards unless explicitly asked otherwise. + + - name: desired database name (must be unique in the workspace) + - call list_databases first to avoid duplicates + - call the create_tables tools afterwards unless explicitly asked otherwise + """ + + nonlocal user, workspace, tool_helpers + + tool_helpers.update_status( + _("Creating database %(database_name)s...") % {"database_name": name} + ) + + with transaction.atomic(): + database = CreateApplicationActionType.do( + user, workspace, "database", name=name + ) + + return { + "created_database": DatabaseItem( + id=database.id, name=database.name + ).model_dump() + } + + return create_database + + +class CreateDatabaseToolType(AssistantToolType): + type = "create_database" + thinking_message = "Creating a new database..." + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_create_database_tool(user, workspace, tool_helpers) + + +def get_create_tables_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[list[TableItemCreate]], list[dict[str, Any]]]: + """ + Returns a function that creates a set of tables in a given database the user has + access to + """ + + def create_tables( + database_id: int, tables: list[TableItemCreate], add_sample_rows: bool = True + ) -> list[dict[str, Any]]: + """ + Creates tables with fields and rows in a database. **ALWAYS** add sample rows + unless explicitly asked otherwise. + + - table names should be unique in a database + - add meaningful fields with the appropriate types and relationships to other + existing tables. The reversed link_row fields will be created automatically. + - if add_sample_rows is True (default), add some example rows to each table + """ + + nonlocal user, workspace, tool_helpers + + if not tables: + return {"created_tables": []} + + database = CoreService().get_application( + user, + database_id, + specific=False, + base_queryset=Database.objects.filter(workspace=workspace), + ) + + created_tables = [] + with transaction.atomic(): + for i, table in enumerate(tables): + tool_helpers.update_status( + _("Creating table %(table_name)s...") % {"table_name": table.name} + ) + + created_table, __ = CreateTableActionType.do( + user, database, table.name, fill_example=False + ) + created_tables.append(created_table) + + primary_field_item = table.primary_field + primary_field = created_table.get_primary_field().specific + new_field_type = field_type_registry.get(primary_field_item.type) + UpdateFieldActionType.do( + user, + primary_field, + new_type_name=new_field_type.type, + name=primary_field_item.name, + ) + + # Now that we have all the tables created, we can create the fields + notes = [] + for table, created_table in zip(tables, created_tables): + with transaction.atomic(): + try: + utils.create_fields(user, created_table, table.fields, tool_helpers) + except Exception as e: + notes.append( + f"Error creating fields for table_{created_table.id}: {e}.\n" + f"Please retry recreating fields for table_{created_table.id} manually." + ) + + tool_helpers.navigate_to( + TableNavigationType( + type="database-table", + database_id=database.id, + table_id=created_table.id, + table_name=created_table.name, + ) + ) + + if add_sample_rows: + tools = {} + instructions = [] + tool_helpers.update_status( + _("Preparing example rows for these new tables...") + ) + for table, created_table in zip(tables, created_tables): + create_rows_tool = utils.get_table_rows_tools( + user, workspace, tool_helpers, created_table + )["create"] + tools[create_rows_tool.name] = create_rows_tool + instructions.append( + f"- Create 5 example rows for table_{created_table.id}. Fill every relationship with valid data when possible." + ) + + predictor = dspy.Predict(ToolSignature) + result = predictor( + question=("\n".join(instructions)), + tools=list(tools.values()), + ) + for call in result.outputs.tool_calls: + with transaction.atomic(): + try: + result = tools[call.name](**call.args) + notes.append( + f"Rows created for table_{created_table.id}: {result}" + ) + except Exception as e: + notes.append( + f"Error creating example rows for table_{created_table.id}: {e}\n." + f"Please retry recreating rows for table_{created_table.id} manually." + ) + + return { + "created_tables": [ + BaseTableItem(id=t.id, name=t.name).model_dump() for t in created_tables + ], + "notes": notes, + } + + return create_tables + + +class CreateTablesToolType(AssistantToolType): + type = "create_tables" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_create_tables_tool(user, workspace, tool_helpers) + + +def get_create_fields_tool( + user: AbstractUser, + workspace: Workspace, + tool_helpers: ToolHelpers, +) -> Callable[[int, list[AnyFieldItemCreate]], list[dict[str, Any]]]: + """ + Returns a function that creates fields in a given table the user has access to + in the current workspace. + """ + + def create_fields( + table_id: int, fields: list[AnyFieldItemCreate] + ) -> list[AnyFieldItem]: + """ + Creates fields in the specified table. + + - Choose the most appropriate field type for each field. + - Field names must be unique within a table: check existing names + when needed and skip duplicates. + - For link_row fields, ensure the linked table already exists in + the same database; create it first if needed. + """ + + nonlocal user, workspace, tool_helpers + + if not fields: + return [] + + table = utils.filter_tables(user, workspace).get(id=table_id) + + with transaction.atomic(): + created_fields = utils.create_fields(user, table, fields, tool_helpers) + return {"created_fields": [field.model_dump() for field in created_fields]} + + return create_fields + + +class CreateFieldsToolType(AssistantToolType): + type = "create_fields" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_create_fields_tool(user, workspace, tool_helpers) + + +def get_list_rows_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int, int, int, list[int] | None], list[dict[str, Any]]]: + """ + Returns a function that lists rows in a given table the user has access to in the + current workspace. + """ + + def list_rows( + table_id: int, + offset: int = 0, + limit: int = 10, + field_ids: list[int] | None = None, + ) -> list[dict[str, Any]]: + """ + Lists rows in the specified table. + + - Use offset and limit for pagination. + - Use field_ids to limit the response to specific fields. + """ + + nonlocal user, workspace, tool_helpers + + table = utils.filter_tables(user, workspace).get(id=table_id) + + tool_helpers.update_status( + _("Listing rows in %(table_name)s ") % {"table_name": table.name} + ) + + rows_qs = table.get_model().objects.all() + rows = rows_qs[offset : offset + limit] + + response_model = create_model( + f"ResponseTable{table.id}RowWithFieldFilter", + id=(int, ...), + __base__=utils.get_create_row_model(table, field_ids=field_ids), + ) + + return { + "rows": [ + response_model.from_django_orm(row, field_ids).model_dump() + for row in rows + ], + "total": rows_qs.count(), + } + + return list_rows + + +class ListRowsToolType(AssistantToolType): + type = "list_rows" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_list_rows_tool(user, workspace, tool_helpers) + + +def get_rows_meta_tool( + user: AbstractUser, + workspace: Workspace, + tool_helpers: ToolHelpers, +) -> Callable[[int, list[dict[str, Any]]], list[Any]]: + """ + Returns a meta-tool that, given a table ID, returns an observation that says that + new tools are available and a list of tools to create, update and delete rows + in that table. + """ + + def get_rows_tools( + table_ids: list[int], + operations: list[Literal["create", "update", "delete"]], + ) -> Tuple[str, list[Callable[[Any], Any]]]: + """ + Generates row operation tools for specified tables. Required before: + create/update/delete rows. + """ + + nonlocal user, workspace, tool_helpers + + observation = ["New tools are now available.\n"] + + new_tools = [] + tables = utils.filter_tables(user, workspace).filter(id__in=table_ids) + for table in tables: + table_tools = utils.get_table_rows_tools( + user, workspace, tool_helpers, table + ) + + observation.append(f"Table '{table.name}' (ID: {table.id}):") + + if "create" in operations: + create_rows = table_tools["create"] + new_tools.append(create_rows) + observation.append(f"- Use {create_rows.name} to create new rows.") + + if "update" in operations: + update_rows = table_tools["update"] + new_tools.append(update_rows) + observation.append( + f"- Use {update_rows.name} to update existing rows by their IDs." + ) + + if "delete" in operations: + delete_rows = table_tools["delete"] + new_tools.append(delete_rows) + observation.append( + f"- Use {delete_rows.name} to delete rows by their IDs." + ) + + observation.append("") + + return ToolsUpgradeResponse( + observation="\n".join(observation), new_tools=new_tools + ) + + return get_rows_tools + + +class GetRowsToolsToolType(AssistantToolType): + type = "get_rows_tools" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_rows_meta_tool(user, workspace, tool_helpers) + + +def get_list_views_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int], list[dict[str, Any]]]: + """ + Returns a function that lists all the views in a given table the user has + access to in the current workspace. + """ + + def list_views(table_id: int) -> list[dict[str, Any]]: + """ + List views in the specified table. + + - Always call this for existing tables to avoid creating views with duplicate + names. + """ + + nonlocal user, workspace, tool_helpers + + table = utils.filter_tables(user, workspace).get(id=table_id) + + tool_helpers.update_status( + _("Listing views in %(table_name)s...") % {"table_name": table.name} + ) + + views = ViewHandler().list_views( + user, + table, + filters=False, + sortings=False, + decorations=False, + group_bys=False, + limit=100, + ) + + return { + "views": [ + view_item_registry.from_django_orm(view).model_dump() for view in views + ] + } + + return list_views + + +class ListViewsToolType(AssistantToolType): + type = "list_views" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_list_views_tool(user, workspace, tool_helpers) + + +def get_create_views_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int, list[str]], list[str]]: + """ + Returns a function that creates views in a given table the user has access to + in the current workspace. + """ + + def create_views( + table_id: int, views: list[AnyViewItemCreate] + ) -> list[dict[str, Any]]: + """ + Creates views in the specified table. A default grid view showing all the rows + is created automatically when a table is created, no need to recreate it. + + - Choose the most appropriate view type for each view. + - View names must be unique within a table: check existing names when needed and + skip duplicates. + """ + + nonlocal user, workspace, tool_helpers + + if not views: + return [] + + table = utils.filter_tables(user, workspace).get(id=table_id) + + created_views = [] + with transaction.atomic(): + for view in views: + tool_helpers.update_status( + _("Creating %(view_type)s view %(view_name)s") + % {"view_type": view.type, "view_name": view.name} + ) + + orm_view = CreateViewActionType.do( + user, + table, + view.type, + **view.to_django_orm_kwargs(table), + ) + + field_options = view.field_options_to_django_orm() + if field_options: + UpdateViewFieldOptionsActionType.do(user, orm_view, field_options) + created_views.append({"id": orm_view.id, **view.model_dump()}) + + tool_helpers.navigate_to( + ViewNavigationType( + type="database-view", + database_id=table.database_id, + table_id=table.id, + view_id=created_views[0]["id"], + view_name=created_views[0]["name"], + ) + ) + + return {"created_views": created_views} + + return create_views + + +class CreateViewsToolType(AssistantToolType): + type = "create_views" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_create_views_tool(user, workspace, tool_helpers) + + +def get_create_view_filters_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[int, list[str]], list[str]]: + """ + Returns a function that creates views in a given table the user has access to + in the current workspace. + """ + + def create_view_filters( + view_id: int, filters: list[AnyViewFilterItemCreate] + ) -> list[AnyViewFilterItem]: + """ + Creates filters in the specified view. + + - Choose the most appropriate filter for each view. + - Filter names must be unique within a view: check existing names + when needed and skip duplicates. + """ + + nonlocal user, workspace, tool_helpers + + if not filters: + return [] + + orm_view = utils.get_view(user, view_id) + tool_helpers.update_status( + _("Creating filters in %(view_name)s...") % {"view_name": orm_view.name} + ) + + fields = {f.id: f for f in orm_view.table.field_set.all()} + + created_filters = [] + with transaction.atomic(): + for filter in filters: + field = fields.get(filter.field_id) + if field is None: + logger.info("Skipping filter creation due to missing field") + continue + field_type = field_type_registry.get_by_model(field.specific_class) + if field_type.type != filter.type: + logger.info("Skipping filter creation due to type mismatch") + continue + + filter_type = filter.get_django_orm_type(field) + filter_value = filter.get_django_orm_value( + field, timezone=user.profile.timezone + ) + + orm_filter = CreateViewFilterActionType.do( + user, + orm_view, + field, + filter_type, + filter_value, + filter_group_id=None, + ) + + created_filters.append({"id": orm_filter.id, **filter.model_dump()}) + + return {"created_view_filters": created_filters} + + return create_view_filters + + +class CreateViewFiltersToolType(AssistantToolType): + type = "create_view_filters" + + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_create_view_filters_tool(user, workspace, tool_helpers) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py new file mode 100644 index 0000000000..484eb7d528 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/__init__.py @@ -0,0 +1,6 @@ +from .base import * # noqa: F401, F403 +from .database import * # noqa: F401, F403 +from .fields import * # noqa: F401, F403 +from .table import * # noqa: F401, F403 +from .view_filters import * # noqa: F401, F403 +from .views import * # noqa: F401, F403 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py new file mode 100644 index 0000000000..b840542877 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/base.py @@ -0,0 +1,37 @@ +from datetime import date, datetime + +from pydantic import Field + +from baserow_enterprise.assistant.types import BaseModel + + +# Somehow LLMs struggle with dates +class Date(BaseModel): + year: int = Field(..., description="year (i.e. 2025).") + month: int = Field(..., description="month (1-12).") + day: int = Field(..., description="day (1-31).") + + def to_django_orm(self): + return date(self.year, self.month, self.day).isoformat() + + @classmethod + def from_django_orm(cls, orm_date: date) -> "Date": + d = orm_date + return cls(year=d.year, month=d.month, day=d.day) + + +class Datetime(Date): + hour: int = Field(..., description="hour (0-23).") + minute: int = Field(..., description="minute (0-59).") + + def to_django_orm(self): + return datetime( + self.year, self.month, self.day, self.hour, self.minute + ).isoformat() + + @classmethod + def from_django_orm(cls, orm_datetime: datetime) -> "Datetime": + dt = orm_datetime + return cls( + year=dt.year, month=dt.month, day=dt.day, hour=dt.hour, minute=dt.minute + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/database.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/database.py new file mode 100644 index 0000000000..5508ff0200 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/database.py @@ -0,0 +1,15 @@ +from pydantic import Field + +from baserow_enterprise.assistant.types import BaseModel + + +class DatabaseItemCreate(BaseModel): + """Base model for creating a new database (no ID).""" + + name: str = Field(...) + + +class DatabaseItem(DatabaseItemCreate): + """Model for an existing database (with ID).""" + + id: int = Field(...) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py new file mode 100644 index 0000000000..347e942629 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/fields.py @@ -0,0 +1,494 @@ +from typing import Annotated, Literal, Type + +from django.db.models import Q + +from baserow_premium.permission_manager import Table +from pydantic import Field + +from baserow.contrib.database.fields.models import DateField +from baserow.contrib.database.fields.models import Field as BaserowField +from baserow.contrib.database.fields.models import ( + LinkRowField, + MultipleSelectField, + NumberField, + RatingField, + SingleSelectField, +) +from baserow.contrib.database.fields.registries import field_type_registry +from baserow_enterprise.assistant.types import BaseModel +from baserow_enterprise.data_sync.hubspot_contacts_data_sync import LongTextField + + +class FieldItemCreate(BaseModel): + """Base model for creating a new field (no ID).""" + + name: str = Field(...) + type: str = Field(...) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return {k: v for k, v in self.model_dump().items() if k not in {"id", "type"}} + + +class FieldItem(FieldItemCreate): + """Model for an existing field (with ID).""" + + id: int = Field(...) + + @classmethod + def from_django_orm(cls, orm_field: BaserowField) -> "FieldItem": + return cls( + id=orm_field.id, + name=orm_field.name, + type=field_type_registry.get_by_model(orm_field).type, + ) + + +# Event if type could be inferred, certain models (i.e. openai-gpt-oss-120b) requires +# all the fields to be required and can cause issues with optional fields, so we +# explicitly set them as required, even if seems unnecessary. + + +class BaseTextFieldItem(FieldItemCreate): + type: Literal["text"] = Field(..., description="Single line text field.") + + +class TextFieldItemCreate(BaseTextFieldItem): + """Model for creating a text field.""" + + +class TextFieldItem(BaseTextFieldItem, FieldItem): + """Model for an existing text field.""" + + +class BaseLongTextFieldItem(FieldItemCreate): + type: Literal["long_text"] = Field( + ..., + description="Multi-line text field. Ideal for descriptions, notes and long-form content.", + ) + rich_text: bool = Field( + ..., + description="Whether the long text field supports rich text. Default is True.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "long_text_enable_rich_text": self.rich_text, + } + + +class LongTextFieldItemCreate(BaseLongTextFieldItem): + """Model for creating a long text field.""" + + +class LongTextFieldItem(BaseLongTextFieldItem, FieldItem): + """Model for an existing long text field.""" + + @classmethod + def from_django_orm(cls, orm_field: LongTextField) -> "LongTextFieldItem": + field = orm_field.specific + return cls( + id=field.id, + name=field.name, + type="long_text", + rich_text=orm_field.long_text_enable_rich_text, + ) + + +class BaseNumberFieldItem(FieldItemCreate): + type: Literal["number"] = Field( + ..., description="Numeric field, with decimals and optional prefix/suffix." + ) + decimal_places: int = Field( + ..., description="The number of decimal places. Default is 2." + ) + suffix: str = Field( + ..., + description="An optional suffix to display after the number. Default is empty.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "number_decimal_places": self.decimal_places, + "number_suffix": self.suffix, + } + + +class NumberFieldItemCreate(BaseNumberFieldItem): + """Model for creating a number field.""" + + +class NumberFieldItem(BaseNumberFieldItem, FieldItem): + """Model for an existing number field.""" + + @classmethod + def from_django_orm(cls, orm_field: NumberField) -> "NumberFieldItem": + return cls( + id=orm_field.id, + name=orm_field.name, + type="number", + decimal_places=orm_field.number_decimal_places, + suffix=orm_field.number_suffix, + ) + + +class BaseRatingFieldItem(FieldItemCreate): + type: Literal["rating"] = Field( + ..., description="Rating field. Ideal for reviews or scores." + ) + max_value: int = Field( + ..., description="The maximum value of the rating field. Default is 5." + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "max_value": self.max_value, + } + + +class RatingFieldItemCreate(BaseRatingFieldItem): + """Model for creating a rating field.""" + + +class RatingFieldItem(BaseRatingFieldItem, FieldItem): + """Model for an existing rating field.""" + + @classmethod + def from_django_orm(cls, orm_field: RatingField) -> "RatingFieldItem": + return cls( + id=orm_field.id, + name=orm_field.name, + type="rating", + max_value=orm_field.max_value, + ) + + +class BaseBooleanFieldItem(FieldItemCreate): + type: Literal["boolean"] = Field(..., description="Boolean field.") + + +class BooleanFieldItemCreate(BaseBooleanFieldItem): + """Model for creating a boolean field.""" + + +class BooleanFieldItem(BaseBooleanFieldItem, FieldItem): + """Model for an existing boolean field.""" + + +class BaseDateFieldItem(FieldItemCreate): + type: Literal["date"] = Field(..., description="Date or datetime field.") + include_time: bool = Field( + ..., description="Whether the date field includes time. Default is False." + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "date_include_time": self.include_time, + } + + +class DateFieldItemCreate(BaseDateFieldItem): + """Model for creating a date field.""" + + +class DateFieldItem(BaseDateFieldItem, FieldItem): + """Model for an existing date field.""" + + @classmethod + def from_django_orm(cls, orm_field: DateField) -> "DateFieldItem": + return cls( + id=orm_field.id, + name=orm_field.name, + type="date", + include_time=orm_field.date_include_time, + ) + + +class BaseLinkRowFieldItem(FieldItemCreate): + type: Literal["link_row"] = Field( + ..., description="Link row field. It creates relationships between tables." + ) + linked_table: str | int = Field( + ..., description="The ID or the name of the table this field links to." + ) + has_link_back: bool = Field( + ..., + description="Whether the linked table should also have a link row field back to this table. Default is True.", + ) + multiple: bool = Field( + ..., description="Whether multiple links are allowed. Default is True." + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + if isinstance(self.linked_table, str): + q = Q(name=self.linked_table, database=table.database) + else: + q = Q(id=self.linked_table, database=table.database) + + try: + link_row_table = Table.objects.get(q) + except Table.DoesNotExist: + raise ValueError( + f"The linked_table '{self.linked_table}' does not exist in the database." + "Ensure you provide a valid table name or ID." + ) + + return { + "name": self.name, + "link_row_table": link_row_table, + "link_row_multiple_relationships": self.multiple, + "has_related_field": self.has_link_back and table != link_row_table, + } + + +class LinkRowFieldItemCreate(BaseLinkRowFieldItem): + """Model for creating a link row field.""" + + +class LinkRowFieldItem(BaseLinkRowFieldItem, FieldItem): + """Model for an existing link row field.""" + + @classmethod + def from_django_orm(cls, orm_field: LinkRowField) -> "BaseLinkRowFieldItem": + return cls( + id=orm_field.id, + name=orm_field.name, + type="link_row", + linked_table=orm_field.link_row_table_id, + multiple=orm_field.link_row_multiple_relationships, + has_link_back=orm_field.link_row_related_field_id is not None, + ) + + +OptionColor = Literal[ + "light-blue", + "light-green", + "light-cyan", + "light-orange", + "light-yellow", + "light-red", + "light-brown", + "light-purple", + "light-pink", + "light-gray", + "blue", + "green", + "cyan", + "orange", + "yellow", + "red", + "brown", + "purple", + "pink", + "gray", + "dark-blue", + "dark-green", + "dark-cyan", + "dark-orange", + "dark-yellow", + "dark-red", + "dark-brown", + "dark-purple", + "dark-pink", + "dark-gray", + "darker-blue", + "darker-green", + "darker-cyan", + "darker-orange", + "darker-yellow", + "darker-red", + "darker-brown", + "darker-purple", + "darker-pink", + "darker-gray", + "deep-dark-green", + "deep-dark-orange", +] + + +class SelectOption(BaseModel): + id: int | None = Field(..., description="The unique identifier of the option.") + value: str + color: OptionColor + + +# Define a subset of colors to use when creating fields, so we don't confuse the model +# with too many options. +OptionColorCreate = Literal[ + "blue", + "green", + "cyan", + "orange", + "yellow", + "red", + "brown", + "purple", + "pink", + "gray", +] + + +class SelectOptionCreate(BaseModel): + value: str + color: OptionColorCreate + + +class BaseSingleSelectFieldItem(FieldItemCreate): + type: Literal["single_select"] = Field( + ..., + description="Single select field. Allows users to choose one option from a list.", + ) + + +class SingleSelectFieldItemCreate(BaseSingleSelectFieldItem): + options: list[SelectOptionCreate] = Field( + description="The list of options for the field. Use appropriate colors for each option.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "select_options": [ + {"id": -i, "value": option.value, "color": option.color} + for (i, option) in enumerate(self.options, start=1) + ], + } + + +class SingleSelectFieldItem(BaseSingleSelectFieldItem, FieldItem): + options: list[SelectOption] = Field( + description="The list of options for the field.", + ) + + @classmethod + def from_django_orm( + cls, orm_field: SingleSelectField + ) -> "BaseSingleSelectFieldItem": + field = orm_field.specific + return cls( + id=field.id, + name=field.name, + type="single_select", + options=[ + SelectOption( + id=opt.id, + value=opt.value, + color=opt.color, + ) + for opt in field.select_options.all() + ], + ) + + +class BaseMultipleSelectFieldItem(FieldItemCreate): + type: Literal["multiple_select"] = Field( + ..., + description="Multiple select field. Allows users to choose multiple options from a list.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "select_options": [ + {"id": -i, "value": option.value, "color": option.color} + for (i, option) in enumerate(self.options, start=1) + ], + } + + +class MultipleSelectFieldItemCreate(BaseMultipleSelectFieldItem): + options: list[SelectOptionCreate] = Field( + description="The list of options for the field. Use appropriate colors for each option.", + ) + + +class MultipleSelectFieldItem(BaseMultipleSelectFieldItem, FieldItem): + options: list[SelectOption] = Field( + description="The list of options for the field.", + ) + + @classmethod + def from_django_orm( + cls, orm_field: MultipleSelectField + ) -> "BaseMultipleSelectFieldItem": + field = orm_field.specific + return cls( + id=field.id, + name=field.name, + type="multiple_select", + options=[ + SelectOption( + id=opt.id, + value=opt.value, + color=opt.color, + ) + for opt in field.select_options.all() + ], + ) + + +class BaseFileFieldItem(FieldItemCreate): + type: Literal["file"] = Field(..., description="File field.") + + +class FileFieldItemCreate(BaseFileFieldItem): + pass + + +class FileFieldItem(BaseFileFieldItem, FieldItem): + pass + + +AnyFieldItemCreate = Annotated[ + TextFieldItemCreate + | LongTextFieldItemCreate + | NumberFieldItemCreate + | RatingFieldItemCreate + | BooleanFieldItemCreate + | DateFieldItemCreate + | LinkRowFieldItemCreate + | SingleSelectFieldItemCreate + | MultipleSelectFieldItemCreate + | FileFieldItemCreate, + Field(discriminator="type"), +] + +AnyFieldItem = ( + TextFieldItem + | LongTextFieldItem + | NumberFieldItem + | RatingFieldItem + | BooleanFieldItem + | DateFieldItem + | LinkRowFieldItem + | SingleSelectFieldItem + | MultipleSelectFieldItem + | FileFieldItem + | FieldItem +) + + +class FieldItemsRegistry: + _registry = { + "text": TextFieldItem, + "long_text": LongTextFieldItem, + "number": NumberFieldItem, + "date": DateFieldItem, + "boolean": BooleanFieldItem, + "rating": RatingFieldItem, + "link_row": LinkRowFieldItem, + "single_select": SingleSelectFieldItem, + "multiple_select": MultipleSelectFieldItem, + "file": FileFieldItem, + } + + def from_django_orm(self, orm_field: Type[BaserowField]) -> FieldItem: + field_type = field_type_registry.get_by_model(orm_field).type + field_class: FieldItem = self._registry.get(field_type, FieldItem) + return field_class.from_django_orm(orm_field) + + +field_item_registry = FieldItemsRegistry() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py new file mode 100644 index 0000000000..ae49ebc265 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/table.py @@ -0,0 +1,68 @@ +from django.db.models import Q + +from pydantic import Field + +from baserow_enterprise.assistant.types import BaseModel + +from .fields import AnyFieldItem, AnyFieldItemCreate + + +class BaseTableItemCreate(BaseModel): + """Model for an existing table (with ID).""" + + name: str = Field(..., description="The name of the table.") + + +class BaseTableItem(BaseTableItemCreate): + """Base model for creating a new table (no ID).""" + + id: int = Field(..., description="The unique identifier of the table.") + + +class TableItemCreate(BaseTableItemCreate): + """Model for creating a table with fields.""" + + primary_field: AnyFieldItemCreate = Field( + ..., + description="The primary field of the table. Preferbly a text field with a sensible name for a primary field of the table.", + ) + fields: list[AnyFieldItemCreate] = Field( + ..., description="The fields of the table." + ) + + +class TableItem(BaseTableItem): + """Model for an existing table with fields.""" + + primary_field: AnyFieldItem = Field( + ..., description="The primary field of the table." + ) + fields: list[AnyFieldItem] = Field(..., description="The fields of the table.") + + +class ListTablesFilterArg(BaseModel): + database_ids: list[int] | None = Field( + ..., description="A list of database_ids to filter. None to exclude this filter" + ) + database_names: list[str] | None = Field( + ..., + description="A list of database_names to filter. None to exclude this filter", + ) + table_ids: list[int] | None = Field( + ..., description="A list of table ids to filter. None to exclude this filter" + ) + table_names: list[str] | None = Field( + ..., description="A list of table names to filter. None to exclude this filter" + ) + + def to_orm_filter(self) -> Q: + q_filter = Q() + if self.database_ids: + q_filter &= Q(database_id__in=self.database_ids) + if self.database_names: + q_filter &= Q(database__name__in=self.database_names) + if self.table_ids: + q_filter &= Q(id__in=self.table_ids) + if self.table_names: + q_filter &= Q(name__in=self.table_names) + return q_filter diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py new file mode 100644 index 0000000000..b19f742a87 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/view_filters.py @@ -0,0 +1,557 @@ +from typing import Literal + +from pydantic import Field + +from baserow_enterprise.assistant.types import Annotated, BaseModel + +from .base import Date + + +class ViewFilterItemCreate(BaseModel): + """Model for creating a new view filter (no ID).""" + + field_id: int = Field(...) + type: str = Field(...) + operator: str = Field(...) + value: str = Field(...) + + def get_django_orm_type(self, field, **kwargs) -> str: + return self.operator + + def get_django_orm_value(self, field, **kwargs) -> str: + return self.value + + +class ViewFilterItem(ViewFilterItemCreate): + """Model for an existing view filter (with ID).""" + + id: int = Field(..., description="The unique identifier of the view filter.") + + +class TextViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["text"] = Field(..., description="A text filter.") + value: str = Field(..., description="The text value to filter on.") + + +class TextEqualViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["equal"] = Field( + ..., description="Checks if the field is equal to the value." + ) + + +class TextEqualViewFilterItem(TextEqualViewFilterItemCreate, ViewFilterItem): + pass + + +class TextNotEqualViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["not_equal"] = Field( + ..., description="Checks if the field is not equal to the value." + ) + + +class TextNotEqualViewFilterItem(TextNotEqualViewFilterItemCreate, ViewFilterItem): + pass + + +class TextContainsViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["contains"] = Field( + ..., description="Checks if the field contains the value." + ) + + +class TextContainsViewFilterItem(TextContainsViewFilterItemCreate, ViewFilterItem): + pass + + +class TextNotContainsViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["contains_not"] = Field( + ..., description="Checks if the field does not contain the value." + ) + + +class TextNotContainsViewFilterItem( + TextNotContainsViewFilterItemCreate, ViewFilterItem +): + pass + + +class TextEmptyViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["is_empty"] = Field( + ..., description="Checks if the field is empty." + ) + + +class TextEmptyViewFilterItem(TextEmptyViewFilterItemCreate, ViewFilterItem): + pass + + +class TextNotEmptyViewFilterItemCreate(TextViewFilterItemCreate): + operator: Literal["is_not_empty"] = Field( + ..., description="Checks if the field is not empty." + ) + + +class TextNotEmptyViewFilterItem(TextNotEmptyViewFilterItemCreate, ViewFilterItem): + pass + + +AnyTextViewFilterItemCreate = Annotated[ + TextEqualViewFilterItemCreate + | TextNotEqualViewFilterItemCreate + | TextContainsViewFilterItemCreate + | TextNotContainsViewFilterItemCreate + | TextEmptyViewFilterItemCreate + | TextNotEmptyViewFilterItemCreate, + Field(discriminator="operator"), +] + +AnyTextViewFilterItem = Annotated[ + TextEqualViewFilterItem + | TextNotEqualViewFilterItem + | TextContainsViewFilterItem + | TextNotContainsViewFilterItem + | TextEmptyViewFilterItem + | TextNotEmptyViewFilterItem, + Field(discriminator="operator"), +] + + +class NumberViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["number"] = Field(..., description="A number filter.") + value: float = Field(..., description="The number value to filter on.") + + def get_django_orm_value(self, field, **kwargs) -> str: + return str(self.value) + + +class NumberViewFilterItem(NumberViewFilterItemCreate, ViewFilterItem): + pass + + +class NumberEqualsViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["equal"] = Field( + ..., description="Checks if the field is equal to the value." + ) + + +class NumberEqualsViewFilterItem(NumberEqualsViewFilterItemCreate, ViewFilterItem): + pass + + +class NumberNotEqualsViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["not_equal"] = Field( + ..., description="Checks if the field is not equal to the value." + ) + + +class NumberNotEqualsViewFilterItem( + NumberNotEqualsViewFilterItemCreate, ViewFilterItem +): + pass + + +class NumberHigherThanViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["higher_than"] = Field( + ..., description="Checks if the field is higher than the value." + ) + or_equal: bool = Field( + False, + description="If true, checks if the field is higher than or equal to the value.", + ) + + +class NumberHigherThanViewFilterItem( + NumberHigherThanViewFilterItemCreate, ViewFilterItem +): + pass + + +class NumberLowerThanViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["lower_than"] = Field( + ..., description="Checks if the field is lower than the value." + ) + or_equal: bool = Field( + False, + description="If true, checks if the field is lower than or equal to the value.", + ) + + +class NumberLowerThanViewFilterItem( + NumberLowerThanViewFilterItemCreate, ViewFilterItem +): + pass + + +class NumberEmptyViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["is_empty"] = Field( + ..., description="Checks if the field is empty." + ) + + +class NumberEmptyViewFilterItem(NumberEmptyViewFilterItemCreate, ViewFilterItem): + pass + + +class NumberNotEmptyViewFilterItemCreate(NumberViewFilterItemCreate): + operator: Literal["is_not_empty"] = Field( + ..., description="Checks if the field is not empty." + ) + + +class NumberNotEmptyViewFilterItem(NumberNotEmptyViewFilterItemCreate, ViewFilterItem): + pass + + +AnyNumberViewFilterItemCreate = Annotated[ + NumberEqualsViewFilterItemCreate + | NumberNotEqualsViewFilterItemCreate + | NumberHigherThanViewFilterItemCreate + | NumberLowerThanViewFilterItemCreate + | NumberEmptyViewFilterItemCreate + | NumberNotEmptyViewFilterItemCreate, + Field(discriminator="operator"), +] + +AnyNumberViewFilterItem = Annotated[ + NumberEqualsViewFilterItem + | NumberNotEqualsViewFilterItem + | NumberHigherThanViewFilterItem + | NumberLowerThanViewFilterItem + | NumberEmptyViewFilterItem + | NumberNotEmptyViewFilterItem, + Field(discriminator="operator"), +] + + +class DateViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["date"] = Field(..., description="A date filter.") + value: Date | int | None = Field( + ..., + description="\n".join( + [ + "The date value to filter on.", + "Use an integer for days/weeks/months/years ago/from now.", + "Use a date object for an exact date.", + "None otherwise.", + ] + ), + ) + mode: Literal[ + "today", + "yesterday", + "tomorrow", + "this_week", + "last_week", + "next_week", + "this_month", + "last_month", + "next_month", + "this_year", + "last_year", + "next_year", + "nr_days_ago", + "nr_days_from_now", + "nr_weeks_ago", + "nr_weeks_from_now", + "nr_months_ago", + "nr_months_from_now", + "nr_years_ago", + "nr_years_from_now", + "exact_date", + ] = Field( + "exact_date", + description="The mode to use for the date filter. ALWAYS use the right mode if available. Use 'exact_date' if you have an exact date.", + ) + + def get_django_orm_value(self, field, **kwargs) -> str: + timezone = kwargs.get("timezone", "UTC") + + if isinstance(self.value, Date): + value = self.value.to_django_orm() + elif isinstance(self.value, int): + value = str(self.value) + else: + value = "" + + return f"{timezone}?{value}?{self.mode}" + + +class DateEqualsViewFilterItemCreate(DateViewFilterItemCreate): + operator: Literal["equal"] = Field( + ..., description="Checks if the field is equal to the value." + ) + + def get_django_orm_type(self, field, **kwargs) -> str: + return "date_is" + + +class DateEqualsViewFilterItem(DateEqualsViewFilterItemCreate, ViewFilterItem): + pass + + +class DateNotEqualsViewFilterItemCreate(DateViewFilterItemCreate): + operator: Literal["not_equal"] = Field( + ..., description="Checks if the field is not equal to the value." + ) + + def get_django_orm_type(self, field, **kwargs) -> str: + return "date_is_not" + + +class DateNotEqualsViewFilterItem(DateNotEqualsViewFilterItemCreate, ViewFilterItem): + pass + + +class DateAfterViewFilterItemCreate(DateViewFilterItemCreate): + operator: Literal["after"] = Field( + ..., description="Checks if the field is after the value." + ) + or_equal: bool = Field( + False, + description="If true, checks if the field is after or equal to the value.", + ) + + def get_django_orm_type(self, field, **kwargs) -> str: + return "date_is_on_or_after" if self.or_equal else "date_is_after" + + +class DateAfterViewFilterItem(DateAfterViewFilterItemCreate, ViewFilterItem): + pass + + +class DateBeforeViewFilterItemCreate(DateViewFilterItemCreate): + operator: Literal["before"] = Field( + ..., description="Checks if the field is before the value." + ) + or_equal: bool = Field( + False, + description="If true, checks if the field is before or equal to the value.", + ) + + def get_django_orm_type(self, field, **kwargs) -> str: + return "date_is_on_or_before" if self.or_equal else "date_is_before" + + +class DateBeforeViewFilterItem(DateBeforeViewFilterItemCreate, ViewFilterItem): + pass + + +AnyDateViewFilterItemCreate = Annotated[ + DateEqualsViewFilterItemCreate + | DateNotEqualsViewFilterItemCreate + | DateAfterViewFilterItemCreate + | DateBeforeViewFilterItemCreate, + Field(discriminator="operator"), +] +AnyDateViewFilterItem = Annotated[ + DateEqualsViewFilterItem + | DateNotEqualsViewFilterItem + | DateAfterViewFilterItem + | DateBeforeViewFilterItem, + Field(discriminator="operator"), +] + + +class SingleSelectViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["single_select"] = Field(..., description="A single select filter.") + value: list[str] = Field( + ..., description="The select option value(s) to filter on." + ) + + def get_django_orm_value(self, field, **kwargs) -> str: + values = set(v.lower() for v in self.value) + valid_option_ids = [ + option.id + for option in field.select_options.all() + if option.value.lower() in values + ] + return ",".join([str(v) for v in valid_option_ids]) + + +class SingleSelectIsAnyViewFilterItemCreate(SingleSelectViewFilterItemCreate): + operator: Literal["is_any_of"] = Field( + ..., description="Checks if the field is equal to any of the values " + ) + + def get_django_orm_type(self, field, **kwargs): + return "single_select_is_any_of" + + +class SingleSelectIsAnyViewFilterItem( + SingleSelectIsAnyViewFilterItemCreate, ViewFilterItem +): + pass + + +class SingleSelectIsNoneOfNotViewFilterItemCreate(SingleSelectViewFilterItemCreate): + operator: Literal["is_none_of"] = Field( + ..., description="Checks if the field is not equal to the value." + ) + + def get_django_orm_type(self, field, **kwargs): + return "single_select_is_none_of" + + +class SingleSelectIsNoneOfNotViewFilterItem( + SingleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem +): + pass + + +AnySingleSelectViewFilterItemCreate = Annotated[ + SingleSelectIsAnyViewFilterItemCreate | SingleSelectIsNoneOfNotViewFilterItemCreate, + Field(discriminator="operator"), +] + +AnySingleSelectViewFilterItem = Annotated[ + SingleSelectIsAnyViewFilterItem | SingleSelectIsNoneOfNotViewFilterItem, + Field(discriminator="operator"), +] + + +class MultipleSelectViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["multiple_select"] = Field( + ..., description="A multiple select filter." + ) + value: list[str] = Field( + ..., description="The select option value(s) to filter on." + ) + + def get_django_orm_value(self, field, **kwargs) -> str: + values = set(v.lower() for v in self.value) + valid_option_ids = [ + option.id + for option in field.select_options.all() + if option.value.lower() in values + ] + return ",".join([str(v) for v in valid_option_ids]) + + +class MultipleSelectIsAnyViewFilterItemCreate(MultipleSelectViewFilterItemCreate): + operator: Literal["is_any_of"] = Field( + ..., description="Checks if the field is equal to any of the values " + ) + + def get_django_orm_type(self, field, **kwargs): + return "multiple_select_has" + + +class MultipleSelectIsAnyViewFilterItem( + MultipleSelectIsAnyViewFilterItemCreate, ViewFilterItem +): + pass + + +class MultipleSelectIsNoneOfNotViewFilterItemCreate(MultipleSelectViewFilterItemCreate): + operator: Literal["is_none_of"] = Field( + ..., description="Checks if the field is not equal to the value." + ) + + def get_django_orm_type(self, field, **kwargs): + return "multiple_select_has_not" + + +class MultipleSelectIsNoneOfNotViewFilterItem( + MultipleSelectIsNoneOfNotViewFilterItemCreate, ViewFilterItem +): + pass + + +AnyMultipleSelectViewFilterItemCreate = Annotated[ + MultipleSelectIsAnyViewFilterItemCreate + | MultipleSelectIsNoneOfNotViewFilterItemCreate, + Field(discriminator="operator"), +] + +AnyMultipleSelectViewFilterItem = Annotated[ + MultipleSelectIsAnyViewFilterItem | MultipleSelectIsNoneOfNotViewFilterItem, + Field(discriminator="operator"), +] + + +class LinkRowViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["link_row"] = Field(..., description="A link row filter.") + value: int = Field(..., description="The linked record ID to filter on.") + + def get_django_orm_value(self, field, **kwargs) -> str: + return str(self.value) + + +class LinkRowHasViewFilterItemCreate(LinkRowViewFilterItemCreate): + operator: Literal["has"] = Field( + ..., description="Checks if the field has the linked record." + ) + + def get_django_orm_type(self, field, **kwargs): + return "link_row_has" + + +class LinkRowHasViewFilterItem(LinkRowHasViewFilterItemCreate, ViewFilterItem): + pass + + +class LinkRowHasNotViewFilterItemCreate(LinkRowViewFilterItemCreate): + operator: Literal["has_not"] = Field( + ..., description="Checks if the field does not have the linked record." + ) + + def get_django_orm_type(self, field, **kwargs): + return "link_row_has_not" + + +class LinkRowHasNotViewFilterItem(LinkRowHasNotViewFilterItemCreate, ViewFilterItem): + pass + + +AnyLinkRowViewFilterItemCreate = Annotated[ + LinkRowHasViewFilterItemCreate | LinkRowHasNotViewFilterItemCreate, + Field(discriminator="operator"), +] + +AnyLinkRowViewFilterItem = Annotated[ + LinkRowHasViewFilterItem | LinkRowHasNotViewFilterItem, + Field(discriminator="operator"), +] + + +class BooleanViewFilterItemCreate(ViewFilterItemCreate): + type: Literal["boolean"] = Field(..., description="A boolean filter.") + value: bool = Field(..., description="The boolean value to filter on.") + + def get_django_orm_value(self, field, **kwargs) -> str: + return "1" if self.value else "0" + + +class BooleanIsViewFilterItemCreate(BooleanViewFilterItemCreate): + operator: Literal["is"] = Field(..., description="Checks if the field is true.") + value: bool = Field(..., description="The boolean value to filter on.") + + def get_django_orm_type(self, field, **kwargs) -> str: + return "boolean" + + +class BooleanIsTrueViewFilterItem(BooleanIsViewFilterItemCreate, ViewFilterItem): + pass + + +AnyViewFilterItemCreate = Annotated[ + AnyTextViewFilterItemCreate + | AnyNumberViewFilterItemCreate + | AnyDateViewFilterItemCreate + | AnySingleSelectViewFilterItemCreate + | AnyLinkRowViewFilterItemCreate + | BooleanViewFilterItemCreate + | MultipleSelectViewFilterItemCreate, + Field(discriminator="type"), +] + +AnyViewFilterItem = Annotated[ + AnyTextViewFilterItem + | AnyNumberViewFilterItem + | AnyDateViewFilterItem + | AnySingleSelectViewFilterItem + | AnyLinkRowViewFilterItem + | BooleanIsTrueViewFilterItem + | MultipleSelectIsAnyViewFilterItem, + Field(discriminator="type"), +] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py new file mode 100644 index 0000000000..458bfcd70a --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/types/views.py @@ -0,0 +1,370 @@ +from typing import Annotated, Literal, Type + +from baserow_premium.permission_manager import Table +from baserow_premium.views.models import CalendarView, KanbanView, TimelineView +from pydantic import Field + +from baserow.contrib.database.fields.models import ( + DateField, + FileField, + SingleSelectField, +) +from baserow.contrib.database.views.models import FormView, GalleryView, GridView +from baserow.contrib.database.views.models import View as BaserowView +from baserow.contrib.database.views.registries import view_type_registry +from baserow_enterprise.assistant.types import BaseModel + + +class ViewItemCreate(BaseModel): + name: str = Field( + ..., + description="A sensible name for the view (i.e. 'All tasks', 'Completed tasks', etc.).", + ) + public: bool = Field( + ..., description="Whether the view is publicly accessible. Default is False." + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + "name": self.name, + "public": self.public, + } + + def field_options_to_django_orm(self) -> dict[str, any]: + return {} + + +class ViewItem(BaseModel): + id: int = Field(...) + + @classmethod + def from_django_orm(cls, orm_view: Type[BaserowView]) -> "ViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + public=orm_view.public, + ) + + +class GridFieldOption(BaseModel): + field_id: int = Field(...) + width: int = Field( + ..., + description="The width of the field in the grid view. Default is 200.", + ) + hidden: bool = Field( + ..., + description="Whether the field is hidden in the grid view. Default is False.", + ) + + +class GridViewItemCreate(ViewItemCreate): + type: Literal["grid"] = Field(..., description="A grid view.") + row_height: Literal["small", "medium", "large"] = Field( + ..., + description=( + "The height of the rows in the view. Can be 'small', 'medium' or 'large'. Default is 'small'." + ), + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + **super().to_django_orm_kwargs(table), + "row_height": self.row_height, + } + + +class GridViewItem(GridViewItemCreate, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: GridView) -> "GridViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="grid", + row_height="small", + public=orm_view.public, + ) + + +class KanbanViewItemCreate(ViewItemCreate): + type: Literal["kanban"] = Field(..., description="A kanban view.") + column_field_id: int | None = Field( + ..., + description="The ID of the field to use for the kanban columns. Must be a single select field. None if no single select field is available.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + model = table.get_model() + column_field = model.get_field_object_by_id(self.column_field_id)["field"] + if not isinstance(column_field, SingleSelectField): + raise ValueError("The column_field_id must be a Single Select field.") + + return { + **super().to_django_orm_kwargs(table), + "single_select_field": column_field, + } + + +class KanbanViewItem(KanbanViewItemCreate, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: KanbanView) -> "KanbanViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="kanban", + column_field_id=orm_view.single_select_field_id, + public=orm_view.public, + ) + + +class CalendarViewItemCreate(ViewItemCreate): + type: Literal["calendar"] = Field(..., description="A calendar view.") + date_field_id: int | None = Field( + ..., + description="The ID of the field to use for the calendar dates. Must be a date field. None if no date field is available.", + ) + + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + model = table.get_model() + date_field = model.get_field_object_by_id(self.date_field_id)["field"] + if not isinstance(date_field, DateField): + raise ValueError("The date_field_id must be a Date field.") + + return { + **super().to_django_orm_kwargs(table), + "date_field": date_field, + } + + +class CalendarViewItem(CalendarViewItemCreate, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: CalendarView) -> "CalendarViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="calendar", + date_field_id=orm_view.date_field_id, + public=orm_view.public, + ) + + +class BaseGalleryViewItem(ViewItemCreate): + type: Literal["gallery"] = Field(..., description="A gallery view.") + cover_field_id: int | None = Field( + None, + description=( + "The ID of the field to use for the gallery cover image. Must be a file field. None if no file field is available." + ), + ) + + +class GalleryViewItemCreate(BaseGalleryViewItem): + def to_django_orm_kwargs(self, table): + model = table.get_model() + cover_field = model.get_field_object_by_id(self.cover_field_id)["field"] + if not isinstance(cover_field, FileField): + raise ValueError("The cover_field_id must be a File field.") + + return { + **super().to_django_orm_kwargs(table), + "card_cover_image_field_id": self.cover_field_id, + } + + +class GalleryViewItem(BaseGalleryViewItem, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: GalleryView) -> "GalleryViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="gallery", + cover_field_id=orm_view.card_cover_image_field_id, + public=orm_view.public, + ) + + +class BaseTimelineViewItem(ViewItemCreate): + type: Literal["timeline"] = Field(..., description="A timeline view.") + start_date_field_id: int | None = Field( + ..., + description="The ID of the field to use for the timeline dates. Must be a date field. None if no date field is available.", + ) + end_date_field_id: int | None = Field( + ..., + description=( + "The ID of the field to use for the timeline end dates. Must be a date field. None if no date field is available." + ), + ) + + +class TimelineViewItemCreate(BaseTimelineViewItem): + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + model = table.get_model() + start_field = model.get_field_object_by_id(self.start_date_field_id)["field"] + end_field = model.get_field_object_by_id(self.end_date_field_id)["field"] + if ( + not isinstance(start_field, DateField) + or not isinstance(end_field, DateField) + or start_field.id == end_field.id + or start_field.date_include_time != end_field.date_include_time + ): + raise ValueError( + "Invalid timeline configuration: both start and end fields must be Date fields " + "and they must have the same include_time setting (either both include time or " + "both are date-only). " + ) + + return { + **super().to_django_orm_kwargs(table), + "start_date_field": start_field, + "end_date_field": end_field, + } + + +class TimelineViewItem(BaseTimelineViewItem, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: TimelineView) -> "TimelineViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="timeline", + start_date_field_id=orm_view.start_date_field_id, + end_date_field_id=orm_view.end_date_field_id, + public=orm_view.public, + ) + + +class FormFieldOption(BaseModel): + field_id: int = Field(..., description="The ID of the field.") + name: str = Field(..., description="The name to show for the field in the form.") + description: str = Field( + ..., description="The description to show for the field in the form." + ) + required: bool = Field( + ..., description="Whether the field is required in the form. Default is True." + ) + order: int = Field(..., description="The order of the field in the form.") + + +class BaseFormViewItem(ViewItemCreate): + type: Literal["form"] = Field(..., description="A form view.") + title: str = Field(..., description="The title of the form. Can be empty.") + description: str = Field( + ..., description="The description of the form. Can be empty." + ) + submit_button_label: str = Field( + ..., description="The label of the submit button. Default is 'Submit'." + ) + receive_notification_on_submit: bool = Field( + ..., + description=( + "Whether to receive an email notification when the form is submitted. Default is False." + ), + ) + submit_action: Literal["MESSAGE", "REDIRECT"] = Field( + ..., + description="The action to perform when the form is submitted. Default is 'MESSAGE'.", + ) + submit_action_message: str = Field( + ..., + description=( + "The message to display when the form is submitted and the action is 'MESSAGE'. Default is empty." + ), + ) + submit_action_redirect_url: str = Field( + ..., + description=( + "The URL to redirect to when the form is submitted and the action is 'REDIRECT'. Default is empty." + ), + ) + + field_options: list[FormFieldOption] = Field( + ..., + description=( + "The list of fields to show in the form, along with their options. The fields must be part of the table." + ), + ) + + +class FormViewItemCreate(BaseFormViewItem): + def to_django_orm_kwargs(self, table: Table) -> dict[str, any]: + return { + **super().to_django_orm_kwargs(table), + "title": self.title, + "description": self.description, + } + + def field_options_to_django_orm(self): + return { + fo.field_id: { + "enabled": True, + "name": fo.name, + "description": fo.description, + "required": fo.required, + "order": fo.order, + } + for fo in self.field_options + } + + +class FormViewItem(FormViewItemCreate, ViewItem): + @classmethod + def from_django_orm(cls, orm_view: FormView) -> "FormViewItem": + return cls( + id=orm_view.id, + name=orm_view.name, + type="form", + public=orm_view.public, + title=orm_view.title, + description=orm_view.description, + field_options=[ + FormFieldOption( + field_id=fo.field_id, + name=fo.name, + description=fo.description, + required=fo.required, + order=fo.order, + ) + for fo in orm_view.active_field_options.all() + ], + ) + + +AnyViewItemCreate = Annotated[ + GridViewItemCreate + | KanbanViewItemCreate + | CalendarViewItemCreate + | GalleryViewItemCreate + | TimelineViewItemCreate + | FormViewItemCreate, + Field(discriminator="type"), +] + +AnyViewItem = Annotated[ + GridViewItem + | KanbanViewItem + | CalendarViewItem + | GalleryViewItem + | TimelineViewItem + | FormViewItem, + Field(discriminator="type"), +] + + +class ViewItemsRegistry: + _registry = { + "grid": GridViewItem, + "kanban": KanbanViewItem, + "calendar": CalendarViewItem, + "gallery": GalleryViewItem, + "timeline": TimelineViewItem, + "form": FormViewItem, + } + + def from_django_orm(self, orm_view: Type[BaserowView]) -> ViewItem: + view_type = view_type_registry.get_by_model(orm_view).type + view_class: ViewItem = self._registry.get(view_type, ViewItem) + return view_class.from_django_orm(orm_view) + + +view_item_registry = ViewItemsRegistry() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py new file mode 100644 index 0000000000..b864e00ae6 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py @@ -0,0 +1,514 @@ +from dataclasses import dataclass +from itertools import groupby +from typing import Any, Callable, Literal, Type, Union + +from django.core.exceptions import ValidationError +from django.db import transaction +from django.db.models import Q +from django.utils.translation import gettext as _ + +import dspy +from dspy.adapters.types.tool import _resolve_json_schema_reference +from pydantic import ConfigDict, Field, create_model + +from baserow.contrib.database.fields.actions import CreateFieldActionType +from baserow.contrib.database.fields.field_types import LinkRowFieldType +from baserow.contrib.database.fields.handler import FieldHandler +from baserow.contrib.database.fields.models import SelectOption as OrmSelectOption +from baserow.contrib.database.fields.registries import field_type_registry +from baserow.contrib.database.rows.actions import ( + CreateRowsActionType, + DeleteRowsActionType, + UpdateRowsActionType, +) +from baserow.contrib.database.table.handler import TableHandler +from baserow.contrib.database.table.models import ( + FieldObject, + GeneratedTableModel, + Table, +) +from baserow.contrib.database.views.handler import ViewHandler +from baserow.core.db import specific_iterator +from baserow.core.models import Workspace +from baserow_enterprise.assistant.tools.database.types.table import ( + BaseTableItem, + TableItem, +) +from baserow_enterprise.assistant.tools.registries import ToolHelpers + +from .types import ( + AnyFieldItem, + AnyFieldItemCreate, + BaseModel, + Date, + Datetime, + field_item_registry, +) + +NoChange = Literal["__NO_CHANGE__"] + + +def filter_tables(user, workspace: Workspace) -> list[Table]: + return TableHandler().list_workspace_tables(user, workspace) + + +def list_tables(user, workspace: Workspace, database_id: int) -> list[BaseTableItem]: + tables_qs = filter_tables(user, workspace).filter(database_id=database_id) + + return [BaseTableItem(id=table.id, name=table.name) for table in tables_qs] + + +def get_tables_schema( + tables: list[Table], + full_schema: bool = False, +) -> list[TableItem]: + """Returns the schema of the specified tables.""" + + q = Q(table__in=tables) + if not full_schema: # Only the primary fields and relationships + q &= Q(linkrowfield__isnull=False) | Q(primary=True) + + base_field_queryset = FieldHandler().get_base_fields_queryset() + fields = specific_iterator( + base_field_queryset.filter(q).order_by("table_id", "order"), + per_content_type_queryset_hook=( + lambda field, queryset: field_type_registry.get_by_model( + field + ).enhance_field_queryset(queryset, field) + ), + ) + + table_items = [] + for table_id, fields_in_table in groupby(fields, lambda f: f.table_id): + fields_in_table = list(fields_in_table) + table = next(t for t in tables if t.id == table_id) + primary_field = next(f for f in fields if f.primary) + primary_field_item = field_item_registry.from_django_orm(primary_field) + + table_items.append( + TableItem( + id=table.id, + name=table.name, + primary_field=primary_field_item, + fields=[ + field_item_registry.from_django_orm(f) + for f in fields_in_table + if f.id != primary_field.id + ], + ) + ) + + # Make sure the order is the same as the input + tables = list(tables) + table_items.sort( + key=lambda t: tables.index(next(tb for tb in tables if tb.id == t.id)) + ) + + return table_items + + +def create_fields( + user, + table: Table, + field_items: list[AnyFieldItemCreate], + tool_helpers: ToolHelpers, +) -> list[AnyFieldItem]: + created_fields = [] + for field_item in field_items: + tool_helpers.update_status( + _("Creating field %(field_name)s...") % {"field_name": field_item.name} + ) + + new_field = CreateFieldActionType.do( + user, + table, + field_item.type, + **field_item.to_django_orm_kwargs(table), + ) + created_fields.append(field_item_registry.from_django_orm(new_field)) + return created_fields + + +@dataclass +class FieldDefinition: + type: Type | None = None + field_def: Any | None = None + to_django_orm: Callable[[Any], Any] | None = None + from_django_orm: Callable[[Any], Any] | None = None + + +def _get_pydantic_field_definition( + field_object: FieldObject, +) -> FieldDefinition: + """ + Returns the Pydantic field type and definition for the given field object. + """ + + orm_field = field_object["field"] + orm_field_type = field_object["type"] + + match orm_field_type.type: + case "text": + return FieldDefinition( + str | None, + Field(..., description="Single-line text", title=orm_field.name), + lambda v: v if v is not None else "", + lambda v: v if v is not None else "", + ) + + case "long_text": + return FieldDefinition( + str | None, + Field(..., description="Multi-line text", title=orm_field.name), + lambda v: v if v is not None else "", + lambda v: v if v is not None else "", + ) + case "number": + return FieldDefinition( + float | None, + Field(..., description="Number or None", title=orm_field.name), + ) + case "boolean": + return FieldDefinition( + bool, Field(..., description="Boolean", title=orm_field.name) + ) + case "date": + if orm_field.date_include_time: + return FieldDefinition( + Datetime | None, + Field(..., description="Datetime or None", title=orm_field.name), + lambda v: v.to_django_orm() if v is not None else None, + lambda v: Datetime.from_django_orm(v) if v is not None else None, + ) + else: + return FieldDefinition( + Date | None, + Field(..., description="Date or None", title=orm_field.name), + lambda v: v.to_django_orm() if v is not None else None, + lambda v: Date.from_django_orm(v) if v is not None else None, + ) + case "single_select": + choices = [option.value for option in orm_field.select_options.all()] + + return FieldDefinition( + Literal[*choices] | None, + Field( + ..., + description=f"One of: {', '.join(choices)} or None", + title=orm_field.name, + ), + lambda v: v if v in choices else None, + lambda v: v.value if isinstance(v, OrmSelectOption) else v, + ) + case "multiple_select": + choices = [option.value for option in orm_field.select_options.all()] + + return FieldDefinition( + list[Literal[*choices]], + Field( + ..., + description=f"List of any of: {', '.join(choices)} or empty list", + title=orm_field.name, + ), + lambda v: [opt for opt in v if opt in choices], + lambda v: [opt.value for opt in v.all()] if v is not None else None, + ) + case "link_row": + linked_model = orm_field.link_row_table.get_model() + linked_primary_key = linked_model.get_primary_field() + + # If there's no primary key, we can't safely work with this field + if linked_primary_key is None: + return FieldDefinition() # Unsupported field type + + # Avoid null or empty values + linked_pk = linked_primary_key.db_column + linked_values = list( + linked_model.objects.exclude( + Q(**{f"{linked_pk}__isnull": True}) + | Q(**{f"{linked_pk}__exact": ""}) + ).values_list(linked_pk, flat=True)[:10] + ) + examples = f"Examples: {', '.join([str(v) for v in linked_values])}" + + def to_django_orm(value): + if isinstance(value, str) or isinstance(value, int): + value = [value] + if value is not None: + try: + return LinkRowFieldType().prepare_value_for_db(orm_field, value) + except ValidationError: + pass + return [] + + def from_django_orm(value): + values = [str(v) for v in value.all()] + if orm_field.link_row_multiple_relationships: + return values + else: + return values[0] if values else None + + # TODO: verify this can work with every possible primary field type + if orm_field.link_row_multiple_relationships: + desc = "List of values (as strings) or IDs (as integers) from the linked table or empty list." + field_type = list[str | int] | None + else: + desc = "Single value (as string) or ID (as integer) from the linked table or empty list." + field_type = str | int | None + if examples: + desc += " " + examples + return FieldDefinition( + field_type, + Field(None, description=desc, title=orm_field.name), + to_django_orm, + from_django_orm, + ) + + case _: + return FieldDefinition() # Unsupported field type + + +def get_create_row_model(table: Table, field_ids: list[int] | None = None) -> BaseModel: + """ + Dynamically creates a Pydantic model for the given table based on its fields, to be + used for row creation and validation. + """ + + model_name = f"Table{table.id}Row" + + field_definitions = {} + field_conversions = {} + + table_model = table.get_model() + for field_object in table_model.get_field_objects(): + field_definition = _get_pydantic_field_definition(field_object) + if field_definition.type is None: + continue # Skip unsupported field types + if field_ids is not None and field_object["field"].id not in field_ids: + continue # Skip fields not in the specified list + + field = field_object["field"] + field_definitions[field.name] = ( + field_definition.type, + field_definition.field_def, + ) + field_conversions[field.name] = ( + field.db_column, + field_definition.to_django_orm, + field_definition.from_django_orm, + ) + + class TableRowModel(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + + def to_django_orm(self) -> dict[str, Any]: + orm_data = {} + for key, value in self.__dict__.items(): + if key == "id": + orm_data["id"] = value + continue + + if key not in field_conversions or value == "__NO_CHANGE__": + continue + + orm_key, to_django_orm, _ = field_conversions[key] + if to_django_orm: + orm_data[orm_key] = to_django_orm(value) + else: + orm_data[orm_key] = value + return orm_data + + @classmethod + def from_django_orm( + cls, orm_row: GeneratedTableModel, field_ids: list[int] | None = None + ) -> "TableRowModel": + init_data = {"id": orm_row.id} + for field_object in orm_row.get_field_objects(): + field = field_object["field"] + if field.name not in field_conversions: + continue + if field_ids is not None and field.id not in field_ids: + continue + db_column, _, from_django_orm = field_conversions[field.name] + value = getattr(orm_row, db_column) + if from_django_orm: + init_data[field.name] = from_django_orm(value) + else: + init_data[field.name] = value + return cls(**init_data) + + return create_model( + model_name, + __module__=__name__, + __base__=TableRowModel, + **field_definitions, + ) + + +def get_update_row_model(table) -> BaseModel: + """Creates an update model where all fields can be NoChange.""" + + create_model_class = get_create_row_model(table) + + # Build update fields - all fields become Union[OriginalType, NoChange] + update_fields = {} + + for field_name, field_info in create_model_class.model_fields.items(): + original_type = field_info.annotation + + update_fields[field_name] = ( + Union[NoChange, original_type], + Field( + ..., + description=f"Use '__NO_CHANGE__' to keep current value. To update, use a {field_info.description}", + ), + ) + + update_fields["id"] = (int, Field(..., description="The ID of the row to update")) + + # Create the update model + UpdateRowModel = create_model( + f"UpdateTable{table.id}Row", + __base__=create_model_class, + **update_fields, + ) + + return UpdateRowModel + + +def get_view(user, view_id: int): + return ViewHandler().get_view_as_user(user, view_id) + + +def get_table_rows_tools( + user, workspace: Workspace, tool_helpers: ToolHelpers, table: Table +): + row_model_for_create = get_create_row_model(table) + row_model_for_update = get_update_row_model(table) + row_model_for_response = create_model( + f"ResponseTable{table.id}Row", + id=(int, ...), + __base__=row_model_for_create, + ) + + def _create_rows( + rows: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """ + Create new rows in the specified table. + """ + + nonlocal user, workspace, tool_helpers, row_model_for_create, row_model_for_response + + if not rows: + return [] + + tool_helpers.update_status( + _("Creating rows in %(table_name)s ") % {"table_name": table.name} + ) + + with transaction.atomic(): + orm_rows = CreateRowsActionType.do( + user, + table, + [row_model_for_create(**row).to_django_orm() for row in rows], + ) + + return {"created_row_ids": [r.id for r in orm_rows]} + + create_row_model_schema = _resolve_json_schema_reference( + row_model_for_create.model_json_schema() + ) + create_rows_tool = dspy.Tool( + func=_create_rows, + name=f"create_rows_in_table_{table.id}", + desc=f"Creates new rows in the table {table.name} (ID: {table.id}). Max 20 rows at a time.", + args={ + "rows": { + "items": create_row_model_schema, + "type": "array", + "maxItems": 20, + } + }, + ) + + def _update_rows( + rows: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + """ + Update existing rows in the specified table. + """ + + nonlocal user, workspace, tool_helpers, row_model_for_update + + if not rows: + return [] + + tool_helpers.update_status( + _("Updating rows in %(table_name)s ") % {"table_name": table.name} + ) + + with transaction.atomic(): + orm_rows = UpdateRowsActionType.do( + user, + table, + [row_model_for_update(**row).to_django_orm() for row in rows], + ).updated_rows + + return {"updated_row_ids": [r.id for r in orm_rows]} + + update_row_model_schema = _resolve_json_schema_reference( + row_model_for_update.model_json_schema() + ) + update_rows_tool = dspy.Tool( + func=_update_rows, + name=f"update_rows_in_table_{table.id}_by_row_ids", + desc=f"Updates existing rows in the table {table.name} (ID: {table.id}), identified by their row IDs. Max 20 at a time.", + args={ + "rows": { + "items": update_row_model_schema, + "type": "array", + "maxItems": 20, + } + }, + ) + + def _delete_rows(row_ids: list[int]) -> str: + """ + Delete rows in the specified table. + """ + + nonlocal user, workspace, tool_helpers + + if not row_ids: + return + + tool_helpers.update_status( + _("Deleting rows in %(table_name)s ") % {"table_name": table.name} + ) + + with transaction.atomic(): + DeleteRowsActionType.do(user, table, row_ids) + + return {"deleted_row_ids": row_ids} + + delete_rows_tool = dspy.Tool( + func=_delete_rows, + name=f"delete_rows_in_table_{table.id}_by_row_ids", + desc=f"Deletes rows in the table {table.name} (ID: {table.id}). Max 20 at a time.", + args={ + "row_ids": { + "items": {"type": "integer"}, + "type": "array", + "maxItems": 20, + } + }, + ) + + return { + "create": create_rows_tool, + "update": update_rows_tool, + "delete": delete_rows_tool, + } diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py new file mode 100644 index 0000000000..975c0d67ba --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/tools.py @@ -0,0 +1,47 @@ +from typing import Callable + +from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ + +from baserow.core.models import Workspace +from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers + +from .types import AnyNavigationRequestType + + +def get_navigation_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[AnyNavigationRequestType], str]: + """ + Returns a function that provides navigation instructions to the user based on + their current workspace context. + """ + + def navigate(request: AnyNavigationRequestType) -> str: + """ + Navigate within the workspace. + + Use when: + - the user asks to open, go, to be brought to something + - the user asks to see something from their workspace + """ + + nonlocal user, workspace + + location = request.to_location(user, workspace, request) + + tool_helpers.update_status( + _("Navigating to %(location)s...") + % {"location": location.to_localized_string()} + ) + return tool_helpers.navigate_to(location) + + return navigate + + +class NavigationToolType(AssistantToolType): + type = "navigation" + + @classmethod + def get_tool(cls, user, workspace, tool_helpers): + return get_navigation_tool(user, workspace, tool_helpers) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py new file mode 100644 index 0000000000..f0421eb7bf --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/types.py @@ -0,0 +1,72 @@ +from typing import Annotated, Literal + +from django.contrib.auth.models import AbstractUser + +from pydantic import Field + +from baserow.core.models import Workspace +from baserow_enterprise.assistant.tools.database.utils import filter_tables +from baserow_enterprise.assistant.types import ( + BaseModel, + TableNavigationType, + WorkspaceNavigationType, +) + + +class NavigationRequestType(BaseModel): + type: str + + @classmethod + def to_location( + cls, + user: AbstractUser, + workspace: Workspace, + request: "NavigationRequestType", + ) -> "LocationType": + raise NotImplementedError() + + +class LocationType(BaseModel): + type: str + + +class TableNavigationRequestType(NavigationRequestType): + type: Literal["database-table"] = Field(..., description="A specific table") + table_id: int = Field(..., description="The table to open") + + @classmethod + def to_location( + cls, + user: AbstractUser, + workspace: Workspace, + request: "TableNavigationRequestType", + ) -> TableNavigationType: + table = filter_tables(user, workspace).get(id=request.table_id) + + return TableNavigationType( + type="database-table", + database_id=table.database_id, + table_id=request.table_id, + table_name=table.name, + ) + + +class WorkspaceNavigationRequestType(NavigationRequestType): + type: Literal["workspace"] = Field( + ..., description="The home page of the workspace" + ) + + @classmethod + def to_location( + cls, + user: AbstractUser, + workspace: Workspace, + request: "WorkspaceNavigationRequestType", + ) -> WorkspaceNavigationType: + return WorkspaceNavigationType(type="workspace", id=workspace.id) + + +AnyNavigationRequestType = Annotated[ + TableNavigationRequestType | WorkspaceNavigationRequestType, + Field(discriminator="type"), +] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py new file mode 100644 index 0000000000..9e42d7629a --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py @@ -0,0 +1,19 @@ +from dspy.dsp.utils.settings import settings +from dspy.streaming.messages import sync_send_to_stream + +from baserow_enterprise.assistant.types import AiNavigationMessage, AnyNavigationType + + +def unsafe_navigate_to(location: AnyNavigationType) -> str: + """ + Navigate to a specific table or view without any safety checks. + Make sure all the IDs provided are valid and can be accessed by the user before + calling this function. + + :param navigation_type: The type of navigation to perform. + """ + + stream = settings.send_stream + if stream is not None: + sync_send_to_stream(stream, AiNavigationMessage(location=location)) + return "Navigated successfully." diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py index 3e4de88c7c..83f036a596 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py @@ -1,6 +1,11 @@ -from typing import Any, Callable +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable from django.contrib.auth.models import AbstractUser +from django.utils import translation + +from dspy.dsp.utils.settings import settings +from dspy.streaming.messages import sync_send_to_stream from baserow.core.exceptions import ( InstanceTypeAlreadyRegistered, @@ -8,12 +13,26 @@ ) from baserow.core.models import Workspace from baserow.core.registries import Instance, Registry +from baserow_enterprise.assistant.tools.navigation.utils import unsafe_navigate_to +from baserow_enterprise.assistant.types import AiThinkingMessage + +if TYPE_CHECKING: + from baserow_enterprise.assistant.tools.navigation.types import ( + AnyNavigationRequestType, + ) + + +@dataclass +class ToolHelpers: + update_status: Callable[[str], None] + navigate_to: Callable[["AnyNavigationRequestType"], str] class AssistantToolType(Instance): - def can_use( - self, user: AbstractUser, workspace: Workspace, *args, **kwargs - ) -> bool: + name: str = "" + + @classmethod + def can_use(cls, user: AbstractUser, workspace: Workspace, *args, **kwargs) -> bool: """ Returns whether or not the given user can use this tool in the given workspace. @@ -24,8 +43,9 @@ def can_use( return True + @classmethod def on_tool_start( - self, + cls, call_id: str, instance: Any, inputs: dict[str, Any], @@ -40,8 +60,9 @@ def on_tool_start( pass + @classmethod def on_tool_end( - self, + cls, call_id: str, instance: Any, inputs: dict[str, Any], @@ -62,9 +83,17 @@ def on_tool_end( pass - def get_tool(self) -> Callable[[Any], Any]: + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: """ Returns the actual tool function to be called to pass to the dspy react agent. + + :param user: The user that will be using the tool. + :param workspace: The workspace the user is currently in. + :param tool_helpers: A dataclass containing helper functions that can be used by + the tool function. """ raise NotImplementedError("Subclasses must implement this method.") @@ -85,12 +114,32 @@ class AssistantToolRegistry(Registry[AssistantToolType]): already_registered_exception_class = AssistantToolAlreadyRegistered def list_all_usable_tools( - self, user: AbstractUser, workspace: Workspace, *args, **kwargs + self, user: AbstractUser, workspace: Workspace ) -> list[AssistantToolType]: + def update_status_localized(status: str): + """ + Sends a localized message to the frontend to update the assistant status. + + :param status: The status message to send. + """ + + nonlocal user + + with translation.override(user.profile.language): + stream = settings.send_stream + + if stream is not None: + sync_send_to_stream(stream, AiThinkingMessage(content=status)) + + tool_helpers = ToolHelpers( + update_status=update_status_localized, + navigate_to=unsafe_navigate_to, + ) + return [ - tool_type.get_tool() + tool_type.get_tool(user, workspace, tool_helpers) for tool_type in self.get_all() - if tool_type.can_use(user, workspace, *args, **kwargs) + if tool_type.can_use(user, workspace) ] diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py index 87831a3762..0a228bdbf3 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py @@ -1,14 +1,12 @@ from typing import Any, Callable, TypedDict from django.contrib.auth.models import AbstractUser +from django.utils.translation import gettext as _ import dspy -from dspy.dsp.utils.settings import settings as dspy_settings -from dspy.streaming.messages import sync_send_to_stream from baserow.core.models import Workspace -from baserow_enterprise.assistant.tools.registries import AssistantToolType -from baserow_enterprise.assistant.types import AiThinkingMessage +from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers from .handler import KnowledgeBaseHandler @@ -20,7 +18,7 @@ class SearchDocsSignature(dspy.Signature): context: list[str] = dspy.InputField() response: str = dspy.OutputField() sources: list[str] = dspy.OutputField( - description=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." + desc=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." ) @@ -29,30 +27,38 @@ class SearchDocsToolOutput(TypedDict): sources: list[str] -def search_docs(query: str) -> SearchDocsToolOutput: +def get_search_docs_tool( + user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers +) -> Callable[[str], SearchDocsToolOutput]: + """ + Returns a function that searches the Baserow documentation for a given query. """ - Search Baserow documentation. - **Critical**: Always use before answering specifics. - Never use your general knowledge to answer specifics. + def search_docs(query: str) -> SearchDocsToolOutput: + """ + Search Baserow documentation. + """ - Covers: user guides, API references, tutorials, FAQs, features, usage. - """ + nonlocal tool_helpers - tool = SearchDocsRAG() - result = tool(query) + tool_helpers.update_status(_("Exploring the knowledge base...")) - sources = [] - for source in result["sources"]: - if source not in sources: - sources.append(source) - if len(sources) >= MAX_SOURCES: - break + tool = SearchDocsRAG() + result = tool(query) - return SearchDocsToolOutput( - response=result["response"], - sources=sources, - ) + sources = [] + for source in result.sources: + if source not in sources: + sources.append(source) + if len(sources) >= MAX_SOURCES: + break + + return SearchDocsToolOutput( + response=result.response, + sources=sources, + ) + + return search_docs class SearchDocsRAG(dspy.Module): @@ -66,24 +72,14 @@ def forward(self, question): class SearchDocsToolType(AssistantToolType): type = "search_docs" - thinking_message = "Searching Baserow documentation..." def can_use( self, user: AbstractUser, workspace: Workspace, *args, **kwargs ) -> bool: return KnowledgeBaseHandler().can_search() - def get_tool(self) -> Callable[[Any], Any]: - return search_docs - - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ): - stream = dspy_settings.send_stream - if stream is not None: - sync_send_to_stream( - stream, AiThinkingMessage(code=self.type, content=self.thinking_message) - ) + @classmethod + def get_tool( + cls, user: AbstractUser, workspace: Workspace, tool_helpers: ToolHelpers + ) -> Callable[[Any], Any]: + return get_search_docs_tool(user, workspace, tool_helpers) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/types.py b/enterprise/backend/src/baserow_enterprise/assistant/types.py index b91ac49def..693ad8d7bb 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/types.py @@ -1,8 +1,18 @@ from datetime import datetime, timezone from enum import StrEnum -from typing import Literal, Optional +from typing import Annotated, Any, Callable, Literal, Optional -from pydantic import BaseModel, Field +from django.utils.translation import gettext as _ + +import dspy +from pydantic import BaseModel as PydanticBaseModel +from pydantic import ConfigDict, Field + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + extra="forbid", + ) class WorkspaceUIContext(BaseModel): @@ -85,6 +95,7 @@ class AssistantMessageType(StrEnum): HUMAN = "human" AI_MESSAGE = "ai/message" AI_THINKING = "ai/thinking" + AI_NAVIGATION = "ai/navigation" AI_ERROR = "ai/error" TOOL_CALL = "tool_call" TOOL = "tool" @@ -97,6 +108,7 @@ class HumanMessage(BaseModel): description="The unique UUID of the message", ) type: Literal["human"] = AssistantMessageType.HUMAN.value + timestamp: datetime | None = Field(default=None) content: str ui_context: Optional[UIContext] = Field( default=None, description="The UI context when the message was sent" @@ -120,23 +132,8 @@ class AiMessage(AiMessageChunk): timestamp: datetime | None = Field(default=None) -class THINKING_MESSAGES(StrEnum): - THINKING = "thinking" - ANSWERING = "answering" - # Tool-specific - SEARCH_DOCS = "search_docs" - ANALYZE_RESULTS = "analyze_results" - - # For dynamic messages that don't have a translation in the frontend - CUSTOM = "custom" - - class AiThinkingMessage(BaseModel): type: Literal["ai/thinking"] = AssistantMessageType.AI_THINKING.value - code: str = Field( - default=THINKING_MESSAGES.CUSTOM, - description="Thinking content. If empty, signals end of thinking.", - ) content: str = Field( default="", description=( @@ -167,3 +164,55 @@ class AiErrorMessage(BaseModel): AiMessage | AiErrorMessage | AiThinkingMessage | ChatTitleMessage | AiMessageChunk ) AssistantMessageUnion = HumanMessage | AIMessageUnion + + +class TableNavigationType(BaseModel): + type: Literal["database-table"] + database_id: int + table_id: int + table_name: str + + def to_localized_string(self): + return _("table %(table_name)s") % {"table_name": self.table_name} + + +class ViewNavigationType(BaseModel): + type: Literal["database-view"] + database_id: int + table_id: int + view_id: int + view_name: str + + def to_localized_string(self): + return _("view %(view_name)s") % {"view_name": self.view_name} + + +class WorkspaceNavigationType(BaseModel): + type: Literal["workspace"] + + def to_localized_string(self): + return _("home") + + +AnyNavigationType = Annotated[ + TableNavigationType | WorkspaceNavigationType | ViewNavigationType, + Field(discriminator="type"), +] + + +class AiNavigationMessage(BaseModel): + type: Literal["ai/navigation"] = "ai/navigation" + location: AnyNavigationType + + +class ToolsUpgradeResponse(BaseModel): + observation: str + new_tools: list[dspy.Tool | Callable[[Any], Any]] + + +class ToolSignature(dspy.Signature): + """Signature for manual tool handling.""" + + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + outputs: dspy.ToolCalls = dspy.OutputField() diff --git a/enterprise/backend/src/baserow_enterprise/locale/en/LC_MESSAGES/django.po b/enterprise/backend/src/baserow_enterprise/locale/en/LC_MESSAGES/django.po index db7aa69806..f53de93b37 100644 --- a/enterprise/backend/src/baserow_enterprise/locale/en/LC_MESSAGES/django.po +++ b/enterprise/backend/src/baserow_enterprise/locale/en/LC_MESSAGES/django.po @@ -8,7 +8,7 @@ msgid "" msgstr "" "Project-Id-Version: PACKAGE VERSION\n" "Report-Msgid-Bugs-To: \n" -"POT-Creation-Date: 2025-08-26 10:53+0200\n" +"POT-Creation-Date: 2025-10-17 14:14+0000\n" "PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n" "Last-Translator: FULL NAME \n" "Language-Team: LANGUAGE \n" @@ -18,6 +18,97 @@ msgstr "" "Content-Transfer-Encoding: 8bit\n" "Plural-Forms: nplurals=2; plural=(n != 1);\n" +#: src/baserow_enterprise/assistant/tools/database/tools.py:63 +msgid "Listing databases..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:136 +#, python-format +msgid "Listing tables in %(database_names)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:195 +#, python-format +msgid "Inspecting %(table_names)s schema..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:238 +#, python-format +msgid "Creating database %(database_name)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:303 +#, python-format +msgid "Creating table %(table_name)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:346 +msgid "Preparing example rows for these new tables..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:468 +#, python-format +msgid "Listing rows in %(table_name)s " +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:593 +#, python-format +msgid "Listing views in %(table_name)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:656 +#, python-format +msgid "Creating %(view_type)s view %(view_name)s" +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/tools.py:723 +#, python-format +msgid "Creating filters in %(view_name)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/utils.py:119 +#, python-format +msgid "Creating field %(field_name)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/utils.py:409 +#, python-format +msgid "Creating rows in %(table_name)s " +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/utils.py:450 +#, python-format +msgid "Updating rows in %(table_name)s " +msgstr "" + +#: src/baserow_enterprise/assistant/tools/database/utils.py:489 +#, python-format +msgid "Deleting rows in %(table_name)s " +msgstr "" + +#: src/baserow_enterprise/assistant/tools/navigation/tools.py:34 +#, python-format +msgid "Navigating to %(location)s..." +msgstr "" + +#: src/baserow_enterprise/assistant/tools/search_docs/tools.py:44 +msgid "Exploring the knowledge base..." +msgstr "" + +#: src/baserow_enterprise/assistant/types.py:176 +#, python-format +msgid "table %(table_name)s" +msgstr "" + +#: src/baserow_enterprise/assistant/types.py:187 +#, python-format +msgid "view %(view_name)s" +msgstr "" + +#: src/baserow_enterprise/assistant/types.py:194 +msgid "home" +msgstr "" + #: src/baserow_enterprise/audit_log/job_types.py:36 msgid "User Email" msgstr "" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py b/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py index 75fe98c247..06187ed86c 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py @@ -12,7 +12,6 @@ from baserow.test_utils.helpers import AnyStr from baserow_enterprise.assistant.models import AssistantChat from baserow_enterprise.assistant.types import ( - THINKING_MESSAGES, AiErrorMessage, AiMessage, AiMessageChunk, @@ -814,11 +813,11 @@ def test_send_message_streams_thinking_messages_during_tool_execution( # Mock assistant with thinking messages (simulating tool execution) async def mock_astream(human_message): # Initial thinking - yield AiThinkingMessage(code=THINKING_MESSAGES.THINKING) + yield AiThinkingMessage(content="Thinking...") # Tool-specific thinking (e.g., searching docs) - yield AiThinkingMessage(code=THINKING_MESSAGES.SEARCH_DOCS) + yield AiThinkingMessage(content="Searching documentation...") # Analyzing results - yield AiThinkingMessage(code=THINKING_MESSAGES.ANALYZE_RESULTS) + yield AiThinkingMessage(content="Analyzing results...") # Final answer yield AiMessageChunk( content="Based on the documentation, here's how to do it...", @@ -850,13 +849,13 @@ async def mock_astream(human_message): # First three messages are thinking messages assert messages[0]["type"] == "ai/thinking" - assert messages[0]["code"] == THINKING_MESSAGES.THINKING + assert messages[0]["content"] == "Thinking..." assert messages[1]["type"] == "ai/thinking" - assert messages[1]["code"] == THINKING_MESSAGES.SEARCH_DOCS + assert messages[1]["content"] == "Searching documentation..." assert messages[2]["type"] == "ai/thinking" - assert messages[2]["code"] == THINKING_MESSAGES.ANALYZE_RESULTS + assert messages[2]["content"] == "Analyzing results..." # Final message is the answer assert messages[3]["type"] == "ai/message" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 460ec889fe..19627fc5cd 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -280,51 +280,6 @@ def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixtu assert assistant.history.messages[0]["question"] == "Question 1" -@pytest.mark.django_db -class TestAssistantSignature: - """Test that the Assistant adapts its signature based on chat state""" - - def test_signature_includes_title_field_for_new_chats( - self, enterprise_data_fixture - ): - """Test that new chats (without title) include chat_title in signature""" - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="" # Empty title = new chat - ) - - assistant = Assistant(chat) - signature = assistant._get_chat_signature() - - # Should have chat_title field for new chats - assert "chat_title" in signature.fields - assert "answer" in signature.fields - assert "question" in signature.fields - - def test_signature_excludes_title_field_for_existing_chats( - self, enterprise_data_fixture - ): - """ - Test that existing chats (with title) don't include chat_title in signature - """ - - user = enterprise_data_fixture.create_user() - workspace = enterprise_data_fixture.create_workspace(user=user) - chat = AssistantChat.objects.create( - user=user, workspace=workspace, title="Existing Chat" - ) - - assistant = Assistant(chat) - signature = assistant._get_chat_signature() - - # Should NOT have chat_title field for existing chats - assert "chat_title" not in signature.fields - assert "answer" in signature.fields - assert "question" in signature.fields - - @pytest.mark.django_db class TestAssistantMessagePersistence: """Test that messages are persisted correctly during streaming""" @@ -427,10 +382,15 @@ async def consume_stream(): ).count() assert ai_messages == 1 + @patch("baserow_enterprise.assistant.assistant.ensure_llm_model_accessible") @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("baserow_enterprise.assistant.assistant.dspy.Predict") def test_astream_messages_persists_chat_title( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, + mock_predict_class, + mock_streamify, + mock_ensure_llm, + enterprise_data_fixture, ): """Test that chat titles are persisted to the database""" @@ -440,7 +400,7 @@ def test_astream_messages_persists_chat_title( user=user, workspace=workspace, title="" # New chat ) - # Mock streaming with title generation + # Mock streaming async def mock_stream(*args, **kwargs): yield StreamResponse( signature_field_name="answer", @@ -448,16 +408,18 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield StreamResponse( - signature_field_name="chat_title", - chunk="Greeting", - predict_name="ReAct", - is_last_chunk=False, - ) - yield Prediction(answer="Hello", chat_title="Greeting") + yield Prediction(answer="Hello") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Mock title generator + async def mock_title_acall(*args, **kwargs): + return Prediction(chat_title="Greeting") + + mock_title_generator = MagicMock() + mock_title_generator.acall = mock_title_acall + mock_predict_class.return_value = mock_title_generator + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -531,14 +493,20 @@ async def consume_stream(): chunks = async_to_sync(consume_stream)() # Should receive chunks with accumulated content - assert len(chunks) == 2 + assert len(chunks) == 3 assert chunks[0].content == "Hello" assert chunks[1].content == "Hello world" + assert chunks[2].content == "Hello world" # Final chunk repeats full answer + @patch("baserow_enterprise.assistant.assistant.ensure_llm_model_accessible") @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("baserow_enterprise.assistant.assistant.dspy.Predict") def test_astream_messages_yields_title_chunks( - self, mock_lm, mock_streamify, enterprise_data_fixture + self, + mock_predict_class, + mock_streamify, + mock_ensure_llm, + enterprise_data_fixture, ): """Test that title chunks are yielded for new chats""" @@ -556,16 +524,18 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield StreamResponse( - signature_field_name="chat_title", - chunk="Title", - predict_name="ReAct", - is_last_chunk=False, - ) - yield Prediction(answer="Answer", chat_title="Title") + yield Prediction(answer="Answer") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Mock title generator + async def mock_title_acall(*args, **kwargs): + return Prediction(chat_title="Title") + + mock_title_generator = MagicMock() + mock_title_generator.acall = mock_title_acall + mock_predict_class.return_value = mock_title_generator + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -601,7 +571,7 @@ def test_astream_messages_yields_thinking_messages( # Mock streaming async def mock_stream(*args, **kwargs): - yield AiThinkingMessage(code="thinking") + yield AiThinkingMessage(content="thinking") yield StreamResponse( signature_field_name="answer", chunk="Answer", @@ -630,7 +600,7 @@ async def consume_stream(): # Should receive thinking messages assert len(thinking_messages) == 1 - assert thinking_messages[0].code == "thinking" + assert thinking_messages[0].content == "thinking" @pytest.mark.django_db diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py new file mode 100644 index 0000000000..7f4c3c21ba --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_rows_tools.py @@ -0,0 +1,396 @@ +import pytest + +from baserow.contrib.database.rows.handler import RowHandler +from baserow_enterprise.assistant.tools.database.tools import ( + get_list_rows_tool, + get_rows_meta_tool, +) +from baserow_enterprise.assistant.types import ToolsUpgradeResponse + +from .utils import fake_tool_helpers + + +def _create_simple_database_with_linked_tables_and_rows(data_fixture): + user = data_fixture.create_user() + table_a, table_b, link_a_to_b = data_fixture.create_two_linked_tables(user=user) + workspace = table_a.database.workspace + primary_field_table_a = table_a.get_primary_field().specific + primary_field_table_b = table_b.get_primary_field().specific + text_field = data_fixture.create_text_field( + user=user, table=table_a, name="Text field" + ) + long_text_field = data_fixture.create_long_text_field( + user=user, table=table_a, name="Long text field" + ) + number_field = data_fixture.create_number_field( + user=user, + table=table_a, + name="Number field", + number_decimal_places=3, + ) + date_field = data_fixture.create_date_field( + user=user, table=table_a, name="Date field" + ) + datetime_field = data_fixture.create_date_field( + user=user, table=table_a, name="Datetime field", date_include_time=True + ) + single_select_field = data_fixture.create_single_select_field( + user=user, table=table_a, name="Single select" + ) + data_fixture.create_select_option(value="Option 1", field=single_select_field) + data_fixture.create_select_option(value="Option 2", field=single_select_field) + + multiple_select_field = data_fixture.create_multiple_select_field( + user=user, table=table_a, name="Multiple select" + ) + data_fixture.create_select_option(value="Option A", field=multiple_select_field) + data_fixture.create_select_option(value="Option B", field=multiple_select_field) + data_fixture.create_select_option(value="Option C", field=multiple_select_field) + single_link_to_b_field = data_fixture.create_link_row_field( + user=user, + table=table_a, + link_row_table=table_b, + name="Single link to B", + link_row_multiple_relationships=False, + ) + + table_b_rows = ( + RowHandler() + .force_create_rows( + user, + table_b, + [ + {primary_field_table_b.db_column: "Row B1"}, + {primary_field_table_b.db_column: "Row B2"}, + {primary_field_table_b.db_column: "Row B3"}, + ], + ) + .created_rows + ) + + table_a_rows = ( + RowHandler() + .force_create_rows( + user, + table_a, + [ + { + primary_field_table_a.db_column: "Row A1", + text_field.db_column: "Text A1", + long_text_field.db_column: "Long text A1", + number_field.db_column: 10.123, + date_field.db_column: "2023-01-01", + datetime_field.db_column: "2023-01-01 10:00:00", + single_select_field.db_column: "Option 1", + multiple_select_field.db_column: ["Option A", "Option B"], + single_link_to_b_field.db_column: [table_b_rows[0].id], + link_a_to_b.db_column: [table_b_rows[0].id, table_b_rows[1].id], + }, + { + primary_field_table_a.db_column: "Row A2", + text_field.db_column: "Text A2", + long_text_field.db_column: "Long text A2", + number_field.db_column: 20.456, + date_field.db_column: "2023-02-01", + datetime_field.db_column: "2023-02-01 11:00:00", + single_select_field.db_column: "Option 2", + multiple_select_field.db_column: ["Option B", "Option C"], + single_link_to_b_field.db_column: [table_b_rows[1].id], + link_a_to_b.db_column: [table_b_rows[1].id, table_b_rows[2].id], + }, + {}, + ], + ) + .created_rows + ) + + return { + "user": user, + "workspace": workspace, + "table_a": table_a, + "table_b": table_b, + "table_a_fields": { + "link_a_to_b": link_a_to_b, + "text_field": text_field, + "long_text_field": long_text_field, + "number_field": number_field, + "date_field": date_field, + "datetime_field": datetime_field, + "single_select_field": single_select_field, + "multiple_select_field": multiple_select_field, + "single_link_to_b_field": single_link_to_b_field, + }, + "table_a_rows": table_a_rows, + "table_b_rows": table_b_rows, + } + + +@pytest.mark.django_db +def test_list_rows(data_fixture): + res = _create_simple_database_with_linked_tables_and_rows(data_fixture) + + user = res["user"] + workspace = res["workspace"] + table = res["table_a"] + tool_helpers = fake_tool_helpers + + list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) + assert callable(list_table_rows) + + result = list_table_rows(table_id=table.id, offset=0, limit=50) + rows = result["rows"] + assert len(rows) == 3 + assert rows[0] == { + "primary": "Row A1", + "Long text field": "Long text A1", + "Number field": 10.123, + "Date field": {"year": 2023, "month": 1, "day": 1}, + "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Single link to B": "Row B1", + "Multiple select": ["Option A", "Option B"], + "Text field": "Text A1", + "Single select": "Option 1", + "link": ["Row B1", "Row B2"], + "id": 1, + } + assert rows[1] == { + "primary": "Row A2", + "Long text field": "Long text A2", + "Number field": 20.456, + "Date field": {"year": 2023, "month": 2, "day": 1}, + "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Single link to B": "Row B2", + "Multiple select": ["Option B", "Option C"], + "Text field": "Text A2", + "Single select": "Option 2", + "link": ["Row B2", "Row B3"], + "id": 2, + } + assert rows[2] == { + "primary": "", + "Long text field": "", + "Number field": None, + "Date field": None, + "Datetime field": None, + "Single link to B": None, + "Multiple select": [], + "Text field": "", + "Single select": None, + "link": [], + "id": 3, + } + + # List a single field + result = list_table_rows( + table_id=table.id, offset=0, limit=50, field_ids=[table.get_primary_field().id] + ) + rows = result["rows"] + assert len(rows) == 3 + assert rows[0] == { + "primary": "Row A1", + "id": 1, + } + assert rows[1] == { + "primary": "Row A2", + "id": 2, + } + assert rows[2] == { + "primary": "", + "id": 3, + } + + +@pytest.mark.django_db(transaction=True) +def test_create_rows(data_fixture): + res = _create_simple_database_with_linked_tables_and_rows(data_fixture) + + user = res["user"] + workspace = res["workspace"] + table = res["table_a"] + tool_helpers = fake_tool_helpers + + meta_tool = get_rows_meta_tool(user, workspace, tool_helpers) + assert callable(meta_tool) + + tools_upgrade = meta_tool([table.id], ["create"]) + assert isinstance(tools_upgrade, ToolsUpgradeResponse) + assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation + assert f"create_rows_in_table_{table.id}" in tools_upgrade.observation + assert ( + f"update_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation + ) + assert ( + f"delete_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation + ) + assert len(tools_upgrade.new_tools) == 1 + + create_table_rows = tools_upgrade.new_tools[0] + assert create_table_rows.name == f"create_rows_in_table_{table.id}" + + table_model = table.get_model() + assert table_model.objects.count() == 3 + + row_1 = { + "primary": "Row A3", + "Text field": "Text A3", + "Long text field": "Long text A3", + "Number field": 30.789, + "Date field": {"year": 2023, "month": 3, "day": 1}, + "Datetime field": { + "year": 2023, + "month": 3, + "day": 1, + "hour": 12, + "minute": 0, + }, + "Single select": "Option 1", + "Multiple select": ["Option A", "Option C"], + "Single link to B": "Row B3", + "link": ["Row B1"], + } + row_2 = { + "primary": "", + "Text field": "", + "Long text field": "", + "Number field": None, + "Date field": None, + "Datetime field": None, + "Single select": None, + "Multiple select": [], + "Single link to B": None, + "link": [], + } + result = create_table_rows(rows=[row_1, row_2]) + created_row_ids = result["created_row_ids"] + assert len(created_row_ids) == 2 + assert created_row_ids == [4, 5] + + +@pytest.mark.django_db(transaction=True) +def test_update_rows(data_fixture): + res = _create_simple_database_with_linked_tables_and_rows(data_fixture) + + user = res["user"] + workspace = res["workspace"] + table = res["table_a"] + tool_helpers = fake_tool_helpers + + meta_tool = get_rows_meta_tool(user, workspace, tool_helpers) + assert callable(meta_tool) + + tools_upgrade = meta_tool([table.id], ["update"]) + assert isinstance(tools_upgrade, ToolsUpgradeResponse) + assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation + assert f"create_rows_in_table_{table.id}" not in tools_upgrade.observation + assert f"update_rows_in_table_{table.id}_by_row_ids" in tools_upgrade.observation + assert ( + f"delete_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation + ) + assert len(tools_upgrade.new_tools) == 1 + + update_table_rows = tools_upgrade.new_tools[0] + assert update_table_rows.name == f"update_rows_in_table_{table.id}_by_row_ids" + + table_model = table.get_model() + assert table_model.objects.count() == 3 + + # Update row 1 with new values + row_1_updates = { + "id": 1, + "primary": "Updated Row A1", + "Text field": "Updated Text A1", + "Number field": 99.999, + "Single select": "Option 2", + "link": ["Row B3"], + "Single link to B": "Row B2", + "Datetime field": "__NO_CHANGE__", + "Date field": "__NO_CHANGE__", + "Multiple select": "__NO_CHANGE__", + "Long text field": "__NO_CHANGE__", + } + # Update row 2 with new values + row_2_updates = { + "id": 2, + "Single link to B": "__NO_CHANGE__", + "Long text field": "Updated Long text A2", + "Date field": {"year": 2024, "month": 12, "day": 31}, + "Multiple select": ["Option A"], + "primary": "__NO_CHANGE__", + "Text field": "__NO_CHANGE__", + "Number field": "__NO_CHANGE__", + "Datetime field": "__NO_CHANGE__", + "Single select": "__NO_CHANGE__", + "link": "__NO_CHANGE__", + } + + result = update_table_rows(rows=[row_1_updates, row_2_updates]) + updated_row_ids = result["updated_row_ids"] + assert len(updated_row_ids) == 2 + assert updated_row_ids == [1, 2] + + # Verify the rows were updated correctly + list_table_rows = get_list_rows_tool(user, workspace, tool_helpers) + row_1, row_2 = list_table_rows(table_id=table.id, offset=0, limit=2)["rows"] + assert row_1 == { + "primary": "Updated Row A1", + "Long text field": "Long text A1", + "Number field": 99.999, + "Date field": {"year": 2023, "month": 1, "day": 1}, + "Datetime field": {"year": 2023, "month": 1, "day": 1, "hour": 10, "minute": 0}, + "Single link to B": "Row B2", + "Multiple select": ["Option A", "Option B"], + "Text field": "Updated Text A1", + "Single select": "Option 2", + "link": ["Row B3"], + "id": 1, + } + assert row_2 == { + "primary": "Row A2", + "Long text field": "Updated Long text A2", + "Number field": 20.456, + "Date field": {"year": 2024, "month": 12, "day": 31}, + "Datetime field": {"year": 2023, "month": 2, "day": 1, "hour": 11, "minute": 0}, + "Single link to B": "Row B2", + "Multiple select": ["Option A"], + "Text field": "Text A2", + "Single select": "Option 2", + "link": ["Row B2", "Row B3"], + "id": 2, + } + + +@pytest.mark.django_db(transaction=True) +def test_delete_rows(data_fixture): + res = _create_simple_database_with_linked_tables_and_rows(data_fixture) + + user = res["user"] + workspace = res["workspace"] + table = res["table_a"] + tool_helpers = fake_tool_helpers + + meta_tool = get_rows_meta_tool(user, workspace, tool_helpers) + assert callable(meta_tool) + + tools_upgrade = meta_tool([table.id], ["delete"]) + assert isinstance(tools_upgrade, ToolsUpgradeResponse) + assert f"list_rows_in_table_{table.id}" not in tools_upgrade.observation + assert f"create_rows_in_table_{table.id}" not in tools_upgrade.observation + assert ( + f"update_rows_in_table_{table.id}_by_row_ids" not in tools_upgrade.observation + ) + assert f"delete_rows_in_table_{table.id}_by_row_ids" in tools_upgrade.observation + assert len(tools_upgrade.new_tools) == 1 + + delete_table_rows = tools_upgrade.new_tools[0] + assert delete_table_rows.name == f"delete_rows_in_table_{table.id}_by_row_ids" + + table_model = table.get_model() + assert table_model.objects.count() == 3 + + # Delete rows with ids 1 and 3 + result = delete_table_rows(row_ids=[1, 3]) + assert result["deleted_row_ids"] == [1, 3] + + # Verify rows were deleted + assert table_model.objects.count() == 1 + assert list(table_model.objects.values_list("id", flat=True)) == [2] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py new file mode 100644 index 0000000000..50cdc238cc --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_table_tools.py @@ -0,0 +1,325 @@ +import pytest + +from baserow.contrib.database.table.models import Table +from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.database.tools import ( + get_create_tables_tool, + get_list_tables_tool, +) +from baserow_enterprise.assistant.tools.database.types import ( + BooleanFieldItemCreate, + DateFieldItemCreate, + FileFieldItemCreate, + LinkRowFieldItemCreate, + ListTablesFilterArg, + LongTextFieldItemCreate, + MultipleSelectFieldItemCreate, + NumberFieldItemCreate, + RatingFieldItemCreate, + SelectOptionCreate, + SingleSelectFieldItemCreate, + TableItemCreate, + TextFieldItemCreate, + field_item_registry, +) + +from .utils import fake_tool_helpers + + +@pytest.mark.django_db +def test_list_tables_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database_1 = data_fixture.create_database_application( + workspace=workspace, name="Database 1" + ) + database_2 = data_fixture.create_database_application( + workspace=workspace, name="Database 2" + ) + table_1 = data_fixture.create_database_table(database=database_1, name="Table 1") + table_2 = data_fixture.create_database_table(database=database_1, name="Table 2") + table_3 = data_fixture.create_database_table(database=database_2, name="Table 3") + + tool = get_list_tables_tool(user, workspace, fake_tool_helpers) + + # Test 1: Filter by database_ids (single database) - returns flat list + response = tool( + filters=ListTablesFilterArg( + database_ids=[database_1.id], + database_names=None, + table_ids=None, + table_names=None, + ) + ) + assert response == [ + {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, + {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, + ] + + # Test 2: Filter by database_names (single database) - returns flat list + response = tool( + filters=ListTablesFilterArg( + database_ids=None, + database_names=["Database 2"], + table_ids=None, + table_names=None, + ) + ) + assert response == [ + {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, + ] + + # Test 3: Filter by multiple database_ids - returns database wrapper structure + response = tool( + filters=ListTablesFilterArg( + database_ids=[database_1.id, database_2.id], + database_names=None, + table_ids=None, + table_names=None, + ) + ) + assert response == [ + { + "id": database_1.id, + "name": "Database 1", + "tables": [ + {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, + {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, + ], + }, + { + "id": database_2.id, + "name": "Database 2", + "tables": [ + {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, + ], + }, + ] + + # Test 4: Filter by table_ids (single database) - returns flat list + response = tool( + filters=ListTablesFilterArg( + database_ids=None, + database_names=None, + table_ids=[table_1.id, table_2.id], + table_names=None, + ) + ) + assert response == [ + {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, + {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, + ] + + # Test 5: Filter by table_names (single database) - returns flat list + response = tool( + filters=ListTablesFilterArg( + database_ids=None, + database_names=None, + table_ids=None, + table_names=["Table 1"], + ) + ) + assert response == [ + {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, + ] + + # Test 6: Filter by table_ids across multiple databases - returns database wrapper + response = tool( + filters=ListTablesFilterArg( + database_ids=None, + database_names=None, + table_ids=[table_1.id, table_3.id], + table_names=None, + ) + ) + assert response == [ + { + "id": database_1.id, + "name": "Database 1", + "tables": [ + {"id": table_1.id, "name": "Table 1", "database_id": database_1.id}, + ], + }, + { + "id": database_2.id, + "name": "Database 2", + "tables": [ + {"id": table_3.id, "name": "Table 3", "database_id": database_2.id}, + ], + }, + ] + + # Test 7: Combined filters (database_ids + table_names) - returns flat list + response = tool( + filters=ListTablesFilterArg( + database_ids=[database_1.id], + database_names=None, + table_ids=None, + table_names=["Table 2"], + ) + ) + assert response == [ + {"id": table_2.id, "name": "Table 2", "database_id": database_1.id}, + ] + + # Test 8: No matching tables - returns "No tables found" + response = tool( + filters=ListTablesFilterArg( + database_ids=None, + database_names=None, + table_ids=None, + table_names=["Nonexistent Table"], + ) + ) + assert response == "No tables found" + + +@pytest.mark.django_db +def test_create_simple_table_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Database 1" + ) + + tool = get_create_tables_tool(user, workspace, fake_tool_helpers) + response = tool( + database_id=database.id, + tables=[ + TableItemCreate( + name="New Table", + primary_field=TextFieldItemCreate(type="text", name="Name"), + fields=[], + ) + ], + add_sample_rows=False, + ) + + assert response == { + "created_tables": [{"id": AnyInt(), "name": "New Table"}], + "notes": [], + } + + # Ensure the table was actually created + assert Table.objects.filter( + id=response["created_tables"][0]["id"], name="New Table" + ).exists() + + +@pytest.mark.django_db +def test_create_complex_table_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Database 1" + ) + table = data_fixture.create_database_table(database=database, name="Table 1") + + tool = get_create_tables_tool(user, workspace, fake_tool_helpers) + primary_field = TextFieldItemCreate(type="text", name="Name") + fields = [ + LongTextFieldItemCreate( + type="long_text", + name="Description", + rich_text=True, + ), + NumberFieldItemCreate( + type="number", + name="Amount", + decimal_places=2, + suffix="$", + ), + DateFieldItemCreate( + type="date", + name="Due Date", + include_time=False, + ), + DateFieldItemCreate( + type="date", + name="Event Time", + include_time=True, + ), + BooleanFieldItemCreate( + type="boolean", + name="Done?", + ), + SingleSelectFieldItemCreate( + type="single_select", + name="Status", + options=[ + SelectOptionCreate(value="New", color="blue"), + SelectOptionCreate(value="In Progress", color="yellow"), + SelectOptionCreate(value="Done", color="green"), + ], + ), + MultipleSelectFieldItemCreate( + type="multiple_select", + name="Tags", + options=[ + SelectOptionCreate(value="Red", color="red"), + SelectOptionCreate(value="Yellow", color="yellow"), + SelectOptionCreate(value="Green", color="green"), + SelectOptionCreate(value="Blue", color="blue"), + ], + ), + LinkRowFieldItemCreate( + type="link_row", + name="Related Items", + linked_table=table.id, + has_link_back=False, + multiple=True, + ), + RatingFieldItemCreate( + type="rating", + name="Rating", + max_value=5, + ), + FileFieldItemCreate( + type="file", + name="Attachments", + ), + ] + response = tool( + database_id=database.id, + tables=[ + TableItemCreate( + name="New Table", + primary_field=primary_field, + fields=fields, + ) + ], + add_sample_rows=False, + ) + + assert response == { + "created_tables": [{"id": AnyInt(), "name": "New Table"}], + "notes": [], + } + + # Ensure the table was actually created with all fields + created_table = Table.objects.filter( + id=response["created_tables"][0]["id"], name="New Table" + ).first() + assert created_table is not None + assert created_table.field_set.count() == 11 + + table_model = created_table.get_model() + fields_map = {field.name: field for field in fields} + fields_map[primary_field.name] = primary_field + for field_object in table_model.get_field_objects(): + orm_field = field_object["field"] + assert orm_field.name in fields_map + field_item = fields_map.pop(orm_field.name).model_dump() + orm_field_to_item = field_item_registry.from_django_orm(orm_field).model_dump() + if orm_field.primary: + assert field_item["name"] == primary_field.name + + for key, value in orm_field_to_item.items(): + if key == "id": + continue + if key == "options": + # Saved options have an ID, so we need to remove them before comparison + for option in value: + option.pop("id") + + assert field_item[key] == value diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py new file mode 100644 index 0000000000..2c74da289e --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_tools.py @@ -0,0 +1,52 @@ +import pytest + +from baserow.contrib.database.models import Database +from baserow.test_utils.helpers import AnyInt +from baserow_enterprise.assistant.tools.database.tools import ( + get_create_database_tool, + get_list_databases_tool, +) + +from .utils import fake_tool_helpers + + +@pytest.mark.django_db +def test_list_databases_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application( + workspace=workspace, name="Database 1" + ) + + tool = get_list_databases_tool(user, workspace, fake_tool_helpers) + response = tool() + + assert response == {"databases": [{"id": database.id, "name": "Database 1"}]} + + database_2 = data_fixture.create_database_application( + workspace=workspace, name="Database 2" + ) + response = tool() + assert response == { + "databases": [ + {"id": database.id, "name": "Database 1"}, + {"id": database_2.id, "name": "Database 2"}, + ] + } + + +@pytest.mark.django_db +def test_create_database_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + tool = get_create_database_tool(user, workspace, fake_tool_helpers) + response = tool(name="New Database") + + assert response == {"created_database": {"id": AnyInt(), "name": "New Database"}} + + # Ensure the database was actually created + + assert Database.objects.filter( + id=response["created_database"]["id"], name="New Database" + ).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py new file mode 100644 index 0000000000..16128971cb --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_database_views_tools.py @@ -0,0 +1,719 @@ +import pytest + +from baserow.contrib.database.views.models import View, ViewFilter +from baserow_enterprise.assistant.tools.database.tools import ( + get_create_view_filters_tool, + get_create_views_tool, + get_list_views_tool, +) +from baserow_enterprise.assistant.tools.database.types import ( + BooleanIsViewFilterItemCreate, + CalendarViewItemCreate, + DateAfterViewFilterItemCreate, + DateBeforeViewFilterItemCreate, + DateEqualsViewFilterItemCreate, + DateNotEqualsViewFilterItemCreate, + FormFieldOption, + FormViewItemCreate, + GalleryViewItemCreate, + GridViewItemCreate, + KanbanViewItemCreate, + MultipleSelectIsAnyViewFilterItemCreate, + MultipleSelectIsNoneOfNotViewFilterItemCreate, + NumberEqualsViewFilterItemCreate, + NumberHigherThanViewFilterItemCreate, + NumberLowerThanViewFilterItemCreate, + NumberNotEqualsViewFilterItemCreate, + SingleSelectIsAnyViewFilterItemCreate, + SingleSelectIsNoneOfNotViewFilterItemCreate, + TextContainsViewFilterItemCreate, + TextEqualViewFilterItemCreate, + TextNotContainsViewFilterItemCreate, + TextNotEqualViewFilterItemCreate, + TimelineViewItemCreate, +) +from baserow_enterprise.assistant.tools.database.types.base import Date + +from .utils import fake_tool_helpers + + +@pytest.mark.django_db +def test_list_views_tool(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + view = data_fixture.create_grid_view(table=table, name="View 1", order=1) + + tool = get_list_views_tool(user, workspace, fake_tool_helpers) + response = tool(table_id=table.id) + + assert response == { + "views": [ + { + "id": view.id, + "name": "View 1", + "type": "grid", + "row_height": "small", + "public": False, + } + ] + } + + view_2 = data_fixture.create_grid_view(table=table, name="View 2", order=2) + response = tool(table_id=table.id) + assert len(response["views"]) == 2 + assert response["views"][0]["name"] == "View 1" + assert response["views"][1]["name"] == "View 2" + + +@pytest.mark.django_db +def test_create_grid_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + GridViewItemCreate( + type="grid", name="Grid View", public=False, row_height="medium" + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Grid View" + assert View.objects.filter(name="Grid View").exists() + + +@pytest.mark.django_db +def test_create_kanban_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + single_select = data_fixture.create_single_select_field(table=table, name="Status") + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + KanbanViewItemCreate( + type="kanban", + name="Kanban View", + public=False, + column_field_id=single_select.id, + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Kanban View" + assert View.objects.filter(name="Kanban View").exists() + + +@pytest.mark.django_db +def test_create_calendar_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + date_field = data_fixture.create_date_field(table=table, name="Date") + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + CalendarViewItemCreate( + type="calendar", + name="Calendar View", + public=False, + date_field_id=date_field.id, + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Calendar View" + assert View.objects.filter(name="Calendar View").exists() + + +@pytest.mark.django_db +def test_create_gallery_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + file_field = data_fixture.create_file_field(table=table, name="Files") + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + GalleryViewItemCreate( + type="gallery", + name="Gallery View", + public=False, + cover_field_id=file_field.id, + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Gallery View" + assert View.objects.filter(name="Gallery View").exists() + + +@pytest.mark.django_db +def test_create_timeline_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + start_date = data_fixture.create_date_field(table=table, name="Start Date") + end_date = data_fixture.create_date_field(table=table, name="End Date") + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + TimelineViewItemCreate( + type="timeline", + name="Timeline View", + public=False, + start_date_field_id=start_date.id, + end_date_field_id=end_date.id, + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Timeline View" + assert View.objects.filter(name="Timeline View").exists() + + +@pytest.mark.django_db +def test_create_form_view(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Name", primary=True) + + tool = get_create_views_tool(user, workspace, fake_tool_helpers) + response = tool( + table_id=table.id, + views=[ + FormViewItemCreate( + type="form", + name="Form View", + public=True, + title="Contact Form", + description="Fill out this form", + submit_button_label="Submit", + receive_notification_on_submit=False, + submit_action="MESSAGE", + submit_action_message="Thank you!", + submit_action_redirect_url="", + field_options=[ + FormFieldOption( + field_id=field.id, + name="Your Name", + description="Enter your name", + required=True, + order=1, + ) + ], + ) + ], + ) + + assert len(response["created_views"]) == 1 + assert response["created_views"][0]["name"] == "Form View" + assert View.objects.filter(name="Form View").exists() + + +# Text filter tests +@pytest.mark.django_db +def test_create_text_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table, name="Name") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + TextEqualViewFilterItemCreate( + field_id=field.id, type="text", operator="equal", value="test" + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert response["created_view_filters"][0]["operator"] == "equal" + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() + + +@pytest.mark.django_db +def test_create_text_not_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + TextNotEqualViewFilterItemCreate( + field_id=field.id, type="text", operator="not_equal", value="test" + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="not_equal").exists() + + +@pytest.mark.django_db +def test_create_text_contains_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + TextContainsViewFilterItemCreate( + field_id=field.id, type="text", operator="contains", value="test" + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="contains").exists() + + +@pytest.mark.django_db +def test_create_text_not_contains_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_text_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + TextNotContainsViewFilterItemCreate( + field_id=field.id, type="text", operator="contains_not", value="test" + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="contains_not" + ).exists() + + +# Number filter tests +@pytest.mark.django_db +def test_create_number_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + NumberEqualsViewFilterItemCreate( + field_id=field.id, type="number", operator="equal", value=42.0 + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="equal").exists() + + +@pytest.mark.django_db +def test_create_number_not_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + NumberNotEqualsViewFilterItemCreate( + field_id=field.id, type="number", operator="not_equal", value=42.0 + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="not_equal").exists() + + +@pytest.mark.django_db +def test_create_number_higher_than_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + NumberHigherThanViewFilterItemCreate( + field_id=field.id, + type="number", + operator="higher_than", + value=10.0, + or_equal=False, + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="higher_than" + ).exists() + + +@pytest.mark.django_db +def test_create_number_lower_than_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_number_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + NumberLowerThanViewFilterItemCreate( + field_id=field.id, + type="number", + operator="lower_than", + value=100.0, + or_equal=False, + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="lower_than").exists() + + +# Date filter tests +@pytest.mark.django_db +def test_create_date_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_date_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + DateEqualsViewFilterItemCreate( + field_id=field.id, + type="date", + operator="equal", + value=Date(year=2024, month=1, day=15), + mode="exact_date", + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="date_is").exists() + + +@pytest.mark.django_db +def test_create_date_not_equal_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_date_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + DateNotEqualsViewFilterItemCreate( + field_id=field.id, + type="date", + operator="not_equal", + value=None, + mode="today", + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="date_is_not" + ).exists() + + +@pytest.mark.django_db +def test_create_date_after_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_date_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + DateAfterViewFilterItemCreate( + field_id=field.id, + type="date", + operator="after", + value=7, + mode="nr_days_ago", + or_equal=False, + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="date_is_after" + ).exists() + + +@pytest.mark.django_db +def test_create_date_before_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_date_field(table=table) + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + DateBeforeViewFilterItemCreate( + field_id=field.id, + type="date", + operator="before", + value=None, + mode="tomorrow", + or_equal=True, + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="date_is_on_or_before" + ).exists() + + +# Single select filter tests +@pytest.mark.django_db +def test_create_single_select_is_any_of_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_single_select_field(table=table) + data_fixture.create_select_option(field=field, value="Option 1") + data_fixture.create_select_option(field=field, value="Option 2") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + SingleSelectIsAnyViewFilterItemCreate( + field_id=field.id, + type="single_select", + operator="is_any_of", + value=["Option 1", "Option 2"], + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="single_select_is_any_of" + ).exists() + + +@pytest.mark.django_db +def test_create_single_select_is_none_of_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_single_select_field(table=table) + data_fixture.create_select_option(field=field, value="Bad Option") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + SingleSelectIsNoneOfNotViewFilterItemCreate( + field_id=field.id, + type="single_select", + operator="is_none_of", + value=["Bad Option"], + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="single_select_is_none_of" + ).exists() + + +# Boolean filter tests +@pytest.mark.django_db +def test_create_boolean_is_true_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_boolean_field(table=table, name="Active") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + BooleanIsViewFilterItemCreate( + field_id=field.id, type="boolean", operator="is", value=True + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + + +@pytest.mark.django_db +def test_create_boolean_is_false_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_boolean_field(table=table, name="Active") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + BooleanIsViewFilterItemCreate( + field_id=field.id, type="boolean", operator="is", value=False + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter(view=view, field=field, type="boolean").exists() + + +# Multiple select filter tests +@pytest.mark.django_db +def test_create_multiple_select_is_any_of_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_multiple_select_field(table=table) + data_fixture.create_select_option(field=field, value="Tag 1") + data_fixture.create_select_option(field=field, value="Tag 2") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + MultipleSelectIsAnyViewFilterItemCreate( + field_id=field.id, + type="multiple_select", + operator="is_any_of", + value=["Tag 1", "Tag 2"], + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="multiple_select_has" + ).exists() + + +@pytest.mark.django_db +def test_create_multiple_select_is_none_of_filter(data_fixture): + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + database = data_fixture.create_database_application(workspace=workspace) + table = data_fixture.create_database_table(database=database) + field = data_fixture.create_multiple_select_field(table=table) + data_fixture.create_select_option(field=field, value="Bad Tag") + view = data_fixture.create_grid_view(table=table) + + tool = get_create_view_filters_tool(user, workspace, fake_tool_helpers) + response = tool( + view_id=view.id, + filters=[ + MultipleSelectIsNoneOfNotViewFilterItemCreate( + field_id=field.id, + type="multiple_select", + operator="is_none_of", + value=["Bad Tag"], + ) + ], + ) + + assert len(response["created_view_filters"]) == 1 + assert ViewFilter.objects.filter( + view=view, field=field, type="multiple_select_has_not" + ).exists() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py new file mode 100644 index 0000000000..d6d975e160 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/utils.py @@ -0,0 +1,3 @@ +from baserow_enterprise.assistant.tools.registries import ToolHelpers + +fake_tool_helpers = ToolHelpers(lambda x: None, lambda x: None) diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss index 2a1d655188..bdeac73cdf 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss @@ -157,6 +157,12 @@ } } +.assistant__status-message { + @extend %ellipsis; + + display: inline-block; +} + .assistant__input-section { background: $palette-neutral-25; diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantInputMessage.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantInputMessage.vue index 4d80f6267c..46628bbc61 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantInputMessage.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantInputMessage.vue @@ -2,11 +2,11 @@
- + {{ $t('assistantInputMessage.statusWaiting') }} - - {{ getRunningMessage() }} + + {{ runningMessage || $t('assistant.statusThinking') }}
@@ -44,15 +44,6 @@