Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion haystack/components/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ def __init__(
List variables that must be provided as input to user_prompt or system_prompt.
If a variable listed as required is not provided, an exception is raised.
If set to `"*"`, all variables found in the prompts are required. Optional.
You can also include `"messages"` to make the `messages` run-time input required.
:param exit_conditions: List of conditions that will cause the agent to return.
Can include "text" if the agent should return when it generates a message without tool calls,
or tool names that will cause the agent to return once the tool was executed. Defaults to ["text"].
Expand Down Expand Up @@ -363,8 +364,12 @@ def _register_prompt_variables(self) -> None:
"""
required_variables = self.required_variables

# Check whether required_variables targets any prompt template variables (i.e. not just "messages")
has_prompt_required_vars = required_variables == "*" or (
isinstance(required_variables, list) and any(v != "messages" for v in required_variables)
)
if (
required_variables is not None
has_prompt_required_vars
and self._system_chat_prompt_builder is None
and self._user_chat_prompt_builder is None
):
Expand Down Expand Up @@ -406,6 +411,17 @@ def _register_prompt_variables(self) -> None:
else:
component.set_input_type(self, name=var_name, type=Any, default=None)

def __post_init__(self) -> None:
"""
Called by ComponentMeta after input/output sockets are parsed from the run method signature.

Used to retroactively make the ``messages`` input required when ``"messages"`` is listed
in ``required_variables``, which cannot be done inside ``__init__`` because the socket
hasn't been parsed yet at that point.
"""
if isinstance(self.required_variables, list) and "messages" in self.required_variables:
component.set_input_type(self, name="messages", type=list[ChatMessage])
Comment on lines +414 to +423
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julian-risch a __post_init__ was needed to do this because we are unable to run set_input_type on a param in the run method that is positional and already has a type.


def warm_up(self) -> None:
"""
Warm up the Agent.
Expand Down
8 changes: 7 additions & 1 deletion haystack/core/component/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ def inner(method: Callable[..., Any], sockets: Sockets) -> inspect.Signature:
existing_socket = sockets.get(param_name)
if existing_socket is not None and existing_socket != new_socket:
raise ComponentError(
"set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
"set_input_types()/set_input_type() cannot override the parameters of the 'run' method.\n"
f"Conflict found for parameter '{param_name}': {existing_socket} vs {new_socket}"
Comment on lines 263 to +265
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julian-risch this is what throws an error if the set_input_type is done in the init method.

)

sockets[param_name] = new_socket
Expand Down Expand Up @@ -322,6 +323,11 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:
ComponentMeta._parse_and_set_input_sockets(cls, instance)
ComponentMeta._parse_and_set_output_sockets(instance)

# Call __post_init__ if defined, allowing components to adjust sockets after
# the run method signature has been parsed (e.g. to make an existing socket required).
if callable(getattr(instance, "__post_init__", None)):
instance.__post_init__()
Comment on lines +326 to +329
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@julian-risch this is what I'm most concerned about since one could argue it's a pretty big feature to add support for __post_init__ to initializing components. It could be fine/a good idea to add, but I'd appreciate your thoughts.


# Since a Component can't be used in multiple Pipelines at the same time
# we need to know if it's already owned by a Pipeline when adding it to one.
# We use this flag to check that.
Expand Down
14 changes: 14 additions & 0 deletions test/components/agents/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,20 @@ def test_register_prompt_variables_warning_when_no_prompt_and_required_variables
make_agent(required_variables=["name"])
assert "The parameter required_variables is provided but neither" in caplog.text

def test_register_prompt_variables_messages_required(self, make_agent):
agent = make_agent(required_variables=["messages"])
messages_socket = agent.__haystack_input__._sockets_dict["messages"]
assert messages_socket.is_mandatory

def test_register_prompt_variables_messages_optional_by_default(self, make_agent):
agent = make_agent()
messages_socket = agent.__haystack_input__._sockets_dict["messages"]
assert not messages_socket.is_mandatory

def test_register_prompt_variables_no_warning_when_only_messages_required(self, make_agent, caplog):
make_agent(required_variables=["messages"])
assert "The parameter required_variables is provided but neither" not in caplog.text

def test_register_prompt_variables_set_all_variables_as_required(self, make_agent):
agent = make_agent(user_prompt=_user_msg("Question: {{question}}"), required_variables="*")
assert agent._user_chat_prompt_builder.required_variables == "*"
Expand Down
Loading