Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from rest_framework.status import HTTP_404_NOT_FOUND
from rest_framework.status import HTTP_400_BAD_REQUEST, HTTP_404_NOT_FOUND

ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST = (
"ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST",
Expand All @@ -9,11 +9,17 @@

ERROR_ASSISTANT_MODEL_NOT_SUPPORTED = (
"ERROR_ASSISTANT_MODEL_NOT_SUPPORTED",
400,
HTTP_400_BAD_REQUEST,
(
"The specified language model is not supported or the provided API key is missing/invalid. "
"Ensure you have set the correct provider API key and selected a compatible model in "
"`BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL`. See https://docs.litellm.ai/docs/providers for "
"supported models, required environment variables, and example configuration."
),
)

ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK = (
"ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK",
HTTP_400_BAD_REQUEST,
"This message cannot be submitted for feedback because it has no associated prediction.",
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from drf_spectacular.plumbing import force_instance
from rest_framework import serializers

from baserow_enterprise.assistant.models import AssistantChat
from baserow_enterprise.assistant.models import AssistantChat, AssistantChatPrediction
from baserow_enterprise.assistant.types import (
AssistantMessageType,
AssistantMessageUnion,
Expand Down Expand Up @@ -138,6 +138,19 @@ class AiMessageSerializer(serializers.Serializer):
"The list of relevant source URLs referenced in the knowledge. Can be empty or null."
),
)
can_submit_feedback = serializers.BooleanField(
default=False,
help_text=(
"Whether the user can submit feedback for this message. "
"Only true for messages with an associated prediction."
),
)
human_sentiment = serializers.ChoiceField(
required=False,
allow_null=True,
choices=["LIKE", "DISLIKE"],
help_text="The sentiment for the message, if it has been rated.",
)


class AiThinkingSerializer(serializers.Serializer):
Expand Down Expand Up @@ -295,3 +308,28 @@ def _map_serializer(self, auto_schema, direction, mapping):
},
},
}


class AssistantRateChatMessageSerializer(serializers.Serializer):
sentiment = serializers.ChoiceField(
required=True,
allow_null=True,
choices=["LIKE", "DISLIKE"],
help_text="The sentiment for the message.",
)
feedback = serializers.CharField(
help_text="Optional feedback about the message.",
required=False,
allow_blank=True,
allow_null=True,
)

def to_internal_value(self, data):
validated_data = super().to_internal_value(data)
validated_data["sentiment"] = AssistantChatPrediction.SENTIMENT_MAP.get(
data.get("sentiment")
)
# Additional feedback is only allowed for DISLIKE sentiment
if data["sentiment"] != "DISLIKE":
validated_data["feedback"] = ""
return validated_data
11 changes: 10 additions & 1 deletion enterprise/backend/src/baserow_enterprise/api/assistant/urls.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from django.urls import path

from .views import AssistantChatsView, AssistantChatView
from .views import (
AssistantChatMessageFeedbackView,
AssistantChatsView,
AssistantChatView,
)

app_name = "baserow_enterprise.api.assistant"

Expand All @@ -15,4 +19,9 @@
AssistantChatsView.as_view(),
name="list",
),
path(
"messages/<int:message_id>/feedback/",
AssistantChatMessageFeedbackView.as_view(),
name="message_feedback",
),
]
91 changes: 82 additions & 9 deletions enterprise/backend/src/baserow_enterprise/api/assistant/views.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
from urllib.request import Request
from uuid import uuid4

from django.http import StreamingHttpResponse

from baserow_premium.license.handler import LicenseHandler
from drf_spectacular.openapi import OpenApiParameter, OpenApiTypes
from drf_spectacular.utils import OpenApiResponse, extend_schema
from loguru import logger
from rest_framework.response import Response
from rest_framework.status import HTTP_204_NO_CONTENT
from rest_framework.views import APIView

from baserow.api.decorators import (
Expand All @@ -18,32 +21,38 @@
from baserow.api.pagination import LimitOffsetPagination
from baserow.api.schemas import get_error_schema
from baserow.api.serializers import get_example_pagination_serializer_class
from baserow.api.sessions import set_client_undo_redo_action_group_id
from baserow.core.exceptions import UserNotInWorkspace, WorkspaceDoesNotExist
from baserow.core.feature_flags import FF_ASSISTANT, feature_flag_is_enabled
from baserow.core.handler import CoreHandler
from baserow_enterprise.api.assistant.errors import (
ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST,
ERROR_ASSISTANT_MODEL_NOT_SUPPORTED,
)
from baserow_enterprise.assistant.exceptions import (
AssistantChatDoesNotExist,
AssistantChatMessagePredictionDoesNotExist,
AssistantModelNotSupportedError,
)
from baserow_enterprise.assistant.handler import AssistantHandler
from baserow_enterprise.assistant.models import AssistantChatPrediction
from baserow_enterprise.assistant.operations import ChatAssistantChatOperationType
from baserow_enterprise.assistant.types import (
AiErrorMessage,
AssistantMessageUnion,
HumanMessage,
UIContext,
)
from baserow_enterprise.features import ASSISTANT

from .errors import (
ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST,
ERROR_ASSISTANT_MODEL_NOT_SUPPORTED,
ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK,
)
from .serializers import (
AssistantChatMessagesSerializer,
AssistantChatSerializer,
AssistantChatsRequestSerializer,
AssistantMessageRequestSerializer,
AssistantMessageSerializer,
AssistantRateChatMessageSerializer,
)


Expand Down Expand Up @@ -139,7 +148,6 @@ class AssistantChatView(APIView):
{
UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP,
WorkspaceDoesNotExist: ERROR_GROUP_DOES_NOT_EXIST,
AssistantChatDoesNotExist: ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST,
AssistantModelNotSupportedError: ERROR_ASSISTANT_MODEL_NOT_SUPPORTED,
}
)
Expand All @@ -164,16 +172,33 @@ def post(self, request: Request, chat_uuid: str, data) -> StreamingHttpResponse:

