Skip to content

Commit 35ea72f

Browse files
authored
Fix mypy arg-type errors in generated discriminated union encoders (#175)
Why === The codegen for discriminated union TypedDict encoders produces ternary chains like: ```python encode_Foo(x) if x["kind"] == "foo" else encode_Bar(x) ``` mypy can't narrow union types through these ternary conditions, so it flags every encoder call as receiving the wrong type (`arg-type`). This broke the pid2 codegen CI when new discriminated union variants were added to a schema. What changed ============ Use `cast()` to explicitly narrow the type to the correct variant after the discriminator check, instead of suppressing with `# type: ignore[arg-type]`. This preserves type safety in the generated code. Before: ```python encode_Foo(x) # type: ignore[arg-type] if x["kind"] == "foo" else encode_Bar(x) # type: ignore[arg-type] ``` After: ```python encode_Foo(cast('Foo', x)) if x["kind"] == "foo" else encode_Bar(cast('Bar', x)) ``` Affects both the single-variant and multi-variant discriminator code paths. Test plan ========= CI
1 parent 6c7a537 commit 35ea72f

File tree

12 files changed

+23
-4
lines changed

12 files changed

+23
-4
lines changed

src/replit_river/codegen/client.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
Mapping,
8181
NotRequired,
8282
TypedDict,
83+
cast,
8384
)
8485
from typing_extensions import Annotated
8586
@@ -301,9 +302,10 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
301302
# "encoder_names" is only a TypedDict thing
302303
encoder_names.add(encoder_name)
303304
_field_name = render_literal_type(encoder_name)
305+
_type_name = render_literal_type(type_name)
304306
typeddict_encoder.append(
305307
f"""\
306-
{_field_name}(x) # type: ignore[arg-type]
308+
{_field_name}(cast('{_type_name}', x))
307309
""".strip()
308310
)
309311
if local_discriminators:
@@ -333,8 +335,10 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
333335
# TODO(dstewart): Figure out why uncommenting this breaks
334336
# generated code
335337
# encoder_names.add(encoder_name)
338+
_type_name = render_literal_type(type_name)
336339
typeddict_encoder.append(
337-
f"{render_literal_type(encoder_name)}(x)"
340+
f"{render_literal_type(encoder_name)}"
341+
f"(cast('{_type_name}', x))"
338342
)
339343
typeddict_encoder.append(
340344
f"""

tests/v1/codegen/rpc/generated/test_service/rpc_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/rpc/generated_special_chars/test_service/rpc_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_anyof_mixed_types/test_service/anyof_mixed_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/emit_error.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_basic_stream/test_service/stream_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_pathological_types/test_service/pathological_method.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_recursive_types/recursiveService/getTree.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

tests/v1/codegen/snapshot/snapshots/test_unknown_enum/enumService/needsEnumObject.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Mapping,
88
NotRequired,
99
TypedDict,
10+
cast,
1011
)
1112
from typing_extensions import Annotated
1213

@@ -71,9 +72,13 @@ def encode_NeedsenumobjectInput(
7172
x: "NeedsenumobjectInput",
7273
) -> Any:
7374
return (
74-
encode_NeedsenumobjectInputOneOf_in_first(x)
75+
encode_NeedsenumobjectInputOneOf_in_first(
76+
cast("NeedsenumobjectInputOneOf_in_first", x)
77+
)
7578
if x["kind"] == "in_first"
76-
else encode_NeedsenumobjectInputOneOf_in_second(x)
79+
else encode_NeedsenumobjectInputOneOf_in_second(
80+
cast("NeedsenumobjectInputOneOf_in_second", x)
81+
)
7782
)
7883

7984

0 commit comments

Comments
 (0)