Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 196 additions & 114 deletions python/packages/azurefunctions/agent_framework_azurefunctions/_app.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class AgentCallbackContext:

agent_name: str
correlation_id: str
conversation_id: str | None = None
thread_id: str | None = None
request_message: str | None = None


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
from typing import Any, cast

import azure.durable_functions as df
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, get_logger
from agent_framework import AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, Role, get_logger

from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._models import AgentResponse, ChatRole, RunRequest
from ._models import AgentResponse, RunRequest
from ._state import AgentState

logger = get_logger("agent_framework.azurefunctions.entities")
Expand Down Expand Up @@ -81,33 +81,32 @@ async def run_agent(
"""
# Convert string or dict to RunRequest
if isinstance(request, str):
run_request = RunRequest(message=request, role=ChatRole.USER)
run_request = RunRequest(message=request, role=Role.USER)
elif isinstance(request, dict):
run_request = RunRequest.from_dict(request)
else:
run_request = request

message = run_request.message
conversation_id = run_request.conversation_id
thread_id = run_request.thread_id
correlation_id = run_request.correlation_id
if not conversation_id:
raise ValueError("RunRequest must include a conversation_id")
if not thread_id:
raise ValueError("RunRequest must include a thread_id")
if not correlation_id:
raise ValueError("RunRequest must include a correlation_id")
role = run_request.role or ChatRole.USER
role = run_request.role or Role.USER
response_format = run_request.response_format
enable_tool_calls = run_request.enable_tool_calls

logger.debug(f"[AgentEntity.run_agent] Received message: {message}")
logger.debug(f"[AgentEntity.run_agent] Conversation ID: {conversation_id}")
logger.debug(f"[AgentEntity.run_agent] Thread ID: {thread_id}")
logger.debug(f"[AgentEntity.run_agent] Correlation ID: {correlation_id}")
logger.debug(f"[AgentEntity.run_agent] Role: {role.value if isinstance(role, ChatRole) else role}")
logger.debug(f"[AgentEntity.run_agent] Role: {role.value}")
logger.debug(f"[AgentEntity.run_agent] Enable tool calls: {enable_tool_calls}")
logger.debug(f"[AgentEntity.run_agent] Response format: {'provided' if response_format else 'none'}")

# Store message in history with role
role_str = role.value if isinstance(role, ChatRole) else role
self.state.add_user_message(message, role=role_str, correlation_id=correlation_id)
self.state.add_user_message(message, role=role, correlation_id=correlation_id)

logger.debug("[AgentEntity.run_agent] Executing agent...")

Expand All @@ -123,7 +122,7 @@ async def run_agent(
agent_run_response: AgentRunResponse = await self._invoke_agent(
run_kwargs=run_kwargs,
correlation_id=correlation_id,
conversation_id=conversation_id,
thread_id=thread_id,
request_message=message,
)

Expand Down Expand Up @@ -160,7 +159,7 @@ async def run_agent(
agent_response = AgentResponse(
response=response_text,
message=str(message),
conversation_id=str(conversation_id),
thread_id=str(thread_id),
status="success",
message_count=self.state.message_count,
structured_response=structured_response,
Expand All @@ -185,7 +184,7 @@ async def run_agent(
error_response = AgentResponse(
response=f"Error: {exc!s}",
message=str(message),
conversation_id=str(conversation_id),
thread_id=str(thread_id),
status="error",
message_count=self.state.message_count,
error=str(exc),
Expand All @@ -197,15 +196,15 @@ async def _invoke_agent(
self,
run_kwargs: dict[str, Any],
correlation_id: str,
conversation_id: str,
thread_id: str,
request_message: str,
) -> AgentRunResponse:
"""Execute the agent, preferring streaming when available."""
callback_context: AgentCallbackContext | None = None
if self.callback is not None:
callback_context = self._build_callback_context(
correlation_id=correlation_id,
conversation_id=conversation_id,
thread_id=thread_id,
request_message=request_message,
)

Expand Down Expand Up @@ -319,15 +318,15 @@ async def _notify_final_response(
def _build_callback_context(
self,
correlation_id: str,
conversation_id: str,
thread_id: str,
request_message: str,
) -> AgentCallbackContext:
"""Create the callback context provided to consumers."""
agent_name = getattr(self.agent, "name", None) or type(self.agent).__name__
return AgentCallbackContext(
agent_name=agent_name,
correlation_id=correlation_id,
conversation_id=conversation_id,
thread_id=thread_id,
request_message=request_message,
)

Expand Down Expand Up @@ -375,11 +374,8 @@ async def _entity_coroutine(context: df.DurableEntityContext) -> None:
if operation == "run_agent":
input_data: Any = context.get_input()

# Support both old format (message + conversation_id) and new format (RunRequest dict)
# This provides backward compatibility
request: str | dict[str, Any]
if isinstance(input_data, dict) and "message" in input_data:
# Input can be either old format or new RunRequest format
request = cast(dict[str, Any], input_data)
else:
# Fall back to treating input as message string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,23 @@
This module defines the request and response models used by the framework.
"""

from __future__ import annotations

import inspect
import uuid
from collections.abc import MutableMapping
from dataclasses import dataclass
from enum import Enum
from importlib import import_module
from typing import TYPE_CHECKING, Any, cast

import azure.durable_functions as df
from agent_framework import AgentThread
from agent_framework import AgentThread, Role

if TYPE_CHECKING: # pragma: no cover - type checking imports only
from pydantic import BaseModel

_PydanticBaseModel: type["BaseModel"] | None
_PydanticBaseModel: type[BaseModel] | None

try:
from pydantic import BaseModel as _RuntimeBaseModel
except ImportError: # pragma: no cover - optional dependency
Expand All @@ -28,14 +30,6 @@
_PydanticBaseModel = _RuntimeBaseModel


class ChatRole(str, Enum):
"""Chat message role enum."""

USER = "user"
SYSTEM = "system"
ASSISTANT = "assistant"


@dataclass
class AgentSessionId:
"""Represents an agent session ID, which is used to identify a long-running agent session.
Expand Down Expand Up @@ -63,7 +57,7 @@ def to_entity_name(name: str) -> str:
return f"{AgentSessionId.ENTITY_NAME_PREFIX}{name}"

@staticmethod
def with_random_key(name: str) -> "AgentSessionId":
def with_random_key(name: str) -> AgentSessionId:
"""Creates a new AgentSessionId with the specified name and a randomly generated key.

Args:
Expand All @@ -83,7 +77,7 @@ def to_entity_id(self) -> df.EntityId:
return df.EntityId(self.to_entity_name(self.name), self.key)

@staticmethod
def from_entity_id(entity_id: df.EntityId) -> "AgentSessionId":
def from_entity_id(entity_id: df.EntityId) -> AgentSessionId:
"""Creates an AgentSessionId from a Durable Functions EntityId.

Args:
Expand Down Expand Up @@ -113,7 +107,7 @@ def __repr__(self) -> str:
return f"AgentSessionId(name='{self.name}', key='{self.key}')"

@staticmethod
def parse(session_id_string: str) -> "AgentSessionId":
def parse(session_id_string: str) -> AgentSessionId:
"""Parses a string representation of an agent session ID.

Args:
Expand Down Expand Up @@ -172,7 +166,7 @@ def from_session_id(
service_thread_id: str | None = None,
message_store: Any = None,
context_provider: Any = None,
) -> "DurableAgentThread":
) -> DurableAgentThread:
"""Creates a durable thread pre-associated with the supplied session ID."""
return cls(
session_id=session_id,
Expand All @@ -195,7 +189,7 @@ async def deserialize(
*,
message_store: Any = None,
**kwargs: Any,
) -> "DurableAgentThread":
) -> DurableAgentThread:
"""Restores a durable thread, rehydrating the stored session identifier."""
state_payload = dict(serialized_thread_state)
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
Expand All @@ -217,7 +211,7 @@ async def deserialize(
return thread


def _serialize_response_format(response_format: type["BaseModel"] | None) -> Any:
def _serialize_response_format(response_format: type[BaseModel] | None) -> Any:
"""Serialize response format for transport across durable function boundaries."""
if response_format is None:
return None
Expand All @@ -235,7 +229,7 @@ def _serialize_response_format(response_format: type["BaseModel"] | None) -> Any
}


