Skip to content

Commit f6d79db

Browse files
author
Anders Brams
committed
feat: allow docstring passon from typeddicts
1 parent 8855987 commit f6d79db

6 files changed

Lines changed: 228 additions & 10 deletions

File tree

openapi_python/generator/render.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,26 @@ def _field_annotation(field: FieldDef) -> str:
4848
"""
4949
Jinja2 filter to format a FieldDef's annotation.
5050
"""
51-
annotation = repr(_render_annotation(field.annotation))
51+
annotation = repr(_render_field_annotation(field))
5252
if not field.required:
5353
annotation = f"NotRequired[{annotation}]"
5454
return annotation
5555

5656

5757
def _class_field_annotation(field: FieldDef, total_optional: bool) -> str:
58-
annotation = _render_annotation(field.annotation)
58+
annotation = _render_field_annotation(field)
5959
if not field.required and not total_optional:
6060
annotation = f"NotRequired[{annotation}]"
6161
return annotation
6262

6363

64+
def _render_field_annotation(field: FieldDef) -> str:
65+
annotation = _render_annotation(field.annotation)
66+
if field.description is None:
67+
return annotation
68+
return f"Annotated[{annotation}, _openapi_python_field({field.description!r})]"
69+
70+
6471
def _string_literal(value: str) -> str:
6572
return repr(value)
6673

@@ -238,6 +245,10 @@ def _format_type_definition(defn: TypeAliasDef | TypedDictDef) -> str:
238245
return _format_typeddict(defn)
239246

240247

