diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py b/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py index 69b13705b6..f53870db51 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/errors.py @@ -1,4 +1,4 @@ -from rest_framework.status import HTTP_404_NOT_FOUND +from rest_framework.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST = ( "ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST", @@ -9,7 +9,7 @@ ERROR_ASSISTANT_MODEL_NOT_SUPPORTED = ( "ERROR_ASSISTANT_MODEL_NOT_SUPPORTED", - 400, + HTTP_400_BAD_REQUEST, ( "The specified language model is not supported or the provided API key is missing/invalid. " "Ensure you have set the correct provider API key and selected a compatible model in " @@ -17,3 +17,9 @@ "supported models, required environment variables, and example configuration." ), ) + +ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK = ( + "ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK", + HTTP_400_BAD_REQUEST, + "This message cannot be submitted for feedback because it has no associated prediction.", +) diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py index ae4abdb25c..437c2df373 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/serializers.py @@ -4,7 +4,7 @@ from drf_spectacular.plumbing import force_instance from rest_framework import serializers -from baserow_enterprise.assistant.models import AssistantChat +from baserow_enterprise.assistant.models import AssistantChat, AssistantChatPrediction from baserow_enterprise.assistant.types import ( AssistantMessageType, AssistantMessageUnion, @@ -138,6 +138,19 @@ class AiMessageSerializer(serializers.Serializer): "The list of relevant source URLs referenced in the knowledge. Can be empty or null." ), ) + can_submit_feedback = serializers.BooleanField( + default=False, + help_text=( + "Whether the user can submit feedback for this message. " + "Only true for messages with an associated prediction." + ), + ) + human_sentiment = serializers.ChoiceField( + required=False, + allow_null=True, + choices=["LIKE", "DISLIKE"], + help_text="The sentiment for the message, if it has been rated.", + ) class AiThinkingSerializer(serializers.Serializer): @@ -295,3 +308,28 @@ def _map_serializer(self, auto_schema, direction, mapping): }, }, } + + +class AssistantRateChatMessageSerializer(serializers.Serializer): + sentiment = serializers.ChoiceField( + required=True, + allow_null=True, + choices=["LIKE", "DISLIKE"], + help_text="The sentiment for the message.", + ) + feedback = serializers.CharField( + help_text="Optional feedback about the message.", + required=False, + allow_blank=True, + allow_null=True, + ) + + def to_internal_value(self, data): + validated_data = super().to_internal_value(data) + validated_data["sentiment"] = AssistantChatPrediction.SENTIMENT_MAP.get( + data.get("sentiment") + ) + # Additional feedback is only allowed for DISLIKE sentiment + if data["sentiment"] != "DISLIKE": + validated_data["feedback"] = "" + return validated_data diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/urls.py b/enterprise/backend/src/baserow_enterprise/api/assistant/urls.py index 67aef49166..ed455e1477 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/urls.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/urls.py @@ -1,6 +1,10 @@ from django.urls import path -from .views import AssistantChatsView, AssistantChatView +from .views import ( + AssistantChatMessageFeedbackView, + AssistantChatsView, + AssistantChatView, +) app_name = "baserow_enterprise.api.assistant" @@ -15,4 +19,9 @@ AssistantChatsView.as_view(), name="list", ), + path( + "messages//feedback/", + AssistantChatMessageFeedbackView.as_view(), + name="message_feedback", + ), ] diff --git a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py index c0b2baa1df..a97e1ee2e4 100644 --- a/enterprise/backend/src/baserow_enterprise/api/assistant/views.py +++ b/enterprise/backend/src/baserow_enterprise/api/assistant/views.py @@ -1,12 +1,15 @@ import json from urllib.request import Request +from uuid import uuid4 from django.http import StreamingHttpResponse from baserow_premium.license.handler import LicenseHandler from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes from drf_spectacular.utils import OpenApiResponse, extend_schema +from loguru import logger from rest_framework.response import Response +from rest_framework.status import HTTP_204_NO_CONTENT from rest_framework.views import APIView from baserow.api.decorators import ( @@ -18,32 +21,38 @@ from baserow.api.pagination import LimitOffsetPagination from baserow.api.schemas import get_error_schema from baserow.api.serializers import get_example_pagination_serializer_class +from baserow.api.sessions import set_client_undo_redo_action_group_id from baserow.core.exceptions import UserNotInWorkspace, WorkspaceDoesNotExist from baserow.core.feature_flags import FF_ASSISTANT, feature_flag_is_enabled from baserow.core.handler import CoreHandler -from baserow_enterprise.api.assistant.errors import ( - ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST, - ERROR_ASSISTANT_MODEL_NOT_SUPPORTED, -) from baserow_enterprise.assistant.exceptions import ( AssistantChatDoesNotExist, + AssistantChatMessagePredictionDoesNotExist, AssistantModelNotSupportedError, ) from baserow_enterprise.assistant.handler import AssistantHandler +from baserow_enterprise.assistant.models import AssistantChatPrediction from baserow_enterprise.assistant.operations import ChatAssistantChatOperationType from baserow_enterprise.assistant.types import ( + AiErrorMessage, AssistantMessageUnion, HumanMessage, UIContext, ) from baserow_enterprise.features import ASSISTANT +from .errors import ( + ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST, + ERROR_ASSISTANT_MODEL_NOT_SUPPORTED, + ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK, +) from .serializers import ( AssistantChatMessagesSerializer, AssistantChatSerializer, AssistantChatsRequestSerializer, AssistantMessageRequestSerializer, AssistantMessageSerializer, + AssistantRateChatMessageSerializer, ) @@ -139,7 +148,6 @@ class AssistantChatView(APIView): { UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP, WorkspaceDoesNotExist: ERROR_GROUP_DOES_NOT_EXIST, - AssistantChatDoesNotExist: ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST, AssistantModelNotSupportedError: ERROR_ASSISTANT_MODEL_NOT_SUPPORTED, } ) @@ -164,16 +172,33 @@ def post(self, request: Request, chat_uuid: str, data) -> StreamingHttpResponse: # Clearing the user websocket_id will make sure real-time updates are sent chat.user.web_socket_id = None - # FIXME: As long as we don't allow users to change it, temporarily set the - # timezone to the one provided in the UI context + + # Used to group all the actions done to produce this message together + # so they can be undone in one go. + set_client_undo_redo_action_group_id(chat.user, str(uuid4())) + + # As long as we don't allow users to change it, temporarily set the timezone to + # the one provided in the UI context so tools can use it if needed. chat.user.profile.timezone = ui_context.timezone assistant = handler.get_assistant(chat) + assistant.check_llm_ready_or_raise() human_message = HumanMessage(content=data["content"], ui_context=ui_context) async def stream_assistant_messages(): - async for msg in assistant.astream_messages(human_message): - yield self._stream_assistant_message(msg) + try: + async for msg in assistant.astream_messages(human_message): + yield self._stream_assistant_message(msg) + except Exception: + logger.exception("Error while streaming assistant messages") + yield self._stream_assistant_message( + AiErrorMessage( + content=( + "Oops, something went wrong and I cannot continue the conversation. " + "Please try again." + ) + ) + ) response = StreamingHttpResponse( stream_assistant_messages(), @@ -230,3 +255,51 @@ def get(self, request: Request, chat_uuid: str) -> Response: serializer = AssistantChatMessagesSerializer({"messages": messages}) return Response(serializer.data) + + +class AssistantChatMessageFeedbackView(APIView): + @extend_schema( + tags=["AI Assistant"], + operation_id="submit_assistant_message_feedback", + description=( + "Provide sentiment and feedback for the given AI assistant chat message.\n\n" + "This is an **advanced/enterprise** feature." + ), + responses={ + 200: None, + 400: get_error_schema( + ["ERROR_USER_NOT_IN_GROUP", "ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK"] + ), + }, + ) + @validate_body(AssistantRateChatMessageSerializer, return_validated=True) + @map_exceptions( + { + UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP, + WorkspaceDoesNotExist: ERROR_GROUP_DOES_NOT_EXIST, + AssistantChatDoesNotExist: ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST, + AssistantChatMessagePredictionDoesNotExist: ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK, + } + ) + def put(self, request: Request, message_id: int, data) -> Response: + feature_flag_is_enabled(FF_ASSISTANT, raise_if_disabled=True) + + handler = AssistantHandler() + message = handler.get_chat_message_by_id(request.user, message_id) + LicenseHandler.raise_if_user_doesnt_have_feature( + ASSISTANT, request.user, message.chat.workspace + ) + + try: + prediction: AssistantChatPrediction = message.prediction + except AttributeError: + raise AssistantChatMessagePredictionDoesNotExist( + f"Message with ID {message_id} does not have an associated prediction." + ) + + prediction.human_sentiment = data["sentiment"] + prediction.human_feedback = data.get("feedback") or "" + prediction.save( + update_fields=["human_sentiment", "human_feedback", "updated_on"] + ) + return Response(status=HTTP_204_NO_CONTENT) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/adapter.py b/enterprise/backend/src/baserow_enterprise/assistant/adapter.py index 1702fb1c3c..c6c2ed321b 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/adapter.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/adapter.py @@ -1,14 +1,21 @@ -import dspy - from .prompts import ASSISTANT_SYSTEM_PROMPT -class ChatAdapter(dspy.ChatAdapter): - def format_field_description(self, signature: type[dspy.Signature]) -> str: - """ - This is the first part of the prompt the LLM sees, so we prepend our custom - system prompt to it to give it the personality and context of Baserow. - """ +def get_chat_adapter(): + import dspy # local import to save memory when not used + + class ChatAdapter(dspy.ChatAdapter): + def format_field_description(self, signature: type[dspy.Signature]) -> str: + """ + This is the first part of the prompt the LLM sees, so we prepend our custom + system prompt to it to give it the personality and context of Baserow. + """ + + field_description = super().format_field_description(signature) + return ( + ASSISTANT_SYSTEM_PROMPT + + "## TASK INSTRUCTIONS:\n\n" + + field_description + ) - field_description = super().format_field_description(signature) - return ASSISTANT_SYSTEM_PROMPT + "## TASK INSTRUCTIONS:\n\n" + field_description + return ChatAdapter() diff --git a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py index 7c2bee09b2..10f73c44d4 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/assistant.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/assistant.py @@ -1,17 +1,14 @@ +from functools import lru_cache from typing import Any, AsyncGenerator, TypedDict from django.conf import settings -import dspy -from dspy.primitives.prediction import Prediction -from dspy.streaming import StreamListener, StreamResponse -from dspy.utils.callback import BaseCallback - +from baserow.api.sessions import get_client_undo_redo_action_group_id +from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError from baserow_enterprise.assistant.tools.registries import assistant_tool_registry -from .adapter import ChatAdapter -from .models import AssistantChat, AssistantChatMessage -from .react import ReAct +from .adapter import get_chat_adapter +from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction from .types import ( AiMessage, AiMessageChunk, @@ -22,20 +19,6 @@ HumanMessage, UIContext, ) -from .utils import ensure_llm_model_accessible - - -class ChatSignature(dspy.Signature): - question: str = dspy.InputField() - history: dspy.History = dspy.InputField() - ui_context: UIContext | None = dspy.InputField( - default=None, - desc=( - "The frontend UI content the user is currently in. " - "Whenever make sense, use it to ground your answer." - ), - ) - answer: str = dspy.OutputField() class AssistantMessagePair(TypedDict): @@ -43,72 +26,95 @@ class AssistantMessagePair(TypedDict): answer: str -class AssistantCallbacks(BaseCallback): - def __init__(self): - self.tool_calls = {} - self.sources = [] +def get_assistant_callbacks(): + from dspy.utils.callback import BaseCallback - def extend_sources(self, sources: list[str]) -> None: - """ - Extends the current list of sources with new ones, avoiding duplicates. + class AssistantCallbacks(BaseCallback): + def __init__(self): + self.tool_calls = {} + self.sources = [] - :param sources: The list of new source URLs to add. - :return: None - """ + def extend_sources(self, sources: list[str]) -> None: + """ + Extends the current list of sources with new ones, avoiding duplicates. - self.sources.extend([s for s in sources if s not in self.sources]) + :param sources: The list of new source URLs to add. + :return: None + """ - def on_tool_start( - self, - call_id: str, - instance: Any, - inputs: dict[str, Any], - ) -> None: - """ - Called when a tool starts. It records the tool call and invokes the - corresponding tool's on_tool_start method if it exists. + self.sources.extend([s for s in sources if s not in self.sources]) - :param call_id: The unique identifier of the tool call. - :param instance: The instance of the tool being called. - :param inputs: The inputs provided to the tool. - """ + def on_tool_start( + self, + call_id: str, + instance: Any, + inputs: dict[str, Any], + ) -> None: + """ + Called when a tool starts. It records the tool call and invokes the + corresponding tool's on_tool_start method if it exists. - try: - assistant_tool_registry.get(instance.name).on_tool_start( - call_id, instance, inputs + :param call_id: The unique identifier of the tool call. + :param instance: The instance of the tool being called. + :param inputs: The inputs provided to the tool. + """ + + try: + assistant_tool_registry.get(instance.name).on_tool_start( + call_id, instance, inputs + ) + self.tool_calls[call_id] = (instance, inputs) + except assistant_tool_registry.does_not_exist_exception_class: + pass + + def on_tool_end( + self, + call_id: str, + outputs: dict[str, Any] | None, + exception: Exception | None = None, + ) -> None: + """ + Called when a tool ends. It invokes the corresponding tool's on_tool_end + method if it exists and updates the sources if the tool produced any. + + :param call_id: The unique identifier of the tool call. + :param outputs: The outputs returned by the tool, or None if there was an + exception. + :param exception: The exception raised by the tool, or None if it was + successful. + """ + + if call_id not in self.tool_calls: + return + + instance, inputs = self.tool_calls.pop(call_id) + assistant_tool_registry.get(instance.name).on_tool_end( + call_id, instance, inputs, outputs, exception ) - self.tool_calls[call_id] = (instance, inputs) - except assistant_tool_registry.does_not_exist_exception_class: - pass - def on_tool_end( - self, - call_id: str, - outputs: dict[str, Any] | None, - exception: Exception | None = None, - ) -> None: - """ - Called when a tool ends. It invokes the corresponding tool's on_tool_end - method if it exists and updates the sources if the tool produced any. - - :param call_id: The unique identifier of the tool call. - :param outputs: The outputs returned by the tool, or None if there was an - exception. - :param exception: The exception raised by the tool, or None if it was - successful. - """ + # If the tool produced sources, add them to the overall list of sources. + if isinstance(outputs, dict) and "sources" in outputs: + self.extend_sources(outputs["sources"]) - if call_id not in self.tool_calls: - return + return AssistantCallbacks() - instance, inputs = self.tool_calls.pop(call_id) - assistant_tool_registry.get(instance.name).on_tool_end( - call_id, instance, inputs, outputs, exception + +def get_chat_signature(): + import dspy # local import to save memory when not used + + class ChatSignature(dspy.Signature): + question: str = dspy.InputField() + history: dspy.History = dspy.InputField() + ui_context: UIContext | None = dspy.InputField( + default=None, + desc=( + "The frontend UI content the user is currently in. " + "Whenever make sense, use it to ground your answer." + ), ) + answer: str = dspy.OutputField() - # If the tool produced sources, add them to the overall list of sources. - if isinstance(outputs, dict) and "sources" in outputs: - self.extend_sources(outputs["sources"]) + return ChatSignature class Assistant: @@ -117,17 +123,27 @@ def __init__(self, chat: AssistantChat): self._user = chat.user self._workspace = chat.workspace + self._init_lm_client() + self._init_assistant() + + def _init_lm_client(self): + import dspy # local import to save memory when not used + lm_model = settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL + self._lm_client = dspy.LM( model=lm_model, cache=not settings.DEBUG, max_retries=5, ) + def _init_assistant(self): + from .react import ReAct # local import to save memory when not used + tools = assistant_tool_registry.list_all_usable_tools( self._user, self._workspace ) - self._assistant = ReAct(ChatSignature, tools=tools) + self._assistant = ReAct(get_chat_signature(), tools=tools) self.history = None async def acreate_chat_message( @@ -135,6 +151,7 @@ async def acreate_chat_message( role: AssistantChatMessage.Role, content: str, artifacts: dict[str, Any] | None = None, + **kwargs, ) -> AssistantChatMessage: """ Creates and saves a new chat message. @@ -149,10 +166,13 @@ async def acreate_chat_message( chat=self._chat, role=role, content=content, + **kwargs, ) if artifacts: message.artifacts = artifacts - return await message.asave() + + await message.asave() + return message def list_chat_messages( self, last_message_id: int | None = None, limit: int = 100 @@ -166,7 +186,11 @@ def list_chat_messages( :return: A list of AssistantChatMessage instances. """ - queryset = self._chat.messages.all().order_by("-created_on") + queryset = ( + self._chat.messages.all() + .select_related("prediction") + .order_by("-created_on") + ) if last_message_id is not None: queryset = queryset.filter(id__lt=last_message_id) @@ -179,8 +203,19 @@ def list_chat_messages( ) ) else: + sentiment_data = {} + if getattr(msg, "prediction", None): + sentiment_data = { + "can_submit_feedback": True, + "human_sentiment": msg.prediction.get_human_sentiment_display(), + } messages.append( - AiMessage(content=msg.content, id=msg.id, timestamp=msg.created_on) + AiMessage( + content=msg.content, + id=msg.id, + timestamp=msg.created_on, + **sentiment_data, + ) ) return list(reversed(messages)) @@ -194,6 +229,8 @@ async def aload_chat_history(self, limit=20): :return: None """ + import dspy # local import to save memory when not used + last_saved_messages: list[AssistantChatMessage] = [ msg async for msg in self._chat.messages.order_by("-created_on")[:limit] ] @@ -219,6 +256,28 @@ async def aload_chat_history(self, limit=20): self.history = dspy.History(messages=messages) + @lru_cache(maxsize=1) + def check_llm_ready_or_raise(self): + import dspy # local import to save memory when not used + from litellm import get_supported_openai_params + + lm = self._lm_client + params = get_supported_openai_params(lm.model) + if params is None or "tools" not in params: + raise AssistantModelNotSupportedError( + f"The model '{lm.model}' is not supported or could not be found. " + "Please make sure the model name is correct, it can use tools, " + "and that your API key has access to it." + ) + + try: + with dspy.context(lm=lm): + lm("Say ok if you can read this.") + except Exception as e: + raise AssistantModelNotSupportedError( + f"The model '{lm.model}' is not supported or accessible: {e}" + ) + async def astream_messages( self, human_message: HumanMessage ) -> AsyncGenerator[AssistantMessageUnion, None]: @@ -229,16 +288,17 @@ async def astream_messages( :return: An async generator that yields the response messages. """ - # The first time, make sure the model and api_key are setup correctly - ensure_llm_model_accessible(self._lm_client) + import dspy # local import to save memory when not used + from dspy.primitives.prediction import Prediction + from dspy.streaming import StreamListener, StreamResponse - callback_manager = AssistantCallbacks() + callback_manager = get_assistant_callbacks() with dspy.context( lm=self._lm_client, cache=not settings.DEBUG, callbacks=[*dspy.settings.config.callbacks, callback_manager], - adapter=ChatAdapter(), + adapter=get_chat_adapter(), ): if self.history is None: await self.aload_chat_history() @@ -260,7 +320,7 @@ async def astream_messages( ), ) - await self.acreate_chat_message( + human_msg = await self.acreate_chat_message( AssistantChatMessage.Role.HUMAN, human_message.content ) @@ -273,19 +333,37 @@ async def astream_messages( yield AiMessageChunk( content=answer, sources=callback_manager.sources ) + elif isinstance(stream_chunk, (AiThinkingMessage, AiNavigationMessage)): + # forward thinking/navigation messages as-is to the frontend + yield stream_chunk elif isinstance(stream_chunk, Prediction): - yield AiMessageChunk( - content=stream_chunk.answer, sources=callback_manager.sources - ) - await self.acreate_chat_message( + # At the end of the prediction, save the AI message and the + # prediction details for future analysis and feedback. + ai_msg = await self.acreate_chat_message( AssistantChatMessage.Role.AI, answer, artifacts={"sources": callback_manager.sources}, + action_group_id=get_client_undo_redo_action_group_id( + self._user + ), + ) + await AssistantChatPrediction.objects.acreate( + human_message=human_msg, + ai_response=ai_msg, + prediction={ + "model": self._lm_client.model, + "trajectory": stream_chunk.trajectory, + "reasoning": stream_chunk.reasoning, + }, + ) + # In case the streaming didn't work, make sure we yield at least one + # final message with the complete answer. + yield AiMessage( + id=ai_msg.id, + content=stream_chunk.answer, + sources=callback_manager.sources, + can_submit_feedback=True, ) - - elif isinstance(stream_chunk, (AiThinkingMessage, AiNavigationMessage)): - # forward thinking/navigation messages as-is to the frontend - yield stream_chunk if not self._chat.title: title_generator = dspy.Predict("question -> chat_title") diff --git a/enterprise/backend/src/baserow_enterprise/assistant/exceptions.py b/enterprise/backend/src/baserow_enterprise/assistant/exceptions.py index ee148ad4cb..e5182a79db 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/exceptions.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/exceptions.py @@ -8,3 +8,7 @@ class AssistantChatDoesNotExist(AssistantException): class AssistantModelNotSupportedError(AssistantException): pass + + +class AssistantChatMessagePredictionDoesNotExist(AssistantException): + pass diff --git a/enterprise/backend/src/baserow_enterprise/assistant/handler.py b/enterprise/backend/src/baserow_enterprise/assistant/handler.py index bd4993e2a9..fa7033fffc 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/handler.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/handler.py @@ -1,3 +1,4 @@ +from datetime import datetime, timedelta, timezone from typing import AsyncGenerator from uuid import UUID @@ -7,7 +8,7 @@ from .assistant import Assistant from .exceptions import AssistantChatDoesNotExist -from .models import AssistantChat +from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction from .types import AiMessage, AssistantMessageUnion, HumanMessage, UIContext @@ -66,6 +67,25 @@ def list_chats(self, user: AbstractUser, workspace_id: int) -> list[AssistantCha workspace_id=workspace_id, user=user ).order_by("-updated_on", "id") + def get_chat_message_by_id(self, user: AbstractUser, message_id: int) -> AiMessage: + """ + Get a specific message from the AI assistant chat by its ID. + + :param user: The user requesting the message. + :param message_id: The ID of the message to retrieve. + :return: The AI assistant message. + :raises AssistantChatDoesNotExist: If the chat or message does not exist. + """ + + try: + return AssistantChatMessage.objects.select_related( + "chat__workspace", "prediction" + ).get(chat__user=user, id=message_id) + except AssistantChatMessage.DoesNotExist: + raise AssistantChatDoesNotExist( + f"Message with ID {message_id} does not exist." + ) + def list_chat_messages(self, chat: AssistantChat) -> list[AiMessage | HumanMessage]: """ Get all messages from the AI assistant chat. @@ -87,6 +107,27 @@ def get_assistant(self, chat: AssistantChat) -> Assistant: return Assistant(chat) + def delete_predictions( + self, older_than_days: int = 30, exclude_rated: bool = True + ) -> tuple[int, dict]: + """ + Delete predictions older than the specified number of days. + + :param older_than_days: The number of days to retain predictions. + :param exclude_rated: Whether to exclude predictions that have been rated by + users. + :return: A tuple containing the number of deleted predictions and a dict with + details. + """ + + cutoff_date = datetime.now(timezone.utc) - timedelta(days=older_than_days) + queryset = AssistantChatPrediction.objects.filter(created_on__lt=cutoff_date) + + if exclude_rated: + queryset = queryset.filter(human_sentiment__isnull=True) + + return queryset.delete() + async def astream_assistant_messages( self, chat: AssistantChat, diff --git a/enterprise/backend/src/baserow_enterprise/assistant/models.py b/enterprise/backend/src/baserow_enterprise/assistant/models.py index 6ee1903da3..fadc669025 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/models.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/models.py @@ -80,6 +80,13 @@ class Role(models.TextChoices): "such as metadata or processing results." ), ) + action_group_id = models.UUIDField( + null=True, + help_text=( + "Unique identifier for the action group. Can be provided by the client. " + "All the actions done to produce this message can be undone by referencing this ID." + ), + ) class Meta: indexes = [ @@ -87,6 +94,60 @@ class Meta: ] +class AssistantChatPrediction( + BigAutoFieldMixin, CreatedAndUpdatedOnMixin, models.Model +): + """ + Model representing a prediction for an assistant chat message, including the + reasoning and any tool calls made by the AI. It also captures optional feedback from + the human user regarding the prediction. + """ + + SENTIMENT_MAP = { + "LIKE": 1, + "DISLIKE": -1, + # Add also the reverse mapping for convenience. + 1: "LIKE", + -1: "DISLIKE", + } + + human_message = models.OneToOneField( + AssistantChatMessage, + on_delete=models.CASCADE, + related_name="+", + help_text="The human message that caused this prediction.", + ) + ai_response = models.OneToOneField( + AssistantChatMessage, + on_delete=models.CASCADE, + related_name="prediction", + help_text="The AI response message generated as a prediction.", + ) + prediction = models.JSONField( + default=dict, + help_text="The prediction data, including the reasoning and any tool calls.", + ) + human_sentiment = models.SmallIntegerField( + choices=[ + (SENTIMENT_MAP["LIKE"], "Like"), + (SENTIMENT_MAP["DISLIKE"], "Dislike"), + ], + null=True, + help_text="Optional feedback provided by the human user on the prediction.", + ) + human_feedback = models.TextField( + blank=True, + help_text="Optional feedback provided by the human user on the prediction.", + ) + + def get_human_sentiment_display(self): + """ + Returns the display value of the human sentiment. + """ + + return self.SENTIMENT_MAP.get(self.human_sentiment) + + class DocumentCategory(NamedTuple): name: str parent: str diff --git a/enterprise/backend/src/baserow_enterprise/assistant/react.py b/enterprise/backend/src/baserow_enterprise/assistant/react.py index 42474ccd11..e2c50f8e9c 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/react.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/react.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal +from typing import Any, Callable, Literal import dspy from dspy.adapters.types.tool import Tool @@ -10,10 +10,6 @@ from .types import ToolsUpgradeResponse -if TYPE_CHECKING: - from dspy.signatures.signature import Signature - - # Variant of dspy.predict.react.ReAct that accepts a "meta-tool": # a callable that can produce tools at runtime (e.g. per-table schemas). # This lets a single ReAct instance handle many different table signatures @@ -21,9 +17,7 @@ class ReAct(Module): - def __init__( - self, signature: type["Signature"], tools: list[Callable], max_iters: int = 100 - ): + def __init__(self, signature, tools: list[Callable], max_iters: int = 100): """ ReAct stands for "Reasoning and Acting," a popular paradigm for building tool-using agents. In this approach, the language model is iteratively provided diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tasks.py b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py new file mode 100644 index 0000000000..aff4d4eeec --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/assistant/tasks.py @@ -0,0 +1,18 @@ +from datetime import timedelta + +from baserow.config.celery import app + +from .handler import AssistantHandler + + +@app.task(bind=True) +def delete_old_unrated_predictions(self): + AssistantHandler().delete_predictions(older_than_days=30, exclude_rated=True) + + +@app.on_after_finalize.connect +def setup_period_trash_tasks(sender, **kwargs): + sender.add_periodic_task( + timedelta(days=1), + delete_old_unrated_predictions.s(), + ) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py index bd12a3daa2..87b306626d 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py @@ -5,7 +5,6 @@ from django.db import transaction from django.utils.translation import gettext as _ -import dspy from loguru import logger from pydantic import create_model @@ -25,9 +24,9 @@ from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers from baserow_enterprise.assistant.types import ( TableNavigationType, - ToolSignature, ToolsUpgradeResponse, ViewNavigationType, + get_tool_signature, ) from . import utils @@ -284,6 +283,8 @@ def create_tables( - if add_sample_rows is True (default), add some example rows to each table """ + import dspy # local import to save memory when not used + nonlocal user, workspace, tool_helpers if not tables: @@ -354,7 +355,7 @@ def create_tables( f"- Create 5 example rows for table_{created_table.id}. Fill every relationship with valid data when possible." ) - predictor = dspy.Predict(ToolSignature) + predictor = dspy.Predict(get_tool_signature()) result = predictor( question=("\n".join(instructions)), tools=list(tools.values()), diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py index b864e00ae6..42fc3d8242 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py @@ -7,8 +7,6 @@ from django.db.models import Q from django.utils.translation import gettext as _ -import dspy -from dspy.adapters.types.tool import _resolve_json_schema_reference from pydantic import ConfigDict, Field, create_model from baserow.contrib.database.fields.actions import CreateFieldActionType @@ -385,6 +383,9 @@ def get_view(user, view_id: int): def get_table_rows_tools( user, workspace: Workspace, tool_helpers: ToolHelpers, table: Table ): + import dspy # local import to save memory when not used + from dspy.adapters.types.tool import _resolve_json_schema_reference + row_model_for_create = get_create_row_model(table) row_model_for_update = get_update_row_model(table) row_model_for_response = create_model( diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py index 9e42d7629a..52c4de4d73 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py @@ -1,6 +1,3 @@ -from dspy.dsp.utils.settings import settings -from dspy.streaming.messages import sync_send_to_stream - from baserow_enterprise.assistant.types import AiNavigationMessage, AnyNavigationType @@ -13,6 +10,9 @@ def unsafe_navigate_to(location: AnyNavigationType) -> str: :param navigation_type: The type of navigation to perform. """ + from dspy.dsp.utils.settings import settings + from dspy.streaming.messages import sync_send_to_stream + stream = settings.send_stream if stream is not None: sync_send_to_stream(stream, AiNavigationMessage(location=location)) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py index 83f036a596..c3a2783de8 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/registries.py @@ -4,9 +4,6 @@ from django.contrib.auth.models import AbstractUser from django.utils import translation -from dspy.dsp.utils.settings import settings -from dspy.streaming.messages import sync_send_to_stream - from baserow.core.exceptions import ( InstanceTypeAlreadyRegistered, InstanceTypeDoesNotExist, @@ -123,6 +120,9 @@ def update_status_localized(status: str): :param status: The status message to send. """ + from dspy.dsp.utils.settings import settings + from dspy.streaming.messages import sync_send_to_stream + nonlocal user with translation.override(user.profile.language): diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py index 6c0177e943..874f2d7fc6 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/handler.py @@ -5,7 +5,6 @@ from django.core import serializers from django.db.models import Q -import dspy from httpx import Client as httpxClient from pgvector.django import L2Distance @@ -62,11 +61,13 @@ def __call__(self, texts: list[str]) -> list[list[float]]: class VectorHandler: - def __init__(self, embedder: dspy.Embedder | None = None): + def __init__(self, embedder=None): self._embedder = embedder @property - def embedder(self) -> dspy.Embedder: + def embedder(self): + import dspy # local import to save memory when not used + if self._embedder is None: self._embedder = dspy.Embedder( BaserowEmbedder(settings.BASEROW_EMBEDDINGS_API_URL) diff --git a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py index 0a228bdbf3..d30c41cb92 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/tools/search_docs/tools.py @@ -3,8 +3,6 @@ from django.contrib.auth.models import AbstractUser from django.utils.translation import gettext as _ -import dspy - from baserow.core.models import Workspace from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers @@ -13,13 +11,18 @@ MAX_SOURCES = 3 -class SearchDocsSignature(dspy.Signature): - question: str = dspy.InputField() - context: list[str] = dspy.InputField() - response: str = dspy.OutputField() - sources: list[str] = dspy.OutputField( - desc=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." - ) +def get_search_predictor(): + import dspy # local import to save memory when not used + + class SearchDocsSignature(dspy.Signature): + question: str = dspy.InputField() + context: list[str] = dspy.InputField() + response: str = dspy.OutputField() + sources: list[str] = dspy.OutputField( + desc=f"List of unique and relevant source URLs. Max {MAX_SOURCES}." + ) + + return dspy.ChainOfThought(SearchDocsSignature) class SearchDocsToolOutput(TypedDict): @@ -39,10 +42,20 @@ def search_docs(query: str) -> SearchDocsToolOutput: Search Baserow documentation. """ + import dspy # local import to save memory when not used + nonlocal tool_helpers tool_helpers.update_status(_("Exploring the knowledge base...")) + class SearchDocsRAG(dspy.Module): + def __init__(self): + self.respond = get_search_predictor() + + def forward(self, question): + context = KnowledgeBaseHandler().search(question, num_results=10) + return self.respond(context=context, question=question) + tool = SearchDocsRAG() result = tool(query) @@ -61,15 +74,6 @@ def search_docs(query: str) -> SearchDocsToolOutput: return search_docs -class SearchDocsRAG(dspy.Module): - def __init__(self): - self.respond = dspy.ChainOfThought(SearchDocsSignature) - - def forward(self, question): - context = KnowledgeBaseHandler().search(question, num_results=10) - return self.respond(context=context, question=question) - - class SearchDocsToolType(AssistantToolType): type = "search_docs" diff --git a/enterprise/backend/src/baserow_enterprise/assistant/types.py b/enterprise/backend/src/baserow_enterprise/assistant/types.py index 693ad8d7bb..620e30e91c 100644 --- a/enterprise/backend/src/baserow_enterprise/assistant/types.py +++ b/enterprise/backend/src/baserow_enterprise/assistant/types.py @@ -4,7 +4,6 @@ from django.utils.translation import gettext as _ -import dspy from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict, Field @@ -130,6 +129,20 @@ class AiMessage(AiMessageChunk): description="The unique UUID of the message", ) timestamp: datetime | None = Field(default=None) + can_submit_feedback: bool = Field( + default=False, + description=( + "Whether the message can be submitted for feedback. This is true if the " + "message has an associated prediction." + ), + ) + human_sentiment: Optional[Literal["LIKE", "DISLIKE"]] = Field( + default=None, + description=( + "The sentiment of the message as submitted by the user. It can be 'LIKE', " + "'DISLIKE', or None if no sentiment has been submitted." + ), + ) class AiThinkingMessage(BaseModel): @@ -156,7 +169,9 @@ class AiErrorMessageCode(StrEnum): class AiErrorMessage(BaseModel): type: Literal["ai/error"] = AssistantMessageType.AI_ERROR.value - code: AiErrorMessageCode = Field(description="The type of error that occurred") + code: AiErrorMessageCode = Field( + AiErrorMessageCode.UNKNOWN, description="The type of error that occurred" + ) content: str = Field(description="Error message content") @@ -207,12 +222,17 @@ class AiNavigationMessage(BaseModel): class ToolsUpgradeResponse(BaseModel): observation: str - new_tools: list[dspy.Tool | Callable[[Any], Any]] + new_tools: list[Callable[[Any], Any]] + + +def get_tool_signature(): + import dspy # local import to save memory when not used + class ToolSignature(dspy.Signature): + """Signature for manual tool handling.""" -class ToolSignature(dspy.Signature): - """Signature for manual tool handling.""" + question: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + outputs: dspy.ToolCalls = dspy.OutputField() - question: str = dspy.InputField() - tools: list[dspy.Tool] = dspy.InputField() - outputs: dspy.ToolCalls = dspy.OutputField() + return ToolSignature diff --git a/enterprise/backend/src/baserow_enterprise/assistant/utils/__init__.py b/enterprise/backend/src/baserow_enterprise/assistant/utils/__init__.py deleted file mode 100644 index 1a3450c05f..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .llm import * # noqa: F403, F401 diff --git a/enterprise/backend/src/baserow_enterprise/assistant/utils/llm.py b/enterprise/backend/src/baserow_enterprise/assistant/utils/llm.py deleted file mode 100644 index 11c9b12ab0..0000000000 --- a/enterprise/backend/src/baserow_enterprise/assistant/utils/llm.py +++ /dev/null @@ -1,32 +0,0 @@ -from functools import lru_cache - -import dspy -from litellm import get_supported_openai_params - -from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError - - -@lru_cache(maxsize=1) -def ensure_llm_model_accessible(lm: dspy.LM) -> None: - """ - Ensure the given model is accessible and works with the current API key and - settings. - - :param model: The model name to validate. - :raises AssistantModelNotSupportedError: If the model is not supported or not - accessible. - """ - - params = get_supported_openai_params(lm.model) - if params is None: - raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or could not be found." - ) - - with dspy.context(lm=lm): - try: - lm("Say ok if you can read this.") - except Exception as e: - raise AssistantModelNotSupportedError( - f"The model '{lm.model}' is not supported or accessible: {e}" - ) from e diff --git a/enterprise/backend/src/baserow_enterprise/migrations/0055_assistantchatmessage_action_group_id_and_more.py b/enterprise/backend/src/baserow_enterprise/migrations/0055_assistantchatmessage_action_group_id_and_more.py new file mode 100644 index 0000000000..67f5e5d0c6 --- /dev/null +++ b/enterprise/backend/src/baserow_enterprise/migrations/0055_assistantchatmessage_action_group_id_and_more.py @@ -0,0 +1,82 @@ +# Generated by Django 5.0.14 on 2025-10-13 09:43 + +import django.db.models.deletion +from django.db import migrations, models + +import baserow.core.fields + + +class Migration(migrations.Migration): + dependencies = [ + ("baserow_enterprise", "0054_assistant_knowledgebase_and_more"), + ] + + operations = [ + migrations.AddField( + model_name="assistantchatmessage", + name="action_group_id", + field=models.UUIDField( + help_text="Unique identifier for the action group. Can be provided by the client. All the actions done to produce this message can be undone by referencing this ID.", + null=True, + ), + ), + migrations.CreateModel( + name="AssistantChatPrediction", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("created_on", models.DateTimeField(auto_now_add=True)), + ("updated_on", baserow.core.fields.SyncedDateTimeField(auto_now=True)), + ( + "prediction", + models.JSONField( + default=dict, + help_text="The prediction data, including the reasoning and any tool calls.", + ), + ), + ( + "human_sentiment", + models.SmallIntegerField( + choices=[(1, "Like"), (-1, "Dislike")], + help_text="Optional feedback provided by the human user on the prediction.", + null=True, + ), + ), + ( + "human_feedback", + models.TextField( + blank=True, + help_text="Optional feedback provided by the human user on the prediction.", + ), + ), + ( + "ai_response", + models.OneToOneField( + help_text="The AI response message generated as a prediction.", + on_delete=django.db.models.deletion.CASCADE, + related_name="prediction", + to="baserow_enterprise.assistantchatmessage", + ), + ), + ( + "human_message", + models.OneToOneField( + help_text="The human message that caused this prediction.", + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="baserow_enterprise.assistantchatmessage", + ), + ), + ], + options={ + "abstract": False, + }, + ), + ] diff --git a/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py b/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py index 06187ed86c..668e73d0ad 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/api/assistant/test_assistant_views.py @@ -1,5 +1,4 @@ import json -from datetime import datetime, timezone from unittest.mock import MagicMock, patch from uuid import uuid4 @@ -10,7 +9,11 @@ from freezegun import freeze_time from baserow.test_utils.helpers import AnyStr -from baserow_enterprise.assistant.models import AssistantChat +from baserow_enterprise.assistant.models import ( + AssistantChat, + AssistantChatMessage, + AssistantChatPrediction, +) from baserow_enterprise.assistant.types import ( AiErrorMessage, AiMessage, @@ -455,10 +458,7 @@ def test_cannot_get_messages_from_another_users_chat( @pytest.mark.django_db @override_settings(DEBUG=True) -@patch("baserow_enterprise.api.assistant.views.AssistantHandler") -def test_get_messages_returns_chat_history( - mock_handler_class, api_client, enterprise_data_fixture -): +def test_get_messages_returns_chat_history(api_client, enterprise_data_fixture): """Test that the endpoint returns the chat message history""" user, token = enterprise_data_fixture.create_user_and_token() @@ -470,43 +470,34 @@ def test_get_messages_returns_chat_history( user=user, workspace=workspace, title="Test Chat" ) - # Mock the handler - mock_handler = MagicMock() - mock_handler_class.return_value = mock_handler - - # Mock get_chat to return the chat - mock_handler.get_chat.return_value = chat - # Mock message history - only HumanMessage and AiMessage are returned message_history = [ - HumanMessage( + AssistantChatMessage( id=1, + role=AssistantChatMessage.Role.HUMAN, content="What's the weather like?", - ui_context=UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ), + chat=chat, ), - AiMessage( + AssistantChatMessage( id=2, + role=AssistantChatMessage.Role.AI, content="I don't have access to real-time weather data.", - timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + chat=chat, ), - HumanMessage( + AssistantChatMessage( id=3, + role=AssistantChatMessage.Role.HUMAN, content="Can you help me with Baserow?", - ui_context=UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ), + chat=chat, ), - AiMessage( + AssistantChatMessage( id=4, + role=AssistantChatMessage.Role.AI, content="Of course! I'd be happy to help you with Baserow.", - timestamp=datetime(2024, 1, 1, 12, 1, 0, tzinfo=timezone.utc), + chat=chat, ), ] - mock_handler.list_chat_messages.return_value = message_history + AssistantChatMessage.objects.bulk_create(message_history) rsp = api_client.get( reverse( @@ -550,16 +541,11 @@ def test_get_messages_returns_chat_history( assert data["messages"][3]["id"] == 4 assert "timestamp" in data["messages"][3] - # Verify handler was called correctly - mock_handler.get_chat.assert_called_once_with(user, chat.uuid) - mock_handler.list_chat_messages.assert_called_once_with(chat) - @pytest.mark.django_db @override_settings(DEBUG=True) -@patch("baserow_enterprise.api.assistant.views.AssistantHandler") def test_get_messages_returns_empty_list_for_new_chat( - mock_handler_class, api_client, enterprise_data_fixture + api_client, enterprise_data_fixture ): """Test that the endpoint returns an empty list for a chat with no messages""" @@ -572,16 +558,6 @@ def test_get_messages_returns_empty_list_for_new_chat( user=user, workspace=workspace, title="Empty Chat" ) - # Mock the handler - mock_handler = MagicMock() - mock_handler_class.return_value = mock_handler - - # Mock get_chat to return the chat - mock_handler.get_chat.return_value = chat - - # Mock empty message history - mock_handler.get_chat_messages.return_value = [] - rsp = api_client.get( reverse( "assistant:chat_messages", @@ -599,10 +575,7 @@ def test_get_messages_returns_empty_list_for_new_chat( @pytest.mark.django_db @override_settings(DEBUG=True) -@patch("baserow_enterprise.api.assistant.views.AssistantHandler") -def test_get_messages_with_different_message_types( - mock_handler_class, api_client, enterprise_data_fixture -): +def test_get_messages_with_different_message_types(api_client, enterprise_data_fixture): """Test that the endpoint correctly handles different message types""" user, token = enterprise_data_fixture.create_user_and_token() @@ -614,43 +587,31 @@ def test_get_messages_with_different_message_types( user=user, workspace=workspace, title="Test Chat" ) - # Mock the handler - mock_handler = MagicMock() - mock_handler_class.return_value = mock_handler - - # Mock get_chat to return the chat - mock_handler.get_chat.return_value = chat - # Mock message history - only HumanMessage and AiMessage are returned message_history = [ - HumanMessage( - id=1, - content="Hello", - ui_context=UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ), + AssistantChatMessage( + id=1, role=AssistantChatMessage.Role.HUMAN, content="Hello", chat=chat ), - AiMessage( + AssistantChatMessage( id=2, + role=AssistantChatMessage.Role.AI, content="Hi there! How can I help you?", - timestamp=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), + chat=chat, ), - HumanMessage( + AssistantChatMessage( id=3, + role=AssistantChatMessage.Role.HUMAN, content="Tell me about Baserow", - ui_context=UIContext( - workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), - user=UserUIContext(id=user.id, name=user.first_name, email=user.email), - ), + chat=chat, ), - AiMessage( + AssistantChatMessage( id=4, + role=AssistantChatMessage.Role.AI, content="Baserow is an open-source no-code database platform.", - timestamp=datetime(2024, 1, 1, 12, 1, 0, tzinfo=timezone.utc), + chat=chat, ), ] - mock_handler.list_chat_messages.return_value = message_history + AssistantChatMessage.objects.bulk_create(message_history) rsp = api_client.get( reverse( @@ -691,6 +652,226 @@ def test_get_messages_with_different_message_types( assert "timestamp" in data["messages"][3] +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_get_messages_includes_can_submit_feedback_field( + api_client, enterprise_data_fixture +): + """ + Test that AI messages include can_submit_feedback field based on prediction state + """ + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create a chat with messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + + # Create human message + human_message_1 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="First question", + ) + + # Create AI message WITH prediction (no feedback yet) + ai_message_1 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="First answer", + ) + AssistantChatPrediction.objects.create( + human_message=human_message_1, + ai_response=ai_message_1, + prediction={"reasoning": "test"}, + ) + + # Create second human message + human_message_2 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="Second question", + ) + + # Create AI message WITHOUT prediction + ai_message_2 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="Second answer", + ) + + # Create third human message + human_message_3 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="Third question", + ) + + # Create AI message WITH prediction AND existing feedback + ai_message_3 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="Third answer", + ) + AssistantChatPrediction.objects.create( + human_message=human_message_3, + ai_response=ai_message_3, + prediction={"reasoning": "test"}, + human_sentiment=1, # Already has feedback + human_feedback="Great answer", + ) + + rsp = api_client.get( + reverse( + "assistant:chat_messages", + kwargs={"chat_uuid": str(chat.uuid)}, + ), + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 200 + data = rsp.json() + + assert len(data["messages"]) == 6 + + assert data["messages"][0]["type"] == "human" + assert "can_submit_feedback" not in data["messages"][0] + assert "human_sentiment" not in data["messages"][0] + + # First AI message: has prediction, no feedback yet -> can submit + assert data["messages"][1]["type"] == "ai/message" + assert data["messages"][1]["can_submit_feedback"] is True + assert data["messages"][1]["human_sentiment"] is None + + assert data["messages"][2]["type"] == "human" + assert "can_submit_feedback" not in data["messages"][2] + assert "human_sentiment" not in data["messages"][2] + + # Second AI message: no prediction -> cannot submit + assert data["messages"][3]["type"] == "ai/message" + assert data["messages"][3]["can_submit_feedback"] is False + assert data["messages"][3]["human_sentiment"] is None + + assert data["messages"][4]["type"] == "human" + assert "can_submit_feedback" not in data["messages"][4] + assert "human_sentiment" not in data["messages"][4] + + # Third AI message: has prediction with existing feedback + assert data["messages"][5]["type"] == "ai/message" + assert data["messages"][5]["can_submit_feedback"] is True + assert data["messages"][5]["human_sentiment"] == "LIKE" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_get_messages_includes_human_sentiment_when_feedback_exists( + api_client, enterprise_data_fixture +): + """Test that human_sentiment is included in AI messages when feedback exists""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create a chat + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + + # Create messages with LIKE feedback + human_message_1 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="Question 1", + ) + ai_message_1 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="Answer 1", + ) + AssistantChatPrediction.objects.create( + human_message=human_message_1, + ai_response=ai_message_1, + prediction={"reasoning": "test"}, + human_sentiment=1, # LIKE + human_feedback="Very helpful", + ) + + # Create messages with DISLIKE feedback + human_message_2 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="Question 2", + ) + ai_message_2 = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="Answer 2", + ) + AssistantChatPrediction.objects.create( + human_message=human_message_2, + ai_response=ai_message_2, + prediction={"reasoning": "test"}, + human_sentiment=-1, # DISLIKE + human_feedback="Not accurate", + ) + + message_history = [ + HumanMessage( + id=human_message_1.id, + content="Question 1", + ui_context=UIContext( + workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), + user=UserUIContext(id=user.id, name=user.first_name, email=user.email), + ), + ), + AiMessage( + id=ai_message_1.id, + content="Answer 1", + can_submit_feedback=False, + human_sentiment="LIKE", + ), + HumanMessage( + id=human_message_2.id, + content="Question 2", + ui_context=UIContext( + workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), + user=UserUIContext(id=user.id, name=user.first_name, email=user.email), + ), + ), + AiMessage( + id=ai_message_2.id, + content="Answer 2", + can_submit_feedback=False, + human_sentiment="DISLIKE", + ), + ] + + rsp = api_client.get( + reverse( + "assistant:chat_messages", + kwargs={"chat_uuid": str(chat.uuid)}, + ), + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 200 + data = rsp.json() + + # First AI message: LIKE sentiment + assert data["messages"][1]["type"] == "ai/message" + assert data["messages"][1]["human_sentiment"] == "LIKE" + assert data["messages"][1]["can_submit_feedback"] is True + + # Second AI message: DISLIKE sentiment + assert data["messages"][3]["type"] == "ai/message" + assert data["messages"][3]["human_sentiment"] == "DISLIKE" + assert data["messages"][3]["can_submit_feedback"] is True + + @pytest.mark.django_db @override_settings(DEBUG=True) @patch("baserow_enterprise.assistant.handler.Assistant") @@ -1467,3 +1648,522 @@ async def mock_astream(human_message): assert received_message is not None assert received_message.ui_context.dashboard.id == "dash-789" assert received_message.ui_context.dashboard.name == "Sales Dashboard" + + +# ============================================================================= +# Tests for AssistantChatMessageFeedbackView +# ============================================================================= + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_with_like_sentiment(api_client, enterprise_data_fixture): + """Test submitting positive feedback (LIKE) for a message""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create a chat with messages and prediction + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + + # Create human message + human_message = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.HUMAN, + content="Hello", + ) + + # Create AI message + ai_message = AssistantChatMessage.objects.create( + chat=chat, + role=AssistantChatMessage.Role.AI, + content="Hi there!", + ) + + # Create prediction + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Submit feedback + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "LIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify feedback was saved + prediction.refresh_from_db() + assert prediction.human_sentiment == 1 # LIKE = 1 + assert prediction.human_feedback == "" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_with_dislike_sentiment_and_text( + api_client, enterprise_data_fixture +): + """Test submitting negative feedback (DISLIKE) with feedback text""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Submit negative feedback with text + feedback_text = "The answer was not helpful" + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "DISLIKE", "feedback": feedback_text}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify feedback was saved + prediction.refresh_from_db() + assert prediction.human_sentiment == -1 # DISLIKE = -1 + assert prediction.human_feedback == feedback_text + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_update_existing_feedback(api_client, enterprise_data_fixture): + """Test updating feedback that was already submitted""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages with existing feedback + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + human_sentiment=1, # Initially LIKE + human_feedback="Was helpful", + ) + + # Update to DISLIKE with new feedback + new_feedback = "Actually, it wasn't accurate" + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "DISLIKE", "feedback": new_feedback}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify feedback was updated + prediction.refresh_from_db() + assert prediction.human_sentiment == -1 # Changed to DISLIKE + assert prediction.human_feedback == new_feedback + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_with_null_sentiment(api_client, enterprise_data_fixture): + """Test clearing/removing feedback by setting sentiment to null""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages with existing feedback + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + human_sentiment=1, + human_feedback="Was helpful", + ) + + # Clear feedback by sending null sentiment + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": None}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify feedback was cleared + prediction.refresh_from_db() + assert prediction.human_sentiment is None + assert prediction.human_feedback == "" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_cannot_submit_feedback_for_message_without_prediction( + api_client, enterprise_data_fixture +): + """Test that submitting feedback fails if message has no prediction""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and AI message WITHOUT prediction + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + + # Try to submit feedback + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "LIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 400 + assert rsp.json()["error"] == "ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_cannot_submit_feedback_for_nonexistent_message( + api_client, enterprise_data_fixture +): + """Test that submitting feedback fails for non-existent message""" + + _, token = enterprise_data_fixture.create_user_and_token() + enterprise_data_fixture.enable_enterprise() + + # Try to submit feedback for non-existent message + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": 999999}), + data={"sentiment": "LIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 404 + assert rsp.json()["error"] == "ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_cannot_submit_feedback_for_another_users_message( + api_client, enterprise_data_fixture +): + """Test that users cannot submit feedback on other users' messages""" + + user1, _ = enterprise_data_fixture.create_user_and_token() + user2, token2 = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(members=[user1, user2]) + enterprise_data_fixture.enable_enterprise() + + # Create chat and message for user1 + chat = AssistantChat.objects.create( + user=user1, workspace=workspace, title="User1's Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Try to submit feedback as user2 + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "LIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token2}", + ) + + assert rsp.status_code == 404 + assert rsp.json()["error"] == "ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_cannot_submit_feedback_without_license(api_client, enterprise_data_fixture): + """Test that submitting feedback requires an enterprise license""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + # Note: NOT enabling enterprise license + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Try to submit feedback without license + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "LIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 402 + assert rsp.json()["error"] == "ERROR_FEATURE_NOT_AVAILABLE" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_validates_sentiment_choice( + api_client, enterprise_data_fixture +): + """Test that feedback endpoint validates sentiment choices""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Try to submit with invalid sentiment + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "INVALID"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 400 + assert "sentiment" in str(rsp.json()).lower() + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_requires_sentiment_field(api_client, enterprise_data_fixture): + """Test that feedback endpoint requires sentiment field""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Try to submit without sentiment field + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"feedback": "Just some feedback"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 400 + assert "sentiment" in str(rsp.json()).lower() + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_without_feedback_text(api_client, enterprise_data_fixture): + """Test that feedback text is optional""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Submit feedback without text (only sentiment) + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "DISLIKE"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify feedback was saved without text + prediction.refresh_from_db() + assert prediction.human_sentiment == -1 + assert prediction.human_feedback == "" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_with_empty_feedback_text(api_client, enterprise_data_fixture): + """Test that empty feedback text is stored as empty string""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + ) + + # Submit with empty feedback string + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "LIKE", "feedback": ""}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify empty string is stored + prediction.refresh_from_db() + assert prediction.human_sentiment == 1 + assert prediction.human_feedback == "" + + +@pytest.mark.django_db +@override_settings(DEBUG=True) +def test_submit_feedback_toggles_sentiment_from_like_to_dislike( + api_client, enterprise_data_fixture +): + """Test changing sentiment from LIKE to DISLIKE""" + + user, token = enterprise_data_fixture.create_user_and_token() + workspace = enterprise_data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Create chat and messages + chat = AssistantChat.objects.create( + user=user, workspace=workspace, title="Test Chat" + ) + human_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.HUMAN, content="Question" + ) + ai_message = AssistantChatMessage.objects.create( + chat=chat, role=AssistantChatMessage.Role.AI, content="Answer" + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_message, + ai_response=ai_message, + prediction={"reasoning": "test"}, + human_sentiment=1, # Start with LIKE + ) + + # Change to DISLIKE + rsp = api_client.put( + reverse("assistant:message_feedback", kwargs={"message_id": ai_message.id}), + data={"sentiment": "DISLIKE", "feedback": "Changed my mind"}, + format="json", + HTTP_AUTHORIZATION=f"JWT {token}", + ) + + assert rsp.status_code == 204 + + # Verify change + prediction.refresh_from_db() + assert prediction.human_sentiment == -1 + assert prediction.human_feedback == "Changed my mind" diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py index 19627fc5cd..4747b82e68 100644 --- a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant.py @@ -15,7 +15,7 @@ from dspy.primitives.prediction import Prediction from dspy.streaming import StreamResponse -from baserow_enterprise.assistant.assistant import Assistant, AssistantCallbacks +from baserow_enterprise.assistant.assistant import Assistant, get_assistant_callbacks from baserow_enterprise.assistant.models import AssistantChat, AssistantChatMessage from baserow_enterprise.assistant.types import ( AiMessageChunk, @@ -38,7 +38,7 @@ class TestAssistantCallbacks: def test_extend_sources_deduplicates(self): """Test that sources are deduplicated when extended""" - callbacks = AssistantCallbacks() + callbacks = get_assistant_callbacks() # Add initial sources callbacks.extend_sources( @@ -64,7 +64,7 @@ def test_extend_sources_deduplicates(self): def test_extend_sources_preserves_order(self): """Test that source order is preserved (first occurrence wins)""" - callbacks = AssistantCallbacks() + callbacks = get_assistant_callbacks() callbacks.extend_sources(["https://example.com/a"]) callbacks.extend_sources(["https://example.com/b"]) @@ -76,7 +76,7 @@ def test_extend_sources_preserves_order(self): def test_on_tool_end_extracts_sources_from_outputs(self): """Test that sources are extracted from tool outputs""" - callbacks = AssistantCallbacks() + callbacks = get_assistant_callbacks() # Mock tool instance and inputs tool_instance = MagicMock() @@ -107,7 +107,7 @@ def test_on_tool_end_extracts_sources_from_outputs(self): def test_on_tool_end_handles_missing_sources(self): """Test that tool outputs without sources don't cause errors""" - callbacks = AssistantCallbacks() + callbacks = get_assistant_callbacks() tool_instance = MagicMock() tool_instance.name = "some_tool" @@ -284,8 +284,8 @@ def test_aload_chat_history_handles_incomplete_pairs(self, enterprise_data_fixtu class TestAssistantMessagePersistence: """Test that messages are persisted correctly during streaming""" - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("dspy.streamify") + @patch("dspy.LM") def test_astream_messages_persists_human_message( self, mock_lm, mock_streamify, enterprise_data_fixture ): @@ -306,10 +306,13 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Hello") + yield Prediction(answer="Hello", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Configure mock LM to return a serializable model name + mock_lm.return_value.model = "test-model" + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -335,8 +338,8 @@ async def consume_stream(): ).first() assert saved_message.content == "Test message" - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("dspy.streamify") + @patch("dspy.LM") def test_astream_messages_persists_ai_message_with_sources( self, mock_lm, mock_streamify, enterprise_data_fixture ): @@ -356,10 +359,13 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Based on docs") + yield Prediction(answer="Based on docs", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Configure mock LM to return a serializable model name + mock_lm.return_value.model = "test-model" + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -382,14 +388,12 @@ async def consume_stream(): ).count() assert ai_messages == 1 - @patch("baserow_enterprise.assistant.assistant.ensure_llm_model_accessible") - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.Predict") + @patch("dspy.streamify") + @patch("dspy.Predict") def test_astream_messages_persists_chat_title( self, mock_predict_class, mock_streamify, - mock_ensure_llm, enterprise_data_fixture, ): """Test that chat titles are persisted to the database""" @@ -408,7 +412,7 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Hello") + yield Prediction(answer="Hello", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) @@ -445,8 +449,8 @@ async def consume_stream(): class TestAssistantStreaming: """Test streaming behavior of the Assistant""" - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("dspy.streamify") + @patch("dspy.LM") def test_astream_messages_yields_answer_chunks( self, mock_lm, mock_streamify, enterprise_data_fixture ): @@ -472,10 +476,13 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Hello world") + yield Prediction(answer="Hello world", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Configure mock LM to return a serializable model name + mock_lm.return_value.model = "test-model" + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), @@ -498,14 +505,12 @@ async def consume_stream(): assert chunks[1].content == "Hello world" assert chunks[2].content == "Hello world" # Final chunk repeats full answer - @patch("baserow_enterprise.assistant.assistant.ensure_llm_model_accessible") - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.Predict") + @patch("dspy.streamify") + @patch("dspy.Predict") def test_astream_messages_yields_title_chunks( self, mock_predict_class, mock_streamify, - mock_ensure_llm, enterprise_data_fixture, ): """Test that title chunks are yielded for new chats""" @@ -524,7 +529,7 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Answer") + yield Prediction(answer="Answer", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) @@ -556,8 +561,8 @@ async def consume_stream(): assert len(title_messages) == 1 assert title_messages[0].content == "Title" - @patch("baserow_enterprise.assistant.assistant.dspy.streamify") - @patch("baserow_enterprise.assistant.assistant.dspy.LM") + @patch("dspy.streamify") + @patch("dspy.LM") def test_astream_messages_yields_thinking_messages( self, mock_lm, mock_streamify, enterprise_data_fixture ): @@ -578,10 +583,13 @@ async def mock_stream(*args, **kwargs): predict_name="ReAct", is_last_chunk=False, ) - yield Prediction(answer="Answer") + yield Prediction(answer="Answer", trajectory=[], reasoning="") mock_streamify.return_value = MagicMock(return_value=mock_stream()) + # Configure mock LM to return a serializable model name + mock_lm.return_value.model = "test-model" + assistant = Assistant(chat) ui_context = UIContext( workspace=WorkspaceUIContext(id=workspace.id, name=workspace.name), diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_handler.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_handler.py new file mode 100644 index 0000000000..2fa77bbc74 --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_handler.py @@ -0,0 +1,378 @@ +from datetime import datetime, timedelta, timezone + +import pytest + +from baserow_enterprise.assistant.handler import AssistantHandler +from baserow_enterprise.assistant.models import ( + AssistantChat, + AssistantChatMessage, + AssistantChatPrediction, +) + + +@pytest.mark.django_db +def test_delete_predictions_removes_old_unrated_predictions(enterprise_data_fixture): + """Test that old predictions without sentiment are deleted.""" + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + # Create a chat + chat = AssistantChat.objects.create(user=user, workspace=workspace) + + # Create old predictions (older than 30 days) without sentiment + old_date = datetime.now(timezone.utc) - timedelta(days=35) + + human_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question 1", created_on=old_date + ) + ai_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer 1", created_on=old_date + ) + prediction_1 = AssistantChatPrediction.objects.create( + human_message=human_msg_1, + ai_response=ai_msg_1, + prediction={"test": "data"}, + ) + prediction_1.created_on = old_date + prediction_1.save() + + human_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question 2", created_on=old_date + ) + ai_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer 2", created_on=old_date + ) + prediction_2 = AssistantChatPrediction.objects.create( + human_message=human_msg_2, + ai_response=ai_msg_2, + prediction={"test": "data"}, + ) + prediction_2.created_on = old_date + prediction_2.save() + + # Delete predictions older than 30 days + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions(older_than_days=30) + + # Both predictions should be deleted + assert deleted_count == 2 + assert AssistantChatPrediction.objects.count() == 0 + + +@pytest.mark.django_db +def test_delete_predictions_preserves_recent_predictions(enterprise_data_fixture): + """Test that recent predictions are not deleted.""" + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + + # Create recent prediction (within 30 days) + recent_date = datetime.now(timezone.utc) - timedelta(days=10) + + human_msg = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question", created_on=recent_date + ) + ai_msg = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer", created_on=recent_date + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_msg, + ai_response=ai_msg, + prediction={"test": "data"}, + ) + prediction.created_on = recent_date + prediction.save() + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions(older_than_days=30) + + # Prediction should NOT be deleted + assert deleted_count == 0 + assert AssistantChatPrediction.objects.count() == 1 + + +@pytest.mark.django_db +def test_delete_predictions_excludes_rated_by_default(enterprise_data_fixture): + """ + Test that predictions with sentiment are excluded from deletion by default. + """ + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + old_date = datetime.now(timezone.utc) - timedelta(days=35) + + # Create old prediction with LIKE sentiment + human_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question 1", created_on=old_date + ) + ai_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer 1", created_on=old_date + ) + prediction_1 = AssistantChatPrediction.objects.create( + human_message=human_msg_1, + ai_response=ai_msg_1, + prediction={"test": "data"}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["LIKE"], + ) + prediction_1.created_on = old_date + prediction_1.save() + + # Create old prediction with DISLIKE sentiment + human_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question 2", created_on=old_date + ) + ai_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer 2", created_on=old_date + ) + prediction_2 = AssistantChatPrediction.objects.create( + human_message=human_msg_2, + ai_response=ai_msg_2, + prediction={"test": "data"}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["DISLIKE"], + ) + prediction_2.created_on = old_date + prediction_2.save() + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions( + older_than_days=30, exclude_rated=True + ) + + # No predictions should be deleted (both have sentiment) + assert deleted_count == 0 + assert AssistantChatPrediction.objects.count() == 2 + + +@pytest.mark.django_db +def test_delete_predictions_includes_rated_when_specified(enterprise_data_fixture): + """ + Test that predictions with sentiment are deleted when exclude_rated=False. + """ + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + old_date = datetime.now(timezone.utc) - timedelta(days=35) + + # Create old prediction with sentiment + human_msg = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Question", created_on=old_date + ) + ai_msg = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="Answer", created_on=old_date + ) + prediction = AssistantChatPrediction.objects.create( + human_message=human_msg, + ai_response=ai_msg, + prediction={"test": "data"}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["LIKE"], + human_feedback="Great answer!", + ) + prediction.created_on = old_date + prediction.save() + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions( + older_than_days=30, exclude_rated=False + ) + + # Prediction should be deleted even with sentiment + assert deleted_count == 1 + assert AssistantChatPrediction.objects.count() == 0 + + +@pytest.mark.django_db +def test_delete_predictions_handles_mixed_scenarios(enterprise_data_fixture): + """ + Test deletion with mixed old/recent and rated/unrated predictions. + """ + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + old_date = datetime.now(timezone.utc) - timedelta(days=35) + recent_date = datetime.now(timezone.utc) - timedelta(days=10) + + # Old + unrated (should be deleted) + human_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Q1", created_on=old_date + ) + ai_msg_1 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="A1", created_on=old_date + ) + pred_1 = AssistantChatPrediction.objects.create( + human_message=human_msg_1, ai_response=ai_msg_1, prediction={} + ) + pred_1.created_on = old_date + pred_1.save() + + # Old + rated (should NOT be deleted) + human_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Q2", created_on=old_date + ) + ai_msg_2 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="A2", created_on=old_date + ) + pred_2 = AssistantChatPrediction.objects.create( + human_message=human_msg_2, + ai_response=ai_msg_2, + prediction={}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["LIKE"], + ) + pred_2.created_on = old_date + pred_2.save() + + # Recent + unrated (should NOT be deleted) + human_msg_3 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Q3", created_on=recent_date + ) + ai_msg_3 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="A3", created_on=recent_date + ) + pred_3 = AssistantChatPrediction.objects.create( + human_message=human_msg_3, ai_response=ai_msg_3, prediction={} + ) + pred_3.created_on = recent_date + pred_3.save() + + # Recent + rated (should NOT be deleted) + human_msg_4 = AssistantChatMessage.objects.create( + chat=chat, role="human", content="Q4", created_on=recent_date + ) + ai_msg_4 = AssistantChatMessage.objects.create( + chat=chat, role="ai", content="A4", created_on=recent_date + ) + pred_4 = AssistantChatPrediction.objects.create( + human_message=human_msg_4, + ai_response=ai_msg_4, + prediction={}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["DISLIKE"], + ) + pred_4.created_on = recent_date + pred_4.save() + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions( + older_than_days=30, exclude_rated=True + ) + + # Only old unrated should be deleted + assert deleted_count == 1 + assert AssistantChatPrediction.objects.count() == 3 + + # Verify the correct prediction was deleted + assert not AssistantChatPrediction.objects.filter(id=pred_1.id).exists() + assert AssistantChatPrediction.objects.filter(id=pred_2.id).exists() + assert AssistantChatPrediction.objects.filter(id=pred_3.id).exists() + assert AssistantChatPrediction.objects.filter(id=pred_4.id).exists() + + +@pytest.mark.django_db +def test_delete_predictions_custom_days_threshold(enterprise_data_fixture): + """Test deletion with different day thresholds.""" + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + + # Create predictions at different ages + very_old = datetime.now(timezone.utc) - timedelta(days=100) + medium_old = datetime.now(timezone.utc) - timedelta(days=50) + recent = datetime.now(timezone.utc) - timedelta(days=5) + + for age, label in [ + (very_old, "very_old"), + (medium_old, "medium"), + (recent, "recent"), + ]: + human_msg = AssistantChatMessage.objects.create( + chat=chat, role="human", content=f"Q {label}", created_on=age + ) + ai_msg = AssistantChatMessage.objects.create( + chat=chat, role="ai", content=f"A {label}", created_on=age + ) + pred = AssistantChatPrediction.objects.create( + human_message=human_msg, ai_response=ai_msg, prediction={} + ) + pred.created_on = age + pred.save() + + handler = AssistantHandler() + + # Delete predictions older than 60 days (should delete 1) + deleted_count, _ = handler.delete_predictions(older_than_days=60) + assert deleted_count == 1 + assert AssistantChatPrediction.objects.count() == 2 + + # Delete predictions older than 10 days (should delete 1 more) + deleted_count, _ = handler.delete_predictions(older_than_days=10) + assert deleted_count == 1 + assert AssistantChatPrediction.objects.count() == 1 + + +@pytest.mark.django_db +def test_delete_predictions_empty_database(): + """Test that deletion returns 0 when no predictions exist.""" + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions(older_than_days=30) + + assert deleted_count == 0 + + +@pytest.mark.django_db +def test_delete_predictions_return_count_matches_deleted(enterprise_data_fixture): + """Test that the return count matches the number of deleted predictions.""" + + user = enterprise_data_fixture.create_user() + workspace = enterprise_data_fixture.create_workspace(user=user) + + chat = AssistantChat.objects.create(user=user, workspace=workspace) + old_date = datetime.now(timezone.utc) - timedelta(days=35) + + # Create exactly 5 old unrated predictions + for i in range(5): + human_msg = AssistantChatMessage.objects.create( + chat=chat, role="human", content=f"Question {i}", created_on=old_date + ) + ai_msg = AssistantChatMessage.objects.create( + chat=chat, role="ai", content=f"Answer {i}", created_on=old_date + ) + pred = AssistantChatPrediction.objects.create( + human_message=human_msg, ai_response=ai_msg, prediction={} + ) + pred.created_on = old_date + pred.save() + + # Create 2 rated predictions that should NOT be deleted + for i in range(2): + human_msg = AssistantChatMessage.objects.create( + chat=chat, role="human", content=f"Rated Q {i}", created_on=old_date + ) + ai_msg = AssistantChatMessage.objects.create( + chat=chat, role="ai", content=f"Rated A {i}", created_on=old_date + ) + pred = AssistantChatPrediction.objects.create( + human_message=human_msg, + ai_response=ai_msg, + prediction={}, + human_sentiment=AssistantChatPrediction.SENTIMENT_MAP["LIKE"], + ) + pred.created_on = old_date + pred.save() + + handler = AssistantHandler() + deleted_count, _ = handler.delete_predictions(older_than_days=30) + + # Should delete exactly 5 unrated predictions + assert deleted_count == 5 + # Should have 2 rated predictions remaining + assert AssistantChatPrediction.objects.count() == 2 diff --git a/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py new file mode 100644 index 0000000000..3fcfef4ccd --- /dev/null +++ b/enterprise/backend/tests/baserow_enterprise_tests/assistant/test_assistant_lazy_loading.py @@ -0,0 +1,143 @@ +""" +Test that dspy is lazy-loaded only when the Assistant is actually used. + +This prevents unnecessary memory usage when the AI Assistant feature is not being used. +""" +import sys + +import pytest + + +@pytest.mark.django_db +class TestDspyLazyLoading: + """Verify that dspy is only loaded when Assistant is instantiated.""" + + def test_dspy_not_loaded_on_django_startup(self): + """ + Test that dspy is NOT loaded when Django starts up. + + This is critical for memory efficiency - dspy should only be loaded + when the AI Assistant feature is actually used. + """ + + # Remove dspy and litellm from sys.modules if already loaded + # (this can happen if other tests ran first) + modules_to_remove = [ + key + for key in sys.modules + if key.startswith("dspy") or key.startswith("litellm") + ] + for module in modules_to_remove: + del sys.modules[module] + + # Import the handler module (which is what gets imported at Django startup) + from baserow_enterprise.assistant import handler # noqa: F401 + + # Verify dspy and litellm are NOT loaded yet + assert "dspy" not in sys.modules, ( + "dspy should not be loaded on import. " + "Check for top-level dspy imports in assistant module files." + ) + + assert "litellm" not in sys.modules, ( + "litellm should not be loaded on import. " + "Check for top-level litellm imports in assistant module files." + ) + + def test_dspy_loaded_when_assistant_created( + self, data_fixture, enterprise_data_fixture + ): + """ + Test that dspy IS loaded when an Assistant object is created. + + This verifies that lazy loading works correctly and dspy is available + when needed. + """ + + # Remove dspy and litellm from sys.modules to start fresh + modules_to_remove = [ + key + for key in sys.modules + if key.startswith("dspy") or key.startswith("litellm") + ] + for module in modules_to_remove: + del sys.modules[module] + + # Create necessary fixtures + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + enterprise_data_fixture.enable_enterprise() + + # Import and use handler (should not load dspy yet) + from baserow_enterprise.assistant.handler import AssistantHandler + from baserow_enterprise.assistant.models import AssistantChat + + # Verify dspy and litellm are still not loaded + assert ( + "dspy" not in sys.modules + ), "dspy should not be loaded after importing handler" + + assert ( + "litellm" not in sys.modules + ), "litellm should not be loaded after importing handler" + + # Create a chat + chat = AssistantChat.objects.create( + user=user, + workspace=workspace, + ) + + # Create Assistant - this SHOULD trigger dspy loading + handler = AssistantHandler() + assistant = handler.get_assistant(chat) + + # Now dspy and litellm should be loaded + assert "dspy" in sys.modules, ( + "dspy should be loaded after creating Assistant instance. " + "Check that Assistant.__init__ imports dspy." + ) + + assert "litellm" in sys.modules, ( + "litellm should be loaded after creating Assistant instance. " + "Check that Assistant.__init__ imports dspy." + ) + + assert assistant is not None + + def test_assistant_handler_does_not_load_dspy(self, data_fixture): + """ + Test that using AssistantHandler methods (other than get_assistant) + does not load dspy. + """ + + # Remove dspy and litellm from sys.modules + modules_to_remove = [ + key + for key in sys.modules + if key.startswith("dspy") or key.startswith("litellm") + ] + for module in modules_to_remove: + del sys.modules[module] + + # Create fixtures + user = data_fixture.create_user() + workspace = data_fixture.create_workspace(user=user) + + from baserow_enterprise.assistant.handler import AssistantHandler + + handler = AssistantHandler() + + # These operations should not load dspy + chats = handler.list_chats(user, workspace.id) + assert chats is not None + + # Verify dspy and litellm are still not loaded + assert "dspy" not in sys.modules, ( + "dspy should not be loaded by AssistantHandler methods " + "(except get_assistant)" + ) + + assert "litellm" not in sys.modules, ( + "litellm should not be loaded by AssistantHandler methods " + "(except get_assistant)" + ) diff --git a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss index bdeac73cdf..a0d4f976a1 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss +++ b/enterprise/web-frontend/modules/baserow_enterprise/assets/scss/components/assistant.scss @@ -656,3 +656,62 @@ color: $palette-neutral-400; flex-shrink: 0; } + +.assistant__actions { + margin-top: 4px; + width: 100%; +} + +.assistant__actions-header { + display: flex; + align-items: center; + justify-content: flex-start; +} + +// Feedback buttons (inline in header) +.assistant__feedback-button { + padding: 3px 6px; + border: none; + border-radius: 3px; + background: transparent; + color: $palette-neutral-500; + font-size: 11px; + cursor: pointer; + display: flex; + align-items: center; + gap: 3px; + + i { + font-size: 13px; + } + + &:hover { + color: $palette-neutral-700; + } + + &--active { + color: $palette-neutral-1200; + } +} + +// Feedback Context (popup near thumb down) +.assistant__feedback-context { + min-width: 280px; + + textarea { + width: 100%; + min-height: 52px; + resize: vertical; + } +} + +.assistant__feedback-context-content { + padding: 12px; +} + +.assistant__feedback-context-actions { + display: flex; + justify-content: flex-end; + gap: 8px; + margin-top: 10px; +} diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageActions.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageActions.vue new file mode 100644 index 0000000000..434e78c78b --- /dev/null +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageActions.vue @@ -0,0 +1,197 @@ + + + diff --git a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue index 3fc1e75658..62ca4beec7 100644 --- a/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue +++ b/enterprise/web-frontend/modules/baserow_enterprise/components/assistant/AssistantMessageList.vue @@ -37,6 +37,8 @@ /> + + @@ -45,6 +47,7 @@