Skip to content

Commit 90c7103

Browse files
author
Anders Brams
committed
feat: simplify protocols when --no-requests and --no-responses
1 parent 206fa18 commit 90c7103

3 files changed

Lines changed: 132 additions & 13 deletions

File tree

openapi_python/generator/api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .loader import load_openapi, load_openapi_json
99
from .model import NormalizedSpec
1010
from .normalize import normalize_openapi
11-
from .render import render_package
11+
from .render import render_package, rendered_type_definition_count
1212
from .write import write_artifacts
1313

1414

@@ -73,7 +73,12 @@ def generate_client(request: GenerationRequest) -> GenerationResult:
7373
success=True,
7474
written_files=tuple(written),
7575
operations=len(normalized.operations),
76-
type_definitions=len(normalized.aliases) + len(normalized.typed_dicts),
76+
type_definitions=rendered_type_definition_count(
77+
normalized,
78+
generate_routes=request.generate_routes,
79+
generate_requests=request.generate_requests,
80+
generate_responses=request.generate_responses,
81+
),
7782
diagnostics=(),
7883
)
7984

openapi_python/generator/render.py

Lines changed: 124 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,13 @@ def _protocol_block(
283283
)
284284

285285

286-
def _method_overload_line(op: OperationDef, *, is_async: bool = False) -> str:
286+
def _method_overload_line(
287+
op: OperationDef, *, return_type: str, is_async: bool = False
288+
) -> str:
287289
return _render_template(
288290
"method_overload.py.j2",
289291
op=op,
292+
return_type=return_type,
290293
protocol_name=_protocol_name(op, is_async=is_async),
291294
)
292295

@@ -304,10 +307,77 @@ def _fallback_method_block(
304307
)
305308

306309

