|
1 | 1 | import re |
2 | 2 | from inspect import isawaitable |
| 3 | +from typing import Any, Awaitable, cast |
3 | 4 |
|
4 | 5 | from pytest import mark # type: ignore |
5 | 6 |
|
6 | | -from graphql.execution import execute |
| 7 | +from graphql.execution import execute, ExecutionResult |
7 | 8 | from graphql.language import parse |
| 9 | +from graphql.pyutils import AwaitableOrValue |
8 | 10 | from graphql.type import ( |
9 | 11 | GraphQLArgument, |
10 | 12 | GraphQLField, |
@@ -95,22 +97,25 @@ async def promiseNonNullNest(self, _info): |
95 | 97 | ) |
96 | 98 |
|
97 | 99 |
|
98 | | -def execute_query(query, root_value): |
| 100 | +def execute_query(query: str, root_value: Any) -> AwaitableOrValue[ExecutionResult]: |
99 | 101 | return execute(schema=schema, document=parse(query), root_value=root_value) |
100 | 102 |
|
101 | 103 |
|
102 | 104 | # avoids also doing any nests |
103 | | -def patch(data): |
| 105 | +def patch(data: str) -> str: |
104 | 106 | return re.sub( |
105 | 107 | r"\bsyncNonNull\b", "promiseNonNull", re.sub(r"\bsync\b", "promise", data) |
106 | 108 | ) |
107 | 109 |
|
108 | 110 |
|
109 | | -async def execute_sync_and_async(query, root_value): |
| 111 | +async def execute_sync_and_async(query: str, root_value: Any) -> ExecutionResult: |
110 | 112 | sync_result = execute_query(query, root_value) |
111 | 113 | if isawaitable(sync_result): |
112 | | - sync_result = await sync_result |
113 | | - async_result = await execute_query(patch(query), root_value) |
| 114 | + sync_result = await cast(Awaitable[ExecutionResult], sync_result) |
| 115 | + sync_result = cast(ExecutionResult, sync_result) |
| 116 | + async_result = await cast( |
| 117 | + Awaitable[ExecutionResult], execute_query(patch(query), root_value) |
| 118 | + ) |
114 | 119 |
|
115 | 120 | assert repr(async_result) == patch(repr(sync_result)) |
116 | 121 | return sync_result |
@@ -254,12 +259,16 @@ def describe_nulls_a_complex_tree_of_nullable_fields_each(): |
254 | 259 |
|
255 | 260 | @mark.asyncio |
256 | 261 | async def returns_null(): |
257 | | - result = await execute_query(query, NullingData()) |
| 262 | + result = await cast( |
| 263 | + Awaitable[ExecutionResult], execute_query(query, NullingData()) |
| 264 | + ) |
258 | 265 | assert result == (data, None) |
259 | 266 |
|
260 | 267 | @mark.asyncio |
261 | 268 | async def throws(): |
262 | | - result = await execute_query(query, ThrowingData()) |
| 269 | + result = await cast( |
| 270 | + Awaitable[ExecutionResult], execute_query(query, ThrowingData()) |
| 271 | + ) |
263 | 272 | assert result == ( |
264 | 273 | data, |
265 | 274 | [ |
@@ -384,7 +393,9 @@ def describe_nulls_first_nullable_after_long_chain_of_non_null_fields(): |
384 | 393 |
|
385 | 394 | @mark.asyncio |
386 | 395 | async def returns_null(): |
387 | | - result = await execute_query(query, NullingData()) |
| 396 | + result = await cast( |
| 397 | + Awaitable[ExecutionResult], execute_query(query, NullingData()) |
| 398 | + ) |
388 | 399 | assert result == ( |
389 | 400 | data, |
390 | 401 | [ |
@@ -445,7 +456,9 @@ async def returns_null(): |
445 | 456 |
|
446 | 457 | @mark.asyncio |
447 | 458 | async def throws(): |
448 | | - result = await execute_query(query, ThrowingData()) |
| 459 | + result = await cast( |
| 460 | + Awaitable[ExecutionResult], execute_query(query, ThrowingData()) |
| 461 | + ) |
449 | 462 | assert result == ( |
450 | 463 | data, |
451 | 464 | [ |
|
0 commit comments