Skip to content

Commit 329f09c

Browse files
committed
improve type hints, apply type related fixes
1 parent 3961aa7 commit 329f09c

File tree

1 file changed

+39
-27
lines changed

1 file changed

+39
-27
lines changed

pytest_httpserver/httpserver.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from typing import Pattern
2525
from typing import Tuple
2626
from typing import Union
27+
from typing import cast
28+
from wsgiref.types import WSGIApplication
2729

2830
import werkzeug.http
2931
from werkzeug import Request
@@ -113,11 +115,13 @@ def complete(self, result: bool): # noqa: FBT001
113115

114116
@property
115117
def result(self) -> bool:
116-
return self._result
118+
return bool(self._result)
117119

118120
@property
119121
def elapsed_time(self) -> float:
120122
"""Elapsed time in seconds"""
123+
if self._stop is None:
124+
raise TypeError("unsupported operand type(s) for -: 'NoneType' and 'float'")
121125
return self._stop - self._start
122126

123127

@@ -139,7 +143,8 @@ def authorization_header_value_matcher(actual: str | None, expected: str) -> boo
139143
func = getattr(Authorization, "from_header", None)
140144
if func is None: # Werkzeug < 2.3.0
141145
func = werkzeug.http.parse_authorization_header # type: ignore[attr-defined]
142-
return func(actual) == func(expected)
146+
147+
return func(actual) == func(expected) # type: ignore
143148

144149
@staticmethod
145150
def default_header_value_matcher(actual: str | None, expected: str) -> bool:
@@ -174,7 +179,7 @@ def match(self, request_query_string: bytes) -> bool:
174179
return values[0] == values[1]
175180

176181
@abc.abstractmethod
177-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
182+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Any, Any]:
178183
pass
179184

180185

@@ -195,10 +200,10 @@ def __init__(self, query_string: bytes | str):
195200

196201
self.query_string = query_string
197202

198-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
203+
def get_comparing_values(self, request_query_string: bytes) -> tuple[bytes, bytes]:
199204
if isinstance(self.query_string, str):
200205
query_string = self.query_string.encode()
201-
elif isinstance(self.query_string, bytes):
206+
elif isinstance(self.query_string, bytes): # type: ignore
202207
query_string = self.query_string
203208
else:
204209
raise TypeError("query_string must be a string, or a bytes-like object")
@@ -211,7 +216,7 @@ class MappingQueryMatcher(QueryMatcher):
211216
Matches a query string to a dictionary or MultiDict specified
212217
"""
213218

214-
def __init__(self, query_dict: Mapping | MultiDict):
219+
def __init__(self, query_dict: Mapping[str, str] | MultiDict[str, str]):
215220
"""
216221
:param query_dict: if dictionary (Mapping) is specified, it will be used as a
217222
key-value mapping where both key and value should be string. If there are multiple
@@ -221,7 +226,7 @@ def __init__(self, query_dict: Mapping | MultiDict):
221226
"""
222227
self.query_dict = query_dict
223228

224-
def get_comparing_values(self, request_query_string: bytes) -> tuple:
229+
def get_comparing_values(self, request_query_string: bytes) -> tuple[Mapping[str, str], Mapping[str, str]]:
225230
query = MultiDict(urllib.parse.parse_qsl(request_query_string.decode("utf-8")))
226231
if isinstance(self.query_dict, MultiDict):
227232
return (query, self.query_dict)
@@ -241,14 +246,14 @@ def __init__(self, result: bool): # noqa: FBT001
241246
"""
242247
self.result = result
243248

244-
def get_comparing_values(self, request_query_string): # noqa: ARG002
249+
def get_comparing_values(self, request_query_string: bytes): # noqa: ARG002
245250
if self.result:
246251
return (True, True)
247252
else:
248253
return (True, False)
249254

250255

251-
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping) -> QueryMatcher:
256+
def _create_query_matcher(query_string: None | QueryMatcher | str | bytes | Mapping[str, str]) -> QueryMatcher:
252257
if isinstance(query_string, QueryMatcher):
253258
return query_string
254259

@@ -312,7 +317,7 @@ def __init__(
312317
data: str | bytes | None = None,
313318
data_encoding: str = "utf-8",
314319
headers: Mapping[str, str] | None = None,
315-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
320+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
316321
header_value_matcher: HVMATCHER_T | None = None,
317322
json: Any = UNDEFINED,
318323
):
@@ -410,7 +415,7 @@ def match_json(self, request: Request) -> bool:
410415

411416
return json_received == self.json
412417

413-
def difference(self, request: Request) -> list[tuple]:
418+
def difference(self, request: Request) -> list[tuple[str, str, str | URIPattern]]:
414419
"""
415420
Calculates the difference between the matcher and the request.
416421
@@ -422,7 +427,7 @@ def difference(self, request: Request) -> list[tuple]:
422427
matches the fields set in the matcher object.
423428
"""
424429

425-
retval: list[tuple] = []
430+
retval: list[tuple[str, Any, Any]] = []
426431

427432
if not self.match_uri(request):
428433
retval.append(("uri", request.path, self.uri))
@@ -433,8 +438,8 @@ def difference(self, request: Request) -> list[tuple]:
433438
if not self.query_matcher.match(request.query_string):
434439
retval.append(("query_string", request.query_string, self.query_string))
435440

