diff --git a/backend/requirements/base.in b/backend/requirements/base.in index 8ae549aae6..f4d8666ab5 100644 --- a/backend/requirements/base.in +++ b/backend/requirements/base.in @@ -64,7 +64,7 @@ Brotli==1.1.0 loguru==0.7.2 django-cachalot==2.6.2 celery-singleton==0.3.1 -posthog==3.5.0 +posthog==7.0.1 https://github.com/fellowapp/prosemirror-py/archive/refs/tags/v0.3.5.zip rich==13.7.1 tzdata==2025.2 @@ -88,4 +88,4 @@ httpcore==1.0.9 # Pinned to address vulnerability. genson==1.3.0 pyotp==2.9.0 qrcode==8.2 -udspy==0.1.7 +udspy==0.1.8 diff --git a/backend/requirements/base.txt b/backend/requirements/base.txt index f93a1d83df..2b0e4ac2a3 100644 --- a/backend/requirements/base.txt +++ b/backend/requirements/base.txt @@ -125,6 +125,7 @@ distro==1.9.0 # via # anthropic # openai + # posthog dj-database-url==2.1.0 # via -r base.in django==5.0.14 @@ -311,8 +312,6 @@ mdurl==0.1.2 # via markdown-it-py mistralai==1.1.0 # via -r base.in -monotonic==1.6 - # via posthog msgpack==1.1.0 # via channels-redis mypy-extensions==1.0.0 @@ -454,7 +453,7 @@ pgvector==0.4.1 # via -r base.in pillow==10.3.0 # via -r base.in -posthog==3.5.0 +posthog==7.0.1 # via -r base.in prometheus-client==0.21.1 # via flower @@ -665,6 +664,7 @@ typing-extensions==4.11.0 # opentelemetry-exporter-otlp-proto-http # opentelemetry-sdk # opentelemetry-semantic-conventions + # posthog # prosemirror # pydantic # pydantic-core @@ -679,7 +679,7 @@ tzdata==2025.2 # -r base.in # django-celery-beat # kombu -udspy==0.1.7 +udspy==0.1.8 # via -r base.in unicodecsv==0.14.1 # via -r base.in diff --git a/backend/src/baserow/config/settings/base.py b/backend/src/baserow/config/settings/base.py index 40d402d164..fc7c8fba66 100644 --- a/backend/src/baserow/config/settings/base.py +++ b/backend/src/baserow/config/settings/base.py @@ -11,7 +11,6 @@ from django.core.exceptions import ImproperlyConfigured import dj_database_url -import posthog import sentry_sdk from corsheaders.defaults import default_headers from sentry_sdk.integrations.django import DjangoIntegration @@ -1243,13 +1242,8 @@ def __setitem__(self, key, value): ) POSTHOG_PROJECT_API_KEY = os.getenv("POSTHOG_PROJECT_API_KEY", "") -POSTHOG_HOST = os.getenv("POSTHOG_HOST", "") -POSTHOG_ENABLED = POSTHOG_PROJECT_API_KEY and POSTHOG_HOST -if POSTHOG_ENABLED: - posthog.project_api_key = POSTHOG_PROJECT_API_KEY - posthog.host = POSTHOG_HOST -else: - posthog.disabled = True +POSTHOG_HOST = os.getenv("POSTHOG_HOST") or None +POSTHOG_ENABLED = bool(POSTHOG_PROJECT_API_KEY) BASEROW_BUILDER_DOMAINS = os.getenv("BASEROW_BUILDER_DOMAINS", None) BASEROW_BUILDER_DOMAINS = ( diff --git a/backend/src/baserow/core/posthog.py b/backend/src/baserow/core/posthog.py index 42f7ea5c80..90803f5ca3 100644 --- a/backend/src/baserow/core/posthog.py +++ b/backend/src/baserow/core/posthog.py @@ -6,13 +6,20 @@ from django.contrib.auth.models import AbstractUser from django.dispatch import receiver -import posthog from loguru import logger +from posthog import Posthog from baserow.core.action.signals import ActionCommandType, action_done from baserow.core.models import Workspace from baserow.core.utils import exception_capturer +posthog_client = Posthog( + settings.POSTHOG_PROJECT_API_KEY, + settings.POSTHOG_HOST, + # disabled=True will automatically avoid sending any data, even if capture is called + disabled=not settings.POSTHOG_ENABLED, +) + def capture_event(distinct_id: str, event: str, properties: dict): """ @@ -28,7 +35,7 @@ def capture_event(distinct_id: str, event: str, properties: dict): return try: - posthog.capture( + posthog_client.capture( distinct_id=distinct_id, event=event, properties=properties, diff --git a/backend/tests/baserow/core/test_posthog.py b/backend/tests/baserow/core/test_posthog.py index cbbfbebd7a..7e5b4e7248 100644 --- a/backend/tests/baserow/core/test_posthog.py +++ b/backend/tests/baserow/core/test_posthog.py @@ -26,7 +26,7 @@ def scope(cls, *args, **kwargs): @pytest.mark.django_db @override_settings(POSTHOG_ENABLED=False) -@patch("baserow.core.posthog.posthog") +@patch("baserow.core.posthog.posthog_client") def test_not_capture_event_if_not_enabled(mock_posthog, data_fixture): user = data_fixture.create_user() capture_user_event(user, "test", {}) @@ -35,7 +35,7 @@ def test_not_capture_event_if_not_enabled(mock_posthog, data_fixture): @pytest.mark.django_db @override_settings(POSTHOG_ENABLED=True) -@patch("baserow.core.posthog.posthog") +@patch("baserow.core.posthog.posthog_client") def test_capture_event_if_enabled(mock_posthog, data_fixture): user = data_fixture.create_user() workspace = data_fixture.create_workspace() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index 3b3bb99b5c..4c7ce85821 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -15,6 +15,7 @@ AssistantMessageCancelled, AssistantModelNotSupportedError, ) +from baserow_enterprise.assistant.telemetry import PosthogTracingCallback from baserow_enterprise.assistant.tools.navigation.types import AnyNavigationRequestType from baserow_enterprise.assistant.tools.navigation.utils import unsafe_navigate_to from baserow_enterprise.assistant.tools.registries import assistant_tool_registry @@ -190,22 +191,42 @@ def _init_assistant(self): self._user, self._workspace, self.tool_helpers ) ] - self.callbacks = AssistantCallbacks(self.tool_helpers) + + self._assistant_callbacks = AssistantCallbacks(self.tool_helpers) + self._telemetry_callbacks = PosthogTracingCallback() + self._callbacks = [self._assistant_callbacks, self._telemetry_callbacks] module_kwargs = { "temperature": settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_TEMPERATURE, "response_format": {"type": "json_object"}, } - - self.search_user_docs_tool = next( - (tool for tool in tools if tool.name == "search_user_docs"), None - ) + self.search_user_docs_tool = self._get_search_user_docs_tool(tools) self.agent_tools = tools self._request_router = udspy.ChainOfThought(RequestRouter, **module_kwargs) self._assistant = udspy.ReAct( ChatSignature, tools=self.agent_tools, max_iters=20, **module_kwargs ) + def _get_search_user_docs_tool( + self, tools: list[udspy.Tool | Callable] + ) -> udspy.Tool | None: + """ + Retrieves the search_user_docs tool from the list of tools if available. + + :param tools: The list of tools to search through. + :return: The search_user_docs as udspy.Tool or None if not found. + """ + + search_user_docs_tool = next( + (tool for tool in tools if tool.name == "search_user_docs"), None + ) + if search_user_docs_tool is None or isinstance( + search_user_docs_tool, udspy.Tool + ): + return search_user_docs_tool + + return udspy.Tool(search_user_docs_tool) + async def acreate_chat_message( self, role: AssistantChatMessage.Role, @@ -360,7 +381,7 @@ async def _acreate_ai_message_response( :return: The created AiMessage instance to return to the user. """ - sources = self.callbacks.sources + sources = self._assistant_callbacks.sources ai_msg = await self.acreate_chat_message( AssistantChatMessage.Role.AI, prediction.answer, @@ -449,7 +470,7 @@ async def _process_router_stream( messages.append( AiMessageChunk( content=event.content, - sources=self.callbacks.sources, + sources=self._assistant_callbacks.sources, ) ) @@ -472,7 +493,6 @@ async def _process_router_stream( "the local knowledge base. \n\n" "You can find setup instructions at: https://baserow.io/user-docs" ), - sources=[], ) ) elif getattr(event, "answer", None): @@ -510,7 +530,7 @@ async def _process_agent_stream( messages.append( AiMessageChunk( content=event.content, - sources=self.callbacks.sources, + sources=self._assistant_callbacks.sources, ) ) @@ -586,11 +606,12 @@ async def astream_messages( AssistantChatMessage.Role.HUMAN, message.content, ) + default_callbacks = udspy.settings.callbacks with udspy.settings.context( lm=self._lm_client, - callbacks=[*udspy.settings.callbacks, self.callbacks], - ): + callbacks=[*default_callbacks, *self._callbacks], + ), self._telemetry_callbacks.trace(self._chat, human_msg.content): message_id = str(human_msg.id) yield AiStartedMessage(message_id=message_id) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py new file mode 100644 index 0000000000..facea32576 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/telemetry.py @@ -0,0 +1,294 @@ +""" +Posthog telemetry integration for the Baserow Assistant. + +This module provides tracing callbacks that capture DSPy execution flows +and send structured events to Posthog for LLM analytics. +""" + +from contextlib import contextmanager +from datetime import datetime, timezone +from typing import Any +from uuid import uuid4 + +import udspy +from posthog.ai.openai import AsyncOpenAI +from udspy.callback import BaseCallback + +from baserow.core.posthog import posthog_client +from baserow_enterprise.assistant.models import AssistantChat + + +def _utc_now() -> datetime: + return datetime.now(tz=timezone.utc) + + +def _uuid() -> str: + return str(uuid4()) + + +class PosthogTracingCallback(BaseCallback): + """ + Captures uDSPy execution traces and sends events to Posthog. + + This callback tracks: + - uDSPy module execution (ChainOfThought, ReAct, Predict) + - LLM API calls (OpenAI, Groq, etc.) + - Tool invocations + - Performance metrics and token usage + + Each instance is created per Assistant call with trace context, so + multiple concurrent traces can be captured independently. + """ + + def __init__(self): + super().__init__() + + self.chat: AssistantChat | None = None + self.human_msg: str | None = None + self.trace_id: str | None = None + self.span_id: str | None = None + self.user_id: str | None = None + self.workspace_id: str | None = None + self.chat_uuid: str | None = None + self.spans: dict[str, dict] = {} + self.span_ids: list[str] = [] + + @contextmanager + def trace(self, chat: AssistantChat, human_message: str): + """ + Context manager for tracing an assistant execution. + Initializes trace context and captures the overall trace event. + It also patches the OpenAI client to auto-capture generation events. + + :param chat: The AssistantChat instance + :param human_message: The initial user message + """ + + self.chat = chat + self.human_msg = human_message + + self.trace_id = _uuid() + self.span_id = _uuid() + self.user_id = str(chat.user_id) + self.workspace_id = str(chat.workspace_id) + self.chat_uuid = str(chat.uuid) + + start_time = _utc_now() + self.spans = {} + self.span_ids = [self.span_id] + self.trace_outputs = None + + # patch the OpenAI client to automatically send the generation event + lm = udspy.settings._context_lm.get() + openai_client = lm.client + if not isinstance(openai_client, AsyncOpenAI): + lm.client = AsyncOpenAI( + api_key=openai_client.api_key, + base_url=openai_client.base_url, + posthog_client=posthog_client, + ) + + exception = None + try: + yield self + except Exception as exc: + exception = exc + raise + finally: + # Stop trace + self._capture_event( + "$ai_trace", + timestamp=start_time, + properties={ + "$ai_session_id": chat.uuid, + "$ai_span_name": f"{self.user_id}: {human_message[:20]}", + "$ai_span_id": self.span_id, + "$ai_latency": (_utc_now() - start_time).total_seconds(), + "$ai_is_error": exception is not None, + "$ai_input_state": {"user_message": human_message}, + "$ai_output_state": self.trace_outputs + if exception is None + else str(exception), + }, + ) + + def _capture_event(self, event: str, **kwargs): + """ + Capture a Posthog event if Posthog is enabled. + + :param event: Event name (e.g., "$ai_generation") + :param properties: Event properties dictionary + """ + + default_props = { + "$ai_trace_id": self.trace_id, + "$ai_session_id": self.chat_uuid, + "workspace_id": self.workspace_id, + } + if "properties" in kwargs: + kwargs["properties"].update(default_props) + else: + kwargs["properties"] = default_props + + posthog_client.capture( + distinct_id=str(self.user_id), + event=event, + **kwargs, + ) # noqa: W505 + + def on_module_start(self, call_id: str, instance: Any, inputs: dict): + """ + Track the start of a DSPy module execution. + + Captures ChainOfThought, ReAct, Predict, and other module types. + + :param call_id: Unique identifier for this call + :param instance: The DSPy module instance + :param inputs: Input dictionary passed to the module + """ + + module_type = instance.__class__.__name__ + parent_span_id = self.span_ids[-1] if self.span_ids else None + span_id = call_id + self.span_ids.append(span_id) + span = { + "start_time": _utc_now(), + "properties": { + "$ai_span_name": module_type, + "$ai_span_id": span_id, + "$ai_parent_span_id": parent_span_id, + }, + } + self.spans[span_id] = span + + def _update_span_with_signature_data(signature): + adapter = udspy.ChatAdapter() + input_fields = ", ".join(signature.get_input_fields().keys()) + output_fields = ", ".join(signature.get_output_fields()) + span["properties"]["$ai_input_state"] = { + "signature": f"{input_fields} -> {output_fields}", + "instructions": adapter.format_instructions(signature), + **inputs["kwargs"], + } + + if isinstance(instance, (udspy.Predict, udspy.ReAct)): + _update_span_with_signature_data(instance.signature) + elif isinstance(instance, udspy.ChainOfThought): + _update_span_with_signature_data(instance.original_signature) + + def on_module_end(self, call_id: str, outputs: Any, exception: Exception | None): + """ + Remove the span from the stack together with all the started $ai_generation + spans appended in `on_lm_start` + + Args: + call_id: Unique identifier for this call + outputs: Module output (if successful) + exception: Exception raised (if failed) + """ + + while (span_id := self.span_ids.pop()) != call_id: + continue + + span = self.spans.pop(span_id) + start_time = span.pop("start_time") + span["properties"].update( + { + "$ai_latency": (_utc_now() - start_time).total_seconds(), + "$ai_is_error": exception is not None, + "$ai_output_state": outputs if exception is None else str(exception), + } + ) + + if isinstance(outputs, dict) and "answer" in outputs: + self.trace_outputs = { + k: v + for k, v in outputs.items() + if k not in ["module", "native_tool_calls"] + } + + self._capture_event("$ai_span", timestamp=start_time, **span) + + def on_lm_start(self, call_id: str, instance: Any, inputs: dict): + """ + Only enrich posthog properties that will be sent automatically + by the patched openai client. + Add the span_id to the stack so any tool call will be shown + as a child span. + + Args: + call_id: Unique identifier for this call + instance: The LM instance + inputs: API call parameters (model, messages, temperature, etc.) + """ + + parent_span_id = self.span_ids[-1] if self.span_ids else None + kwargs = inputs["kwargs"] + span_id = call_id + self.span_ids.append(span_id) + kwargs["posthog_distinct_id"] = self.user_id + kwargs["posthog_trace_id"] = self.trace_id + kwargs["posthog_properties"] = { + "$ai_session_id": self.chat_uuid, + "$ai_parent_span_id": parent_span_id, + "$ai_span_id": span_id, + "workspace_id": self.workspace_id, + "$ai_provider": instance.provider, + } + + def on_lm_end(self, call_id: str, outputs: Any, exception: Exception | None): + """ + Automatically tracked by the patched openai client. + + :param call_id: Unique identifier for this call + :param outputs: LLM response object + :param exception: Exception raised (if failed) + """ + + pass + + def on_tool_start(self, call_id: str, instance: Any, inputs: dict): + """ + Track the start of a tool invocation. + + Args: + call_id: Unique identifier for this call + instance: The tool instance + inputs: Tool input parameters + """ + + tool_name = getattr(instance, "name", instance.__class__.__name__) + + span_id = call_id + parent_span_id = self.span_ids[-1] if self.span_ids else None + self.spans[span_id] = { + "start_time": _utc_now(), + "properties": { + "$ai_span_name": f"Tool: {tool_name}", + "$ai_span_id": span_id, + "$ai_parent_span_id": parent_span_id, + "$ai_input_state": inputs, + }, + } + + def on_tool_end(self, call_id: str, outputs: Any, exception: Exception | None): + """ + Track the completion of a tool invocation. + + Args: + call_id: Unique identifier for this call + outputs: Tool output + exception: Exception raised (if failed) + """ + + span_id = call_id + span = self.spans.pop(span_id) + start_time = span.pop("start_time") + span["properties"].update( + { + "$ai_latency": (_utc_now() - start_time).total_seconds(), + "$ai_is_error": exception is not None, + "$ai_output_state": outputs if exception is None else str(exception), + } + ) + self._capture_event("$ai_span", timestamp=start_time, **span) diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 7da32e04a5..03a8d77c1f 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -34,6 +34,15 @@ ) +@pytest.fixture(autouse=True) +def mock_posthog_openai(): + with patch("posthog.ai.openai.AsyncOpenAI") as mock: + # Configure the mock if needed + mock.return_value = MagicMock() + mock.return_value.model = "test-model" + yield mock + + @pytest.mark.django_db class TestAssistantCallbacks: """Test the AssistantCallbacks class for handling tool execution""" @@ -296,9 +305,8 @@ def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixtu } @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_history_is_passed_to_astream_as_context( - self, mock_lm, mock_react_astream, enterprise_data_fixture + self, mock_react_astream, enterprise_data_fixture ): """ Test that chat history is loaded correctly and passed to the agent as context @@ -372,7 +380,6 @@ async def _stream(): return _stream() mock_react_astream.side_effect = mock_agent_stream_factory - mock_lm.return_value.model = "test-model" message = HumanMessage(content="How to add a view?") @@ -390,9 +397,8 @@ class TestAssistantMessagePersistence: @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_persists_human_message( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that human messages are persisted to database before streaming""" @@ -426,9 +432,6 @@ async def mock_agent_stream(*args, **kwargs): mock_react_astream.return_value = mock_agent_stream() - # Configure mock LM to return a serializable model name - mock_lm.return_value.model = "test-model" - assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -456,9 +459,8 @@ async def consume_stream(): @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_persists_ai_message_with_sources( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that AI messages are persisted with sources in artifacts""" @@ -468,9 +470,6 @@ def test_astream_messages_persists_ai_message_with_sources( user=user, workspace=workspace, title="Test Chat" ) - # Configure mock LM to return a serializable model name - mock_lm.return_value.model = "test-model" - assistant = Assistant(chat) # Mock the router stream @@ -599,9 +598,8 @@ class TestAssistantStreaming: @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_yields_answer_chunks( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that answer chunks are yielded during streaming""" @@ -643,9 +641,6 @@ async def mock_agent_stream(*args, **kwargs): mock_react_astream.return_value = mock_agent_stream() - # Configure mock LM to return a serializable model name - mock_lm.return_value.model = "test-model" - async def consume_stream(): chunks = [] human_message = HumanMessage(content="Test") @@ -733,9 +728,8 @@ async def consume_stream(): @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_yields_thinking_messages( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that thinking messages from tools are yielded""" @@ -771,9 +765,6 @@ async def mock_agent_stream(*args, **kwargs): mock_react_astream.return_value = mock_agent_stream() - # Configure mock LM to return a serializable model name - mock_lm.return_value.model = "test-model" - ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), user=UserUIContext(id=user.id, name=user.first_name, email=user.email), @@ -1056,9 +1047,8 @@ def test_check_cancellation_does_nothing_when_no_flag( @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_yields_ai_started_message( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that astream_messages yields AiStartedMessage at the beginning""" @@ -1090,7 +1080,6 @@ async def mock_agent_stream(*args, **kwargs): yield Prediction(answer="Hello there!", trajectory=[], reasoning="") mock_react_astream.return_value = mock_agent_stream() - mock_lm.return_value.model = "test-model" assistant = Assistant(chat) human_message = HumanMessage(content="Hello") @@ -1111,9 +1100,8 @@ async def collect_messages(): @patch("udspy.ChainOfThought.astream") @patch("udspy.ReAct.astream") - @patch("udspy.LM") def test_astream_messages_checks_cancellation_periodically( - self, mock_lm, mock_react_astream, mock_cot_astream, enterprise_data_fixture + self, mock_react_astream, mock_cot_astream, enterprise_data_fixture ): """Test that astream_messages checks for cancellation every 10 chunks""" @@ -1147,7 +1135,6 @@ async def mock_agent_stream(*args, **kwargs): yield Prediction(answer="Complete response", trajectory=[], reasoning="") mock_react_astream.return_value = mock_agent_stream() - mock_lm.return_value.model = "test-model" assistant = Assistant(chat) cache_key = assistant._get_cancellation_cache_key() diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py new file mode 100644 index 0000000000..0bebff3a50 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_telemetry.py @@ -0,0 +1,274 @@ +from unittest.mock import MagicMock, patch + +import pytest +import udspy + +from baserow_enterprise.assistant.models import AssistantChat +from baserow_enterprise.assistant.telemetry import PosthogTracingCallback + + +@pytest.fixture +def assistant_chat_fixture(enterprise_data_fixture): + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + return AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + + +@pytest.fixture(autouse=True) +def mock_posthog_openai(): + with udspy.settings.context(lm=udspy.LM(model="fake-model")), patch( + "posthog.ai.openai.AsyncOpenAI" + ) as mock: + # Configure the mock if needed + mock.return_value = MagicMock() + mock.return_value.model = "test-model" + yield mock + + +@pytest.mark.django_db +class TestPosthogTracingCallback: + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_trace_context_manager_success(self, mock_posthog, assistant_chat_fixture): + """Test the trace context manager in a successful execution flow.""" + + callback = PosthogTracingCallback() + + with callback.trace(assistant_chat_fixture, "Hello"): + assert callback.trace_id is not None + assert callback.span_id is not None + assert callback.user_id == str(assistant_chat_fixture.user_id) + + # Verify trace event captured + mock_posthog.capture.assert_called_once() + call_args = mock_posthog.capture.call_args + + # Check the event structure + assert call_args.kwargs["distinct_id"] == str(assistant_chat_fixture.user_id) + assert call_args.kwargs["event"] == "$ai_trace" + assert "timestamp" in call_args.kwargs + + # Check properties + props = call_args.kwargs["properties"] + assert props["$ai_trace_id"] == callback.trace_id + assert props["$ai_session_id"] == str(assistant_chat_fixture.uuid) + assert props["workspace_id"] == str(assistant_chat_fixture.workspace_id) + assert props["$ai_span_name"] == f"{assistant_chat_fixture.user_id}: Hello" + assert props["$ai_span_id"] == callback.span_id + assert props["$ai_latency"] >= 0 + assert props["$ai_is_error"] is False + assert props["$ai_input_state"] == {"user_message": "Hello"} + assert props["$ai_output_state"] is None + + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_trace_context_manager_exception( + self, mock_posthog, assistant_chat_fixture + ): + """Test the trace context manager when an exception occurs.""" + + callback = PosthogTracingCallback() + + with pytest.raises(ValueError): + with callback.trace(assistant_chat_fixture, "Hello"): + raise ValueError("Test error") + + # Verify trace event captured with error + call_args = mock_posthog.capture.call_args + assert call_args is not None + assert call_args.kwargs["event"] == "$ai_trace" + assert call_args.kwargs["properties"]["$ai_is_error"] is True + + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_on_module_start_end(self, mock_posthog, assistant_chat_fixture): + """Test module execution tracing.""" + + callback = PosthogTracingCallback() + # Initialize context manually + callback.chat = assistant_chat_fixture + callback.user_id = str(assistant_chat_fixture.user_id) + callback.workspace_id = str(assistant_chat_fixture.workspace_id) + callback.chat_uuid = str(assistant_chat_fixture.uuid) + callback.trace_id = "trace-123" + callback.span_ids = ["root-span"] + callback.spans = {} + callback.enabled = True + + # Mock a CoT module + mock_module = MagicMock(spec=udspy.ChainOfThought) + mock_module.__class__ = udspy.ChainOfThought + mock_signature = MagicMock() + mock_signature.get_input_fields.return_value = {"q": 1} + mock_signature.get_output_fields.return_value = { + "a": 1 + } # Should be dict, not list + mock_signature.get_instructions.return_value = "Test instructions" + mock_module.original_signature = mock_signature + + # Start module + callback.on_module_start( + call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} + ) + + assert len(callback.span_ids) == 2 + assert len(callback.spans) == 1 + + # End module + callback.on_module_end( + call_id="call-1", outputs={"a": "result"}, exception=None + ) + + assert len(callback.span_ids) == 1 + assert len(callback.spans) == 0 + + # Verify span event was called + mock_posthog.capture.assert_called_once() + call_args = mock_posthog.capture.call_args + + # Check the event structure + assert call_args.kwargs["distinct_id"] == str(assistant_chat_fixture.user_id) + assert call_args.kwargs["event"] == "$ai_span" + assert "timestamp" in call_args.kwargs + + # Check properties + props = call_args.kwargs["properties"] + assert props["$ai_trace_id"] == "trace-123" + assert props["$ai_session_id"] == str(assistant_chat_fixture.uuid) + assert props["workspace_id"] == str(assistant_chat_fixture.workspace_id) + assert props["$ai_span_name"] == "ChainOfThought" + assert props["$ai_span_id"] == "call-1" + assert props["$ai_parent_span_id"] == "root-span" + assert "$ai_input_state" in props + assert props["$ai_output_state"] == {"a": "result"} + assert props["$ai_latency"] >= 0 + assert props["$ai_is_error"] is False + + def test_on_lm_start(self, assistant_chat_fixture): + """Test LM start tracing.""" + + callback = PosthogTracingCallback() + callback.chat = assistant_chat_fixture + callback.user_id = "user-1" + callback.workspace_id = "ws-1" + callback.chat_uuid = "chat-1" + callback.trace_id = "trace-1" + callback.span_ids = ["root"] + + mock_lm = MagicMock() + mock_lm.provider = "openai" + + inputs = {"kwargs": {}} + callback.on_lm_start("call-1", mock_lm, inputs) + + assert len(callback.span_ids) == 2 + assert inputs["kwargs"]["posthog_distinct_id"] == "user-1" + assert inputs["kwargs"]["posthog_trace_id"] == "trace-1" + assert inputs["kwargs"]["posthog_properties"]["$ai_provider"] == "openai" + + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_on_tool_start_end(self, mock_posthog, assistant_chat_fixture): + """Test tool execution tracing.""" + + callback = PosthogTracingCallback() + callback.chat = assistant_chat_fixture + callback.user_id = str(assistant_chat_fixture.user_id) + callback.workspace_id = str(assistant_chat_fixture.workspace_id) + callback.chat_uuid = str(assistant_chat_fixture.uuid) + callback.trace_id = "trace-1" + callback.span_ids = ["root"] + callback.spans = {} + callback.enabled = True + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + + # Start tool + callback.on_tool_start( + call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + ) + + assert len(callback.spans) == 1 + + # End tool + callback.on_tool_end(call_id="call-1", outputs="result", exception=None) + + # Verify event + mock_posthog.capture.assert_called() + props = mock_posthog.capture.call_args.kwargs["properties"] + assert props["$ai_span_name"] == "Tool: test_tool" + assert props["$ai_input_state"] == {"arg": "val"} + assert props["$ai_output_state"] == "result" + + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_on_module_end_with_exception(self, mock_posthog, assistant_chat_fixture): + """Test that exception string is captured in $ai_output_state.""" + + callback = PosthogTracingCallback() + callback.chat = assistant_chat_fixture + callback.user_id = str(assistant_chat_fixture.user_id) + callback.workspace_id = str(assistant_chat_fixture.workspace_id) + callback.chat_uuid = str(assistant_chat_fixture.uuid) + callback.trace_id = "trace-123" + callback.span_ids = ["root-span"] + callback.spans = {} + callback.enabled = True + + # Mock a module + mock_module = MagicMock(spec=udspy.ChainOfThought) + mock_module.__class__ = udspy.ChainOfThought + mock_signature = MagicMock() + mock_signature.get_input_fields.return_value = {"q": 1} + mock_signature.get_output_fields.return_value = {"a": 1} + mock_signature.get_instructions.return_value = "Test instructions" + mock_module.original_signature = mock_signature + + # Start module + callback.on_module_start( + call_id="call-1", instance=mock_module, inputs={"kwargs": {"q": "test"}} + ) + + # End module with exception + test_exception = ValueError("Test error message") + callback.on_module_end(call_id="call-1", outputs=None, exception=test_exception) + + # Verify exception string is captured + mock_posthog.capture.assert_called_once() + call_args = mock_posthog.capture.call_args + props = call_args.kwargs["properties"] + + assert props["$ai_is_error"] is True + assert props["$ai_output_state"] == "Test error message" + + @patch("baserow_enterprise.assistant.telemetry.posthog_client") + def test_on_tool_end_with_exception(self, mock_posthog, assistant_chat_fixture): + """Test that exception string is captured in $ai_output_state for tools.""" + + callback = PosthogTracingCallback() + callback.chat = assistant_chat_fixture + callback.user_id = str(assistant_chat_fixture.user_id) + callback.workspace_id = str(assistant_chat_fixture.workspace_id) + callback.chat_uuid = str(assistant_chat_fixture.uuid) + callback.trace_id = "trace-1" + callback.span_ids = ["root"] + callback.spans = {} + callback.enabled = True + + mock_tool = MagicMock() + mock_tool.name = "test_tool" + + # Start tool + callback.on_tool_start( + call_id="call-1", instance=mock_tool, inputs={"arg": "val"} + ) + + # End tool with exception + test_exception = RuntimeError("Tool execution failed") + callback.on_tool_end(call_id="call-1", outputs=None, exception=test_exception) + + # Verify exception string is captured + mock_posthog.capture.assert_called_once() + call_args = mock_posthog.capture.call_args + props = call_args.kwargs["properties"] + + assert props["$ai_is_error"] is True + assert props["$ai_output_state"] == "Tool execution failed"