Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 124 additions & 24 deletions python/packages/core/agent_framework/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2861,14 +2861,13 @@ def __init__(
self.created_at = created_at
self.finish_reason = finish_reason
self.usage_details = usage_details
self.value = value
self._value: Any | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
self.additional_properties.update(kwargs or {})
self.raw_representation: Any | list[Any] | None = raw_representation

if response_format:
self.try_parse_value(output_format_type=response_format)

@classmethod
def from_chat_response_updates(
cls: type[TChatResponse],
Expand Down Expand Up @@ -2933,29 +2932,78 @@ async def from_chat_response_generator(
Keyword Args:
output_format_type: Optional Pydantic model type to parse the response text into structured data.
"""
msg = cls(messages=[])
response_format = output_format_type if isinstance(output_format_type, type) else None
msg = cls(messages=[], response_format=response_format)
async for update in updates:
_process_update(msg, update)
_finalize_response(msg)
if output_format_type and isinstance(output_format_type, type) and issubclass(output_format_type, BaseModel):
msg.try_parse_value(output_format_type)
if response_format and issubclass(response_format, BaseModel):
msg.try_parse_value(response_format)
return msg

@property
def text(self) -> str:
"""Returns the concatenated text of all messages in the response."""
return ("\n".join(message.text for message in self.messages if isinstance(message, ChatMessage))).strip()

@property
def value(self) -> Any | None:
"""Get the parsed structured output value.

If a response_format was provided and parsing hasn't been attempted yet,
this will attempt to parse the text into the specified type.

Raises:
ValidationError: If the response text doesn't match the expected schema.
"""
if self._value_parsed:
return self._value
if (
self._response_format is not None
and isinstance(self._response_format, type)
and issubclass(self._response_format, BaseModel)
):
self._value = self._response_format.model_validate_json(self.text)
self._value_parsed = True
return self._value

def __str__(self) -> str:
return self.text

def try_parse_value(self, output_format_type: type[BaseModel]) -> None:
"""If there is a value, does nothing, otherwise tries to parse the text into the value."""
if self.value is None and isinstance(output_format_type, type) and issubclass(output_format_type, BaseModel):
try:
self.value = output_format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
except ValidationError as ex:
logger.debug("Failed to parse value from chat response text: %s", ex)
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
"""Try to parse the text into a typed value.

This is the safe alternative to accessing the value property directly.
Returns the parsed value on success, or None on failure.

Args:
output_format_type: The Pydantic model type to parse into.
If None, uses the response_format from initialization.

Returns:
The parsed value as the specified type, or None if parsing fails.
"""
format_type = output_format_type or self._response_format
if format_type is None or not (isinstance(format_type, type) and issubclass(format_type, BaseModel)):
return None

# Cache the result unless a different schema than the configured response_format is requested.
# This prevents calls with a different schema from polluting the cached value.
use_cache = (
self._response_format is None or output_format_type is None or output_format_type is self._response_format
)

if use_cache and self._value_parsed and self._value is not None:
return self._value # type: ignore[return-value, no-any-return]
try:
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
if use_cache:
self._value = parsed_value
self._value_parsed = True
return parsed_value # type: ignore[return-value]
except ValidationError as ex:
logger.warning("Failed to parse value from chat response text: %s", ex)
return None


# region ChatResponseUpdate
Expand Down Expand Up @@ -3141,6 +3189,7 @@ def __init__(
created_at: CreatedAtT | None = None,
usage_details: UsageDetails | MutableMapping[str, Any] | None = None,
value: Any | None = None,
response_format: type[BaseModel] | None = None,
raw_representation: Any | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
Expand All @@ -3153,6 +3202,7 @@ def __init__(
created_at: A timestamp for the chat response.
usage_details: The usage details for the chat response.
value: The structured output of the agent run response, if applicable.
response_format: Optional response format for the agent response.
additional_properties: Any additional properties associated with the chat response.
raw_representation: The raw representation of the chat response from an underlying implementation.
**kwargs: Additional properties to set on the response.
Expand Down Expand Up @@ -3180,7 +3230,9 @@ def __init__(
self.response_id = response_id
self.created_at = created_at
self.usage_details = usage_details
self.value = value
self._value: Any | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
self.additional_properties.update(kwargs or {})
self.raw_representation = raw_representation
Expand All @@ -3190,6 +3242,27 @@ def text(self) -> str:
"""Get the concatenated text of all messages."""
return "".join(msg.text for msg in self.messages) if self.messages else ""

@property
def value(self) -> Any | None:
"""Get the parsed structured output value.

If a response_format was provided and parsing hasn't been attempted yet,
this will attempt to parse the text into the specified type.

Raises:
ValidationError: If the response text doesn't match the expected schema.
"""
if self._value_parsed:
return self._value
if (
self._response_format is not None
and isinstance(self._response_format, type)
and issubclass(self._response_format, BaseModel)
):
self._value = self._response_format.model_validate_json(self.text)
self._value_parsed = True
return self._value

@property
def user_input_requests(self) -> list[UserInputRequestContents]:
"""Get all BaseUserInputRequest messages from the response."""
Expand All @@ -3215,7 +3288,7 @@ def from_agent_run_response_updates(
Keyword Args:
output_format_type: Optional Pydantic model type to parse the response text into structured data.
"""
msg = cls(messages=[])
msg = cls(messages=[], response_format=output_format_type)
for update in updates:
_process_update(msg, update)
_finalize_response(msg)
Expand All @@ -3238,7 +3311,7 @@ async def from_agent_response_generator(
Keyword Args:
output_format_type: Optional Pydantic model type to parse the response text into structured data
"""
msg = cls(messages=[])
msg = cls(messages=[], response_format=output_format_type)
async for update in updates:
_process_update(msg, update)
_finalize_response(msg)
Expand All @@ -3249,13 +3322,40 @@ async def from_agent_response_generator(
def __str__(self) -> str:
return self.text

def try_parse_value(self, output_format_type: type[BaseModel]) -> None:
"""If there is a value, does nothing, otherwise tries to parse the text into the value."""
if self.value is None:
try:
self.value = output_format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
except ValidationError as ex:
logger.debug("Failed to parse value from agent run response text: %s", ex)
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
"""Try to parse the text into a typed value.

This is the safe alternative when you need to parse the response text into a typed value.
Returns the parsed value on success, or None on failure.

Args:
output_format_type: The Pydantic model type to parse into.
If None, uses the response_format from initialization.

Returns:
The parsed value as the specified type, or None if parsing fails.
"""
format_type = output_format_type or self._response_format
if format_type is None or not (isinstance(format_type, type) and issubclass(format_type, BaseModel)):
return None

# Cache the result unless a different schema than the configured response_format is requested.
# This prevents calls with a different schema from polluting the cached value.
use_cache = (
self._response_format is None or output_format_type is None or output_format_type is self._response_format
)

if use_cache and self._value_parsed and self._value is not None:
return self._value # type: ignore[return-value, no-any-return]
try:
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
if use_cache:
self._value = parsed_value
self._value_parsed = True
return parsed_value # type: ignore[return-value]
except ValidationError as ex:
logger.warning("Failed to parse value from agent run response text: %s", ex)
return None


# region AgentResponseUpdate
Expand Down
118 changes: 118 additions & 0 deletions python/packages/core/tests/core/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,124 @@ def test_chat_response_with_format_init():
assert response.value.response == "Hello"


def test_chat_response_value_raises_on_invalid_schema():
"""Test that value property raises ValidationError with field constraint details."""
from typing import Literal

from pydantic import Field, ValidationError

class StrictSchema(BaseModel):
id: Literal[5]
name: str = Field(min_length=10)
score: int = Field(gt=0, le=100)

message = ChatMessage(role="assistant", text='{"id": 1, "name": "test", "score": -5}')
response = ChatResponse(messages=message, response_format=StrictSchema)

with raises(ValidationError) as exc_info:
_ = response.value

errors = exc_info.value.errors()
error_fields = {e["loc"][0] for e in errors}
assert "id" in error_fields, "Expected 'id' Literal constraint error"
assert "name" in error_fields, "Expected 'name' min_length constraint error"
assert "score" in error_fields, "Expected 'score' gt constraint error"


def test_chat_response_try_parse_value_returns_none_on_invalid():
"""Test that try_parse_value returns None on validation failure with Field constraints."""
from typing import Literal

from pydantic import Field

class StrictSchema(BaseModel):
id: Literal[5]
name: str = Field(min_length=10)
score: int = Field(gt=0, le=100)

message = ChatMessage(role="assistant", text='{"id": 1, "name": "test", "score": -5}')
response = ChatResponse(messages=message)

result = response.try_parse_value(StrictSchema)
assert result is None


def test_chat_response_try_parse_value_returns_value_on_success():
"""Test that try_parse_value returns parsed value when all constraints pass."""
from pydantic import Field

class MySchema(BaseModel):
name: str = Field(min_length=3)
score: int = Field(ge=0, le=100)

message = ChatMessage(role="assistant", text='{"name": "test", "score": 85}')
response = ChatResponse(messages=message)

result = response.try_parse_value(MySchema)
assert result is not None
assert result.name == "test"
assert result.score == 85


def test_agent_response_value_raises_on_invalid_schema():
"""Test that AgentResponse.value property raises ValidationError with field constraint details."""
from typing import Literal

from pydantic import Field, ValidationError

class StrictSchema(BaseModel):
id: Literal[5]
name: str = Field(min_length=10)
score: int = Field(gt=0, le=100)

message = ChatMessage(role="assistant", text='{"id": 1, "name": "test", "score": -5}')
response = AgentResponse(messages=message, response_format=StrictSchema)

with raises(ValidationError) as exc_info:
_ = response.value

errors = exc_info.value.errors()
error_fields = {e["loc"][0] for e in errors}
assert "id" in error_fields, "Expected 'id' Literal constraint error"
assert "name" in error_fields, "Expected 'name' min_length constraint error"
assert "score" in error_fields, "Expected 'score' gt constraint error"


def test_agent_response_try_parse_value_returns_none_on_invalid():
"""Test that AgentResponse.try_parse_value returns None on Field constraint failure."""
from typing import Literal

from pydantic import Field

class StrictSchema(BaseModel):
id: Literal[5]
name: str = Field(min_length=10)
score: int = Field(gt=0, le=100)

message = ChatMessage(role="assistant", text='{"id": 1, "name": "test", "score": -5}')
response = AgentResponse(messages=message)

result = response.try_parse_value(StrictSchema)
assert result is None


def test_agent_response_try_parse_value_returns_value_on_success():
"""Test that AgentResponse.try_parse_value returns parsed value when all constraints pass."""
from pydantic import Field

class MySchema(BaseModel):
name: str = Field(min_length=3)
score: int = Field(ge=0, le=100)

message = ChatMessage(role="assistant", text='{"name": "test", "score": 85}')
response = AgentResponse(messages=message)

result = response.try_parse_value(MySchema)
assert result is not None
assert result.name == "test"
assert result.score == 85


# region ChatResponseUpdate


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ async def main() -> None:
print(f"User: {query}")
result = await agent.run(query)

if isinstance(result.value, ReleaseBrief):
release_brief = result.value
if release_brief := result.try_parse_value(ReleaseBrief):
print("Agent:")
print(f"Feature: {release_brief.feature}")
print(f"Benefit: {release_brief.benefit}")
print(f"Launch date: {release_brief.launch_date}")
else:
print(f"Failed to parse response: {result.text}")


if __name__ == "__main__":
Expand Down
Loading
Loading