436-
request_headers = {}
437-
expected_headers = {}
441+
request_headers: dict[str, str | None] = {}
442+
expected_headers: dict[str, str] = {}
438443
for key, value in self.headers.items():
439444
if not self.header_value_matcher(key, request.headers.get(key), value):
440445
request_headers[key] = request.headers.get(key)
@@ -467,7 +472,7 @@ class RequestHandlerBase(abc.ABC):
467472

468473
def respond_with_json(
469474
self,
470-
response_json,
475+
response_json: Any,
471476
status: int = 200,
472477
headers: Mapping[str, str] | None = None,
473478
content_type: str = "application/json",
@@ -578,7 +583,7 @@ def __repr__(self) -> str:
578583
return retval
579584

580585

581-
class RequestHandlerList(list):
586+
class RequestHandlerList(list[RequestHandler]):
582587
"""
583588
Represents a list of :py:class:`RequestHandler` objects.
584589
@@ -640,7 +645,7 @@ def __init__(
640645
self.port = port
641646
self.server = None
642647
self.server_thread = None
643-
self.assertions: list[str] = []
648+
self.assertions: list[str | AssertionError] = []
644649
self.handler_errors: list[Exception] = []
645650
self.log: list[tuple[Request, Response]] = []
646651
self.ssl_context = ssl_context
@@ -727,7 +732,7 @@ def thread_target(self):
727732
728733
This should not be called directly, but can be overridden to tailor it to your needs.
729734
"""
730-
735+
assert self.server is not None
731736
self.server.serve_forever()
732737

733738
def is_running(self) -> bool:
@@ -755,8 +760,13 @@ def start(self):
755760
if self.is_running():
756761
raise HTTPServerError("Server is already running")
757762

763+
app = cast(WSGIApplication, self.application)
758764
self.server = make_server(
759-
self.host, self.port, self.application, ssl_context=self.ssl_context, threaded=self.threaded
765+
self.host,
766+
self.port,
767+
app,
768+
ssl_context=self.ssl_context,
769+
threaded=self.threaded,
760770
)
761771
self.port = self.server.port # Update port (needed if `port` was set to 0)
762772
self.server_thread = threading.Thread(target=self.thread_target)
@@ -772,14 +782,16 @@ def stop(self):
772782
Only a running server can be stopped. If the sever is not running, :py:class`HTTPServerError`
773783
will be raised.
774784
"""
785+
assert self.server is not None
786+
assert self.server_thread is not None
775787
if not self.is_running():
776788
raise HTTPServerError("Server is not running")
777789
self.server.shutdown()
778790
self.server_thread.join()
779791
self.server = None
780792
self.server_thread = None
781793

782-
def add_assertion(self, obj):
794+
def add_assertion(self, obj: str | AssertionError):
783795
"""
784796
Add a new assertion
785797
@@ -849,7 +861,7 @@ def dispatch(self, request: Request) -> Response:
849861
"""
850862

851863
@Request.application # type: ignore
852-
def application(self, request: Request):
864+
def application(self, request: Request) -> Response:
853865
"""
854866
Entry point of werkzeug.
855867
@@ -886,7 +898,7 @@ def __exit__(self, *args, **kwargs):
886898
self.stop()
887899

888900
@staticmethod
889-
def format_host(host):
901+
def format_host(host: str):
890902
"""
891903
Formats a hostname so it can be used in a URL.
892904
Notably, this adds brackets around IPV6 addresses when
@@ -929,8 +941,8 @@ class HTTPServer(HTTPServerBase): # pylint: disable=too-many-instance-attribute
929941

930942
def __init__(
931943
self,
932-
host=DEFAULT_LISTEN_HOST,
933-
port=DEFAULT_LISTEN_PORT,
944+
host: str = DEFAULT_LISTEN_HOST,
945+
port: int = DEFAULT_LISTEN_PORT,
934946
ssl_context: SSLContext | None = None,
935947
default_waiting_settings: WaitingSettings | None = None,
936948
*,
@@ -979,7 +991,7 @@ def expect_request(
979991
data: str | bytes | None = None,
980992
data_encoding: str = "utf-8",
981993
headers: Mapping[str, str] | None = None,
982-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
994+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
983995
header_value_matcher: HVMATCHER_T | None = None,
984996
handler_type: HandlerType = HandlerType.PERMANENT,
985997
json: Any = UNDEFINED,
@@ -1062,7 +1074,7 @@ def expect_oneshot_request(
10621074
data: str | bytes | None = None,
10631075
data_encoding: str = "utf-8",
10641076
headers: Mapping[str, str] | None = None,
1065-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1077+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
10661078
header_value_matcher: HVMATCHER_T | None = None,
10671079
json: Any = UNDEFINED,
10681080
) -> RequestHandler:
@@ -1117,7 +1129,7 @@ def expect_ordered_request(
11171129
data: str | bytes | None = None,
11181130
data_encoding: str = "utf-8",
11191131
headers: Mapping[str, str] | None = None,
1120-
query_string: None | QueryMatcher | str | bytes | Mapping = None,
1132+
query_string: None | QueryMatcher | str | bytes | Mapping[str, str] = None,
11211133
header_value_matcher: HVMATCHER_T | None = None,
11221134
json: Any = UNDEFINED,
11231135
) -> RequestHandler:

0 commit comments

Comments
 (0)