diff --git a/dash/_callback.py b/dash/_callback.py index 319272995e..fdd92e4a7a 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib import inspect +from datetime import datetime, timezone from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict @@ -425,6 +426,12 @@ def _setup_background_callback( ctx_value, ) + callback_manager.handle.set( + f"{cache_key}-created_at", + datetime.now(timezone.utc).isoformat(), + expire=callback_manager.expire, + ) + data = { "cacheKey": cache_key, "job": job, diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 64323c09af..d4f91ddd48 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -33,6 +33,7 @@ list_tools, read_resource, ) +from dash.mcp.tasks import get_task, get_task_result, cancel_task from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, ) @@ -163,11 +164,16 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: "initialize": _handle_initialize, "tools/list": list_tools, "tools/call": lambda: call_tool( - params.get("name", ""), params.get("arguments", {}) + tool_name=params.get("name", ""), + arguments=params.get("arguments", {}), + task=params.get("task"), ), "resources/list": list_resources, "resources/templates/list": list_resource_templates, "resources/read": lambda: read_resource(params.get("uri", "")), + "tasks/get": lambda: get_task(task_id=params.get("taskId", "")), + "tasks/result": lambda: get_task_result(task_id=params.get("taskId", "")), + "tasks/cancel": lambda: cancel_task(task_id=params.get("taskId", "")), } try: diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index b8f12a1dbd..eea7af43c1 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -4,17 +4,19 @@ from typing import Any -from mcp.types import CallToolResult, ListToolsResult +from mcp.types import CallToolResult, CreateTaskResult, ListToolsResult from dash.mcp.types import ToolNotFoundError from .base import MCPToolProvider +from .tool_background_tasks import BackgroundTaskTools from .tool_decorated_mcp_functions import DecoratedFunctionTools from .tool_get_dash_component import GetDashComponentTool from .tools_callbacks import CallbackTools _TOOL_PROVIDERS: list[type[MCPToolProvider]] = [ CallbackTools, + BackgroundTaskTools, GetDashComponentTool, DecoratedFunctionTools, ] @@ -28,11 +30,17 @@ def list_tools() -> ListToolsResult: return ListToolsResult(tools=tools) -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: - """Route a tools/call request by tool name.""" +def call_tool( + tool_name: str, arguments: dict[str, Any], task: dict | None = None +) -> CallToolResult | CreateTaskResult: + """Route a tools/call request by tool name. + + The optional ``task`` parameter (per MCP Tasks protocol) is passed + through to providers that support background callbacks. + """ for provider in _TOOL_PROVIDERS: if tool_name in provider.get_tool_names(): - return provider.call_tool(tool_name, arguments) + return provider.call_tool(tool_name, arguments, task=task) raise ToolNotFoundError( f"Tool not found: {tool_name}." " The app's callbacks may have changed." diff --git a/dash/mcp/primitives/tools/base.py b/dash/mcp/primitives/tools/base.py index 60fa7374d6..f7a5c54aac 100644 --- a/dash/mcp/primitives/tools/base.py +++ b/dash/mcp/primitives/tools/base.py @@ -4,7 +4,7 @@ from typing import Any -from mcp.types import CallToolResult, Tool +from mcp.types import CallToolResult, CreateTaskResult, Tool class MCPToolProvider: @@ -24,5 +24,7 @@ def list_tools(cls) -> list[Tool]: raise NotImplementedError @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult | CreateTaskResult: raise NotImplementedError diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index b32238992c..a4227868d8 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -13,6 +13,7 @@ from typing import TYPE_CHECKING from .base import ToolDescriptionSource +from .description_background_callbacks import BackgroundCallbackDescription from .description_docstring import DocstringDescription from .description_outputs import OutputSummaryDescription @@ -22,6 +23,7 @@ _SOURCES: list[type[ToolDescriptionSource]] = [ OutputSummaryDescription, DocstringDescription, + BackgroundCallbackDescription, ] diff --git a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py new file mode 100644 index 0000000000..eada24a01e --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py @@ -0,0 +1,32 @@ +"""Description for background (long-running) callbacks. + +Informs the LLM that the tool returns a taskId immediately +and must be polled via the background task result tool. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..tool_background_tasks import GET_RESULT_TOOL_NAME +from .base import ToolDescriptionSource + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +class BackgroundCallbackDescription(ToolDescriptionSource): + """Add async polling instructions for background callbacks.""" + + @classmethod + def describe(cls, callback: CallbackAdapter) -> list[str]: + # pylint: disable-next=protected-access + if not callback._cb_info.get("background"): + return [] + + return [ + "", + "This is a long-running background operation. " + "It returns a taskId immediately. " + f"Call tool `{GET_RESULT_TOOL_NAME}` with the taskId to poll for the result.", + ] diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index 09e86410a7..12f9507da7 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -7,17 +7,19 @@ from __future__ import annotations import json -from typing import Any +from typing import TYPE_CHECKING, Any -from mcp.types import CallToolResult, TextContent +from mcp.types import CallToolResult, CreateTaskResult, TextContent from dash.types import CallbackExecutionResponse -from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter from .base import ResultFormatter from .result_dataframe import DataFrameResult from .result_plotly_figure import PlotlyFigureResult +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + _RESULT_FORMATTERS: list[type[ResultFormatter]] = [ PlotlyFigureResult, DataFrameResult, @@ -50,3 +52,32 @@ def format_callback_response( content=content, structuredContent=dict(response), ) + + +def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallToolResult: + """Wrap a CreateTaskResult as a CallToolResult with polling instructions. + + MCP Tasks are not yet supported by LLM clients, so this converts the + task metadata into a tool response that guides the LLM to poll via + the get_background_task_result tool. + """ + task = create_task_result.task + return CallToolResult( + content=[ + TextContent( + type="text", + text=json.dumps( + { + "taskId": task.taskId, + "status": task.status, + "pollInterval": task.pollInterval, + "message": ( + "This is a long-running background callback. " + "Call the get_background_task_result tool with this taskId " + "to poll for the result." + ), + } + ), + ) + ], + ) diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py new file mode 100644 index 0000000000..a4dffa23f4 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -0,0 +1,102 @@ +"""Built-in tools for background callback task lifecycle. + +Thin wrappers around the spec-aligned core in dash.mcp.tasks. +Only registered when the app has background callbacks. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp.tasks import get_task, get_task_result, cancel_task + +from .base import MCPToolProvider + + +GET_RESULT_TOOL_NAME = "get_background_task_result" +CANCEL_TOOL_NAME = "cancel_background_task" + + +def _has_background_callbacks() -> bool: + return any(cb_info.get("background") for cb_info in get_app().callback_map.values()) + + +class BackgroundTaskTools(MCPToolProvider): + """Built-in tools for polling and cancelling background callback tasks. + + Only registered when the app has background callbacks. + """ + + @classmethod + def get_tool_names(cls) -> set[str]: + if not _has_background_callbacks(): + return set() + return {GET_RESULT_TOOL_NAME, CANCEL_TOOL_NAME} + + @classmethod + def list_tools(cls) -> list[Tool]: + if not _has_background_callbacks(): + return [] + return [ + Tool( + name=GET_RESULT_TOOL_NAME, + description=( + "Poll for the result of a long-running background callback. " + "Pass the taskId returned by the original tool call. " + "If the task is still running, call this tool again. " + "If complete, returns the callback result." + ), + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId returned by the background callback tool.", + }, + }, + "required": ["taskId"], + }, + ), + Tool( + name=CANCEL_TOOL_NAME, + description="Cancel a running background callback.", + inputSchema={ + "type": "object", + "properties": { + "taskId": { + "type": "string", + "description": "The taskId of the background task to cancel.", + }, + }, + "required": ["taskId"], + }, + ), + ] + + @classmethod + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: + task_id = arguments.get("taskId", "") + + if tool_name == GET_RESULT_TOOL_NAME: + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[TextContent(type="text", text=task_status.model_dump_json())], + ) + + if tool_name == CANCEL_TOOL_NAME: + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py index c135455c88..0b3edbbcbe 100644 --- a/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py +++ b/dash/mcp/primitives/tools/tool_decorated_mcp_functions.py @@ -125,7 +125,9 @@ def list_tools(cls) -> list[Tool]: return [_build_tool(name, reg) for name, reg in cls._registry().items()] @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, tool_name: str, arguments: dict[str, Any], task: dict | None = None + ) -> CallToolResult: reg = cls._registry().get(tool_name) if reg is None: return CallToolResult( diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index f03b93293f..7d1adfe0b5 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -58,7 +58,12 @@ def list_tools(cls) -> list[Tool]: ] @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult: comp_id = arguments.get("component_id", "") if not comp_id: raise ValueError("component_id is required") diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index 716b777326..97970c5df7 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -7,14 +7,15 @@ from typing import Any -from mcp.types import CallToolResult, TextContent, Tool +from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool from dash import get_app +from dash.mcp.tasks import create_task from dash.mcp.types import CallbackExecutionError, ToolNotFoundError from .base import MCPToolProvider from .callback_utils import run_callback -from .results import format_callback_response +from .results import format_callback_response, task_result_to_tool_result class CallbackTools(MCPToolProvider): @@ -30,7 +31,12 @@ def list_tools(cls) -> list[Tool]: return get_app().mcp_callback_map.as_mcp_tools() @classmethod - def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + def call_tool( + cls, + tool_name: str, + arguments: dict[str, Any], + task: dict | None = None, + ) -> CallToolResult | CreateTaskResult: """Execute a callback tool by name.""" callback_map = get_app().mcp_callback_map cb = callback_map.find_by_tool_name(tool_name) @@ -41,6 +47,9 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: " Please call tools/list to refresh your tool list." ) + # pylint: disable-next=protected-access + is_background = bool(cb._cb_info.get("background")) + try: callback_response = run_callback(cb, arguments) except CallbackExecutionError as e: @@ -48,4 +57,11 @@ def call_tool(cls, tool_name: str, arguments: dict[str, Any]) -> CallToolResult: content=[TextContent(type="text", text=str(e))], isError=True, ) + + if is_background: + task_result = create_task(callback_response, cb) + if task is not None: + return task_result + return task_result_to_tool_result(task_result) + return format_callback_response(callback_response, cb) diff --git a/dash/mcp/tasks/__init__.py b/dash/mcp/tasks/__init__.py new file mode 100644 index 0000000000..8b78741d60 --- /dev/null +++ b/dash/mcp/tasks/__init__.py @@ -0,0 +1,5 @@ +"""MCP Tasks — lifecycle management for background callback execution.""" + +from .tasks import create_task, get_task, get_task_result, cancel_task + +__all__ = ["create_task", "get_task", "get_task_result", "cancel_task"] diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py new file mode 100644 index 0000000000..aab1a98f85 --- /dev/null +++ b/dash/mcp/tasks/tasks.py @@ -0,0 +1,154 @@ +"""Handler functions for MCP tasks/* methods.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +from mcp.types import CancelTaskResult, CreateTaskResult, GetTaskResult, Task + +from dash import get_app +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.types import MCPError + + +def parse_task_id(task_id: str) -> tuple[str, str, str]: + """Parse a taskId into (tool_name, job_id, cache_key).""" + return task_id.split(":", 2) + + +def _get_callback_manager(): + """Get the background callback manager from the app's callback_map.""" + app = get_app() + for cb_info in app.callback_map.values(): + manager = cb_info.get("manager") + if manager is not None: + return manager + return None + + +def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult: + """Create a Task from a background callback's initial dispatch response.""" + cache_key = dispatch_response["cacheKey"] + job_id = str(dispatch_response["job"]) + task_id = f"{callback.tool_name}:{job_id}:{cache_key}" + # pylint: disable-next=protected-access + interval = callback._cb_info.get("background", {}).get("interval", 1000) + now = datetime.now(timezone.utc) + return CreateTaskResult( + task=Task( + taskId=task_id, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=None, + pollInterval=interval, + ), + ) + + +def get_task(task_id: str) -> GetTaskResult: + """Handle tasks/get — derive status from the callback manager.""" + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + return GetTaskResult( + taskId=task_id, + status="failed", + statusMessage="No background callback manager configured.", + createdAt=datetime.now(timezone.utc), + lastUpdatedAt=datetime.now(timezone.utc), + ttl=None, + ) + + running = manager.job_running(job_id) + progress = manager.get_progress(cache_key) + + if running: + status = "working" + elif manager.result_ready(cache_key): + status = "completed" + else: + status = "failed" + + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + interval = None + if adapter is not None: + # pylint: disable-next=protected-access + interval = adapter._cb_info.get("background", {}).get("interval", 1000) + + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=task_id, + status=status, + statusMessage=str(progress) if progress else None, + createdAt=datetime.fromisoformat( + manager.handle.get(f"{cache_key}-created_at") or now.isoformat() + ), + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + pollInterval=interval, + ) + + +def get_task_result(task_id: str) -> Any: + """Handle tasks/result — retrieve and format the callback result. + + Mirrors the Dash renderer: calls get_result() which clears from cache. + """ + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + # Mirror the renderer: dispatch with cacheKey/job query params. + # The framework handles result retrieval, wrapping, and cleanup. + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + body = adapter.as_callback_body({}) + app = get_app() + + with app.server.test_request_context( + f"/_dash-update-component?cacheKey={cache_key}&job={job_id}", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_data = json.loads(response.get_data(as_text=True)) + + if "response" not in response_data: + raise MCPError( + "Task result not ready. Poll tasks/get until status is 'completed'." + ) + + return format_callback_response(response_data, adapter) + + +def cancel_task(task_id: str) -> Any: + """Handle tasks/cancel — terminate the background job. + + Same underlying mechanism as the renderer's cancelJob query param. + """ + _tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + manager.terminate_job(job_id) + + now = datetime.now(timezone.utc) + created_at = manager.handle.get(f"{cache_key}-created_at") + manager.handle.delete(f"{cache_key}-created_at") + + return CancelTaskResult( + taskId=task_id, + status="cancelled", + createdAt=datetime.fromisoformat(created_at) if created_at else now, + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + ) diff --git a/requirements/install.txt b/requirements/install.txt index b813a6ce55..caf1e34d0d 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,4 +8,4 @@ retrying nest-asyncio setuptools pydantic>=2.10 -mcp>=1.0.0; python_version>="3.10" +mcp>=1.23.0; python_version>="3.10" diff --git a/tests/integration/mcp/test_mcp_background_tasks.py b/tests/integration/mcp/test_mcp_background_tasks.py new file mode 100644 index 0000000000..e3e0c0acdf --- /dev/null +++ b/tests/integration/mcp/test_mcp_background_tasks.py @@ -0,0 +1,298 @@ +"""Background callback support through the MCP HTTP endpoint. + +End-to-end flows: trigger a background callback, poll via +``get_background_task_result``, observe progress (``set_progress``), +confirm the cache-expiry behavior, and verify the background-only tools +appear in ``tools/list``. +""" + +import json +import re +import time +from datetime import datetime + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + time.sleep(0.5) + return f"done: {value}" + + return app + + +def _post(client, method, params=None, request_id=1): + return client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + } + ), + headers={"Content-Type": "application/json"}, + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_mcpbg012_trigger_poll_and_retrieve(): + app = _make_background_app() + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + task_info = json.loads(data["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + assert task_info["status"] == "working" + + # Read createdAt from the callback manager directly + _, _, cache_key = task_id.split(":", 2) + stored_created_at = app.callback_map["output.children"]["manager"].handle.get( + f"{cache_key}-created_at" + ) + assert stored_created_at is not None + + # Poll — should be working, with createdAt matching the stored value + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + assert r.status_code == 200 + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert datetime.fromisoformat(poll_data["createdAt"]) == datetime.fromisoformat( + stored_created_at + ) + + # Wait for completion + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 5 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Get result + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + assert r.status_code == 200 + data = json.loads(r.data) + text = data["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg013_result_expires(): + """Result and createdAt are available until the cache expires.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache, cache_by=[lambda: "fixed"], expire=2) + + app = Dash(__name__) + app.layout = html.Div([html.Div(id="input"), html.Div(id="output")]) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def fast_cb(value): + return f"done: {value}" + + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "fast_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + _, job_id, cache_key = task_id.split(":", 2) + + # Wait for job to finish + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # First retrieval — result and createdAt available + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + created_at = manager.handle.get(f"{cache_key}-created_at") + assert created_at is not None + + # Second retrieval — still available (cache_by keeps it) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=3, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + assert manager.handle.get(f"{cache_key}-created_at") == created_at + + # Wait for expiry + time.sleep(2.5) + + # After expiry — tool reports failure, createdAt is fresh (stored value gone) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=4, + ) + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert poll_data["status"] == "failed" + assert datetime.fromisoformat(poll_data["createdAt"]) > datetime.fromisoformat( + created_at + ) + + +def test_mcpbg014_progress_in_poll_response(): + """Progress reported via set_progress appears in poll statusMessage.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="status"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + progress=Output("status", "children"), + background=True, + manager=manager, + interval=200, + ) + def progress_cb(set_progress, value): + for i in range(10): + set_progress(f"Step {i + 1} of 10") + time.sleep(0.2) + return f"done: {value}" + + client = app.server.test_client() + + # Trigger + r = _post( + client, + "tools/call", + {"name": "progress_cb", "arguments": {"value": "hi"}}, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + + # Poll and collect all progress messages + progress_pattern = re.compile(r"Step \d+ of 10") + progress_messages = [] + deadline = time.time() + 10 + while time.time() < deadline: + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + try: + poll_data = json.loads(text) + msg = poll_data.get("statusMessage") + if msg is not None: + progress_messages.append(msg) + if poll_data.get("status") == "completed": + break + except (json.JSONDecodeError, KeyError): + break + time.sleep(0.3) + + assert len(progress_messages) > 0, "Expected progress updates during polling" + for msg in progress_messages: + assert progress_pattern.search(msg), f"Unexpected progress format: {msg}" + + +def test_mcpbg015_background_tools_in_tools_list(): + app = _make_background_app() + client = app.server.test_client() + r = _post(client, "tools/list") + data = json.loads(r.data) + names = [t["name"] for t in data["result"]["tools"]] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + assert "slow_callback" in names diff --git a/tests/unit/mcp/tools/test_mcp_background_callbacks.py b/tests/unit/mcp/tools/test_mcp_background_callbacks.py new file mode 100644 index 0000000000..bc6f6b32ed --- /dev/null +++ b/tests/unit/mcp/tools/test_mcp_background_callbacks.py @@ -0,0 +1,277 @@ +"""Background callback support via MCP Tasks. + +Covers both layers: +- Layer 1 (``dash/mcp/tasks/``): ``tasks/get``, ``tasks/result``, ``tasks/cancel`` + derived on-demand from the callback manager. +- Layer 2 (tool wrappers): ``get_background_task_result`` and + ``cancel_background_task`` — only registered when the app has + background callbacks. +""" + +import json +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + """A background callback.""" + time.sleep(0.3) + return f"done: {value}" + + return app + + +def _trigger_task(app): + """Call slow_callback via tools/call and return its taskId.""" + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + return json.loads(result["result"]["content"][0]["text"])["taskId"] + + +def _wait_for_completion(app, task_id, timeout=3): + """Block until the callback manager reports the job is no longer running.""" + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + timeout + while time.time() < deadline: + if not manager.job_running(job_id): + return + time.sleep(0.1) + + +# --------------------------------------------------------------------------- +# Tool-layer: cancel_background_task, get_background_task_result, registration +# --------------------------------------------------------------------------- + + +def test_mcpbg001_cancel_via_tool(): + app = _make_background_app() + task_id = _trigger_task(app) + + cancel = _mcp( + app, + "tools/call", + { + "name": "cancel_background_task", + "arguments": {"taskId": task_id}, + }, + ) + assert cancel["result"].get("isError") is not True + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +def test_mcpbg002_present_with_background_callbacks(): + app = _make_background_app() + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + + +def test_mcpbg003_absent_without_background_callbacks(): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("in", "children")) + def normal_cb(v): + return v + + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" not in names + assert "cancel_background_task" not in names + + +def test_mcpbg004_returns_working_while_running(): + app = _make_background_app() + task_id = _trigger_task(app) + poll = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = poll["result"]["content"][0]["text"] + assert "working" in text.lower() + + +def test_mcpbg005_returns_result_when_complete(): + app = _make_background_app() + task_id = _trigger_task(app) + _wait_for_completion(app, task_id) + + result = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg006_returns_task_id(): + """Calling a background callback tool returns a taskId immediately.""" + app = _make_background_app() + result = _mcp( + app, + "tools/call", + {"name": "slow_callback", "arguments": {"value": "hello"}}, + ) + text = result["result"]["content"][0]["text"] + assert "taskId" in text + assert "slow_callback:" in text + + +# --------------------------------------------------------------------------- +# Tasks-protocol layer: tasks/get, tasks/result, tasks/cancel +# --------------------------------------------------------------------------- + + +def test_mcpbg007_tasks_get_working_status_while_running(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + get_result = _mcp(app, "tasks/get", {"taskId": task_id}) + assert get_result["result"]["status"] == "working" + assert get_result["result"]["taskId"] == task_id + + +def test_mcpbg008_tasks_result_returns_formatted_result(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + _wait_for_completion(app, task_id) + + result = _mcp(app, "tasks/result", {"taskId": task_id}) + assert "content" in result["result"] + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +def test_mcpbg009_tasks_cancel_terminates_job(): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) + assert "error" not in cancel_result + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +# --------------------------------------------------------------------------- +# tools/call with task metadata → CreateTaskResult + taskId encoding +# --------------------------------------------------------------------------- + + +def test_mcpbg010_returns_create_task_result(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task = result["result"]["task"] + assert task["status"] == "working" + assert "taskId" in task + assert "pollInterval" in task + + +def test_mcpbg011_task_id_encodes_tool_name_job_id_cache_key(): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = result["result"]["task"]["taskId"] + tool_name, _job_id, cache_key = task_id.split(":", 2) + assert tool_name == "slow_callback" + assert len(cache_key) == 64 # SHA256 hex