Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy

from haystack.dataclasses.breakpoints import AgentSnapshot, ToolBreakpoint
from haystack.utils import _deserialize_value_with_schema

Expand Down Expand Up @@ -31,7 +33,7 @@ def get_tool_calls_and_descriptions_from_snapshot(
tool_caused_break_point = break_point.tool_name

# Deserialize the tool invoker inputs from the snapshot
tool_invoker_inputs = _deserialize_value_with_schema(agent_snapshot.component_inputs["tool_invoker"])
tool_invoker_inputs = _deserialize_value_with_schema(deepcopy(agent_snapshot.component_inputs["tool_invoker"]))
tool_call_messages = tool_invoker_inputs["messages"]
state = tool_invoker_inputs["state"]
tool_name_to_tool = {t.name: t for t in tool_invoker_inputs["tools"]}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
dependencies = [
"haystack-ai",
"rich", # For pretty printing in the console used by human-in-the-loop utilities
"lazy-imports<1.2.0" # 1.2.0 requires Python 3.10+, see https://github.com/bachorp/lazy-imports/releases/tag/1.2.0
]

[project.urls]
Expand Down
49 changes: 46 additions & 3 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
# SPDX-License-Identifier: Apache-2.0

import copy
import os
from pathlib import Path
from typing import Any, Optional
Expand All @@ -12,13 +13,11 @@
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.core.errors import BreakpointException
from haystack.core.pipeline.breakpoint import load_pipeline_snapshot
from haystack.dataclasses import ChatMessage
from haystack.dataclasses import ChatMessage, ToolCall
from haystack.dataclasses.breakpoints import PipelineSnapshot
from haystack.tools import Tool, create_tool_from_function

from haystack_experimental.chat_message_stores.in_memory import InMemoryChatMessageStore
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter
from haystack_experimental.components.agents.agent import Agent
from haystack_experimental.components.agents.human_in_the_loop import (
AlwaysAskPolicy,
Expand All @@ -34,6 +33,8 @@
from haystack_experimental.components.agents.human_in_the_loop.breakpoint import (
get_tool_calls_and_descriptions_from_snapshot,
)
from haystack_experimental.components.retrievers import ChatMessageRetriever
from haystack_experimental.components.writers import ChatMessageWriter


@pytest.fixture
Expand All @@ -50,6 +51,19 @@ def run(self, messages: list[ChatMessage], tools: Any) -> dict[str, list[ChatMes
return {"replies": [ChatMessage.from_assistant("This is a mock response.")]}


@component
class MockChatGeneratorToolsResponse:
@component.output_types(replies=list[ChatMessage])
def run(self, messages: list[ChatMessage], tools: Any) -> dict[str, list[ChatMessage]]:
return {
"replies": [
ChatMessage.from_assistant(
tool_calls=[ToolCall(tool_name="addition_tool", arguments={"a": 2, "b": 3})]
)
]
}


@component
class MockAgent:
def __init__(self, system_prompt: Optional[str] = None):
Expand Down Expand Up @@ -257,6 +271,35 @@ def test_from_dict(self, tools, confirmation_strategies, monkeypatch):


class TestAgentConfirmationStrategy:
def test_get_tool_calls_and_descriptions_from_snapshot_no_mutation_of_snapshot(self, tools, tmp_path):
agent = Agent(
chat_generator=MockChatGeneratorToolsResponse(),
tools=tools,
confirmation_strategies={
"addition_tool": BreakpointConfirmationStrategy(snapshot_file_path=str(tmp_path)),
},
)
agent.warm_up()

# Run the agent to create a snapshot with a breakpoint
try:
agent.run([ChatMessage.from_user("What is 2+2?")])
except BreakpointException:
pass

# Load the latest snapshot from disk
loaded_snapshot = get_latest_snapshot(snapshot_file_path=str(tmp_path))

original_snapshot = copy.deepcopy(loaded_snapshot)

# Extract tool calls and descriptions
_ = get_tool_calls_and_descriptions_from_snapshot(
agent_snapshot=loaded_snapshot.agent_snapshot, breakpoint_tool_only=True
)

# Verify that the original snapshot has not been mutated
assert loaded_snapshot == original_snapshot

@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY"), reason="OPENAI_API_KEY not set")
@pytest.mark.integration
def test_run_blocking_confirmation_strategy_modify(self, tools):
Expand Down
Loading