Skip to content

Commit 3b3f740

Browse files
author
Anders Brams
committed
feat: ability to selectively generate protocol parts
1 parent a43266e commit 3b3f740

12 files changed

Lines changed: 346 additions & 39 deletions

File tree

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ Generate a client from an OpenAPI spec in `openapi.json`:
2424
# Types + HTTP transport
2525
uv run openapi-python generate --spec ./openapi.json --out ./generated
2626

27-
# Types
28-
uv run openapi-python generate --spec ./openapi.json --out ./generated --transport-mode protocol-only
27+
# Types + custom transport protocol
28+
uv run openapi-python generate --spec ./openapi.json --out ./generated --protocol-only
2929
```
3030

3131
... or programatically:
@@ -87,7 +87,9 @@ book = client.get("/books/{book_id}")(params={"book_id": 1})
8787

8888
Generated clients expose a transport protocol. You can plug in your own transport while keeping route-level typing guarantees.
8989

90-
Use `--transport-mode protocol-only` to generate clients that require a supplied transport and do not emit the built-in `httpx` transport classes. The default `--transport-mode default` includes `DefaultTransport` and `DefaultAsyncTransport`, which require the `httpx` extra when instantiated.
90+
Use `--protocol-only` to generate clients that require a supplied transport and do not emit the built-in `httpx` transport classes. By default, generated clients include `DefaultTransport` and `DefaultAsyncTransport`, which require the `httpx` extra when instantiated.
91+
92+
Protocol typing can be relaxed independently with `--no-routes`, `--no-requests`, and `--no-responses`. Those flags replace the corresponding route literals, request payload types, or response types with broad catch-all types.
9193

9294
### Built-in `httpx` transport
9395

@@ -128,7 +130,7 @@ uv run openapi-python generate \
128130
--spec ./openapi.json \
129131
--out ./generated \
130132
--package my_client \
131-
--transport-mode protocol-only
133+
--protocol-only
132134
```
133135

134136
Then provide an object that satisfies the generated `Transport` protocol:

openapi_python/cli.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,27 @@ def _build_parser() -> argparse.ArgumentParser:
3030
help="Disable SSL certificate verification for URL specs",
3131
)
3232
generate.add_argument(
33-
"--transport-mode",
34-
choices=["default", "protocol-only"],
35-
default="default",
36-
help="Generation mode for transport integration",
33+
"--protocol-only",
34+
action="store_true",
35+
help="Require supplied transports instead of generating built-in httpx transports",
36+
)
37+
generate.add_argument(
38+
"--routes",
39+
action=argparse.BooleanOptionalAction,
40+
default=True,
41+
help="Generate route literal types and route-specific client overloads (DEFAULT: True)",
42+
)
43+
generate.add_argument(
44+
"--requests",
45+
action=argparse.BooleanOptionalAction,
46+
default=True,
47+
help="Generate typed params, query, header, and body protocol arguments (DEFAULT: True)",
48+
)
49+
generate.add_argument(
50+
"--responses",
51+
action=argparse.BooleanOptionalAction,
52+
default=True,
53+
help="Generate typed protocol response values (DEFAULT: True)",
3754
)
3855
return parser
3956

@@ -53,7 +70,10 @@ def main(argv: list[str] | None = None) -> int:
5370
package_name=args.package,
5471
overwrite=args.overwrite,
5572
verify_ssl=not args.no_ssl,
56-
transport_mode=args.transport_mode,
73+
protocol_only=args.protocol_only,
74+
generate_routes=args.routes,
75+
generate_requests=args.requests,
76+
generate_responses=args.responses,
5777
)
5878
result = try_generate_client(request)
5979

openapi_python/generator/api.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ class GenerationRequest:
1919
package_name: str = "my_client"
2020
overwrite: bool = False
2121
verify_ssl: bool = True
22-
transport_mode: str = "default"
22+
protocol_only: bool = False
23+
generate_routes: bool = True
24+
generate_requests: bool = True
25+
generate_responses: bool = True
2326
extensions: GeneratorExtensions | None = None
2427
spec_json: str | None = None
2528

