From ab2979ea2a6575622a272188275fdc3b5a29611a Mon Sep 17 00:00:00 2001 From: William Bergamin Date: Thu, 19 Mar 2026 14:20:28 -0400 Subject: [PATCH] feat: widen the availability of set_status --- .../context/assistant/assistant_utilities.py | 8 + .../assistant/async_assistant_utilities.py | 8 + .../async_attaching_agent_kwargs.py | 7 +- .../attaching_agent_kwargs.py | 7 +- ...est_events_assistant_without_middleware.py | 6 +- .../scenario_tests/test_events_set_status.py | 171 ++++++++++++++++ ...est_events_assistant_without_middleware.py | 6 +- .../test_events_set_status.py | 183 ++++++++++++++++++ .../test_attaching_agent_kwargs.py | 18 +- .../test_async_attaching_agent_kwargs.py | 18 +- 10 files changed, 414 insertions(+), 18 deletions(-) create mode 100644 tests/scenario_tests/test_events_set_status.py create mode 100644 tests/scenario_tests_async/test_events_set_status.py diff --git a/slack_bolt/context/assistant/assistant_utilities.py b/slack_bolt/context/assistant/assistant_utilities.py index 53500efdb..42f05c94b 100644 --- a/slack_bolt/context/assistant/assistant_utilities.py +++ b/slack_bolt/context/assistant/assistant_utilities.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional from slack_sdk.web import WebClient @@ -51,6 +52,13 @@ def is_valid(self) -> bool: @property def set_status(self) -> SetStatus: + warnings.warn( + "AssistantUtilities.set_status is deprecated. " + "Use the set_status argument directly in your listener function " + "or access it via context.set_status instead.", + DeprecationWarning, + stacklevel=2, + ) return SetStatus(self.client, self.channel_id, self.thread_ts) @property diff --git a/slack_bolt/context/assistant/async_assistant_utilities.py b/slack_bolt/context/assistant/async_assistant_utilities.py index 5a7324e99..b40b2619c 100644 --- a/slack_bolt/context/assistant/async_assistant_utilities.py +++ b/slack_bolt/context/assistant/async_assistant_utilities.py @@ -1,3 +1,4 @@ +import warnings from typing import Optional from slack_sdk.web.async_client import AsyncWebClient @@ -54,6 +55,13 @@ def is_valid(self) -> bool: @property def set_status(self) -> AsyncSetStatus: + warnings.warn( + "AsyncAssistantUtilities.set_status is deprecated. " + "Use the set_status argument directly in your listener function " + "or access it via context.set_status instead.", + DeprecationWarning, + stacklevel=2, + ) return AsyncSetStatus(self.client, self.channel_id, self.thread_ts) @property diff --git a/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py index 08851c1eb..82f1a7671 100644 --- a/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py +++ b/slack_bolt/middleware/attaching_agent_kwargs/async_attaching_agent_kwargs.py @@ -3,6 +3,7 @@ from slack_bolt.context.assistant.async_assistant_utilities import AsyncAssistantUtilities from slack_bolt.context.assistant.thread_context_store.async_store import AsyncAssistantThreadContextStore from slack_bolt.context.say_stream.async_say_stream import AsyncSayStream +from slack_bolt.context.set_status.async_set_status import AsyncSetStatus from slack_bolt.middleware.async_middleware import AsyncMiddleware from slack_bolt.request.async_request import AsyncBoltRequest from slack_bolt.request.payload_utils import is_assistant_event, to_event @@ -32,7 +33,6 @@ async def async_process( thread_context_store=self.thread_context_store, ) req.context["say"] = assistant.say - req.context["set_status"] = assistant.set_status req.context["set_title"] = assistant.set_title req.context["set_suggested_prompts"] = assistant.set_suggested_prompts req.context["get_thread_context"] = assistant.get_thread_context @@ -41,6 +41,11 @@ async def async_process( # TODO: in the future we might want to introduce a "proper" extract_ts utility thread_ts = req.context.thread_ts or event.get("ts") if req.context.channel_id and thread_ts: + req.context["set_status"] = AsyncSetStatus( + client=req.context.client, + channel_id=req.context.channel_id, + thread_ts=thread_ts, + ) req.context["say_stream"] = AsyncSayStream( client=req.context.client, channel=req.context.channel_id, diff --git a/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py index 38a62c0c8..70f41d561 100644 --- a/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py +++ b/slack_bolt/middleware/attaching_agent_kwargs/attaching_agent_kwargs.py @@ -3,6 +3,7 @@ from slack_bolt.context.assistant.assistant_utilities import AssistantUtilities from slack_bolt.context.assistant.thread_context_store.store import AssistantThreadContextStore from slack_bolt.context.say_stream.say_stream import SayStream +from slack_bolt.context.set_status.set_status import SetStatus from slack_bolt.middleware import Middleware from slack_bolt.request.payload_utils import is_assistant_event, to_event from slack_bolt.request.request import BoltRequest @@ -26,7 +27,6 @@ def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], Bo thread_context_store=self.thread_context_store, ) req.context["say"] = assistant.say - req.context["set_status"] = assistant.set_status req.context["set_title"] = assistant.set_title req.context["set_suggested_prompts"] = assistant.set_suggested_prompts req.context["get_thread_context"] = assistant.get_thread_context @@ -35,6 +35,11 @@ def process(self, *, req: BoltRequest, resp: BoltResponse, next: Callable[[], Bo # TODO: in the future we might want to introduce a "proper" extract_ts utility thread_ts = req.context.thread_ts or event.get("ts") if req.context.channel_id and thread_ts: + req.context["set_status"] = SetStatus( + client=req.context.client, + channel_id=req.context.channel_id, + thread_ts=thread_ts, + ) req.context["say_stream"] = SayStream( client=req.context.client, channel=req.context.channel_id, diff --git a/tests/scenario_tests/test_events_assistant_without_middleware.py b/tests/scenario_tests/test_events_assistant_without_middleware.py index 36d86c43a..6a9381a33 100644 --- a/tests/scenario_tests/test_events_assistant_without_middleware.py +++ b/tests/scenario_tests/test_events_assistant_without_middleware.py @@ -180,7 +180,7 @@ def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None @@ -208,7 +208,7 @@ def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None @@ -236,7 +236,7 @@ def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None diff --git a/tests/scenario_tests/test_events_set_status.py b/tests/scenario_tests/test_events_set_status.py new file mode 100644 index 000000000..2dbdd38b8 --- /dev/null +++ b/tests/scenario_tests/test_events_set_status.py @@ -0,0 +1,171 @@ +import json +from threading import Event +from urllib.parse import quote + +from slack_sdk.web import WebClient + +from slack_bolt import App, BoltContext, BoltRequest +from slack_bolt.context.set_status.set_status import SetStatus +from slack_bolt.middleware.assistant import Assistant +from tests.mock_web_api_server import ( + assert_auth_test_count, + assert_received_request_count, + cleanup_mock_web_api_server, + setup_mock_web_api_server, +) +from tests.scenario_tests.test_app import app_mention_event_body +from tests.scenario_tests.test_events_assistant import thread_started_event_body +from tests.scenario_tests.test_events_assistant import user_message_event_body as threaded_user_message_event_body +from tests.scenario_tests.test_message_bot import bot_message_event_payload, user_message_event_payload +from tests.scenario_tests.test_view_submission import body as view_submission_body +from tests.utils import remove_os_env_temporarily, restore_os_env + + +class TestEventsSetStatus: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = WebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + def setup_method(self): + self.old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server(self) + + def teardown_method(self): + cleanup_mock_web_api_server(self) + restore_os_env(self.old_os_env) + + def test_set_status_injected_for_app_mention(self): + app = App(client=self.web_client) + + @app.event("app_mention") + def handle_mention(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1595926230.009600" + set_status(status="Thinking...") + + request = BoltRequest(body=app_mention_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_injected_for_threaded_message(self): + app = App(client=self.web_client) + + @app.event("message") + def handle_message(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + set_status(status="Thinking...") + + request = BoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_in_user_message(self): + app = App(client=self.web_client) + + @app.message("") + def handle_user_message(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1610261659.001400" + set_status(status="Thinking...") + + request = BoltRequest(body=user_message_event_payload, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_in_bot_message(self): + app = App(client=self.web_client) + + @app.message("") + def handle_bot_message(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1610261539.000900" + set_status(status="Thinking...") + + request = BoltRequest(body=bot_message_event_payload, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_in_assistant_thread_started(self): + app = App(client=self.web_client) + assistant = Assistant() + + @assistant.thread_started + def start_thread(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + set_status(status="Thinking...") + + app.assistant(assistant) + + request = BoltRequest(body=thread_started_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_in_assistant_user_message(self): + app = App(client=self.web_client) + assistant = Assistant() + + @assistant.user_message + def handle_user_message(set_status: SetStatus, context: BoltContext): + assert set_status is not None + assert isinstance(set_status, SetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + set_status(status="Thinking...") + + app.assistant(assistant) + + request = BoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert_received_request_count(self, path="/assistant.threads.setStatus", min_count=1) + + def test_set_status_is_none_for_view_submission(self): + app = App(client=self.web_client, request_verification_enabled=False) + listener_called = Event() + + @app.view("view-id") + def handle_view(ack, set_status, context: BoltContext): + ack() + assert set_status is None + assert context.set_status is None + listener_called.set() + + request = BoltRequest( + body=f"payload={quote(json.dumps(view_submission_body))}", + ) + response = app.dispatch(request) + assert response.status == 200 + assert_auth_test_count(self, 1) + assert listener_called.is_set() diff --git a/tests/scenario_tests_async/test_events_assistant_without_middleware.py b/tests/scenario_tests_async/test_events_assistant_without_middleware.py index be6c2b166..916dfd467 100644 --- a/tests/scenario_tests_async/test_events_assistant_without_middleware.py +++ b/tests/scenario_tests_async/test_events_assistant_without_middleware.py @@ -197,7 +197,7 @@ async def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None @@ -226,7 +226,7 @@ async def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None @@ -255,7 +255,7 @@ async def handle_message_event( ): assert context.thread_ts == "1726133698.626339" assert say.thread_ts == None - assert set_status is None + assert set_status is not None assert set_title is None assert set_suggested_prompts is None assert get_thread_context is None diff --git a/tests/scenario_tests_async/test_events_set_status.py b/tests/scenario_tests_async/test_events_set_status.py new file mode 100644 index 000000000..0e5be3349 --- /dev/null +++ b/tests/scenario_tests_async/test_events_set_status.py @@ -0,0 +1,183 @@ +import asyncio +import json +from urllib.parse import quote + +import pytest +from slack_sdk.web.async_client import AsyncWebClient + +from slack_bolt.app.async_app import AsyncApp +from slack_bolt.async_app import AsyncAssistant +from slack_bolt.context.async_context import AsyncBoltContext +from slack_bolt.context.set_status.async_set_status import AsyncSetStatus +from slack_bolt.request.async_request import AsyncBoltRequest +from tests.mock_web_api_server import ( + assert_auth_test_count_async, + assert_received_request_count_async, + cleanup_mock_web_api_server_async, + setup_mock_web_api_server_async, +) +from tests.scenario_tests_async.test_app import app_mention_event_body +from tests.scenario_tests_async.test_events_assistant import thread_started_event_body +from tests.scenario_tests_async.test_events_assistant import user_message_event_body as threaded_user_message_event_body +from tests.scenario_tests_async.test_message_bot import bot_message_event_payload, user_message_event_payload +from tests.scenario_tests_async.test_view_submission import body as view_submission_body +from tests.utils import remove_os_env_temporarily, restore_os_env + + +class TestAsyncEventsSetStatus: + valid_token = "xoxb-valid" + mock_api_server_base_url = "http://localhost:8888" + web_client = AsyncWebClient( + token=valid_token, + base_url=mock_api_server_base_url, + ) + + @pytest.fixture(scope="function", autouse=True) + def setup_teardown(self): + old_os_env = remove_os_env_temporarily() + setup_mock_web_api_server_async(self) + try: + yield + finally: + cleanup_mock_web_api_server_async(self) + restore_os_env(old_os_env) + + @pytest.mark.asyncio + async def test_set_status_injected_for_app_mention(self): + app = AsyncApp(client=self.web_client) + + @app.event("app_mention") + async def handle_mention(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1595926230.009600" + await set_status(status="Thinking...") + + request = AsyncBoltRequest(body=app_mention_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_injected_for_threaded_message(self): + app = AsyncApp(client=self.web_client) + + @app.event("message") + async def handle_message(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + await set_status(status="Thinking...") + + request = AsyncBoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_in_user_message(self): + app = AsyncApp(client=self.web_client) + + @app.message("") + async def handle_user_message(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1610261659.001400" + await set_status(status="Thinking...") + + request = AsyncBoltRequest(body=user_message_event_payload, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_in_bot_message(self): + app = AsyncApp(client=self.web_client) + + @app.message("") + async def handle_user_message(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "C111" + assert set_status.thread_ts == "1610261539.000900" + await set_status(status="Thinking...") + + request = AsyncBoltRequest(body=bot_message_event_payload, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_in_assistant_thread_started(self): + app = AsyncApp(client=self.web_client) + assistant = AsyncAssistant() + + @assistant.thread_started + async def start_thread(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + await set_status(status="Thinking...") + + app.assistant(assistant) + + request = AsyncBoltRequest(body=thread_started_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_in_assistant_user_message(self): + app = AsyncApp(client=self.web_client) + assistant = AsyncAssistant() + + @assistant.user_message + async def handle_user_message(set_status: AsyncSetStatus, context: AsyncBoltContext): + assert set_status is not None + assert isinstance(set_status, AsyncSetStatus) + assert set_status == context.set_status + assert set_status.channel_id == "D111" + assert set_status.thread_ts == "1726133698.626339" + await set_status(status="Thinking...") + + app.assistant(assistant) + + request = AsyncBoltRequest(body=threaded_user_message_event_body, mode="socket_mode") + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + await assert_received_request_count_async(self, path="/assistant.threads.setStatus", min_count=1) + + @pytest.mark.asyncio + async def test_set_status_is_none_for_view_submission(self): + app = AsyncApp(client=self.web_client, request_verification_enabled=False) + listener_called = asyncio.Event() + + @app.view("view-id") + async def handle_view(ack, set_status, context: AsyncBoltContext): + await ack() + assert set_status is None + assert context.set_status is None + listener_called.set() + + request = AsyncBoltRequest( + body=f"payload={quote(json.dumps(view_submission_body))}", + ) + response = await app.async_dispatch(request) + assert response.status == 200 + await assert_auth_test_count_async(self, 1) + assert listener_called.is_set() diff --git a/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py b/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py index f56bd2e62..8e626fd0c 100644 --- a/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py +++ b/tests/slack_bolt/middleware/attaching_agent_kwargs/test_attaching_agent_kwargs.py @@ -14,7 +14,7 @@ def next(): return BoltResponse(status=200) -AGENT_KWARGS = ("say", "set_status", "set_title", "set_suggested_prompts", "get_thread_context", "save_thread_context") +ASSISTANT_KWARGS = ("say", "set_title", "set_suggested_prompts", "get_thread_context", "save_thread_context") class TestAttachingAgentKwargs: @@ -26,9 +26,11 @@ def test_assistant_event_attaches_kwargs(self): resp = middleware.process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key in req.context, f"{key} should be set on context" assert req.context["say"].thread_ts == "1726133698.626339" + assert "say_stream" in req.context + assert "set_status" in req.context def test_user_message_event_attaches_kwargs(self): middleware = AttachingAgentKwargs() @@ -38,9 +40,11 @@ def test_user_message_event_attaches_kwargs(self): resp = middleware.process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key in req.context, f"{key} should be set on context" assert req.context["say"].thread_ts == "1726133698.626339" + assert "say_stream" in req.context + assert "set_status" in req.context def test_non_assistant_event_does_not_attach_kwargs(self): middleware = AttachingAgentKwargs() @@ -50,8 +54,10 @@ def test_non_assistant_event_does_not_attach_kwargs(self): resp = middleware.process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key not in req.context, f"{key} should not be set on context" + assert "say_stream" in req.context + assert "set_status" in req.context def test_non_event_does_not_attach_kwargs(self): middleware = AttachingAgentKwargs() @@ -60,5 +66,7 @@ def test_non_event_does_not_attach_kwargs(self): resp = middleware.process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key not in req.context, f"{key} should not be set on context" + assert "say_stream" not in req.context + assert "set_status" not in req.context diff --git a/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py b/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py index 55883e5f3..61aa0b59e 100644 --- a/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py +++ b/tests/slack_bolt_async/middleware/attaching_agent_kwargs/test_async_attaching_agent_kwargs.py @@ -15,7 +15,7 @@ async def next(): return BoltResponse(status=200) -AGENT_KWARGS = ("say", "set_status", "set_title", "set_suggested_prompts", "get_thread_context", "save_thread_context") +ASSISTANT_KWARGS = ("say", "set_title", "set_suggested_prompts", "get_thread_context", "save_thread_context") class TestAsyncAttachingAgentKwargs: @@ -28,9 +28,11 @@ async def test_assistant_event_attaches_kwargs(self): resp = await middleware.async_process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key in req.context, f"{key} should be set on context" assert req.context["say"].thread_ts == "1726133698.626339" + assert "say_stream" in req.context + assert "set_status" in req.context @pytest.mark.asyncio async def test_user_message_event_attaches_kwargs(self): @@ -41,9 +43,11 @@ async def test_user_message_event_attaches_kwargs(self): resp = await middleware.async_process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key in req.context, f"{key} should be set on context" assert req.context["say"].thread_ts == "1726133698.626339" + assert "say_stream" in req.context + assert "set_status" in req.context @pytest.mark.asyncio async def test_non_assistant_event_does_not_attach_kwargs(self): @@ -54,8 +58,10 @@ async def test_non_assistant_event_does_not_attach_kwargs(self): resp = await middleware.async_process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key not in req.context, f"{key} should not be set on context" + assert "say_stream" in req.context + assert "set_status" in req.context @pytest.mark.asyncio async def test_non_event_does_not_attach_kwargs(self): @@ -65,5 +71,7 @@ async def test_non_event_does_not_attach_kwargs(self): resp = await middleware.async_process(req=req, resp=BoltResponse(status=404), next=next) assert resp.status == 200 - for key in AGENT_KWARGS: + for key in ASSISTANT_KWARGS: assert key not in req.context, f"{key} should not be set on context" + assert "say_stream" not in req.context + assert "set_status" not in req.context