Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ WARP.md
**/tmpclaude*

# Azurite storage emulator files
*/__azurite_db_blob__.json
*/__azurite_db_blob_extent__.json
*/__azurite_db_queue__.json
*/__azurite_db_queue_extent__.json
*/__azurite_db_table__.json
*/__azurite_db_blob__.json*
*/__azurite_db_blob_extent__.json*
*/__azurite_db_queue__.json*
*/__azurite_db_queue_extent__.json*
*/__azurite_db_table__.json*
*/__blobstorage__/
*/__queuestorage__/
*/AzuriteConfig

# Azure Functions local settings
local.settings.json
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
import importlib.metadata

from ._app import AgentFunctionApp
from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._orchestration import DurableAIAgent

try:
__version__ = importlib.metadata.version(__name__)
except importlib.metadata.PackageNotFoundError:
__version__ = "0.0.0" # Fallback for development mode

__all__ = [
"AgentCallbackContext",
"AgentFunctionApp",
"AgentResponseCallbackProtocol",
"DurableAIAgent",
"__version__",
]
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

import json
import re
import uuid
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, TypeVar, cast

import azure.durable_functions as df
import azure.functions as func
from agent_framework import AgentProtocol, get_logger

from ._callbacks import AgentResponseCallbackProtocol
from ._constants import (
from agent_framework_durabletask import (
DEFAULT_MAX_POLL_RETRIES,
DEFAULT_POLL_INTERVAL_SECONDS,
MIMETYPE_APPLICATION_JSON,
Expand All @@ -28,12 +28,17 @@
THREAD_ID_HEADER,
WAIT_FOR_RESPONSE_FIELD,
WAIT_FOR_RESPONSE_HEADER,
AgentResponseCallbackProtocol,
AgentSessionId,
ApiResponseFields,
DurableAgentState,
DurableAIAgent,
RunRequest,
)
from ._durable_agent_state import DurableAgentState

from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._models import AgentSessionId, RunRequest
from ._orchestration import AgentOrchestrationContextType, DurableAIAgent
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor

logger = get_logger("agent_framework.azurefunctions")

Expand Down Expand Up @@ -294,7 +299,7 @@ def get_agent(
self,
context: AgentOrchestrationContextType,
agent_name: str,
) -> DurableAIAgent:
) -> DurableAIAgent[AgentTask]:
"""Return a DurableAIAgent proxy for a registered agent.

Args:
Expand All @@ -305,14 +310,15 @@ def get_agent(
ValueError: If the requested agent has not been registered.

Returns:
DurableAIAgent wrapper bound to the orchestration context.
DurableAIAgent[AgentTask] wrapper bound to the orchestration context.
"""
normalized_name = str(agent_name)

if normalized_name not in self._agent_metadata:
raise ValueError(f"Agent '{normalized_name}' is not registered with this app.")

return DurableAIAgent(context, normalized_name)
executor = AzureFunctionsAgentExecutor(context)
return DurableAIAgent(executor, normalized_name)

