Skip to content

Commit 137aa1a

Browse files
authored
Lazy load dspy to save memory on startup (baserow#4117)
1 parent 2765525 commit 137aa1a

File tree

12 files changed

+317
-140
lines changed

12 files changed

+317
-140
lines changed
Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
import dspy
2-
31
from .prompts import ASSISTANT_SYSTEM_PROMPT
42

53

6-
class ChatAdapter(dspy.ChatAdapter):
7-
def format_field_description(self, signature: type[dspy.Signature]) -> str:
8-
"""
9-
This is the first part of the prompt the LLM sees, so we prepend our custom
10-
system prompt to it to give it the personality and context of Baserow.
11-
"""
4+
def get_chat_adapter():
5+
import dspy # local import to save memory when not used
6+
7+
class ChatAdapter(dspy.ChatAdapter):
8+
def format_field_description(self, signature: type[dspy.Signature]) -> str:
9+
"""
10+
This is the first part of the prompt the LLM sees, so we prepend our custom
11+
system prompt to it to give it the personality and context of Baserow.
12+
"""
13+
14+
field_description = super().format_field_description(signature)
15+
return (
16+
ASSISTANT_SYSTEM_PROMPT
17+
+ "## TASK INSTRUCTIONS:\n\n"
18+
+ field_description
19+
)
1220

13-
field_description = super().format_field_description(signature)
14-
return ASSISTANT_SYSTEM_PROMPT + "## TASK INSTRUCTIONS:\n\n" + field_description
21+
return ChatAdapter()

enterprise/backend/src/baserow_enterprise/assistant/assistant.py

Lines changed: 100 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33

44
from django.conf import settings
55

6-
import dspy
7-
from dspy.primitives.prediction import Prediction
8-
from dspy.streaming import StreamListener, StreamResponse
9-
from dspy.utils.callback import BaseCallback
10-
from litellm import get_supported_openai_params
11-
126
from baserow.api.sessions import get_client_undo_redo_action_group_id
137
from baserow_enterprise.assistant.exceptions import AssistantModelNotSupportedError
148
from baserow_enterprise.assistant.tools.registries import assistant_tool_registry
159

16-
from .adapter import ChatAdapter
10+
from .adapter import get_chat_adapter
1711
from .models import AssistantChat, AssistantChatMessage, AssistantChatPrediction
18-
from .react import ReAct
1912
from .types import (
2013
AiMessage,
2114
AiMessageChunk,
@@ -28,90 +21,100 @@
2821
)
2922

3023

31-
class ChatSignature(dspy.Signature):
32-
question: str = dspy.InputField()
33-
history: dspy.History = dspy.InputField()
34-
ui_context: UIContext | None = dspy.InputField(
35-
default=None,
36-
desc=(
37-
"The frontend UI content the user is currently in. "
38-
"Whenever make sense, use it to ground your answer."
39-
),
40-
)
41-
answer: str = dspy.OutputField()
42-
43-
4424
class AssistantMessagePair(TypedDict):
4525
question: str
4626
answer: str
4727

4828

49-
class AssistantCallbacks(BaseCallback):
50-
def __init__(self):
51-
self.tool_calls = {}
52-
self.sources = []
29+
def get_assistant_callbacks():
30+
from dspy.utils.callback import BaseCallback
5331

54-
def extend_sources(self, sources: list[str]) -> None:
55-
"""
56-
Extends the current list of sources with new ones, avoiding duplicates.
32+
class AssistantCallbacks(BaseCallback):
33+
def __init__(self):
34+
self.tool_calls = {}
35+
self.sources = []
5736

58-
:param sources: The list of new source URLs to add.
59-
:return: None
60-
"""
37+
def extend_sources(self, sources: list[str]) -> None:
38+
"""
39+
Extends the current list of sources with new ones, avoiding duplicates.
6140
62-
self.sources.extend([s for s in sources if s not in self.sources])
41+
:param sources: The list of new source URLs to add.
42+
:return: None
43+
"""
6344

64-
def on_tool_start(
65-
self,
66-
call_id: str,
67-
instance: Any,
68-
inputs: dict[str, Any],
69-
) -> None:
70-
"""
71-
Called when a tool starts. It records the tool call and invokes the
72-
corresponding tool's on_tool_start method if it exists.
45+
self.sources.extend([s for s in sources if s not in self.sources])
7346

74-
:param call_id: The unique identifier of the tool call.
75-
:param instance: The instance of the tool being called.
76-
:param inputs: The inputs provided to the tool.
77-
"""
47+
def on_tool_start(
48+
self,
49+
call_id: str,
50+
instance: Any,
51+
inputs: dict[str, Any],
52+
) -> None:
53+
"""
54+
Called when a tool starts. It records the tool call and invokes the
55+
corresponding tool's on_tool_start method if it exists.
7856
79-
try:
80-
assistant_tool_registry.get(instance.name).on_tool_start(
81-
call_id, instance, inputs
57+
:param call_id: The unique identifier of the tool call.
58+
:param instance: The instance of the tool being called.
59+
:param inputs: The inputs provided to the tool.
60+
"""
61+
62+
try:
63+
assistant_tool_registry.get(instance.name).on_tool_start(
64+
call_id, instance, inputs
65+
)
66+
self.tool_calls[call_id] = (instance, inputs)
67+
except assistant_tool_registry.does_not_exist_exception_class:
68+
pass
69+
70+
def on_tool_end(
71+
self,
72+
call_id: str,
73+
outputs: dict[str, Any] | None,
74+
exception: Exception | None = None,
75+
) -> None:
76+
"""
77+
Called when a tool ends. It invokes the corresponding tool's on_tool_end
78+
method if it exists and updates the sources if the tool produced any.
79+
80+
:param call_id: The unique identifier of the tool call.
81+
:param outputs: The outputs returned by the tool, or None if there was an
82+
exception.
83+
:param exception: The exception raised by the tool, or None if it was
84+
successful.
85+
"""
86+
87+
if call_id not in self.tool_calls:
88+
return
89+
90+
instance, inputs = self.tool_calls.pop(call_id)
91+
assistant_tool_registry.get(instance.name).on_tool_end(
92+
call_id, instance, inputs, outputs, exception
8293
)
83-
self.tool_calls[call_id] = (instance, inputs)
84-
except assistant_tool_registry.does_not_exist_exception_class:
85-
pass
8694

87-
def on_tool_end(
88-
self,
89-
call_id: str,
90-
outputs: dict[str, Any] | None,
91-
exception: Exception | None = None,
92-
) -> None:
93-
"""
94-
Called when a tool ends. It invokes the corresponding tool's on_tool_end
95-
method if it exists and updates the sources if the tool produced any.
96-
97-
:param call_id: The unique identifier of the tool call.
98-
:param outputs: The outputs returned by the tool, or None if there was an
99-
exception.
100-
:param exception: The exception raised by the tool, or None if it was
101-
successful.
102-
"""
95+
# If the tool produced sources, add them to the overall list of sources.
96+
if isinstance(outputs, dict) and "sources" in outputs:
97+
self.extend_sources(outputs["sources"])
10398

104-
if call_id not in self.tool_calls:
105-
return
99+
return AssistantCallbacks()
106100

107-
instance, inputs = self.tool_calls.pop(call_id)
108-
assistant_tool_registry.get(instance.name).on_tool_end(
109-
call_id, instance, inputs, outputs, exception
101+
102+
def get_chat_signature():
103+
import dspy # local import to save memory when not used
104+
105+
class ChatSignature(dspy.Signature):
106+
question: str = dspy.InputField()
107+
history: dspy.History = dspy.InputField()
108+
ui_context: UIContext | None = dspy.InputField(
109+
default=None,
110+
desc=(
111+
"The frontend UI content the user is currently in. "
112+
"Whenever make sense, use it to ground your answer."
113+
),
110114
)
115+
answer: str = dspy.OutputField()
111116

112-
# If the tool produced sources, add them to the overall list of sources.
113-
if isinstance(outputs, dict) and "sources" in outputs:
114-
self.extend_sources(outputs["sources"])
117+
return ChatSignature
115118

116119

117120
class Assistant:
@@ -120,17 +123,27 @@ def __init__(self, chat: AssistantChat):
120123
self._user = chat.user
121124
self._workspace = chat.workspace
122125

126+
self._init_lm_client()
127+
self._init_assistant()
128+
129+
def _init_lm_client(self):
130+
import dspy # local import to save memory when not used
131+
123132
lm_model = settings.BASEROW_ENTERPRISE_ASSISTANT_LLM_MODEL
133+
124134
self._lm_client = dspy.LM(
125135
model=lm_model,
126136
cache=not settings.DEBUG,
127137
max_retries=5,
128138
)
129139

140+
def _init_assistant(self):
141+
from .react import ReAct # local import to save memory when not used
142+
130143
tools = assistant_tool_registry.list_all_usable_tools(
131144
self._user, self._workspace
132145
)
133-
self._assistant = ReAct(ChatSignature, tools=tools)
146+
self._assistant = ReAct(get_chat_signature(), tools=tools)
134147
self.history = None
135148

136149
async def acreate_chat_message(
@@ -216,6 +229,8 @@ async def aload_chat_history(self, limit=20):
216229
:return: None
217230
"""
218231

232+
import dspy # local import to save memory when not used
233+
219234
last_saved_messages: list[AssistantChatMessage] = [
220235
msg async for msg in self._chat.messages.order_by("-created_on")[:limit]
221236
]
@@ -243,6 +258,9 @@ async def aload_chat_history(self, limit=20):
243258

244259
@lru_cache(maxsize=1)
245260
def check_llm_ready_or_raise(self):
261+
import dspy # local import to save memory when not used
262+
from litellm import get_supported_openai_params
263+
246264
lm = self._lm_client
247265
params = get_supported_openai_params(lm.model)
248266
if params is None or "tools" not in params:
@@ -270,13 +288,17 @@ async def astream_messages(
270288
:return: An async generator that yields the response messages.
271289
"""
272290

273-
callback_manager = AssistantCallbacks()
291+
import dspy # local import to save memory when not used
292+
from dspy.primitives.prediction import Prediction
293+
from dspy.streaming import StreamListener, StreamResponse
294+
295+
callback_manager = get_assistant_callbacks()
274296

275297
with dspy.context(
276298
lm=self._lm_client,
277299
cache=not settings.DEBUG,
278300
callbacks=[*dspy.settings.config.callbacks, callback_manager],
279-
adapter=ChatAdapter(),
301+
adapter=get_chat_adapter(),
280302
):
281303
if self.history is None:
282304
await self.aload_chat_history()

enterprise/backend/src/baserow_enterprise/assistant/react.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Any, Callable, Literal
1+
from typing import Any, Callable, Literal
22

33
import dspy
44
from dspy.adapters.types.tool import Tool
@@ -10,20 +10,14 @@
1010

1111
from .types import ToolsUpgradeResponse
1212

13-
if TYPE_CHECKING:
14-
from dspy.signatures.signature import Signature
15-
16-
1713
# Variant of dspy.predict.react.ReAct that accepts a "meta-tool":
1814
# a callable that can produce tools at runtime (e.g. per-table schemas).
1915
# This lets a single ReAct instance handle many different table signatures
2016
# without creating a new agent for each request.
2117

2218

2319
class ReAct(Module):
24-
def __init__(
25-
self, signature: type["Signature"], tools: list[Callable], max_iters: int = 100
26-
):
20+
def __init__(self, signature, tools: list[Callable], max_iters: int = 100):
2721
"""
2822
ReAct stands for "Reasoning and Acting," a popular paradigm for building
2923
tool-using agents. In this approach, the language model is iteratively provided

enterprise/backend/src/baserow_enterprise/assistant/tools/database/tools.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from django.db import transaction
66
from django.utils.translation import gettext as _
77

8-
import dspy
98
from loguru import logger
109
from pydantic import create_model
1110

@@ -25,9 +24,9 @@
2524
from baserow_enterprise.assistant.tools.registries import AssistantToolType, ToolHelpers
2625
from baserow_enterprise.assistant.types import (
2726
TableNavigationType,
28-
ToolSignature,
2927
ToolsUpgradeResponse,
3028
ViewNavigationType,
29+
get_tool_signature,
3130
)
3231

3332
from . import utils
@@ -284,6 +283,8 @@ def create_tables(
284283
- if add_sample_rows is True (default), add some example rows to each table
285284
"""
286285

286+
import dspy # local import to save memory when not used
287+
287288
nonlocal user, workspace, tool_helpers
288289

289290
if not tables:
@@ -354,7 +355,7 @@ def create_tables(
354355
f"- Create 5 example rows for table_{created_table.id}. Fill every relationship with valid data when possible."
355356
)
356357

357-
predictor = dspy.Predict(ToolSignature)
358+
predictor = dspy.Predict(get_tool_signature())
358359
result = predictor(
359360
question=("\n".join(instructions)),
360361
tools=list(tools.values()),

enterprise/backend/src/baserow_enterprise/assistant/tools/database/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from django.db.models import Q
88
from django.utils.translation import gettext as _
99

10-
import dspy
11-
from dspy.adapters.types.tool import _resolve_json_schema_reference
1210
from pydantic import ConfigDict, Field, create_model
1311

1412
from baserow.contrib.database.fields.actions import CreateFieldActionType
@@ -385,6 +383,9 @@ def get_view(user, view_id: int):
385383
def get_table_rows_tools(
386384
user, workspace: Workspace, tool_helpers: ToolHelpers, table: Table
387385
):
386+
import dspy # local import to save memory when not used
387+
from dspy.adapters.types.tool import _resolve_json_schema_reference
388+
388389
row_model_for_create = get_create_row_model(table)
389390
row_model_for_update = get_update_row_model(table)
390391
row_model_for_response = create_model(

enterprise/backend/src/baserow_enterprise/assistant/tools/navigation/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from dspy.dsp.utils.settings import settings
2-
from dspy.streaming.messages import sync_send_to_stream
3-
41
from baserow_enterprise.assistant.types import AiNavigationMessage, AnyNavigationType
52

63

@@ -13,6 +10,9 @@ def unsafe_navigate_to(location: AnyNavigationType) -> str:
1310
:param navigation_type: The type of navigation to perform.
1411
"""
1512

13+
from dspy.dsp.utils.settings import settings
14+
from dspy.streaming.messages import sync_send_to_stream
15+
1616
stream = settings.send_stream
1717
if stream is not None:
1818
sync_send_to_stream(stream, AiNavigationMessage(location=location))

0 commit comments

Comments
 (0)