@@ -38,8 +41,6 @@ def generate_client(request: GenerationRequest) -> GenerationResult:
3841
raise invalid_request("Exactly one of spec_source or spec_json is required")
3942
if not request.package_name:
4043
raise invalid_request("package_name is required")
41-
if request.transport_mode not in {"default", "protocol-only"}:
42-
raise invalid_request("transport_mode must be 'default' or 'protocol-only'")
4344

4445
if request.spec_json is not None:
4546
document = load_openapi_json(request.spec_json)
@@ -57,7 +58,12 @@ def generate_client(request: GenerationRequest) -> GenerationResult:
5758
normalized = candidate
5859

5960
artifacts = render_package(
60-
normalized, request.extensions, transport_mode=request.transport_mode
61+
normalized,
62+
request.extensions,
63+
protocol_only=request.protocol_only,
64+
generate_routes=request.generate_routes,
65+
generate_requests=request.generate_requests,
66+
generate_responses=request.generate_responses,
6167
)
6268
written = write_artifacts(
6369
output_dir=request.output_dir, artifacts=artifacts, overwrite=request.overwrite

openapi_python/generator/render.py

Lines changed: 77 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,15 @@ def _format_type_definition(defn: TypeAliasDef | TypedDictDef) -> str:
225225
return _format_typeddict(defn)
226226

227227

228-
def _call_parameters(op: OperationDef) -> dict[str, str]:
228+
def _call_parameters(op: OperationDef, *, generate_requests: bool) -> dict[str, str]:
229+
if not generate_requests:
230+
return {
231+
"params": "params: dict[str, Any] | None = None",
232+
"query": "query: dict[str, Any] | None = None",
233+
"headers": "headers: dict[str, Any] | None = None",
234+
"body": "body: object | None = None",
235+
}
236+
229237
params = "params: " + _render_annotation(op.params_type)
230238
if not op.params_required:
231239
params += " | None = None"
@@ -256,13 +264,22 @@ def _protocol_name(op: OperationDef, *, is_async: bool = False) -> str:
256264
return f"Async{op.protocol_name}" if is_async else op.protocol_name
257265

258266

259-
def _protocol_block(op: OperationDef, *, is_async: bool = False) -> str:
267+
def _protocol_block(
268+
op: OperationDef,
269+
*,
270+
generate_requests: bool,
271+
generate_responses: bool,
272+
is_async: bool = False,
273+
) -> str:
260274
return _render_template(
261275
"protocol.py.j2",
262276
op=op,
263277
is_async=is_async,
264278
protocol_name=_protocol_name(op, is_async=is_async),
265-
call_parameters=_call_parameters(op),
279+
call_parameters=_call_parameters(op, generate_requests=generate_requests),
280+
response_type=(
281+
_render_annotation(op.response_type) if generate_responses else "Any"
282+
),
266283
)
267284

268285

@@ -287,8 +304,8 @@ def _fallback_method_block(
287304
)
288305

289306

290-
def _render_types(spec: NormalizedSpec) -> str:
291-
aliases = (*_route_aliases(spec), *spec.aliases)
307+
def _render_types(spec: NormalizedSpec, *, generate_routes: bool) -> str:
308+
aliases = (*_route_aliases(spec, generate_routes=generate_routes), *spec.aliases)
292309
type_definitions = _order_type_definitions(aliases, spec.typed_dicts)
293310
blocks = [_format_enum(item) for item in spec.enums] + [
294311
_format_type_definition(item) for item in type_definitions
@@ -303,7 +320,12 @@ def _literal_annotation(values: set[str]) -> LiteralAnnotation:
303320
return LiteralAnnotation(tuple(sorted(values)))
304321

305322

306-
def _route_aliases(spec: NormalizedSpec) -> tuple[TypeAliasDef, ...]:
323+
def _route_aliases(
324+
spec: NormalizedSpec, *, generate_routes: bool
325+
) -> tuple[TypeAliasDef, ...]:
326+
if not generate_routes:
327+
return (TypeAliasDef(name="RouteLiteral", annotation=NamedAnnotation("str")),)
328+
307329
routes_by_method: dict[str, set[str]] = {}
308330
for op in spec.operations:
309331
routes_by_method.setdefault(op.method.upper(), set()).add(op.route_literal)
@@ -331,28 +353,50 @@ def _route_aliases(spec: NormalizedSpec) -> tuple[TypeAliasDef, ...]:
331353
return tuple(aliases)
332354

333355

334-
def _render_transport(spec: NormalizedSpec, *, transport_mode: str) -> str:
356+
def _render_transport(spec: NormalizedSpec, *, protocol_only: bool) -> str:
335357
return _render_template(
336358
"transport.py.j2",
337-
typing_imports=(
338-
"TYPE_CHECKING, Protocol" if transport_mode == "default" else "Protocol"
339-
),
340-
include_default_transport=transport_mode == "default",
359+
typing_imports="Protocol" if protocol_only else "TYPE_CHECKING, Protocol",
360+
include_default_transport=not protocol_only,
341361
)
342362

343363

344-
def _render_client(spec: NormalizedSpec, *, transport_mode: str) -> str:
364+
def _render_client(
365+
spec: NormalizedSpec,
366+
*,
367+
protocol_only: bool,
368+
generate_routes: bool,
369+
generate_requests: bool,
370+
generate_responses: bool,
371+
) -> str:
345372
protocols: list[str] = []
346373
async_protocols: list[str] = []
347374
method_overloads: dict[str, list[str]] = {}
348375
async_method_overloads: dict[str, list[str]] = {}
349376
for op in spec.operations:
350-
protocols.append(_protocol_block(op))
351-
async_protocols.append(_protocol_block(op, is_async=True))
352-
method_overloads.setdefault(op.method, []).append(_method_overload_line(op))
353-
async_method_overloads.setdefault(op.method, []).append(
354-
_method_overload_line(op, is_async=True)
355-
)
377+
if generate_routes:
378+
protocols.append(
379+
_protocol_block(
380+
op,
381+
generate_requests=generate_requests,
382+
generate_responses=generate_responses,
383+
)
384+
)
385+
async_protocols.append(
386+
_protocol_block(
387+
op,
388+
generate_requests=generate_requests,
389+
generate_responses=generate_responses,
390+
is_async=True,
391+
)
392+
)
393+
method_overloads.setdefault(op.method, []).append(_method_overload_line(op))
394+
async_method_overloads.setdefault(op.method, []).append(
395+
_method_overload_line(op, is_async=True)
396+
)
397+
else:
398+
method_overloads.setdefault(op.method, [])
399+
async_method_overloads.setdefault(op.method, [])
356400

357401
method_blocks: list[str] = []
358402
for method in sorted(method_overloads):
@@ -372,7 +416,7 @@ def _render_client(spec: NormalizedSpec, *, transport_mode: str) -> str:
372416
)
373417
)
374418

