Skip to content

Commit 7da4e00

Browse files
committed
Refine test_requests_batch.py
1 parent 21963e6 commit 7da4e00

File tree

1 file changed

+162
-24
lines changed

1 file changed

+162
-24
lines changed

tests/test_requests_batch.py

Lines changed: 162 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,28 @@ def test_code():
7373

7474
query = [GraphQLRequest(query1_str)]
7575

76-
# Execute query synchronously
77-
results = session.execute_batch(query)
76+
# Execute query synchronously with timeout and get_execution_result=True
77+
execution_results = session.execute_batch(
78+
query, get_execution_result=True, timeout=10
79+
)
7880

79-
continents = results[0]["continents"]
81+
# Verify the ExecutionResult structure
82+
assert len(execution_results) == 1
83+
result = execution_results[0]
8084

81-
africa = continents[0]
85+
# Verify errors is None (not just missing)
86+
assert result.errors is None
87+
# Verify extensions is None (response has no extensions)
88+
assert result.extensions is None
89+
90+
# Verify data content
91+
assert result.data is not None
92+
continents = result.data["continents"]
93+
assert len(continents) == 7
8294

95+
africa = continents[0]
8396
assert africa["code"] == "AF"
97+
assert africa["name"] == "Africa"
8498

8599
# Checking response headers are saved in the transport
86100
assert hasattr(transport, "response_headers")
@@ -90,6 +104,55 @@ def test_code():
90104
await run_sync_test(server, test_code)
91105

92106

107+
@pytest.mark.aiohttp
108+
@pytest.mark.asyncio
109+
async def test_requests_query_two_requests(aiohttp_server, run_sync_test):
110+
from aiohttp import web
111+
112+
from gql.transport.requests import RequestsHTTPTransport
113+
114+
async def handler(request):
115+
return web.Response(
116+
text=query1_server_answer_twice_list,
117+
content_type="application/json",
118+
)
119+
120+
app = web.Application()
121+
app.router.add_route("POST", "/", handler)
122+
server = await aiohttp_server(app)
123+
124+
url = server.make_url("/")
125+
126+
def test_code():
127+
transport = RequestsHTTPTransport(url=url)
128+
129+
with Client(transport=transport) as session:
130+
131+
# Two requests in batch
132+
query = [GraphQLRequest(query1_str), GraphQLRequest(query1_str)]
133+
134+
# Execute query synchronously with timeout and get_execution_result=True
135+
execution_results = session.execute_batch(
136+
query, get_execution_result=True, timeout=10
137+
)
138+
139+
# Verify we got exactly 2 results
140+
assert len(execution_results) == 2
141+
142+
# Verify both results
143+
for result in execution_results:
144+
assert result.errors is None
145+
assert result.extensions is None
146+
assert result.data is not None
147+
148+
continents = result.data["continents"]
149+
assert len(continents) == 7
150+
assert continents[0]["code"] == "AF"
151+
assert continents[0]["name"] == "Africa"
152+
153+
await run_sync_test(server, test_code)
154+
155+
93156
@pytest.mark.aiohttp
94157
@pytest.mark.asyncio
95158
async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test):
@@ -227,14 +290,22 @@ def test_code():
227290

228291
query = [GraphQLRequest(query1_str)]
229292

230-
# Execute query synchronously
231-
results = session.execute_batch(query)
293+
# Execute query synchronously with timeout and get_execution_result=True
294+
execution_results = session.execute_batch(
295+
query, get_execution_result=True, timeout=10
296+
)
232297

233-
continents = results[0]["continents"]
298+
result = execution_results[0]
234299

235-
africa = continents[0]
300+
# Verify ExecutionResult structure
301+
assert result.errors is None
302+
assert result.extensions is None
303+
assert result.data is not None
236304

305+
continents = result.data["continents"]
306+
africa = continents[0]
237307
assert africa["code"] == "AF"
308+
assert africa["name"] == "Africa"
238309

239310
await run_sync_test(server, test_code)
240311

@@ -271,6 +342,7 @@ def test_code():
271342
session.execute_batch(query)
272343

273344
assert "401 Client Error: Unauthorized" in str(exc_info.value)
345+
assert exc_info.value.code == 401
274346

275347
await run_sync_test(server, test_code)
276348

@@ -312,6 +384,7 @@ def test_code():
312384
session.execute(request)
313385

314386
assert "401 Client Error: Unauthorized" in str(exc_info.value)
387+
assert exc_info.value.code == 401
315388

316389
await run_sync_test(server, test_code)
317390

@@ -359,6 +432,7 @@ def test_code():
359432
session.execute_batch(query)
360433

361434
assert "429, message='Too Many Requests'" in str(exc_info.value)
435+
assert exc_info.value.code == 429
362436

363437
# Checking response headers are saved in the transport
364438
assert hasattr(transport, "response_headers")
@@ -390,9 +464,13 @@ def test_code():
390464

391465
query = [GraphQLRequest(query1_str)]
392466

