Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@ async def mcp_tool_handler(context: str, client: df.DurableOrchestrationClient)
logger.debug("[MCP Tool Trigger] Received invocation for agent: %s", agent_name)
return await self._handle_mcp_tool_invocation(agent_name=agent_name, context=context, client=client)

_ = mcp_tool_handler
logger.debug("[AgentFunctionApp] Registered MCP tool trigger for agent: %s", agent_name)

async def _handle_mcp_tool_invocation(
Expand All @@ -587,15 +588,17 @@ async def _handle_mcp_tool_invocation(

# Parse JSON context string
try:
parsed_context = json.loads(context)
parsed_context: Any = json.loads(context)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid MCP context format: {e}") from e

parsed_context = cast(Mapping[str, Any], parsed_context) if isinstance(parsed_context, dict) else {}

# Extract arguments from MCP context
arguments = parsed_context.get("arguments", {}) if isinstance(parsed_context, dict) else {}
arguments: dict[str, Any] = parsed_context.get("arguments", {})

# Validate required 'query' argument
query = arguments.get("query")
query: Any = arguments.get("query")
if not query or not isinstance(query, str):
raise ValueError("MCP Tool invocation is missing required 'query' argument of type string.")

Expand Down Expand Up @@ -951,10 +954,9 @@ def _extract_normalized_headers(self, req: func.HttpRequest) -> dict[str, str]:
"""Create a lowercase header mapping from the incoming request."""
headers: dict[str, str] = {}
raw_headers = req.headers
if isinstance(raw_headers, Mapping):
for key, value in raw_headers.items():
if value is not None:
headers[str(key).lower()] = str(value)
for key, value in cast(Mapping[str, str], raw_headers).items():
headers[key.lower()] = value

return headers

@staticmethod
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ class RunRequest:
thread_id: Optional thread ID for tracking
correlation_id: Optional correlation ID for tracking the response to this specific request
created_at: Optional timestamp when the request was created
orchestration_id: Optional ID of the orchestration that initiated this request
"""

message: str
Expand All @@ -297,6 +298,7 @@ class RunRequest:
thread_id: str | None = None
correlation_id: str | None = None
created_at: str | None = None
orchestration_id: str | None = None

def __init__(
self,
Expand All @@ -308,6 +310,7 @@ def __init__(
thread_id: str | None = None,
correlation_id: str | None = None,
created_at: str | None = None,
orchestration_id: str | None = None,
) -> None:
self.message = message
self.role = self.coerce_role(role)
Expand All @@ -317,6 +320,7 @@ def __init__(
self.thread_id = thread_id
self.correlation_id = correlation_id
self.created_at = created_at
self.orchestration_id = orchestration_id

@staticmethod
def coerce_role(value: Role | str | None) -> Role:
Expand Down Expand Up @@ -346,6 +350,8 @@ def to_dict(self) -> dict[str, Any]:
result["correlationId"] = self.correlation_id
if self.created_at:
result["created_at"] = self.created_at
if self.orchestration_id:
result["orchestrationId"] = self.orchestration_id

return result

Expand All @@ -361,4 +367,5 @@ def from_dict(cls, data: dict[str, Any]) -> RunRequest:
thread_id=data.get("thread_id"),
correlation_id=data.get("correlationId"),
created_at=data.get("created_at"),
orchestration_id=data.get("orchestrationId"),
)
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,14 @@ def my_orchestration(context):
)

# Prepare the request using RunRequest model
# Include the orchestration's instance_id so it can be stored in the agent's entity state
run_request = RunRequest(
message=message_str,
enable_tool_calls=enable_tool_calls,
correlation_id=correlation_id,
thread_id=session_id.key,
response_format=response_format,
orchestration_id=self.context.instance_id,
)

logger.debug("[DurableAIAgent] Calling entity %s with message: %s", entity_id, message_str[:100])
Expand Down
98 changes: 95 additions & 3 deletions python/packages/azurefunctions/tests/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_init_creates_entity(self) -> None:
assert entity.agent == mock_agent
assert len(entity.state.data.conversation_history) == 0
assert entity.state.data.extension_data is None
assert entity.state.schema_version == "1.0.0"
assert entity.state.schema_version == DurableAgentState.SCHEMA_VERSION

def test_init_stores_agent_reference(self) -> None:
"""Test that the agent reference is stored correctly."""
Expand Down Expand Up @@ -124,8 +124,7 @@ async def test_run_agent_executes_agent(self) -> None:
# Verify agent.run was called
mock_agent.run.assert_called_once()
_, kwargs = mock_agent.run.call_args
sent_messages = kwargs.get("messages")
assert isinstance(sent_messages, list)
sent_messages: list[Any] = kwargs.get("messages")
assert len(sent_messages) == 1
sent_message = sent_messages[0]
assert isinstance(sent_message, ChatMessage)
Expand Down Expand Up @@ -910,5 +909,98 @@ async def test_entity_function_with_run_request_dict(self) -> None:
assert text_found, f"Response text not found in message: {message}"


class TestDurableAgentStateRequestOrchestrationId:
"""Test suite for DurableAgentStateRequest orchestration_id field."""

def test_request_with_orchestration_id(self) -> None:
"""Test creating a request with an orchestration_id."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
orchestration_id="orch-456",
)

assert request.orchestration_id == "orch-456"

def test_request_to_dict_includes_orchestration_id(self) -> None:
"""Test that to_dict includes orchestrationId when set."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
orchestration_id="orch-789",
)

data = request.to_dict()

assert "orchestrationId" in data
assert data["orchestrationId"] == "orch-789"

def test_request_to_dict_excludes_orchestration_id_when_none(self) -> None:
"""Test that to_dict excludes orchestrationId when not set."""
request = DurableAgentStateRequest(
correlation_id="corr-123",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="test")],
)
],
)

data = request.to_dict()

assert "orchestrationId" not in data

def test_request_from_dict_with_orchestration_id(self) -> None:
"""Test from_dict correctly parses orchestrationId."""
data = {
"$type": "request",
"correlationId": "corr-123",
"createdAt": "2024-01-01T00:00:00Z",
"messages": [{"role": "user", "contents": [{"$type": "text", "text": "test"}]}],
"orchestrationId": "orch-from-dict",
}

request = DurableAgentStateRequest.from_dict(data)

assert request.orchestration_id == "orch-from-dict"

def test_request_from_run_request_with_orchestration_id(self) -> None:
"""Test from_run_request correctly transfers orchestration_id."""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
orchestration_id="orch-from-run-request",
)

durable_request = DurableAgentStateRequest.from_run_request(run_request)

assert durable_request.orchestration_id == "orch-from-run-request"

def test_request_from_run_request_without_orchestration_id(self) -> None:
"""Test from_run_request correctly handles missing orchestration_id."""
run_request = RunRequest(
message="test message",
correlation_id="corr-run",
)

durable_request = DurableAgentStateRequest.from_run_request(run_request)

assert durable_request.orchestration_id is None


if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
65 changes: 65 additions & 0 deletions python/packages/azurefunctions/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,71 @@ def test_round_trip_with_correlationId(self) -> None:
assert restored.correlation_id == original.correlation_id
assert restored.thread_id == original.thread_id

def test_init_with_orchestration_id(self) -> None:
"""Test RunRequest initialization with orchestration_id."""
request = RunRequest(
message="Test message",
thread_id="thread-orch-init",
orchestration_id="orch-123",
)

assert request.message == "Test message"
assert request.orchestration_id == "orch-123"

def test_to_dict_with_orchestration_id(self) -> None:
"""Test to_dict includes orchestrationId."""
request = RunRequest(
message="Test",
thread_id="thread-orch-to-dict",
orchestration_id="orch-456",
)
data = request.to_dict()

assert data["message"] == "Test"
assert data["orchestrationId"] == "orch-456"

def test_to_dict_excludes_orchestration_id_when_none(self) -> None:
"""Test to_dict excludes orchestrationId when not set."""
request = RunRequest(
message="Test",
thread_id="thread-orch-none",
)
data = request.to_dict()

assert "orchestrationId" not in data

def test_from_dict_with_orchestration_id(self) -> None:
"""Test from_dict with orchestrationId."""
data = {
"message": "Test",
"orchestrationId": "orch-789",
"thread_id": "thread-orch-from-dict",
}
request = RunRequest.from_dict(data)

assert request.message == "Test"
assert request.orchestration_id == "orch-789"
assert request.thread_id == "thread-orch-from-dict"

def test_round_trip_with_orchestration_id(self) -> None:
"""Test round-trip to_dict and from_dict with orchestration_id."""
original = RunRequest(
message="Test message",
thread_id="thread-123",
role=Role.SYSTEM,
correlation_id="corr-123",
orchestration_id="orch-123",
)

data = original.to_dict()
restored = RunRequest.from_dict(data)

assert restored.message == original.message
assert restored.role == original.role
assert restored.correlation_id == original.correlation_id
assert restored.orchestration_id == original.orchestration_id
assert restored.thread_id == original.thread_id


class TestModelIntegration:
"""Test suite for integration between models."""
Expand Down
22 changes: 22 additions & 0 deletions python/packages/azurefunctions/tests/test_orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,28 @@ def test_run_creates_entity_call(self) -> None:
assert request["correlationId"] == "correlation-guid"
assert "thread_id" in request
assert request["thread_id"] == "thread-guid"
# Verify orchestration ID is set from context.instance_id
assert "orchestrationId" in request
assert request["orchestrationId"] == "test-instance-001"

def test_run_sets_orchestration_id(self) -> None:
"""Test that run() sets the orchestration_id from context.instance_id."""
mock_context = Mock()
mock_context.instance_id = "my-orchestration-123"
mock_context.new_uuid = Mock(side_effect=["thread-guid", "correlation-guid"])

entity_task = _create_entity_task()
mock_context.call_entity = Mock(return_value=entity_task)

agent = DurableAIAgent(mock_context, "TestAgent")
thread = agent.get_new_thread()

agent.run(messages="Test", thread=thread)

call_args = mock_context.call_entity.call_args
request = call_args[0][2]

assert request["orchestrationId"] == "my-orchestration-123"

def test_run_without_thread(self) -> None:
"""Test that run() works without explicit thread (creates unique session key)."""
Expand Down
2 changes: 1 addition & 1 deletion python/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading