diff --git a/src/uipath_langchain/agent/react/agent.py b/src/uipath_langchain/agent/react/agent.py index 0756a5a4..1bcb6428 100644 --- a/src/uipath_langchain/agent/react/agent.py +++ b/src/uipath_langchain/agent/react/agent.py @@ -157,6 +157,9 @@ def create_agent( target_node_names, ) + if config.is_conversational: + target_node_names.append(AgentGraphNode.AGENT) + for tool_name in tool_node_names: builder.add_conditional_edges(tool_name, route_agent, target_node_names) builder.add_edge(AgentGraphNode.TERMINATE, END) diff --git a/src/uipath_langchain/runtime/factory.py b/src/uipath_langchain/runtime/factory.py index 88a85bda..51a98953 100644 --- a/src/uipath_langchain/runtime/factory.py +++ b/src/uipath_langchain/runtime/factory.py @@ -264,16 +264,17 @@ async def _create_runtime_instance( Returns: Configured runtime instance """ + memory = await self._get_memory() + storage = SqliteResumableStorage(memory) + trigger_manager = UiPathResumeTriggerHandler() + base_runtime = UiPathLangGraphRuntime( graph=compiled_graph, runtime_id=runtime_id, entrypoint=entrypoint, + storage=storage, ) - memory = await self._get_memory() - storage = SqliteResumableStorage(memory) - trigger_manager = UiPathResumeTriggerHandler() - return UiPathResumableRuntime( delegate=base_runtime, storage=storage, diff --git a/src/uipath_langchain/runtime/messages.py b/src/uipath_langchain/runtime/messages.py index 6a19d8f8..1ab4e014 100644 --- a/src/uipath_langchain/runtime/messages.py +++ b/src/uipath_langchain/runtime/messages.py @@ -1,15 +1,15 @@ +import asyncio import json import logging from datetime import datetime, timezone from typing import Any, cast -from uuid import uuid4 from langchain_core.messages import ( AIMessageChunk, BaseMessage, HumanMessage, TextContentBlock, - ToolCallChunk, + ToolCall, ToolMessage, ) from pydantic import ValidationError @@ -27,9 +27,13 @@ UiPathConversationToolCallStartEvent, UiPathInlineValue, ) +from uipath.runtime import UiPathRuntimeStorageProtocol logger = logging.getLogger(__name__) +STORAGE_NAMESPACE_EVENT_MAPPER = "chat-event-mapper" +STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP = "tool_call_map" + class UiPathChatMessagesMapper: """Stateful mapper that converts LangChain messages to UiPath message events. @@ -38,11 +42,13 @@ class UiPathChatMessagesMapper: - The AI message ID associated with each tool call for proper correlation with ToolMessage """ - def __init__(self): + def __init__(self, runtime_id: str, storage: UiPathRuntimeStorageProtocol | None): """Initialize the mapper with empty state.""" - self.tool_call_to_ai_message: dict[str, str] = {} + self.runtime_id = runtime_id + self.storage = storage self.current_message: AIMessageChunk self.seen_message_ids: set[str] = set() + self._storage_lock = asyncio.Lock() def _extract_text(self, content: Any) -> str: """Normalize LangGraph message.content to plain text.""" @@ -139,7 +145,7 @@ def _map_messages_internal( return human_messages - def map_event( + async def map_event( self, message: BaseMessage, ) -> list[UiPathConversationMessageEvent] | None: @@ -151,164 +157,273 @@ def map_event( Returns: A UiPathConversationMessageEvent if the message should be emitted, None otherwise. """ - # Format timestamp as ISO 8601 UTC with milliseconds: 2025-01-04T10:30:00.123Z - timestamp = ( + # --- Streaming AIMessageChunk --- + if isinstance(message, AIMessageChunk): + return await self.map_ai_message_chunk_to_events(message) + + # --- ToolMessage --- + if isinstance(message, ToolMessage): + return await self.map_tool_message_to_events(message) + + # Don't send events for system or user messages. Agent messages are handled above. + return None + + def get_timestamp(self): + """Format current time as ISO 8601 UTC with milliseconds: 2025-01-04T10:30:00.123Z""" + return ( datetime.now(timezone.utc) .isoformat(timespec="milliseconds") .replace("+00:00", "Z") ) - # --- Streaming AIMessageChunk --- - if isinstance(message, AIMessageChunk): - if message.id is None: - return None - - msg_event = UiPathConversationMessageEvent( - message_id=message.id, - ) + def get_content_part_id(self, message_id: str) -> str: + return f"chunk-{message_id}-0" - # Check if this is the last chunk by examining chunk_position - if message.chunk_position == "last": - events: list[UiPathConversationMessageEvent] = [] - - # Loop through all content_blocks in current_message and create toolCallStart events for each tool_call_chunk - if self.current_message and self.current_message.content_blocks: - for block in self.current_message.content_blocks: - if block.get("type") == "tool_call_chunk": - tool_chunk_block = cast(ToolCallChunk, block) - tool_call_id = tool_chunk_block.get("id") - tool_name = tool_chunk_block.get("name") - tool_args = tool_chunk_block.get("args") - - if tool_call_id: - tool_event = UiPathConversationMessageEvent( - message_id=message.id, - tool_call=UiPathConversationToolCallEvent( - tool_call_id=tool_call_id, - start=UiPathConversationToolCallStartEvent( - tool_name=tool_name, - timestamp=timestamp, - input=UiPathInlineValue(inline=tool_args), - ), - ), - ) - events.append(tool_event) - - # Create the final event for the message - msg_event.end = UiPathConversationMessageEndEvent(timestamp=timestamp) - msg_event.content_part = UiPathConversationContentPartEvent( - content_part_id=f"chunk-{message.id}-0", - end=UiPathConversationContentPartEndEvent(), - ) - events.append(msg_event) + async def map_ai_message_chunk_to_events( + self, message: AIMessageChunk + ) -> list[UiPathConversationMessageEvent]: + if message.id is None: # Should we throw instead? + return [] - return events + events: list[UiPathConversationMessageEvent] = [] + + # For every new message_id, start a new message + if message.id not in self.seen_message_ids: + self.current_message = message + self.seen_message_ids.add(message.id) + events.append(self.map_to_message_start_event(message.id)) + + if message.content_blocks: + # Generate events for each chunk + for block in message.content_blocks: + block_type = block.get("type") + match block_type: + case "text": + events.append( + self.map_chunk_to_content_part_chunk_event( + message.id, cast(TextContentBlock, block) + ) + ) + case "tool_call_chunk": + # Accumulate the message chunk + self.current_message = self.current_message + message - # For every new message_id, start a new message - if message.id not in self.seen_message_ids: - self.seen_message_ids.add(message.id) - self.current_message = message - msg_event.start = UiPathConversationMessageStartEvent( - role="assistant", timestamp=timestamp - ) - msg_event.content_part = UiPathConversationContentPartEvent( - content_part_id=f"chunk-{message.id}-0", - start=UiPathConversationContentPartStartEvent( - mime_type="text/plain" - ), + elif isinstance(message.content, str) and message.content: + # Fallback: raw string content on the chunk (rare when using content_blocks) + events.append( + self.map_content_to_content_part_chunk_event( + message.id, message.content ) + ) - elif message.content_blocks: - for block in message.content_blocks: - block_type = block.get("type") + # Check if this is the last chunk by examining chunk_position, send end message event only if there are no pending tool calls + if message.chunk_position == "last": + if ( + self.current_message.tool_calls is not None + and len(self.current_message.tool_calls) > 0 + ): + events.extend( + await self.map_current_message_to_start_tool_call_events() + ) + else: + events.append(self.map_to_message_end_event(message.id)) + + return events + + async def map_current_message_to_start_tool_call_events(self): + events: list[UiPathConversationMessageEvent] = [] + if ( + self.current_message + and self.current_message.id is not None + and self.current_message.tool_calls + ): + async with self._storage_lock: + if self.storage is not None: + tool_call_id_to_message_id_map: dict[ + str, str + ] = await self.storage.get_value( + self.runtime_id, + STORAGE_NAMESPACE_EVENT_MAPPER, + STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, + ) - if block_type == "text": - text_block = cast(TextContentBlock, block) - text = text_block["text"] + if tool_call_id_to_message_id_map is None: + tool_call_id_to_message_id_map = {} + else: + tool_call_id_to_message_id_map = {} - msg_event.content_part = UiPathConversationContentPartEvent( - content_part_id=f"chunk-{message.id}-0", - chunk=UiPathConversationContentPartChunkEvent( - data=text, - ), + for tool_call in self.current_message.tool_calls: + tool_call_id = tool_call["id"] + if tool_call_id is not None: + tool_call_id_to_message_id_map[tool_call_id] = ( + self.current_message.id + ) + events.append( + self.map_tool_call_to_tool_call_start_event( + self.current_message.id, tool_call + ) ) - elif block_type == "tool_call_chunk": - tool_chunk_block = cast(ToolCallChunk, block) + if self.storage is not None: + await self.storage.set_value( + self.runtime_id, + STORAGE_NAMESPACE_EVENT_MAPPER, + STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, + tool_call_id_to_message_id_map, + ) - tool_call_id = tool_chunk_block.get("id") - if tool_call_id: - # Track tool_call_id -> ai_message_id mapping - self.tool_call_to_ai_message[tool_call_id] = message.id + return events - # Accumulate the message chunk - self.current_message = self.current_message + message - continue + async def map_tool_message_to_events( + self, message: ToolMessage + ) -> list[UiPathConversationMessageEvent]: + # Look up the AI message ID using the tool_call_id + message_id, is_last_tool_call = await self.get_message_id_for_tool_call( + message.tool_call_id + ) + if message_id is None: + logger.warning( + f"Tool message {message.tool_call_id} has no associated AI message ID. Skipping." + ) + return [] - # Fallback: raw string content on the chunk (rare when using content_blocks) - elif isinstance(message.content, str) and message.content: - msg_event.content_part = UiPathConversationContentPartEvent( - content_part_id=f"content-{message.id}", - chunk=UiPathConversationContentPartChunkEvent( - data=message.content, + content_value: Any = message.content + if isinstance(content_value, str): + try: + content_value = json.loads(content_value) + except (json.JSONDecodeError, TypeError): + # Keep as string if not valid JSON + pass + + events = [ + UiPathConversationMessageEvent( + message_id=message_id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=message.tool_call_id, + end=UiPathConversationToolCallEndEvent( + timestamp=self.get_timestamp(), + output=UiPathInlineValue(inline=content_value), ), - ) + ), + ) + ] - if ( - msg_event.start - or msg_event.content_part - or msg_event.tool_call - or msg_event.end - ): - return [msg_event] + if is_last_tool_call: + events.append(self.map_to_message_end_event(message_id)) - return None + return events - # --- ToolMessage --- - if isinstance(message, ToolMessage): - # Look up the AI message ID using the tool_call_id - result_message_id = ( - self.tool_call_to_ai_message.get(message.tool_call_id) - if message.tool_call_id - else None + async def get_message_id_for_tool_call( + self, tool_call_id: str + ) -> tuple[str | None, bool]: + if self.storage is None: + logger.error( + f"attempt to lookup tool call id {tool_call_id} when no storage provided" + ) + return None, False + + async with self._storage_lock: + tool_call_id_to_message_id_map: dict[ + str, str + ] = await self.storage.get_value( + self.runtime_id, + STORAGE_NAMESPACE_EVENT_MAPPER, + STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, ) - # If no AI message ID was found, we cannot properly associate this tool result - if not result_message_id: - logger.warning( - f"Tool message {message.tool_call_id} has no associated AI message ID. Skipping." + if tool_call_id_to_message_id_map is None: + logger.error( + f"attempt to lookup tool call id {tool_call_id} when no map present in storage" ) + return None, False - # Clean up the mapping after use - if ( - message.tool_call_id - and message.tool_call_id in self.tool_call_to_ai_message - ): - del self.tool_call_to_ai_message[message.tool_call_id] - - content_value: Any = message.content - if isinstance(content_value, str): - try: - content_value = json.loads(content_value) - except (json.JSONDecodeError, TypeError): - # Keep as string if not valid JSON - pass - - return [ - UiPathConversationMessageEvent( - message_id=result_message_id or str(uuid4()), - tool_call=UiPathConversationToolCallEvent( - tool_call_id=message.tool_call_id, - end=UiPathConversationToolCallEndEvent( - timestamp=timestamp, - output=UiPathInlineValue(inline=content_value), - ), - ), + message_id = tool_call_id_to_message_id_map.get(tool_call_id) + if message_id is None: + logger.error( + f"tool call to message map does not contain tool call id {tool_call_id}" ) - ] + return None, False - # Don't send events for system or user messages. Agent messages are handled above. - return [] + del tool_call_id_to_message_id_map[tool_call_id] + + await self.storage.set_value( + self.runtime_id, + STORAGE_NAMESPACE_EVENT_MAPPER, + STORAGE_KEY_TOOL_CALL_ID_TO_MESSAGE_ID_MAP, + tool_call_id_to_message_id_map, + ) + + is_last = message_id not in tool_call_id_to_message_id_map.values() + + return message_id, is_last + + def map_tool_call_to_tool_call_start_event( + self, message_id: str, tool_call: ToolCall + ) -> UiPathConversationMessageEvent: + return UiPathConversationMessageEvent( + message_id=message_id, + tool_call=UiPathConversationToolCallEvent( + tool_call_id=tool_call["id"], + start=UiPathConversationToolCallStartEvent( + tool_name=tool_call["name"], + timestamp=self.get_timestamp(), + input=UiPathInlineValue(inline=tool_call["args"]), + ), + ), + ) + + def map_chunk_to_content_part_chunk_event( + self, message_id: str, block: TextContentBlock + ) -> UiPathConversationMessageEvent: + text = block["text"] + return UiPathConversationMessageEvent( + message_id=message_id, + content_part=UiPathConversationContentPartEvent( + content_part_id=self.get_content_part_id(message_id), + chunk=UiPathConversationContentPartChunkEvent( + data=text, + ), + ), + ) + + def map_content_to_content_part_chunk_event( + self, message_id: str, content: str + ) -> UiPathConversationMessageEvent: + return UiPathConversationMessageEvent( + message_id=message_id, + content_part=UiPathConversationContentPartEvent( + content_part_id=self.get_content_part_id(message_id), + chunk=UiPathConversationContentPartChunkEvent( + data=content, + ), + ), + ) + + def map_to_message_start_event( + self, message_id: str + ) -> UiPathConversationMessageEvent: + return UiPathConversationMessageEvent( + message_id=message_id, + start=UiPathConversationMessageStartEvent( + role="assistant", timestamp=self.get_timestamp() + ), + content_part=UiPathConversationContentPartEvent( + content_part_id=self.get_content_part_id(message_id), + start=UiPathConversationContentPartStartEvent(mime_type="text/plain"), + ), + ) + + def map_to_message_end_event( + self, message_id: str + ) -> UiPathConversationMessageEvent: + return UiPathConversationMessageEvent( + message_id=message_id, + end=UiPathConversationMessageEndEvent(), + content_part=UiPathConversationContentPartEvent( + content_part_id=self.get_content_part_id(message_id), + end=UiPathConversationContentPartEndEvent(), + ), + ) __all__ = ["UiPathChatMessagesMapper"] diff --git a/src/uipath_langchain/runtime/runtime.py b/src/uipath_langchain/runtime/runtime.py index f30ebf87..5e9ac903 100644 --- a/src/uipath_langchain/runtime/runtime.py +++ b/src/uipath_langchain/runtime/runtime.py @@ -13,6 +13,7 @@ UiPathExecuteOptions, UiPathRuntimeResult, UiPathRuntimeStatus, + UiPathRuntimeStorageProtocol, UiPathStreamOptions, ) from uipath.runtime.errors import UiPathErrorCategory, UiPathErrorCode @@ -43,6 +44,7 @@ def __init__( runtime_id: str | None = None, entrypoint: str | None = None, callbacks: list[BaseCallbackHandler] | None = None, + storage: UiPathRuntimeStorageProtocol | None = None, ): """ Initialize the runtime. @@ -56,7 +58,7 @@ def __init__( self.runtime_id: str = runtime_id or "default" self.entrypoint: str | None = entrypoint self.callbacks: list[BaseCallbackHandler] = callbacks or [] - self.chat = UiPathChatMessagesMapper() + self.chat = UiPathChatMessagesMapper(self.runtime_id, storage) self._middleware_node_names: set[str] = self._detect_middleware_nodes() async def execute( @@ -139,7 +141,7 @@ async def stream( if isinstance(data, tuple): message, _ = data try: - events = self.chat.map_event(message) + events = await self.chat.map_event(message) except Exception as e: logger.warning(f"Error mapping message event: {e}") events = None diff --git a/tests/agent/react/test_create_agent.py b/tests/agent/react/test_create_agent.py new file mode 100644 index 00000000..0ea8e325 --- /dev/null +++ b/tests/agent/react/test_create_agent.py @@ -0,0 +1,317 @@ +"""Tests for create_agent function in agent.py module.""" + +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import SystemMessage +from langchain_core.runnables.graph import Edge +from langchain_core.tools import BaseTool +from langgraph.graph import StateGraph + +from uipath_langchain.agent.react.agent import create_agent +from uipath_langchain.agent.react.init_node import create_init_node +from uipath_langchain.agent.react.router import create_route_agent +from uipath_langchain.agent.react.router_conversational import ( + create_route_agent_conversational, +) +from uipath_langchain.agent.react.terminate_node import create_terminate_node +from uipath_langchain.agent.react.tools.tools import create_flow_control_tools +from uipath_langchain.agent.react.types import ( + AgentGraphConfig, + AgentGraphNode, + AgentSettings, +) +from uipath_langchain.chat.types import APIFlavor, LLMProvider + + +def _make_mock_model() -> MagicMock: + """Create a mock chat model that satisfies UiPathPassthroughChatModel protocol.""" + model = MagicMock(spec=BaseChatModel) + model.llm_provider = LLMProvider.OPENAI + model.api_flavor = APIFlavor.OPENAI_RESPONSES + + # Protocol check: make isinstance(model, UiPathPassthroughChatModel) return True + from uipath_langchain.chat.types import UiPathPassthroughChatModel + + model.__class__ = type( + "MockUiPathModel", + (UiPathPassthroughChatModel,), + { + "llm_provider": property(lambda self: LLMProvider.OPENAI), + "api_flavor": property(lambda self: APIFlavor.OPENAI_RESPONSES), + }, + ) + return model + + +def _make_mock_tool(name: str = "test_tool") -> Mock: + """Create a mock BaseTool.""" + tool = Mock(spec=BaseTool) + tool.name = name + return tool + + +# Patch targets (all internal functions called by create_agent) +_PATCH_BASE = "uipath_langchain.agent.react.agent" +_PATCHES = { + "create_init_node": { + "target": f"{_PATCH_BASE}.create_init_node", + "wraps": create_init_node, + }, + "create_terminate_node": { + "target": f"{_PATCH_BASE}.create_terminate_node", + "wraps": create_terminate_node, + }, + "create_flow_control_tools": { + "target": f"{_PATCH_BASE}.create_flow_control_tools", + "wraps": create_flow_control_tools, + }, + "create_route_agent": { + "target": f"{_PATCH_BASE}.create_route_agent", + "wraps": create_route_agent, + }, + "create_route_agent_conversational": { + "target": f"{_PATCH_BASE}.create_route_agent_conversational", + "wraps": create_route_agent_conversational, + }, +} + + +def _patch(name, **overrides): + kwargs = {**_PATCHES[name], **overrides} + target = kwargs.pop("target") + return patch(target, **kwargs) + + +mock_tool_a_name = "mock_tool_a" +mock_tool_b_name = "mock_tool_b" + + +class TestCreateAgent: + """Test that create_agent wires up nodes/edges correctly with default config.""" + + @pytest.fixture + def mock_model(self): + return _make_mock_model() + + @pytest.fixture + def mock_tool_a(self): + return _make_mock_tool(mock_tool_a_name) + + @pytest.fixture + def mock_tool_b(self): + return _make_mock_tool(mock_tool_b_name) + + @pytest.fixture + def messages(self): + return [SystemMessage(content="You are a helpful assistant.")] + + @_patch("create_route_agent_conversational") + @_patch("create_route_agent") + @_patch("create_flow_control_tools") + @_patch("create_init_node") + @_patch("create_terminate_node") + def test_autonomous_agent_with_tools( + self, + mock_create_terminate_node, + mock_create_init_node, + mock_create_flow_control_tools, + mock_route_agent, + mock_route_agent_conversational, + mock_model, + mock_tool_a, + mock_tool_b, + messages, + ): + """Should return a StateGraph instance.""" + result: StateGraph[Any] = create_agent( + mock_model, [mock_tool_a, mock_tool_b], messages + ) + graph = result.compile().get_graph() + assert set(graph.nodes.keys()) == set( + [ + "__start__", + "__end__", + mock_tool_a_name, + mock_tool_b_name, + AgentGraphNode.TERMINATE, + AgentGraphNode.AGENT, + AgentGraphNode.INIT, + ] + ) + assert set(graph.edges) == set( + [ + Edge("__start__", AgentGraphNode.INIT), + Edge(AgentGraphNode.INIT, AgentGraphNode.AGENT), + Edge(AgentGraphNode.TERMINATE, "__end__"), + Edge(AgentGraphNode.AGENT, mock_tool_a_name, conditional=True), + Edge(AgentGraphNode.AGENT, mock_tool_b_name, conditional=True), + Edge(AgentGraphNode.AGENT, AgentGraphNode.TERMINATE, conditional=True), + Edge(AgentGraphNode.AGENT, AgentGraphNode.AGENT, conditional=True), + Edge(mock_tool_a_name, AgentGraphNode.TERMINATE, conditional=True), + Edge(mock_tool_a_name, AgentGraphNode.AGENT, conditional=True), + Edge(mock_tool_a_name, mock_tool_a_name, conditional=True), + Edge(mock_tool_a_name, mock_tool_b_name, conditional=True), + Edge(mock_tool_b_name, AgentGraphNode.TERMINATE, conditional=True), + Edge(mock_tool_b_name, AgentGraphNode.AGENT, conditional=True), + Edge(mock_tool_b_name, mock_tool_a_name, conditional=True), + Edge(mock_tool_b_name, mock_tool_b_name, conditional=True), + ] + ) + mock_route_agent.assert_called_once() + mock_route_agent_conversational.assert_not_called() + mock_create_flow_control_tools.assert_called_once() + mock_create_init_node.assert_called_once_with( + messages, + None, # input schema + False, # is_conversational + AgentSettings( + llm_provider=mock_model.llm_provider, + api_flavor=mock_model.api_flavor, + ), + ) + mock_create_terminate_node.assert_called_once_with( + None, # output schema + False, # is_conversational + ) + + @_patch("create_route_agent_conversational") + @_patch("create_route_agent") + @_patch("create_flow_control_tools") + def test_autonomous_agent_without_tools( + self, + mock_create_flow_control_tools, + mock_route_agent, + mock_route_agent_conversational, + mock_model, + messages, + ): + """Should return a StateGraph instance.""" + result: StateGraph[Any] = create_agent(mock_model, [], messages) + graph = result.compile().get_graph() + assert set(graph.nodes.keys()) == set( + [ + "__start__", + "__end__", + AgentGraphNode.TERMINATE, + AgentGraphNode.AGENT, + AgentGraphNode.INIT, + ] + ) + assert set(graph.edges) == set( + [ + Edge("__start__", AgentGraphNode.INIT), + Edge(AgentGraphNode.INIT, AgentGraphNode.AGENT), + Edge(AgentGraphNode.TERMINATE, "__end__"), + Edge(AgentGraphNode.AGENT, AgentGraphNode.TERMINATE, conditional=True), + Edge(AgentGraphNode.AGENT, AgentGraphNode.AGENT, conditional=True), + ] + ) + mock_route_agent.assert_called_once() + mock_route_agent_conversational.assert_not_called() + mock_create_flow_control_tools.assert_called_once() + + @_patch("create_route_agent_conversational") + @_patch("create_route_agent") + @_patch("create_flow_control_tools") + @_patch("create_init_node") + @_patch("create_terminate_node") + def test_conversational_agent_with_tools( + self, + mock_create_terminate_node, + mock_create_init_node, + mock_create_flow_control_tools, + mock_route_agent, + mock_route_agent_conversational, + mock_model, + mock_tool_a, + messages, + ): + """Should return a StateGraph instance.""" + result: StateGraph[Any] = create_agent( + mock_model, + [mock_tool_a], + messages, + config=AgentGraphConfig(is_conversational=True), + ) + graph = result.compile().get_graph() + assert set(graph.nodes.keys()) == set( + [ + "__start__", + "__end__", + mock_tool_a_name, + AgentGraphNode.TERMINATE, + AgentGraphNode.AGENT, + AgentGraphNode.INIT, + ] + ) + assert set(graph.edges) == set( + [ + Edge("__start__", AgentGraphNode.INIT), + Edge(AgentGraphNode.INIT, AgentGraphNode.AGENT), + Edge(AgentGraphNode.TERMINATE, "__end__"), + Edge(AgentGraphNode.AGENT, mock_tool_a_name, conditional=True), + Edge(AgentGraphNode.AGENT, AgentGraphNode.TERMINATE, conditional=True), + Edge(mock_tool_a_name, AgentGraphNode.TERMINATE, conditional=True), + Edge(mock_tool_a_name, mock_tool_a_name, conditional=True), + Edge(mock_tool_a_name, AgentGraphNode.AGENT, conditional=True), + ] + ) + mock_route_agent.assert_not_called() + mock_route_agent_conversational.assert_called_once() + mock_create_flow_control_tools.assert_not_called() + mock_create_init_node.assert_called_once_with( + messages, + None, # input schema + True, # is_conversational + AgentSettings( + llm_provider=mock_model.llm_provider, + api_flavor=mock_model.api_flavor, + ), + ) + mock_create_terminate_node.assert_called_once_with( + None, # output schema + True, # is_conversational + ) + + @_patch("create_route_agent_conversational") + @_patch("create_route_agent") + @_patch("create_flow_control_tools") + def test_conversational_agent_without_tools( + self, + mock_create_flow_control_tools, + mock_route_agent, + mock_route_agent_conversational, + mock_model, + messages, + ): + """Should return a StateGraph instance.""" + result: StateGraph[Any] = create_agent( + mock_model, + [], + messages, + config=AgentGraphConfig(is_conversational=True), + ) + graph = result.compile().get_graph() + assert set(graph.nodes.keys()) == set( + [ + "__start__", + "__end__", + AgentGraphNode.TERMINATE, + AgentGraphNode.AGENT, + AgentGraphNode.INIT, + ] + ) + assert set(graph.edges) == set( + [ + Edge("__start__", AgentGraphNode.INIT), + Edge(AgentGraphNode.INIT, AgentGraphNode.AGENT), + Edge(AgentGraphNode.TERMINATE, "__end__"), + Edge(AgentGraphNode.AGENT, AgentGraphNode.TERMINATE, conditional=True), + ] + ) + mock_route_agent.assert_not_called() + mock_route_agent_conversational.assert_called_once() + mock_create_flow_control_tools.assert_not_called() diff --git a/tests/agent/react/test_router_conversational.py b/tests/agent/react/test_router_conversational.py index 52868eae..16724c28 100644 --- a/tests/agent/react/test_router_conversational.py +++ b/tests/agent/react/test_router_conversational.py @@ -3,7 +3,7 @@ from typing import Any import pytest -from langchain_core.messages import AIMessage, HumanMessage +from langchain_core.messages import AIMessage, HumanMessage, ToolMessage from pydantic import BaseModel from uipath_langchain.agent.exceptions import AgentNodeRoutingException @@ -171,6 +171,44 @@ def test_routes_to_first_tool_in_sequence(self, route_function): assert result == "first_tool" + def test_routes_to_sequence_tool_in_sequence_when_first_tool_completed( + self, route_function + ): + """Should route to first tool in sequential execution.""" + ai_message = AIMessage( + content="Using tools in order", + tool_calls=[ + {"name": "first_tool", "args": {}, "id": "call_1"}, + {"name": "second_tool", "args": {}, "id": "call_2"}, + {"name": "third_tool", "args": {}, "id": "call_3"}, + ], + ) + tool_message = ToolMessage(tool_call_id="call_1") + state = MockAgentGraphState(messages=[ai_message, tool_message]) + + result = route_function(state) + + assert result == "second_tool" + + def test_routes_to_agent_when_tools_calls_completed(self, route_function): + """Should route to first tool in sequential execution.""" + ai_message = AIMessage( + content="Using tools in order", + tool_calls=[ + {"name": "first_tool", "args": {}, "id": "call_1"}, + {"name": "second_tool", "args": {}, "id": "call_2"}, + ], + ) + tool_message_1 = ToolMessage(tool_call_id="call_1") + tool_message_2 = ToolMessage(tool_call_id="call_2") + state = MockAgentGraphState( + messages=[ai_message, tool_message_1, tool_message_2] + ) + + result = route_function(state) + + assert result == AgentGraphNode.AGENT + class TestRouteAgentConversationalFactory: """Test cases for the factory function behavior.""" diff --git a/tests/runtime/chat_message_mapper.py b/tests/runtime/chat_message_mapper.py new file mode 100644 index 00000000..fa4d2963 --- /dev/null +++ b/tests/runtime/chat_message_mapper.py @@ -0,0 +1,598 @@ +"""Tests for UiPathChatMessagesMapper.""" + +from unittest.mock import AsyncMock, patch + +import pytest +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + HumanMessage, + SystemMessage, + ToolMessage, +) +from uipath.core.chat import ( + UiPathConversationContentPart, + UiPathConversationMessage, + UiPathInlineValue, +) + +from uipath_langchain.runtime.messages import UiPathChatMessagesMapper + +# Helper timestamp string for tests +TEST_TIMESTAMP = "2025-01-15T10:30:00Z" + + +def create_mock_storage(): + """Create a mock storage object for testing.""" + storage = AsyncMock() + storage.get_value = AsyncMock(return_value=None) + storage.set_value = AsyncMock() + return storage + + +class TestExtractText: + """Tests for the _extract_text method.""" + + def test_extract_text_from_string(self): + """Should return string content as-is.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + result = mapper._extract_text("hello world") + + assert result == "hello world" + + def test_extract_text_from_list_with_text_parts(self): + """Should extract text from list of content parts.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + content = [ + {"type": "text", "text": "hello "}, + {"type": "text", "text": "world"}, + ] + + result = mapper._extract_text(content) + + assert result == "hello world" + + def test_extract_text_from_list_ignores_non_text_parts(self): + """Should ignore non-text parts in content list.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + content = [ + {"type": "text", "text": "hello"}, + {"type": "image", "url": "http://example.com/img.png"}, + {"type": "text", "text": " world"}, + ] + + result = mapper._extract_text(content) + + assert result == "hello world" + + def test_extract_text_from_empty_list(self): + """Should return empty string for empty list.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + result = mapper._extract_text([]) + + assert result == "" + + def test_extract_text_from_none(self): + """Should return empty string for None.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + result = mapper._extract_text(None) + + assert result == "" + + def test_extract_text_from_other_type(self): + """Should convert other types to string.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + result = mapper._extract_text(123) + + assert result == "123" + + +class TestMapMessages: + """Tests for the map_messages method.""" + + def test_map_messages_raises_on_non_list(self): + """Should raise TypeError when messages is not a list.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + with pytest.raises(TypeError, match="messages must be a list"): + mapper.map_messages("not a list") # type: ignore[arg-type] + + def test_map_messages_returns_empty_list_for_empty_input(self): + """Should return empty list for empty input.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + + result = mapper.map_messages([]) + + assert result == [] + + def test_map_messages_passes_through_langchain_messages(self): + """Should return LangChain messages as-is.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + messages = [ + HumanMessage(content="hello"), + AIMessage(content="hi there"), + ] + + result = mapper.map_messages(messages) + + assert result == messages + + def test_map_messages_converts_uipath_messages(self): + """Should convert UiPath messages to HumanMessages.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + uipath_msg = UiPathConversationMessage( + message_id="msg-1", + role="user", + created_at=TEST_TIMESTAMP, + updated_at=TEST_TIMESTAMP, + content_parts=[ + UiPathConversationContentPart( + content_part_id="part-1", + mime_type="text/plain", + data=UiPathInlineValue(inline="hello world"), + ) + ], + ) + + result = mapper.map_messages([uipath_msg]) + + assert len(result) == 1 + msg = result[0] + assert isinstance(msg, HumanMessage) + assert msg.content == "hello world" + assert msg.metadata["message_id"] == "msg-1" # type: ignore[attr-defined] + assert msg.metadata["content_part_id"] == "part-1" # type: ignore[attr-defined] + assert msg.metadata["mime_type"] == "text/plain" # type: ignore[attr-defined] + + def test_map_messages_converts_dict_messages(self): + """Should convert dict messages to HumanMessages.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + dict_msg = { + "message_id": "msg-1", + "role": "user", + "createdAt": "2025-01-15T10:30:00Z", + "updatedAt": "2025-01-15T10:30:00Z", + "content_parts": [ + { + "content_part_id": "part-1", + "mime_type": "text/plain", + "data": {"inline": "hello from dict"}, + } + ], + } + + result = mapper.map_messages([dict_msg]) + + assert len(result) == 1 + assert isinstance(result[0], HumanMessage) + assert result[0].content == "hello from dict" + + def test_map_messages_raises_on_mixed_uipath_types(self): + """Should raise TypeError for mixed UiPath message types.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + uipath_msg = UiPathConversationMessage( + message_id="msg-1", + role="user", + created_at=TEST_TIMESTAMP, + updated_at=TEST_TIMESTAMP, + ) + + with pytest.raises(TypeError, match="Mixed message types not supported"): + mapper.map_messages([uipath_msg, "not a uipath message"]) + + def test_map_messages_passthrough_unknown_types(self): + """Should pass through unknown types.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + unknown = [{"unknown_field": "value"}] + + result = mapper.map_messages(unknown) + + assert result == unknown + + def test_map_messages_handles_multiple_content_parts(self): + """Should create separate HumanMessages for each content part.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + uipath_msg = UiPathConversationMessage( + message_id="msg-1", + role="user", + created_at=TEST_TIMESTAMP, + updated_at=TEST_TIMESTAMP, + content_parts=[ + UiPathConversationContentPart( + content_part_id="part-1", + mime_type="text/plain", + data=UiPathInlineValue(inline="first part"), + ), + UiPathConversationContentPart( + content_part_id="part-2", + mime_type="text/plain", + data=UiPathInlineValue(inline="second part"), + ), + ], + ) + + result = mapper.map_messages([uipath_msg]) + + assert len(result) == 2 + msg0 = result[0] + msg1 = result[1] + assert isinstance(msg0, HumanMessage) + assert isinstance(msg1, HumanMessage) + assert msg0.content == "first part" + assert msg0.metadata["content_part_id"] == "part-1" # type: ignore[attr-defined] + assert msg1.content == "second part" + assert msg1.metadata["content_part_id"] == "part-2" # type: ignore[attr-defined] + + def test_map_messages_handles_message_without_content_parts(self): + """Should handle UiPath message without content parts.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + uipath_msg = UiPathConversationMessage( + message_id="msg-1", + role="user", + created_at=TEST_TIMESTAMP, + updated_at=TEST_TIMESTAMP, + ) + + result = mapper.map_messages([uipath_msg]) + + assert len(result) == 1 + msg = result[0] + assert isinstance(msg, HumanMessage) + assert msg.content == "" + assert msg.metadata["message_id"] == "msg-1" # type: ignore[attr-defined] + assert msg.metadata["role"] == "user" # type: ignore[attr-defined] + + +class TestMapEvent: + """Tests for the map_event method.""" + + @pytest.mark.asyncio + async def test_map_event_returns_empty_list_for_ai_chunk_without_id(self): + """Should return empty list for AIMessageChunk without id.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + chunk = AIMessageChunk(content="hello", id=None) + + result = await mapper.map_event(chunk) + + assert result == [] + + @pytest.mark.asyncio + async def test_map_event_starts_new_message_for_new_id(self): + """Should emit start event for new message id.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + chunk = AIMessageChunk(content="", id="msg-123") + + result = await mapper.map_event(chunk) + + assert result is not None + assert len(result) == 1 + event = result[0] + assert event.message_id == "msg-123" + assert event.start is not None + assert event.start.role == "assistant" + assert event.content_part is not None + assert event.content_part.start is not None + + @pytest.mark.asyncio + async def test_map_event_tracks_seen_message_ids(self): + """Should track seen message ids.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + chunk = AIMessageChunk(content="", id="msg-123") + + await mapper.map_event(chunk) + + assert "msg-123" in mapper.seen_message_ids + + @pytest.mark.asyncio + async def test_map_event_emits_text_chunk_for_subsequent_messages(self): + """Should emit text chunk event for subsequent messages with same id.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + # First chunk starts the message + first_chunk = AIMessageChunk(content="", id="msg-123") + await mapper.map_event(first_chunk) + + # Second chunk with text content + second_chunk = AIMessageChunk( + content="", + id="msg-123", + content_blocks=[{"type": "text", "text": "hello"}], + ) + result = await mapper.map_event(second_chunk) + + assert result is not None + assert len(result) == 1 + event = result[0] + assert event.content_part is not None + assert event.content_part.chunk is not None + assert event.content_part.chunk.data == "hello" + + @pytest.mark.asyncio + async def test_map_event_handles_raw_string_content(self): + """Should handle raw string content on chunk.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + # First chunk starts the message + first_chunk = AIMessageChunk(content="", id="msg-123") + await mapper.map_event(first_chunk) + + # Second chunk with string content + second_chunk = AIMessageChunk(content="raw content", id="msg-123") + result = await mapper.map_event(second_chunk) + + assert result is not None + assert len(result) == 1 + event = result[0] + assert event.content_part is not None + assert event.content_part.chunk is not None + assert event.content_part.chunk.data == "raw content" + + @pytest.mark.asyncio + async def test_map_event_tracks_tool_call_to_message_mapping_in_storage(self): + """Should track tool_call_id to ai_message_id mapping in storage.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + # First chunk starts the message with tool_calls + first_chunk = AIMessageChunk( + content="", + id="msg-123", + tool_calls=[{"id": "tool-1", "name": "test_tool", "args": {}}], + ) + await mapper.map_event(first_chunk) + + # Last chunk triggers tool call start events + last_chunk = AIMessageChunk( + content="", + id="msg-123", + ) + object.__setattr__(last_chunk, "chunk_position", "last") + await mapper.map_event(last_chunk) + + # Verify storage was called with the tool call mapping + storage.set_value.assert_called() + call_args = storage.set_value.call_args + assert call_args[0][0] == "test-runtime" + assert call_args[0][1] == "chat-event-mapper" + assert call_args[0][2] == "tool_call_map" + assert "tool-1" in call_args[0][3] + assert call_args[0][3]["tool-1"] == "msg-123" + + @pytest.mark.asyncio + async def test_map_event_emits_end_event_for_last_chunk_without_tool_calls(self): + """Should emit end event for chunk with position 'last' when no tool calls.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + # First chunk starts the message + first_chunk = AIMessageChunk(content="", id="msg-123") + await mapper.map_event(first_chunk) + + # Last chunk with no tool calls + last_chunk = AIMessageChunk(content="", id="msg-123") + object.__setattr__(last_chunk, "chunk_position", "last") + + result = await mapper.map_event(last_chunk) + + assert result is not None + # Should have the end event + end_event = result[-1] + assert end_event.end is not None + assert end_event.content_part is not None + assert end_event.content_part.end is not None + + @pytest.mark.asyncio + async def test_map_event_emits_tool_call_start_events_on_last_chunk(self): + """Should emit tool call start events when chunk is last.""" + storage = create_mock_storage() + storage.get_value.return_value = {} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + # First chunk starts the message with tool_calls + first_chunk = AIMessageChunk( + content="", + id="msg-123", + tool_calls=[ + {"id": "tool-1", "name": "test_tool", "args": {"arg": "value"}} + ], + ) + await mapper.map_event(first_chunk) + + # Last chunk triggers tool call start events + last_chunk = AIMessageChunk( + content="", + id="msg-123", + ) + object.__setattr__(last_chunk, "chunk_position", "last") + + result = await mapper.map_event(last_chunk) + + assert result is not None + # Find the tool call events - may have duplicates from accumulation + tool_events = [e for e in result if e.tool_call is not None] + assert len(tool_events) >= 1 + # Check that at least one has the expected tool call start + tool_start_events = [ + e for e in tool_events if e.tool_call and e.tool_call.start is not None + ] + assert len(tool_start_events) >= 1 + tool_event = tool_start_events[0] + assert tool_event.tool_call is not None + assert tool_event.tool_call.tool_call_id == "tool-1" + assert tool_event.tool_call.start is not None + assert tool_event.tool_call.start.tool_name == "test_tool" + + @pytest.mark.asyncio + async def test_map_event_handles_tool_message(self): + """Should convert ToolMessage to tool call end event.""" + storage = create_mock_storage() + # Pre-populate the tool call mapping in storage + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"result": "success"}', + tool_call_id="tool-1", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + assert len(result) == 2 # tool call end event + message end event + event = result[0] + assert event.message_id == "msg-123" + assert event.tool_call is not None + assert event.tool_call.tool_call_id == "tool-1" + assert event.tool_call.end is not None + assert event.tool_call.end.output.inline == {"result": "success"} + + @pytest.mark.asyncio + async def test_map_event_cleans_up_tool_mapping_after_use(self): + """Should remove tool_call_id from storage mapping after processing ToolMessage.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage(content="result", tool_call_id="tool-1") + await mapper.map_event(tool_msg) + + # Verify storage was updated without the tool-1 key + storage.set_value.assert_called() + call_args = storage.set_value.call_args + assert "tool-1" not in call_args[0][3] + + @pytest.mark.asyncio + async def test_map_event_handles_tool_message_without_mapping(self): + """Should handle ToolMessage when no mapping exists.""" + storage = create_mock_storage() + storage.get_value.return_value = {} # Empty mapping + mapper = UiPathChatMessagesMapper("test-runtime", storage) + tool_msg = ToolMessage(content="result", tool_call_id="unknown-tool") + + with patch("uipath_langchain.runtime.messages.logger") as mock_logger: + result = await mapper.map_event(tool_msg) + + mock_logger.error.assert_called_once() + assert result == [] + + @pytest.mark.asyncio + async def test_map_event_handles_tool_message_without_storage(self): + """Should handle ToolMessage when no storage is configured.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + tool_msg = ToolMessage(content="result", tool_call_id="tool-1") + + with patch("uipath_langchain.runtime.messages.logger") as mock_logger: + result = await mapper.map_event(tool_msg) + + mock_logger.error.assert_called_once() + assert result == [] + + @pytest.mark.asyncio + async def test_map_event_parses_json_tool_content(self): + """Should parse JSON content in ToolMessage.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content='{"key": "value", "number": 42}', + tool_call_id="tool-1", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + event = result[0] + assert event.tool_call is not None + assert event.tool_call.end is not None + assert event.tool_call.end.output.inline == {"key": "value", "number": 42} + + @pytest.mark.asyncio + async def test_map_event_keeps_string_content_when_not_json(self): + """Should keep string content when not valid JSON.""" + storage = create_mock_storage() + storage.get_value.return_value = {"tool-1": "msg-123"} + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + tool_msg = ToolMessage( + content="not json content", + tool_call_id="tool-1", + ) + + result = await mapper.map_event(tool_msg) + + assert result is not None + event = result[0] + assert event.tool_call is not None + assert event.tool_call.end is not None + assert event.tool_call.end.output.inline == "not json content" + + @pytest.mark.asyncio + async def test_map_event_returns_none_for_system_message(self): + """Should return None for SystemMessage.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + msg = SystemMessage(content="system prompt") + + result = await mapper.map_event(msg) + + assert result is None + + @pytest.mark.asyncio + async def test_map_event_returns_none_for_human_message(self): + """Should return None for HumanMessage.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + msg = HumanMessage(content="user input") + + result = await mapper.map_event(msg) + + assert result is None + + @pytest.mark.asyncio + async def test_map_event_returns_empty_list_when_no_content_to_emit(self): + """Should return empty list when chunk has no content to emit.""" + mapper = UiPathChatMessagesMapper("test-runtime", None) + # First chunk starts the message + first_chunk = AIMessageChunk(content="", id="msg-123") + await mapper.map_event(first_chunk) + + # Empty subsequent chunk with no content blocks and no string content + empty_chunk = AIMessageChunk(content="", id="msg-123") + + result = await mapper.map_event(empty_chunk) + + assert result == [] + + @pytest.mark.asyncio + async def test_map_event_emits_message_end_after_last_tool_result(self): + """Should emit message end event after the last tool result for a message.""" + storage = create_mock_storage() + # Two tool calls for the same message + storage.get_value.side_effect = [ + {"tool-1": "msg-123", "tool-2": "msg-123"}, # First lookup + {"tool-2": "msg-123"}, # After removing tool-1 + ] + mapper = UiPathChatMessagesMapper("test-runtime", storage) + + # First tool result - not the last one + tool_msg1 = ToolMessage(content="result1", tool_call_id="tool-1") + result1 = await mapper.map_event(tool_msg1) + + # Should have tool call end but NOT message end + assert result1 is not None + assert len(result1) == 1 # Only tool call end + assert result1[0].tool_call is not None + assert result1[0].end is None + + # Second tool result - the last one + storage.get_value.side_effect = [ + {"tool-2": "msg-123"}, # Lookup for tool-2 + ] + tool_msg2 = ToolMessage(content="result2", tool_call_id="tool-2") + result2 = await mapper.map_event(tool_msg2) + + # Should have both tool call end AND message end + assert result2 is not None + assert len(result2) == 2 # Tool call end + message end + assert result2[0].tool_call is not None + assert result2[1].end is not None