From 05fd7bba0a4e85a24b4c1b8b728c786db4d1b307 Mon Sep 17 00:00:00 2001 From: Davide Silvestri <75379892+silvestrid@users.noreply.github.com> Date: Tue, 23 Dec 2025 13:34:41 +0100 Subject: [PATCH] chore: dynamically import backend library to reduce memory footprint and startup time (#4315) * lazy imports * Fix tests * Add warning checks for libraries that should be lazy imported * Address feedback --- backend/src/baserow/config/asgi.py | 9 ++-- backend/src/baserow/config/celery.py | 13 ++++++ backend/src/baserow/config/helpers.py | 46 +++++++++++++++++++ backend/src/baserow/config/settings/base.py | 29 ++++++++++-- backend/src/baserow/config/settings/dev.py | 5 ++ backend/src/baserow/config/wsgi.py | 9 ++++ .../contrib/database/mcp/rows/tools.py | 18 +++++++- .../contrib/database/mcp/table/tools.py | 6 ++- .../contrib/integrations/ai/service_types.py | 9 ++-- .../generative_ai_model_types.py | 36 +++++++++++---- backend/src/baserow/core/mcp/__init__.py | 42 ++++++++++++----- backend/src/baserow/core/mcp/registries.py | 15 +++--- backend/src/baserow/core/mcp/sse.py | 34 +++++++++----- backend/src/baserow/core/output_parsers.py | 45 +++++++++--------- backend/src/baserow/core/posthog.py | 29 ++++++++---- backend/src/baserow/core/sentry.py | 4 +- backend/src/baserow/core/telemetry/tasks.py | 3 +- .../src/baserow/core/telemetry/telemetry.py | 44 ++++++++++-------- .../core/two_factor_auth/registries.py | 7 ++- .../src/baserow/core/user_files/handler.py | 12 +++-- backend/tests/baserow/core/test_posthog.py | 37 +++++++++++---- ...es_to_reduce_initial_memory_footprint.json | 9 ++++ docker-compose.dev.yml | 2 +- .../api/sso/saml/validators.py | 5 +- .../baserow_enterprise/assistant/telemetry.py | 15 ++++-- .../data_sync/jira_issues_data_sync.py | 3 +- .../baserow_enterprise/sso/saml/handler.py | 20 +++++--- .../assistant/test_telemetry.py | 46 ++++++++++++++----- .../docker-compose.multi-service.dev.yml | 2 +- .../src/baserow_premium/api/fields/views.py | 4 +- .../baserow_premium/export/exporter_types.py | 5 +- .../fields/ai_field_output_types.py | 11 +++-- .../src/baserow_premium/fields/exceptions.py | 6 +++ .../src/baserow_premium/fields/handler.py | 10 ++-- .../fields/test_ai_field_handler.py | 4 +- .../fields/test_ai_field_output_types.py | 4 +- 36 files changed, 435 insertions(+), 163 deletions(-) create mode 100644 changelog/entries/unreleased/refactor/lazy_import_libraries_to_reduce_initial_memory_footprint.json diff --git a/backend/src/baserow/config/asgi.py b/backend/src/baserow/config/asgi.py index 33f44bf8a9..0afdbb6bd8 100644 --- a/backend/src/baserow/config/asgi.py +++ b/backend/src/baserow/config/asgi.py @@ -4,8 +4,8 @@ from channels.routing import ProtocolTypeRouter, URLRouter -from baserow.config.helpers import ConcurrencyLimiterASGI -from baserow.core.mcp import baserow_mcp +from baserow.config.helpers import ConcurrencyLimiterASGI, check_lazy_loaded_libraries +from baserow.core.mcp import get_baserow_mcp_server from baserow.core.telemetry.telemetry import setup_logging, setup_telemetry from baserow.ws.routers import websocket_router @@ -18,13 +18,16 @@ # logging setup. Otherwise Django will try to destroy and log handlers we added prior. setup_logging() +# Check that libraries meant to be lazy-loaded haven't been imported at startup. +# This runs after Django is fully loaded, so it catches imports from all apps. +check_lazy_loaded_libraries() application = ProtocolTypeRouter( { "http": ConcurrencyLimiterASGI( URLRouter( [ - re_path(r"^mcp", baserow_mcp.sse_app()), + re_path(r"^mcp", get_baserow_mcp_server().sse_app()), re_path(r"", django_asgi_app), ] ), diff --git a/backend/src/baserow/config/celery.py b/backend/src/baserow/config/celery.py index 2ef0b6370e..eb95d46b0d 100644 --- a/backend/src/baserow/config/celery.py +++ b/backend/src/baserow/config/celery.py @@ -1,5 +1,8 @@ +from django.conf import settings + from celery import Celery, signals +from baserow.config.helpers import check_lazy_loaded_libraries from baserow.core.telemetry.tasks import BaserowTelemetryTask app = Celery("baserow") @@ -26,3 +29,13 @@ def clear_local(*args, **kwargs): signals.task_prerun.connect(clear_local) signals.task_postrun.connect(clear_local) + + +@signals.worker_process_init.connect +def on_worker_init(**kwargs): + # This is only needed in asgi.py + settings.BASEROW_LAZY_LOADED_LIBRARIES.append("mcp") + + # Check that libraries meant to be lazy-loaded haven't been imported at startup. + # This runs after Django is fully loaded, so it catches imports from all apps. + check_lazy_loaded_libraries() diff --git a/backend/src/baserow/config/helpers.py b/backend/src/baserow/config/helpers.py index 8fdce13dd1..adf35aa1ea 100644 --- a/backend/src/baserow/config/helpers.py +++ b/backend/src/baserow/config/helpers.py @@ -1,8 +1,54 @@ import asyncio +import sys + +from django.conf import settings from loguru import logger +def check_lazy_loaded_libraries(): + """ + Check if any libraries that should be lazy-loaded have been imported at startup. + + This function checks sys.modules against settings.BASEROW_LAZY_LOADED_LIBRARIES + and emits a warning if any of them have been loaded prematurely. This helps + catch accidental top-level imports that defeat the purpose of lazy loading + these heavy libraries to reduce memory footprint. + + Only runs when DEBUG is True. + """ + + if not settings.DEBUG: + return + + lazy_libs = getattr(settings, "BASEROW_LAZY_LOADED_LIBRARIES", []) + loaded_early = [] + + for lib in lazy_libs: + if lib in sys.modules: + loaded_early.append(lib) + + if loaded_early: + libs_list = ", ".join(f'"{lib}"' for lib in loaded_early) + logger.warning( + f"The following libraries were loaded during startup but should be " + f"lazy-loaded to reduce memory footprint: {', '.join(loaded_early)}. " + f"Either import them inside functions/methods where they're used, or " + f"remove them from BASEROW_LAZY_LOADED_LIBRARIES if they're legitimately " + f"needed at startup. " + f"To debug, add the following code at the very top of your settings file " + f"(e.g., settings/dev.py, before any other imports):\n\n" + f"import sys, traceback\n" + f"class _T:\n" + f" def find_module(self, n, p=None):\n" + f" for lib in [{libs_list}]:\n" + f" if n == lib or n.startswith(lib + '.'):\n" + f" print(f'IMPORT: {{n}}'); traceback.print_stack(); sys.exit(1)\n" + f" return None\n" + f"sys.meta_path.insert(0, _T())\n" + ) + + class dummy_context: async def __aenter__(self): pass diff --git a/backend/src/baserow/config/settings/base.py b/backend/src/baserow/config/settings/base.py index 4b35533444..ebcadc5301 100644 --- a/backend/src/baserow/config/settings/base.py +++ b/backend/src/baserow/config/settings/base.py @@ -11,10 +11,7 @@ from django.core.exceptions import ImproperlyConfigured import dj_database_url -import sentry_sdk from corsheaders.defaults import default_headers -from sentry_sdk.integrations.django import DjangoIntegration -from sentry_sdk.scrubber import DEFAULT_DENYLIST, EventScrubber from baserow.config.settings.utils import ( Setting, @@ -1303,11 +1300,33 @@ def __setitem__(self, key, value): print(e) +# Libraries that should be lazy-loaded (imported inside functions/methods) to reduce +# memory footprint at startup. If any of these are found in sys.modules during startup, +# a warning will be shown suggesting to either lazy-load them or remove them from this +# list if they're legitimately needed at startup. +BASEROW_LAZY_LOADED_LIBRARIES = [ + "openai", + "anthropic", + "mistralai", + "ollama", + "langchain_core", + "jira2markdown", + "saml2", + "openpyxl", + "numpy", +] + + SENTRY_BACKEND_DSN = os.getenv("SENTRY_BACKEND_DSN") SENTRY_DSN = SENTRY_BACKEND_DSN or os.getenv("SENTRY_DSN") -SENTRY_DENYLIST = DEFAULT_DENYLIST + ["username", "email", "name"] if SENTRY_DSN: + import sentry_sdk + from sentry_sdk.integrations.django import DjangoIntegration + from sentry_sdk.scrubber import DEFAULT_DENYLIST, EventScrubber + + SENTRY_DENYLIST = DEFAULT_DENYLIST + ["username", "email", "name"] + sentry_sdk.init( dsn=SENTRY_DSN, integrations=[DjangoIntegration(signals_spans=False, middleware_spans=False)], @@ -1315,6 +1334,8 @@ def __setitem__(self, key, value): event_scrubber=EventScrubber(recursive=True, denylist=SENTRY_DENYLIST), environment=os.getenv("SENTRY_ENVIRONMENT", ""), ) +else: + BASEROW_LAZY_LOADED_LIBRARIES.append("sentry_sdk") BASEROW_OPENAI_API_KEY = os.getenv("BASEROW_OPENAI_API_KEY", None) BASEROW_OPENAI_ORGANIZATION = os.getenv("BASEROW_OPENAI_ORGANIZATION", "") or None diff --git a/backend/src/baserow/config/settings/dev.py b/backend/src/baserow/config/settings/dev.py index 359bd5ed42..c71b5ed367 100755 --- a/backend/src/baserow/config/settings/dev.py +++ b/backend/src/baserow/config/settings/dev.py @@ -20,6 +20,11 @@ INSTALLED_APPS.insert(0, "daphne") # noqa: F405 INSTALLED_APPS += ["django_extensions"] # noqa: F405 +# daphne imports numpy via autobahn -> flatbuffers, so we exclude it from the +# lazy-load check in dev mode. In production, numpy should still be lazy-loaded. +if "numpy" in BASEROW_LAZY_LOADED_LIBRARIES: # noqa: F405 + BASEROW_LAZY_LOADED_LIBRARIES.remove("numpy") # noqa: F405 + BASEROW_ENABLE_SILK = str_to_bool(os.getenv("BASEROW_ENABLE_SILK", "on")) if BASEROW_ENABLE_SILK: INSTALLED_APPS += ["silk"] # noqa: F405 diff --git a/backend/src/baserow/config/wsgi.py b/backend/src/baserow/config/wsgi.py index f80362267e..6eb510255c 100644 --- a/backend/src/baserow/config/wsgi.py +++ b/backend/src/baserow/config/wsgi.py @@ -7,8 +7,10 @@ https://docs.djangoproject.com/en/2.2/howto/deployment/wsgi/ """ +from django.conf import settings from django.core.wsgi import get_wsgi_application +from baserow.config.helpers import check_lazy_loaded_libraries from baserow.core.telemetry.telemetry import setup_logging, setup_telemetry # The telemetry instrumentation library setup needs to run prior to django's setup. @@ -19,3 +21,10 @@ # It is critical to setup our own logging after django has been setup and done its own # logging setup. Otherwise Django will try to destroy and log handlers we added prior. setup_logging() + +# This is only needed in asgi.py +settings.BASEROW_LAZY_LOADED_LIBRARIES.append("mcp") + +# Check that libraries meant to be lazy-loaded haven't been imported at startup. +# This runs after Django is fully loaded, so it catches imports from all apps. +check_lazy_loaded_libraries() diff --git a/backend/src/baserow/contrib/database/mcp/rows/tools.py b/backend/src/baserow/contrib/database/mcp/rows/tools.py index 95cba5adf3..523039e537 100644 --- a/backend/src/baserow/contrib/database/mcp/rows/tools.py +++ b/backend/src/baserow/contrib/database/mcp/rows/tools.py @@ -1,6 +1,4 @@ from asgiref.sync import sync_to_async -from mcp import Tool -from mcp.types import TextContent from rest_framework.response import Response from starlette.status import HTTP_204_NO_CONTENT @@ -23,6 +21,8 @@ class ListRowsMcpTool(MCPTool): name = "list_table_rows" async def list(self, endpoint): + from mcp import Tool + return [ Tool( name=self.name, @@ -64,6 +64,8 @@ async def call( name_parameters, call_arguments, ): + from mcp.types import TextContent + table_id = call_arguments["table_id"] if not await sync_to_async(table_in_workspace_of_endpoint)(endpoint, table_id): return [TextContent(type="text", text="Table not in endpoint workspace.")] @@ -92,6 +94,8 @@ class CreateRowMcpTool(MCPTool): name = "create_row_table_{id}" async def list(self, endpoint): + from mcp import Tool + tables = await sync_to_async(get_all_tables)(endpoint) tables = await sync_to_async(remove_table_no_permission)( endpoint, tables, CreateRowDatabaseTableOperationType @@ -127,6 +131,8 @@ async def call( name_parameters, call_arguments, ): + from mcp.types import TextContent + table_id = name_parameters["id"] if not await sync_to_async(table_in_workspace_of_endpoint)(endpoint, table_id): return [TextContent(type="text", text="Table not in endpoint workspace.")] @@ -148,6 +154,8 @@ class UpdateRowMcpTool(MCPTool): name = "update_row_table_{id}" async def list(self, endpoint): + from mcp import Tool + tables = await sync_to_async(get_all_tables)(endpoint) tables = await sync_to_async(remove_table_no_permission)( endpoint, tables, UpdateDatabaseRowOperationType @@ -187,6 +195,8 @@ async def call( name_parameters, call_arguments, ): + from mcp.types import TextContent + table_id = name_parameters["id"] if not await sync_to_async(table_in_workspace_of_endpoint)(endpoint, table_id): return [TextContent(type="text", text="Table not in endpoint workspace.")] @@ -211,6 +221,8 @@ class DeleteRowMcpTool(MCPTool): name = "delete_table_row" async def list(self, endpoint): + from mcp import Tool + return [ Tool( name=self.name, @@ -241,6 +253,8 @@ async def call( name_parameters, call_arguments, ): + from mcp.types import TextContent + table_id = call_arguments["table_id"] if not await sync_to_async(table_in_workspace_of_endpoint)(endpoint, table_id): return [TextContent(type="text", text="Table not in endpoint workspace.")] diff --git a/backend/src/baserow/contrib/database/mcp/table/tools.py b/backend/src/baserow/contrib/database/mcp/table/tools.py index 8cfccc1398..17f959dac8 100644 --- a/backend/src/baserow/contrib/database/mcp/table/tools.py +++ b/backend/src/baserow/contrib/database/mcp/table/tools.py @@ -1,8 +1,6 @@ import json from asgiref.sync import sync_to_async -from mcp import Tool -from mcp.types import TextContent from baserow.contrib.database.api.tables.serializers import ( TableWithoutDataSyncSerializer, @@ -16,6 +14,8 @@ class ListTablesMcpTool(MCPTool): name = "list_tables" async def list(self, endpoint): + from mcp import Tool + return [ Tool( name=self.name, @@ -34,6 +34,8 @@ async def call( name_parameters, call_arguments, ): + from mcp.types import TextContent + tables = await sync_to_async(get_all_tables)(endpoint) serializer = TableWithoutDataSyncSerializer(tables, many=True) table_json = json.dumps(serializer.data) diff --git a/backend/src/baserow/contrib/integrations/ai/service_types.py b/backend/src/baserow/contrib/integrations/ai/service_types.py index 55873f8e76..4a64240561 100644 --- a/backend/src/baserow/contrib/integrations/ai/service_types.py +++ b/backend/src/baserow/contrib/integrations/ai/service_types.py @@ -3,8 +3,6 @@ from django.contrib.auth.models import AbstractUser -from langchain_core.exceptions import OutputParserException -from langchain_core.prompts import PromptTemplate from rest_framework import serializers from rest_framework.exceptions import ValidationError as DRFValidationError @@ -18,7 +16,7 @@ ) from baserow.core.generative_ai.registries import generative_ai_model_type_registry from baserow.core.integrations.handler import IntegrationHandler -from baserow.core.output_parsers import StrictEnumOutputParser +from baserow.core.output_parsers import get_strict_enum_output_parser from baserow.core.services.dispatch_context import DispatchContext from baserow.core.services.exceptions import ( ServiceImproperlyConfiguredDispatchException, @@ -170,6 +168,9 @@ def dispatch_data( resolved_values: Dict[str, Any], dispatch_context: DispatchContext, ) -> Dict[str, Any]: + from langchain_core.exceptions import OutputParserException + from langchain_core.prompts import PromptTemplate + if not service.ai_generative_ai_type: raise ServiceImproperlyConfiguredDispatchException( "The AI provider type is missing." @@ -228,7 +229,7 @@ def dispatch_data( choices_enum = enum.Enum( "Choices", {f"OPTION_{i}": choice for i, choice in enumerate(choices)} ) - output_parser = StrictEnumOutputParser(enum=choices_enum) + output_parser = get_strict_enum_output_parser(enum=choices_enum) format_instructions = output_parser.get_format_instructions() prompt_template = PromptTemplate( template=prompt + "\n\nGiven this user query:\n\n{format_instructions}", diff --git a/backend/src/baserow/core/generative_ai/generative_ai_model_types.py b/backend/src/baserow/core/generative_ai/generative_ai_model_types.py index 9c59cd77da..c057445489 100644 --- a/backend/src/baserow/core/generative_ai/generative_ai_model_types.py +++ b/backend/src/baserow/core/generative_ai/generative_ai_model_types.py @@ -3,15 +3,6 @@ from django.conf import settings -from anthropic import Anthropic, APIStatusError -from mistralai import Mistral -from mistralai.models import HTTPValidationError, SDKError -from ollama import Client as OllamaClient -from ollama import RequestError as OllamaRequestError -from ollama import ResponseError as OllamaResponseError -from openai import APIStatusError as OpenAIAPIStatusError -from openai import OpenAI, OpenAIError - from baserow.core.generative_ai.exceptions import AIFileError, GenerativeAIPromptError from baserow.core.generative_ai.types import FileId @@ -49,6 +40,8 @@ def is_enabled(self, workspace=None, settings_override=None): ) def get_client(self, workspace=None, settings_override=None): + from openai import OpenAI + api_key = self.get_api_key(workspace, settings_override) organization = self.get_organization(workspace, settings_override) base_url = self.get_base_url(workspace, settings_override) @@ -62,6 +55,9 @@ def get_settings_serializer(self): def prompt( self, model, prompt, workspace=None, temperature=None, settings_override=None ): + from openai import APIStatusError as OpenAIAPIStatusError + from openai import OpenAIError + try: client = self.get_client(workspace, settings_override) kwargs = {} @@ -118,6 +114,9 @@ def get_max_file_size(self) -> int: return min(512, settings.BASEROW_OPENAI_UPLOADED_FILE_SIZE_LIMIT_MB) def upload_file(self, file_name: str, file: bytes, workspace=None) -> FileId: + from openai import APIStatusError as OpenAIAPIStatusError + from openai import OpenAIError + try: client = self.get_client(workspace=workspace) openai_file = client.files.create( @@ -128,6 +127,9 @@ def upload_file(self, file_name: str, file: bytes, workspace=None) -> FileId: raise AIFileError(str(exc)) from exc def delete_files(self, file_ids: list[FileId], workspace=None): + from openai import APIStatusError as OpenAIAPIStatusError + from openai import OpenAIError + try: client = self.get_client(workspace=workspace) for file_id in file_ids: @@ -138,6 +140,9 @@ def delete_files(self, file_ids: list[FileId], workspace=None): def prompt_with_files( self, model, prompt, file_ids: list[FileId], workspace=None, temperature=None ): + from openai import APIStatusError as OpenAIAPIStatusError + from openai import OpenAIError + run, thread, assistant = None, None, None try: client = self.get_client(workspace) @@ -229,6 +234,8 @@ def is_enabled(self, workspace=None, settings_override=None): ) def get_client(self, workspace=None, settings_override=None): + from anthropic import Anthropic + api_key = self.get_api_key(workspace, settings_override) return Anthropic(api_key=api_key) @@ -240,6 +247,8 @@ def get_settings_serializer(self): def prompt( self, model, prompt, workspace=None, temperature=None, settings_override=None ): + from anthropic import APIStatusError + try: client = self.get_client(workspace, settings_override) kwargs = {} @@ -286,6 +295,8 @@ def is_enabled(self, workspace=None, settings_override=None): ) def get_client(self, workspace=None, settings_override=None): + from mistralai import Mistral + api_key = self.get_api_key(workspace, settings_override) return Mistral(api_key=api_key) @@ -297,6 +308,8 @@ def get_settings_serializer(self): def prompt( self, model, prompt, workspace=None, temperature=None, settings_override=None ): + from mistralai.models import HTTPValidationError, SDKError + try: client = self.get_client(workspace, settings_override) kwargs = {} @@ -339,12 +352,17 @@ def is_enabled(self, workspace=None, settings_override=None): ) def get_client(self, workspace=None, settings_override=None): + from ollama import Client as OllamaClient + ollama_host = self.get_host(workspace, settings_override) return OllamaClient(host=ollama_host) def prompt( self, model, prompt, workspace=None, temperature=None, settings_override=None ): + from ollama import RequestError as OllamaRequestError + from ollama import ResponseError as OllamaResponseError + client = self.get_client(workspace, settings_override) options = {} if temperature: diff --git a/backend/src/baserow/core/mcp/__init__.py b/backend/src/baserow/core/mcp/__init__.py index 3a485bb62f..72c19f8d58 100644 --- a/backend/src/baserow/core/mcp/__init__.py +++ b/backend/src/baserow/core/mcp/__init__.py @@ -1,18 +1,15 @@ import contextvars +from typing import TYPE_CHECKING from asgiref.sync import sync_to_async from loguru import logger -from mcp.server.lowlevel.server import Server -from mcp.server.lowlevel.server import lifespan as default_lifespan -from mcp.types import TextContent -from mcp.types import Tool as MCPTool -from starlette.applications import Starlette -from starlette.requests import Request -from starlette.responses import Response -from starlette.routing import Mount, Route from baserow.core.mcp.sse import DjangoChannelsSseServerTransport +if TYPE_CHECKING: + from mcp.types import Tool + from starlette.applications import Starlette + current_key: contextvars.ContextVar[str] = contextvars.ContextVar("current_key") @@ -30,6 +27,9 @@ class BaserowMCPServer: """ def __init__(self): + from mcp.server.lowlevel.server import Server + from mcp.server.lowlevel.server import lifespan as default_lifespan + self._mcp_server = Server( name="Baserow MCP", instructions="Handles all the actions, operations, mutations, and tools " @@ -78,6 +78,8 @@ async def get_endpoint(self): return None async def call_tool(self, name: str, arguments): + from mcp.types import TextContent + from baserow.core.mcp.registries import mcp_tool_registry endpoint = await self.get_endpoint() @@ -88,7 +90,7 @@ async def call_tool(self, name: str, arguments): return [TextContent(type="text", text=f"Tool '{name}' not found.")] return await tool.call(endpoint, name, params, arguments) - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self) -> list["Tool"]: from baserow.core.mcp.registries import mcp_tool_registry endpoint = await self.get_endpoint() @@ -99,7 +101,18 @@ async def list_tools(self) -> list[MCPTool]: return [] return await mcp_tool_registry.list_all_tools(endpoint) - def sse_app(self) -> Starlette: + def sse_app(self) -> "Starlette": + """ + Returns an ASGI application that can handle MCP SSE connections. + + return: Starlette: The ASGI application for handling MCP SSE connections. + """ + + from starlette.applications import Starlette + from starlette.requests import Request + from starlette.responses import Response + from starlette.routing import Mount, Route + sse_path = "/mcp/{key}/sse" messages_path = "/mcp/messages/" sse = DjangoChannelsSseServerTransport(messages_path) @@ -157,4 +170,11 @@ async def handle_sse(request: Request) -> None: ) -baserow_mcp = BaserowMCPServer() +_baserow_mcp = None + + +def get_baserow_mcp_server() -> BaserowMCPServer: + global _baserow_mcp + if _baserow_mcp is None: + _baserow_mcp = BaserowMCPServer() + return _baserow_mcp diff --git a/backend/src/baserow/core/mcp/registries.py b/backend/src/baserow/core/mcp/registries.py index fcf6d69ded..782ad2f014 100644 --- a/backend/src/baserow/core/mcp/registries.py +++ b/backend/src/baserow/core/mcp/registries.py @@ -1,12 +1,13 @@ -from typing import Any, Dict, List, Optional, Sequence, Union - -from mcp import Tool -from mcp.types import EmbeddedResource, ImageContent, TextContent +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union from baserow.core.mcp.models import MCPEndpoint from baserow.core.mcp.utils import NameRoute from baserow.core.registry import Instance, Registry +if TYPE_CHECKING: + from mcp import Tool + from mcp.types import EmbeddedResource, ImageContent, TextContent + class MCPTool(Instance): name = None @@ -23,7 +24,7 @@ def get_name(self): ) return self.name - async def list(self, endpoint: MCPEndpoint) -> List[Tool]: + async def list(self, endpoint: MCPEndpoint) -> List["Tool"]: """ :param endpoint: The endpoint related to the request. Can be used to dynamically check which tools the user has access to. @@ -37,7 +38,7 @@ async def call( endpoint: MCPEndpoint, name_parameters: Dict[str, Any], call_arguments: Dict[str, Any], - ) -> Sequence[TextContent | ImageContent | EmbeddedResource]: + ) -> Sequence[Union["TextContent", "ImageContent", "EmbeddedResource"]]: """ :param endpoint: The endpoint related to the authenticated user. @@ -57,7 +58,7 @@ def resolve_name(self, **kwargs): class MCPToolRegistry(Registry[MCPTool]): name = "mcp_tools" - async def list_all_tools(self, endpoint: MCPEndpoint) -> List[Tool]: + async def list_all_tools(self, endpoint: MCPEndpoint) -> List["Tool"]: """ :param endpoint: The endpoint related to the request. Can be used to dynamically check which tools the user has access to. diff --git a/backend/src/baserow/core/mcp/sse.py b/backend/src/baserow/core/mcp/sse.py index acfbd76251..96ba7c01ca 100644 --- a/backend/src/baserow/core/mcp/sse.py +++ b/backend/src/baserow/core/mcp/sse.py @@ -33,23 +33,17 @@ async def handle_sse(request): import logging from contextlib import asynccontextmanager -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import quote from uuid import UUID, uuid4 -import anyio -import mcp.types as types -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from channels.layers import get_channel_layer -from mcp.shared.message import SessionMessage from pydantic import ValidationError -from sse_starlette import EventSourceResponse -from starlette.requests import Request -from starlette.responses import Response -from starlette.types import Receive, Scope, Send logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from starlette.types import Receive, Scope, Send + class DjangoChannelsSseServerTransport: """ @@ -73,7 +67,17 @@ def __init__(self, endpoint: str) -> None: ) @asynccontextmanager - async def connect_sse(self, scope: Scope, receive: Receive, send: Send): + async def connect_sse(self, scope: "Scope", receive: "Receive", send: "Send"): + import anyio + import mcp.types as types + from anyio.streams.memory import ( + MemoryObjectReceiveStream, + MemoryObjectSendStream, + ) + from channels.layers import get_channel_layer + from mcp.shared.message import SessionMessage + from sse_starlette import EventSourceResponse + if scope["type"] != "http": logger.error("connect_sse received non-HTTP request") raise ValueError("connect_sse can only handle HTTP requests") @@ -156,8 +160,14 @@ async def response_wrapper(scope: Scope, receive: Receive, send: Send): yield (read_stream, write_stream) async def handle_post_message( - self, scope: Scope, receive: Receive, send: Send + self, scope: "Scope", receive: "Receive", send: "Send" ) -> None: + import mcp.types as types + from channels.layers import get_channel_layer + from mcp.shared.message import SessionMessage + from starlette.requests import Request + from starlette.responses import Response + logger.debug("Handling POST message") request = Request(scope, receive) diff --git a/backend/src/baserow/core/output_parsers.py b/backend/src/baserow/core/output_parsers.py index 04692feebf..720e791bbf 100644 --- a/backend/src/baserow/core/output_parsers.py +++ b/backend/src/baserow/core/output_parsers.py @@ -2,29 +2,32 @@ from difflib import get_close_matches from typing import Any -from langchain.output_parsers.enum import EnumOutputParser +def get_strict_enum_output_parser(enum: type) -> Any: + from langchain.output_parsers.enum import EnumOutputParser -class StrictEnumOutputParser(EnumOutputParser): - def get_format_instructions(self) -> str: - json_array = json.dumps(self._valid_values) - return f"""Categorize the result following these requirements: + class StrictEnumOutputParser(EnumOutputParser): + def get_format_instructions(self) -> str: + json_array = json.dumps(self._valid_values) + return f"""Categorize the result following these requirements: -- Select only one option from the JSON array below. -- Don't use quotes or commas or partial values, just the option name. -- Choose the option that most closely matches the row values. + - Select only one option from the JSON array below. + - Don't use quotes or commas or partial values, just the option name. + - Choose the option that most closely matches the row values. -```json -{json_array} -```""" # nosec this falsely marks as hardcoded sql expression, but it's not related - # to SQL at all. + ```json + {json_array} + ```""" # nosec this falsely marks as hardcoded sql expression, but it's not related + # to SQL at all. - def parse(self, response: str) -> Any: - response = response.strip() - # Sometimes the LLM responds with a quotes value or with part of the value if - # it contains a comma. Finding the close matches helps with selecting the - # right value. - closest_matches = get_close_matches( - response, self._valid_values, n=1, cutoff=0.0 - ) - return super().parse(closest_matches[0]) + def parse(self, response: str) -> Any: + response = response.strip() + # Sometimes the LLM responds with a quotes value or with part of the value + # if it contains a comma. Finding the close matches helps with selecting the + # right value. + closest_matches = get_close_matches( + response, self._valid_values, n=1, cutoff=0.0 + ) + return super().parse(closest_matches[0]) + + return StrictEnumOutputParser(enum=enum) diff --git a/backend/src/baserow/core/posthog.py b/backend/src/baserow/core/posthog.py index 90803f5ca3..5a11500a66 100644 --- a/backend/src/baserow/core/posthog.py +++ b/backend/src/baserow/core/posthog.py @@ -7,18 +7,30 @@ from django.dispatch import receiver from loguru import logger -from posthog import Posthog from baserow.core.action.signals import ActionCommandType, action_done from baserow.core.models import Workspace from baserow.core.utils import exception_capturer -posthog_client = Posthog( - settings.POSTHOG_PROJECT_API_KEY, - settings.POSTHOG_HOST, - # disabled=True will automatically avoid sending any data, even if capture is called - disabled=not settings.POSTHOG_ENABLED, -) +_posthog = None + + +def get_posthog_client(): + """ + Returns the Posthog instance configured according to the settings. If Posthog is + disabled, the instance will have the `disabled` attribute set to True. + """ + + from posthog import Posthog + + global _posthog + if _posthog is None: + _posthog = Posthog( + settings.POSTHOG_PROJECT_API_KEY, + host=settings.POSTHOG_HOST, + disabled=not settings.POSTHOG_ENABLED, + ) + return _posthog def capture_event(distinct_id: str, event: str, properties: dict): @@ -31,7 +43,8 @@ def capture_event(distinct_id: str, event: str, properties: dict): the event. """ - if not settings.POSTHOG_ENABLED: + posthog_client = get_posthog_client() + if posthog_client.disabled: return try: diff --git a/backend/src/baserow/core/sentry.py b/backend/src/baserow/core/sentry.py index 932b2932e7..7b35bde147 100644 --- a/backend/src/baserow/core/sentry.py +++ b/backend/src/baserow/core/sentry.py @@ -1,7 +1,5 @@ from django.contrib.auth import get_user_model -from sentry_sdk import set_user - def setup_user_in_sentry(user): """ @@ -11,6 +9,8 @@ def setup_user_in_sentry(user): :param user: The user that needs to be set in the Sentry context. """ + from sentry_sdk import set_user + set_user({"id": user.id}) diff --git a/backend/src/baserow/core/telemetry/tasks.py b/backend/src/baserow/core/telemetry/tasks.py index 70a6ea66e4..1d503c7e30 100644 --- a/backend/src/baserow/core/telemetry/tasks.py +++ b/backend/src/baserow/core/telemetry/tasks.py @@ -1,6 +1,5 @@ from celery import Task from celery.signals import worker_process_init -from opentelemetry import baggage, context from baserow.core.telemetry.telemetry import setup_logging, setup_telemetry from baserow.core.telemetry.utils import otel_is_enabled @@ -16,6 +15,8 @@ def initialize_otel(**kwargs): class BaserowTelemetryTask(Task): def __call__(self, *args, **kwargs): + from opentelemetry import baggage, context + if otel_is_enabled(): # Safely attach and detach baggage context within the same task call curr_ctx = context.get_current() diff --git a/backend/src/baserow/core/telemetry/telemetry.py b/backend/src/baserow/core/telemetry/telemetry.py index 6e2d22cf99..9e67e6b9d0 100644 --- a/backend/src/baserow/core/telemetry/telemetry.py +++ b/backend/src/baserow/core/telemetry/telemetry.py @@ -5,31 +5,12 @@ from celery import signals from opentelemetry import metrics, trace from opentelemetry._logs import set_logger_provider -from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter -from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter -from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter -from opentelemetry.instrumentation.botocore import BotocoreInstrumentor -from opentelemetry.instrumentation.celery import CeleryInstrumentor -from opentelemetry.instrumentation.django import DjangoInstrumentor -from opentelemetry.instrumentation.redis import RedisInstrumentor -from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler -from opentelemetry.sdk._logs._internal.export import BatchLogRecordProcessor -from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics._internal.export import PeriodicExportingMetricReader -from opentelemetry.trace import ProxyTracerProvider from baserow.core.psycopg import is_psycopg3 from baserow.core.telemetry.provider import DifferentSamplerPerLibraryTracerProvider from baserow.core.telemetry.utils import BatchBaggageSpanProcessor, otel_is_enabled -if is_psycopg3: - from opentelemetry.instrumentation.psycopg import PsycopgInstrumentor -else: - from opentelemetry.instrumentation.psycopg2 import ( - Psycopg2Instrumentor as PsycopgInstrumentor, - ) - class LogGuruCompatibleLoggerHandler(LoggingHandler): def emit(self, record: logging.LogRecord) -> None: @@ -94,6 +75,14 @@ def setup_telemetry(add_django_instrumentation: bool): process that is processing requests. Don't enable this for a celery process etc. """ + from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( + OTLPMetricExporter, + ) + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.sdk.metrics import MeterProvider + from opentelemetry.sdk.metrics._internal.export import PeriodicExportingMetricReader + from opentelemetry.trace import ProxyTracerProvider + if otel_is_enabled(): existing_provider = trace.get_tracer_provider() if not isinstance(existing_provider, ProxyTracerProvider): @@ -126,6 +115,9 @@ def setup_telemetry(add_django_instrumentation: bool): def _setup_log_exporting(logger): from django.conf import settings + from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter + from opentelemetry.sdk._logs._internal.export import BatchLogRecordProcessor + logger_provider = LoggerProvider() set_logger_provider(logger_provider) exporter = OTLPLogExporter() @@ -154,6 +146,18 @@ def count_task(sender, **kwargs): def _setup_standard_backend_instrumentation(): + from opentelemetry.instrumentation.botocore import BotocoreInstrumentor + from opentelemetry.instrumentation.celery import CeleryInstrumentor + from opentelemetry.instrumentation.redis import RedisInstrumentor + from opentelemetry.instrumentation.requests import RequestsInstrumentor + + if is_psycopg3: + from opentelemetry.instrumentation.psycopg import PsycopgInstrumentor + else: + from opentelemetry.instrumentation.psycopg2 import ( + Psycopg2Instrumentor as PsycopgInstrumentor, + ) + BotocoreInstrumentor().instrument() PsycopgInstrumentor().instrument() RedisInstrumentor().instrument() @@ -162,4 +166,6 @@ def _setup_standard_backend_instrumentation(): def _setup_django_process_instrumentation(): + from opentelemetry.instrumentation.django import DjangoInstrumentor + DjangoInstrumentor().instrument() diff --git a/backend/src/baserow/core/two_factor_auth/registries.py b/backend/src/baserow/core/two_factor_auth/registries.py index 7b9517798f..35fb939041 100644 --- a/backend/src/baserow/core/two_factor_auth/registries.py +++ b/backend/src/baserow/core/two_factor_auth/registries.py @@ -9,8 +9,6 @@ from django.conf import settings from django.contrib.auth.models import AbstractUser -import pyotp -import qrcode from rest_framework import serializers from baserow.core.registry import ( @@ -112,6 +110,9 @@ def configure( provider: TOTPAuthProviderModel | None = None, **kwargs, ) -> TOTPAuthProviderModel: + import pyotp + import qrcode + if provider and provider.enabled: raise TwoFactorAuthAlreadyConfigured @@ -192,6 +193,8 @@ def is_enabled(self, provider) -> bool: return provider.enabled def verify(self, **kwargs) -> bool: + import pyotp + email = kwargs.get("email") code = kwargs.get("code") backup_code = kwargs.get("backup_code") diff --git a/backend/src/baserow/core/user_files/handler.py b/backend/src/baserow/core/user_files/handler.py index 3db8931292..12131d5a1e 100644 --- a/backend/src/baserow/core/user_files/handler.py +++ b/backend/src/baserow/core/user_files/handler.py @@ -6,7 +6,7 @@ import secrets from io import BytesIO from os.path import join -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from urllib.parse import urlparse from zipfile import ZipFile @@ -19,7 +19,6 @@ import advocate from advocate.exceptions import UnacceptableAddressException from loguru import logger -from PIL import Image, ImageOps from requests.exceptions import RequestException from baserow.core.import_export.utils import file_chunk_generator @@ -40,6 +39,9 @@ ) from .models import deconstruct_user_file_regex +if TYPE_CHECKING: + from PIL import Image + MIME_TYPE_UNKNOWN = "application/octet-stream" @@ -167,7 +169,7 @@ def generate_unique(self, sha256_hash, extension, length=32, max_tries=1000): def generate_and_save_image_thumbnails( self, - image: Image, + image: "Image", user_file_name: str, storage: Storage | None = None, only_with_name: str | None = None, @@ -186,6 +188,8 @@ def generate_and_save_image_thumbnails( :raises ValueError: If the provided user file is not a valid image. """ + from PIL import Image, ImageOps + storage = storage or get_default_storage() # adjust image orientation, if exif data differs from the image data @@ -246,6 +250,8 @@ def upload_user_file(self, user, file_name, stream, storage=None): :rtype: UserFile """ + from PIL import Image + if not hasattr(stream, "read"): raise InvalidFileStreamError("The provided stream is not readable.") diff --git a/backend/tests/baserow/core/test_posthog.py b/backend/tests/baserow/core/test_posthog.py index 7e5b4e7248..fe5b78b0ec 100644 --- a/backend/tests/baserow/core/test_posthog.py +++ b/backend/tests/baserow/core/test_posthog.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import MagicMock, patch from django.test.utils import override_settings @@ -6,7 +6,20 @@ from baserow.core.action.registries import ActionType from baserow.core.action.signals import ActionCommandType -from baserow.core.posthog import capture_event_action_done, capture_user_event +from baserow.core.posthog import ( + capture_event_action_done, + capture_user_event, + get_posthog_client, +) + + +@pytest.fixture(autouse=True) +def reset_posthog_instance(): + from baserow.core import posthog + + posthog._posthog = None + yield + posthog._posthog = None class TestActionType(ActionType): @@ -26,21 +39,29 @@ def scope(cls, *args, **kwargs): @pytest.mark.django_db @override_settings(POSTHOG_ENABLED=False) -@patch("baserow.core.posthog.posthog_client") -def test_not_capture_event_if_not_enabled(mock_posthog, data_fixture): +def test_not_capture_event_if_not_enabled(data_fixture): + posthog = get_posthog_client() + + assert posthog.disabled is True + posthog.capture = MagicMock() + user = data_fixture.create_user() capture_user_event(user, "test", {}) - mock_posthog.capture.assert_not_called() + posthog.capture.assert_not_called() @pytest.mark.django_db @override_settings(POSTHOG_ENABLED=True) -@patch("baserow.core.posthog.posthog_client") -def test_capture_event_if_enabled(mock_posthog, data_fixture): +def test_capture_event_if_enabled(data_fixture): + posthog = get_posthog_client() + + assert posthog.disabled is False + posthog.capture = MagicMock() + user = data_fixture.create_user() workspace = data_fixture.create_workspace() capture_user_event(user, "test", {}, session="session", workspace=workspace) - mock_posthog.capture.assert_called_once_with( + posthog.capture.assert_called_once_with( distinct_id=user.id, event="test", properties={ diff --git a/changelog/entries/unreleased/refactor/lazy_import_libraries_to_reduce_initial_memory_footprint.json b/changelog/entries/unreleased/refactor/lazy_import_libraries_to_reduce_initial_memory_footprint.json new file mode 100644 index 0000000000..bcace755a6 --- /dev/null +++ b/changelog/entries/unreleased/refactor/lazy_import_libraries_to_reduce_initial_memory_footprint.json @@ -0,0 +1,9 @@ +{ + "type": "refactor", + "message": "Lazy import libraries to reduce initial memory footprint", + "issue_origin": "github", + "issue_number": null, + "domain": "core", + "bullet_points": [], + "created_at": "2025-12-12" +} \ No newline at end of file diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index a66db0347a..48ef065bbf 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -20,7 +20,7 @@ services: backend: image: baserow_backend_dev:latest environment: - - BASEROW_BACKEND_DEBUGGER_ENABLED=${BASEROW_BACKEND_DEBUGGER_ENABLED:-True} + - BASEROW_BACKEND_DEBUGGER_ENABLED - BASEROW_BACKEND_DEBUGGER_PORT=${BASEROW_BACKEND_DEBUGGER_PORT:-5678} - BASEROW_DANGEROUS_SILKY_ANALYZE_QUERIES - OTEL_EXPORTER_OTLP_ENDPOINT=http://otel-collector:4318 diff --git a/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py b/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py index 66f8b15228..01ea65af59 100644 --- a/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py +++ b/enterprise/backend/src/baserow_enterprise/api/sso/saml/validators.py @@ -3,8 +3,6 @@ from django.db.models import QuerySet from rest_framework import serializers -from saml2.xml.schema import XMLSchemaError -from saml2.xml.schema import validate as validate_saml_metadata_schema from baserow_enterprise.sso.saml.exceptions import SamlProviderForDomainAlreadyExists from baserow_enterprise.sso.saml.models import SamlAuthProviderModel @@ -27,6 +25,9 @@ def validate_unique_saml_domain( def validate_saml_metadata(value): + from saml2.xml.schema import XMLSchemaError + from saml2.xml.schema import validate as validate_saml_metadata_schema + metadata = io.StringIO(value) try: validate_saml_metadata_schema(metadata) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py index facea32576..9b5259092d 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py @@ -11,10 +11,9 @@ from uuid import uuid4 import udspy -from posthog.ai.openai import AsyncOpenAI from udspy.callback import BaseCallback -from baserow.core.posthog import posthog_client +from baserow.core.posthog import get_posthog_client from baserow_enterprise.assistant.models import AssistantChat @@ -64,6 +63,8 @@ def trace(self, chat: AssistantChat, human_message: str): :param human_message: The initial user message """ + from posthog.ai.openai import AsyncOpenAI + self.chat = chat self.human_msg = human_message @@ -81,11 +82,16 @@ def trace(self, chat: AssistantChat, human_message: str): # patch the OpenAI client to automatically send the generation event lm = udspy.settings._context_lm.get() openai_client = lm.client - if not isinstance(openai_client, AsyncOpenAI): + + # Check if client is already a PostHog-wrapped client by checking its + # module. We avoid isinstance() here because it can fail when the class + # is mocked in tests. + is_posthog_client = "posthog" in type(openai_client).__module__ + if not is_posthog_client: lm.client = AsyncOpenAI( api_key=openai_client.api_key, base_url=openai_client.base_url, - posthog_client=posthog_client, + posthog_client=get_posthog_client(), ) exception = None @@ -130,6 +136,7 @@ def _capture_event(self, event: str, **kwargs): else: kwargs["properties"] = default_props + posthog_client = get_posthog_client() posthog_client.capture( distinct_id=str(self.user_id), event=event, diff --git a/enterprise/backend/src/baserow_enterprise/data_sync/jira_issues_data_sync.py b/enterprise/backend/src/baserow_enterprise/data_sync/jira_issues_data_sync.py index 0763e759a1..bc26b777d8 100644 --- a/enterprise/backend/src/baserow_enterprise/data_sync/jira_issues_data_sync.py +++ b/enterprise/backend/src/baserow_enterprise/data_sync/jira_issues_data_sync.py @@ -5,7 +5,6 @@ import advocate from advocate import UnacceptableAddressException from baserow_premium.license.handler import LicenseHandler -from jira2markdown import convert from requests.auth import HTTPBasicAuth from requests.exceptions import JSONDecodeError, RequestException @@ -293,6 +292,8 @@ def get_all_rows( instance, progress_builder: Optional[ChildProgressBuilder] = None, ) -> List[Dict]: + from jira2markdown import convert + issue_list = [] progress = ChildProgressBuilder.build(progress_builder, child_total=10) fetched_issues = self._fetch_issues( diff --git a/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py b/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py index 45a4d3649f..eb6907d6e5 100644 --- a/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py +++ b/enterprise/backend/src/baserow_enterprise/sso/saml/handler.py @@ -1,6 +1,6 @@ import base64 import binascii -from typing import Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict, Optional from django.conf import settings from django.contrib.auth.models import AbstractUser @@ -8,10 +8,6 @@ from defusedxml import ElementTree from loguru import logger -from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT, entity -from saml2.client import Saml2Client -from saml2.config import Config as Saml2Config -from saml2.response import AuthnResponse from baserow.core.auth_provider.types import UserInfo from baserow_enterprise.api.sso.utils import get_valid_frontend_url @@ -26,6 +22,10 @@ InvalidSamlResponse, ) +if TYPE_CHECKING: + from saml2.client import Saml2Client + from saml2.response import AuthnResponse + class SamlAuthProviderHandler: model_class: Model = SamlAuthProviderModel @@ -34,7 +34,7 @@ class SamlAuthProviderHandler: def prepare_saml_client( cls, saml_auth_provider: SamlAuthProviderModelMixin, - ) -> Saml2Client: + ) -> "Saml2Client": """ Returns a SAML client with the correct configuration for the given authentication provider. @@ -44,6 +44,10 @@ def prepare_saml_client( :return: The SAML client that can be used to authenticate the user. """ + from saml2 import BINDING_HTTP_POST, BINDING_HTTP_REDIRECT + from saml2.client import Saml2Client + from saml2.config import Config as Saml2Config + acs_url = saml_auth_provider.get_type().get_acs_absolute_url( saml_auth_provider.specific ) @@ -74,7 +78,7 @@ def prepare_saml_client( return Saml2Client(config=sp_config) @classmethod - def check_authn_response_is_valid_or_raise(cls, authn_response: AuthnResponse): + def check_authn_response_is_valid_or_raise(cls, authn_response: "AuthnResponse"): """ Checks if the authn response is valid and raises an exception if not. @@ -245,6 +249,8 @@ def sign_in_user_from_saml_response( :return: The user that was signed in. """ + from saml2 import entity + try: saml_auth_provider = cls.get_saml_auth_provider_from_saml_response( saml_response, base_queryset=base_queryset diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py index 0bebff3a50..387baa0dd9 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py @@ -29,10 +29,15 @@ def mock_posthog_openai(): @pytest.mark.django_db class TestPosthogTracingCallback: - @patch("baserow_enterprise.assistant.telemetry.posthog_client") - def test_trace_context_manager_success(self, mock_posthog, assistant_chat_fixture): + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_trace_context_manager_success( + self, mock_get_client, assistant_chat_fixture + ): """Test the trace context manager in a successful execution flow.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() with callback.trace(assistant_chat_fixture, "Hello"): @@ -61,12 +66,15 @@ def test_trace_context_manager_success(self, mock_posthog, assistant_chat_fixtur assert props["$ai_input_state"] == {"user_message": "Hello"} assert props["$ai_output_state"] is None - @patch("baserow_enterprise.assistant.telemetry.posthog_client") + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") def test_trace_context_manager_exception( - self, mock_posthog, assistant_chat_fixture + self, mock_get_client, assistant_chat_fixture ): """Test the trace context manager when an exception occurs.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() with pytest.raises(ValueError): @@ -79,10 +87,13 @@ def test_trace_context_manager_exception( assert call_args.kwargs["event"] == "$ai_trace" assert call_args.kwargs["properties"]["$ai_is_error"] is True - @patch("baserow_enterprise.assistant.telemetry.posthog_client") - def test_on_module_start_end(self, mock_posthog, assistant_chat_fixture): + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_on_module_start_end(self, mock_get_client, assistant_chat_fixture): """Test module execution tracing.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() # Initialize context manually callback.chat = assistant_chat_fixture @@ -165,10 +176,13 @@ def test_on_lm_start(self, assistant_chat_fixture): assert inputs["kwargs"]["posthog_trace_id"] == "trace-1" assert inputs["kwargs"]["posthog_properties"]["$ai_provider"] == "openai" - @patch("baserow_enterprise.assistant.telemetry.posthog_client") - def test_on_tool_start_end(self, mock_posthog, assistant_chat_fixture): + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_on_tool_start_end(self, mock_get_client, assistant_chat_fixture): """Test tool execution tracing.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() callback.chat = assistant_chat_fixture callback.user_id = str(assistant_chat_fixture.user_id) @@ -199,10 +213,15 @@ def test_on_tool_start_end(self, mock_posthog, assistant_chat_fixture): assert props["$ai_input_state"] == {"arg": "val"} assert props["$ai_output_state"] == "result" - @patch("baserow_enterprise.assistant.telemetry.posthog_client") - def test_on_module_end_with_exception(self, mock_posthog, assistant_chat_fixture): + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_on_module_end_with_exception( + self, mock_get_client, assistant_chat_fixture + ): """Test that exception string is captured in $ai_output_state.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() callback.chat = assistant_chat_fixture callback.user_id = str(assistant_chat_fixture.user_id) @@ -239,10 +258,13 @@ def test_on_module_end_with_exception(self, mock_posthog, assistant_chat_fixture assert props["$ai_is_error"] is True assert props["$ai_output_state"] == "Test error message" - @patch("baserow_enterprise.assistant.telemetry.posthog_client") - def test_on_tool_end_with_exception(self, mock_posthog, assistant_chat_fixture): + @patch("baserow_enterprise.assistant.telemetry.get_posthog_client") + def test_on_tool_end_with_exception(self, mock_get_client, assistant_chat_fixture): """Test that exception string is captured in $ai_output_state for tools.""" + mock_posthog = MagicMock() + mock_get_client.return_value = mock_posthog + callback = PosthogTracingCallback() callback.chat = assistant_chat_fixture callback.user_id = str(assistant_chat_fixture.user_id) diff --git a/plugin-boilerplate/{{ cookiecutter.project_slug }}/docker-compose.multi-service.dev.yml b/plugin-boilerplate/{{ cookiecutter.project_slug }}/docker-compose.multi-service.dev.yml index e8c8d37022..8ea682114e 100644 --- a/plugin-boilerplate/{{ cookiecutter.project_slug }}/docker-compose.multi-service.dev.yml +++ b/plugin-boilerplate/{{ cookiecutter.project_slug }}/docker-compose.multi-service.dev.yml @@ -36,7 +36,7 @@ services: - PUBLIC_WEB_FRONTEND_URL=http://localhost:3000 - MEDIA_URL=http://localhost:4000/ - BASEROW_PUBLIC_URL= - - BASEROW_BACKEND_DEBUGGER_ENABLED=${BASEROW_BACKEND_DEBUGGER_ENABLED:-True} + - BASEROW_BACKEND_DEBUGGER_ENABLED - BASEROW_BACKEND_DEBUGGER_PORT=${BASEROW_BACKEND_DEBUGGER_PORT:-5678} ports: - "${HOST_PUBLISH_IP:-127.0.0.1}:8000:8000" diff --git a/premium/backend/src/baserow_premium/api/fields/views.py b/premium/backend/src/baserow_premium/api/fields/views.py index 9465ae40f7..41cbcdc75b 100644 --- a/premium/backend/src/baserow_premium/api/fields/views.py +++ b/premium/backend/src/baserow_premium/api/fields/views.py @@ -1,13 +1,13 @@ from django.db import transaction from baserow_premium.fields.actions import GenerateFormulaWithAIActionType +from baserow_premium.fields.exceptions import AiFieldOutputParserException from baserow_premium.fields.job_types import GenerateAIValuesJobType from baserow_premium.fields.models import AIField from baserow_premium.license.features import PREMIUM from baserow_premium.license.handler import LicenseHandler from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes from drf_spectacular.utils import extend_schema -from langchain_core.exceptions import OutputParserException from rest_framework import status from rest_framework.permissions import IsAuthenticated from rest_framework.request import Request @@ -187,7 +187,7 @@ class GenerateFormulaWithAIView(APIView): ModelDoesNotBelongToType: ERROR_MODEL_DOES_NOT_BELONG_TO_TYPE, TableDoesNotExist: ERROR_TABLE_DOES_NOT_EXIST, GenerativeAIPromptError: ERROR_GENERATIVE_AI_PROMPT, - OutputParserException: ERROR_OUTPUT_PARSER, + AiFieldOutputParserException: ERROR_OUTPUT_PARSER, } ) @validate_body(GenerateFormulaWithAIRequestSerializer) diff --git a/premium/backend/src/baserow_premium/export/exporter_types.py b/premium/backend/src/baserow_premium/export/exporter_types.py index 257d14b404..800a6e8c63 100644 --- a/premium/backend/src/baserow_premium/export/exporter_types.py +++ b/premium/backend/src/baserow_premium/export/exporter_types.py @@ -3,8 +3,8 @@ from typing import List, Optional, Type import zipstream +from baserow_premium.license.features import PREMIUM from baserow_premium.license.handler import LicenseHandler -from openpyxl import Workbook from baserow.config.settings.base import BASEROW_DEFAULT_ZIP_COMPRESS_LEVEL from baserow.contrib.database.api.export.serializers import ( @@ -18,7 +18,6 @@ from baserow.contrib.database.views.view_types import GridViewType from baserow.core.storage import ExportZipFile, get_default_storage -from ..license.features import PREMIUM from .serializers import ExcelExporterOptionsSerializer, FileExporterOptionsSerializer from .utils import get_unique_name, safe_xml_tag_name, to_xml @@ -190,6 +189,8 @@ def write_to_file( Excel file. """ + from openpyxl import Workbook + workbook = Workbook(write_only=True) worksheet = workbook.create_sheet() diff --git a/premium/backend/src/baserow_premium/fields/ai_field_output_types.py b/premium/backend/src/baserow_premium/fields/ai_field_output_types.py index 98d8adccaf..d281e6872a 100644 --- a/premium/backend/src/baserow_premium/fields/ai_field_output_types.py +++ b/premium/backend/src/baserow_premium/fields/ai_field_output_types.py @@ -1,13 +1,10 @@ import enum -from langchain_core.exceptions import OutputParserException -from langchain_core.prompts import PromptTemplate - from baserow.contrib.database.fields.field_types import ( LongTextFieldType, SingleSelectFieldType, ) -from baserow.core.output_parsers import StrictEnumOutputParser +from baserow.core.output_parsers import get_strict_enum_output_parser from .registries import AIFieldOutputType @@ -29,9 +26,11 @@ def get_output_parser(self, ai_field): for option in ai_field.select_options.all() }, ) - return StrictEnumOutputParser(enum=choices) + return get_strict_enum_output_parser(enum=choices) def format_prompt(self, prompt, ai_field): + from langchain_core.prompts import PromptTemplate + output_parser = self.get_output_parser(ai_field) format_instructions = output_parser.get_format_instructions() prompt = PromptTemplate( @@ -43,6 +42,8 @@ def format_prompt(self, prompt, ai_field): return message def parse_output(self, output, ai_field): + from langchain_core.exceptions import OutputParserException + if not output: return None diff --git a/premium/backend/src/baserow_premium/fields/exceptions.py b/premium/backend/src/baserow_premium/fields/exceptions.py index 69bafa6227..8c8e923b31 100644 --- a/premium/backend/src/baserow_premium/fields/exceptions.py +++ b/premium/backend/src/baserow_premium/fields/exceptions.py @@ -3,3 +3,9 @@ class GenerativeAITypeDoesNotSupportFileField(Exception): Raised when file field is not supported for the particular generative AI model type. """ + + +class AiFieldOutputParserException(Exception): + """ + Raised when the output from the AI model could not be parsed correctly. + """ diff --git a/premium/backend/src/baserow_premium/fields/handler.py b/premium/backend/src/baserow_premium/fields/handler.py index 3096c514a8..f33bc8ed67 100644 --- a/premium/backend/src/baserow_premium/fields/handler.py +++ b/premium/backend/src/baserow_premium/fields/handler.py @@ -1,10 +1,8 @@ import json from typing import Optional +from baserow_premium.fields.exceptions import AiFieldOutputParserException from baserow_premium.prompts import get_generate_formula_prompt -from langchain_core.exceptions import OutputParserException -from langchain_core.output_parsers import JsonOutputParser -from langchain_core.prompts import PromptTemplate from baserow.contrib.database.fields.registries import field_type_registry from baserow.contrib.database.table.models import Table @@ -38,6 +36,10 @@ def generate_formula_with_ai( :return: The generated model. """ + from langchain_core.exceptions import OutputParserException + from langchain_core.output_parsers import JsonOutputParser + from langchain_core.prompts import PromptTemplate + generative_ai_model_type = generative_ai_model_type_registry.get(ai_type) ai_models = generative_ai_model_type.get_enabled_models( table.database.workspace @@ -72,6 +74,6 @@ def generate_formula_with_ai( try: return output_parser.parse(response)["formula"] except (OutputParserException, TypeError) as e: - raise OutputParserException( + raise AiFieldOutputParserException( "The model didn't respond with the correct output. " "Please try again." ) from e diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_handler.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_handler.py index 0c8fe06427..101e1f3fce 100644 --- a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_handler.py +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_handler.py @@ -1,8 +1,8 @@ from unittest.mock import patch import pytest +from baserow_premium.fields.exceptions import AiFieldOutputParserException from baserow_premium.fields.handler import AIFieldHandler -from langchain_core.exceptions import OutputParserException from baserow.core.generative_ai.exceptions import ( GenerativeAITypeDoesNotExist, @@ -78,7 +78,7 @@ def test_generate_formula_output_parser_error(premium_data_fixture, api_client): ) table = premium_data_fixture.create_database_table(name="table", database=database) - with pytest.raises(OutputParserException): + with pytest.raises(AiFieldOutputParserException): AIFieldHandler.generate_formula_with_ai( table, "test_generative_ai", diff --git a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py index 2b9b09ab66..00b5aab6a4 100644 --- a/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py +++ b/premium/backend/tests/baserow_premium_tests/fields/test_ai_field_output_types.py @@ -1,7 +1,7 @@ import enum import pytest -from baserow_premium.fields.ai_field_output_types import StrictEnumOutputParser +from baserow_premium.fields.ai_field_output_types import get_strict_enum_output_parser from langchain_core.prompts import PromptTemplate from baserow.core.generative_ai.registries import ( @@ -21,7 +21,7 @@ def test_strict_enum_output_parser(): "OPTION_4": "A,B,C", }, ) - output_parser = StrictEnumOutputParser(enum=choices) + output_parser = get_strict_enum_output_parser(enum=choices) format_instructions = output_parser.get_format_instructions() prompt = "What is a motorcycle?" prompt = PromptTemplate(