Skip to content

Commit c15b04f

Browse files
committed
perf: optimize hot paths with caching and O(1) operations
- Replace list.pop(0) with deque.popleft() for O(1) queue dequeue - Cache compiled regex patterns in ResourceTemplate for URI matching - Cache field info mapping in FuncMetadata via lazy property - Throttle expired task cleanup with interval-based execution These optimizations target high-frequency operations in message queuing, resource lookups, tool calls, and task store access.
1 parent a9cc822 commit c15b04f

File tree

4 files changed

+49
-95
lines changed

4 files changed

+49
-95
lines changed

src/mcp/server/fastmcp/resources/templates.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ResourceTemplate(BaseModel):
3333
fn: Callable[..., Any] = Field(exclude=True)
3434
parameters: dict[str, Any] = Field(description="JSON schema for function parameters")
3535
context_kwarg: str | None = Field(None, description="Name of the kwarg that should receive context")
36+
_uri_pattern: re.Pattern[str] | None = None
3637

3738
@classmethod
3839
def from_function(
@@ -66,7 +67,10 @@ def from_function(
6667
# ensure the arguments are properly cast
6768
fn = validate_call(fn)
6869

69-
return cls(
70+
pattern_str = uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
71+
compiled_pattern = re.compile(f"^{pattern_str}$")
72+
73+
instance = cls(
7074
uri_template=uri_template,
7175
name=func_name,
7276
title=title,
@@ -78,15 +82,15 @@ def from_function(
7882
parameters=parameters,
7983
context_kwarg=context_kwarg,
8084
)
85+
instance._uri_pattern = compiled_pattern
86+
return instance
8187

8288
def matches(self, uri: str) -> dict[str, Any] | None:
83-
"""Check if URI matches template and extract parameters."""
84-
# Convert template to regex pattern
85-
pattern = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
86-
match = re.match(f"^{pattern}$", uri)
87-
if match:
88-
return match.groupdict()
89-
return None
89+
if self._uri_pattern is None:
90+
pattern_str = self.uri_template.replace("{", "(?P<").replace("}", ">[^/]+)")
91+
self._uri_pattern = re.compile(f"^{pattern_str}$")
92+
match = self._uri_pattern.match(uri)
93+
return match.groupdict() if match else None
9094

9195
async def create_resource(
9296
self,

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,18 @@ class FuncMetadata(BaseModel):
7070
output_schema: dict[str, Any] | None = None
7171
output_model: Annotated[type[BaseModel], WithJsonSchema(None)] | None = None
7272
wrap_output: bool = False
73+
_key_to_field_info: dict[str, FieldInfo] | None = None
74+
75+
@property
76+
def key_to_field_info(self) -> dict[str, FieldInfo]:
77+
if self._key_to_field_info is None:
78+
mapping: dict[str, FieldInfo] = {}
79+
for field_name, field_info in self.arg_model.model_fields.items():
80+
mapping[field_name] = field_info
81+
if field_info.alias:
82+
mapping[field_info.alias] = field_info
83+
self._key_to_field_info = mapping
84+
return self._key_to_field_info
7385

7486
async def call_fn_with_arg_validation(
7587
self,
@@ -141,30 +153,19 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
141153
it seems incapable of NOT doing this. For sub-models, it tends to pass
142154
dicts (JSON objects) as JSON strings, which can be pre-parsed here.
143155
"""
144-
new_data = data.copy() # Shallow copy
145-
146-
# Build a mapping from input keys (including aliases) to field info
147-
key_to_field_info: dict[str, FieldInfo] = {}
148-
for field_name, field_info in self.arg_model.model_fields.items():
149-
# Map both the field name and its alias (if any) to the field info
150-
key_to_field_info[field_name] = field_info
151-
if field_info.alias:
152-
key_to_field_info[field_info.alias] = field_info
156+
new_data = data.copy()
153157

154158
for data_key, data_value in data.items():
155-
if data_key not in key_to_field_info: # pragma: no cover
159+
if data_key not in self.key_to_field_info: # pragma: no cover
156160
continue
157161

158-
field_info = key_to_field_info[data_key]
162+
field_info = self.key_to_field_info[data_key]
159163
if isinstance(data_value, str) and field_info.annotation is not str:
160164
try:
161165
pre_parsed = json.loads(data_value)
162166
except json.JSONDecodeError:
163-
continue # Not JSON - skip
167+
continue
164168
if isinstance(pre_parsed, str | int | float):
165-
# This is likely that the raw value is e.g. `"hello"` which we
166-
# Should really be parsed as '"hello"' in Python - but if we parse
167-
# it as JSON it'll turn into just 'hello'. So we skip it.
168169
continue
169170
new_data[data_key] = pre_parsed
170171
assert new_data.keys() == data.keys()

src/mcp/shared/experimental/tasks/in_memory_task_store.py

Lines changed: 14 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
from mcp.shared.experimental.tasks.store import TaskStore
1818
from mcp.types import Result, Task, TaskMetadata, TaskStatus
1919

20+
CLEANUP_INTERVAL_SECONDS = 1.0
21+
2022

2123
@dataclass
2224
class StoredTask:
23-
"""Internal storage representation of a task."""
24-
2525
task: Task
2626
result: Result | None = None
27-
# Time when this task should be removed (None = never)
2827
expires_at: datetime | None = field(default=None)
2928

3029

@@ -49,21 +48,26 @@ def __init__(self, page_size: int = 10) -> None:
4948
self._tasks: dict[str, StoredTask] = {}
5049
self._page_size = page_size
5150
self._update_events: dict[str, anyio.Event] = {}
51+
self._last_cleanup: datetime | None = None
5252

5353
def _calculate_expiry(self, ttl_ms: int | None) -> datetime | None:
54-
"""Calculate expiry time from TTL in milliseconds."""
5554
if ttl_ms is None:
5655
return None
5756
return datetime.now(timezone.utc) + timedelta(milliseconds=ttl_ms)
5857

5958
def _is_expired(self, stored: StoredTask) -> bool:
60-
"""Check if a task has expired."""
6159
if stored.expires_at is None:
6260
return False
6361
return datetime.now(timezone.utc) >= stored.expires_at
6462

6563
def _cleanup_expired(self) -> None:
66-
"""Remove all expired tasks. Called lazily during access operations."""
64+
now = datetime.now(timezone.utc)
65+
if self._last_cleanup is not None:
66+
elapsed = (now - self._last_cleanup).total_seconds()
67+
if elapsed < CLEANUP_INTERVAL_SECONDS:
68+
return
69+
70+
self._last_cleanup = now
6771
expired_ids = [task_id for task_id, stored in self._tasks.items() if self._is_expired(stored)]
6872
for task_id in expired_ids:
6973
del self._tasks[task_id]
@@ -73,34 +77,21 @@ async def create_task(
7377
metadata: TaskMetadata,
7478
task_id: str | None = None,
7579
) -> Task:
76-
"""Create a new task with the given metadata."""
77-
# Cleanup expired tasks on access
7880
self._cleanup_expired()
79-
8081
task = create_task_state(metadata, task_id)
8182

8283
if task.taskId in self._tasks:
8384
raise ValueError(f"Task with ID {task.taskId} already exists")
8485

85-
stored = StoredTask(
86-
task=task,
87-
expires_at=self._calculate_expiry(metadata.ttl),
88-
)
86+
stored = StoredTask(task=task, expires_at=self._calculate_expiry(metadata.ttl))
8987
self._tasks[task.taskId] = stored
90-
91-
# Return a copy to prevent external modification
9288
return Task(**task.model_dump())
9389

9490
async def get_task(self, task_id: str) -> Task | None:
95-
"""Get a task by ID."""
96-
# Cleanup expired tasks on access
9791
self._cleanup_expired()
98-
9992
stored = self._tasks.get(task_id)
10093
if stored is None:
10194
return None
102-
103-
# Return a copy to prevent external modification
10495
return Task(**stored.task.model_dump())
10596

10697
async def update_task(
@@ -109,12 +100,10 @@ async def update_task(
109100
status: TaskStatus | None = None,
110101
status_message: str | None = None,
111102
) -> Task:
112-
"""Update a task's status and/or message."""
113103
stored = self._tasks.get(task_id)
114104
if stored is None:
115105
raise ValueError(f"Task with ID {task_id} not found")
116106

117-
# Per spec: Terminal states MUST NOT transition to any other status
118107
if status is not None and status != stored.task.status and is_terminal(stored.task.status):
119108
raise ValueError(f"Cannot transition from terminal status '{stored.task.status}'")
120109

@@ -126,94 +115,69 @@ async def update_task(
126115
if status_message is not None:
127116
stored.task.statusMessage = status_message
128117

129-
# Update lastUpdatedAt on any change
130118
stored.task.lastUpdatedAt = datetime.now(timezone.utc)
131119

132-
# If task is now terminal and has TTL, reset expiry timer
133120
if status is not None and is_terminal(status) and stored.task.ttl is not None:
134121
stored.expires_at = self._calculate_expiry(stored.task.ttl)
135122

136-
# Notify waiters if status changed
137123
if status_changed:
138124
await self.notify_update(task_id)
139125

140126
return Task(**stored.task.model_dump())
141127

142128
async def store_result(self, task_id: str, result: Result) -> None:
143-
"""Store the result for a task."""
144129
stored = self._tasks.get(task_id)
145130
if stored is None:
146131
raise ValueError(f"Task with ID {task_id} not found")
147-
148132
stored.result = result
149133

150134
async def get_result(self, task_id: str) -> Result | None:
151-
"""Get the stored result for a task."""
152135
stored = self._tasks.get(task_id)
153-
if stored is None:
154-
return None
155-
156-
return stored.result
136+
return stored.result if stored else None
157137

158138
async def list_tasks(
159139
self,
160140
cursor: str | None = None,
161141
) -> tuple[list[Task], str | None]:
162-
"""List tasks with pagination."""
163-
# Cleanup expired tasks on access
164142
self._cleanup_expired()
165-
166143
all_task_ids = list(self._tasks.keys())
167144

168145
start_index = 0
169146
if cursor is not None:
170147
try:
171-
cursor_index = all_task_ids.index(cursor)
172-
start_index = cursor_index + 1
148+
start_index = all_task_ids.index(cursor) + 1
173149
except ValueError:
174150
raise ValueError(f"Invalid cursor: {cursor}")
175151

176152
page_task_ids = all_task_ids[start_index : start_index + self._page_size]
177153
tasks = [Task(**self._tasks[tid].task.model_dump()) for tid in page_task_ids]
178154

179-
# Determine next cursor
180155
next_cursor = None
181156
if start_index + self._page_size < len(all_task_ids) and page_task_ids:
182157
next_cursor = page_task_ids[-1]
183158

184159
return tasks, next_cursor
185160

186161
async def delete_task(self, task_id: str) -> bool:
187-
"""Delete a task."""
188162
if task_id not in self._tasks:
189163
return False
190-
191164
del self._tasks[task_id]
192165
return True
193166

194167
async def wait_for_update(self, task_id: str) -> None:
195-
"""Wait until the task status changes."""
196168
if task_id not in self._tasks:
197169
raise ValueError(f"Task with ID {task_id} not found")
198-
199-
# Create a fresh event for waiting (anyio.Event can't be cleared)
200170
self._update_events[task_id] = anyio.Event()
201-
event = self._update_events[task_id]
202-
await event.wait()
171+
await self._update_events[task_id].wait()
203172

204173
async def notify_update(self, task_id: str) -> None:
205-
"""Signal that a task has been updated."""
206174
if task_id in self._update_events:
207175
self._update_events[task_id].set()
208176

209-
# --- Testing/debugging helpers ---
210-
211177
def cleanup(self) -> None:
212-
"""Cleanup all tasks (useful for testing or graceful shutdown)."""
213178
self._tasks.clear()
214179
self._update_events.clear()
215180

216181
def get_all_tasks(self) -> list[Task]:
217-
"""Get all tasks (useful for debugging). Returns copies to prevent modification."""
218182
self._cleanup_expired()
219183
return [Task(**stored.task.model_dump()) for stored in self._tasks.values()]

src/mcp/shared/experimental/tasks/message_queue.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
from abc import ABC, abstractmethod
16+
from collections import deque
1617
from dataclasses import dataclass, field
1718
from datetime import datetime, timezone
1819
from typing import Any, Literal
@@ -162,67 +163,51 @@ class InMemoryTaskMessageQueue(TaskMessageQueue):
162163
"""
163164

164165
def __init__(self) -> None:
165-
self._queues: dict[str, list[QueuedMessage]] = {}
166+
self._queues: dict[str, deque[QueuedMessage]] = {}
166167
self._events: dict[str, anyio.Event] = {}
167168

168-
def _get_queue(self, task_id: str) -> list[QueuedMessage]:
169-
"""Get or create the queue for a task."""
169+
def _get_queue(self, task_id: str) -> deque[QueuedMessage]:
170170
if task_id not in self._queues:
171-
self._queues[task_id] = []
171+
self._queues[task_id] = deque()
172172
return self._queues[task_id]
173173

174174
async def enqueue(self, task_id: str, message: QueuedMessage) -> None:
175-
"""Add a message to the queue."""
176175
queue = self._get_queue(task_id)
177176
queue.append(message)
178-
# Signal that a message is available
179177
await self.notify_message_available(task_id)
180178

181179
async def dequeue(self, task_id: str) -> QueuedMessage | None:
182-
"""Remove and return the next message."""
183180
queue = self._get_queue(task_id)
184181
if not queue:
185182
return None
186-
return queue.pop(0)
183+
return queue.popleft()
187184

188185
async def peek(self, task_id: str) -> QueuedMessage | None:
189-
"""Return the next message without removing it."""
190186
queue = self._get_queue(task_id)
191-
if not queue:
192-
return None
193-
return queue[0]
187+
return queue[0] if queue else None
194188

195189
async def is_empty(self, task_id: str) -> bool:
196-
"""Check if the queue is empty."""
197-
queue = self._get_queue(task_id)
198-
return len(queue) == 0
190+
return len(self._get_queue(task_id)) == 0
199191

200192
async def clear(self, task_id: str) -> list[QueuedMessage]:
201-
"""Remove and return all messages."""
202193
queue = self._get_queue(task_id)
203194
messages = list(queue)
204195
queue.clear()
205196
return messages
206197

207198
async def wait_for_message(self, task_id: str) -> None:
208-
"""Wait until a message is available."""
209-
# Check if there are already messages
210199
if not await self.is_empty(task_id):
211200
return
212201

213-
# Create a fresh event for waiting (anyio.Event can't be cleared)
214202
self._events[task_id] = anyio.Event()
215203
event = self._events[task_id]
216204

217-
# Double-check after creating event (avoid race condition)
218205
if not await self.is_empty(task_id):
219206
return
220207

221-
# Wait for a new message
222208
await event.wait()
223209

224210
async def notify_message_available(self, task_id: str) -> None:
225-
"""Signal that a message is available."""
226211
if task_id in self._events:
227212
self._events[task_id].set()
228213

0 commit comments

Comments
 (0)