248+
def _has_field_descriptions(defns: tuple[TypedDictDef, ...]) -> bool:
249+
return any(field.description is not None for defn in defns for field in defn.fields)
250+
251+
241252
def _call_parameters(op: OperationDef, *, generate_requests: bool) -> dict[str, str]:
242253
if not generate_requests:
243254
return {
@@ -395,6 +406,7 @@ def _render_types(
395406
]
396407
return _render_template(
397408
"types.py.j2",
409+
has_field_descriptions=_has_field_descriptions(typed_dicts),
398410
type_blocks="\n".join(blocks).strip() + "\n",
399411
)
400412

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,21 @@
11
from __future__ import annotations
22

33
from enum import Enum
4-
from typing import Any, Literal, NotRequired, TypeAlias, TypedDict
4+
{% if has_field_descriptions -%}
5+
from importlib import import_module
6+
{% endif -%}
7+
from typing import Annotated, Any, Literal, NotRequired, TypeAlias, TypedDict
8+
9+
{% if has_field_descriptions %}
10+
def _openapi_python_field(description: str) -> object:
11+
try:
12+
# Allow FastAPI applications and other OpenAPI-spec-auto-generating tools
13+
# to use the same field description syntax as the generated client code
14+
# without a hard dependency on Pydantic.
15+
field = import_module("pydantic").Field
16+
except ImportError:
17+
return description
18+
return field(description=description)
19+
{% endif %}
520

621
{{ type_blocks }}

tests/contract/docstrings/app.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from typing import Annotated, TypedDict
4+
35
from fastapi import FastAPI
46
from pydantic import BaseModel, Field
57

@@ -12,6 +14,115 @@ class MyDTO(BaseModel):
1214
a: int = Field(description="This is the docstring for a single field")
1315

1416

17+
class CiscoAccessSwitchVLAN(TypedDict):
18+
vlan_id: int
19+
20+
21+
class CiscoAccessSwitchVLANInterface(TypedDict):
22+
name: str
23+
24+
25+
class CiscoAccessSwitchInterface(TypedDict):
26+
name: str
27+
28+
29+
class ISEProfile(TypedDict):
30+
name: str
31+
32+
33+
class CiscoAccessSwitchTemplateParams(TypedDict):
34+
hostname: Annotated[str, Field(description="The hostname of the switch")]
35+
vlans: Annotated[
36+
list[CiscoAccessSwitchVLAN],
37+
Field(
38+
description=(
39+
"All the VLANs that should be configured on the switch. These are "
40+
"all the VLANs associated with the location in Nautobot."
41+
)
42+
),
43+
]
44+
vlan_interfaces: Annotated[
45+
list[CiscoAccessSwitchVLANInterface],
46+
Field(
47+
description="All the virtual interfaces on the switch with the tag 'VLAN'."
48+
),
49+
]
50+
client_interfaces: Annotated[
51+
list[CiscoAccessSwitchInterface],
52+
Field(description="All the client interfaces on the switch."),
53+
]
54+
downlink_interfaces: Annotated[
55+
list[CiscoAccessSwitchInterface],
56+
Field(description="All the downlink interfaces on the switch."),
57+
]
58+
uplink_interfaces: Annotated[
59+
list[CiscoAccessSwitchInterface],
60+
Field(description="All the uplink interfaces on the switch."),
61+
]
62+
device_mgmt_ip: Annotated[
63+
str,
64+
Field(
65+
description="The IP address assigned to the switch for management purposes."
66+
),
67+
]
68+
management_vlan_id: Annotated[
69+
int,
70+
Field(
71+
description=(
72+
"The VLAN ID of the management VLAN. This is used for the "
73+
"management interface and default gateway."
74+
)
75+
),
76+
]
77+
management_interface: Annotated[
78+
str,
79+
Field(
80+
description=(
81+
"The interface on which the management IP is configured. This is "
82+
"used to determine which interface should be used for out-of-band "
83+
"management access to the switch."
84+
)
85+
),
86+
]
87+
default_gateway: Annotated[
88+
str,
89+
Field(description="The default gateway for the management VLAN."),
90+
]
91+
snmp_contact: Annotated[
92+
str,
93+
Field(description="The SNMP contact information for the switch."),
94+
]
95+
snmp_location: Annotated[
96+
str,
97+
Field(description="The SNMP location information for the switch."),
98+
]
99+
ise_profile: Annotated[
100+
ISEProfile,
101+
Field(description="The ISE profile to use for this switch."),
102+
]
103+
104+
15105
@app.get("/dto", response_model=MyDTO)
16106
def get_dto() -> MyDTO:
17107
return MyDTO(a=1)
108+
109+
110+
@app.get("/switch-template", response_model=CiscoAccessSwitchTemplateParams)
111+
def get_switch_template() -> CiscoAccessSwitchTemplateParams:
112+
switch_vlan: CiscoAccessSwitchVLAN = {"vlan_id": 10}
113+
switch_interface: CiscoAccessSwitchInterface = {"name": "GigabitEthernet1/0/1"}
114+
return {
115+
"hostname": "switch-01",
116+
"vlans": [switch_vlan],
117+
"vlan_interfaces": [{"name": "Vlan10"}],
118+
"client_interfaces": [switch_interface],
119+
"downlink_interfaces": [switch_interface],
120+
"uplink_interfaces": [switch_interface],
121+
"device_mgmt_ip": "192.0.2.10",
122+
"management_vlan_id": 10,
123+
"management_interface": "Vlan10",
124+
"default_gateway": "192.0.2.1",
125+
"snmp_contact": "Network Operations",
126+
"snmp_location": "DC1",
127+
"ise_profile": {"name": "default"},
128+
}

tests/contract/docstrings/generate.py

Lines changed: 75 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,69 @@
66
from pathlib import Path
77

88
from app import app
9+
from fastapi import FastAPI
910

1011
from openapi_python.generator import GenerationRequest, generate_client
1112

13+
SWITCH_TEMPLATE_FIELD_DOCSTRINGS = {
14+
"hostname": "The hostname of the switch",
15+
"vlans": (
16+
"All the VLANs that should be configured on the switch. These are all "
17+
"the VLANs associated with the location in Nautobot."
18+
),
19+
"vlan_interfaces": "All the virtual interfaces on the switch with the tag 'VLAN'.",
20+
"client_interfaces": "All the client interfaces on the switch.",
21+
"downlink_interfaces": "All the downlink interfaces on the switch.",
22+
"uplink_interfaces": "All the uplink interfaces on the switch.",
23+
"device_mgmt_ip": "The IP address assigned to the switch for management purposes.",
24+
"management_vlan_id": (
25+
"The VLAN ID of the management VLAN. This is used for the management "
26+
"interface and default gateway."
27+
),
28+
"management_interface": (
29+
"The interface on which the management IP is configured. This is used "
30+
"to determine which interface should be used for out-of-band management "
31+
"access to the switch."
32+
),
33+
"default_gateway": "The default gateway for the management VLAN.",
34+
"snmp_contact": "The SNMP contact information for the switch.",
35+
"snmp_location": "The SNMP location information for the switch.",
36+
"ise_profile": "The ISE profile to use for this switch.",
37+
}
38+
39+
40+
def _class_def(module: ast.Module, name: str) -> ast.ClassDef:
41+
return next(
42+
node
43+
for node in module.body
44+
if isinstance(node, ast.ClassDef) and node.name == name
45+
)
46+
47+
48+
def _field_docstrings(class_def: ast.ClassDef) -> dict[str, str]:
49+
docs: dict[str, str] = {}
50+
for field, docstring in zip(class_def.body, class_def.body[1:], strict=False):
51+
if not isinstance(field, ast.AnnAssign) or not isinstance(
52+
field.target, ast.Name
53+
):
54+
continue
55+
if not isinstance(docstring, ast.Expr) or not isinstance(
56+
docstring.value, ast.Constant
57+
):
58+
continue
59+
if isinstance(docstring.value.value, str):
60+
docs[field.target.id] = docstring.value.value
61+
return docs
62+
63+
64+
def _schema_property_descriptions(schema: dict) -> dict[str, str]:
65+
properties = schema["properties"]
66+
return {
67+
name: prop["description"]
68+
for name, prop in properties.items()
69+
if "description" in prop
70+
}
71+
1272

1373
def main() -> None:
1474
output_dir = Path(__file__).parent / "generated"
@@ -22,16 +82,26 @@ def main() -> None:
2282

2383
source = (output_dir / "my_client" / "types.py").read_text()
2484
module = ast.parse(source)
25-
dto = next(
26-
node
27-
for node in module.body
28-
if isinstance(node, ast.ClassDef) and node.name == "MyDTO"
29-
)
85+
dto = _class_def(module, "MyDTO")
86+
switch_template = _class_def(module, "CiscoAccessSwitchTemplateParams")
3087
generated_types = importlib.import_module("generated.my_client.types")
3188

3289
assert ast.get_docstring(dto) == "This is a descriptive docstring"
3390
assert generated_types.MyDTO.__doc__ == "This is a descriptive docstring"
3491
assert "This is the docstring for a single field" in source
92+
assert _field_docstrings(switch_template) == SWITCH_TEMPLATE_FIELD_DOCSTRINGS
93+
94+
api_b = FastAPI()
95+
96+
def get_switch_template() -> object:
97+
return {}
98+
99+
api_b.get(
100+
"/switch-template",
101+
response_model=generated_types.CiscoAccessSwitchTemplateParams,
102+
)(get_switch_template)
103+
schema = api_b.openapi()["components"]["schemas"]["CiscoAccessSwitchTemplateParams"]
104+
assert _schema_property_descriptions(schema) == SWITCH_TEMPLATE_FIELD_DOCSTRINGS
35105

36106

37107
if __name__ == "__main__":

tests/contract/docstrings/usage_async.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import assert_type
44

55
from generated.my_client import AsyncClient
6-
from generated.my_client.types import MyDTO
6+
from generated.my_client.types import CiscoAccessSwitchTemplateParams, MyDTO
77

88
client = AsyncClient(base_url="http://testserver")
99

@@ -12,3 +12,8 @@ async def main() -> None:
1212
result = await client.get("/dto")()
1313
assert_type(result, MyDTO)
1414
assert_type(result["a"], int)
15+
16+
switch_template = await client.get("/switch-template")()
17+
assert_type(switch_template, CiscoAccessSwitchTemplateParams)
18+
assert_type(switch_template["hostname"], str)
19+
assert_type(switch_template["management_vlan_id"], int)

tests/contract/docstrings/usage_sync.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,15 @@
33
from typing import assert_type
44

55
from generated.my_client import Client
6-
from generated.my_client.types import MyDTO
6+
from generated.my_client.types import CiscoAccessSwitchTemplateParams, MyDTO
77

88
client = Client(base_url="http://testserver")
99

1010
result = client.get("/dto")()
1111
assert_type(result, MyDTO)
1212
assert_type(result["a"], int)
13+
14+
switch_template = client.get("/switch-template")()
15+
assert_type(switch_template, CiscoAccessSwitchTemplateParams)
16+
assert_type(switch_template["hostname"], str)
17+
assert_type(switch_template["management_vlan_id"], int)

0 commit comments

Comments
 (0)