def _setup_agent_functions(
self,
Expand Down Expand Up @@ -375,8 +381,6 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
"enable_tool_calls": true|false (optional, default: true)
}
"""
logger.debug(f"[HTTP Trigger] Received request on route: /api/agents/{agent_name}/run")

request_response_format: str = REQUEST_RESPONSE_FORMAT_JSON
thread_id: str | None = None

Expand All @@ -385,9 +389,9 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
thread_id = self._resolve_thread_id(req=req, req_body=req_body)
wait_for_response = self._should_wait_for_response(req=req, req_body=req_body)

logger.debug(f"[HTTP Trigger] Message: {message}")
logger.debug(f"[HTTP Trigger] Thread ID: {thread_id}")
logger.debug(f"[HTTP Trigger] wait_for_response: {wait_for_response}")
logger.debug(
f"[HTTP Trigger] Message: {message}, Thread ID: {thread_id}, wait_for_response: {wait_for_response}"
)

if not message:
logger.warning("[HTTP Trigger] Request rejected: Missing message")
Expand All @@ -401,15 +405,18 @@ async def http_start(req: func.HttpRequest, client: df.DurableOrchestrationClien
session_id = self._create_session_id(agent_name, thread_id)
correlation_id = self._generate_unique_id()

logger.debug(f"[HTTP Trigger] Using session ID: {session_id}")
logger.debug(f"[HTTP Trigger] Generated correlation ID: {correlation_id}")
logger.debug("[HTTP Trigger] Calling entity to run agent...")
logger.debug(
f"[HTTP Trigger] Calling entity to run agent using session ID: {session_id} "
f"and correlation ID: {correlation_id}"
)

entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)
run_request = self._build_request_data(
req_body,
message,
thread_id,
correlation_id,
request_response_format,
)
Expand Down Expand Up @@ -622,14 +629,16 @@ async def _handle_mcp_tool_invocation(
session_id = AgentSessionId.with_random_key(agent_name)

# Build entity instance ID
entity_instance_id = session_id.to_entity_id()
entity_instance_id = df.EntityId(
name=session_id.entity_name,
key=session_id.key,
)

# Create run request
correlation_id = self._generate_unique_id()
run_request = self._build_request_data(
req_body={"message": query, "role": "user"},
message=query,
thread_id=str(session_id),
correlation_id=correlation_id,
request_response_format=REQUEST_RESPONSE_FORMAT_TEXT,
)
Expand Down Expand Up @@ -781,7 +790,7 @@ async def _poll_entity_for_response(
agent_response = state.try_get_agent_response(correlation_id)
if agent_response:
result = self._build_success_result(
response_data=agent_response,
response_message=agent_response.text,
message=message,
thread_id=thread_id,
correlation_id=correlation_id,
Expand Down Expand Up @@ -827,23 +836,22 @@ async def _build_timeout_result(self, message: str, thread_id: str, correlation_
)

def _build_success_result(
self, response_data: dict[str, Any], message: str, thread_id: str, correlation_id: str, state: DurableAgentState
self, response_message: str, message: str, thread_id: str, correlation_id: str, state: DurableAgentState
) -> dict[str, Any]:
"""Build the success result returned to the HTTP caller."""
return self._build_response_payload(
response=response_data.get("content"),
response=response_message,
message=message,
thread_id=thread_id,
status="success",
correlation_id=correlation_id,
extra_fields={"message_count": response_data.get("message_count", state.message_count)},
extra_fields={ApiResponseFields.MESSAGE_COUNT: state.message_count},
)

def _build_request_data(
self,
req_body: dict[str, Any],
message: str,
thread_id: str,
correlation_id: str,
request_response_format: str,
) -> dict[str, Any]:
Expand All @@ -857,8 +865,8 @@ def _build_request_data(
request_response_format=request_response_format,
response_format=req_body.get("response_format"),
enable_tool_calls=enable_tool_calls,
thread_id=thread_id,
correlation_id=correlation_id,
created_at=datetime.now(timezone.utc),
).to_dict()

def _build_accepted_response(self, message: str, thread_id: str, correlation_id: str) -> dict[str, Any]:
Expand Down Expand Up @@ -910,15 +918,13 @@ def _convert_payload_to_text(self, payload: dict[str, Any]) -> str:

def _generate_unique_id(self) -> str:
"""Generate a new unique identifier."""
import uuid

return uuid.uuid4().hex

def _create_session_id(self, func_name: str, thread_id: str | None) -> AgentSessionId:
def _create_session_id(self, agent_name: str, thread_id: str | None) -> AgentSessionId:
"""Create a session identifier using the provided thread id or a random value."""
if thread_id:
return AgentSessionId(name=func_name, key=thread_id)
return AgentSessionId.with_random_key(name=func_name)
return AgentSessionId(name=agent_name, key=thread_id)
return AgentSessionId.with_random_key(name=agent_name)

def _resolve_thread_id(self, req: func.HttpRequest, req_body: dict[str, Any]) -> str:
"""Retrieve the thread identifier from request body or query parameters."""
Expand Down
Loading
Loading