# Clearing the user websocket_id will make sure real-time updates are sent
chat.user.web_socket_id = None
# FIXME: As long as we don't allow users to change it, temporarily set the
# timezone to the one provided in the UI context

# Used to group all the actions done to produce this message together
# so they can be undone in one go.
set_client_undo_redo_action_group_id(chat.user, str(uuid4()))

# As long as we don't allow users to change it, temporarily set the timezone to
# the one provided in the UI context so tools can use it if needed.
chat.user.profile.timezone = ui_context.timezone

assistant = handler.get_assistant(chat)
assistant.check_llm_ready_or_raise()
human_message = HumanMessage(content=data["content"], ui_context=ui_context)

async def stream_assistant_messages():
async for msg in assistant.astream_messages(human_message):
yield self._stream_assistant_message(msg)
try:
async for msg in assistant.astream_messages(human_message):
yield self._stream_assistant_message(msg)
except Exception:
logger.exception("Error while streaming assistant messages")
yield self._stream_assistant_message(
AiErrorMessage(
content=(
"Oops, something went wrong and I cannot continue the conversation. "
"Please try again."
)
)
)

response = StreamingHttpResponse(
stream_assistant_messages(),
Expand Down Expand Up @@ -230,3 +255,51 @@ def get(self, request: Request, chat_uuid: str) -> Response:
serializer = AssistantChatMessagesSerializer({"messages": messages})

return Response(serializer.data)


class AssistantChatMessageFeedbackView(APIView):
@extend_schema(
tags=["AI Assistant"],
operation_id="submit_assistant_message_feedback",
description=(
"Provide sentiment and feedback for the given AI assistant chat message.\n\n"
"This is an **advanced/enterprise** feature."
),
responses={
200: None,
400: get_error_schema(
["ERROR_USER_NOT_IN_GROUP", "ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK"]
),
},
)
@validate_body(AssistantRateChatMessageSerializer, return_validated=True)
@map_exceptions(
{
UserNotInWorkspace: ERROR_USER_NOT_IN_GROUP,
WorkspaceDoesNotExist: ERROR_GROUP_DOES_NOT_EXIST,
AssistantChatDoesNotExist: ERROR_ASSISTANT_CHAT_DOES_NOT_EXIST,
AssistantChatMessagePredictionDoesNotExist: ERROR_CANNOT_SUBMIT_MESSAGE_FEEDBACK,
}
)
def put(self, request: Request, message_id: int, data) -> Response:
feature_flag_is_enabled(FF_ASSISTANT, raise_if_disabled=True)

handler = AssistantHandler()
message = handler.get_chat_message_by_id(request.user, message_id)
LicenseHandler.raise_if_user_doesnt_have_feature(
ASSISTANT, request.user, message.chat.workspace
)

try:
prediction: AssistantChatPrediction = message.prediction
except AttributeError:
raise AssistantChatMessagePredictionDoesNotExist(
f"Message with ID {message_id} does not have an associated prediction."
)

prediction.human_sentiment = data["sentiment"]
prediction.human_feedback = data.get("feedback") or ""
prediction.save(
update_fields=["human_sentiment", "human_feedback", "updated_on"]
)
return Response(status=HTTP_204_NO_CONTENT)
27 changes: 17 additions & 10 deletions enterprise/backend/src/baserow_enterprise/assistant/adapter.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import dspy

from .prompts import ASSISTANT_SYSTEM_PROMPT


class ChatAdapter(dspy.ChatAdapter):
def format_field_description(self, signature: type[dspy.Signature]) -> str:
"""
This is the first part of the prompt the LLM sees, so we prepend our custom
system prompt to it to give it the personality and context of Baserow.
"""
def get_chat_adapter():
import dspy # local import to save memory when not used

class ChatAdapter(dspy.ChatAdapter):
def format_field_description(self, signature: type[dspy.Signature]) -> str:
"""
This is the first part of the prompt the LLM sees, so we prepend our custom
system prompt to it to give it the personality and context of Baserow.
"""

field_description = super().format_field_description(signature)
return (
ASSISTANT_SYSTEM_PROMPT
+ "## TASK INSTRUCTIONS:\n\n"
+ field_description
)

field_description = super().format_field_description(signature)
return ASSISTANT_SYSTEM_PROMPT + "## TASK INSTRUCTIONS:\n\n" + field_description
return ChatAdapter()
Loading
Loading