diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 4ed1b525a0..49bdd64387 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -911,19 +911,21 @@ async def _post_hook(response: AgentResponse) -> None: if ctx is None: return # No context available (shouldn't happen in normal flow) + # Update thread with conversation_id derived from streaming raw updates. + # Using response_id here can break function-call continuation for APIs + # where response IDs are not valid conversation handles. + conversation_id = self._extract_conversation_id_from_streaming_response(response) # Ensure author names are set for all messages for message in response.messages: if message.author_name is None: message.author_name = ctx["agent_name"] - # Propagate conversation_id back to session from streaming updates + # Propagate conversation_id back to session from streaming updates. + # For Responses-style APIs this can rotate every turn (response_id-based continuation), + # so refresh when a newer value is returned. sess = ctx["session"] - if sess and not sess.service_session_id and response.raw_representation: - raw_items = response.raw_representation if isinstance(response.raw_representation, list) else [] - for item in raw_items: - if hasattr(item, "conversation_id") and item.conversation_id: - sess.service_session_id = item.conversation_id - break + if sess and conversation_id and sess.service_session_id != conversation_id: + sess.service_session_id = conversation_id # Run after_run providers (reverse order) session_context = ctx["session_context"] @@ -974,6 +976,27 @@ def _finalize_response_updates( output_format_type = response_format if isinstance(response_format, type) else None return AgentResponse.from_updates(updates, output_format_type=output_format_type) + @staticmethod + def _extract_conversation_id_from_streaming_response(response: AgentResponse[Any]) -> str | None: + """Extract conversation_id from streaming raw updates, if present.""" + raw = response.raw_representation + if raw is None: + return None + + raw_items: list[Any] = raw if isinstance(raw, list) else [raw] + for item in reversed(raw_items): + if isinstance(item, Mapping): + value = item.get("conversation_id") + if isinstance(value, str) and value: + return value + continue + + value = getattr(item, "conversation_id", None) + if isinstance(value, str) and value: + return value + + return None + async def _prepare_run_context( self, *, @@ -1100,8 +1123,10 @@ async def _finalize_response( if message.author_name is None: message.author_name = agent_name - # Propagate conversation_id back to session (e.g. thread ID from Assistants API) - if session and response.conversation_id and not session.service_session_id: + # Propagate conversation_id back to session (e.g. thread ID from Assistants API). + # For Responses-style APIs this can rotate every turn (response_id-based continuation), + # so refresh when a newer value is returned. + if session and response.conversation_id and session.service_session_id != response.conversation_id: session.service_session_id = response.conversation_id # Set the response on the context for after_run providers diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index b52d2b252b..d9a2b5579d 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -872,7 +872,16 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str: k: v for k, v in kwargs.items() if k - not in {"chat_options", "tools", "tool_choice", "session", "thread", "conversation_id", "options", "response_format"} + not in { + "chat_options", + "tools", + "tool_choice", + "session", + "thread", + "conversation_id", + "options", + "response_format", + } } parser = self.parse_tool_results or _parse_tool_result_from_mcp diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 3f0bd63086..3b10579055 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -2,7 +2,7 @@ import logging import sys -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from typing import Any, cast @@ -358,22 +358,31 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY) or {}) updates: list[AgentResponseUpdate] = [] - user_input_requests: list[Content] = [] - async for update in self._agent.run( + streamed_user_input_requests: list[Content] = [] + stream = self._agent.run( self._cache, stream=True, session=self._session, options=options, **run_kwargs, - ): + ) + async for update in stream: updates.append(update) await ctx.yield_output(update) - if update.user_input_requests: - user_input_requests.extend(update.user_input_requests) - - # Build the final AgentResponse from the collected updates - if is_chat_agent(self._agent): + streamed_user_input_requests.extend(update.user_input_requests) + + # Prefer stream finalization when available so result hooks run + # (e.g., thread conversation updates). Fall back to reconstructing from updates + # for legacy/custom agents that return a plain async iterable. + # TODO(evmattso): Integrate workflow agent run handling around ResponseStream so + # AgentExecutor does not need this conditional stream-finalization branch. + maybe_get_final_response = getattr(stream, "get_final_response", None) + get_final_response = maybe_get_final_response if callable(maybe_get_final_response) else None + response: AgentResponse[Any] + if get_final_response is not None: + response = await cast(Callable[[], Awaitable[AgentResponse[Any]]], get_final_response)() + elif is_chat_agent(self._agent): response_format = self._agent.default_options.get("response_format") response = AgentResponse.from_updates( updates, @@ -383,6 +392,16 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp response = AgentResponse.from_updates(updates) # Handle any user input requests after the streaming completes + user_input_requests: list[Content] = [] + seen_request_ids: set[str] = set() + for user_input_request in [*streamed_user_input_requests, *response.user_input_requests]: + request_id = getattr(user_input_request, "id", None) + if isinstance(request_id, str) and request_id: + if request_id in seen_request_ids: + continue + seen_request_ids.add(request_id) + user_input_requests.append(user_input_request) + if user_input_requests: for user_input_request in user_input_requests: self._pending_agent_requests[user_input_request.id] = user_input_request # type: ignore[index] diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index 16365a7a5f..6d776a8b43 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -17,6 +17,7 @@ BaseContextProvider, ChatOptions, ChatResponse, + ChatResponseUpdate, Content, FunctionTool, Message, @@ -154,6 +155,111 @@ async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChat assert session.service_session_id == "123" +async def test_chat_client_agent_updates_existing_session_id_non_streaming( + chat_client_base: SupportsChatGetResponse, +) -> None: + chat_client_base.run_responses = [ + ChatResponse( + messages=[Message(role="assistant", contents=[Content.from_text("test response")])], + conversation_id="resp_new_123", + ) + ] + + agent = Agent(client=chat_client_base) + session = agent.get_session(service_session_id="resp_old_123") + + await agent.run("Hello", session=session) + assert session.service_session_id == "resp_new_123" + + +async def test_chat_client_agent_update_session_id_streaming_uses_conversation_id( + chat_client_base: SupportsChatGetResponse, +) -> None: + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[Content.from_text("stream part 1")], + role="assistant", + response_id="resp_stream_123", + conversation_id="conv_stream_456", + ), + ChatResponseUpdate( + contents=[Content.from_text(" stream part 2")], + role="assistant", + response_id="resp_stream_123", + conversation_id="conv_stream_456", + finish_reason="stop", + ), + ] + ] + + agent = Agent(client=chat_client_base) + session = agent.create_session() + + stream = agent.run("Hello", session=session, stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + assert result.text == "stream part 1 stream part 2" + assert session.service_session_id == "conv_stream_456" + + +async def test_chat_client_agent_updates_existing_session_id_streaming( + chat_client_base: SupportsChatGetResponse, +) -> None: + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[Content.from_text("stream part 1")], + role="assistant", + response_id="resp_stream_123", + conversation_id="resp_new_456", + ), + ChatResponseUpdate( + contents=[Content.from_text(" stream part 2")], + role="assistant", + response_id="resp_stream_123", + conversation_id="resp_new_456", + finish_reason="stop", + ), + ] + ] + + agent = Agent(client=chat_client_base) + session = agent.get_session(service_session_id="resp_old_456") + + stream = agent.run("Hello", session=session, stream=True) + async for _ in stream: + pass + await stream.get_final_response() + assert session.service_session_id == "resp_new_456" + + +async def test_chat_client_agent_update_session_id_streaming_does_not_use_response_id( + chat_client_base: SupportsChatGetResponse, +) -> None: + chat_client_base.streaming_responses = [ + [ + ChatResponseUpdate( + contents=[Content.from_text("stream response without conversation id")], + role="assistant", + response_id="resp_only_123", + finish_reason="stop", + ), + ] + ] + + agent = Agent(client=chat_client_base) + session = agent.create_session() + + stream = agent.run("Hello", session=session, stream=True) + async for _ in stream: + pass + result = await stream.get_final_response() + assert result.text == "stream response without conversation id" + assert session.service_session_id is None + + async def test_chat_client_agent_update_session_messages(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) session = agent.create_session() diff --git a/python/packages/core/tests/workflow/__init__.py b/python/packages/core/tests/workflow/__init__.py index e69de29bb2..2a50eae894 100644 --- a/python/packages/core/tests/workflow/__init__.py +++ b/python/packages/core/tests/workflow/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Microsoft. All rights reserved. diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 07b15d5bf1..7c2e6fc356 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -50,6 +50,57 @@ async def _run() -> AgentResponse: return _run() +class _StreamingHookAgent(BaseAgent): + """Agent that exposes whether its streaming result hook was executed.""" + + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + self.result_hook_called = False + + def run( + self, + messages: str | Message | list[str] | list[Message] | None = None, + *, + stream: bool = False, + **kwargs: Any, + ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + if stream: + + async def _stream() -> AsyncIterable[AgentResponseUpdate]: + yield AgentResponseUpdate( + contents=[Content.from_text(text="hook test")], + role="assistant", + ) + + async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse: + self.result_hook_called = True + return response + + return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook( + _mark_result_hook_called + ) + + async def _run() -> AgentResponse: + return AgentResponse(messages=[Message("assistant", ["hook test"])]) + + return _run() + + +async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None: + """AgentExecutor should call get_final_response() so stream result hooks execute.""" + agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") + executor = AgentExecutor(agent, id="hook_exec") + workflow = SequentialBuilder(participants=[executor]).build() + + output_events: list[Any] = [] + async for event in workflow.run("run hook test", stream=True): + if event.type == "output": + output_events.append(event) + + assert output_events + assert agent.result_hook_called + + async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: """Test that workflow checkpoint stores AgentExecutor's cache and session states and restores them correctly.""" storage = InMemoryCheckpointStorage() diff --git a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py index dbb01d6e66..9400084692 100644 --- a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py +++ b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py @@ -12,7 +12,7 @@ WorkflowRunState, ) from agent_framework._workflows._checkpoint_encoding import ( - _PICKLE_MARKER, + _PICKLE_MARKER, # type: ignore encode_checkpoint_value, ) from agent_framework._workflows._events import WorkflowEvent diff --git a/python/packages/devui/agent_framework_devui/_conversations.py b/python/packages/devui/agent_framework_devui/_conversations.py index 2778c2da78..f0e91e0d87 100644 --- a/python/packages/devui/agent_framework_devui/_conversations.py +++ b/python/packages/devui/agent_framework_devui/_conversations.py @@ -14,7 +14,7 @@ from typing import Any, Literal, cast from agent_framework import AgentSession, Message -from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage +from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage, WorkflowCheckpoint from openai.types.conversations import Conversation, ConversationDeletedResource from openai.types.conversations.conversation_item import ConversationItem from openai.types.conversations.message import Message as OpenAIMessage @@ -480,7 +480,7 @@ async def list_items( checkpoint_storage = conv_data.get("checkpoint_storage") if checkpoint_storage: # Get all checkpoints for this conversation - checkpoints = await checkpoint_storage.list_checkpoints() + checkpoints = self._list_all_checkpoints(checkpoint_storage) for checkpoint in checkpoints: # Create a conversation item for each checkpoint with summary metadata # Full checkpoint state is NOT included here (too large for list view) @@ -495,7 +495,9 @@ async def list_items( "id": f"checkpoint_{checkpoint.checkpoint_id}", "type": "checkpoint", "checkpoint_id": checkpoint.checkpoint_id, - "workflow_id": checkpoint.workflow_id, + # Keep workflow_id for backward compatibility with existing UI payloads. + "workflow_id": checkpoint.workflow_name, + "workflow_name": checkpoint.workflow_name, "timestamp": checkpoint.timestamp, "status": "completed", "metadata": { @@ -506,6 +508,7 @@ async def list_items( "message_count": sum(len(msgs) for msgs in checkpoint.messages.values()), "size_bytes": checkpoint_size, "version": checkpoint.version, + "graph_signature_hash": checkpoint.graph_signature_hash, }, } items.append(cast(ConversationItem, checkpoint_item)) @@ -551,8 +554,9 @@ async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem return None # Load full checkpoint from storage - checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id) - if not checkpoint: + try: + checkpoint = await checkpoint_storage.load(checkpoint_id) + except Exception: return None # Calculate size of checkpoint @@ -566,7 +570,9 @@ async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem "id": item_id, "type": "checkpoint", "checkpoint_id": checkpoint.checkpoint_id, - "workflow_id": checkpoint.workflow_id, + # Keep workflow_id for backward compatibility with existing UI payloads. + "workflow_id": checkpoint.workflow_name, + "workflow_name": checkpoint.workflow_name, "timestamp": checkpoint.timestamp, "status": "completed", "metadata": { @@ -577,6 +583,7 @@ async def get_item(self, conversation_id: str, item_id: str) -> ConversationItem "message_count": sum(len(msgs) for msgs in checkpoint.messages.values()), "size_bytes": checkpoint_size, "version": checkpoint.version, + "graph_signature_hash": checkpoint.graph_signature_hash, # 🔥 FULL checkpoint state (lazy loaded) "full_checkpoint": checkpoint.to_dict(), }, @@ -631,8 +638,8 @@ async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) if conv_meta.get("type") == "workflow_session": checkpoint_storage = conv_data.get("checkpoint_storage") if checkpoint_storage: - checkpoints = await checkpoint_storage.list_checkpoints() - latest = checkpoints[0] if checkpoints else None + checkpoints = self._list_all_checkpoints(checkpoint_storage) + latest = max(checkpoints, key=lambda cp: cp.timestamp) if checkpoints else None conv_meta["checkpoint_summary"] = { "count": len(checkpoints), "latest_iteration": latest.iteration_count if latest else 0, @@ -654,6 +661,19 @@ async def list_conversations_by_metadata(self, metadata_filter: dict[str, str]) return results + @staticmethod + def _list_all_checkpoints(checkpoint_storage: Any) -> list[WorkflowCheckpoint]: + """Return all checkpoints from a conversation-scoped storage instance. + + DevUI uses one checkpoint storage per conversation. Core storage APIs now + require workflow_name filters, so we gather directly from in-memory storage + internals to provide conversation-wide listing for UI views. + """ + checkpoint_map = getattr(checkpoint_storage, "_checkpoints", None) + if isinstance(checkpoint_map, dict): + return list(cast(dict[str, WorkflowCheckpoint], checkpoint_map).values()) + return [] + class CheckpointConversationManager: """Manages checkpoint storage for workflow sessions - SESSION-SCOPED. diff --git a/python/packages/devui/tests/devui/test_checkpoints.py b/python/packages/devui/tests/devui/test_checkpoints.py index ffbbf93022..1e87187e41 100644 --- a/python/packages/devui/tests/devui/test_checkpoints.py +++ b/python/packages/devui/tests/devui/test_checkpoints.py @@ -104,17 +104,21 @@ async def test_conversation_scoped_checkpoint_save(self, checkpoint_manager, tes from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} + checkpoint_id=str(uuid.uuid4()), + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, + messages={}, + state={"test": "data"}, ) # Get checkpoint storage for this conversation and save storage = checkpoint_manager.get_checkpoint_storage(conversation_id) - checkpoint_id = await storage.save_checkpoint(checkpoint) + checkpoint_id = await storage.save(checkpoint) assert checkpoint_id == checkpoint.checkpoint_id # Verify checkpoint stored in THIS conversation only - checkpoints = await storage.list_checkpoints() + checkpoints = await storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints) == 1 assert checkpoints[0].checkpoint_id == checkpoint.checkpoint_id @@ -140,20 +144,21 @@ async def test_conversation_isolation(self, checkpoint_manager, test_workflow): checkpoint_a = WorkflowCheckpoint( checkpoint_id=str(uuid.uuid4()), - workflow_id=test_workflow.id, + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, messages={}, state={"conversation": "A"}, ) storage_a = checkpoint_manager.get_checkpoint_storage(conv_a) - await storage_a.save_checkpoint(checkpoint_a) + await storage_a.save(checkpoint_a) # Verify conversation A has checkpoint - checkpoints_a = await storage_a.list_checkpoints() + checkpoints_a = await storage_a.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_a) == 1 # Verify conversation B has NO checkpoints (isolation) storage_b = checkpoint_manager.get_checkpoint_storage(conv_b) - checkpoints_b = await storage_b.list_checkpoints() + checkpoints_b = await storage_b.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_b) == 0 @pytest.mark.asyncio @@ -177,15 +182,16 @@ async def test_list_checkpoints_in_session(self, checkpoint_manager, test_workfl for i in range(3): checkpoint = WorkflowCheckpoint( checkpoint_id=str(uuid.uuid4()), - workflow_id=test_workflow.id, + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, messages={}, state={"iteration": i}, ) - saved_id = await storage.save_checkpoint(checkpoint) + saved_id = await storage.save(checkpoint) checkpoint_ids.append(saved_id) # List checkpoints using the storage - checkpoints_list = await storage.list_checkpoints() + checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_list) == 3 # Verify all checkpoint IDs are present @@ -213,11 +219,12 @@ async def test_checkpoints_appear_as_conversation_items(self, checkpoint_manager for i in range(2): checkpoint = WorkflowCheckpoint( checkpoint_id=f"checkpoint_{i}", - workflow_id=test_workflow.id, + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, messages={}, state={"iteration": i}, ) - saved_id = await storage.save_checkpoint(checkpoint) + saved_id = await storage.save(checkpoint) checkpoint_ids.append(saved_id) # List conversation items - should include checkpoints @@ -233,7 +240,7 @@ async def test_checkpoints_appear_as_conversation_items(self, checkpoint_manager for item in checkpoint_items: assert item.get("type") == "checkpoint" assert item.get("checkpoint_id") in checkpoint_ids - assert item.get("workflow_id") == test_workflow.id + assert item.get("workflow_name") == test_workflow.name assert "timestamp" in item assert item.get("id").startswith("checkpoint_") # ID format: checkpoint_{checkpoint_id} @@ -255,21 +262,22 @@ async def test_load_checkpoint_from_session(self, checkpoint_manager, test_workf original_checkpoint = WorkflowCheckpoint( checkpoint_id=str(uuid.uuid4()), - workflow_id=test_workflow.id, + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, messages={}, state={"test_key": "test_value"}, ) # Save to this session storage = checkpoint_manager.get_checkpoint_storage(conversation_id) - await storage.save_checkpoint(original_checkpoint) + await storage.save(original_checkpoint) # Load checkpoint from this session - loaded_checkpoint = await storage.load_checkpoint(original_checkpoint.checkpoint_id) + loaded_checkpoint = await storage.load(original_checkpoint.checkpoint_id) assert loaded_checkpoint is not None assert loaded_checkpoint.checkpoint_id == original_checkpoint.checkpoint_id - assert loaded_checkpoint.workflow_id == original_checkpoint.workflow_id + assert loaded_checkpoint.workflow_name == original_checkpoint.workflow_name assert loaded_checkpoint.state == {"test_key": "test_value"} @@ -296,24 +304,28 @@ async def test_checkpoint_storage_protocol(self, checkpoint_manager, test_workfl from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"test": "data"} + checkpoint_id=str(uuid.uuid4()), + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, + messages={}, + state={"test": "data"}, ) - # Test save_checkpoint - checkpoint_id = await storage.save_checkpoint(checkpoint) + # Test save + checkpoint_id = await storage.save(checkpoint) assert checkpoint_id == checkpoint.checkpoint_id - # Test load_checkpoint - loaded = await storage.load_checkpoint(checkpoint_id) + # Test load + loaded = await storage.load(checkpoint_id) assert loaded is not None assert loaded.checkpoint_id == checkpoint_id # Test list_checkpoint_ids - ids = await storage.list_checkpoint_ids(workflow_id=test_workflow.id) + ids = await storage.list_checkpoint_ids(workflow_name=test_workflow.name) assert checkpoint_id in ids # Test list_checkpoints - checkpoints_list = await storage.list_checkpoints(workflow_id=test_workflow.id) + checkpoints_list = await storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_list) >= 1 assert any(cp.checkpoint_id == checkpoint_id for cp in checkpoints_list) @@ -346,12 +358,16 @@ async def test_manual_checkpoint_save_via_injected_storage(self, checkpoint_mana from agent_framework._workflows._checkpoint import WorkflowCheckpoint checkpoint = WorkflowCheckpoint( - checkpoint_id=str(uuid.uuid4()), workflow_id=test_workflow.id, messages={}, state={"injected": True} + checkpoint_id=str(uuid.uuid4()), + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, + messages={}, + state={"injected": True}, ) - await checkpoint_storage.save_checkpoint(checkpoint) + await checkpoint_storage.save(checkpoint) # Verify checkpoint is accessible via storage (in this session) - storage_checkpoints = await checkpoint_storage.list_checkpoints() + storage_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name) assert len(storage_checkpoints) > 0 assert storage_checkpoints[0].checkpoint_id == checkpoint.checkpoint_id @@ -377,20 +393,21 @@ async def test_checkpoint_roundtrip_via_storage(self, checkpoint_manager, test_w checkpoint = WorkflowCheckpoint( checkpoint_id=str(uuid.uuid4()), - workflow_id=test_workflow.id, + workflow_name=test_workflow.name, + graph_signature_hash=test_workflow.graph_signature_hash, messages={}, state={"ready_to_resume": True}, ) - checkpoint_id = await checkpoint_storage.save_checkpoint(checkpoint) + checkpoint_id = await checkpoint_storage.save(checkpoint) # Verify checkpoint can be loaded for resume - loaded = await checkpoint_storage.load_checkpoint(checkpoint_id) + loaded = await checkpoint_storage.load(checkpoint_id) assert loaded is not None assert loaded.checkpoint_id == checkpoint_id assert loaded.state == {"ready_to_resume": True} # Verify checkpoint is accessible via storage (for UI to list checkpoints) - checkpoints = await checkpoint_storage.list_checkpoints() + checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints) > 0 assert checkpoints[0].checkpoint_id == checkpoint_id @@ -420,7 +437,7 @@ async def test_workflow_auto_saves_checkpoints_to_injected_storage(self, checkpo test_workflow._runner.context._checkpoint_storage = checkpoint_storage # Verify no checkpoints initially - checkpoints_before = await checkpoint_storage.list_checkpoints() + checkpoints_before = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_before) == 0 # Run workflow until it reaches IDLE_WITH_PENDING_REQUESTS (after checkpoint is created) @@ -435,9 +452,9 @@ async def test_workflow_auto_saves_checkpoints_to_injected_storage(self, checkpo assert saw_request_event, "Test workflow should have emitted request_info event (type='request_info')" # Verify checkpoint was AUTOMATICALLY saved to our storage by the framework - checkpoints_after = await checkpoint_storage.list_checkpoints() + checkpoints_after = await checkpoint_storage.list_checkpoints(workflow_name=test_workflow.name) assert len(checkpoints_after) > 0, "Workflow should have auto-saved checkpoint at HIL pause" - # Verify checkpoint has correct workflow_id + # Verify checkpoint has correct workflow identity checkpoint = checkpoints_after[0] - assert checkpoint.workflow_id == test_workflow.id + assert checkpoint.workflow_name == test_workflow.name diff --git a/python/packages/devui/tests/devui/test_server.py b/python/packages/devui/tests/devui/test_server.py index b6215ddab5..8fbc127946 100644 --- a/python/packages/devui/tests/devui/test_server.py +++ b/python/packages/devui/tests/devui/test_server.py @@ -379,26 +379,27 @@ async def test_checkpoint_api_endpoints(test_entities_dir): storage = executor.checkpoint_manager.get_checkpoint_storage(conv_id) checkpoint = WorkflowCheckpoint( checkpoint_id="test_checkpoint_1", - workflow_id="test_workflow", + workflow_name="test_workflow", + graph_signature_hash="test_graph_hash", state={"key": "value"}, iteration_count=1, ) - await storage.save_checkpoint(checkpoint) + await storage.save(checkpoint) # Test list checkpoints endpoint - checkpoints = await storage.list_checkpoints() + checkpoints = await storage.list_checkpoints(workflow_name="test_workflow") assert len(checkpoints) == 1 assert checkpoints[0].checkpoint_id == "test_checkpoint_1" - assert checkpoints[0].workflow_id == "test_workflow" + assert checkpoints[0].workflow_name == "test_workflow" # Test delete checkpoint endpoint - deleted = await storage.delete_checkpoint("test_checkpoint_1") + deleted = await storage.delete("test_checkpoint_1") assert deleted is True # Verify checkpoint was deleted - remaining = await storage.list_checkpoints() + remaining = await storage.list_checkpoints(workflow_name="test_workflow") assert len(remaining) == 0 # Test delete non-existent checkpoint - deleted = await storage.delete_checkpoint("nonexistent") + deleted = await storage.delete("nonexistent") assert deleted is False diff --git a/python/packages/durabletask/tests/test_executors.py b/python/packages/durabletask/tests/test_executors.py index 36aa8d46b2..2333f08106 100644 --- a/python/packages/durabletask/tests/test_executors.py +++ b/python/packages/durabletask/tests/test_executors.py @@ -189,19 +189,29 @@ def test_executor_respects_custom_max_poll_retries(self, mock_client: Mock, samp # Verify get_entity was called 2 times (max_poll_retries) assert mock_client.get_entity.call_count == 2 - def test_executor_respects_custom_poll_interval(self, mock_client: Mock, sample_run_request: RunRequest) -> None: + def test_executor_respects_custom_poll_interval( + self, + mock_client: Mock, + sample_run_request: RunRequest, + monkeypatch: pytest.MonkeyPatch, + ) -> None: """Verify executor respects custom poll_interval_seconds during polling.""" # Create executor with very short interval executor = ClientAgentExecutor(mock_client, max_poll_retries=3, poll_interval_seconds=0.01) - # Measure time taken - start = time.time() + sleep_calls: list[float] = [] + + def fake_sleep(seconds: float) -> None: + sleep_calls.append(seconds) + + # Use deterministic assertions instead of wall-clock timing to avoid CI flakiness. + monkeypatch.setattr("agent_framework_durabletask._executors.time.sleep", fake_sleep) + result = executor.run_durable_agent("test_agent", sample_run_request) - elapsed = time.time() - start - # Should take roughly 3 * 0.01 = 0.03 seconds (plus overhead) - # Be generous with timing to avoid flakiness - assert elapsed < 0.2 # Should be quick with 0.01 interval + assert len(sleep_calls) == 3 + assert sleep_calls == pytest.approx([0.01, 0.01, 0.01]) + assert mock_client.get_entity.call_count == 3 assert isinstance(result, AgentResponse) diff --git a/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py b/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py index a774674c5d..7fb9daef13 100644 --- a/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py +++ b/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py @@ -8,9 +8,11 @@ from agent_framework import ( Agent, + AgentResponseUpdate, Content, FileCheckpointStorage, Workflow, + WorkflowEvent, tool, ) from agent_framework.azure import AzureOpenAIResponsesClient @@ -183,8 +185,16 @@ async def main() -> None: initial_request = "Hi, my order 12345 arrived damaged. I need a refund." # Phase 1: Initial run - workflow will pause when it needs user input - results = await workflow.run(message=initial_request) - request_events = results.get_request_info_events() + print("Running initial workflow...") + results = await workflow.run(message=initial_request, stream=True) + + # Iterate through streamed events and collect request_info events + request_events: list[WorkflowEvent] = [] + async for event in results: + event: WorkflowEvent + if event.type == "request_info": + request_events.append(event) + if not request_events: print("Workflow completed without needing user input") return @@ -224,8 +234,17 @@ async def main() -> None: raise RuntimeError("No checkpoints found.") checkpoint_id = checkpoint.checkpoint_id - results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id) - request_events = results.get_request_info_events() + print("Resuming workflow from checkpoint...") + results = await workflow.run(responses=responses, checkpoint_id=checkpoint_id, stream=True) + + # Iterate through streamed events and collect request_info events + request_events: list[WorkflowEvent] = [] + async for event in results: + event: WorkflowEvent + if event.type == "request_info": + request_events.append(event) + elif event.type == "output" and isinstance(event.data, AgentResponseUpdate): + print(event.data.text, end="", flush=True) print("\n" + "=" * 60) print("DEMO COMPLETE") diff --git a/python/samples/demos/chatkit-integration/uploads/atc_3df183a3 b/python/samples/demos/chatkit-integration/uploads/atc_3df183a3 new file mode 100644 index 0000000000..f760b6553f Binary files /dev/null and b/python/samples/demos/chatkit-integration/uploads/atc_3df183a3 differ diff --git a/python/samples/demos/chatkit-integration/uploads/atc_967c57ef b/python/samples/demos/chatkit-integration/uploads/atc_967c57ef new file mode 100644 index 0000000000..f760b6553f Binary files /dev/null and b/python/samples/demos/chatkit-integration/uploads/atc_967c57ef differ diff --git a/python/samples/demos/chatkit-integration/uploads/atc_f4a18d5e b/python/samples/demos/chatkit-integration/uploads/atc_f4a18d5e new file mode 100644 index 0000000000..f760b6553f Binary files /dev/null and b/python/samples/demos/chatkit-integration/uploads/atc_f4a18d5e differ diff --git a/python/samples/demos/chatkit-integration/uploads/atc_fa77f9c0 b/python/samples/demos/chatkit-integration/uploads/atc_fa77f9c0 new file mode 100644 index 0000000000..f760b6553f Binary files /dev/null and b/python/samples/demos/chatkit-integration/uploads/atc_fa77f9c0 differ diff --git a/python/samples/getting_started/orchestrations/handoff/handoff_with_code_interpreter_file.py b/python/samples/getting_started/orchestrations/handoff/handoff_with_code_interpreter_file.py deleted file mode 100644 index 9801d64888..0000000000 --- a/python/samples/getting_started/orchestrations/handoff/handoff_with_code_interpreter_file.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -""" -Handoff Workflow with Code Interpreter File Generation Sample - -This sample demonstrates retrieving file IDs from code interpreter output -in a handoff workflow context. A triage agent routes to a code specialist -that generates a text file, and we verify the file_id is captured correctly -from the streaming workflow events. - -Verifies GitHub issue #2718: files generated by code interpreter in -HandoffBuilder workflows can be properly retrieved. - -Prerequisites: - - AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint. - - `az login` (Azure CLI authentication) - - AZURE_AI_MODEL_DEPLOYMENT_NAME -""" - -import asyncio -import os -from collections.abc import AsyncIterable -from typing import cast - -from agent_framework import ( - AgentResponseUpdate, - Message, - WorkflowEvent, - WorkflowRunState, -) -from agent_framework.azure import AzureOpenAIResponsesClient -from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder -from azure.identity import AzureCliCredential - - -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream.""" - return [event async for event in stream] - - -def _handle_events(events: list[WorkflowEvent]) -> tuple[list[WorkflowEvent[HandoffAgentUserRequest]], list[str]]: - """Process workflow events and extract file IDs and pending requests. - - Returns: - Tuple of (pending_requests, file_ids_found) - """ - - requests: list[WorkflowEvent[HandoffAgentUserRequest]] = [] - file_ids: list[str] = [] - - for event in events: - if event.type == "handoff_sent": - print(f"\n[Handoff from {event.data.source} to {event.data.target} initiated.]") - elif event.type == "status" and event.state in { - WorkflowRunState.IDLE, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - }: - print(f"[status] {event.state}") - elif event.type == "request_info" and isinstance(event.data, HandoffAgentUserRequest): - requests.append(cast(WorkflowEvent[HandoffAgentUserRequest], event)) - elif event.type == "output": - data = event.data - if isinstance(data, AgentResponseUpdate): - for content in data.contents: - if content.type == "hosted_file": - file_ids.append(content.file_id) # type: ignore - print(f"[Found HostedFileContent: file_id={content.file_id}]") - elif content.type == "text" and content.annotations: - for annotation in content.annotations: - file_id = annotation["file_id"] # type: ignore - file_ids.append(file_id) - print(f"[Found file annotation: file_id={file_id}]") - elif isinstance(data, list): - conversation = cast(list[Message], data) - if isinstance(conversation, list): - print("\n=== Final Conversation Snapshot ===") - for message in conversation: - speaker = message.author_name or message.role - print(f"- {speaker}: {message.text or [content.type for content in message.contents]}") - print("===================================") - - return requests, file_ids - - -async def main() -> None: - """Run a simple handoff workflow with code interpreter file generation.""" - print("=== Handoff Workflow with Code Interpreter File Generation ===\n") - - client = AzureOpenAIResponsesClient( - project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"], - deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"], - credential=AzureCliCredential(), - ) - - triage = client.as_agent( - name="triage_agent", - instructions=( - "You are a triage agent. Route code-related requests to the code_specialist. " - "When the user asks to create or generate files, hand off to code_specialist " - "by calling handoff_to_code_specialist." - ), - ) - - code_interpreter_tool = client.get_code_interpreter_tool() - - code_specialist = client.as_agent( - name="code_specialist", - instructions=( - "You are a Python code specialist. Use the code interpreter to execute Python code " - "and create files when requested. Always save files to /mnt/data/ directory." - ), - tools=[code_interpreter_tool], - ) - - workflow = ( - HandoffBuilder( - termination_condition=lambda conv: sum(1 for msg in conv if msg.role == "user") >= 2, - ) - .participants([triage, code_specialist]) - .with_start_agent(triage) - .build() - ) - - user_inputs = [ - "Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it.", - "exit", - ] - input_index = 0 - all_file_ids: list[str] = [] - - print(f"User: {user_inputs[0]}") - events = await _drain(workflow.run(user_inputs[0], stream=True)) - requests, file_ids = _handle_events(events) - all_file_ids.extend(file_ids) - input_index += 1 - - while requests: - request = requests[0] - if input_index >= len(user_inputs): - break - user_input = user_inputs[input_index] - print(f"\nUser: {user_input}") - - responses = {request.request_id: HandoffAgentUserRequest.create_response(user_input)} - events = await _drain(workflow.run(stream=True, responses=responses)) - requests, file_ids = _handle_events(events) - all_file_ids.extend(file_ids) - input_index += 1 - - print("\n" + "=" * 50) - if all_file_ids: - print(f"SUCCESS: Found {len(all_file_ids)} file ID(s) in handoff workflow:") - for fid in all_file_ids: - print(f" - {fid}") - else: - print("WARNING: No file IDs captured from the handoff workflow.") - print("=" * 50) - - """ - Sample Output: - - User: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it. - [Found HostedFileContent: file_id=assistant-JT1sA...] - - === Conversation So Far === - - user: Please create a text file called hello.txt with 'Hello from handoff workflow!' inside it. - - triage_agent: I am handing off your request to create the text file "hello.txt" with the specified content to the code specialist. They will assist you shortly. - - code_specialist: The file "hello.txt" has been created with the content "Hello from handoff workflow!". You can download it using the link below: - - [hello.txt](sandbox:/mnt/data/hello.txt) - =========================== - - [status] IDLE_WITH_PENDING_REQUESTS - - User: exit - [status] IDLE - - ================================================== - SUCCESS: Found 1 file ID(s) in handoff workflow: - - assistant-JT1sA... - ================================================== - """ # noqa: E501 - - -if __name__ == "__main__": - asyncio.run(main())