393-
with pytest.raises(TransportServerError):
467+
with pytest.raises(TransportServerError) as exc_info:
394468
session.execute_batch(query)
395469

470+
# Verify the error contains 500 status code
471+
assert "500" in str(exc_info.value)
472+
assert exc_info.value.code == 500
473+
396474
await run_sync_test(server, test_code)
397475

398476

@@ -424,28 +502,71 @@ def test_code():
424502

425503
query = [GraphQLRequest(query1_str)]
426504

427-
with pytest.raises(TransportQueryError):
505+
with pytest.raises(TransportQueryError) as exc_info:
428506
session.execute_batch(query)
429507

508+
# Verify error message
509+
assert str(exc_info.value) == "Error 1"
510+
511+
# Verify errors list from exception
512+
assert exc_info.value.errors == ["Error 1", "Error 2"]
513+
514+
# Verify data is None (no data in response)
515+
assert exc_info.value.data is None
516+
517+
# Verify extensions is None (no extensions in response)
518+
assert exc_info.value.extensions is None
519+
430520
await run_sync_test(server, test_code)
431521

432522

433523
invalid_protocol_responses = [
434-
"{}",
435-
"qlsjfqsdlkj",
436-
'{"not_data_or_errors": 35}',
437-
"[{}]",
438-
"[qlsjfqsdlkj]",
439-
'[{"not_data_or_errors": 35}]',
440-
"[]",
441-
"[1]",
524+
(
525+
"{}",
526+
"Server did not return a valid GraphQL result: Answer is not a list: {}",
527+
),
528+
(
529+
"qlsjfqsdlkj",
530+
"Server did not return a GraphQL result: Not a JSON answer: qlsjfqsdlkj",
531+
),
532+
(
533+
'{"not_data_or_errors": 35}',
534+
"Server did not return a valid GraphQL result: "
535+
"Answer is not a list: {'not_data_or_errors': 35}",
536+
),
537+
(
538+
"[{}]",
539+
"Server did not return a valid GraphQL result: "
540+
'No "data" or "errors" keys in answer: [{}]',
541+
),
542+
(
543+
"[qlsjfqsdlkj]",
544+
"Server did not return a GraphQL result: Not a JSON answer: [qlsjfqsdlkj]",
545+
),
546+
(
547+
'[{"not_data_or_errors": 35}]',
548+
"Server did not return a valid GraphQL result: "
549+
'No "data" or "errors" keys in answer: [{\'not_data_or_errors\': 35}]',
550+
),
551+
(
552+
"[]",
553+
"Server did not return a valid GraphQL result: "
554+
"Invalid number of answers: 0 answers received for 1 requests: []",
555+
),
556+
(
557+
"[1]",
558+
"Server did not return a valid GraphQL result: "
559+
"Not every answer is dict: [1]",
560+
),
442561
]
443562

444563

445564
@pytest.mark.aiohttp
446565
@pytest.mark.asyncio
447-
@pytest.mark.parametrize("response", invalid_protocol_responses)
448-
async def test_requests_invalid_protocol(aiohttp_server, response, run_sync_test):
566+
@pytest.mark.parametrize("response,expected_message", invalid_protocol_responses)
567+
async def test_requests_invalid_protocol(
568+
aiohttp_server, response, expected_message, run_sync_test
569+
):
449570
from aiohttp import web
450571

451572
from gql.transport.requests import RequestsHTTPTransport
@@ -466,9 +587,11 @@ def test_code():
466587

467588
query = [GraphQLRequest(query1_str)]
468589

469-
with pytest.raises(TransportProtocolError):
590+
with pytest.raises(TransportProtocolError) as exc_info:
470591
session.execute_batch(query)
471592

593+
assert str(exc_info.value) == expected_message
594+
472595
await run_sync_test(server, test_code)
473596

474597

@@ -495,9 +618,11 @@ def test_code():
495618

496619
query = [GraphQLRequest(query1_str)]
497620

498-
with pytest.raises(TransportClosed):
621+
with pytest.raises(TransportClosed) as exc_info:
499622
transport.execute_batch(query)
500623

624+
assert str(exc_info.value) == "Transport is not connected"
625+
501626
await run_sync_test(server, test_code)
502627

503628

@@ -538,9 +663,22 @@ def test_code():
538663

539664
query = [GraphQLRequest(query1_str)]
540665

541-
execution_results = session.execute_batch(query, get_execution_result=True)
666+
execution_results = session.execute_batch(
667+
query, get_execution_result=True, timeout=10
668+
)
669+
670+
result = execution_results[0]
671+
672+
# Verify errors is None
673+
assert result.errors is None
674+
675+
# Verify data content
676+
assert result.data is not None
677+
assert result.data["continents"][0]["code"] == "AF"
542678

543-
assert execution_results[0].extensions["key1"] == "val1"
679+
# Verify extensions
680+
assert result.extensions is not None
681+
assert result.extensions["key1"] == "val1"
544682

545683
await run_sync_test(server, test_code)
546684

0 commit comments

Comments
 (0)