def _deserialize_response_format(response_format: Any) -> type["BaseModel"] | None:
def _deserialize_response_format(response_format: Any) -> type[BaseModel] | None:
"""Deserialize response format back into actionable type if possible."""
if response_format is None:
return None
Expand Down Expand Up @@ -287,17 +281,45 @@ class RunRequest:
role: The role of the message sender (user, system, or assistant)
response_format: Optional Pydantic BaseModel type describing the structured response format
enable_tool_calls: Whether to enable tool calls for this request
conversation_id: Optional conversation/session ID for tracking
thread_id: Optional thread ID for tracking
correlation_id: Optional correlation ID for tracking the response to this specific request
"""

message: str
role: ChatRole = ChatRole.USER
response_format: type["BaseModel"] | None = None
role: Role = Role.USER
response_format: type[BaseModel] | None = None
enable_tool_calls: bool = True
conversation_id: str | None = None
thread_id: str | None = None
correlation_id: str | None = None

def __init__(
self,
message: str,
role: Role | str | None = Role.USER,
response_format: type[BaseModel] | None = None,
enable_tool_calls: bool = True,
thread_id: str | None = None,
correlation_id: str | None = None,
) -> None:
self.message = message
self.role = self.coerce_role(role)
self.response_format = response_format
self.enable_tool_calls = enable_tool_calls
self.thread_id = thread_id
self.correlation_id = correlation_id

@staticmethod
def coerce_role(value: Role | str | None) -> Role:
"""Normalize various role representations into a Role instance."""
if isinstance(value, Role):
return value
if isinstance(value, str):
normalized = value.strip()
if not normalized:
return Role.USER
return Role(value=normalized.lower())
Comment thread
larohra marked this conversation as resolved.
Comment thread
larohra marked this conversation as resolved.
return Role.USER

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result = {
Expand All @@ -307,30 +329,21 @@ def to_dict(self) -> dict[str, Any]:
}
if self.response_format:
result["response_format"] = _serialize_response_format(self.response_format)
if self.conversation_id:
result["conversation_id"] = self.conversation_id
if self.thread_id:
result["thread_id"] = self.thread_id
if self.correlation_id:
result["correlation_id"] = self.correlation_id
return result

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "RunRequest":
def from_dict(cls, data: dict[str, Any]) -> RunRequest:
"""Create RunRequest from dictionary."""
role_str = data.get("role")
if role_str:
try:
role = ChatRole(role_str.lower())
except ValueError:
role = ChatRole.USER # Default to USER if invalid
else:
role = ChatRole.USER

return cls(
message=data.get("message", ""),
role=role,
role=cls.coerce_role(data.get("role")),
response_format=_deserialize_response_format(data.get("response_format")),
enable_tool_calls=data.get("enable_tool_calls", True),
conversation_id=data.get("conversation_id"),
thread_id=data.get("thread_id"),
correlation_id=data.get("correlation_id"),
)

Expand All @@ -342,7 +355,7 @@ class AgentResponse:
Attributes:
response: The agent's text response (or None for structured responses)
message: The original message sent to the agent
conversation_id: The conversation/session ID
thread_id: The thread identifier
status: Status of the execution (success, error, etc.)
message_count: Number of messages in the conversation
error: Error message if status is error
Expand All @@ -352,7 +365,7 @@ class AgentResponse:

response: str | None
message: str
conversation_id: str | None
thread_id: str | None
status: str
message_count: int = 0
error: str | None = None
Expand All @@ -363,7 +376,7 @@ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
result = {
"message": self.message,
"conversation_id": self.conversation_id,
"thread_id": self.thread_id,
"status": self.status,
"message_count": self.message_count,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def my_orchestration(context):
message=message_str,
enable_tool_calls=enable_tool_calls,
correlation_id=correlation_id,
conversation_id=session_id.key,
thread_id=session_id.key,
response_format=response_format,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from collections.abc import MutableMapping
from datetime import datetime, timezone
from typing import Any, Literal, cast
from typing import Any, cast

from agent_framework import AgentRunResponse, ChatMessage, get_logger
from agent_framework import AgentRunResponse, ChatMessage, Role, get_logger

logger = get_logger("agent_framework.azurefunctions.state")

Expand Down Expand Up @@ -38,7 +38,7 @@ def _current_timestamp(self) -> str:
def add_user_message(
self,
content: str,
role: Literal["user", "system", "assistant", "tool"] = "user",
role: Role = Role.USER,
correlation_id: str | None = None,
) -> None:
"""Add a user message to the conversation history as a ChatMessage object.
Expand Down
Loading