307-
def _render_types(spec: NormalizedSpec, *, generate_routes: bool) -> str:
308-
aliases = (*_route_aliases(spec, generate_routes=generate_routes), *spec.aliases)
309-
type_definitions = _order_type_definitions(aliases, spec.typed_dicts)
310-
blocks = [_format_enum(item) for item in spec.enums] + [
310+
def _included_annotations(
311+
spec: NormalizedSpec, *, generate_requests: bool, generate_responses: bool
312+
) -> tuple[TypeAnnotation, ...]:
313+
"""
314+
Returns the set of type annotations that be included
315+
in the rendered output.
316+
"""
317+
roots: list[TypeAnnotation] = []
318+
for op in spec.operations:
319+
if generate_requests:
320+
roots.extend((op.params_type, op.query_type, op.headers_type))
321+
if op.body_type is not None:
322+
roots.append(op.body_type)
323+
if generate_responses:
324+
roots.append(op.response_type)
325+
return tuple(roots)
326+
327+
328+
def _used_type_names(
329+
spec: NormalizedSpec, *, generate_requests: bool, generate_responses: bool
330+
) -> set[str]:
331+
"""
332+
Returns the set of type definition names that are transitively referenced by
333+
the client protocols.
334+
"""
335+
by_name: dict[str, TypeAliasDef | TypedDictDef | EnumDef] = {
336+
item.name: item for item in (*spec.aliases, *spec.typed_dicts, *spec.enums)
337+
}
338+
all_names = set(by_name)
339+
used: set[str] = set()
340+
pending: list[str] = []
341+
342+
for annotation in _included_annotations(
343+
spec,
344+
generate_requests=generate_requests,
345+
generate_responses=generate_responses,
346+
):
347+
pending.extend(_annotation_dependencies(annotation, all_names) - used)
348+
349+
while pending:
350+
name = pending.pop()
351+
if name in used:
352+
continue
353+
used.add(name)
354+
item = by_name[name]
355+
match item:
356+
case TypeAliasDef() | TypedDictDef():
357+
pending.extend(_type_dependencies(item, all_names) - used)
358+
case EnumDef():
359+
pass
360+
return used
361+
362+
363+
def _render_types(
364+
spec: NormalizedSpec,
365+
*,
366+
generate_routes: bool,
367+
generate_requests: bool,
368+
generate_responses: bool,
369+
) -> str:
370+
route_aliases = _route_aliases(spec, generate_routes=generate_routes)
371+
used_names = _used_type_names(
372+
spec,
373+
generate_requests=generate_requests,
374+
generate_responses=generate_responses,
375+
)
376+
aliases = tuple(item for item in spec.aliases if item.name in used_names)
377+
typed_dicts = tuple(item for item in spec.typed_dicts if item.name in used_names)
378+
enums = tuple(item for item in spec.enums if item.name in used_names)
379+
type_definitions = _order_type_definitions((*route_aliases, *aliases), typed_dicts)
380+
blocks = [_format_enum(item) for item in enums] + [
311381
_format_type_definition(item) for item in type_definitions
312382
]
313383
return _render_template(
@@ -316,6 +386,26 @@ def _render_types(spec: NormalizedSpec, *, generate_routes: bool) -> str:
316386
)
317387

318388

389+
def rendered_type_definition_count(
390+
spec: NormalizedSpec,
391+
*,
392+
generate_routes: bool,
393+
generate_requests: bool,
394+
generate_responses: bool,
395+
) -> int:
396+
used_names = _used_type_names(
397+
spec,
398+
generate_requests=generate_requests,
399+
generate_responses=generate_responses,
400+
)
401+
return (
402+
len(_route_aliases(spec, generate_routes=generate_routes))
403+
+ sum(1 for item in spec.aliases if item.name in used_names)
404+
+ sum(1 for item in spec.typed_dicts if item.name in used_names)
405+
+ sum(1 for item in spec.enums if item.name in used_names)
406+
)
407+
408+
319409
def _literal_annotation(values: set[str]) -> LiteralAnnotation:
320410
return LiteralAnnotation(tuple(sorted(values)))
321411

@@ -393,13 +483,32 @@ def _render_client(
393483
is_async=True,
394484
)
395485
)
396-
method_overloads.setdefault(op.method, []).append(_method_overload_line(op))
486+
method_overloads.setdefault(op.method, []).append(
487+
_method_overload_line(
488+
op, return_type=_protocol_name(op), is_async=False
489+
)
490+
)
397491
async_method_overloads.setdefault(op.method, []).append(
398-
_method_overload_line(op, is_async=True)
492+
_method_overload_line(
493+
op,
494+
return_type=_protocol_name(op, is_async=True),
495+
is_async=True,
496+
)
399497
)
400498
else:
401-
method_overloads.setdefault(op.method, [])
402-
async_method_overloads.setdefault(op.method, [])
499+
overloads = method_overloads.setdefault(op.method, [])
500+
async_overloads = async_method_overloads.setdefault(op.method, [])
501+
if generate_routes:
502+
overloads.append(
503+
_method_overload_line(
504+
op, return_type="Callable[..., object]", is_async=False
505+
)
506+
)
507+
async_overloads.append(
508+
_method_overload_line(
509+
op, return_type="Callable[..., Awaitable[Any]]", is_async=True
510+
)
511+
)
403512

404513
method_blocks: list[str] = []
405514
for method in sorted(method_overloads):
@@ -459,7 +568,12 @@ def render_package(
459568
generate_responses: bool = True,
460569
) -> list[GeneratedArtifact]:
461570
context = {
462-
"types": _render_types(spec, generate_routes=generate_routes),
571+
"types": _render_types(
572+
spec,
573+
generate_routes=generate_routes,
574+
generate_requests=generate_requests,
575+
generate_responses=generate_responses,
576+
),
463577
"transport": _render_transport(spec, protocol_only=protocol_only),
464578
"client": _render_client(
465579
spec,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
@overload
2-
def {{ op.method }}(self, route: Literal[{{ op.route_literal|repr }}]) -> {{ protocol_name }}: ...
2+
def {{ op.method }}(self, route: Literal[{{ op.route_literal|repr }}]) -> {{ return_type }}: ...

0 commit comments

Comments
 (0)