33
44from django .conf import settings
55
6- import dspy
7- from dspy .primitives .prediction import Prediction
8- from dspy .streaming import StreamListener , StreamResponse
9- from dspy .utils .callback import BaseCallback
10- from litellm import get_supported_openai_params
11-
126from baserow .api .sessions import get_client_undo_redo_action_group_id
137from baserow_enterprise .assistant .exceptions import AssistantModelNotSupportedError
148from baserow_enterprise .assistant .tools .registries import assistant_tool_registry
159
16- from .adapter import ChatAdapter
10+ from .adapter import get_chat_adapter
1711from .models import AssistantChat , AssistantChatMessage , AssistantChatPrediction
18- from .react import ReAct
1912from .types import (
2013 AiMessage ,
2114 AiMessageChunk ,
2821)
2922
3023
31- class ChatSignature (dspy .Signature ):
32- question : str = dspy .InputField ()
33- history : dspy .History = dspy .InputField ()
34- ui_context : UIContext | None = dspy .InputField (
35- default = None ,
36- desc = (
37- "The frontend UI content the user is currently in. "
38- "Whenever make sense, use it to ground your answer."
39- ),
40- )
41- answer : str = dspy .OutputField ()
42-
43-
4424class AssistantMessagePair (TypedDict ):
4525 question : str
4626 answer : str
4727
4828
49- class AssistantCallbacks (BaseCallback ):
50- def __init__ (self ):
51- self .tool_calls = {}
52- self .sources = []
29+ def get_assistant_callbacks ():
30+ from dspy .utils .callback import BaseCallback
5331
54- def extend_sources (self , sources : list [str ]) -> None :
55- """
56- Extends the current list of sources with new ones, avoiding duplicates.
32+ class AssistantCallbacks (BaseCallback ):
33+ def __init__ (self ):
34+ self .tool_calls = {}
35+ self .sources = []
5736
58- :param sources: The list of new source URLs to add.
59- :return: None
60- """
37+ def extend_sources ( self , sources : list [ str ]) -> None :
38+ """
39+ Extends the current list of sources with new ones, avoiding duplicates.
6140
62- self .sources .extend ([s for s in sources if s not in self .sources ])
41+ :param sources: The list of new source URLs to add.
42+ :return: None
43+ """
6344
64- def on_tool_start (
65- self ,
66- call_id : str ,
67- instance : Any ,
68- inputs : dict [str , Any ],
69- ) -> None :
70- """
71- Called when a tool starts. It records the tool call and invokes the
72- corresponding tool's on_tool_start method if it exists.
45+ self .sources .extend ([s for s in sources if s not in self .sources ])
7346
74- :param call_id: The unique identifier of the tool call.
75- :param instance: The instance of the tool being called.
76- :param inputs: The inputs provided to the tool.
77- """
47+ def on_tool_start (
48+ self ,
49+ call_id : str ,
50+ instance : Any ,
51+ inputs : dict [str , Any ],
52+ ) -> None :
53+ """
54+ Called when a tool starts. It records the tool call and invokes the
55+ corresponding tool's on_tool_start method if it exists.
7856
79- try :
80- assistant_tool_registry .get (instance .name ).on_tool_start (
81- call_id , instance , inputs
57+ :param call_id: The unique identifier of the tool call.
58+ :param instance: The instance of the tool being called.
59+ :param inputs: The inputs provided to the tool.
60+ """
61+
62+ try :
63+ assistant_tool_registry .get (instance .name ).on_tool_start (
64+ call_id , instance , inputs
65+ )
66+ self .tool_calls [call_id ] = (instance , inputs )
67+ except assistant_tool_registry .does_not_exist_exception_class :
68+ pass
69+
70+ def on_tool_end (
71+ self ,
72+ call_id : str ,
73+ outputs : dict [str , Any ] | None ,
74+ exception : Exception | None = None ,
75+ ) -> None :
76+ """
77+ Called when a tool ends. It invokes the corresponding tool's on_tool_end
78+ method if it exists and updates the sources if the tool produced any.
79+
80+ :param call_id: The unique identifier of the tool call.
81+ :param outputs: The outputs returned by the tool, or None if there was an
82+ exception.
83+ :param exception: The exception raised by the tool, or None if it was
84+ successful.
85+ """
86+
87+ if call_id not in self .tool_calls :
88+ return
89+
90+ instance , inputs = self .tool_calls .pop (call_id )
91+ assistant_tool_registry .get (instance .name ).on_tool_end (
92+ call_id , instance , inputs , outputs , exception
8293 )
83- self .tool_calls [call_id ] = (instance , inputs )
84- except assistant_tool_registry .does_not_exist_exception_class :
85- pass
8694
87- def on_tool_end (
88- self ,
89- call_id : str ,
90- outputs : dict [str , Any ] | None ,
91- exception : Exception | None = None ,
92- ) -> None :
93- """
94- Called when a tool ends. It invokes the corresponding tool's on_tool_end
95- method if it exists and updates the sources if the tool produced any.
96-
97- :param call_id: The unique identifier of the tool call.
98- :param outputs: The outputs returned by the tool, or None if there was an
99- exception.
100- :param exception: The exception raised by the tool, or None if it was
101- successful.
102- """
95+ # If the tool produced sources, add them to the overall list of sources.
96+ if isinstance (outputs , dict ) and "sources" in outputs :
97+ self .extend_sources (outputs ["sources" ])
10398
104- if call_id not in self .tool_calls :
105- return
99+ return AssistantCallbacks ()
106100
107- instance , inputs = self .tool_calls .pop (call_id )
108- assistant_tool_registry .get (instance .name ).on_tool_end (
109- call_id , instance , inputs , outputs , exception
101+
102+ def get_chat_signature ():
103+ import dspy # local import to save memory when not used
104+
105+ class ChatSignature (dspy .Signature ):
106+ question : str = dspy .InputField ()
107+ history : dspy .History = dspy .InputField ()
108+ ui_context : UIContext | None = dspy .InputField (
109+ default = None ,
110+ desc = (
111+ "The frontend UI content the user is currently in. "
112+ "Whenever make sense, use it to ground your answer."
113+ ),
110114 )
115+ answer : str = dspy .OutputField ()
111116
112- # If the tool produced sources, add them to the overall list of sources.
113- if isinstance (outputs , dict ) and "sources" in outputs :
114- self .extend_sources (outputs ["sources" ])
117+ return ChatSignature
115118
116119
117120class Assistant :
@@ -120,17 +123,27 @@ def __init__(self, chat: AssistantChat):
120123 self ._user = chat .user
121124 self ._workspace = chat .workspace
122125
126+ self ._init_lm_client ()
127+ self ._init_assistant ()
128+
129+ def _init_lm_client (self ):
130+ import dspy # local import to save memory when not used
131+
123132 lm_model = settings .BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL
133+
124134 self ._lm_client = dspy .LM (
125135 model = lm_model ,
126136 cache = not settings .DEBUG ,
127137 max_retries = 5 ,
128138 )
129139
140+ def _init_assistant (self ):
141+ from .react import ReAct # local import to save memory when not used
142+
130143 tools = assistant_tool_registry .list_all_usable_tools (
131144 self ._user , self ._workspace
132145 )
133- self ._assistant = ReAct (ChatSignature , tools = tools )
146+ self ._assistant = ReAct (get_chat_signature () , tools = tools )
134147 self .history = None
135148
136149 async def acreate_chat_message (
@@ -216,6 +229,8 @@ async def aload_chat_history(self, limit=20):
216229 :return: None
217230 """
218231
232+ import dspy # local import to save memory when not used
233+
219234 last_saved_messages : list [AssistantChatMessage ] = [
220235 msg async for msg in self ._chat .messages .order_by ("-created_on" )[:limit ]
221236 ]
@@ -243,6 +258,9 @@ async def aload_chat_history(self, limit=20):
243258
244259 @lru_cache (maxsize = 1 )
245260 def check_llm_ready_or_raise (self ):
261+ import dspy # local import to save memory when not used
262+ from litellm import get_supported_openai_params
263+
246264 lm = self ._lm_client
247265 params = get_supported_openai_params (lm .model )
248266 if params is None or "tools" not in params :
@@ -270,13 +288,17 @@ async def astream_messages(
270288 :return: An async generator that yields the response messages.
271289 """
272290
273- callback_manager = AssistantCallbacks ()
291+ import dspy # local import to save memory when not used
292+ from dspy .primitives .prediction import Prediction
293+ from dspy .streaming import StreamListener , StreamResponse
294+
295+ callback_manager = get_assistant_callbacks ()
274296
275297 with dspy .context (
276298 lm = self ._lm_client ,
277299 cache = not settings .DEBUG ,
278300 callbacks = [* dspy .settings .config .callbacks , callback_manager ],
279- adapter = ChatAdapter (),
301+ adapter = get_chat_adapter (),
280302 ):
281303 if self .history is None :
282304 await self .aload_chat_history ()
0 commit comments