375-
if transport_mode == "default":
419+
if not protocol_only:
376420
transport_imports = (
377421
"from .transport import AsyncTransport, DefaultAsyncTransport, "
378422
"DefaultTransport, Transport"
@@ -406,20 +450,29 @@ def render_package(
406450
spec: NormalizedSpec,
407451
extensions: GeneratorExtensions | None = None,
408452
*,
409-
transport_mode: str = "default",
453+
protocol_only: bool = False,
454+
generate_routes: bool = True,
455+
generate_requests: bool = True,
456+
generate_responses: bool = True,
410457
) -> list[GeneratedArtifact]:
411458
context = {
412-
"types": _render_types(spec),
413-
"transport": _render_transport(spec, transport_mode=transport_mode),
414-
"client": _render_client(spec, transport_mode=transport_mode),
459+
"types": _render_types(spec, generate_routes=generate_routes),
460+
"transport": _render_transport(spec, protocol_only=protocol_only),
461+
"client": _render_client(
462+
spec,
463+
protocol_only=protocol_only,
464+
generate_routes=generate_routes,
465+
generate_requests=generate_requests,
466+
generate_responses=generate_responses,
467+
),
415468
}
416469

417470
if extensions:
418471
for hook in extensions.render_context_hooks:
419472
context = hook(spec, context)
420473

421474
init_content = _render_template(
422-
"init.py.j2", include_default_transport=transport_mode == "default"
475+
"init.py.j2", include_default_transport=not protocol_only
423476
)
424477

425478
return [

openapi_python/generator/templates/method_block.py.j2

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
{{ overloads }}
2+
{% if overloads %}
23
@overload
34
def {{ method }}(self, route: str) -> Callable[..., {{ callable_return }}]: ...
5+
{% endif %}
46
def {{ method }}(self, route: str) -> Callable[..., {{ callable_return }}]:
5-
{{ "async " if is_async else "" }}def _call(*, params: dict[str, object] | None = None, query: dict[str, object] | None = None, headers: dict[str, object] | None = None, body: object | None = None) -> {{ call_return }}:
7+
{{ "async " if is_async else "" }}def _call(*, params: dict[str, Any] | None = None, query: dict[str, Any] | None = None, headers: dict[str, Any] | None = None, body: object | None = None) -> {{ call_return }}:
68
return {{ "await " if is_async else "" }}self._transport.request(
79
method={{ method|repr }},
810
route=route,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
class {{ protocol_name }}(Protocol):
2-
{{ "async " if is_async else "" }}def __call__(self, *, {{ call_parameters.params }}, {{ call_parameters.query }}, {{ call_parameters.headers }}, {{ call_parameters.body }}) -> {{ op.response_type | annotation }}: ...
2+
{{ "async " if is_async else "" }}def __call__(self, *, {{ call_parameters.params }}, {{ call_parameters.query }}, {{ call_parameters.headers }}, {{ call_parameters.body }}) -> {{ response_type }}: ...
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from pathlib import Path
5+
6+
from openapi_python.generator import GenerationRequest, generate_client
7+
8+
SPEC = {
9+
"openapi": "3.1.0",
10+
"info": {"title": "Protocol Flags", "version": "1.0.0"},
11+
"paths": {
12+
"/books/{book_id}": {
13+
"get": {
14+
"parameters": [
15+
{
16+
"name": "book_id",
17+
"in": "path",
18+
"required": True,
19+
"schema": {"type": "integer"},
20+
}
21+
],
22+
"responses": {
23+
"200": {
24+
"description": "OK",
25+
"content": {
26+
"application/json": {
27+
"schema": {"$ref": "#/components/schemas/Book"}
28+
}
29+
},
30+
}
31+
},
32+
}
33+
}
34+
},
35+
"components": {
36+
"schemas": {
37+
"Book": {
38+
"type": "object",
39+
"required": ["id", "title"],
40+
"properties": {
41+
"id": {"type": "integer"},
42+
"title": {"type": "string"},
43+
},
44+
}
45+
}
46+
},
47+
}
48+
49+
50+
def main() -> None:
51+
generate_client(
52+
GenerationRequest(
53+
output_dir=Path(__file__).parent / "generated",
54+
spec_json=json.dumps(SPEC),
55+
overwrite=True,
56+
generate_requests=False,
57+
)
58+
)
59+
60+
61+
if __name__ == "__main__":
62+
main()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from __future__ import annotations
2+
3+
from typing import assert_type
4+
5+
from generated.my_client import Client
6+
from generated.my_client.types import Book
7+
8+
client = Client(base_url="http://testserver")
9+
10+
book = client.get("/books/{book_id}")(
11+
params={"book_id": "not statically constrained"},
12+
query={"unexpected": object()},
13+
headers={"x-test": object()},
14+
)
15+
assert_type(book, Book)
16+
assert_type(book["title"], str)

0 commit comments

Comments
 (0)