Skip to content

Commit 61cf8aa

Browse files
fix #3153
1 parent 6db2558 commit 61cf8aa

2 files changed

Lines changed: 337 additions & 6 deletions

File tree

python/packages/core/agent_framework/_tools.py

Lines changed: 78 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,15 @@
44
import inspect
55
import json
66
import sys
7-
from collections.abc import AsyncIterable, Awaitable, Callable, Collection, Mapping, MutableMapping, Sequence
7+
from collections.abc import (
8+
AsyncIterable,
9+
Awaitable,
10+
Callable,
11+
Collection,
12+
Mapping,
13+
MutableMapping,
14+
Sequence,
15+
)
816
from functools import wraps
917
from time import perf_counter, time_ns
1018
from typing import (
@@ -18,6 +26,7 @@
1826
Protocol,
1927
TypedDict,
2028
TypeVar,
29+
Union,
2130
cast,
2231
get_args,
2332
get_origin,
@@ -121,7 +130,13 @@ def _parse_inputs(
121130
if inputs is None:
122131
return []
123132

124-
from ._types import BaseContent, DataContent, HostedFileContent, HostedVectorStoreContent, UriContent
133+
from ._types import (
134+
BaseContent,
135+
DataContent,
136+
HostedFileContent,
137+
HostedVectorStoreContent,
138+
UriContent,
139+
)
125140

126141
parsed_inputs: list["Contents"] = []
127142
if not isinstance(inputs, list):
@@ -1008,6 +1023,27 @@ def _build_pydantic_model_from_json_schema(
10081023
if not properties:
10091024
return create_model(f"{model_name}_input")
10101025

1026+
def _resolve_literal_type(prop_details: dict[str, Any]) -> type | None:
1027+
"""Check if property should be a Literal type (const or enum).
1028+
1029+
Args:
1030+
prop_details: The JSON Schema property details
1031+
1032+
Returns:
1033+
Literal type if const or enum is present, None otherwise
1034+
"""
1035+
# const → Literal["value"]
1036+
if "const" in prop_details:
1037+
return Literal[prop_details["const"]] # type: ignore
1038+
1039+
# enum → Literal["a", "b", ...]
1040+
if "enum" in prop_details and isinstance(prop_details["enum"], list):
1041+
enum_values = prop_details["enum"]
1042+
if enum_values:
1043+
return Literal[tuple(enum_values)] # type: ignore
1044+
1045+
return None
1046+
10111047
def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
10121048
"""Resolve JSON Schema type to Python type, handling $ref, nested objects, and typed arrays.
10131049
@@ -1018,6 +1054,31 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
10181054
Returns:
10191055
Python type annotation (could be int, str, list[str], or a nested Pydantic model)
10201056
"""
1057+
# Handle oneOf + discriminator (polymorphic objects)
1058+
if "oneOf" in prop_details and "discriminator" in prop_details:
1059+
discriminator = prop_details["discriminator"]
1060+
disc_field = discriminator.get("propertyName")
1061+
1062+
variants = []
1063+
for variant in prop_details["oneOf"]:
1064+
if "$ref" in variant:
1065+
ref = variant["$ref"]
1066+
if ref.startswith("#/$defs/"):
1067+
def_name = ref.split("/")[-1]
1068+
resolved = definitions.get(def_name)
1069+
if resolved:
1070+
variant_model = _resolve_type(
1071+
resolved,
1072+
parent_name=f"{parent_name}_{def_name}",
1073+
)
1074+
variants.append(variant_model)
1075+
1076+
if variants and disc_field:
1077+
return Annotated[
1078+
Union[tuple(variants)], # type: ignore
1079+
Field(discriminator=disc_field),
1080+
]
1081+
10211082
# Handle $ref by resolving the reference
10221083
if "$ref" in prop_details:
10231084
ref = prop_details["$ref"]
@@ -1068,9 +1129,15 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
10681129
else nested_prop_details
10691130
)
10701131

1071-
nested_python_type = _resolve_type(
1072-
nested_prop_details, f"{nested_model_name}_{nested_prop_name}"
1073-
)
1132+
# Check for Literal types first (const/enum)
1133+
literal_type = _resolve_literal_type(nested_prop_details)
1134+
if literal_type is not None:
1135+
nested_python_type = literal_type
1136+
else:
1137+
nested_python_type = _resolve_type(
1138+
nested_prop_details,
1139+
f"{nested_model_name}_{nested_prop_name}",
1140+
)
10741141
nested_description = nested_prop_details.get("description", "")
10751142

10761143
# Build field kwargs for nested property
@@ -1107,7 +1174,12 @@ def _resolve_type(prop_details: dict[str, Any], parent_name: str = "") -> type:
11071174
for prop_name, prop_details in properties.items():
11081175
prop_details = json.loads(prop_details) if isinstance(prop_details, str) else prop_details
11091176

1110-
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
1177+
# Check for Literal types first (const/enum)
1178+
literal_type = _resolve_literal_type(prop_details)
1179+
if literal_type is not None:
1180+
python_type = literal_type
1181+
else:
1182+
python_type = _resolve_type(prop_details, f"{model_name}_{prop_name}")
11111183
description = prop_details.get("description", "")
11121184

11131185
# Build field kwargs (description, etc.)

python/packages/core/tests/core/test_tools.py

Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1756,4 +1756,263 @@ def dummy_func(**kwargs):
17561756
)
17571757

17581758

1759+
def test_one_of_discriminator_polymorphism():
1760+
"""Test that oneOf with discriminator creates proper polymorphic union types.
1761+
1762+
Tests that oneOf + discriminator patterns are properly converted to Pydantic discriminated unions.
1763+
"""
1764+
schema = {
1765+
"$defs": {
1766+
"CreateProject": {
1767+
"description": "Action: Create an Azure DevOps project.",
1768+
"properties": {
1769+
"name": {
1770+
"const": "create_project",
1771+
"default": "create_project",
1772+
"type": "string",
1773+
},
1774+
"params": {"$ref": "#/$defs/CreateProjectParams"},
1775+
},
1776+
"required": ["params"],
1777+
"type": "object",
1778+
},
1779+
"CreateProjectParams": {
1780+
"description": "Parameters for the create_project action.",
1781+
"properties": {
1782+
"orgUrl": {"minLength": 1, "type": "string"},
1783+
"projectName": {"minLength": 1, "type": "string"},
1784+
"description": {"default": "", "type": "string"},
1785+
"template": {"default": "Agile", "type": "string"},
1786+
"sourceControl": {
1787+
"default": "Git",
1788+
"enum": ["Git", "Tfvc"],
1789+
"type": "string",
1790+
},
1791+
"visibility": {"default": "private", "type": "string"},
1792+
},
1793+
"required": ["orgUrl", "projectName"],
1794+
"type": "object",
1795+
},
1796+
"DeployRequest": {
1797+
"description": "Request to deploy Azure DevOps resources.",
1798+
"properties": {
1799+
"projectName": {"minLength": 1, "type": "string"},
1800+
"organization": {"minLength": 1, "type": "string"},
1801+
"actions": {
1802+
"items": {
1803+
"discriminator": {
1804+
"mapping": {
1805+
"create_project": "#/$defs/CreateProject",
1806+
"hello_world": "#/$defs/HelloWorld",
1807+
},
1808+
"propertyName": "name",
1809+
},
1810+
"oneOf": [
1811+
{"$ref": "#/$defs/HelloWorld"},
1812+
{"$ref": "#/$defs/CreateProject"},
1813+
],
1814+
},
1815+
"type": "array",
1816+
},
1817+
},
1818+
"required": ["projectName", "organization"],
1819+
"type": "object",
1820+
},
1821+
"HelloWorld": {
1822+
"description": "Action: Prints a greeting message.",
1823+
"properties": {
1824+
"name": {
1825+
"const": "hello_world",
1826+
"default": "hello_world",
1827+
"type": "string",
1828+
},
1829+
"params": {"$ref": "#/$defs/HelloWorldParams"},
1830+
},
1831+
"required": ["params"],
1832+
"type": "object",
1833+
},
1834+
"HelloWorldParams": {
1835+
"description": "Parameters for the hello_world action.",
1836+
"properties": {
1837+
"name": {
1838+
"description": "Name to greet",
1839+
"minLength": 1,
1840+
"type": "string",
1841+
}
1842+
},
1843+
"required": ["name"],
1844+
"type": "object",
1845+
},
1846+
},
1847+
"properties": {"params": {"$ref": "#/$defs/DeployRequest"}},
1848+
"required": ["params"],
1849+
"type": "object",
1850+
}
1851+
1852+
# Build the model
1853+
model = _build_pydantic_model_from_json_schema("deploy_tool", schema)
1854+
1855+
# Verify the model structure
1856+
assert model is not None
1857+
assert issubclass(model, BaseModel)
1858+
1859+
# Test with HelloWorld action
1860+
hello_world_data = {
1861+
"params": {
1862+
"projectName": "MyProject",
1863+
"organization": "MyOrg",
1864+
"actions": [
1865+
{
1866+
"name": "hello_world",
1867+
"params": {"name": "Alice"},
1868+
}
1869+
],
1870+
}
1871+
}
1872+
1873+
instance = model(**hello_world_data)
1874+
assert instance.params.projectName == "MyProject"
1875+
assert instance.params.organization == "MyOrg"
1876+
assert len(instance.params.actions) == 1
1877+
assert instance.params.actions[0].name == "hello_world"
1878+
assert instance.params.actions[0].params.name == "Alice"
1879+
1880+
# Test with CreateProject action
1881+
create_project_data = {
1882+
"params": {
1883+
"projectName": "MyProject",
1884+
"organization": "MyOrg",
1885+
"actions": [
1886+
{
1887+
"name": "create_project",
1888+
"params": {
1889+
"orgUrl": "https://dev.azure.com/myorg",
1890+
"projectName": "NewProject",
1891+
"sourceControl": "Git",
1892+
},
1893+
}
1894+
],
1895+
}
1896+
}
1897+
1898+
instance2 = model(**create_project_data)
1899+
assert instance2.params.actions[0].name == "create_project"
1900+
assert instance2.params.actions[0].params.projectName == "NewProject"
1901+
assert instance2.params.actions[0].params.sourceControl == "Git"
1902+
1903+
# Test with mixed actions
1904+
mixed_data = {
1905+
"params": {
1906+
"projectName": "MyProject",
1907+
"organization": "MyOrg",
1908+
"actions": [
1909+
{"name": "hello_world", "params": {"name": "Bob"}},
1910+
{
1911+
"name": "create_project",
1912+
"params": {
1913+
"orgUrl": "https://dev.azure.com/myorg",
1914+
"projectName": "AnotherProject",
1915+
},
1916+
},
1917+
],
1918+
}
1919+
}
1920+
1921+
instance3 = model(**mixed_data)
1922+
assert len(instance3.params.actions) == 2
1923+
assert instance3.params.actions[0].name == "hello_world"
1924+
assert instance3.params.actions[1].name == "create_project"
1925+
1926+
1927+
def test_const_creates_literal():
1928+
"""Test that const in JSON Schema creates Literal type."""
1929+
schema = {
1930+
"properties": {
1931+
"action": {
1932+
"const": "create",
1933+
"type": "string",
1934+
"description": "Action type",
1935+
},
1936+
"value": {"type": "integer"},
1937+
},
1938+
"required": ["action", "value"],
1939+
}
1940+
1941+
model = _build_pydantic_model_from_json_schema("test_const", schema)
1942+
1943+
# Verify valid const value works
1944+
instance = model(action="create", value=42)
1945+
assert instance.action == "create"
1946+
assert instance.value == 42
1947+
1948+
# Verify incorrect const value fails
1949+
with pytest.raises(ValidationError):
1950+
model(action="delete", value=42)
1951+
1952+
1953+
def test_enum_creates_literal():
1954+
"""Test that enum in JSON Schema creates Literal type."""
1955+
schema = {
1956+
"properties": {
1957+
"status": {
1958+
"enum": ["pending", "approved", "rejected"],
1959+
"type": "string",
1960+
"description": "Status",
1961+
},
1962+
"priority": {"enum": [1, 2, 3], "type": "integer"},
1963+
},
1964+
"required": ["status"],
1965+
}
1966+
1967+
model = _build_pydantic_model_from_json_schema("test_enum", schema)
1968+
1969+
# Verify valid enum values work
1970+
instance = model(status="approved", priority=2)
1971+
assert instance.status == "approved"
1972+
assert instance.priority == 2
1973+
1974+
# Verify invalid enum value fails
1975+
with pytest.raises(ValidationError):
1976+
model(status="unknown")
1977+
1978+
with pytest.raises(ValidationError):
1979+
model(status="pending", priority=5)
1980+
1981+
1982+
def test_nested_object_with_const_and_enum():
1983+
"""Test that const and enum work in nested objects."""
1984+
schema = {
1985+
"properties": {
1986+
"config": {
1987+
"type": "object",
1988+
"properties": {
1989+
"type": {
1990+
"const": "production",
1991+
"default": "production",
1992+
"type": "string",
1993+
},
1994+
"level": {"enum": ["low", "medium", "high"], "type": "string"},
1995+
},
1996+
"required": ["level"],
1997+
}
1998+
},
1999+
"required": ["config"],
2000+
}
2001+
2002+
model = _build_pydantic_model_from_json_schema("test_nested", schema)
2003+
2004+
# Valid data
2005+
instance = model(config={"type": "production", "level": "high"})
2006+
assert instance.config.type == "production"
2007+
assert instance.config.level == "high"
2008+
2009+
# Invalid const in nested object
2010+
with pytest.raises(ValidationError):
2011+
model(config={"type": "development", "level": "low"})
2012+
2013+
# Invalid enum in nested object
2014+
with pytest.raises(ValidationError):
2015+
model(config={"type": "production", "level": "critical"})
2016+
2017+
17592018
# endregion

0 commit comments

Comments
 (0)